byoi 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- byoi/__init__.py +233 -0
- byoi/__main__.py +228 -0
- byoi/cache.py +346 -0
- byoi/config.py +349 -0
- byoi/dependencies.py +144 -0
- byoi/errors.py +360 -0
- byoi/models.py +451 -0
- byoi/pkce.py +144 -0
- byoi/providers.py +434 -0
- byoi/py.typed +0 -0
- byoi/repositories.py +252 -0
- byoi/service.py +723 -0
- byoi/telemetry.py +352 -0
- byoi/tokens.py +340 -0
- byoi/types.py +130 -0
- byoi-0.1.0a1.dist-info/METADATA +504 -0
- byoi-0.1.0a1.dist-info/RECORD +20 -0
- byoi-0.1.0a1.dist-info/WHEEL +4 -0
- byoi-0.1.0a1.dist-info/entry_points.txt +3 -0
- byoi-0.1.0a1.dist-info/licenses/LICENSE +21 -0
byoi/service.py
ADDED
|
@@ -0,0 +1,723 @@
|
|
|
1
|
+
"""Authentication service for the BYOI library.
|
|
2
|
+
|
|
3
|
+
This is the main service that orchestrates the OAuth/OIDC authentication flow,
|
|
4
|
+
including PKCE, token exchange, and identity linking.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
from datetime import datetime, timedelta, timezone
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Self
|
|
12
|
+
from urllib.parse import urlencode
|
|
13
|
+
|
|
14
|
+
import httpx
|
|
15
|
+
|
|
16
|
+
from byoi.errors import (
|
|
17
|
+
CannotUnlinkLastIdentityError,
|
|
18
|
+
IdentityAlreadyLinkedError,
|
|
19
|
+
IdentityNotFoundError,
|
|
20
|
+
InvalidCodeError,
|
|
21
|
+
InvalidStateError,
|
|
22
|
+
StateExpiredError,
|
|
23
|
+
TokenExchangeError,
|
|
24
|
+
TokenRefreshError,
|
|
25
|
+
UserNotFoundError,
|
|
26
|
+
)
|
|
27
|
+
from byoi.models import (
|
|
28
|
+
AuthenticatedUser,
|
|
29
|
+
AuthorizationRequest,
|
|
30
|
+
AuthorizationResponse,
|
|
31
|
+
LinkIdentityRequest,
|
|
32
|
+
LinkedIdentityInfo,
|
|
33
|
+
TokenExchangeRequest,
|
|
34
|
+
TokenExchangeResponse,
|
|
35
|
+
TokenRefreshRequest,
|
|
36
|
+
TokenRefreshResponse,
|
|
37
|
+
UnlinkIdentityRequest,
|
|
38
|
+
UserIdentities,
|
|
39
|
+
)
|
|
40
|
+
from byoi.pkce import generate_nonce, generate_pkce_pair, generate_state
|
|
41
|
+
from byoi.providers import ProviderManager
|
|
42
|
+
from byoi.repositories import (
|
|
43
|
+
AuthStateRepositoryProtocol,
|
|
44
|
+
LinkedIdentityRepositoryProtocol,
|
|
45
|
+
UserRepositoryProtocol,
|
|
46
|
+
)
|
|
47
|
+
from byoi.tokens import TokenValidator
|
|
48
|
+
|
|
49
|
+
if TYPE_CHECKING:
|
|
50
|
+
from types import TracebackType
|
|
51
|
+
|
|
52
|
+
logger = logging.getLogger("byoi.service")
|
|
53
|
+
|
|
54
|
+
__all__ = ("AuthService",)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class AuthService:
|
|
58
|
+
"""Main authentication service for BYOI.
|
|
59
|
+
|
|
60
|
+
This service provides the core authentication functionality:
|
|
61
|
+
- Initiating OAuth authorization flows with PKCE
|
|
62
|
+
- Exchanging authorization codes for tokens
|
|
63
|
+
- Validating ID tokens
|
|
64
|
+
- Managing identity linking
|
|
65
|
+
|
|
66
|
+
The implementing application is responsible for:
|
|
67
|
+
- Providing repository implementations for data persistence
|
|
68
|
+
- Creating FastAPI routes that use this service
|
|
69
|
+
- Managing user sessions after authentication
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
# Default expiration time for auth state
|
|
73
|
+
DEFAULT_STATE_EXPIRATION_MINUTES = 10
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
provider_manager: ProviderManager,
|
|
78
|
+
user_repository: UserRepositoryProtocol,
|
|
79
|
+
identity_repository: LinkedIdentityRepositoryProtocol,
|
|
80
|
+
auth_state_repository: AuthStateRepositoryProtocol,
|
|
81
|
+
state_expiration_minutes: int = DEFAULT_STATE_EXPIRATION_MINUTES,
|
|
82
|
+
http_client: httpx.AsyncClient | None = None,
|
|
83
|
+
http_timeout_seconds: float = 30.0,
|
|
84
|
+
) -> None:
|
|
85
|
+
"""Initialize the auth service.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
provider_manager: Manager for OIDC providers.
|
|
89
|
+
user_repository: Repository for user data.
|
|
90
|
+
identity_repository: Repository for linked identities.
|
|
91
|
+
auth_state_repository: Repository for auth state (PKCE, nonce).
|
|
92
|
+
state_expiration_minutes: How long auth states are valid.
|
|
93
|
+
http_client: Optional HTTP client for making requests.
|
|
94
|
+
http_timeout_seconds: Timeout for HTTP requests in seconds.
|
|
95
|
+
"""
|
|
96
|
+
self._provider_manager = provider_manager
|
|
97
|
+
self._user_repository = user_repository
|
|
98
|
+
self._identity_repository = identity_repository
|
|
99
|
+
self._auth_state_repository = auth_state_repository
|
|
100
|
+
self._state_expiration_minutes = state_expiration_minutes
|
|
101
|
+
self._http_client = http_client
|
|
102
|
+
self._owns_http_client = http_client is None
|
|
103
|
+
self._http_timeout_seconds = http_timeout_seconds
|
|
104
|
+
self._token_validator = TokenValidator(provider_manager)
|
|
105
|
+
|
|
106
|
+
async def __aenter__(self) -> Self:
|
|
107
|
+
"""Enter the async context manager."""
|
|
108
|
+
return self
|
|
109
|
+
|
|
110
|
+
async def __aexit__(
|
|
111
|
+
self,
|
|
112
|
+
exc_type: type[BaseException] | None,
|
|
113
|
+
exc_val: BaseException | None,
|
|
114
|
+
exc_tb: TracebackType | None,
|
|
115
|
+
) -> None:
|
|
116
|
+
"""Exit the async context manager and close the service."""
|
|
117
|
+
await self.close()
|
|
118
|
+
async def _get_http_client(self) -> httpx.AsyncClient:
|
|
119
|
+
"""Get or create the HTTP client."""
|
|
120
|
+
if self._http_client is None:
|
|
121
|
+
self._http_client = httpx.AsyncClient(
|
|
122
|
+
timeout=httpx.Timeout(self._http_timeout_seconds),
|
|
123
|
+
follow_redirects=True,
|
|
124
|
+
)
|
|
125
|
+
return self._http_client
|
|
126
|
+
|
|
127
|
+
# =========================================================================
|
|
128
|
+
# Authorization Flow
|
|
129
|
+
# =========================================================================
|
|
130
|
+
|
|
131
|
+
async def create_authorization_url(
|
|
132
|
+
self,
|
|
133
|
+
request: AuthorizationRequest,
|
|
134
|
+
) -> AuthorizationResponse:
|
|
135
|
+
"""Create an authorization URL for initiating OAuth flow.
|
|
136
|
+
|
|
137
|
+
This generates PKCE challenge, state, and nonce, stores them,
|
|
138
|
+
and returns the URL to redirect the user to.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
request: The authorization request details.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Response containing the authorization URL and state.
|
|
145
|
+
|
|
146
|
+
Raises:
|
|
147
|
+
ProviderNotFoundError: If the provider is not registered.
|
|
148
|
+
"""
|
|
149
|
+
# Get provider config and authorization endpoint
|
|
150
|
+
config = self._provider_manager.get_provider(request.provider_name)
|
|
151
|
+
auth_endpoint = await self._provider_manager.get_authorization_endpoint(
|
|
152
|
+
request.provider_name
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# Generate PKCE, state, and nonce
|
|
156
|
+
pkce = generate_pkce_pair()
|
|
157
|
+
state = generate_state()
|
|
158
|
+
nonce = generate_nonce()
|
|
159
|
+
|
|
160
|
+
# Calculate expiration
|
|
161
|
+
now = datetime.now(timezone.utc)
|
|
162
|
+
expires_at = now + timedelta(minutes=self._state_expiration_minutes)
|
|
163
|
+
|
|
164
|
+
# Store auth state
|
|
165
|
+
await self._auth_state_repository.create(
|
|
166
|
+
state=state,
|
|
167
|
+
code_verifier=pkce.code_verifier,
|
|
168
|
+
nonce=nonce,
|
|
169
|
+
provider_name=request.provider_name,
|
|
170
|
+
redirect_uri=request.redirect_uri,
|
|
171
|
+
expires_at=expires_at,
|
|
172
|
+
client_type=request.client_type.value,
|
|
173
|
+
extra_data=request.extra_data,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Build authorization URL parameters
|
|
177
|
+
params = {
|
|
178
|
+
"response_type": "code",
|
|
179
|
+
"client_id": config.client_id,
|
|
180
|
+
"redirect_uri": request.redirect_uri,
|
|
181
|
+
"scope": " ".join(config.scopes),
|
|
182
|
+
"state": state,
|
|
183
|
+
"nonce": nonce,
|
|
184
|
+
"code_challenge": pkce.code_challenge,
|
|
185
|
+
"code_challenge_method": pkce.code_challenge_method,
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
# Add any extra auth params from config
|
|
189
|
+
params.update(config.extra_auth_params)
|
|
190
|
+
|
|
191
|
+
# Build the full authorization URL
|
|
192
|
+
authorization_url = f"{auth_endpoint}?{urlencode(params)}"
|
|
193
|
+
|
|
194
|
+
return AuthorizationResponse(
|
|
195
|
+
authorization_url=authorization_url,
|
|
196
|
+
state=state,
|
|
197
|
+
provider_name=request.provider_name,
|
|
198
|
+
expires_at=expires_at,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
async def exchange_code(
|
|
202
|
+
self,
|
|
203
|
+
request: TokenExchangeRequest,
|
|
204
|
+
) -> TokenExchangeResponse:
|
|
205
|
+
"""Exchange an authorization code for tokens.
|
|
206
|
+
|
|
207
|
+
This validates the state, performs PKCE verification, exchanges
|
|
208
|
+
the code for tokens, and validates the ID token.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
request: The token exchange request.
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
Response containing tokens and identity information.
|
|
215
|
+
|
|
216
|
+
Raises:
|
|
217
|
+
InvalidStateError: If the state is invalid or not found.
|
|
218
|
+
StateExpiredError: If the state has expired.
|
|
219
|
+
InvalidCodeError: If the authorization code is invalid.
|
|
220
|
+
TokenExchangeError: If the token exchange fails.
|
|
221
|
+
TokenValidationError: If ID token validation fails.
|
|
222
|
+
"""
|
|
223
|
+
# Retrieve and validate auth state
|
|
224
|
+
auth_state = await self._auth_state_repository.get_by_state(request.state)
|
|
225
|
+
if auth_state is None:
|
|
226
|
+
raise InvalidStateError(request.state)
|
|
227
|
+
|
|
228
|
+
# Check expiration
|
|
229
|
+
now = datetime.now(timezone.utc)
|
|
230
|
+
expires_at = auth_state.expires_at
|
|
231
|
+
# Ensure expires_at is timezone-aware for comparison
|
|
232
|
+
if expires_at.tzinfo is None:
|
|
233
|
+
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
|
234
|
+
|
|
235
|
+
if now > expires_at:
|
|
236
|
+
# Clean up expired state
|
|
237
|
+
await self._auth_state_repository.delete(request.state)
|
|
238
|
+
raise StateExpiredError(request.state)
|
|
239
|
+
|
|
240
|
+
# Validate redirect URI matches
|
|
241
|
+
if auth_state.redirect_uri != request.redirect_uri:
|
|
242
|
+
raise InvalidStateError(request.state)
|
|
243
|
+
|
|
244
|
+
# Get provider config and token endpoint
|
|
245
|
+
provider_name = auth_state.provider_name
|
|
246
|
+
config = self._provider_manager.get_provider(provider_name)
|
|
247
|
+
token_endpoint = await self._provider_manager.get_token_endpoint(provider_name)
|
|
248
|
+
|
|
249
|
+
# Exchange code for tokens
|
|
250
|
+
token_response = await self._exchange_code_for_tokens(
|
|
251
|
+
code=request.code,
|
|
252
|
+
redirect_uri=request.redirect_uri,
|
|
253
|
+
code_verifier=auth_state.code_verifier,
|
|
254
|
+
client_id=config.client_id,
|
|
255
|
+
client_secret=config.client_secret,
|
|
256
|
+
token_endpoint=token_endpoint,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# Validate ID token
|
|
260
|
+
id_token = token_response.get("id_token")
|
|
261
|
+
if not id_token:
|
|
262
|
+
raise TokenExchangeError("No ID token in response")
|
|
263
|
+
|
|
264
|
+
identity = await self._token_validator.validate(
|
|
265
|
+
id_token=id_token,
|
|
266
|
+
provider_name=provider_name,
|
|
267
|
+
expected_nonce=auth_state.nonce,
|
|
268
|
+
verify_at_hash=False, # Optional verification
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
# Clean up the auth state (one-time use)
|
|
272
|
+
await self._auth_state_repository.delete(request.state)
|
|
273
|
+
|
|
274
|
+
return TokenExchangeResponse(
|
|
275
|
+
identity=identity,
|
|
276
|
+
access_token=token_response["access_token"],
|
|
277
|
+
token_type=token_response.get("token_type", "Bearer"),
|
|
278
|
+
expires_in=token_response.get("expires_in"),
|
|
279
|
+
refresh_token=token_response.get("refresh_token"),
|
|
280
|
+
id_token=id_token,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
async def _exchange_code_for_tokens(
|
|
284
|
+
self,
|
|
285
|
+
code: str,
|
|
286
|
+
redirect_uri: str,
|
|
287
|
+
code_verifier: str,
|
|
288
|
+
client_id: str,
|
|
289
|
+
client_secret: str | None,
|
|
290
|
+
token_endpoint: str,
|
|
291
|
+
) -> dict[str, Any]:
|
|
292
|
+
"""Exchange an authorization code for tokens.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
code: The authorization code.
|
|
296
|
+
redirect_uri: The redirect URI.
|
|
297
|
+
code_verifier: The PKCE code verifier.
|
|
298
|
+
client_id: The OAuth client ID.
|
|
299
|
+
client_secret: The OAuth client secret (optional).
|
|
300
|
+
token_endpoint: The token endpoint URL.
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
The token response as a dictionary.
|
|
304
|
+
|
|
305
|
+
Raises:
|
|
306
|
+
InvalidCodeError: If the code is invalid.
|
|
307
|
+
TokenExchangeError: If the exchange fails.
|
|
308
|
+
"""
|
|
309
|
+
# Build request body
|
|
310
|
+
data = {
|
|
311
|
+
"grant_type": "authorization_code",
|
|
312
|
+
"code": code,
|
|
313
|
+
"redirect_uri": redirect_uri,
|
|
314
|
+
"client_id": client_id,
|
|
315
|
+
"code_verifier": code_verifier,
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
# Add client secret if provided (confidential clients)
|
|
319
|
+
if client_secret:
|
|
320
|
+
data["client_secret"] = client_secret
|
|
321
|
+
|
|
322
|
+
try:
|
|
323
|
+
client = await self._get_http_client()
|
|
324
|
+
response = await client.post(
|
|
325
|
+
token_endpoint,
|
|
326
|
+
data=data,
|
|
327
|
+
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
if response.status_code == 400:
|
|
331
|
+
error_data = response.json()
|
|
332
|
+
error = error_data.get("error", "unknown_error")
|
|
333
|
+
error_description = error_data.get("error_description", "")
|
|
334
|
+
|
|
335
|
+
if error == "invalid_grant":
|
|
336
|
+
raise InvalidCodeError(error_description or "Invalid or expired code")
|
|
337
|
+
|
|
338
|
+
raise TokenExchangeError(
|
|
339
|
+
f"{error}: {error_description}" if error_description else error
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
response.raise_for_status()
|
|
343
|
+
return response.json()
|
|
344
|
+
|
|
345
|
+
except (InvalidCodeError, TokenExchangeError):
|
|
346
|
+
raise
|
|
347
|
+
except httpx.HTTPError as e:
|
|
348
|
+
raise TokenExchangeError(f"HTTP error: {e}") from e
|
|
349
|
+
except Exception as e:
|
|
350
|
+
raise TokenExchangeError(f"Unexpected error: {e}") from e
|
|
351
|
+
|
|
352
|
+
async def refresh_token(
|
|
353
|
+
self,
|
|
354
|
+
request: TokenRefreshRequest,
|
|
355
|
+
) -> TokenRefreshResponse:
|
|
356
|
+
"""Refresh an access token using a refresh token.
|
|
357
|
+
|
|
358
|
+
This method exchanges a refresh token for a new access token.
|
|
359
|
+
Some providers may also return a new refresh token (token rotation).
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
request: The token refresh request containing the refresh token.
|
|
363
|
+
|
|
364
|
+
Returns:
|
|
365
|
+
Response containing the new access token and optional refresh token.
|
|
366
|
+
|
|
367
|
+
Raises:
|
|
368
|
+
ProviderNotFoundError: If the provider is not registered.
|
|
369
|
+
TokenRefreshError: If the token refresh fails.
|
|
370
|
+
"""
|
|
371
|
+
logger.debug("Refreshing token for provider=%s", request.provider_name)
|
|
372
|
+
|
|
373
|
+
# Get provider config and token endpoint
|
|
374
|
+
config = self._provider_manager.get_provider(request.provider_name)
|
|
375
|
+
token_endpoint = await self._provider_manager.get_token_endpoint(
|
|
376
|
+
request.provider_name
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
# Build request body
|
|
380
|
+
data: dict[str, str] = {
|
|
381
|
+
"grant_type": "refresh_token",
|
|
382
|
+
"refresh_token": request.refresh_token,
|
|
383
|
+
"client_id": config.client_id,
|
|
384
|
+
}
|
|
385
|
+
|
|
386
|
+
# Add client secret if provided (confidential clients)
|
|
387
|
+
if config.client_secret:
|
|
388
|
+
data["client_secret"] = config.client_secret
|
|
389
|
+
|
|
390
|
+
# Add scope if specified
|
|
391
|
+
if request.scope:
|
|
392
|
+
data["scope"] = request.scope
|
|
393
|
+
|
|
394
|
+
try:
|
|
395
|
+
client = await self._get_http_client()
|
|
396
|
+
response = await client.post(
|
|
397
|
+
token_endpoint,
|
|
398
|
+
data=data,
|
|
399
|
+
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
if response.status_code == 400:
|
|
403
|
+
error_data = response.json()
|
|
404
|
+
error = error_data.get("error", "unknown_error")
|
|
405
|
+
error_description = error_data.get("error_description", "")
|
|
406
|
+
|
|
407
|
+
raise TokenRefreshError(
|
|
408
|
+
f"{error}: {error_description}" if error_description else error
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
response.raise_for_status()
|
|
412
|
+
token_data = response.json()
|
|
413
|
+
|
|
414
|
+
logger.debug(
|
|
415
|
+
"Token refreshed successfully for provider=%s",
|
|
416
|
+
request.provider_name,
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
return TokenRefreshResponse(
|
|
420
|
+
access_token=token_data["access_token"],
|
|
421
|
+
token_type=token_data.get("token_type", "Bearer"),
|
|
422
|
+
expires_in=token_data.get("expires_in"),
|
|
423
|
+
refresh_token=token_data.get("refresh_token"),
|
|
424
|
+
scope=token_data.get("scope"),
|
|
425
|
+
id_token=token_data.get("id_token"),
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
except TokenRefreshError:
|
|
429
|
+
raise
|
|
430
|
+
except httpx.HTTPError as e:
|
|
431
|
+
logger.error("HTTP error during token refresh: %s", e)
|
|
432
|
+
raise TokenRefreshError(f"HTTP error: {e}") from e
|
|
433
|
+
except Exception as e:
|
|
434
|
+
logger.error("Unexpected error during token refresh: %s", e)
|
|
435
|
+
raise TokenRefreshError(f"Unexpected error: {e}") from e
|
|
436
|
+
|
|
437
|
+
# =========================================================================
|
|
438
|
+
# Authentication (Login/Register)
|
|
439
|
+
# =========================================================================
|
|
440
|
+
|
|
441
|
+
async def authenticate(
|
|
442
|
+
self,
|
|
443
|
+
request: TokenExchangeRequest,
|
|
444
|
+
auto_create_user: bool = True,
|
|
445
|
+
) -> AuthenticatedUser:
|
|
446
|
+
"""Authenticate a user using an authorization code.
|
|
447
|
+
|
|
448
|
+
This is the main authentication method. It:
|
|
449
|
+
1. Exchanges the code for tokens
|
|
450
|
+
2. Looks up or creates the user
|
|
451
|
+
3. Links/updates the identity
|
|
452
|
+
|
|
453
|
+
Args:
|
|
454
|
+
request: The token exchange request.
|
|
455
|
+
auto_create_user: Whether to create a new user if not found.
|
|
456
|
+
|
|
457
|
+
Returns:
|
|
458
|
+
The authenticated user information.
|
|
459
|
+
|
|
460
|
+
Raises:
|
|
461
|
+
InvalidStateError: If the state is invalid.
|
|
462
|
+
StateExpiredError: If the state has expired.
|
|
463
|
+
TokenExchangeError: If token exchange fails.
|
|
464
|
+
TokenValidationError: If ID token validation fails.
|
|
465
|
+
UserNotFoundError: If user not found and auto_create_user is False.
|
|
466
|
+
"""
|
|
467
|
+
logger.debug("Authenticating user with state=%s", request.state)
|
|
468
|
+
|
|
469
|
+
# Exchange code for tokens and identity
|
|
470
|
+
token_response = await self.exchange_code(request)
|
|
471
|
+
identity = token_response.identity
|
|
472
|
+
|
|
473
|
+
# Check if this identity is already linked
|
|
474
|
+
existing_link = await self._identity_repository.get_by_provider(
|
|
475
|
+
provider_name=identity.provider_name,
|
|
476
|
+
provider_subject=identity.subject,
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
if existing_link:
|
|
480
|
+
# User exists, update last used time
|
|
481
|
+
await self._identity_repository.update_last_used(
|
|
482
|
+
identity_id=existing_link.id,
|
|
483
|
+
last_used_at=datetime.now(timezone.utc),
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
# Get the user
|
|
487
|
+
user = await self._user_repository.get_by_id(existing_link.user_id)
|
|
488
|
+
if user is None:
|
|
489
|
+
# This shouldn't happen, but handle it gracefully
|
|
490
|
+
logger.error(
|
|
491
|
+
"Linked identity references non-existent user: identity_id=%s, user_id=%s",
|
|
492
|
+
existing_link.id,
|
|
493
|
+
existing_link.user_id,
|
|
494
|
+
)
|
|
495
|
+
raise UserNotFoundError(existing_link.user_id)
|
|
496
|
+
|
|
497
|
+
logger.info(
|
|
498
|
+
"User authenticated: user_id=%s, provider=%s, is_new=False",
|
|
499
|
+
user.id,
|
|
500
|
+
identity.provider_name,
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
return AuthenticatedUser(
|
|
504
|
+
user_id=user.id,
|
|
505
|
+
email=user.email,
|
|
506
|
+
is_new_user=False,
|
|
507
|
+
identity=identity,
|
|
508
|
+
linked_identity_id=existing_link.id,
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
# Identity not linked - create new user or fail
|
|
512
|
+
if not auto_create_user:
|
|
513
|
+
logger.warning(
|
|
514
|
+
"User not found and auto_create_user=False: provider=%s, subject=%s",
|
|
515
|
+
identity.provider_name,
|
|
516
|
+
identity.subject,
|
|
517
|
+
)
|
|
518
|
+
raise UserNotFoundError(
|
|
519
|
+
f"No user linked to {identity.provider_name}:{identity.subject}"
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
# Create new user
|
|
523
|
+
logger.debug("Creating new user for provider=%s", identity.provider_name)
|
|
524
|
+
user = await self._user_repository.create(
|
|
525
|
+
email=identity.email if identity.email_verified else None,
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
# Link the identity to the new user
|
|
529
|
+
linked_identity = await self._identity_repository.create(
|
|
530
|
+
user_id=user.id,
|
|
531
|
+
provider_name=identity.provider_name,
|
|
532
|
+
provider_subject=identity.subject,
|
|
533
|
+
email=identity.email,
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
logger.info(
|
|
537
|
+
"New user created and authenticated: user_id=%s, provider=%s, is_new=True",
|
|
538
|
+
user.id,
|
|
539
|
+
identity.provider_name,
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
return AuthenticatedUser(
|
|
543
|
+
user_id=user.id,
|
|
544
|
+
email=user.email,
|
|
545
|
+
is_new_user=True,
|
|
546
|
+
identity=identity,
|
|
547
|
+
linked_identity_id=linked_identity.id,
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
# =========================================================================
|
|
551
|
+
# Identity Linking
|
|
552
|
+
# =========================================================================
|
|
553
|
+
|
|
554
|
+
async def link_identity(
|
|
555
|
+
self,
|
|
556
|
+
request: LinkIdentityRequest,
|
|
557
|
+
) -> LinkedIdentityInfo:
|
|
558
|
+
"""Link a new identity to an existing user.
|
|
559
|
+
|
|
560
|
+
This allows users to add additional login methods to their account.
|
|
561
|
+
|
|
562
|
+
Args:
|
|
563
|
+
request: The link identity request.
|
|
564
|
+
|
|
565
|
+
Returns:
|
|
566
|
+
Information about the linked identity.
|
|
567
|
+
|
|
568
|
+
Raises:
|
|
569
|
+
UserNotFoundError: If the user doesn't exist.
|
|
570
|
+
IdentityAlreadyLinkedError: If the identity is already linked.
|
|
571
|
+
InvalidStateError: If the state is invalid.
|
|
572
|
+
TokenExchangeError: If token exchange fails.
|
|
573
|
+
"""
|
|
574
|
+
# Verify user exists
|
|
575
|
+
user = await self._user_repository.get_by_id(request.user_id)
|
|
576
|
+
if user is None:
|
|
577
|
+
raise UserNotFoundError(request.user_id)
|
|
578
|
+
|
|
579
|
+
# Exchange code for tokens
|
|
580
|
+
token_response = await self.exchange_code(
|
|
581
|
+
TokenExchangeRequest(
|
|
582
|
+
code=request.code,
|
|
583
|
+
state=request.state,
|
|
584
|
+
redirect_uri=request.redirect_uri,
|
|
585
|
+
)
|
|
586
|
+
)
|
|
587
|
+
identity = token_response.identity
|
|
588
|
+
|
|
589
|
+
# Check if this identity is already linked to any user
|
|
590
|
+
existing_link = await self._identity_repository.get_by_provider(
|
|
591
|
+
provider_name=identity.provider_name,
|
|
592
|
+
provider_subject=identity.subject,
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
if existing_link:
|
|
596
|
+
raise IdentityAlreadyLinkedError(
|
|
597
|
+
provider_name=identity.provider_name,
|
|
598
|
+
provider_subject=identity.subject,
|
|
599
|
+
existing_user_id=existing_link.user_id,
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
# Create the link
|
|
603
|
+
linked_identity = await self._identity_repository.create(
|
|
604
|
+
user_id=request.user_id,
|
|
605
|
+
provider_name=identity.provider_name,
|
|
606
|
+
provider_subject=identity.subject,
|
|
607
|
+
email=identity.email,
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
# Get provider display name
|
|
611
|
+
provider_info = self._provider_manager.get_provider_info(identity.provider_name)
|
|
612
|
+
|
|
613
|
+
return LinkedIdentityInfo(
|
|
614
|
+
id=linked_identity.id,
|
|
615
|
+
provider_name=linked_identity.provider_name,
|
|
616
|
+
provider_display_name=provider_info.display_name,
|
|
617
|
+
email=linked_identity.email,
|
|
618
|
+
created_at=linked_identity.created_at,
|
|
619
|
+
last_used_at=linked_identity.last_used_at,
|
|
620
|
+
)
|
|
621
|
+
|
|
622
|
+
async def unlink_identity(
|
|
623
|
+
self,
|
|
624
|
+
request: UnlinkIdentityRequest,
|
|
625
|
+
) -> bool:
|
|
626
|
+
"""Unlink an identity from a user.
|
|
627
|
+
|
|
628
|
+
Args:
|
|
629
|
+
request: The unlink identity request.
|
|
630
|
+
|
|
631
|
+
Returns:
|
|
632
|
+
True if the identity was unlinked.
|
|
633
|
+
|
|
634
|
+
Raises:
|
|
635
|
+
UserNotFoundError: If the user doesn't exist.
|
|
636
|
+
IdentityNotFoundError: If the identity doesn't exist.
|
|
637
|
+
CannotUnlinkLastIdentityError: If this is the user's only identity.
|
|
638
|
+
"""
|
|
639
|
+
# Verify user exists
|
|
640
|
+
user = await self._user_repository.get_by_id(request.user_id)
|
|
641
|
+
if user is None:
|
|
642
|
+
raise UserNotFoundError(request.user_id)
|
|
643
|
+
|
|
644
|
+
# Get the identity
|
|
645
|
+
identity = await self._identity_repository.get_by_id(request.identity_id)
|
|
646
|
+
if identity is None:
|
|
647
|
+
raise IdentityNotFoundError(request.identity_id)
|
|
648
|
+
|
|
649
|
+
# Verify the identity belongs to this user
|
|
650
|
+
if identity.user_id != request.user_id:
|
|
651
|
+
raise IdentityNotFoundError(request.identity_id)
|
|
652
|
+
|
|
653
|
+
# Check if this is the user's only identity
|
|
654
|
+
user_identities = await self._identity_repository.get_by_user_id(request.user_id)
|
|
655
|
+
if len(user_identities) <= 1:
|
|
656
|
+
raise CannotUnlinkLastIdentityError(request.user_id)
|
|
657
|
+
|
|
658
|
+
# Delete the identity
|
|
659
|
+
return await self._identity_repository.delete(request.identity_id)
|
|
660
|
+
|
|
661
|
+
async def get_user_identities(self, user_id: str) -> UserIdentities:
|
|
662
|
+
"""Get all linked identities for a user.
|
|
663
|
+
|
|
664
|
+
Args:
|
|
665
|
+
user_id: The user's ID.
|
|
666
|
+
|
|
667
|
+
Returns:
|
|
668
|
+
List of linked identities.
|
|
669
|
+
|
|
670
|
+
Raises:
|
|
671
|
+
UserNotFoundError: If the user doesn't exist.
|
|
672
|
+
"""
|
|
673
|
+
# Verify user exists
|
|
674
|
+
user = await self._user_repository.get_by_id(user_id)
|
|
675
|
+
if user is None:
|
|
676
|
+
raise UserNotFoundError(user_id)
|
|
677
|
+
|
|
678
|
+
# Get all identities
|
|
679
|
+
identities = await self._identity_repository.get_by_user_id(user_id)
|
|
680
|
+
|
|
681
|
+
# Convert to LinkedIdentityInfo
|
|
682
|
+
identity_infos = []
|
|
683
|
+
for identity in identities:
|
|
684
|
+
try:
|
|
685
|
+
provider_info = self._provider_manager.get_provider_info(
|
|
686
|
+
identity.provider_name
|
|
687
|
+
)
|
|
688
|
+
display_name = provider_info.display_name
|
|
689
|
+
except Exception:
|
|
690
|
+
display_name = identity.provider_name
|
|
691
|
+
|
|
692
|
+
identity_infos.append(
|
|
693
|
+
LinkedIdentityInfo(
|
|
694
|
+
id=identity.id,
|
|
695
|
+
provider_name=identity.provider_name,
|
|
696
|
+
provider_display_name=display_name,
|
|
697
|
+
email=identity.email,
|
|
698
|
+
created_at=identity.created_at,
|
|
699
|
+
last_used_at=identity.last_used_at,
|
|
700
|
+
)
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
return UserIdentities(user_id=user_id, identities=identity_infos)
|
|
704
|
+
|
|
705
|
+
# =========================================================================
|
|
706
|
+
# Cleanup
|
|
707
|
+
# =========================================================================
|
|
708
|
+
|
|
709
|
+
async def cleanup_expired_states(self) -> int:
|
|
710
|
+
"""Clean up expired auth states.
|
|
711
|
+
|
|
712
|
+
Should be called periodically to remove stale state data.
|
|
713
|
+
|
|
714
|
+
Returns:
|
|
715
|
+
Number of expired states removed.
|
|
716
|
+
"""
|
|
717
|
+
return await self._auth_state_repository.delete_expired()
|
|
718
|
+
|
|
719
|
+
async def close(self) -> None:
|
|
720
|
+
"""Close the auth service and release resources."""
|
|
721
|
+
if self._owns_http_client and self._http_client is not None:
|
|
722
|
+
await self._http_client.aclose()
|
|
723
|
+
self._http_client = None
|