caspian-utils 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- casp/__init__.py +0 -0
- casp/auth.py +537 -0
- casp/cache_handler.py +180 -0
- casp/caspian_config.py +441 -0
- casp/component_decorator.py +183 -0
- casp/components_compiler.py +293 -0
- casp/html_attrs.py +93 -0
- casp/layout.py +474 -0
- casp/loading.py +25 -0
- casp/rpc.py +230 -0
- casp/scripts_type.py +21 -0
- casp/state_manager.py +134 -0
- casp/string_helpers.py +18 -0
- casp/tw.py +31 -0
- casp/validate.py +747 -0
- caspian_utils-0.0.12.dist-info/METADATA +214 -0
- caspian_utils-0.0.12.dist-info/RECORD +19 -0
- caspian_utils-0.0.12.dist-info/WHEEL +5 -0
- caspian_utils-0.0.12.dist-info/top_level.txt +1 -0
casp/__init__.py
ADDED
|
File without changes
|
casp/auth.py
ADDED
|
@@ -0,0 +1,537 @@
|
|
|
1
|
+
from typing import Dict, Callable, Any, Optional, Union, List
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
import secrets
|
|
5
|
+
from contextvars import ContextVar
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from datetime import datetime, timezone, timedelta
|
|
8
|
+
from functools import wraps
|
|
9
|
+
import httpx
|
|
10
|
+
from fastapi import Request, Response
|
|
11
|
+
from fastapi.responses import RedirectResponse
|
|
12
|
+
import asyncio
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
_request_ctx: ContextVar[Optional[Request]
|
|
16
|
+
] = ContextVar('request', default=None)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class AuthSettings:
|
|
21
|
+
"""
|
|
22
|
+
App behavior configuration - set via configure_auth() in main.py.
|
|
23
|
+
Secrets are always read from environment variables.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
# Token settings
|
|
27
|
+
default_token_validity: str = "1h"
|
|
28
|
+
token_auto_refresh: bool = False
|
|
29
|
+
|
|
30
|
+
# Role-based access
|
|
31
|
+
role_identifier: str = "role"
|
|
32
|
+
is_role_based: bool = False
|
|
33
|
+
|
|
34
|
+
# Route protection
|
|
35
|
+
is_all_routes_private: bool = True
|
|
36
|
+
private_routes: List[str] = field(default_factory=list)
|
|
37
|
+
public_routes: List[str] = field(default_factory=lambda: ["/"])
|
|
38
|
+
auth_routes: List[str] = field(
|
|
39
|
+
default_factory=lambda: ["/signin", "/signup"])
|
|
40
|
+
role_based_routes: Dict[str, List[str]] = field(
|
|
41
|
+
default_factory=dict) # e.g. {"/admin": ["admin", "superadmin"]}
|
|
42
|
+
|
|
43
|
+
# Redirects
|
|
44
|
+
default_signin_redirect: str = "/dashboard"
|
|
45
|
+
default_signout_redirect: str = "/signin"
|
|
46
|
+
api_auth_prefix: str = "/api/auth"
|
|
47
|
+
|
|
48
|
+
# Callbacks (hooks for custom logic)
|
|
49
|
+
on_sign_in: Optional[Callable[[dict], None]] = None
|
|
50
|
+
on_sign_out: Optional[Callable[[], None]] = None
|
|
51
|
+
on_auth_failure: Optional[Callable[[Request], Response]] = None
|
|
52
|
+
|
|
53
|
+
# Secrets (always from env)
|
|
54
|
+
secret_key: str = field(default_factory=lambda: os.getenv(
|
|
55
|
+
"AUTH_SECRET", "default_secret_key_change_me"))
|
|
56
|
+
cookie_name: str = field(default_factory=lambda: os.getenv(
|
|
57
|
+
"AUTH_COOKIE_NAME", "auth_cookie"))
|
|
58
|
+
|
|
59
|
+
def __post_init__(self):
|
|
60
|
+
self.cookie_name = re.sub(
|
|
61
|
+
r"\s+", "_", self.cookie_name.strip()).lower()
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# Global settings instance
|
|
65
|
+
_settings: AuthSettings = AuthSettings()
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def configure_auth(settings: AuthSettings) -> None:
|
|
69
|
+
"""
|
|
70
|
+
Configure auth at app startup. Call this before app starts.
|
|
71
|
+
|
|
72
|
+
Example:
|
|
73
|
+
configure_auth(AuthSettings(
|
|
74
|
+
default_signin_redirect="/app",
|
|
75
|
+
private_routes=["/dashboard", "/settings"],
|
|
76
|
+
role_based_routes={
|
|
77
|
+
"/admin": ["admin", "superadmin"],
|
|
78
|
+
"/reports": ["admin", "manager", "user"],
|
|
79
|
+
},
|
|
80
|
+
on_sign_in=lambda user: print(f"Welcome {user.get('name')}"),
|
|
81
|
+
))
|
|
82
|
+
"""
|
|
83
|
+
global _settings
|
|
84
|
+
_settings = settings
|
|
85
|
+
Auth._instance = None
|
|
86
|
+
Auth.get_instance()
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def get_auth_settings() -> AuthSettings:
|
|
90
|
+
"""Get current auth settings."""
|
|
91
|
+
return _settings
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclass
|
|
95
|
+
class GoogleProvider:
|
|
96
|
+
"""Google OAuth provider. Reads secrets from env if not provided."""
|
|
97
|
+
client_id: str = field(
|
|
98
|
+
default_factory=lambda: os.getenv("GOOGLE_CLIENT_ID", ""))
|
|
99
|
+
client_secret: str = field(
|
|
100
|
+
default_factory=lambda: os.getenv("GOOGLE_CLIENT_SECRET", ""))
|
|
101
|
+
redirect_uri: str = field(
|
|
102
|
+
default_factory=lambda: os.getenv("GOOGLE_REDIRECT_URI", ""))
|
|
103
|
+
max_age: str = "30d"
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@dataclass
|
|
107
|
+
class GithubProvider:
|
|
108
|
+
"""GitHub OAuth provider. Reads secrets from env if not provided."""
|
|
109
|
+
client_id: str = field(
|
|
110
|
+
default_factory=lambda: os.getenv("GITHUB_CLIENT_ID", ""))
|
|
111
|
+
client_secret: str = field(
|
|
112
|
+
default_factory=lambda: os.getenv("GITHUB_CLIENT_SECRET", ""))
|
|
113
|
+
max_age: str = "30d"
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class Auth:
|
|
117
|
+
PAYLOAD_NAME = "payload_name_8639D"
|
|
118
|
+
PAYLOAD_SESSION_KEY = "payload_session_key_2183A"
|
|
119
|
+
|
|
120
|
+
_instance: Optional["Auth"] = None
|
|
121
|
+
_cookie_name: str = ""
|
|
122
|
+
_providers: List[Any] = []
|
|
123
|
+
|
|
124
|
+
def __init__(self) -> None:
|
|
125
|
+
self._settings = _settings
|
|
126
|
+
Auth._cookie_name = self._settings.cookie_name
|
|
127
|
+
|
|
128
|
+
@classmethod
|
|
129
|
+
def get_instance(cls) -> "Auth":
|
|
130
|
+
if cls._instance is None:
|
|
131
|
+
cls._instance = cls()
|
|
132
|
+
return cls._instance
|
|
133
|
+
|
|
134
|
+
@classmethod
|
|
135
|
+
def set_request(cls, request: Request):
|
|
136
|
+
_request_ctx.set(request)
|
|
137
|
+
|
|
138
|
+
@classmethod
|
|
139
|
+
def get_request(cls) -> Optional[Request]:
|
|
140
|
+
return _request_ctx.get()
|
|
141
|
+
|
|
142
|
+
@classmethod
|
|
143
|
+
def set_providers(cls, *providers) -> None:
|
|
144
|
+
"""Set OAuth providers for the auth instance."""
|
|
145
|
+
cls._providers = list(providers)
|
|
146
|
+
|
|
147
|
+
@classmethod
|
|
148
|
+
def get_providers(cls) -> List[Any]:
|
|
149
|
+
return cls._providers
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def settings(self) -> AuthSettings:
|
|
153
|
+
return self._settings
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def cookie_name(self) -> str:
|
|
157
|
+
return Auth._cookie_name
|
|
158
|
+
|
|
159
|
+
def _get_session(self) -> dict:
|
|
160
|
+
request = self.get_request()
|
|
161
|
+
if request and hasattr(request, 'session'):
|
|
162
|
+
return request.session
|
|
163
|
+
return {}
|
|
164
|
+
|
|
165
|
+
# ====
|
|
166
|
+
# ROUTE CHECKING
|
|
167
|
+
# ====
|
|
168
|
+
def is_public_route(self, path: str) -> bool:
|
|
169
|
+
"""Check if path is a public route."""
|
|
170
|
+
return path in self._settings.public_routes
|
|
171
|
+
|
|
172
|
+
def is_auth_route(self, path: str) -> bool:
|
|
173
|
+
"""Check if path is an auth route (signin/signup)."""
|
|
174
|
+
return path in self._settings.auth_routes
|
|
175
|
+
|
|
176
|
+
def is_private_route(self, path: str) -> bool:
|
|
177
|
+
"""Check if path requires authentication."""
|
|
178
|
+
if self._settings.is_all_routes_private:
|
|
179
|
+
return not (self.is_public_route(path) or self.is_auth_route(path))
|
|
180
|
+
return path in self._settings.private_routes
|
|
181
|
+
|
|
182
|
+
def get_required_roles(self, path: str) -> Optional[List[str]]:
|
|
183
|
+
"""Get required roles for a path, if any."""
|
|
184
|
+
return self._settings.role_based_routes.get(path)
|
|
185
|
+
|
|
186
|
+
# ====
|
|
187
|
+
# CORE AUTH METHODS
|
|
188
|
+
# ====
|
|
189
|
+
def sign_in(
|
|
190
|
+
self,
|
|
191
|
+
data: Union[dict, str, Any],
|
|
192
|
+
token_validity: Optional[str] = None,
|
|
193
|
+
redirect_to: Union[bool, str] = False,
|
|
194
|
+
) -> Union[str, Response]:
|
|
195
|
+
validity = token_validity or self._settings.default_token_validity
|
|
196
|
+
exp_time = self._calculate_expiration(validity)
|
|
197
|
+
|
|
198
|
+
data = self._normalize_payload(data)
|
|
199
|
+
|
|
200
|
+
payload = {
|
|
201
|
+
self.PAYLOAD_NAME: data,
|
|
202
|
+
"exp": exp_time,
|
|
203
|
+
"iat": datetime.now(timezone.utc).timestamp(),
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
session = self._get_session()
|
|
207
|
+
session[self.PAYLOAD_SESSION_KEY] = payload
|
|
208
|
+
session["csrf_token"] = secrets.token_hex(32)
|
|
209
|
+
|
|
210
|
+
# Call hook if configured
|
|
211
|
+
if self._settings.on_sign_in:
|
|
212
|
+
self._settings.on_sign_in(data if isinstance(
|
|
213
|
+
data, dict) else {"value": data})
|
|
214
|
+
|
|
215
|
+
if redirect_to is True:
|
|
216
|
+
return RedirectResponse(url=self._settings.default_signin_redirect, status_code=303)
|
|
217
|
+
if isinstance(redirect_to, str) and redirect_to:
|
|
218
|
+
return RedirectResponse(url=redirect_to, status_code=303)
|
|
219
|
+
|
|
220
|
+
return "ok"
|
|
221
|
+
|
|
222
|
+
def sign_out(self, redirect_to: Optional[str] = None) -> Optional[Response]:
|
|
223
|
+
session = self._get_session()
|
|
224
|
+
session.pop(self.PAYLOAD_SESSION_KEY, None)
|
|
225
|
+
session.pop("csrf_token", None)
|
|
226
|
+
session.clear()
|
|
227
|
+
|
|
228
|
+
# Call hook if configured
|
|
229
|
+
if self._settings.on_sign_out:
|
|
230
|
+
self._settings.on_sign_out()
|
|
231
|
+
|
|
232
|
+
target = redirect_to or self._settings.default_signout_redirect
|
|
233
|
+
if target:
|
|
234
|
+
return RedirectResponse(url=target, status_code=303)
|
|
235
|
+
return None
|
|
236
|
+
|
|
237
|
+
def is_authenticated(self) -> bool:
|
|
238
|
+
session = self._get_session()
|
|
239
|
+
payload = session.get(self.PAYLOAD_SESSION_KEY)
|
|
240
|
+
if not isinstance(payload, dict):
|
|
241
|
+
return False
|
|
242
|
+
|
|
243
|
+
exp = payload.get("exp")
|
|
244
|
+
if exp is not None:
|
|
245
|
+
try:
|
|
246
|
+
now_ts = datetime.now(timezone.utc).timestamp()
|
|
247
|
+
if float(exp) < now_ts:
|
|
248
|
+
session.pop(self.PAYLOAD_SESSION_KEY, None)
|
|
249
|
+
return False
|
|
250
|
+
except Exception:
|
|
251
|
+
session.pop(self.PAYLOAD_SESSION_KEY, None)
|
|
252
|
+
return False
|
|
253
|
+
|
|
254
|
+
if payload.get(self.PAYLOAD_NAME) is None:
|
|
255
|
+
session.pop(self.PAYLOAD_SESSION_KEY, None)
|
|
256
|
+
return False
|
|
257
|
+
|
|
258
|
+
return True
|
|
259
|
+
|
|
260
|
+
def get_payload(self) -> Optional[Dict[str, Any]]:
|
|
261
|
+
session = self._get_session()
|
|
262
|
+
payload = session.get(self.PAYLOAD_SESSION_KEY)
|
|
263
|
+
if not isinstance(payload, dict):
|
|
264
|
+
return None
|
|
265
|
+
|
|
266
|
+
data = payload.get(self.PAYLOAD_NAME)
|
|
267
|
+
if isinstance(data, dict):
|
|
268
|
+
return data
|
|
269
|
+
if data is not None:
|
|
270
|
+
return {"value": data}
|
|
271
|
+
return None
|
|
272
|
+
|
|
273
|
+
def check_role(self, user: Any, allowed_roles: List[str]) -> bool:
|
|
274
|
+
"""Check if user has one of the allowed roles."""
|
|
275
|
+
if isinstance(user, dict):
|
|
276
|
+
user_role = user.get(self._settings.role_identifier, "")
|
|
277
|
+
else:
|
|
278
|
+
user_role = str(user) if user else ""
|
|
279
|
+
return user_role in allowed_roles
|
|
280
|
+
|
|
281
|
+
# ====
|
|
282
|
+
# OAUTH PROVIDERS
|
|
283
|
+
# ====
|
|
284
|
+
def auth_providers(self, *providers) -> Optional[Response]:
|
|
285
|
+
"""Handle OAuth provider signin/callback routes."""
|
|
286
|
+
request = self.get_request()
|
|
287
|
+
if not request:
|
|
288
|
+
return None
|
|
289
|
+
|
|
290
|
+
path_parts = request.url.path.strip("/").split("/")
|
|
291
|
+
|
|
292
|
+
# Handle signin redirects
|
|
293
|
+
if request.method == "GET" and "signin" in path_parts:
|
|
294
|
+
for provider in providers:
|
|
295
|
+
if isinstance(provider, GithubProvider) and "github" in path_parts:
|
|
296
|
+
url = (
|
|
297
|
+
"https://github.com/login/oauth/authorize"
|
|
298
|
+
f"?scope=user:email%20read:user&client_id={provider.client_id}"
|
|
299
|
+
)
|
|
300
|
+
return RedirectResponse(url=url)
|
|
301
|
+
|
|
302
|
+
if isinstance(provider, GoogleProvider) and "google" in path_parts:
|
|
303
|
+
url = (
|
|
304
|
+
"https://accounts.google.com/o/oauth2/v2/auth?"
|
|
305
|
+
"scope=email%20profile&response_type=code&"
|
|
306
|
+
f"client_id={provider.client_id}&"
|
|
307
|
+
f"redirect_uri={provider.redirect_uri}"
|
|
308
|
+
)
|
|
309
|
+
return RedirectResponse(url=url)
|
|
310
|
+
|
|
311
|
+
# Handle callbacks
|
|
312
|
+
code = request.query_params.get("code")
|
|
313
|
+
if request.method == "GET" and "callback" in path_parts and code:
|
|
314
|
+
if "github" in path_parts:
|
|
315
|
+
provider = self._find_provider(providers, GithubProvider)
|
|
316
|
+
if provider:
|
|
317
|
+
return self._github_callback(provider, code)
|
|
318
|
+
if "google" in path_parts:
|
|
319
|
+
provider = self._find_provider(providers, GoogleProvider)
|
|
320
|
+
if provider:
|
|
321
|
+
return self._google_callback(provider, code)
|
|
322
|
+
|
|
323
|
+
return None
|
|
324
|
+
|
|
325
|
+
def _github_callback(self, provider: GithubProvider, code: str) -> Optional[Response]:
|
|
326
|
+
try:
|
|
327
|
+
token_resp = httpx.post(
|
|
328
|
+
"https://github.com/login/oauth/access_token",
|
|
329
|
+
data={
|
|
330
|
+
"client_id": provider.client_id,
|
|
331
|
+
"client_secret": provider.client_secret,
|
|
332
|
+
"code": code,
|
|
333
|
+
},
|
|
334
|
+
headers={"Accept": "application/json"},
|
|
335
|
+
timeout=20,
|
|
336
|
+
)
|
|
337
|
+
token_data = token_resp.json()
|
|
338
|
+
access_token = token_data.get("access_token")
|
|
339
|
+
if not access_token:
|
|
340
|
+
return None
|
|
341
|
+
|
|
342
|
+
headers = {"Authorization": f"Bearer {access_token}",
|
|
343
|
+
"Accept": "application/json"}
|
|
344
|
+
|
|
345
|
+
email_resp = httpx.get(
|
|
346
|
+
"https://api.github.com/user/emails", headers=headers, timeout=20)
|
|
347
|
+
emails = email_resp.json() if isinstance(email_resp.json(), list) else []
|
|
348
|
+
primary_email = next(
|
|
349
|
+
(e.get("email") for e in emails if e.get(
|
|
350
|
+
"primary") and e.get("verified")), None
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
user_resp = httpx.get(
|
|
354
|
+
"https://api.github.com/user", headers=headers, timeout=20)
|
|
355
|
+
user_info = user_resp.json() if isinstance(user_resp.json(), dict) else {}
|
|
356
|
+
|
|
357
|
+
user_data = {
|
|
358
|
+
"name": user_info.get("login"),
|
|
359
|
+
"email": primary_email,
|
|
360
|
+
"image": user_info.get("avatar_url"),
|
|
361
|
+
"provider": "github",
|
|
362
|
+
"provider_id": str(user_info.get("id")) if user_info.get("id") is not None else None,
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
self.sign_in(user_data, provider.max_age)
|
|
366
|
+
return RedirectResponse(url=self._settings.default_signin_redirect, status_code=303)
|
|
367
|
+
except Exception:
|
|
368
|
+
return None
|
|
369
|
+
|
|
370
|
+
def _google_callback(self, provider: GoogleProvider, code: str) -> Optional[Response]:
|
|
371
|
+
try:
|
|
372
|
+
token_resp = httpx.post(
|
|
373
|
+
"https://oauth2.googleapis.com/token",
|
|
374
|
+
data={
|
|
375
|
+
"client_id": provider.client_id,
|
|
376
|
+
"client_secret": provider.client_secret,
|
|
377
|
+
"code": code,
|
|
378
|
+
"grant_type": "authorization_code",
|
|
379
|
+
"redirect_uri": provider.redirect_uri,
|
|
380
|
+
},
|
|
381
|
+
timeout=20,
|
|
382
|
+
)
|
|
383
|
+
token_data = token_resp.json()
|
|
384
|
+
access_token = token_data.get("access_token")
|
|
385
|
+
if not access_token:
|
|
386
|
+
return None
|
|
387
|
+
|
|
388
|
+
user_resp = httpx.get(
|
|
389
|
+
"https://www.googleapis.com/oauth2/v1/userinfo",
|
|
390
|
+
headers={"Authorization": f"Bearer {access_token}"},
|
|
391
|
+
timeout=20,
|
|
392
|
+
)
|
|
393
|
+
user_info = user_resp.json() if isinstance(user_resp.json(), dict) else {}
|
|
394
|
+
|
|
395
|
+
user_data = {
|
|
396
|
+
"name": user_info.get("name"),
|
|
397
|
+
"email": user_info.get("email"),
|
|
398
|
+
"image": user_info.get("picture"),
|
|
399
|
+
"provider": "google",
|
|
400
|
+
"provider_id": str(user_info.get("id")) if user_info.get("id") is not None else None,
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
self.sign_in(user_data, provider.max_age)
|
|
404
|
+
return RedirectResponse(url=self._settings.default_signin_redirect, status_code=303)
|
|
405
|
+
except Exception:
|
|
406
|
+
return None
|
|
407
|
+
|
|
408
|
+
# ====
|
|
409
|
+
# HELPERS
|
|
410
|
+
# ====
|
|
411
|
+
def _normalize_payload(self, data: Any) -> Any:
|
|
412
|
+
if not isinstance(data, dict):
|
|
413
|
+
return data
|
|
414
|
+
|
|
415
|
+
if "role" not in data:
|
|
416
|
+
ur = data.get("userRole")
|
|
417
|
+
if isinstance(ur, dict):
|
|
418
|
+
data["role"] = (
|
|
419
|
+
ur.get("name") or ur.get("slug") or ur.get(
|
|
420
|
+
"role") or ur.get("value") or ur.get("id")
|
|
421
|
+
)
|
|
422
|
+
elif isinstance(ur, str):
|
|
423
|
+
data["role"] = ur
|
|
424
|
+
|
|
425
|
+
return data
|
|
426
|
+
|
|
427
|
+
def _calculate_expiration(self, duration: str) -> int:
|
|
428
|
+
match = re.match(r"^(\d+)(s|m|h|d)$", duration)
|
|
429
|
+
if not match:
|
|
430
|
+
raise ValueError(f"Invalid duration format: {duration}")
|
|
431
|
+
|
|
432
|
+
value, unit = int(match.group(1)), match.group(2)
|
|
433
|
+
delta = {
|
|
434
|
+
"s": timedelta(seconds=value),
|
|
435
|
+
"m": timedelta(minutes=value),
|
|
436
|
+
"h": timedelta(hours=value),
|
|
437
|
+
"d": timedelta(days=value),
|
|
438
|
+
}[unit]
|
|
439
|
+
|
|
440
|
+
return int((datetime.now(timezone.utc) + delta).timestamp())
|
|
441
|
+
|
|
442
|
+
def _find_provider(self, providers, provider_type):
|
|
443
|
+
for p in providers:
|
|
444
|
+
if isinstance(p, provider_type):
|
|
445
|
+
return p
|
|
446
|
+
return None
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
# Singleton instance
|
|
450
|
+
auth = Auth.get_instance()
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
# ====
|
|
454
|
+
# BACKWARDS COMPATIBILITY
|
|
455
|
+
# ====
|
|
456
|
+
class AuthConfig:
|
|
457
|
+
"""Backwards compatibility alias."""
|
|
458
|
+
PUBLIC_ROUTES = property(lambda self: _settings.public_routes)
|
|
459
|
+
PRIVATE_ROUTES = property(lambda self: _settings.private_routes)
|
|
460
|
+
AUTH_ROUTES = property(lambda self: _settings.auth_routes)
|
|
461
|
+
IS_ALL_ROUTES_PRIVATE = property(
|
|
462
|
+
lambda self: _settings.is_all_routes_private)
|
|
463
|
+
DEFAULT_SIGNIN_REDIRECT = property(
|
|
464
|
+
lambda self: _settings.default_signin_redirect)
|
|
465
|
+
DEFAULT_SIGNOUT_REDIRECT = property(
|
|
466
|
+
lambda self: _settings.default_signout_redirect)
|
|
467
|
+
|
|
468
|
+
@staticmethod
|
|
469
|
+
def check_auth_role(user: Any, allowed_roles: List[str]) -> bool:
|
|
470
|
+
return auth.check_role(user, allowed_roles)
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
# ====
|
|
474
|
+
# DECORATORS
|
|
475
|
+
# ====
|
|
476
|
+
def require_auth(redirect_to: Optional[str] = None):
|
|
477
|
+
"""Decorator to require authentication for a route."""
|
|
478
|
+
def decorator(func):
|
|
479
|
+
@wraps(func)
|
|
480
|
+
async def wrapper(*args, **kwargs):
|
|
481
|
+
request = Auth.get_request()
|
|
482
|
+
if not auth.is_authenticated():
|
|
483
|
+
if auth.settings.on_auth_failure and request:
|
|
484
|
+
return auth.settings.on_auth_failure(request)
|
|
485
|
+
target = redirect_to or "/signin"
|
|
486
|
+
next_url = request.url.path if request else "/"
|
|
487
|
+
return RedirectResponse(url=f"{target}?next={next_url}", status_code=303)
|
|
488
|
+
return await func(*args, **kwargs) if asyncio.iscoroutinefunction(func) else func(*args, **kwargs)
|
|
489
|
+
return wrapper
|
|
490
|
+
return decorator
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
def require_role(*roles: str, redirect_to: str = "/unauthorized"):
|
|
494
|
+
"""Decorator to require specific roles for a route. Roles are strings."""
|
|
495
|
+
def decorator(func):
|
|
496
|
+
@wraps(func)
|
|
497
|
+
async def wrapper(*args, **kwargs):
|
|
498
|
+
request = Auth.get_request()
|
|
499
|
+
if not auth.is_authenticated():
|
|
500
|
+
path = request.url.path if request else "/"
|
|
501
|
+
return RedirectResponse(url=f"/signin?next={path}", status_code=303)
|
|
502
|
+
|
|
503
|
+
user = auth.get_payload()
|
|
504
|
+
if not auth.check_role(user, list(roles)):
|
|
505
|
+
return RedirectResponse(url=redirect_to, status_code=303)
|
|
506
|
+
|
|
507
|
+
return await func(*args, **kwargs) if asyncio.iscoroutinefunction(func) else func(*args, **kwargs)
|
|
508
|
+
return wrapper
|
|
509
|
+
return decorator
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def guest_only(redirect_to: Optional[str] = None):
|
|
513
|
+
"""Decorator for routes that should only be accessible to non-authenticated users."""
|
|
514
|
+
def decorator(func):
|
|
515
|
+
@wraps(func)
|
|
516
|
+
async def wrapper(*args, **kwargs):
|
|
517
|
+
request = Auth.get_request()
|
|
518
|
+
if auth.is_authenticated():
|
|
519
|
+
target = redirect_to or auth.settings.default_signin_redirect
|
|
520
|
+
next_url = request.query_params.get(
|
|
521
|
+
"next", target) if request else target
|
|
522
|
+
return RedirectResponse(url=next_url, status_code=303)
|
|
523
|
+
return await func(*args, **kwargs) if asyncio.iscoroutinefunction(func) else func(*args, **kwargs)
|
|
524
|
+
return wrapper
|
|
525
|
+
return decorator
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
def get_csrf_token() -> str:
|
|
529
|
+
"""Get or generate CSRF token from session."""
|
|
530
|
+
request = Auth.get_request()
|
|
531
|
+
if request and hasattr(request, 'session'):
|
|
532
|
+
csrf_token = request.session.get("csrf_token")
|
|
533
|
+
if not csrf_token:
|
|
534
|
+
csrf_token = secrets.token_hex(32)
|
|
535
|
+
request.session["csrf_token"] = csrf_token
|
|
536
|
+
return csrf_token
|
|
537
|
+
return ""
|