sweatstack 0.59.0__py3-none-any.whl → 0.61.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sweatstack/client.py +67 -27
- sweatstack/fastapi/__init__.py +82 -0
- sweatstack/fastapi/config.py +223 -0
- sweatstack/fastapi/dependencies.py +293 -0
- sweatstack/fastapi/models.py +109 -0
- sweatstack/fastapi/routes.py +312 -0
- sweatstack/fastapi/session.py +102 -0
- {sweatstack-0.59.0.dist-info → sweatstack-0.61.0.dist-info}/METADATA +4 -1
- {sweatstack-0.59.0.dist-info → sweatstack-0.61.0.dist-info}/RECORD +11 -5
- {sweatstack-0.59.0.dist-info → sweatstack-0.61.0.dist-info}/WHEEL +0 -0
- {sweatstack-0.59.0.dist-info → sweatstack-0.61.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
"""FastAPI dependencies for authentication."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import time
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import Annotated, NoReturn
|
|
9
|
+
from urllib.parse import quote
|
|
10
|
+
|
|
11
|
+
import httpx
|
|
12
|
+
from fastapi import Depends, HTTPException, Request, Response
|
|
13
|
+
|
|
14
|
+
from ..client import Client
|
|
15
|
+
from ..constants import DEFAULT_URL
|
|
16
|
+
from ..utils import decode_jwt_body
|
|
17
|
+
from .config import get_config
|
|
18
|
+
from .models import SessionData, TokenSet, extract_user_id
|
|
19
|
+
from .session import (
|
|
20
|
+
SESSION_COOKIE_NAME,
|
|
21
|
+
clear_session_cookie,
|
|
22
|
+
decrypt_session,
|
|
23
|
+
set_session_cookie,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
TOKEN_EXPIRY_MARGIN = 5 # seconds
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass(slots=True)
|
|
32
|
+
class SweatStackUser:
|
|
33
|
+
"""Authenticated SweatStack user.
|
|
34
|
+
|
|
35
|
+
Attributes:
|
|
36
|
+
client: An authenticated Client instance for API calls.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
client: Client
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def user_id(self) -> str:
|
|
43
|
+
"""The user ID this client acts as."""
|
|
44
|
+
return extract_user_id(self.client.api_key)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
# ---------------------------------------------------------------------------
|
|
48
|
+
# Token refresh
|
|
49
|
+
# ---------------------------------------------------------------------------
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _is_token_expiring(token: str) -> bool:
|
|
53
|
+
"""Check if a token is within TOKEN_EXPIRY_MARGIN seconds of expiring."""
|
|
54
|
+
try:
|
|
55
|
+
body = decode_jwt_body(token)
|
|
56
|
+
return body["exp"] - TOKEN_EXPIRY_MARGIN < time.time()
|
|
57
|
+
except Exception:
|
|
58
|
+
return True
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _refresh_access_token(
|
|
62
|
+
refresh_token: str,
|
|
63
|
+
client_id: str,
|
|
64
|
+
client_secret: str,
|
|
65
|
+
tz: str,
|
|
66
|
+
) -> str:
|
|
67
|
+
"""Exchange a refresh token for a new access token."""
|
|
68
|
+
response = httpx.post(
|
|
69
|
+
f"{DEFAULT_URL}/api/v1/oauth/token",
|
|
70
|
+
data={
|
|
71
|
+
"grant_type": "refresh_token",
|
|
72
|
+
"refresh_token": refresh_token,
|
|
73
|
+
"client_id": client_id,
|
|
74
|
+
"client_secret": client_secret,
|
|
75
|
+
"tz": tz,
|
|
76
|
+
},
|
|
77
|
+
)
|
|
78
|
+
response.raise_for_status()
|
|
79
|
+
return response.json()["access_token"]
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _refresh_tokens_if_needed(tokens: TokenSet) -> TokenSet | None:
|
|
83
|
+
"""Refresh tokens if the access token is expiring.
|
|
84
|
+
|
|
85
|
+
Returns new TokenSet if refreshed, None if no refresh needed.
|
|
86
|
+
"""
|
|
87
|
+
if not _is_token_expiring(tokens.access_token):
|
|
88
|
+
return None
|
|
89
|
+
|
|
90
|
+
token_body = decode_jwt_body(tokens.access_token)
|
|
91
|
+
tz = token_body.get("tz", "UTC")
|
|
92
|
+
|
|
93
|
+
config = get_config()
|
|
94
|
+
new_access_token = _refresh_access_token(
|
|
95
|
+
refresh_token=tokens.refresh_token,
|
|
96
|
+
client_id=config.client_id,
|
|
97
|
+
client_secret=config.client_secret.get_secret_value(),
|
|
98
|
+
tz=tz,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
return TokenSet(
|
|
102
|
+
access_token=new_access_token,
|
|
103
|
+
refresh_token=tokens.refresh_token,
|
|
104
|
+
user_id=tokens.user_id,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
# ---------------------------------------------------------------------------
|
|
109
|
+
# Session helpers
|
|
110
|
+
# ---------------------------------------------------------------------------
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _raise_unauthenticated(request: Request) -> NoReturn:
|
|
114
|
+
"""Raise appropriate exception for unauthenticated requests."""
|
|
115
|
+
config = get_config()
|
|
116
|
+
if config.redirect_unauthenticated:
|
|
117
|
+
next_url = request.url.path
|
|
118
|
+
if request.url.query:
|
|
119
|
+
next_url += f"?{request.url.query}"
|
|
120
|
+
login_url = f"{config.auth_route_prefix}/login?next={quote(next_url)}"
|
|
121
|
+
raise HTTPException(status_code=303, headers={"Location": login_url})
|
|
122
|
+
raise HTTPException(status_code=401, detail="Not authenticated")
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _get_session_or_raise(request: Request) -> SessionData:
|
|
126
|
+
"""Get and validate session data, raising if invalid."""
|
|
127
|
+
raw_session = decrypt_session(request.cookies.get(SESSION_COOKIE_NAME))
|
|
128
|
+
if not raw_session:
|
|
129
|
+
_raise_unauthenticated(request)
|
|
130
|
+
|
|
131
|
+
try:
|
|
132
|
+
return SessionData.from_dict(raw_session)
|
|
133
|
+
except (KeyError, TypeError):
|
|
134
|
+
_raise_unauthenticated(request)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _get_session_or_none(request: Request) -> SessionData | None:
|
|
138
|
+
"""Get session data if present and valid, None otherwise."""
|
|
139
|
+
raw_session = decrypt_session(request.cookies.get(SESSION_COOKIE_NAME))
|
|
140
|
+
if not raw_session:
|
|
141
|
+
return None
|
|
142
|
+
|
|
143
|
+
try:
|
|
144
|
+
return SessionData.from_dict(raw_session)
|
|
145
|
+
except (KeyError, TypeError):
|
|
146
|
+
return None
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
# ---------------------------------------------------------------------------
|
|
150
|
+
# Core dependency logic
|
|
151
|
+
# ---------------------------------------------------------------------------
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _create_user(
|
|
155
|
+
session: SessionData,
|
|
156
|
+
response: Response,
|
|
157
|
+
*,
|
|
158
|
+
use_delegated: bool,
|
|
159
|
+
) -> SweatStackUser:
|
|
160
|
+
"""Create user from session, refreshing tokens if needed."""
|
|
161
|
+
config = get_config()
|
|
162
|
+
|
|
163
|
+
# Select which tokens to use
|
|
164
|
+
if use_delegated and session.delegated:
|
|
165
|
+
tokens = session.delegated
|
|
166
|
+
is_delegated = True
|
|
167
|
+
else:
|
|
168
|
+
tokens = session.principal
|
|
169
|
+
is_delegated = False
|
|
170
|
+
|
|
171
|
+
# Refresh tokens if needed and persist immediately
|
|
172
|
+
try:
|
|
173
|
+
refreshed = _refresh_tokens_if_needed(tokens)
|
|
174
|
+
except Exception:
|
|
175
|
+
logger.exception("Token refresh failed for user %s", tokens.user_id)
|
|
176
|
+
clear_session_cookie(response)
|
|
177
|
+
raise HTTPException(status_code=401, detail="Session expired")
|
|
178
|
+
|
|
179
|
+
if refreshed:
|
|
180
|
+
# Update session with refreshed tokens
|
|
181
|
+
if is_delegated:
|
|
182
|
+
session = SessionData(principal=session.principal, delegated=refreshed)
|
|
183
|
+
else:
|
|
184
|
+
session = SessionData(principal=refreshed, delegated=session.delegated)
|
|
185
|
+
tokens = refreshed
|
|
186
|
+
set_session_cookie(response, session.to_dict())
|
|
187
|
+
|
|
188
|
+
return SweatStackUser(
|
|
189
|
+
client=Client(
|
|
190
|
+
api_key=tokens.access_token,
|
|
191
|
+
refresh_token=tokens.refresh_token,
|
|
192
|
+
client_id=config.client_id,
|
|
193
|
+
client_secret=config.client_secret,
|
|
194
|
+
)
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
# ---------------------------------------------------------------------------
|
|
199
|
+
# Dependency functions
|
|
200
|
+
# ---------------------------------------------------------------------------
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def _require_authenticated_user(
|
|
204
|
+
request: Request,
|
|
205
|
+
response: Response,
|
|
206
|
+
) -> SweatStackUser:
|
|
207
|
+
"""Dependency: always returns principal user."""
|
|
208
|
+
session = _get_session_or_raise(request)
|
|
209
|
+
return _create_user(session, response, use_delegated=False)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def _require_selected_user(
|
|
213
|
+
request: Request,
|
|
214
|
+
response: Response,
|
|
215
|
+
) -> SweatStackUser:
|
|
216
|
+
"""Dependency: returns delegated user if selected, otherwise principal."""
|
|
217
|
+
session = _get_session_or_raise(request)
|
|
218
|
+
return _create_user(session, response, use_delegated=True)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def _optional_authenticated_user(
|
|
222
|
+
request: Request,
|
|
223
|
+
response: Response,
|
|
224
|
+
) -> SweatStackUser | None:
|
|
225
|
+
"""Dependency: returns principal user or None."""
|
|
226
|
+
session = _get_session_or_none(request)
|
|
227
|
+
if not session:
|
|
228
|
+
return None
|
|
229
|
+
try:
|
|
230
|
+
return _create_user(session, response, use_delegated=False)
|
|
231
|
+
except HTTPException:
|
|
232
|
+
return None
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def _optional_selected_user(
|
|
236
|
+
request: Request,
|
|
237
|
+
response: Response,
|
|
238
|
+
) -> SweatStackUser | None:
|
|
239
|
+
"""Dependency: returns selected user or None."""
|
|
240
|
+
session = _get_session_or_none(request)
|
|
241
|
+
if not session:
|
|
242
|
+
return None
|
|
243
|
+
try:
|
|
244
|
+
return _create_user(session, response, use_delegated=True)
|
|
245
|
+
except HTTPException:
|
|
246
|
+
return None
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
# ---------------------------------------------------------------------------
|
|
250
|
+
# Public type aliases
|
|
251
|
+
# ---------------------------------------------------------------------------
|
|
252
|
+
|
|
253
|
+
AuthenticatedUser = Annotated[SweatStackUser, Depends(_require_authenticated_user)]
|
|
254
|
+
"""Dependency that always returns the principal (logged-in) user.
|
|
255
|
+
|
|
256
|
+
Example:
|
|
257
|
+
@app.get("/my-athletes")
|
|
258
|
+
def get_athletes(user: AuthenticatedUser):
|
|
259
|
+
return user.client.get_users()
|
|
260
|
+
"""
|
|
261
|
+
|
|
262
|
+
SelectedUser = Annotated[SweatStackUser, Depends(_require_selected_user)]
|
|
263
|
+
"""Dependency that returns the currently selected user.
|
|
264
|
+
|
|
265
|
+
Returns the delegated user if one is selected, otherwise the principal user.
|
|
266
|
+
|
|
267
|
+
Example:
|
|
268
|
+
@app.get("/activities")
|
|
269
|
+
def get_activities(user: SelectedUser):
|
|
270
|
+
return user.client.get_activities()
|
|
271
|
+
"""
|
|
272
|
+
|
|
273
|
+
OptionalUser = Annotated[SweatStackUser | None, Depends(_optional_authenticated_user)]
|
|
274
|
+
"""Dependency that returns the principal user or None if not authenticated.
|
|
275
|
+
|
|
276
|
+
Example:
|
|
277
|
+
@app.get("/")
|
|
278
|
+
def home(user: OptionalUser):
|
|
279
|
+
if user:
|
|
280
|
+
return {"logged_in": True, "user_id": user.user_id}
|
|
281
|
+
return {"logged_in": False}
|
|
282
|
+
"""
|
|
283
|
+
|
|
284
|
+
OptionalSelectedUser = Annotated[SweatStackUser | None, Depends(_optional_selected_user)]
|
|
285
|
+
"""Dependency that returns the selected user or None if not authenticated.
|
|
286
|
+
|
|
287
|
+
Example:
|
|
288
|
+
@app.get("/public-profile")
|
|
289
|
+
def profile(user: OptionalSelectedUser):
|
|
290
|
+
if user:
|
|
291
|
+
return user.client.get_user()
|
|
292
|
+
return {"message": "Not logged in"}
|
|
293
|
+
"""
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
"""Data models for FastAPI session management."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from pydantic import SecretStr
|
|
9
|
+
|
|
10
|
+
from ..utils import decode_jwt_body
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(frozen=True, slots=True)
|
|
14
|
+
class TokenSet:
|
|
15
|
+
"""Immutable token pair with user ID.
|
|
16
|
+
|
|
17
|
+
This represents either principal or delegated tokens stored in the session.
|
|
18
|
+
The frozen=True ensures tokens can't be accidentally modified.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
access_token: str
|
|
22
|
+
refresh_token: str
|
|
23
|
+
user_id: str
|
|
24
|
+
|
|
25
|
+
def to_dict(self) -> dict[str, str]:
|
|
26
|
+
"""Serialize to dictionary for session storage."""
|
|
27
|
+
return {
|
|
28
|
+
"access_token": self.access_token,
|
|
29
|
+
"refresh_token": self.refresh_token,
|
|
30
|
+
"user_id": self.user_id,
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
@classmethod
|
|
34
|
+
def from_dict(cls, data: dict[str, Any]) -> TokenSet:
|
|
35
|
+
"""Deserialize from dictionary."""
|
|
36
|
+
return cls(
|
|
37
|
+
access_token=data["access_token"],
|
|
38
|
+
refresh_token=data["refresh_token"],
|
|
39
|
+
user_id=data["user_id"],
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass(slots=True)
|
|
44
|
+
class SessionData:
|
|
45
|
+
"""Type-safe wrapper for session data.
|
|
46
|
+
|
|
47
|
+
Handles both the new format (with principal/delegated) and legacy format
|
|
48
|
+
(flat access_token/refresh_token/user_id) for backwards compatibility.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
principal: TokenSet
|
|
52
|
+
delegated: TokenSet | None = None
|
|
53
|
+
|
|
54
|
+
def to_dict(self) -> dict[str, Any]:
|
|
55
|
+
"""Serialize to dictionary for cookie storage."""
|
|
56
|
+
data: dict[str, Any] = {"principal": self.principal.to_dict()}
|
|
57
|
+
if self.delegated:
|
|
58
|
+
data["delegated"] = self.delegated.to_dict()
|
|
59
|
+
return data
|
|
60
|
+
|
|
61
|
+
@classmethod
|
|
62
|
+
def from_dict(cls, data: dict[str, Any]) -> SessionData:
|
|
63
|
+
"""Deserialize from dictionary.
|
|
64
|
+
|
|
65
|
+
Handles both new format and legacy format for backwards compatibility.
|
|
66
|
+
"""
|
|
67
|
+
# New format: has "principal" key
|
|
68
|
+
if "principal" in data:
|
|
69
|
+
return cls(
|
|
70
|
+
principal=TokenSet.from_dict(data["principal"]),
|
|
71
|
+
delegated=TokenSet.from_dict(data["delegated"]) if data.get("delegated") else None,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Legacy format: flat structure with access_token, refresh_token, user_id
|
|
75
|
+
# Migrate to new format by treating as principal
|
|
76
|
+
return cls(
|
|
77
|
+
principal=TokenSet(
|
|
78
|
+
access_token=data["access_token"],
|
|
79
|
+
refresh_token=data["refresh_token"],
|
|
80
|
+
user_id=data["user_id"],
|
|
81
|
+
),
|
|
82
|
+
delegated=None,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def extract_user_id(jwt_token: str | SecretStr) -> str:
|
|
87
|
+
"""Extract user ID ('sub' claim) from a JWT token.
|
|
88
|
+
|
|
89
|
+
This does not validate the signature - the token was already validated
|
|
90
|
+
by the API when it was issued.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
jwt_token: The JWT access token (str or SecretStr).
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
The user ID from the token's 'sub' claim.
|
|
97
|
+
|
|
98
|
+
Raises:
|
|
99
|
+
ValueError: If the token is malformed or missing the 'sub' claim.
|
|
100
|
+
"""
|
|
101
|
+
try:
|
|
102
|
+
token_str = jwt_token.get_secret_value() if isinstance(jwt_token, SecretStr) else jwt_token
|
|
103
|
+
payload = decode_jwt_body(token_str)
|
|
104
|
+
user_id = payload.get("sub")
|
|
105
|
+
if not user_id:
|
|
106
|
+
raise ValueError("Token missing 'sub' claim")
|
|
107
|
+
return user_id
|
|
108
|
+
except (IndexError, KeyError) as e:
|
|
109
|
+
raise ValueError(f"Malformed JWT token: {e}") from e
|
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
"""OAuth routes for the FastAPI plugin."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import base64
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
import secrets
|
|
9
|
+
from urllib.parse import urlencode, urlparse
|
|
10
|
+
|
|
11
|
+
import httpx
|
|
12
|
+
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response
|
|
13
|
+
from fastapi.responses import RedirectResponse
|
|
14
|
+
|
|
15
|
+
from ..constants import DEFAULT_URL
|
|
16
|
+
from ..utils import decode_jwt_body
|
|
17
|
+
from .config import get_config
|
|
18
|
+
from .models import SessionData, TokenSet
|
|
19
|
+
from .session import (
|
|
20
|
+
SESSION_COOKIE_NAME,
|
|
21
|
+
STATE_COOKIE_NAME,
|
|
22
|
+
clear_session_cookie,
|
|
23
|
+
clear_state_cookie,
|
|
24
|
+
decrypt_session,
|
|
25
|
+
set_session_cookie,
|
|
26
|
+
set_state_cookie,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def validate_redirect(url: str | None) -> str | None:
|
|
33
|
+
"""Validate that a redirect URL is a safe relative path.
|
|
34
|
+
|
|
35
|
+
Returns the URL if valid, None otherwise.
|
|
36
|
+
"""
|
|
37
|
+
if url and url.startswith("/") and not url.startswith("//"):
|
|
38
|
+
return url
|
|
39
|
+
return None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _is_same_origin(referer: str | None, app_url: str) -> bool:
|
|
43
|
+
"""Check if a referer URL is from the same origin as the app."""
|
|
44
|
+
if not referer:
|
|
45
|
+
return False
|
|
46
|
+
try:
|
|
47
|
+
ref_parsed = urlparse(referer)
|
|
48
|
+
app_parsed = urlparse(app_url)
|
|
49
|
+
return (
|
|
50
|
+
ref_parsed.scheme == app_parsed.scheme
|
|
51
|
+
and ref_parsed.netloc == app_parsed.netloc
|
|
52
|
+
)
|
|
53
|
+
except Exception:
|
|
54
|
+
return False
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _get_redirect_url(request: Request, next_param: str | None) -> str:
|
|
58
|
+
"""Determine the redirect URL after a user selection change.
|
|
59
|
+
|
|
60
|
+
Priority: ?next= parameter > Referer header (if same-origin) > /
|
|
61
|
+
"""
|
|
62
|
+
# First try the explicit next parameter
|
|
63
|
+
if validated := validate_redirect(next_param):
|
|
64
|
+
return validated
|
|
65
|
+
|
|
66
|
+
# Then try the Referer header if same-origin
|
|
67
|
+
config = get_config()
|
|
68
|
+
referer = request.headers.get("referer")
|
|
69
|
+
if _is_same_origin(referer, config.app_url):
|
|
70
|
+
# Extract just the path from referer
|
|
71
|
+
parsed = urlparse(referer)
|
|
72
|
+
path = parsed.path
|
|
73
|
+
if parsed.query:
|
|
74
|
+
path += f"?{parsed.query}"
|
|
75
|
+
if validated := validate_redirect(path):
|
|
76
|
+
return validated
|
|
77
|
+
|
|
78
|
+
# Default to root
|
|
79
|
+
return "/"
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _get_session_data(request: Request) -> SessionData | None:
|
|
83
|
+
"""Get session data from request cookie."""
|
|
84
|
+
raw_session = decrypt_session(request.cookies.get(SESSION_COOKIE_NAME))
|
|
85
|
+
if not raw_session:
|
|
86
|
+
return None
|
|
87
|
+
try:
|
|
88
|
+
return SessionData.from_dict(raw_session)
|
|
89
|
+
except (KeyError, TypeError):
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _fetch_delegated_token(principal_tokens: TokenSet, target_user_id: str) -> TokenSet:
|
|
94
|
+
"""Fetch a delegated token for the target user using principal credentials."""
|
|
95
|
+
config = get_config()
|
|
96
|
+
|
|
97
|
+
response = httpx.post(
|
|
98
|
+
f"{DEFAULT_URL}/api/v1/oauth/delegated-token",
|
|
99
|
+
headers={"Authorization": f"Bearer {principal_tokens.access_token}"},
|
|
100
|
+
json={"sub": target_user_id},
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
if response.status_code == 403:
|
|
104
|
+
raise HTTPException(status_code=403, detail="You don't have access to this user")
|
|
105
|
+
if response.status_code == 404:
|
|
106
|
+
raise HTTPException(status_code=404, detail="User not found")
|
|
107
|
+
|
|
108
|
+
response.raise_for_status()
|
|
109
|
+
tokens = response.json()
|
|
110
|
+
|
|
111
|
+
return TokenSet(
|
|
112
|
+
access_token=tokens["access_token"],
|
|
113
|
+
refresh_token=tokens["refresh_token"],
|
|
114
|
+
user_id=target_user_id,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def create_state(next_url: str | None) -> str:
|
|
119
|
+
"""Create an OAuth state value with nonce and optional redirect."""
|
|
120
|
+
nonce = secrets.token_urlsafe(32)
|
|
121
|
+
state_data = {"nonce": nonce}
|
|
122
|
+
if next_url:
|
|
123
|
+
state_data["next"] = next_url
|
|
124
|
+
return base64.urlsafe_b64encode(json.dumps(state_data).encode()).decode()
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def parse_state(state: str) -> dict:
|
|
128
|
+
"""Parse an OAuth state value."""
|
|
129
|
+
try:
|
|
130
|
+
return json.loads(base64.urlsafe_b64decode(state.encode()))
|
|
131
|
+
except Exception:
|
|
132
|
+
return {}
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def create_router() -> APIRouter:
|
|
136
|
+
"""Create the auth router with login, callback, and logout routes."""
|
|
137
|
+
router = APIRouter()
|
|
138
|
+
|
|
139
|
+
@router.get("/login")
|
|
140
|
+
def login(request: Request, next: str | None = None) -> Response:
|
|
141
|
+
"""Redirect to SweatStack OAuth authorization."""
|
|
142
|
+
config = get_config()
|
|
143
|
+
|
|
144
|
+
# Validate and create state
|
|
145
|
+
validated_next = validate_redirect(next)
|
|
146
|
+
state = create_state(validated_next)
|
|
147
|
+
|
|
148
|
+
# Build authorization URL
|
|
149
|
+
params = {
|
|
150
|
+
"client_id": config.client_id,
|
|
151
|
+
"redirect_uri": config.redirect_uri,
|
|
152
|
+
"scope": " ".join(config.scopes),
|
|
153
|
+
"state": state,
|
|
154
|
+
"prompt": "none",
|
|
155
|
+
}
|
|
156
|
+
auth_url = f"{DEFAULT_URL}/oauth/authorize?{urlencode(params)}"
|
|
157
|
+
|
|
158
|
+
# Set state cookie and redirect
|
|
159
|
+
response = RedirectResponse(url=auth_url, status_code=302)
|
|
160
|
+
set_state_cookie(response, state)
|
|
161
|
+
return response
|
|
162
|
+
|
|
163
|
+
@router.get("/callback")
|
|
164
|
+
def callback(
|
|
165
|
+
request: Request,
|
|
166
|
+
code: str | None = None,
|
|
167
|
+
state: str | None = None,
|
|
168
|
+
error: str | None = None,
|
|
169
|
+
) -> Response:
|
|
170
|
+
"""Handle OAuth callback from SweatStack."""
|
|
171
|
+
config = get_config()
|
|
172
|
+
|
|
173
|
+
# Get state cookie
|
|
174
|
+
state_cookie = request.cookies.get(STATE_COOKIE_NAME)
|
|
175
|
+
|
|
176
|
+
# Clear state cookie regardless of outcome
|
|
177
|
+
response = RedirectResponse(url="/", status_code=302)
|
|
178
|
+
clear_state_cookie(response)
|
|
179
|
+
|
|
180
|
+
# Handle OAuth errors
|
|
181
|
+
if error:
|
|
182
|
+
return response
|
|
183
|
+
|
|
184
|
+
# Verify state (CSRF protection)
|
|
185
|
+
if not state or not state_cookie or state != state_cookie:
|
|
186
|
+
return Response(content="Invalid state", status_code=400)
|
|
187
|
+
|
|
188
|
+
# Exchange code for tokens
|
|
189
|
+
if not code:
|
|
190
|
+
return Response(content="Missing authorization code", status_code=400)
|
|
191
|
+
|
|
192
|
+
try:
|
|
193
|
+
token_response = httpx.post(
|
|
194
|
+
f"{DEFAULT_URL}/api/v1/oauth/token",
|
|
195
|
+
data={
|
|
196
|
+
"grant_type": "authorization_code",
|
|
197
|
+
"client_id": config.client_id,
|
|
198
|
+
"client_secret": config.client_secret.get_secret_value(),
|
|
199
|
+
"code": code,
|
|
200
|
+
"redirect_uri": config.redirect_uri,
|
|
201
|
+
},
|
|
202
|
+
)
|
|
203
|
+
token_response.raise_for_status()
|
|
204
|
+
tokens = token_response.json()
|
|
205
|
+
except Exception:
|
|
206
|
+
return response # Redirect to / on token exchange failure
|
|
207
|
+
|
|
208
|
+
access_token = tokens.get("access_token")
|
|
209
|
+
refresh_token = tokens.get("refresh_token")
|
|
210
|
+
|
|
211
|
+
if not access_token:
|
|
212
|
+
return response
|
|
213
|
+
|
|
214
|
+
# Extract user_id from JWT
|
|
215
|
+
try:
|
|
216
|
+
token_body = decode_jwt_body(access_token)
|
|
217
|
+
user_id = token_body.get("sub")
|
|
218
|
+
except Exception:
|
|
219
|
+
return response
|
|
220
|
+
|
|
221
|
+
if not user_id:
|
|
222
|
+
return response
|
|
223
|
+
|
|
224
|
+
# Create session
|
|
225
|
+
session_data = {
|
|
226
|
+
"access_token": access_token,
|
|
227
|
+
"refresh_token": refresh_token,
|
|
228
|
+
"user_id": user_id,
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
# Determine redirect URL from state
|
|
232
|
+
state_data = parse_state(state)
|
|
233
|
+
redirect_url = state_data.get("next", "/")
|
|
234
|
+
|
|
235
|
+
response = RedirectResponse(url=redirect_url, status_code=302)
|
|
236
|
+
clear_state_cookie(response)
|
|
237
|
+
set_session_cookie(response, session_data)
|
|
238
|
+
return response
|
|
239
|
+
|
|
240
|
+
@router.post("/logout")
|
|
241
|
+
def logout() -> Response:
|
|
242
|
+
"""Clear session and redirect to /."""
|
|
243
|
+
response = RedirectResponse(url="/", status_code=302)
|
|
244
|
+
clear_session_cookie(response)
|
|
245
|
+
return response
|
|
246
|
+
|
|
247
|
+
@router.post("/select-user/{user_id}")
|
|
248
|
+
def select_user(request: Request, user_id: str, next: str | None = None) -> Response:
|
|
249
|
+
"""Switch to viewing as another user.
|
|
250
|
+
|
|
251
|
+
Fetches a delegated token for the target user and stores it in the session.
|
|
252
|
+
Redirects to Referer (if same-origin), ?next= parameter, or /.
|
|
253
|
+
"""
|
|
254
|
+
session = _get_session_data(request)
|
|
255
|
+
if not session:
|
|
256
|
+
raise HTTPException(status_code=401, detail="Not authenticated")
|
|
257
|
+
|
|
258
|
+
# Fetch delegated token for the target user
|
|
259
|
+
try:
|
|
260
|
+
delegated_tokens = _fetch_delegated_token(session.principal, user_id)
|
|
261
|
+
except httpx.HTTPStatusError as e:
|
|
262
|
+
logger.warning("Failed to fetch delegated token for user %s: %s", user_id, e)
|
|
263
|
+
raise HTTPException(status_code=403, detail="You don't have access to this user")
|
|
264
|
+
|
|
265
|
+
# Update session with delegated tokens
|
|
266
|
+
updated_session = SessionData(
|
|
267
|
+
principal=session.principal,
|
|
268
|
+
delegated=delegated_tokens,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
redirect_url = _get_redirect_url(request, next)
|
|
272
|
+
response = RedirectResponse(url=redirect_url, status_code=303)
|
|
273
|
+
set_session_cookie(response, updated_session.to_dict())
|
|
274
|
+
return response
|
|
275
|
+
|
|
276
|
+
@router.post("/select-self")
|
|
277
|
+
def select_self(request: Request, next: str | None = None) -> Response:
|
|
278
|
+
"""Switch back to viewing as yourself (clear delegation).
|
|
279
|
+
|
|
280
|
+
Removes the delegated tokens from the session.
|
|
281
|
+
Redirects to Referer (if same-origin), ?next= parameter, or /.
|
|
282
|
+
"""
|
|
283
|
+
session = _get_session_data(request)
|
|
284
|
+
if not session:
|
|
285
|
+
raise HTTPException(status_code=401, detail="Not authenticated")
|
|
286
|
+
|
|
287
|
+
# Clear delegation
|
|
288
|
+
updated_session = SessionData(
|
|
289
|
+
principal=session.principal,
|
|
290
|
+
delegated=None,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
redirect_url = _get_redirect_url(request, next)
|
|
294
|
+
response = RedirectResponse(url=redirect_url, status_code=303)
|
|
295
|
+
set_session_cookie(response, updated_session.to_dict())
|
|
296
|
+
return response
|
|
297
|
+
|
|
298
|
+
return router
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def instrument(app: FastAPI) -> None:
|
|
302
|
+
"""Add SweatStack auth routes to a FastAPI application.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
app: The FastAPI application to instrument.
|
|
306
|
+
|
|
307
|
+
Raises:
|
|
308
|
+
RuntimeError: If configure() has not been called.
|
|
309
|
+
"""
|
|
310
|
+
config = get_config() # This will raise if not configured
|
|
311
|
+
router = create_router()
|
|
312
|
+
app.include_router(router, prefix=config.auth_route_prefix)
|