auth-gate 0.2.0__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. {auth_gate-0.2.0/src/auth_gate.egg-info → auth_gate-0.2.2}/PKG-INFO +1 -1
  2. {auth_gate-0.2.0 → auth_gate-0.2.2}/pyproject.toml +1 -1
  3. {auth_gate-0.2.0 → auth_gate-0.2.2}/setup.cfg +1 -1
  4. {auth_gate-0.2.0 → auth_gate-0.2.2}/src/auth_gate/__init__.py +6 -1
  5. auth_gate-0.2.2/src/auth_gate/s2s_auth.py +352 -0
  6. {auth_gate-0.2.0 → auth_gate-0.2.2/src/auth_gate.egg-info}/PKG-INFO +1 -1
  7. {auth_gate-0.2.0 → auth_gate-0.2.2}/src/auth_gate.egg-info/SOURCES.txt +2 -0
  8. {auth_gate-0.2.0 → auth_gate-0.2.2}/src/tests/conftest.py +19 -3
  9. auth_gate-0.2.2/src/tests/test_s2s_auth.py +295 -0
  10. {auth_gate-0.2.0 → auth_gate-0.2.2}/LICENSE +0 -0
  11. {auth_gate-0.2.0 → auth_gate-0.2.2}/README.md +0 -0
  12. {auth_gate-0.2.0 → auth_gate-0.2.2}/src/auth_gate/config.py +0 -0
  13. {auth_gate-0.2.0 → auth_gate-0.2.2}/src/auth_gate/fastapi_utils.py +0 -0
  14. {auth_gate-0.2.0 → auth_gate-0.2.2}/src/auth_gate/middleware.py +0 -0
  15. {auth_gate-0.2.0 → auth_gate-0.2.2}/src/auth_gate/schemas.py +0 -0
  16. {auth_gate-0.2.0 → auth_gate-0.2.2}/src/auth_gate/user_auth.py +0 -0
  17. {auth_gate-0.2.0 → auth_gate-0.2.2}/src/auth_gate.egg-info/dependency_links.txt +0 -0
  18. {auth_gate-0.2.0 → auth_gate-0.2.2}/src/auth_gate.egg-info/requires.txt +0 -0
  19. {auth_gate-0.2.0 → auth_gate-0.2.2}/src/auth_gate.egg-info/top_level.txt +0 -0
  20. {auth_gate-0.2.0 → auth_gate-0.2.2}/src/tests/__init__.py +0 -0
  21. {auth_gate-0.2.0 → auth_gate-0.2.2}/src/tests/test_config.py +0 -0
  22. {auth_gate-0.2.0 → auth_gate-0.2.2}/src/tests/test_fastapi_utils.py +0 -0
  23. {auth_gate-0.2.0 → auth_gate-0.2.2}/src/tests/test_intergration.py +0 -0
  24. {auth_gate-0.2.0 → auth_gate-0.2.2}/src/tests/test_middleware.py +0 -0
  25. {auth_gate-0.2.0 → auth_gate-0.2.2}/src/tests/test_schema.py +0 -0
  26. {auth_gate-0.2.0 → auth_gate-0.2.2}/src/tests/test_user_auth.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: auth-gate
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: Enterprise-grade authentication for microservices with Kong and Keycloak integration
5
5
  Home-page: https://github.com/tradelink-org/auth-gate
6
6
  Author: Brian Mburu
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "auth-gate"
7
- version = "0.2.0"
7
+ version = "0.2.2"
8
8
  description = "Enterprise-grade authentication for microservices with Kong and Keycloak integration"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.11"
@@ -1,6 +1,6 @@
1
1
  [metadata]
2
2
  name = auth-gate
3
- version = 0.2.0
3
+ version = 0.2.2
4
4
  author = Brian Mburu
5
5
  author_email = brian.mburu@students.jkuat.ac.ke
6
6
  description = Enterprise-grade authentication for microservices with Kong and Keycloak integration
@@ -25,10 +25,11 @@ from .fastapi_utils import (
25
25
  verify_hmac_signature,
26
26
  )
27
27
  from .middleware import AuthMiddleware
28
+ from .s2s_auth import CircuitBreaker, CircuitBreakerOpenError, ServiceAuthClient
28
29
  from .schemas import ServiceContext, UserContext
29
30
  from .user_auth import UserValidator
30
31
 
31
- __version__ = "0.2.0"
32
+ __version__ = "0.2.2"
32
33
 
33
34
  __all__ = [
34
35
  # Configuration
@@ -39,6 +40,10 @@ __all__ = [
39
40
  "ServiceContext",
40
41
  # User Authentication
41
42
  "UserValidator",
43
+ # Service-to-Service
44
+ "ServiceAuthClient",
45
+ "CircuitBreaker",
46
+ "CircuitBreakerOpenError",
42
47
  # Middleware
43
48
  "AuthMiddleware",
44
49
  # FastAPI Dependencies
@@ -0,0 +1,352 @@
1
+ """
2
+ Service-to-service authentication with resilience patterns
3
+ """
4
+
5
+ import asyncio
6
+ import logging
7
+ from dataclasses import dataclass
8
+ from datetime import datetime, timedelta, timezone
9
+ from enum import Enum
10
+ from typing import Optional, Tuple, Type, Union
11
+
12
+ import httpx
13
+ from fastapi import HTTPException, status
14
+
15
+ from .config import get_settings
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class CircuitState(Enum):
21
+ """Circuit breaker states"""
22
+
23
+ CLOSED = "closed" # Normal operation
24
+ OPEN = "open" # Failing, reject calls
25
+ HALF_OPEN = "half_open" # Testing recovery
26
+
27
+
28
+ class CircuitBreakerOpenError(Exception):
29
+ """Exception raised when circuit breaker is open"""
30
+
31
+ pass
32
+
33
+
34
+ class CircuitBreaker:
35
+ """
36
+ Circuit breaker implementation for fault tolerance.
37
+
38
+ States:
39
+ - CLOSED: Normal operation, calls pass through
40
+ - OPEN: Service is failing, calls are rejected
41
+ - HALF_OPEN: Testing if service recovered
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ failure_threshold: int = 5,
47
+ recovery_timeout: int = 60,
48
+ expected_exception: Union[Type[Exception], Tuple[Type[Exception], ...]] = Exception,
49
+ name: Optional[str] = None,
50
+ ):
51
+ """
52
+ Initialize circuit breaker.
53
+
54
+ Args:
55
+ failure_threshold: Number of failures before opening circuit
56
+ recovery_timeout: Seconds to wait before attempting recovery
57
+ expected_exception: Exception type to catch
58
+ name: Circuit breaker name for logging
59
+ """
60
+ self.failure_threshold = failure_threshold
61
+ self.recovery_timeout = recovery_timeout
62
+ self.expected_exception = expected_exception
63
+ self.name = name or "CircuitBreaker"
64
+
65
+ self._state = CircuitState.CLOSED
66
+ self._failure_count = 0
67
+ self._last_failure_time: Optional[datetime] = None
68
+ self._lock = asyncio.Lock()
69
+
70
+ @property
71
+ def state(self) -> CircuitState:
72
+ """Get current circuit state"""
73
+ return self._state
74
+
75
+ @property
76
+ def is_open(self) -> bool:
77
+ """Check if circuit is open"""
78
+ return self._state == CircuitState.OPEN
79
+
80
+ @property
81
+ def is_closed(self) -> bool:
82
+ """Check if circuit is closed"""
83
+ return self._state == CircuitState.CLOSED
84
+
85
+ async def __aenter__(self):
86
+ """Context manager entry"""
87
+ await self._before_call()
88
+ return self
89
+
90
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
91
+ """Context manager exit"""
92
+ if exc_type is None:
93
+ await self._on_success()
94
+ elif issubclass(exc_type, self.expected_exception):
95
+ await self._on_failure()
96
+ # Don't suppress the exception
97
+ return False
98
+
99
+ async def _before_call(self):
100
+ """Check circuit state before allowing call"""
101
+ async with self._lock:
102
+ if self._state == CircuitState.OPEN:
103
+ # Check if recovery timeout has passed
104
+ if self._last_failure_time and datetime.now(
105
+ timezone.utc
106
+ ) - self._last_failure_time > timedelta(seconds=self.recovery_timeout):
107
+ logger.info(f"{self.name}: Attempting recovery (HALF_OPEN)")
108
+ self._state = CircuitState.HALF_OPEN
109
+ else:
110
+ raise CircuitBreakerOpenError(f"{self.name}: Circuit is OPEN, rejecting call")
111
+
112
+ async def _on_success(self):
113
+ """Handle successful call"""
114
+ async with self._lock:
115
+ if self._state == CircuitState.HALF_OPEN:
116
+ logger.info(f"{self.name}: Recovery successful, closing circuit")
117
+ self._state = CircuitState.CLOSED
118
+ self._failure_count = 0
119
+ self._last_failure_time = None
120
+ elif self._state == CircuitState.CLOSED:
121
+ # Reset failure count on success
122
+ self._failure_count = 0
123
+
124
+ async def _on_failure(self):
125
+ """Handle failed call"""
126
+ async with self._lock:
127
+ self._failure_count += 1
128
+ self._last_failure_time = datetime.now(timezone.utc)
129
+
130
+ if self._state == CircuitState.HALF_OPEN:
131
+ logger.warning(f"{self.name}: Recovery failed, reopening circuit")
132
+ self._state = CircuitState.OPEN
133
+ elif self._state == CircuitState.CLOSED:
134
+ if self._failure_count >= self.failure_threshold:
135
+ logger.error(
136
+ f"{self.name}: Failure threshold reached ({self._failure_count}), "
137
+ f"opening circuit"
138
+ )
139
+ self._state = CircuitState.OPEN
140
+ else:
141
+ logger.warning(
142
+ f"{self.name}: Failure {self._failure_count}/{self.failure_threshold}"
143
+ )
144
+
145
+ def reset(self):
146
+ """Manually reset circuit breaker"""
147
+ self._state = CircuitState.CLOSED
148
+ self._failure_count = 0
149
+ self._last_failure_time = None
150
+ logger.info(f"{self.name}: Circuit manually reset")
151
+
152
+ async def call(self, func, *args, **kwargs):
153
+ """
154
+ Execute a function with circuit breaker protection.
155
+
156
+ Args:
157
+ func: Async function to execute
158
+ *args: Positional arguments for func
159
+ **kwargs: Keyword arguments for func
160
+
161
+ Returns:
162
+ Result of func
163
+
164
+ Raises:
165
+ CircuitBreakerOpenError: If circuit is open
166
+ Exception: Any exception raised by func
167
+ """
168
+ async with self:
169
+ return await func(*args, **kwargs)
170
+
171
+
172
+ @dataclass
173
+ class ServiceToken:
174
+ """Service account token"""
175
+
176
+ access_token: str
177
+ token_type: str
178
+ expires_in: int
179
+ created_at: datetime = datetime.now(timezone.utc) # Default to current time
180
+
181
+ @property
182
+ def is_expired(self) -> bool:
183
+ """Check if token is expired (with buffer)"""
184
+ settings = get_settings()
185
+ expiry_time = self.created_at + timedelta(
186
+ seconds=self.expires_in - settings.TOKEN_REFRESH_BUFFER
187
+ )
188
+ return datetime.now(timezone.utc) > expiry_time
189
+
190
+ @property
191
+ def authorization_header(self) -> str:
192
+ """Get authorization header value"""
193
+ return f"{self.token_type} {self.access_token}"
194
+
195
+
196
+ class ServiceAuthClient:
197
+ """
198
+ Client for obtaining and managing service account tokens from Keycloak.
199
+ Includes circuit breaker for resilience against Keycloak failures.
200
+ """
201
+
202
+ def __init__(self):
203
+ """Initialize service auth client with circuit breaker"""
204
+ settings = get_settings()
205
+ self.token_url = f"{settings.KEYCLOAK_REALM_URL}/protocol/openid-connect/token"
206
+ self.client_id = settings.SERVICE_CLIENT_ID
207
+ self.client_secret = settings.SERVICE_CLIENT_SECRET
208
+ self._token: Optional[ServiceToken] = None
209
+ self._lock = asyncio.Lock()
210
+ self._http_client: Optional[httpx.AsyncClient] = None
211
+
212
+ # Initialize circuit breaker
213
+ self._circuit_breaker: CircuitBreaker = CircuitBreaker(
214
+ failure_threshold=settings.CIRCUIT_BREAKER_FAILURE_THRESHOLD,
215
+ recovery_timeout=settings.CIRCUIT_BREAKER_RECOVERY_TIMEOUT,
216
+ expected_exception=(httpx.RequestError, HTTPException),
217
+ name="ServiceAuthCircuit",
218
+ )
219
+
220
+ @property
221
+ async def http_client(self) -> httpx.AsyncClient:
222
+ """Get or create HTTP client"""
223
+ if self._http_client is None:
224
+ settings = get_settings()
225
+ self._http_client = httpx.AsyncClient(
226
+ timeout=httpx.Timeout(settings.HTTP_TIMEOUT),
227
+ limits=httpx.Limits(max_keepalive_connections=5),
228
+ )
229
+ return self._http_client
230
+
231
+ async def _fetch_token(self) -> ServiceToken:
232
+ """
233
+ Internal method to fetch token from Keycloak.
234
+
235
+ Returns:
236
+ ServiceToken
237
+
238
+ Raises:
239
+ HTTPException: If token fetch fails
240
+ """
241
+ client = await self.http_client
242
+ response = await client.post(
243
+ self.token_url,
244
+ data={
245
+ "grant_type": "client_credentials",
246
+ "client_id": self.client_id,
247
+ "client_secret": self.client_secret,
248
+ "scope": "openid profile",
249
+ },
250
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
251
+ )
252
+
253
+ if response.status_code != 200:
254
+ logger.error(f"Failed to get service token: {response.status_code} - {response.text}")
255
+ raise HTTPException(
256
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
257
+ detail="Unable to authenticate service",
258
+ )
259
+
260
+ token_data = response.json()
261
+ return ServiceToken(
262
+ access_token=token_data["access_token"],
263
+ token_type=token_data.get("token_type", "Bearer"),
264
+ expires_in=token_data.get("expires_in", 300),
265
+ )
266
+
267
+ async def get_service_token(self) -> str:
268
+ """
269
+ Get service account token for service-to-service calls.
270
+ Uses circuit breaker for resilience against Keycloak failures.
271
+
272
+ Returns:
273
+ Authorization header value (e.g., "Bearer token")
274
+
275
+ Raises:
276
+ CircuitBreakerOpenError: If circuit is open due to repeated failures
277
+ HTTPException: If token fetch fails
278
+ """
279
+ async with self._lock:
280
+ # Return cached token if still valid
281
+ if self._token and not self._token.is_expired:
282
+ return self._token.authorization_header
283
+
284
+ # Get new token with circuit breaker protection
285
+ try:
286
+ self._token = await self._circuit_breaker.call(self._fetch_token)
287
+ if self._token is not None:
288
+ logger.info(
289
+ f"Obtained new service token for {self.client_id}, "
290
+ f"expires in {self._token.expires_in} seconds"
291
+ )
292
+ return self._token.authorization_header
293
+ else:
294
+ logger.error("Failed to obtain service token: token is None")
295
+ raise HTTPException(
296
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
297
+ detail="Failed to authenticate service",
298
+ )
299
+
300
+ except CircuitBreakerOpenError:
301
+ logger.error("Circuit breaker is open - Keycloak is unavailable")
302
+ raise HTTPException(
303
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
304
+ detail="Authentication service temporarily unavailable",
305
+ )
306
+ except httpx.RequestError as e:
307
+ logger.error(f"Failed to connect to Keycloak: {e}")
308
+ raise HTTPException(
309
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
310
+ detail="Authentication service unavailable",
311
+ )
312
+ except Exception as e:
313
+ logger.error(f"Unexpected error getting service token: {e}")
314
+ raise HTTPException(
315
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
316
+ detail="Failed to authenticate service",
317
+ )
318
+
319
+ @property
320
+ def circuit_state(self) -> CircuitState:
321
+ """Get current circuit breaker state"""
322
+ return self._circuit_breaker.state
323
+
324
+ def reset_circuit(self):
325
+ """Manually reset circuit breaker"""
326
+ self._circuit_breaker.reset()
327
+
328
+ async def close(self):
329
+ """Close HTTP client and clean up resources"""
330
+ if self._http_client:
331
+ await self._http_client.aclose()
332
+ self._http_client = None
333
+
334
+
335
+ # Global instance management
336
+ _service_auth_client: Optional[ServiceAuthClient] = None
337
+
338
+
339
+ def get_service_auth_client() -> ServiceAuthClient:
340
+ """Get or create service auth client"""
341
+ global _service_auth_client
342
+ if _service_auth_client is None:
343
+ _service_auth_client = ServiceAuthClient()
344
+ return _service_auth_client
345
+
346
+
347
+ async def cleanup_service_auth():
348
+ """Cleanup service auth resources"""
349
+ global _service_auth_client
350
+ if _service_auth_client:
351
+ await _service_auth_client.close()
352
+ _service_auth_client = None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: auth-gate
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: Enterprise-grade authentication for microservices with Kong and Keycloak integration
5
5
  Home-page: https://github.com/tradelink-org/auth-gate
6
6
  Author: Brian Mburu
@@ -6,6 +6,7 @@ src/auth_gate/__init__.py
6
6
  src/auth_gate/config.py
7
7
  src/auth_gate/fastapi_utils.py
8
8
  src/auth_gate/middleware.py
9
+ src/auth_gate/s2s_auth.py
9
10
  src/auth_gate/schemas.py
10
11
  src/auth_gate/user_auth.py
11
12
  src/auth_gate.egg-info/PKG-INFO
@@ -19,5 +20,6 @@ src/tests/test_config.py
19
20
  src/tests/test_fastapi_utils.py
20
21
  src/tests/test_intergration.py
21
22
  src/tests/test_middleware.py
23
+ src/tests/test_s2s_auth.py
22
24
  src/tests/test_schema.py
23
25
  src/tests/test_user_auth.py
@@ -12,6 +12,7 @@ from httpx import AsyncClient
12
12
 
13
13
  from auth_gate import AuthSettings, ServiceContext, UserContext
14
14
  from auth_gate.config import AuthMode, reset_settings
15
+ from auth_gate.s2s_auth import ServiceAuthClient
15
16
  from auth_gate.user_auth import UserValidator, cleanup_user_validator
16
17
 
17
18
 
@@ -60,9 +61,10 @@ def mock_settings():
60
61
  )
61
62
  with patch("auth_gate.config.get_settings", return_value=settings):
62
63
  with patch("auth_gate.user_auth.get_settings", return_value=settings):
63
- with patch("auth_gate.middleware.get_settings", return_value=settings):
64
- with patch("auth_gate.fastapi_utils.get_settings", return_value=settings):
65
- yield settings
64
+ with patch("auth_gate.s2s_auth.get_settings", return_value=settings):
65
+ with patch("auth_gate.middleware.get_settings", return_value=settings):
66
+ with patch("auth_gate.fastapi_utils.get_settings", return_value=settings):
67
+ yield settings
66
68
 
67
69
 
68
70
  @pytest_asyncio.fixture(autouse=True)
@@ -231,3 +233,17 @@ async def bypass_validator():
231
233
 
232
234
  # Clean up after each test
233
235
  await validator.close()
236
+
237
+
238
+ @pytest_asyncio.fixture
239
+ async def service_auth_client():
240
+ """Fixture to create and clean up UserValidator instance"""
241
+ # Reset settings to ensure clean state
242
+ reset_settings()
243
+
244
+ # Create validator with DIRECT_KEYCLOAK mode
245
+ auth_client = ServiceAuthClient()
246
+ yield auth_client
247
+
248
+ # Clean up after each test
249
+ await auth_client.close()
@@ -0,0 +1,295 @@
1
+ """
2
+ Tests for service-to-service authentication
3
+ """
4
+
5
+ import asyncio
6
+ from datetime import datetime, timedelta, timezone
7
+ from unittest.mock import AsyncMock, patch
8
+
9
+ import httpx
10
+ import pytest
11
+ from fastapi import HTTPException
12
+
13
+ from auth_gate.s2s_auth import (
14
+ CircuitBreaker,
15
+ CircuitBreakerOpenError,
16
+ CircuitState,
17
+ ServiceToken,
18
+ )
19
+
20
+
21
+ class TestCircuitBreaker:
22
+ """Test CircuitBreaker implementation"""
23
+
24
+ @pytest.mark.asyncio
25
+ async def test_circuit_breaker_closed_state(self):
26
+ """Test circuit breaker in closed state allows calls"""
27
+ breaker = CircuitBreaker(failure_threshold=3)
28
+
29
+ assert breaker.state == CircuitState.CLOSED
30
+ assert breaker.is_closed is True
31
+
32
+ # Successful calls should pass through
33
+ async with breaker:
34
+ pass # Successful operation
35
+
36
+ assert breaker.state == CircuitState.CLOSED
37
+
38
+ @pytest.mark.asyncio
39
+ async def test_circuit_breaker_opens_after_threshold(self):
40
+ """Test circuit breaker opens after failure threshold"""
41
+ breaker = CircuitBreaker(failure_threshold=3)
42
+
43
+ # Simulate failures
44
+ for i in range(3):
45
+ try:
46
+ async with breaker:
47
+ raise ValueError("Test failure")
48
+ except ValueError:
49
+ pass
50
+
51
+ assert breaker.state == CircuitState.OPEN
52
+ assert breaker.is_open is True
53
+
54
+ @pytest.mark.asyncio
55
+ async def test_circuit_breaker_rejects_calls_when_open(self):
56
+ """Test circuit breaker rejects calls when open"""
57
+ breaker = CircuitBreaker(failure_threshold=2)
58
+
59
+ # Open the circuit
60
+ for i in range(2):
61
+ try:
62
+ async with breaker:
63
+ raise ValueError("Test failure")
64
+ except ValueError:
65
+ pass
66
+
67
+ # Should reject calls when open
68
+ with pytest.raises(CircuitBreakerOpenError):
69
+ async with breaker:
70
+ pass
71
+
72
+ @pytest.mark.asyncio
73
+ async def test_circuit_breaker_half_open_recovery(self):
74
+ """Test circuit breaker recovery through half-open state"""
75
+ breaker = CircuitBreaker(failure_threshold=2, recovery_timeout=0) # Instant recovery
76
+
77
+ # Open the circuit
78
+ for i in range(2):
79
+ try:
80
+ async with breaker:
81
+ raise ValueError("Test failure")
82
+ except ValueError:
83
+ pass
84
+
85
+ assert breaker.state == CircuitState.OPEN
86
+
87
+ # Wait a tiny bit for recovery timeout
88
+ await asyncio.sleep(0.01)
89
+
90
+ # Should transition to half-open and allow one call
91
+ async with breaker:
92
+ pass # Successful call
93
+
94
+ assert breaker.state == CircuitState.CLOSED
95
+
96
+ @pytest.mark.asyncio
97
+ async def test_circuit_breaker_half_open_failure(self):
98
+ """Test circuit breaker reopens on half-open failure"""
99
+ breaker = CircuitBreaker(failure_threshold=2, recovery_timeout=0)
100
+
101
+ # Open the circuit
102
+ for i in range(2):
103
+ try:
104
+ async with breaker:
105
+ raise ValueError("Test failure")
106
+ except ValueError:
107
+ pass
108
+
109
+ # Wait for recovery
110
+ await asyncio.sleep(0.01)
111
+
112
+ # Fail during half-open
113
+ try:
114
+ async with breaker:
115
+ raise ValueError("Recovery failure")
116
+ except ValueError:
117
+ pass
118
+
119
+ assert breaker.state == CircuitState.OPEN
120
+
121
+ @pytest.mark.asyncio
122
+ async def test_circuit_breaker_call_wrapper(self):
123
+ """Test circuit breaker call wrapper method"""
124
+ breaker = CircuitBreaker(failure_threshold=3)
125
+
126
+ async def successful_func(x, y):
127
+ return x + y
128
+
129
+ result = await breaker.call(successful_func, 2, 3)
130
+ assert result == 5
131
+
132
+ async def failing_func():
133
+ raise ValueError("Test error")
134
+
135
+ with pytest.raises(ValueError):
136
+ await breaker.call(failing_func)
137
+
138
+ def test_circuit_breaker_reset(self):
139
+ """Test manual circuit breaker reset"""
140
+ breaker = CircuitBreaker(failure_threshold=1)
141
+
142
+ # Open the circuit by causing a failure
143
+ with pytest.raises(ValueError):
144
+ asyncio.run(breaker.call(lambda: (_ for _ in ()).throw(ValueError("Test error"))))
145
+
146
+ # At this point, the circuit should be open
147
+ assert breaker.state == CircuitState.OPEN
148
+
149
+ # Reset manually
150
+ breaker.reset()
151
+
152
+ assert breaker.state == CircuitState.CLOSED
153
+ assert breaker._failure_count == 0
154
+ assert breaker._last_failure_time is None
155
+
156
+
157
+ class TestServiceToken:
158
+ """Test ServiceToken model"""
159
+
160
+ def test_service_token_not_expired(self):
161
+ """Test service token expiration check"""
162
+ token = ServiceToken(
163
+ access_token="test-token",
164
+ token_type="Bearer",
165
+ expires_in=300,
166
+ created_at=datetime.now(timezone.utc),
167
+ )
168
+
169
+ assert token.is_expired is False
170
+ assert token.authorization_header == "Bearer test-token"
171
+
172
+ def test_service_token_expired(self):
173
+ """Test expired service token"""
174
+ token = ServiceToken(
175
+ access_token="test-token",
176
+ token_type="Bearer",
177
+ expires_in=300,
178
+ created_at=datetime.now(timezone.utc) - timedelta(seconds=400),
179
+ )
180
+
181
+ assert token.is_expired is True
182
+
183
+
184
+ class TestServiceAuthClient:
185
+ """Test ServiceAuthClient"""
186
+
187
+ @pytest.mark.asyncio
188
+ async def test_get_service_token_success(
189
+ self, mock_settings, service_auth_client, service_token_response
190
+ ):
191
+ """Test successful service token acquisition"""
192
+ with patch.object(httpx.AsyncClient, "post", new=AsyncMock()) as mock_post:
193
+ mock_post.return_value = httpx.Response(status_code=200, json=service_token_response)
194
+ token = await service_auth_client.get_service_token()
195
+
196
+ assert token == "Bearer eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.test"
197
+ assert service_auth_client._token is not None
198
+ mock_post.assert_called_once()
199
+
200
+ @pytest.mark.asyncio
201
+ async def test_get_service_token_cached(
202
+ self, mock_settings, service_auth_client, service_token_response
203
+ ):
204
+ """Test service token caching"""
205
+ with patch.object(httpx.AsyncClient, "post", new=AsyncMock()) as mock_post:
206
+ mock_post.return_value = httpx.Response(status_code=200, json=service_token_response)
207
+ # First call
208
+ token1 = await service_auth_client.get_service_token()
209
+
210
+ # Second call should use cache
211
+ token2 = await service_auth_client.get_service_token()
212
+
213
+ assert token1 == token2
214
+ # HTTP client should only be called once
215
+ mock_post.assert_called_once()
216
+
217
+ @pytest.mark.asyncio
218
+ async def test_get_service_token_refresh_on_expiry(
219
+ self, mock_settings, service_auth_client, service_token_response
220
+ ):
221
+ """Test service token refresh on expiry"""
222
+ # Set an already expired token
223
+ service_auth_client._token = ServiceToken(
224
+ access_token="old-token",
225
+ token_type="Bearer",
226
+ expires_in=300,
227
+ created_at=datetime.now(timezone.utc) - timedelta(seconds=400),
228
+ )
229
+
230
+ with patch.object(httpx.AsyncClient, "post", new=AsyncMock()) as mock_post:
231
+ mock_post.return_value = httpx.Response(status_code=200, json=service_token_response)
232
+
233
+ token = await service_auth_client.get_service_token()
234
+
235
+ assert token == "Bearer eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.test"
236
+ mock_post.assert_called_once()
237
+
238
+ @pytest.mark.asyncio
239
+ async def test_get_service_token_circuit_breaker_open(self, mock_settings, service_auth_client):
240
+ """Test service token with open circuit breaker"""
241
+ # Manually open the circuit breaker
242
+ service_auth_client._circuit_breaker._state = CircuitState.OPEN
243
+ service_auth_client._circuit_breaker._last_failure_time = datetime.now(timezone.utc)
244
+
245
+ with pytest.raises(HTTPException) as exc_info:
246
+ await service_auth_client.get_service_token()
247
+
248
+ assert exc_info.value.status_code == 503
249
+ assert "temporarily unavailable" in exc_info.value.detail
250
+
251
+ @pytest.mark.asyncio
252
+ async def test_get_service_token_connection_error(self, mock_settings, service_auth_client):
253
+ """Test service token with connection error"""
254
+ with patch.object(httpx.AsyncClient, "post", new=AsyncMock()) as mock_post:
255
+ mock_post.side_effect = httpx.RequestError("Connection failed")
256
+
257
+ with pytest.raises(HTTPException) as exc_info:
258
+ await service_auth_client.get_service_token()
259
+
260
+ assert exc_info.value.status_code == 503
261
+
262
+ @pytest.mark.asyncio
263
+ async def test_get_service_token_circuit_breaker_integration(
264
+ self, mock_settings, service_auth_client
265
+ ):
266
+ """Test circuit breaker integration with service token"""
267
+ # Simulate multiple failures to open circuit
268
+ with patch.object(httpx.AsyncClient, "post", new=AsyncMock()) as mock_post:
269
+ mock_post.side_effect = httpx.RequestError("Connection failed")
270
+
271
+ # First few failures should be allowed
272
+ for i in range(mock_settings.CIRCUIT_BREAKER_FAILURE_THRESHOLD):
273
+ with pytest.raises(HTTPException):
274
+ await service_auth_client.get_service_token()
275
+
276
+ # Circuit should now be open
277
+ assert service_auth_client.circuit_state == CircuitState.OPEN
278
+
279
+ # Next call should be rejected immediately
280
+ with pytest.raises(HTTPException) as exc_info:
281
+ await service_auth_client.get_service_token()
282
+
283
+ assert "temporarily unavailable" in exc_info.value.detail
284
+
285
+ def test_circuit_state_property(self, service_auth_client):
286
+ """Test circuit state property"""
287
+ assert service_auth_client.circuit_state == CircuitState.CLOSED
288
+
289
+ def test_reset_circuit(self, service_auth_client):
290
+ """Test manual circuit reset"""
291
+ service_auth_client._circuit_breaker._state = CircuitState.OPEN
292
+
293
+ service_auth_client.reset_circuit()
294
+
295
+ assert service_auth_client.circuit_state == CircuitState.CLOSED
File without changes
File without changes