auth-gate 0.2.0__tar.gz → 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {auth_gate-0.2.0/src/auth_gate.egg-info → auth_gate-0.2.2}/PKG-INFO +1 -1
- {auth_gate-0.2.0 → auth_gate-0.2.2}/pyproject.toml +1 -1
- {auth_gate-0.2.0 → auth_gate-0.2.2}/setup.cfg +1 -1
- {auth_gate-0.2.0 → auth_gate-0.2.2}/src/auth_gate/__init__.py +6 -1
- auth_gate-0.2.2/src/auth_gate/s2s_auth.py +352 -0
- {auth_gate-0.2.0 → auth_gate-0.2.2/src/auth_gate.egg-info}/PKG-INFO +1 -1
- {auth_gate-0.2.0 → auth_gate-0.2.2}/src/auth_gate.egg-info/SOURCES.txt +2 -0
- {auth_gate-0.2.0 → auth_gate-0.2.2}/src/tests/conftest.py +19 -3
- auth_gate-0.2.2/src/tests/test_s2s_auth.py +295 -0
- {auth_gate-0.2.0 → auth_gate-0.2.2}/LICENSE +0 -0
- {auth_gate-0.2.0 → auth_gate-0.2.2}/README.md +0 -0
- {auth_gate-0.2.0 → auth_gate-0.2.2}/src/auth_gate/config.py +0 -0
- {auth_gate-0.2.0 → auth_gate-0.2.2}/src/auth_gate/fastapi_utils.py +0 -0
- {auth_gate-0.2.0 → auth_gate-0.2.2}/src/auth_gate/middleware.py +0 -0
- {auth_gate-0.2.0 → auth_gate-0.2.2}/src/auth_gate/schemas.py +0 -0
- {auth_gate-0.2.0 → auth_gate-0.2.2}/src/auth_gate/user_auth.py +0 -0
- {auth_gate-0.2.0 → auth_gate-0.2.2}/src/auth_gate.egg-info/dependency_links.txt +0 -0
- {auth_gate-0.2.0 → auth_gate-0.2.2}/src/auth_gate.egg-info/requires.txt +0 -0
- {auth_gate-0.2.0 → auth_gate-0.2.2}/src/auth_gate.egg-info/top_level.txt +0 -0
- {auth_gate-0.2.0 → auth_gate-0.2.2}/src/tests/__init__.py +0 -0
- {auth_gate-0.2.0 → auth_gate-0.2.2}/src/tests/test_config.py +0 -0
- {auth_gate-0.2.0 → auth_gate-0.2.2}/src/tests/test_fastapi_utils.py +0 -0
- {auth_gate-0.2.0 → auth_gate-0.2.2}/src/tests/test_intergration.py +0 -0
- {auth_gate-0.2.0 → auth_gate-0.2.2}/src/tests/test_middleware.py +0 -0
- {auth_gate-0.2.0 → auth_gate-0.2.2}/src/tests/test_schema.py +0 -0
- {auth_gate-0.2.0 → auth_gate-0.2.2}/src/tests/test_user_auth.py +0 -0
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "auth-gate"
|
|
7
|
-
version = "0.2.
|
|
7
|
+
version = "0.2.2"
|
|
8
8
|
description = "Enterprise-grade authentication for microservices with Kong and Keycloak integration"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
requires-python = ">=3.11"
|
|
@@ -25,10 +25,11 @@ from .fastapi_utils import (
|
|
|
25
25
|
verify_hmac_signature,
|
|
26
26
|
)
|
|
27
27
|
from .middleware import AuthMiddleware
|
|
28
|
+
from .s2s_auth import CircuitBreaker, CircuitBreakerOpenError, ServiceAuthClient
|
|
28
29
|
from .schemas import ServiceContext, UserContext
|
|
29
30
|
from .user_auth import UserValidator
|
|
30
31
|
|
|
31
|
-
__version__ = "0.2.
|
|
32
|
+
__version__ = "0.2.2"
|
|
32
33
|
|
|
33
34
|
__all__ = [
|
|
34
35
|
# Configuration
|
|
@@ -39,6 +40,10 @@ __all__ = [
|
|
|
39
40
|
"ServiceContext",
|
|
40
41
|
# User Authentication
|
|
41
42
|
"UserValidator",
|
|
43
|
+
# Service-to-Service
|
|
44
|
+
"ServiceAuthClient",
|
|
45
|
+
"CircuitBreaker",
|
|
46
|
+
"CircuitBreakerOpenError",
|
|
42
47
|
# Middleware
|
|
43
48
|
"AuthMiddleware",
|
|
44
49
|
# FastAPI Dependencies
|
|
@@ -0,0 +1,352 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Service-to-service authentication with resilience patterns
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from datetime import datetime, timedelta, timezone
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from typing import Optional, Tuple, Type, Union
|
|
11
|
+
|
|
12
|
+
import httpx
|
|
13
|
+
from fastapi import HTTPException, status
|
|
14
|
+
|
|
15
|
+
from .config import get_settings
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class CircuitState(Enum):
|
|
21
|
+
"""Circuit breaker states"""
|
|
22
|
+
|
|
23
|
+
CLOSED = "closed" # Normal operation
|
|
24
|
+
OPEN = "open" # Failing, reject calls
|
|
25
|
+
HALF_OPEN = "half_open" # Testing recovery
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class CircuitBreakerOpenError(Exception):
|
|
29
|
+
"""Exception raised when circuit breaker is open"""
|
|
30
|
+
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class CircuitBreaker:
|
|
35
|
+
"""
|
|
36
|
+
Circuit breaker implementation for fault tolerance.
|
|
37
|
+
|
|
38
|
+
States:
|
|
39
|
+
- CLOSED: Normal operation, calls pass through
|
|
40
|
+
- OPEN: Service is failing, calls are rejected
|
|
41
|
+
- HALF_OPEN: Testing if service recovered
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
failure_threshold: int = 5,
|
|
47
|
+
recovery_timeout: int = 60,
|
|
48
|
+
expected_exception: Union[Type[Exception], Tuple[Type[Exception], ...]] = Exception,
|
|
49
|
+
name: Optional[str] = None,
|
|
50
|
+
):
|
|
51
|
+
"""
|
|
52
|
+
Initialize circuit breaker.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
failure_threshold: Number of failures before opening circuit
|
|
56
|
+
recovery_timeout: Seconds to wait before attempting recovery
|
|
57
|
+
expected_exception: Exception type to catch
|
|
58
|
+
name: Circuit breaker name for logging
|
|
59
|
+
"""
|
|
60
|
+
self.failure_threshold = failure_threshold
|
|
61
|
+
self.recovery_timeout = recovery_timeout
|
|
62
|
+
self.expected_exception = expected_exception
|
|
63
|
+
self.name = name or "CircuitBreaker"
|
|
64
|
+
|
|
65
|
+
self._state = CircuitState.CLOSED
|
|
66
|
+
self._failure_count = 0
|
|
67
|
+
self._last_failure_time: Optional[datetime] = None
|
|
68
|
+
self._lock = asyncio.Lock()
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def state(self) -> CircuitState:
|
|
72
|
+
"""Get current circuit state"""
|
|
73
|
+
return self._state
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def is_open(self) -> bool:
|
|
77
|
+
"""Check if circuit is open"""
|
|
78
|
+
return self._state == CircuitState.OPEN
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def is_closed(self) -> bool:
|
|
82
|
+
"""Check if circuit is closed"""
|
|
83
|
+
return self._state == CircuitState.CLOSED
|
|
84
|
+
|
|
85
|
+
async def __aenter__(self):
|
|
86
|
+
"""Context manager entry"""
|
|
87
|
+
await self._before_call()
|
|
88
|
+
return self
|
|
89
|
+
|
|
90
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
91
|
+
"""Context manager exit"""
|
|
92
|
+
if exc_type is None:
|
|
93
|
+
await self._on_success()
|
|
94
|
+
elif issubclass(exc_type, self.expected_exception):
|
|
95
|
+
await self._on_failure()
|
|
96
|
+
# Don't suppress the exception
|
|
97
|
+
return False
|
|
98
|
+
|
|
99
|
+
async def _before_call(self):
|
|
100
|
+
"""Check circuit state before allowing call"""
|
|
101
|
+
async with self._lock:
|
|
102
|
+
if self._state == CircuitState.OPEN:
|
|
103
|
+
# Check if recovery timeout has passed
|
|
104
|
+
if self._last_failure_time and datetime.now(
|
|
105
|
+
timezone.utc
|
|
106
|
+
) - self._last_failure_time > timedelta(seconds=self.recovery_timeout):
|
|
107
|
+
logger.info(f"{self.name}: Attempting recovery (HALF_OPEN)")
|
|
108
|
+
self._state = CircuitState.HALF_OPEN
|
|
109
|
+
else:
|
|
110
|
+
raise CircuitBreakerOpenError(f"{self.name}: Circuit is OPEN, rejecting call")
|
|
111
|
+
|
|
112
|
+
async def _on_success(self):
|
|
113
|
+
"""Handle successful call"""
|
|
114
|
+
async with self._lock:
|
|
115
|
+
if self._state == CircuitState.HALF_OPEN:
|
|
116
|
+
logger.info(f"{self.name}: Recovery successful, closing circuit")
|
|
117
|
+
self._state = CircuitState.CLOSED
|
|
118
|
+
self._failure_count = 0
|
|
119
|
+
self._last_failure_time = None
|
|
120
|
+
elif self._state == CircuitState.CLOSED:
|
|
121
|
+
# Reset failure count on success
|
|
122
|
+
self._failure_count = 0
|
|
123
|
+
|
|
124
|
+
async def _on_failure(self):
|
|
125
|
+
"""Handle failed call"""
|
|
126
|
+
async with self._lock:
|
|
127
|
+
self._failure_count += 1
|
|
128
|
+
self._last_failure_time = datetime.now(timezone.utc)
|
|
129
|
+
|
|
130
|
+
if self._state == CircuitState.HALF_OPEN:
|
|
131
|
+
logger.warning(f"{self.name}: Recovery failed, reopening circuit")
|
|
132
|
+
self._state = CircuitState.OPEN
|
|
133
|
+
elif self._state == CircuitState.CLOSED:
|
|
134
|
+
if self._failure_count >= self.failure_threshold:
|
|
135
|
+
logger.error(
|
|
136
|
+
f"{self.name}: Failure threshold reached ({self._failure_count}), "
|
|
137
|
+
f"opening circuit"
|
|
138
|
+
)
|
|
139
|
+
self._state = CircuitState.OPEN
|
|
140
|
+
else:
|
|
141
|
+
logger.warning(
|
|
142
|
+
f"{self.name}: Failure {self._failure_count}/{self.failure_threshold}"
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
def reset(self):
|
|
146
|
+
"""Manually reset circuit breaker"""
|
|
147
|
+
self._state = CircuitState.CLOSED
|
|
148
|
+
self._failure_count = 0
|
|
149
|
+
self._last_failure_time = None
|
|
150
|
+
logger.info(f"{self.name}: Circuit manually reset")
|
|
151
|
+
|
|
152
|
+
async def call(self, func, *args, **kwargs):
|
|
153
|
+
"""
|
|
154
|
+
Execute a function with circuit breaker protection.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
func: Async function to execute
|
|
158
|
+
*args: Positional arguments for func
|
|
159
|
+
**kwargs: Keyword arguments for func
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
Result of func
|
|
163
|
+
|
|
164
|
+
Raises:
|
|
165
|
+
CircuitBreakerOpenError: If circuit is open
|
|
166
|
+
Exception: Any exception raised by func
|
|
167
|
+
"""
|
|
168
|
+
async with self:
|
|
169
|
+
return await func(*args, **kwargs)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
@dataclass
|
|
173
|
+
class ServiceToken:
|
|
174
|
+
"""Service account token"""
|
|
175
|
+
|
|
176
|
+
access_token: str
|
|
177
|
+
token_type: str
|
|
178
|
+
expires_in: int
|
|
179
|
+
created_at: datetime = datetime.now(timezone.utc) # Default to current time
|
|
180
|
+
|
|
181
|
+
@property
|
|
182
|
+
def is_expired(self) -> bool:
|
|
183
|
+
"""Check if token is expired (with buffer)"""
|
|
184
|
+
settings = get_settings()
|
|
185
|
+
expiry_time = self.created_at + timedelta(
|
|
186
|
+
seconds=self.expires_in - settings.TOKEN_REFRESH_BUFFER
|
|
187
|
+
)
|
|
188
|
+
return datetime.now(timezone.utc) > expiry_time
|
|
189
|
+
|
|
190
|
+
@property
|
|
191
|
+
def authorization_header(self) -> str:
|
|
192
|
+
"""Get authorization header value"""
|
|
193
|
+
return f"{self.token_type} {self.access_token}"
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class ServiceAuthClient:
|
|
197
|
+
"""
|
|
198
|
+
Client for obtaining and managing service account tokens from Keycloak.
|
|
199
|
+
Includes circuit breaker for resilience against Keycloak failures.
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
def __init__(self):
|
|
203
|
+
"""Initialize service auth client with circuit breaker"""
|
|
204
|
+
settings = get_settings()
|
|
205
|
+
self.token_url = f"{settings.KEYCLOAK_REALM_URL}/protocol/openid-connect/token"
|
|
206
|
+
self.client_id = settings.SERVICE_CLIENT_ID
|
|
207
|
+
self.client_secret = settings.SERVICE_CLIENT_SECRET
|
|
208
|
+
self._token: Optional[ServiceToken] = None
|
|
209
|
+
self._lock = asyncio.Lock()
|
|
210
|
+
self._http_client: Optional[httpx.AsyncClient] = None
|
|
211
|
+
|
|
212
|
+
# Initialize circuit breaker
|
|
213
|
+
self._circuit_breaker: CircuitBreaker = CircuitBreaker(
|
|
214
|
+
failure_threshold=settings.CIRCUIT_BREAKER_FAILURE_THRESHOLD,
|
|
215
|
+
recovery_timeout=settings.CIRCUIT_BREAKER_RECOVERY_TIMEOUT,
|
|
216
|
+
expected_exception=(httpx.RequestError, HTTPException),
|
|
217
|
+
name="ServiceAuthCircuit",
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
@property
|
|
221
|
+
async def http_client(self) -> httpx.AsyncClient:
|
|
222
|
+
"""Get or create HTTP client"""
|
|
223
|
+
if self._http_client is None:
|
|
224
|
+
settings = get_settings()
|
|
225
|
+
self._http_client = httpx.AsyncClient(
|
|
226
|
+
timeout=httpx.Timeout(settings.HTTP_TIMEOUT),
|
|
227
|
+
limits=httpx.Limits(max_keepalive_connections=5),
|
|
228
|
+
)
|
|
229
|
+
return self._http_client
|
|
230
|
+
|
|
231
|
+
async def _fetch_token(self) -> ServiceToken:
|
|
232
|
+
"""
|
|
233
|
+
Internal method to fetch token from Keycloak.
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
ServiceToken
|
|
237
|
+
|
|
238
|
+
Raises:
|
|
239
|
+
HTTPException: If token fetch fails
|
|
240
|
+
"""
|
|
241
|
+
client = await self.http_client
|
|
242
|
+
response = await client.post(
|
|
243
|
+
self.token_url,
|
|
244
|
+
data={
|
|
245
|
+
"grant_type": "client_credentials",
|
|
246
|
+
"client_id": self.client_id,
|
|
247
|
+
"client_secret": self.client_secret,
|
|
248
|
+
"scope": "openid profile",
|
|
249
|
+
},
|
|
250
|
+
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
if response.status_code != 200:
|
|
254
|
+
logger.error(f"Failed to get service token: {response.status_code} - {response.text}")
|
|
255
|
+
raise HTTPException(
|
|
256
|
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
257
|
+
detail="Unable to authenticate service",
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
token_data = response.json()
|
|
261
|
+
return ServiceToken(
|
|
262
|
+
access_token=token_data["access_token"],
|
|
263
|
+
token_type=token_data.get("token_type", "Bearer"),
|
|
264
|
+
expires_in=token_data.get("expires_in", 300),
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
async def get_service_token(self) -> str:
|
|
268
|
+
"""
|
|
269
|
+
Get service account token for service-to-service calls.
|
|
270
|
+
Uses circuit breaker for resilience against Keycloak failures.
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
Authorization header value (e.g., "Bearer token")
|
|
274
|
+
|
|
275
|
+
Raises:
|
|
276
|
+
CircuitBreakerOpenError: If circuit is open due to repeated failures
|
|
277
|
+
HTTPException: If token fetch fails
|
|
278
|
+
"""
|
|
279
|
+
async with self._lock:
|
|
280
|
+
# Return cached token if still valid
|
|
281
|
+
if self._token and not self._token.is_expired:
|
|
282
|
+
return self._token.authorization_header
|
|
283
|
+
|
|
284
|
+
# Get new token with circuit breaker protection
|
|
285
|
+
try:
|
|
286
|
+
self._token = await self._circuit_breaker.call(self._fetch_token)
|
|
287
|
+
if self._token is not None:
|
|
288
|
+
logger.info(
|
|
289
|
+
f"Obtained new service token for {self.client_id}, "
|
|
290
|
+
f"expires in {self._token.expires_in} seconds"
|
|
291
|
+
)
|
|
292
|
+
return self._token.authorization_header
|
|
293
|
+
else:
|
|
294
|
+
logger.error("Failed to obtain service token: token is None")
|
|
295
|
+
raise HTTPException(
|
|
296
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
297
|
+
detail="Failed to authenticate service",
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
except CircuitBreakerOpenError:
|
|
301
|
+
logger.error("Circuit breaker is open - Keycloak is unavailable")
|
|
302
|
+
raise HTTPException(
|
|
303
|
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
304
|
+
detail="Authentication service temporarily unavailable",
|
|
305
|
+
)
|
|
306
|
+
except httpx.RequestError as e:
|
|
307
|
+
logger.error(f"Failed to connect to Keycloak: {e}")
|
|
308
|
+
raise HTTPException(
|
|
309
|
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
310
|
+
detail="Authentication service unavailable",
|
|
311
|
+
)
|
|
312
|
+
except Exception as e:
|
|
313
|
+
logger.error(f"Unexpected error getting service token: {e}")
|
|
314
|
+
raise HTTPException(
|
|
315
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
316
|
+
detail="Failed to authenticate service",
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
@property
|
|
320
|
+
def circuit_state(self) -> CircuitState:
|
|
321
|
+
"""Get current circuit breaker state"""
|
|
322
|
+
return self._circuit_breaker.state
|
|
323
|
+
|
|
324
|
+
def reset_circuit(self):
|
|
325
|
+
"""Manually reset circuit breaker"""
|
|
326
|
+
self._circuit_breaker.reset()
|
|
327
|
+
|
|
328
|
+
async def close(self):
|
|
329
|
+
"""Close HTTP client and clean up resources"""
|
|
330
|
+
if self._http_client:
|
|
331
|
+
await self._http_client.aclose()
|
|
332
|
+
self._http_client = None
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
# Global instance management
|
|
336
|
+
_service_auth_client: Optional[ServiceAuthClient] = None
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def get_service_auth_client() -> ServiceAuthClient:
|
|
340
|
+
"""Get or create service auth client"""
|
|
341
|
+
global _service_auth_client
|
|
342
|
+
if _service_auth_client is None:
|
|
343
|
+
_service_auth_client = ServiceAuthClient()
|
|
344
|
+
return _service_auth_client
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
async def cleanup_service_auth():
|
|
348
|
+
"""Cleanup service auth resources"""
|
|
349
|
+
global _service_auth_client
|
|
350
|
+
if _service_auth_client:
|
|
351
|
+
await _service_auth_client.close()
|
|
352
|
+
_service_auth_client = None
|
|
@@ -6,6 +6,7 @@ src/auth_gate/__init__.py
|
|
|
6
6
|
src/auth_gate/config.py
|
|
7
7
|
src/auth_gate/fastapi_utils.py
|
|
8
8
|
src/auth_gate/middleware.py
|
|
9
|
+
src/auth_gate/s2s_auth.py
|
|
9
10
|
src/auth_gate/schemas.py
|
|
10
11
|
src/auth_gate/user_auth.py
|
|
11
12
|
src/auth_gate.egg-info/PKG-INFO
|
|
@@ -19,5 +20,6 @@ src/tests/test_config.py
|
|
|
19
20
|
src/tests/test_fastapi_utils.py
|
|
20
21
|
src/tests/test_intergration.py
|
|
21
22
|
src/tests/test_middleware.py
|
|
23
|
+
src/tests/test_s2s_auth.py
|
|
22
24
|
src/tests/test_schema.py
|
|
23
25
|
src/tests/test_user_auth.py
|
|
@@ -12,6 +12,7 @@ from httpx import AsyncClient
|
|
|
12
12
|
|
|
13
13
|
from auth_gate import AuthSettings, ServiceContext, UserContext
|
|
14
14
|
from auth_gate.config import AuthMode, reset_settings
|
|
15
|
+
from auth_gate.s2s_auth import ServiceAuthClient
|
|
15
16
|
from auth_gate.user_auth import UserValidator, cleanup_user_validator
|
|
16
17
|
|
|
17
18
|
|
|
@@ -60,9 +61,10 @@ def mock_settings():
|
|
|
60
61
|
)
|
|
61
62
|
with patch("auth_gate.config.get_settings", return_value=settings):
|
|
62
63
|
with patch("auth_gate.user_auth.get_settings", return_value=settings):
|
|
63
|
-
with patch("auth_gate.
|
|
64
|
-
with patch("auth_gate.
|
|
65
|
-
|
|
64
|
+
with patch("auth_gate.s2s_auth.get_settings", return_value=settings):
|
|
65
|
+
with patch("auth_gate.middleware.get_settings", return_value=settings):
|
|
66
|
+
with patch("auth_gate.fastapi_utils.get_settings", return_value=settings):
|
|
67
|
+
yield settings
|
|
66
68
|
|
|
67
69
|
|
|
68
70
|
@pytest_asyncio.fixture(autouse=True)
|
|
@@ -231,3 +233,17 @@ async def bypass_validator():
|
|
|
231
233
|
|
|
232
234
|
# Clean up after each test
|
|
233
235
|
await validator.close()
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@pytest_asyncio.fixture
|
|
239
|
+
async def service_auth_client():
|
|
240
|
+
"""Fixture to create and clean up UserValidator instance"""
|
|
241
|
+
# Reset settings to ensure clean state
|
|
242
|
+
reset_settings()
|
|
243
|
+
|
|
244
|
+
# Create validator with DIRECT_KEYCLOAK mode
|
|
245
|
+
auth_client = ServiceAuthClient()
|
|
246
|
+
yield auth_client
|
|
247
|
+
|
|
248
|
+
# Clean up after each test
|
|
249
|
+
await auth_client.close()
|
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tests for service-to-service authentication
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from datetime import datetime, timedelta, timezone
|
|
7
|
+
from unittest.mock import AsyncMock, patch
|
|
8
|
+
|
|
9
|
+
import httpx
|
|
10
|
+
import pytest
|
|
11
|
+
from fastapi import HTTPException
|
|
12
|
+
|
|
13
|
+
from auth_gate.s2s_auth import (
|
|
14
|
+
CircuitBreaker,
|
|
15
|
+
CircuitBreakerOpenError,
|
|
16
|
+
CircuitState,
|
|
17
|
+
ServiceToken,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TestCircuitBreaker:
|
|
22
|
+
"""Test CircuitBreaker implementation"""
|
|
23
|
+
|
|
24
|
+
@pytest.mark.asyncio
|
|
25
|
+
async def test_circuit_breaker_closed_state(self):
|
|
26
|
+
"""Test circuit breaker in closed state allows calls"""
|
|
27
|
+
breaker = CircuitBreaker(failure_threshold=3)
|
|
28
|
+
|
|
29
|
+
assert breaker.state == CircuitState.CLOSED
|
|
30
|
+
assert breaker.is_closed is True
|
|
31
|
+
|
|
32
|
+
# Successful calls should pass through
|
|
33
|
+
async with breaker:
|
|
34
|
+
pass # Successful operation
|
|
35
|
+
|
|
36
|
+
assert breaker.state == CircuitState.CLOSED
|
|
37
|
+
|
|
38
|
+
@pytest.mark.asyncio
|
|
39
|
+
async def test_circuit_breaker_opens_after_threshold(self):
|
|
40
|
+
"""Test circuit breaker opens after failure threshold"""
|
|
41
|
+
breaker = CircuitBreaker(failure_threshold=3)
|
|
42
|
+
|
|
43
|
+
# Simulate failures
|
|
44
|
+
for i in range(3):
|
|
45
|
+
try:
|
|
46
|
+
async with breaker:
|
|
47
|
+
raise ValueError("Test failure")
|
|
48
|
+
except ValueError:
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
assert breaker.state == CircuitState.OPEN
|
|
52
|
+
assert breaker.is_open is True
|
|
53
|
+
|
|
54
|
+
@pytest.mark.asyncio
|
|
55
|
+
async def test_circuit_breaker_rejects_calls_when_open(self):
|
|
56
|
+
"""Test circuit breaker rejects calls when open"""
|
|
57
|
+
breaker = CircuitBreaker(failure_threshold=2)
|
|
58
|
+
|
|
59
|
+
# Open the circuit
|
|
60
|
+
for i in range(2):
|
|
61
|
+
try:
|
|
62
|
+
async with breaker:
|
|
63
|
+
raise ValueError("Test failure")
|
|
64
|
+
except ValueError:
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
# Should reject calls when open
|
|
68
|
+
with pytest.raises(CircuitBreakerOpenError):
|
|
69
|
+
async with breaker:
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
@pytest.mark.asyncio
|
|
73
|
+
async def test_circuit_breaker_half_open_recovery(self):
|
|
74
|
+
"""Test circuit breaker recovery through half-open state"""
|
|
75
|
+
breaker = CircuitBreaker(failure_threshold=2, recovery_timeout=0) # Instant recovery
|
|
76
|
+
|
|
77
|
+
# Open the circuit
|
|
78
|
+
for i in range(2):
|
|
79
|
+
try:
|
|
80
|
+
async with breaker:
|
|
81
|
+
raise ValueError("Test failure")
|
|
82
|
+
except ValueError:
|
|
83
|
+
pass
|
|
84
|
+
|
|
85
|
+
assert breaker.state == CircuitState.OPEN
|
|
86
|
+
|
|
87
|
+
# Wait a tiny bit for recovery timeout
|
|
88
|
+
await asyncio.sleep(0.01)
|
|
89
|
+
|
|
90
|
+
# Should transition to half-open and allow one call
|
|
91
|
+
async with breaker:
|
|
92
|
+
pass # Successful call
|
|
93
|
+
|
|
94
|
+
assert breaker.state == CircuitState.CLOSED
|
|
95
|
+
|
|
96
|
+
@pytest.mark.asyncio
|
|
97
|
+
async def test_circuit_breaker_half_open_failure(self):
|
|
98
|
+
"""Test circuit breaker reopens on half-open failure"""
|
|
99
|
+
breaker = CircuitBreaker(failure_threshold=2, recovery_timeout=0)
|
|
100
|
+
|
|
101
|
+
# Open the circuit
|
|
102
|
+
for i in range(2):
|
|
103
|
+
try:
|
|
104
|
+
async with breaker:
|
|
105
|
+
raise ValueError("Test failure")
|
|
106
|
+
except ValueError:
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
# Wait for recovery
|
|
110
|
+
await asyncio.sleep(0.01)
|
|
111
|
+
|
|
112
|
+
# Fail during half-open
|
|
113
|
+
try:
|
|
114
|
+
async with breaker:
|
|
115
|
+
raise ValueError("Recovery failure")
|
|
116
|
+
except ValueError:
|
|
117
|
+
pass
|
|
118
|
+
|
|
119
|
+
assert breaker.state == CircuitState.OPEN
|
|
120
|
+
|
|
121
|
+
@pytest.mark.asyncio
|
|
122
|
+
async def test_circuit_breaker_call_wrapper(self):
|
|
123
|
+
"""Test circuit breaker call wrapper method"""
|
|
124
|
+
breaker = CircuitBreaker(failure_threshold=3)
|
|
125
|
+
|
|
126
|
+
async def successful_func(x, y):
|
|
127
|
+
return x + y
|
|
128
|
+
|
|
129
|
+
result = await breaker.call(successful_func, 2, 3)
|
|
130
|
+
assert result == 5
|
|
131
|
+
|
|
132
|
+
async def failing_func():
|
|
133
|
+
raise ValueError("Test error")
|
|
134
|
+
|
|
135
|
+
with pytest.raises(ValueError):
|
|
136
|
+
await breaker.call(failing_func)
|
|
137
|
+
|
|
138
|
+
def test_circuit_breaker_reset(self):
|
|
139
|
+
"""Test manual circuit breaker reset"""
|
|
140
|
+
breaker = CircuitBreaker(failure_threshold=1)
|
|
141
|
+
|
|
142
|
+
# Open the circuit by causing a failure
|
|
143
|
+
with pytest.raises(ValueError):
|
|
144
|
+
asyncio.run(breaker.call(lambda: (_ for _ in ()).throw(ValueError("Test error"))))
|
|
145
|
+
|
|
146
|
+
# At this point, the circuit should be open
|
|
147
|
+
assert breaker.state == CircuitState.OPEN
|
|
148
|
+
|
|
149
|
+
# Reset manually
|
|
150
|
+
breaker.reset()
|
|
151
|
+
|
|
152
|
+
assert breaker.state == CircuitState.CLOSED
|
|
153
|
+
assert breaker._failure_count == 0
|
|
154
|
+
assert breaker._last_failure_time is None
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class TestServiceToken:
|
|
158
|
+
"""Test ServiceToken model"""
|
|
159
|
+
|
|
160
|
+
def test_service_token_not_expired(self):
|
|
161
|
+
"""Test service token expiration check"""
|
|
162
|
+
token = ServiceToken(
|
|
163
|
+
access_token="test-token",
|
|
164
|
+
token_type="Bearer",
|
|
165
|
+
expires_in=300,
|
|
166
|
+
created_at=datetime.now(timezone.utc),
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
assert token.is_expired is False
|
|
170
|
+
assert token.authorization_header == "Bearer test-token"
|
|
171
|
+
|
|
172
|
+
def test_service_token_expired(self):
|
|
173
|
+
"""Test expired service token"""
|
|
174
|
+
token = ServiceToken(
|
|
175
|
+
access_token="test-token",
|
|
176
|
+
token_type="Bearer",
|
|
177
|
+
expires_in=300,
|
|
178
|
+
created_at=datetime.now(timezone.utc) - timedelta(seconds=400),
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
assert token.is_expired is True
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class TestServiceAuthClient:
|
|
185
|
+
"""Test ServiceAuthClient"""
|
|
186
|
+
|
|
187
|
+
@pytest.mark.asyncio
|
|
188
|
+
async def test_get_service_token_success(
|
|
189
|
+
self, mock_settings, service_auth_client, service_token_response
|
|
190
|
+
):
|
|
191
|
+
"""Test successful service token acquisition"""
|
|
192
|
+
with patch.object(httpx.AsyncClient, "post", new=AsyncMock()) as mock_post:
|
|
193
|
+
mock_post.return_value = httpx.Response(status_code=200, json=service_token_response)
|
|
194
|
+
token = await service_auth_client.get_service_token()
|
|
195
|
+
|
|
196
|
+
assert token == "Bearer eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.test"
|
|
197
|
+
assert service_auth_client._token is not None
|
|
198
|
+
mock_post.assert_called_once()
|
|
199
|
+
|
|
200
|
+
@pytest.mark.asyncio
|
|
201
|
+
async def test_get_service_token_cached(
|
|
202
|
+
self, mock_settings, service_auth_client, service_token_response
|
|
203
|
+
):
|
|
204
|
+
"""Test service token caching"""
|
|
205
|
+
with patch.object(httpx.AsyncClient, "post", new=AsyncMock()) as mock_post:
|
|
206
|
+
mock_post.return_value = httpx.Response(status_code=200, json=service_token_response)
|
|
207
|
+
# First call
|
|
208
|
+
token1 = await service_auth_client.get_service_token()
|
|
209
|
+
|
|
210
|
+
# Second call should use cache
|
|
211
|
+
token2 = await service_auth_client.get_service_token()
|
|
212
|
+
|
|
213
|
+
assert token1 == token2
|
|
214
|
+
# HTTP client should only be called once
|
|
215
|
+
mock_post.assert_called_once()
|
|
216
|
+
|
|
217
|
+
@pytest.mark.asyncio
|
|
218
|
+
async def test_get_service_token_refresh_on_expiry(
|
|
219
|
+
self, mock_settings, service_auth_client, service_token_response
|
|
220
|
+
):
|
|
221
|
+
"""Test service token refresh on expiry"""
|
|
222
|
+
# Set an already expired token
|
|
223
|
+
service_auth_client._token = ServiceToken(
|
|
224
|
+
access_token="old-token",
|
|
225
|
+
token_type="Bearer",
|
|
226
|
+
expires_in=300,
|
|
227
|
+
created_at=datetime.now(timezone.utc) - timedelta(seconds=400),
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
with patch.object(httpx.AsyncClient, "post", new=AsyncMock()) as mock_post:
|
|
231
|
+
mock_post.return_value = httpx.Response(status_code=200, json=service_token_response)
|
|
232
|
+
|
|
233
|
+
token = await service_auth_client.get_service_token()
|
|
234
|
+
|
|
235
|
+
assert token == "Bearer eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.test"
|
|
236
|
+
mock_post.assert_called_once()
|
|
237
|
+
|
|
238
|
+
@pytest.mark.asyncio
|
|
239
|
+
async def test_get_service_token_circuit_breaker_open(self, mock_settings, service_auth_client):
|
|
240
|
+
"""Test service token with open circuit breaker"""
|
|
241
|
+
# Manually open the circuit breaker
|
|
242
|
+
service_auth_client._circuit_breaker._state = CircuitState.OPEN
|
|
243
|
+
service_auth_client._circuit_breaker._last_failure_time = datetime.now(timezone.utc)
|
|
244
|
+
|
|
245
|
+
with pytest.raises(HTTPException) as exc_info:
|
|
246
|
+
await service_auth_client.get_service_token()
|
|
247
|
+
|
|
248
|
+
assert exc_info.value.status_code == 503
|
|
249
|
+
assert "temporarily unavailable" in exc_info.value.detail
|
|
250
|
+
|
|
251
|
+
@pytest.mark.asyncio
|
|
252
|
+
async def test_get_service_token_connection_error(self, mock_settings, service_auth_client):
|
|
253
|
+
"""Test service token with connection error"""
|
|
254
|
+
with patch.object(httpx.AsyncClient, "post", new=AsyncMock()) as mock_post:
|
|
255
|
+
mock_post.side_effect = httpx.RequestError("Connection failed")
|
|
256
|
+
|
|
257
|
+
with pytest.raises(HTTPException) as exc_info:
|
|
258
|
+
await service_auth_client.get_service_token()
|
|
259
|
+
|
|
260
|
+
assert exc_info.value.status_code == 503
|
|
261
|
+
|
|
262
|
+
@pytest.mark.asyncio
|
|
263
|
+
async def test_get_service_token_circuit_breaker_integration(
|
|
264
|
+
self, mock_settings, service_auth_client
|
|
265
|
+
):
|
|
266
|
+
"""Test circuit breaker integration with service token"""
|
|
267
|
+
# Simulate multiple failures to open circuit
|
|
268
|
+
with patch.object(httpx.AsyncClient, "post", new=AsyncMock()) as mock_post:
|
|
269
|
+
mock_post.side_effect = httpx.RequestError("Connection failed")
|
|
270
|
+
|
|
271
|
+
# First few failures should be allowed
|
|
272
|
+
for i in range(mock_settings.CIRCUIT_BREAKER_FAILURE_THRESHOLD):
|
|
273
|
+
with pytest.raises(HTTPException):
|
|
274
|
+
await service_auth_client.get_service_token()
|
|
275
|
+
|
|
276
|
+
# Circuit should now be open
|
|
277
|
+
assert service_auth_client.circuit_state == CircuitState.OPEN
|
|
278
|
+
|
|
279
|
+
# Next call should be rejected immediately
|
|
280
|
+
with pytest.raises(HTTPException) as exc_info:
|
|
281
|
+
await service_auth_client.get_service_token()
|
|
282
|
+
|
|
283
|
+
assert "temporarily unavailable" in exc_info.value.detail
|
|
284
|
+
|
|
285
|
+
def test_circuit_state_property(self, service_auth_client):
|
|
286
|
+
"""Test circuit state property"""
|
|
287
|
+
assert service_auth_client.circuit_state == CircuitState.CLOSED
|
|
288
|
+
|
|
289
|
+
def test_reset_circuit(self, service_auth_client):
|
|
290
|
+
"""Test manual circuit reset"""
|
|
291
|
+
service_auth_client._circuit_breaker._state = CircuitState.OPEN
|
|
292
|
+
|
|
293
|
+
service_auth_client.reset_circuit()
|
|
294
|
+
|
|
295
|
+
assert service_auth_client.circuit_state == CircuitState.CLOSED
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|