caspian-utils 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
casp/__init__.py ADDED
File without changes
casp/auth.py ADDED
@@ -0,0 +1,537 @@
1
+ from typing import Dict, Callable, Any, Optional, Union, List
2
+ import os
3
+ import re
4
+ import secrets
5
+ from contextvars import ContextVar
6
+ from dataclasses import dataclass, field
7
+ from datetime import datetime, timezone, timedelta
8
+ from functools import wraps
9
+ import httpx
10
+ from fastapi import Request, Response
11
+ from fastapi.responses import RedirectResponse
12
+ import asyncio
13
+
14
+
15
+ _request_ctx: ContextVar[Optional[Request]
16
+ ] = ContextVar('request', default=None)
17
+
18
+
19
+ @dataclass
20
+ class AuthSettings:
21
+ """
22
+ App behavior configuration - set via configure_auth() in main.py.
23
+ Secrets are always read from environment variables.
24
+ """
25
+
26
+ # Token settings
27
+ default_token_validity: str = "1h"
28
+ token_auto_refresh: bool = False
29
+
30
+ # Role-based access
31
+ role_identifier: str = "role"
32
+ is_role_based: bool = False
33
+
34
+ # Route protection
35
+ is_all_routes_private: bool = True
36
+ private_routes: List[str] = field(default_factory=list)
37
+ public_routes: List[str] = field(default_factory=lambda: ["/"])
38
+ auth_routes: List[str] = field(
39
+ default_factory=lambda: ["/signin", "/signup"])
40
+ role_based_routes: Dict[str, List[str]] = field(
41
+ default_factory=dict) # e.g. {"/admin": ["admin", "superadmin"]}
42
+
43
+ # Redirects
44
+ default_signin_redirect: str = "/dashboard"
45
+ default_signout_redirect: str = "/signin"
46
+ api_auth_prefix: str = "/api/auth"
47
+
48
+ # Callbacks (hooks for custom logic)
49
+ on_sign_in: Optional[Callable[[dict], None]] = None
50
+ on_sign_out: Optional[Callable[[], None]] = None
51
+ on_auth_failure: Optional[Callable[[Request], Response]] = None
52
+
53
+ # Secrets (always from env)
54
+ secret_key: str = field(default_factory=lambda: os.getenv(
55
+ "AUTH_SECRET", "default_secret_key_change_me"))
56
+ cookie_name: str = field(default_factory=lambda: os.getenv(
57
+ "AUTH_COOKIE_NAME", "auth_cookie"))
58
+
59
+ def __post_init__(self):
60
+ self.cookie_name = re.sub(
61
+ r"\s+", "_", self.cookie_name.strip()).lower()
62
+
63
+
64
+ # Global settings instance
65
+ _settings: AuthSettings = AuthSettings()
66
+
67
+
68
+ def configure_auth(settings: AuthSettings) -> None:
69
+ """
70
+ Configure auth at app startup. Call this before app starts.
71
+
72
+ Example:
73
+ configure_auth(AuthSettings(
74
+ default_signin_redirect="/app",
75
+ private_routes=["/dashboard", "/settings"],
76
+ role_based_routes={
77
+ "/admin": ["admin", "superadmin"],
78
+ "/reports": ["admin", "manager", "user"],
79
+ },
80
+ on_sign_in=lambda user: print(f"Welcome {user.get('name')}"),
81
+ ))
82
+ """
83
+ global _settings
84
+ _settings = settings
85
+ Auth._instance = None
86
+ Auth.get_instance()
87
+
88
+
89
+ def get_auth_settings() -> AuthSettings:
90
+ """Get current auth settings."""
91
+ return _settings
92
+
93
+
94
+ @dataclass
95
+ class GoogleProvider:
96
+ """Google OAuth provider. Reads secrets from env if not provided."""
97
+ client_id: str = field(
98
+ default_factory=lambda: os.getenv("GOOGLE_CLIENT_ID", ""))
99
+ client_secret: str = field(
100
+ default_factory=lambda: os.getenv("GOOGLE_CLIENT_SECRET", ""))
101
+ redirect_uri: str = field(
102
+ default_factory=lambda: os.getenv("GOOGLE_REDIRECT_URI", ""))
103
+ max_age: str = "30d"
104
+
105
+
106
+ @dataclass
107
+ class GithubProvider:
108
+ """GitHub OAuth provider. Reads secrets from env if not provided."""
109
+ client_id: str = field(
110
+ default_factory=lambda: os.getenv("GITHUB_CLIENT_ID", ""))
111
+ client_secret: str = field(
112
+ default_factory=lambda: os.getenv("GITHUB_CLIENT_SECRET", ""))
113
+ max_age: str = "30d"
114
+
115
+
116
+ class Auth:
117
+ PAYLOAD_NAME = "payload_name_8639D"
118
+ PAYLOAD_SESSION_KEY = "payload_session_key_2183A"
119
+
120
+ _instance: Optional["Auth"] = None
121
+ _cookie_name: str = ""
122
+ _providers: List[Any] = []
123
+
124
+ def __init__(self) -> None:
125
+ self._settings = _settings
126
+ Auth._cookie_name = self._settings.cookie_name
127
+
128
+ @classmethod
129
+ def get_instance(cls) -> "Auth":
130
+ if cls._instance is None:
131
+ cls._instance = cls()
132
+ return cls._instance
133
+
134
+ @classmethod
135
+ def set_request(cls, request: Request):
136
+ _request_ctx.set(request)
137
+
138
+ @classmethod
139
+ def get_request(cls) -> Optional[Request]:
140
+ return _request_ctx.get()
141
+
142
+ @classmethod
143
+ def set_providers(cls, *providers) -> None:
144
+ """Set OAuth providers for the auth instance."""
145
+ cls._providers = list(providers)
146
+
147
+ @classmethod
148
+ def get_providers(cls) -> List[Any]:
149
+ return cls._providers
150
+
151
+ @property
152
+ def settings(self) -> AuthSettings:
153
+ return self._settings
154
+
155
+ @property
156
+ def cookie_name(self) -> str:
157
+ return Auth._cookie_name
158
+
159
+ def _get_session(self) -> dict:
160
+ request = self.get_request()
161
+ if request and hasattr(request, 'session'):
162
+ return request.session
163
+ return {}
164
+
165
+ # ====
166
+ # ROUTE CHECKING
167
+ # ====
168
+ def is_public_route(self, path: str) -> bool:
169
+ """Check if path is a public route."""
170
+ return path in self._settings.public_routes
171
+
172
+ def is_auth_route(self, path: str) -> bool:
173
+ """Check if path is an auth route (signin/signup)."""
174
+ return path in self._settings.auth_routes
175
+
176
+ def is_private_route(self, path: str) -> bool:
177
+ """Check if path requires authentication."""
178
+ if self._settings.is_all_routes_private:
179
+ return not (self.is_public_route(path) or self.is_auth_route(path))
180
+ return path in self._settings.private_routes
181
+
182
+ def get_required_roles(self, path: str) -> Optional[List[str]]:
183
+ """Get required roles for a path, if any."""
184
+ return self._settings.role_based_routes.get(path)
185
+
186
+ # ====
187
+ # CORE AUTH METHODS
188
+ # ====
189
+ def sign_in(
190
+ self,
191
+ data: Union[dict, str, Any],
192
+ token_validity: Optional[str] = None,
193
+ redirect_to: Union[bool, str] = False,
194
+ ) -> Union[str, Response]:
195
+ validity = token_validity or self._settings.default_token_validity
196
+ exp_time = self._calculate_expiration(validity)
197
+
198
+ data = self._normalize_payload(data)
199
+
200
+ payload = {
201
+ self.PAYLOAD_NAME: data,
202
+ "exp": exp_time,
203
+ "iat": datetime.now(timezone.utc).timestamp(),
204
+ }
205
+
206
+ session = self._get_session()
207
+ session[self.PAYLOAD_SESSION_KEY] = payload
208
+ session["csrf_token"] = secrets.token_hex(32)
209
+
210
+ # Call hook if configured
211
+ if self._settings.on_sign_in:
212
+ self._settings.on_sign_in(data if isinstance(
213
+ data, dict) else {"value": data})
214
+
215
+ if redirect_to is True:
216
+ return RedirectResponse(url=self._settings.default_signin_redirect, status_code=303)
217
+ if isinstance(redirect_to, str) and redirect_to:
218
+ return RedirectResponse(url=redirect_to, status_code=303)
219
+
220
+ return "ok"
221
+
222
+ def sign_out(self, redirect_to: Optional[str] = None) -> Optional[Response]:
223
+ session = self._get_session()
224
+ session.pop(self.PAYLOAD_SESSION_KEY, None)
225
+ session.pop("csrf_token", None)
226
+ session.clear()
227
+
228
+ # Call hook if configured
229
+ if self._settings.on_sign_out:
230
+ self._settings.on_sign_out()
231
+
232
+ target = redirect_to or self._settings.default_signout_redirect
233
+ if target:
234
+ return RedirectResponse(url=target, status_code=303)
235
+ return None
236
+
237
+ def is_authenticated(self) -> bool:
238
+ session = self._get_session()
239
+ payload = session.get(self.PAYLOAD_SESSION_KEY)
240
+ if not isinstance(payload, dict):
241
+ return False
242
+
243
+ exp = payload.get("exp")
244
+ if exp is not None:
245
+ try:
246
+ now_ts = datetime.now(timezone.utc).timestamp()
247
+ if float(exp) < now_ts:
248
+ session.pop(self.PAYLOAD_SESSION_KEY, None)
249
+ return False
250
+ except Exception:
251
+ session.pop(self.PAYLOAD_SESSION_KEY, None)
252
+ return False
253
+
254
+ if payload.get(self.PAYLOAD_NAME) is None:
255
+ session.pop(self.PAYLOAD_SESSION_KEY, None)
256
+ return False
257
+
258
+ return True
259
+
260
+ def get_payload(self) -> Optional[Dict[str, Any]]:
261
+ session = self._get_session()
262
+ payload = session.get(self.PAYLOAD_SESSION_KEY)
263
+ if not isinstance(payload, dict):
264
+ return None
265
+
266
+ data = payload.get(self.PAYLOAD_NAME)
267
+ if isinstance(data, dict):
268
+ return data
269
+ if data is not None:
270
+ return {"value": data}
271
+ return None
272
+
273
+ def check_role(self, user: Any, allowed_roles: List[str]) -> bool:
274
+ """Check if user has one of the allowed roles."""
275
+ if isinstance(user, dict):
276
+ user_role = user.get(self._settings.role_identifier, "")
277
+ else:
278
+ user_role = str(user) if user else ""
279
+ return user_role in allowed_roles
280
+
281
+ # ====
282
+ # OAUTH PROVIDERS
283
+ # ====
284
+ def auth_providers(self, *providers) -> Optional[Response]:
285
+ """Handle OAuth provider signin/callback routes."""
286
+ request = self.get_request()
287
+ if not request:
288
+ return None
289
+
290
+ path_parts = request.url.path.strip("/").split("/")
291
+
292
+ # Handle signin redirects
293
+ if request.method == "GET" and "signin" in path_parts:
294
+ for provider in providers:
295
+ if isinstance(provider, GithubProvider) and "github" in path_parts:
296
+ url = (
297
+ "https://github.com/login/oauth/authorize"
298
+ f"?scope=user:email%20read:user&client_id={provider.client_id}"
299
+ )
300
+ return RedirectResponse(url=url)
301
+
302
+ if isinstance(provider, GoogleProvider) and "google" in path_parts:
303
+ url = (
304
+ "https://accounts.google.com/o/oauth2/v2/auth?"
305
+ "scope=email%20profile&response_type=code&"
306
+ f"client_id={provider.client_id}&"
307
+ f"redirect_uri={provider.redirect_uri}"
308
+ )
309
+ return RedirectResponse(url=url)
310
+
311
+ # Handle callbacks
312
+ code = request.query_params.get("code")
313
+ if request.method == "GET" and "callback" in path_parts and code:
314
+ if "github" in path_parts:
315
+ provider = self._find_provider(providers, GithubProvider)
316
+ if provider:
317
+ return self._github_callback(provider, code)
318
+ if "google" in path_parts:
319
+ provider = self._find_provider(providers, GoogleProvider)
320
+ if provider:
321
+ return self._google_callback(provider, code)
322
+
323
+ return None
324
+
325
+ def _github_callback(self, provider: GithubProvider, code: str) -> Optional[Response]:
326
+ try:
327
+ token_resp = httpx.post(
328
+ "https://github.com/login/oauth/access_token",
329
+ data={
330
+ "client_id": provider.client_id,
331
+ "client_secret": provider.client_secret,
332
+ "code": code,
333
+ },
334
+ headers={"Accept": "application/json"},
335
+ timeout=20,
336
+ )
337
+ token_data = token_resp.json()
338
+ access_token = token_data.get("access_token")
339
+ if not access_token:
340
+ return None
341
+
342
+ headers = {"Authorization": f"Bearer {access_token}",
343
+ "Accept": "application/json"}
344
+
345
+ email_resp = httpx.get(
346
+ "https://api.github.com/user/emails", headers=headers, timeout=20)
347
+ emails = email_resp.json() if isinstance(email_resp.json(), list) else []
348
+ primary_email = next(
349
+ (e.get("email") for e in emails if e.get(
350
+ "primary") and e.get("verified")), None
351
+ )
352
+
353
+ user_resp = httpx.get(
354
+ "https://api.github.com/user", headers=headers, timeout=20)
355
+ user_info = user_resp.json() if isinstance(user_resp.json(), dict) else {}
356
+
357
+ user_data = {
358
+ "name": user_info.get("login"),
359
+ "email": primary_email,
360
+ "image": user_info.get("avatar_url"),
361
+ "provider": "github",
362
+ "provider_id": str(user_info.get("id")) if user_info.get("id") is not None else None,
363
+ }
364
+
365
+ self.sign_in(user_data, provider.max_age)
366
+ return RedirectResponse(url=self._settings.default_signin_redirect, status_code=303)
367
+ except Exception:
368
+ return None
369
+
370
+ def _google_callback(self, provider: GoogleProvider, code: str) -> Optional[Response]:
371
+ try:
372
+ token_resp = httpx.post(
373
+ "https://oauth2.googleapis.com/token",
374
+ data={
375
+ "client_id": provider.client_id,
376
+ "client_secret": provider.client_secret,
377
+ "code": code,
378
+ "grant_type": "authorization_code",
379
+ "redirect_uri": provider.redirect_uri,
380
+ },
381
+ timeout=20,
382
+ )
383
+ token_data = token_resp.json()
384
+ access_token = token_data.get("access_token")
385
+ if not access_token:
386
+ return None
387
+
388
+ user_resp = httpx.get(
389
+ "https://www.googleapis.com/oauth2/v1/userinfo",
390
+ headers={"Authorization": f"Bearer {access_token}"},
391
+ timeout=20,
392
+ )
393
+ user_info = user_resp.json() if isinstance(user_resp.json(), dict) else {}
394
+
395
+ user_data = {
396
+ "name": user_info.get("name"),
397
+ "email": user_info.get("email"),
398
+ "image": user_info.get("picture"),
399
+ "provider": "google",
400
+ "provider_id": str(user_info.get("id")) if user_info.get("id") is not None else None,
401
+ }
402
+
403
+ self.sign_in(user_data, provider.max_age)
404
+ return RedirectResponse(url=self._settings.default_signin_redirect, status_code=303)
405
+ except Exception:
406
+ return None
407
+
408
+ # ====
409
+ # HELPERS
410
+ # ====
411
+ def _normalize_payload(self, data: Any) -> Any:
412
+ if not isinstance(data, dict):
413
+ return data
414
+
415
+ if "role" not in data:
416
+ ur = data.get("userRole")
417
+ if isinstance(ur, dict):
418
+ data["role"] = (
419
+ ur.get("name") or ur.get("slug") or ur.get(
420
+ "role") or ur.get("value") or ur.get("id")
421
+ )
422
+ elif isinstance(ur, str):
423
+ data["role"] = ur
424
+
425
+ return data
426
+
427
+ def _calculate_expiration(self, duration: str) -> int:
428
+ match = re.match(r"^(\d+)(s|m|h|d)$", duration)
429
+ if not match:
430
+ raise ValueError(f"Invalid duration format: {duration}")
431
+
432
+ value, unit = int(match.group(1)), match.group(2)
433
+ delta = {
434
+ "s": timedelta(seconds=value),
435
+ "m": timedelta(minutes=value),
436
+ "h": timedelta(hours=value),
437
+ "d": timedelta(days=value),
438
+ }[unit]
439
+
440
+ return int((datetime.now(timezone.utc) + delta).timestamp())
441
+
442
+ def _find_provider(self, providers, provider_type):
443
+ for p in providers:
444
+ if isinstance(p, provider_type):
445
+ return p
446
+ return None
447
+
448
+
449
+ # Singleton instance
450
+ auth = Auth.get_instance()
451
+
452
+
453
+ # ====
454
+ # BACKWARDS COMPATIBILITY
455
+ # ====
456
+ class AuthConfig:
457
+ """Backwards compatibility alias."""
458
+ PUBLIC_ROUTES = property(lambda self: _settings.public_routes)
459
+ PRIVATE_ROUTES = property(lambda self: _settings.private_routes)
460
+ AUTH_ROUTES = property(lambda self: _settings.auth_routes)
461
+ IS_ALL_ROUTES_PRIVATE = property(
462
+ lambda self: _settings.is_all_routes_private)
463
+ DEFAULT_SIGNIN_REDIRECT = property(
464
+ lambda self: _settings.default_signin_redirect)
465
+ DEFAULT_SIGNOUT_REDIRECT = property(
466
+ lambda self: _settings.default_signout_redirect)
467
+
468
+ @staticmethod
469
+ def check_auth_role(user: Any, allowed_roles: List[str]) -> bool:
470
+ return auth.check_role(user, allowed_roles)
471
+
472
+
473
+ # ====
474
+ # DECORATORS
475
+ # ====
476
+ def require_auth(redirect_to: Optional[str] = None):
477
+ """Decorator to require authentication for a route."""
478
+ def decorator(func):
479
+ @wraps(func)
480
+ async def wrapper(*args, **kwargs):
481
+ request = Auth.get_request()
482
+ if not auth.is_authenticated():
483
+ if auth.settings.on_auth_failure and request:
484
+ return auth.settings.on_auth_failure(request)
485
+ target = redirect_to or "/signin"
486
+ next_url = request.url.path if request else "/"
487
+ return RedirectResponse(url=f"{target}?next={next_url}", status_code=303)
488
+ return await func(*args, **kwargs) if asyncio.iscoroutinefunction(func) else func(*args, **kwargs)
489
+ return wrapper
490
+ return decorator
491
+
492
+
493
+ def require_role(*roles: str, redirect_to: str = "/unauthorized"):
494
+ """Decorator to require specific roles for a route. Roles are strings."""
495
+ def decorator(func):
496
+ @wraps(func)
497
+ async def wrapper(*args, **kwargs):
498
+ request = Auth.get_request()
499
+ if not auth.is_authenticated():
500
+ path = request.url.path if request else "/"
501
+ return RedirectResponse(url=f"/signin?next={path}", status_code=303)
502
+
503
+ user = auth.get_payload()
504
+ if not auth.check_role(user, list(roles)):
505
+ return RedirectResponse(url=redirect_to, status_code=303)
506
+
507
+ return await func(*args, **kwargs) if asyncio.iscoroutinefunction(func) else func(*args, **kwargs)
508
+ return wrapper
509
+ return decorator
510
+
511
+
512
+ def guest_only(redirect_to: Optional[str] = None):
513
+ """Decorator for routes that should only be accessible to non-authenticated users."""
514
+ def decorator(func):
515
+ @wraps(func)
516
+ async def wrapper(*args, **kwargs):
517
+ request = Auth.get_request()
518
+ if auth.is_authenticated():
519
+ target = redirect_to or auth.settings.default_signin_redirect
520
+ next_url = request.query_params.get(
521
+ "next", target) if request else target
522
+ return RedirectResponse(url=next_url, status_code=303)
523
+ return await func(*args, **kwargs) if asyncio.iscoroutinefunction(func) else func(*args, **kwargs)
524
+ return wrapper
525
+ return decorator
526
+
527
+
528
+ def get_csrf_token() -> str:
529
+ """Get or generate CSRF token from session."""
530
+ request = Auth.get_request()
531
+ if request and hasattr(request, 'session'):
532
+ csrf_token = request.session.get("csrf_token")
533
+ if not csrf_token:
534
+ csrf_token = secrets.token_hex(32)
535
+ request.session["csrf_token"] = csrf_token
536
+ return csrf_token
537
+ return ""