sweatstack 0.60.0__py3-none-any.whl → 0.62.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,175 @@
1
+ """Data models for FastAPI session management."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from datetime import datetime
7
+ from typing import Any, Protocol, runtime_checkable
8
+
9
+ from pydantic import SecretStr
10
+
11
+ from ..utils import decode_jwt_body
12
+
13
+
14
+ @dataclass(frozen=True, slots=True)
15
+ class TokenSet:
16
+ """Immutable token pair with user ID.
17
+
18
+ This represents either principal or delegated tokens stored in the session.
19
+ The frozen=True ensures tokens can't be accidentally modified.
20
+ """
21
+
22
+ access_token: str
23
+ refresh_token: str
24
+ user_id: str
25
+
26
+ def to_dict(self) -> dict[str, str]:
27
+ """Serialize to dictionary for session storage."""
28
+ return {
29
+ "access_token": self.access_token,
30
+ "refresh_token": self.refresh_token,
31
+ "user_id": self.user_id,
32
+ }
33
+
34
+ @classmethod
35
+ def from_dict(cls, data: dict[str, Any]) -> TokenSet:
36
+ """Deserialize from dictionary."""
37
+ return cls(
38
+ access_token=data["access_token"],
39
+ refresh_token=data["refresh_token"],
40
+ user_id=data["user_id"],
41
+ )
42
+
43
+
44
+ @dataclass(slots=True)
45
+ class SessionData:
46
+ """Type-safe wrapper for session data.
47
+
48
+ Handles both the new format (with principal/delegated) and legacy format
49
+ (flat access_token/refresh_token/user_id) for backwards compatibility.
50
+ """
51
+
52
+ principal: TokenSet
53
+ delegated: TokenSet | None = None
54
+
55
+ def to_dict(self) -> dict[str, Any]:
56
+ """Serialize to dictionary for cookie storage."""
57
+ data: dict[str, Any] = {"principal": self.principal.to_dict()}
58
+ if self.delegated:
59
+ data["delegated"] = self.delegated.to_dict()
60
+ return data
61
+
62
+ @classmethod
63
+ def from_dict(cls, data: dict[str, Any]) -> SessionData:
64
+ """Deserialize from dictionary.
65
+
66
+ Handles both new format and legacy format for backwards compatibility.
67
+ """
68
+ # New format: has "principal" key
69
+ if "principal" in data:
70
+ return cls(
71
+ principal=TokenSet.from_dict(data["principal"]),
72
+ delegated=TokenSet.from_dict(data["delegated"]) if data.get("delegated") else None,
73
+ )
74
+
75
+ # Legacy format: flat structure with access_token, refresh_token, user_id
76
+ # Migrate to new format by treating as principal
77
+ return cls(
78
+ principal=TokenSet(
79
+ access_token=data["access_token"],
80
+ refresh_token=data["refresh_token"],
81
+ user_id=data["user_id"],
82
+ ),
83
+ delegated=None,
84
+ )
85
+
86
+
87
+ def extract_user_id(jwt_token: str | SecretStr) -> str:
88
+ """Extract user ID ('sub' claim) from a JWT token.
89
+
90
+ This does not validate the signature - the token was already validated
91
+ by the API when it was issued.
92
+
93
+ Args:
94
+ jwt_token: The JWT access token (str or SecretStr).
95
+
96
+ Returns:
97
+ The user ID from the token's 'sub' claim.
98
+
99
+ Raises:
100
+ ValueError: If the token is malformed or missing the 'sub' claim.
101
+ """
102
+ try:
103
+ token_str = jwt_token.get_secret_value() if isinstance(jwt_token, SecretStr) else jwt_token
104
+ payload = decode_jwt_body(token_str)
105
+ user_id = payload.get("sub")
106
+ if not user_id:
107
+ raise ValueError("Token missing 'sub' claim")
108
+ return user_id
109
+ except (IndexError, KeyError) as e:
110
+ raise ValueError(f"Malformed JWT token: {e}") from e
111
+
112
+
113
+ # ---------------------------------------------------------------------------
114
+ # Token storage for webhook support
115
+ # ---------------------------------------------------------------------------
116
+
117
+
118
+ @dataclass(frozen=True, slots=True)
119
+ class StoredTokens:
120
+ """Token data for persistent storage.
121
+
122
+ Used by TokenStore implementations to persist tokens for webhook handling.
123
+ The library does NOT encrypt tokens - see Security section in documentation.
124
+
125
+ Attributes:
126
+ user_id: The SweatStack user ID.
127
+ access_token: The OAuth access token.
128
+ refresh_token: The OAuth refresh token.
129
+ expires_at: When the access token expires.
130
+ """
131
+
132
+ user_id: str
133
+ access_token: str
134
+ refresh_token: str
135
+ expires_at: datetime
136
+
137
+ def __repr__(self) -> str:
138
+ """Hide sensitive data in logs."""
139
+ return (
140
+ f"StoredTokens(user_id={self.user_id!r}, "
141
+ f"access_token='***', refresh_token='***', "
142
+ f"expires_at={self.expires_at!r})"
143
+ )
144
+
145
+
146
+ @runtime_checkable
147
+ class TokenStore(Protocol):
148
+ """Protocol for persisting OAuth tokens.
149
+
150
+ Implement this interface to enable AuthenticatedUser in webhook handlers.
151
+ The library calls these methods automatically:
152
+ - save(): After OAuth callback and token refresh
153
+ - load(): When handling webhooks
154
+ - delete(): On logout
155
+
156
+ Thread Safety:
157
+ All methods may be called concurrently. Your implementation must be thread-safe.
158
+
159
+ Error Handling:
160
+ - save(): Use upsert semantics (don't raise on duplicate user_id)
161
+ - load(): Return None if user not found (don't raise)
162
+ - delete(): Be idempotent (don't raise if user doesn't exist)
163
+ """
164
+
165
+ def save(self, tokens: StoredTokens) -> None:
166
+ """Save or update tokens for a user."""
167
+ ...
168
+
169
+ def load(self, user_id: str) -> StoredTokens | None:
170
+ """Load tokens for a user. Returns None if not found."""
171
+ ...
172
+
173
+ def delete(self, user_id: str) -> None:
174
+ """Delete tokens for a user. Idempotent."""
175
+ ...
@@ -1,25 +1,34 @@
1
1
  """OAuth routes for the FastAPI plugin."""
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import base64
4
6
  import json
7
+ import logging
5
8
  import secrets
6
- from urllib.parse import urlencode
9
+ from urllib.parse import urlencode, urlparse
7
10
 
8
11
  import httpx
9
- from fastapi import APIRouter, FastAPI, Request, Response
12
+ from fastapi import APIRouter, FastAPI, HTTPException, Request, Response
10
13
  from fastapi.responses import RedirectResponse
11
14
 
12
15
  from ..constants import DEFAULT_URL
13
16
  from ..utils import decode_jwt_body
14
17
  from .config import get_config
18
+ from .dependencies import _extract_expiry
19
+ from .models import SessionData, StoredTokens, TokenSet
15
20
  from .session import (
21
+ SESSION_COOKIE_NAME,
16
22
  STATE_COOKIE_NAME,
17
23
  clear_session_cookie,
18
24
  clear_state_cookie,
25
+ decrypt_session,
19
26
  set_session_cookie,
20
27
  set_state_cookie,
21
28
  )
22
29
 
30
+ logger = logging.getLogger(__name__)
31
+
23
32
 
24
33
  def validate_redirect(url: str | None) -> str | None:
25
34
  """Validate that a redirect URL is a safe relative path.
@@ -31,6 +40,82 @@ def validate_redirect(url: str | None) -> str | None:
31
40
  return None
32
41
 
33
42
 
43
+ def _is_same_origin(referer: str | None, app_url: str) -> bool:
44
+ """Check if a referer URL is from the same origin as the app."""
45
+ if not referer:
46
+ return False
47
+ try:
48
+ ref_parsed = urlparse(referer)
49
+ app_parsed = urlparse(app_url)
50
+ return (
51
+ ref_parsed.scheme == app_parsed.scheme
52
+ and ref_parsed.netloc == app_parsed.netloc
53
+ )
54
+ except Exception:
55
+ return False
56
+
57
+
58
+ def _get_redirect_url(request: Request, next_param: str | None) -> str:
59
+ """Determine the redirect URL after a user selection change.
60
+
61
+ Priority: ?next= parameter > Referer header (if same-origin) > /
62
+ """
63
+ # First try the explicit next parameter
64
+ if validated := validate_redirect(next_param):
65
+ return validated
66
+
67
+ # Then try the Referer header if same-origin
68
+ config = get_config()
69
+ referer = request.headers.get("referer")
70
+ if _is_same_origin(referer, config.app_url):
71
+ # Extract just the path from referer
72
+ parsed = urlparse(referer)
73
+ path = parsed.path
74
+ if parsed.query:
75
+ path += f"?{parsed.query}"
76
+ if validated := validate_redirect(path):
77
+ return validated
78
+
79
+ # Default to root
80
+ return "/"
81
+
82
+
83
+ def _get_session_data(request: Request) -> SessionData | None:
84
+ """Get session data from request cookie."""
85
+ raw_session = decrypt_session(request.cookies.get(SESSION_COOKIE_NAME))
86
+ if not raw_session:
87
+ return None
88
+ try:
89
+ return SessionData.from_dict(raw_session)
90
+ except (KeyError, TypeError):
91
+ return None
92
+
93
+
94
+ def _fetch_delegated_token(principal_tokens: TokenSet, target_user_id: str) -> TokenSet:
95
+ """Fetch a delegated token for the target user using principal credentials."""
96
+ config = get_config()
97
+
98
+ response = httpx.post(
99
+ f"{DEFAULT_URL}/api/v1/oauth/delegated-token",
100
+ headers={"Authorization": f"Bearer {principal_tokens.access_token}"},
101
+ json={"sub": target_user_id},
102
+ )
103
+
104
+ if response.status_code == 403:
105
+ raise HTTPException(status_code=403, detail="You don't have access to this user")
106
+ if response.status_code == 404:
107
+ raise HTTPException(status_code=404, detail="User not found")
108
+
109
+ response.raise_for_status()
110
+ tokens = response.json()
111
+
112
+ return TokenSet(
113
+ access_token=tokens["access_token"],
114
+ refresh_token=tokens["refresh_token"],
115
+ user_id=target_user_id,
116
+ )
117
+
118
+
34
119
  def create_state(next_url: str | None) -> str:
35
120
  """Create an OAuth state value with nonce and optional redirect."""
36
121
  nonce = secrets.token_urlsafe(32)
@@ -82,28 +167,34 @@ def create_router() -> APIRouter:
82
167
  code: str | None = None,
83
168
  state: str | None = None,
84
169
  error: str | None = None,
170
+ error_description: str | None = None,
85
171
  ) -> Response:
86
172
  """Handle OAuth callback from SweatStack."""
87
173
  config = get_config()
88
174
 
175
+ def error_redirect(error_code: str) -> Response:
176
+ """Redirect to / with error code in query params."""
177
+ response = RedirectResponse(url=f"/?auth_error={error_code}", status_code=302)
178
+ clear_state_cookie(response)
179
+ return response
180
+
89
181
  # Get state cookie
90
182
  state_cookie = request.cookies.get(STATE_COOKIE_NAME)
91
183
 
92
- # Clear state cookie regardless of outcome
93
- response = RedirectResponse(url="/", status_code=302)
94
- clear_state_cookie(response)
95
-
96
- # Handle OAuth errors
184
+ # Handle OAuth errors from provider
97
185
  if error:
98
- return response
186
+ logger.warning("OAuth error from provider: %s - %s", error, error_description)
187
+ return error_redirect(error)
99
188
 
100
189
  # Verify state (CSRF protection)
101
190
  if not state or not state_cookie or state != state_cookie:
102
- return Response(content="Invalid state", status_code=400)
191
+ logger.warning("OAuth state mismatch (possible CSRF)")
192
+ return error_redirect("invalid_state")
103
193
 
104
194
  # Exchange code for tokens
105
195
  if not code:
106
- return Response(content="Missing authorization code", status_code=400)
196
+ logger.warning("OAuth callback missing authorization code")
197
+ return error_redirect("missing_code")
107
198
 
108
199
  try:
109
200
  token_response = httpx.post(
@@ -118,24 +209,31 @@ def create_router() -> APIRouter:
118
209
  )
119
210
  token_response.raise_for_status()
120
211
  tokens = token_response.json()
121
- except Exception:
122
- return response # Redirect to / on token exchange failure
212
+ except httpx.HTTPStatusError as e:
213
+ logger.error("Token exchange failed: %s - %s", e.response.status_code, e.response.text)
214
+ return error_redirect("token_exchange_failed")
215
+ except Exception as e:
216
+ logger.error("Token exchange error: %s", e)
217
+ return error_redirect("token_exchange_failed")
123
218
 
124
219
  access_token = tokens.get("access_token")
125
220
  refresh_token = tokens.get("refresh_token")
126
221
 
127
222
  if not access_token:
128
- return response
223
+ logger.error("Token response missing access_token")
224
+ return error_redirect("invalid_token_response")
129
225
 
130
226
  # Extract user_id from JWT
131
227
  try:
132
228
  token_body = decode_jwt_body(access_token)
133
229
  user_id = token_body.get("sub")
134
- except Exception:
135
- return response
230
+ except Exception as e:
231
+ logger.error("Failed to decode access token: %s", e)
232
+ return error_redirect("invalid_token")
136
233
 
137
234
  if not user_id:
138
- return response
235
+ logger.error("Access token missing 'sub' claim")
236
+ return error_redirect("invalid_token")
139
237
 
140
238
  # Create session
141
239
  session_data = {
@@ -144,6 +242,17 @@ def create_router() -> APIRouter:
144
242
  "user_id": user_id,
145
243
  }
146
244
 
245
+ # Persist tokens to store if configured
246
+ if config.token_store:
247
+ config.token_store.save(
248
+ StoredTokens(
249
+ user_id=user_id,
250
+ access_token=access_token,
251
+ refresh_token=refresh_token,
252
+ expires_at=_extract_expiry(access_token),
253
+ )
254
+ )
255
+
147
256
  # Determine redirect URL from state
148
257
  state_data = parse_state(state)
149
258
  redirect_url = state_data.get("next", "/")
@@ -154,15 +263,107 @@ def create_router() -> APIRouter:
154
263
  return response
155
264
 
156
265
  @router.post("/logout")
157
- def logout() -> Response:
266
+ def logout(request: Request) -> Response:
158
267
  """Clear session and redirect to /."""
268
+ config = get_config()
269
+
270
+ # Delete tokens from store if configured
271
+ if config.token_store:
272
+ session = _get_session_data(request)
273
+ if session:
274
+ config.token_store.delete(session.principal.user_id)
275
+
159
276
  response = RedirectResponse(url="/", status_code=302)
160
277
  clear_session_cookie(response)
161
278
  return response
162
279
 
280
+ @router.post("/select-user/{user_id}")
281
+ def select_user(request: Request, user_id: str, next: str | None = None) -> Response:
282
+ """Switch to viewing as another user.
283
+
284
+ Fetches a delegated token for the target user and stores it in the session.
285
+ Redirects to Referer (if same-origin), ?next= parameter, or /.
286
+ """
287
+ session = _get_session_data(request)
288
+ if not session:
289
+ raise HTTPException(status_code=401, detail="Not authenticated")
290
+
291
+ # Fetch delegated token for the target user
292
+ try:
293
+ delegated_tokens = _fetch_delegated_token(session.principal, user_id)
294
+ except httpx.HTTPStatusError as e:
295
+ logger.warning("Failed to fetch delegated token for user %s: %s", user_id, e)
296
+ raise HTTPException(status_code=403, detail="You don't have access to this user")
297
+
298
+ # Update session with delegated tokens
299
+ updated_session = SessionData(
300
+ principal=session.principal,
301
+ delegated=delegated_tokens,
302
+ )
303
+
304
+ redirect_url = _get_redirect_url(request, next)
305
+ response = RedirectResponse(url=redirect_url, status_code=303)
306
+ set_session_cookie(response, updated_session.to_dict())
307
+ return response
308
+
309
+ @router.post("/select-self")
310
+ def select_self(request: Request, next: str | None = None) -> Response:
311
+ """Switch back to viewing as yourself (clear delegation).
312
+
313
+ Removes the delegated tokens from the session.
314
+ Redirects to Referer (if same-origin), ?next= parameter, or /.
315
+ """
316
+ session = _get_session_data(request)
317
+ if not session:
318
+ raise HTTPException(status_code=401, detail="Not authenticated")
319
+
320
+ # Clear delegation
321
+ updated_session = SessionData(
322
+ principal=session.principal,
323
+ delegated=None,
324
+ )
325
+
326
+ redirect_url = _get_redirect_url(request, next)
327
+ response = RedirectResponse(url=redirect_url, status_code=303)
328
+ set_session_cookie(response, updated_session.to_dict())
329
+ return response
330
+
163
331
  return router
164
332
 
165
333
 
334
+ def _warn_if_webhook_misconfigured(app: FastAPI) -> None:
335
+ """Log error if WebhookPayload is used but webhook_secret not configured."""
336
+ config = get_config()
337
+
338
+ if config.webhook_secret:
339
+ return # Properly configured
340
+
341
+ # Import here to avoid circular imports
342
+ from .webhooks import _require_webhook_payload
343
+
344
+ # Check if any route uses WebhookPayload dependency
345
+ for route in app.routes:
346
+ if not hasattr(route, "dependant"):
347
+ continue
348
+
349
+ if _uses_dependency(route.dependant, _require_webhook_payload):
350
+ raise RuntimeError(
351
+ f"Route '{route.path}' uses WebhookPayload but webhook_secret is not configured. "
352
+ "Webhook signature verification will fail at runtime. "
353
+ "Configure with the SWEATSTACK_WEBHOOK_SECRET env variable or configure(webhook_secret='whsec_...')"
354
+ )
355
+
356
+
357
+ def _uses_dependency(dependant, target_callable) -> bool:
358
+ """Check if a dependency tree includes the target callable."""
359
+ for dep in dependant.dependencies:
360
+ if dep.call is target_callable:
361
+ return True
362
+ if hasattr(dep, "dependant") and _uses_dependency(dep.dependant, target_callable):
363
+ return True
364
+ return False
365
+
366
+
166
367
  def instrument(app: FastAPI) -> None:
167
368
  """Add SweatStack auth routes to a FastAPI application.
168
369
 
@@ -175,3 +376,8 @@ def instrument(app: FastAPI) -> None:
175
376
  config = get_config() # This will raise if not configured
176
377
  router = create_router()
177
378
  app.include_router(router, prefix=config.auth_route_prefix)
379
+
380
+ # Validate webhook configuration at startup (after all routes are registered)
381
+ @app.on_event("startup")
382
+ def _check_webhook_config():
383
+ _warn_if_webhook_misconfigured(app)