svc-infra 0.1.593__py3-none-any.whl → 0.1.595__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of svc-infra might be problematic. Click here for more details.
- svc_infra/apf_payments/README.md +26 -0
- svc_infra/apf_payments/provider/aiydan.py +28 -2
- svc_infra/apf_payments/service.py +113 -20
- svc_infra/api/fastapi/apf_payments/router.py +67 -4
- svc_infra/api/fastapi/auth/add.py +10 -0
- svc_infra/api/fastapi/auth/gaurd.py +67 -5
- svc_infra/api/fastapi/auth/routers/oauth_router.py +79 -34
- svc_infra/api/fastapi/auth/routers/session_router.py +63 -0
- svc_infra/api/fastapi/auth/settings.py +2 -0
- svc_infra/api/fastapi/db/sql/users.py +13 -1
- svc_infra/api/fastapi/dependencies/ratelimit.py +66 -0
- svc_infra/api/fastapi/middleware/ratelimit.py +26 -11
- svc_infra/api/fastapi/middleware/ratelimit_store.py +78 -0
- svc_infra/api/fastapi/middleware/request_size_limit.py +36 -0
- svc_infra/api/fastapi/setup.py +2 -1
- svc_infra/obs/metrics/__init__.py +53 -0
- svc_infra/obs/metrics.py +52 -0
- svc_infra/security/audit.py +130 -0
- svc_infra/security/audit_service.py +73 -0
- svc_infra/security/headers.py +39 -0
- svc_infra/security/hibp.py +91 -0
- svc_infra/security/jwt_rotation.py +53 -0
- svc_infra/security/lockout.py +96 -0
- svc_infra/security/models.py +245 -0
- svc_infra/security/org_invites.py +128 -0
- svc_infra/security/passwords.py +77 -0
- svc_infra/security/permissions.py +148 -0
- svc_infra/security/session.py +98 -0
- svc_infra/security/signed_cookies.py +80 -0
- {svc_infra-0.1.593.dist-info → svc_infra-0.1.595.dist-info}/METADATA +1 -1
- {svc_infra-0.1.593.dist-info → svc_infra-0.1.595.dist-info}/RECORD +33 -16
- {svc_infra-0.1.593.dist-info → svc_infra-0.1.595.dist-info}/WHEEL +0 -0
- {svc_infra-0.1.593.dist-info → svc_infra-0.1.595.dist-info}/entry_points.txt +0 -0
|
@@ -28,6 +28,8 @@ from svc_infra.api.fastapi.paths.auth import (
|
|
|
28
28
|
OAUTH_LOGIN_PATH,
|
|
29
29
|
OAUTH_REFRESH_PATH,
|
|
30
30
|
)
|
|
31
|
+
from svc_infra.security.models import RefreshToken
|
|
32
|
+
from svc_infra.security.session import issue_session_and_refresh, rotate_session_refresh
|
|
31
33
|
|
|
32
34
|
|
|
33
35
|
def _gen_pkce_pair() -> tuple[str, str]:
|
|
@@ -466,9 +468,13 @@ async def _validate_and_decode_jwt_token(raw_token: str) -> str:
|
|
|
466
468
|
|
|
467
469
|
|
|
468
470
|
async def _set_cookie_on_response(
|
|
469
|
-
resp: Response,
|
|
471
|
+
resp: Response,
|
|
472
|
+
auth_backend: AuthenticationBackend,
|
|
473
|
+
user: Any,
|
|
474
|
+
*,
|
|
475
|
+
refresh_raw: str,
|
|
470
476
|
) -> None:
|
|
471
|
-
"""Set authentication
|
|
477
|
+
"""Set authentication (JWT) and refresh cookies on response."""
|
|
472
478
|
st = get_auth_settings()
|
|
473
479
|
strategy = auth_backend.get_strategy()
|
|
474
480
|
jwt_token = await strategy.write_token(user)
|
|
@@ -477,6 +483,7 @@ async def _set_cookie_on_response(
|
|
|
477
483
|
if same_site_lit == "none" and not bool(st.session_cookie_secure):
|
|
478
484
|
raise HTTPException(500, "session_cookie_samesite=None requires session_cookie_secure=True")
|
|
479
485
|
|
|
486
|
+
# Access/Auth cookie (short-lived JWT)
|
|
480
487
|
resp.set_cookie(
|
|
481
488
|
key=_cookie_name(st),
|
|
482
489
|
value=jwt_token,
|
|
@@ -488,6 +495,18 @@ async def _set_cookie_on_response(
|
|
|
488
495
|
path="/",
|
|
489
496
|
)
|
|
490
497
|
|
|
498
|
+
# Refresh cookie (opaque token, longer lived)
|
|
499
|
+
resp.set_cookie(
|
|
500
|
+
key=getattr(st, "session_cookie_name", "svc_session"),
|
|
501
|
+
value=refresh_raw,
|
|
502
|
+
max_age=60 * 60 * 24 * 7, # 7 days default
|
|
503
|
+
httponly=True,
|
|
504
|
+
secure=bool(st.session_cookie_secure),
|
|
505
|
+
samesite=same_site_lit,
|
|
506
|
+
domain=_cookie_domain(st),
|
|
507
|
+
path="/",
|
|
508
|
+
)
|
|
509
|
+
|
|
491
510
|
|
|
492
511
|
def _clean_oauth_session_state(request: Request, provider: str) -> None:
|
|
493
512
|
"""Clean up transient OAuth session state."""
|
|
@@ -641,9 +660,18 @@ def _create_oauth_router(
|
|
|
641
660
|
user.last_login = datetime.now(timezone.utc)
|
|
642
661
|
await session.commit()
|
|
643
662
|
|
|
644
|
-
# Create
|
|
663
|
+
# Create session + initial refresh token
|
|
664
|
+
raw_refresh, _rt = await issue_session_and_refresh(
|
|
665
|
+
session,
|
|
666
|
+
user_id=user.id,
|
|
667
|
+
tenant_id=getattr(user, "tenant_id", None),
|
|
668
|
+
user_agent=str(request.headers.get("user-agent", ""))[:512],
|
|
669
|
+
ip_hash=None,
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
# Create response with auth + refresh cookies
|
|
645
673
|
resp = RedirectResponse(url=redirect_url, status_code=status.HTTP_302_FOUND)
|
|
646
|
-
await _set_cookie_on_response(resp, auth_backend, user)
|
|
674
|
+
await _set_cookie_on_response(resp, auth_backend, user, refresh_raw=raw_refresh)
|
|
647
675
|
|
|
648
676
|
# Clean up session state
|
|
649
677
|
_clean_oauth_session_state(request, provider)
|
|
@@ -667,44 +695,60 @@ def _create_oauth_router(
|
|
|
667
695
|
"""Refresh authentication token."""
|
|
668
696
|
st = get_auth_settings()
|
|
669
697
|
|
|
670
|
-
# Read and validate cookie
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
if not
|
|
698
|
+
# Read and validate auth JWT cookie
|
|
699
|
+
name_auth = _cookie_name(st)
|
|
700
|
+
raw_auth = request.cookies.get(name_auth)
|
|
701
|
+
if not raw_auth:
|
|
674
702
|
raise HTTPException(401, "missing_token")
|
|
675
703
|
|
|
676
|
-
# Validate and decode JWT token
|
|
677
|
-
user_id = await _validate_and_decode_jwt_token(
|
|
704
|
+
# Validate and decode JWT token to get user id
|
|
705
|
+
user_id = await _validate_and_decode_jwt_token(raw_auth)
|
|
678
706
|
|
|
679
707
|
# Load user
|
|
680
708
|
user = await session.get(user_model, user_id)
|
|
681
709
|
if not user:
|
|
682
710
|
raise HTTPException(401, "invalid_token")
|
|
683
711
|
|
|
684
|
-
#
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
712
|
+
# Obtain refresh cookie
|
|
713
|
+
refresh_cookie_name = getattr(st, "session_cookie_name", "svc_session")
|
|
714
|
+
raw_refresh = request.cookies.get(refresh_cookie_name)
|
|
715
|
+
if not raw_refresh:
|
|
716
|
+
raise HTTPException(401, "missing_refresh_token")
|
|
717
|
+
|
|
718
|
+
# Lookup refresh token row by hash
|
|
719
|
+
from sqlalchemy import select
|
|
720
|
+
|
|
721
|
+
from svc_infra.security.models import hash_refresh_token
|
|
722
|
+
|
|
723
|
+
token_hash = hash_refresh_token(raw_refresh)
|
|
724
|
+
found: RefreshToken | None = (
|
|
725
|
+
(
|
|
726
|
+
await session.execute(
|
|
727
|
+
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
|
728
|
+
)
|
|
729
|
+
)
|
|
730
|
+
.scalars()
|
|
731
|
+
.first()
|
|
732
|
+
)
|
|
733
|
+
if (
|
|
734
|
+
not found
|
|
735
|
+
or found.revoked_at
|
|
736
|
+
or (found.expires_at and found.expires_at < datetime.now(timezone.utc))
|
|
737
|
+
):
|
|
738
|
+
raise HTTPException(401, "invalid_refresh_token")
|
|
739
|
+
|
|
740
|
+
# Rotate refresh token
|
|
741
|
+
try:
|
|
742
|
+
new_raw, _new_rt = await rotate_session_refresh(session, current=found)
|
|
743
|
+
except ValueError:
|
|
744
|
+
# Token expired between validation and rotation; treat as invalid
|
|
745
|
+
raise HTTPException(401, "invalid_refresh_token") from None
|
|
746
|
+
|
|
747
|
+
# Write response (204) with new cookies
|
|
748
|
+
resp = Response(status_code=status.HTTP_204_NO_CONTENT)
|
|
749
|
+
await _set_cookie_on_response(resp, auth_backend, user, refresh_raw=new_raw)
|
|
750
|
+
|
|
751
|
+
# Dead code removed: MFA branch handled earlier in login flow, refresh returns 204 above.
|
|
708
752
|
if hasattr(policy, "on_token_refresh"):
|
|
709
753
|
try:
|
|
710
754
|
await policy.on_token_refresh(user)
|
|
@@ -713,4 +757,5 @@ def _create_oauth_router(
|
|
|
713
757
|
|
|
714
758
|
return resp
|
|
715
759
|
|
|
760
|
+
# Return router at end of factory
|
|
716
761
|
return router
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from datetime import datetime, timezone
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
from fastapi import APIRouter, HTTPException
|
|
7
|
+
from sqlalchemy import select
|
|
8
|
+
|
|
9
|
+
from svc_infra.api.fastapi.auth.security import Identity
|
|
10
|
+
from svc_infra.api.fastapi.db.sql.session import SqlSessionDep
|
|
11
|
+
from svc_infra.security.models import AuthSession
|
|
12
|
+
from svc_infra.security.permissions import RequirePermission
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def build_session_router() -> APIRouter:
|
|
16
|
+
router = APIRouter(prefix="/sessions", tags=["sessions"])
|
|
17
|
+
|
|
18
|
+
@router.get(
|
|
19
|
+
"/me", response_model=list[dict], dependencies=[RequirePermission("security.session.list")]
|
|
20
|
+
)
|
|
21
|
+
async def list_my_sessions(identity: Identity, session: SqlSessionDep) -> List[dict]:
|
|
22
|
+
stmt = select(AuthSession).where(AuthSession.user_id == identity.user.id)
|
|
23
|
+
rows = (await session.execute(stmt)).scalars().all()
|
|
24
|
+
return [
|
|
25
|
+
{
|
|
26
|
+
"id": str(r.id),
|
|
27
|
+
"user_agent": r.user_agent,
|
|
28
|
+
"ip_hash": r.ip_hash,
|
|
29
|
+
"revoked": bool(r.revoked_at),
|
|
30
|
+
"last_seen_at": r.last_seen_at.isoformat() if r.last_seen_at else None,
|
|
31
|
+
"created_at": r.created_at.isoformat() if r.created_at else None,
|
|
32
|
+
}
|
|
33
|
+
for r in rows
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
@router.post(
|
|
37
|
+
"/{session_id}/revoke",
|
|
38
|
+
status_code=204,
|
|
39
|
+
dependencies=[RequirePermission("security.session.revoke")],
|
|
40
|
+
)
|
|
41
|
+
async def revoke_session(session_id: str, identity: Identity, db: SqlSessionDep):
|
|
42
|
+
# Load session and ensure it belongs to the user (non-admin users cannot revoke others)
|
|
43
|
+
s = await db.get(AuthSession, session_id)
|
|
44
|
+
if not s:
|
|
45
|
+
raise HTTPException(404, "session_not_found")
|
|
46
|
+
# Basic ownership check; could extend for admin bypass later
|
|
47
|
+
if s.user_id != identity.user.id:
|
|
48
|
+
raise HTTPException(403, "forbidden")
|
|
49
|
+
if s.revoked_at:
|
|
50
|
+
return # already revoked
|
|
51
|
+
s.revoked_at = datetime.now(timezone.utc)
|
|
52
|
+
s.revoke_reason = "user_revoked"
|
|
53
|
+
# Revoke all refresh tokens for this session
|
|
54
|
+
for rt in s.refresh_tokens:
|
|
55
|
+
if not rt.revoked_at:
|
|
56
|
+
rt.revoked_at = s.revoked_at
|
|
57
|
+
rt.revoke_reason = "session_revoked"
|
|
58
|
+
await db.flush()
|
|
59
|
+
|
|
60
|
+
return router
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
__all__ = ["build_session_router"]
|
|
@@ -18,6 +18,8 @@ class OIDCProvider(BaseModel):
|
|
|
18
18
|
class JWTSettings(BaseModel):
|
|
19
19
|
secret: SecretStr
|
|
20
20
|
lifetime_seconds: int = 60 * 60 * 24 * 7
|
|
21
|
+
# Optional older secrets accepted for verification during rotation window
|
|
22
|
+
old_secrets: List[SecretStr] = Field(default_factory=list)
|
|
21
23
|
|
|
22
24
|
|
|
23
25
|
class PasswordClient(BaseModel):
|
|
@@ -12,6 +12,7 @@ from svc_infra.api.fastapi.auth.settings import get_auth_settings
|
|
|
12
12
|
from svc_infra.api.fastapi.dual.dualize import dualize_public, dualize_user
|
|
13
13
|
from svc_infra.api.fastapi.dual.router import DualAPIRouter
|
|
14
14
|
from svc_infra.app.env import CURRENT_ENVIRONMENT, DEV_ENV, LOCAL_ENV
|
|
15
|
+
from svc_infra.security.jwt_rotation import RotatingJWTStrategy
|
|
15
16
|
|
|
16
17
|
from ...auth.security import auth_login_path
|
|
17
18
|
from ...auth.sender import get_sender
|
|
@@ -94,7 +95,18 @@ def get_fastapi_users(
|
|
|
94
95
|
lifetime = getattr(jwt_block, "lifetime_seconds", None) if jwt_block else None
|
|
95
96
|
if not isinstance(lifetime, int) or lifetime <= 0:
|
|
96
97
|
lifetime = 3600
|
|
97
|
-
|
|
98
|
+
old = []
|
|
99
|
+
if jwt_block and getattr(jwt_block, "old_secrets", None):
|
|
100
|
+
old = [s.get_secret_value() for s in jwt_block.old_secrets or []]
|
|
101
|
+
audience = "fastapi-users:auth"
|
|
102
|
+
if old:
|
|
103
|
+
return RotatingJWTStrategy(
|
|
104
|
+
secret=secret,
|
|
105
|
+
lifetime_seconds=lifetime,
|
|
106
|
+
old_secrets=old,
|
|
107
|
+
token_audience=audience,
|
|
108
|
+
)
|
|
109
|
+
return JWTStrategy(secret=secret, lifetime_seconds=lifetime, token_audience=audience)
|
|
98
110
|
|
|
99
111
|
bearer_transport = BearerTransport(tokenUrl=auth_login_path)
|
|
100
112
|
auth_backend = AuthenticationBackend(
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from typing import Callable
|
|
5
|
+
|
|
6
|
+
from fastapi import HTTPException
|
|
7
|
+
from starlette.requests import Request
|
|
8
|
+
|
|
9
|
+
from svc_infra.api.fastapi.middleware.ratelimit_store import InMemoryRateLimitStore, RateLimitStore
|
|
10
|
+
from svc_infra.obs.metrics import emit_rate_limited
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class RateLimiter:
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
*,
|
|
17
|
+
limit: int,
|
|
18
|
+
window: int = 60,
|
|
19
|
+
key_fn: Callable = lambda r: "global",
|
|
20
|
+
store: RateLimitStore | None = None,
|
|
21
|
+
):
|
|
22
|
+
self.limit = limit
|
|
23
|
+
self.window = window
|
|
24
|
+
self.key_fn = key_fn
|
|
25
|
+
self.store = store or InMemoryRateLimitStore(limit=limit)
|
|
26
|
+
|
|
27
|
+
async def __call__(self, request: Request):
|
|
28
|
+
key = self.key_fn(request)
|
|
29
|
+
count, limit, reset = self.store.incr(str(key), self.window)
|
|
30
|
+
if count > limit:
|
|
31
|
+
retry = max(0, reset - int(time.time()))
|
|
32
|
+
try:
|
|
33
|
+
emit_rate_limited(str(key), limit, retry)
|
|
34
|
+
except Exception:
|
|
35
|
+
pass
|
|
36
|
+
raise HTTPException(
|
|
37
|
+
status_code=429, detail="Rate limit exceeded", headers={"Retry-After": str(retry)}
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
__all__ = ["RateLimiter"]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def rate_limiter(
|
|
45
|
+
*,
|
|
46
|
+
limit: int,
|
|
47
|
+
window: int = 60,
|
|
48
|
+
key_fn: Callable = lambda r: "global",
|
|
49
|
+
store: RateLimitStore | None = None,
|
|
50
|
+
):
|
|
51
|
+
store_ = store or InMemoryRateLimitStore(limit=limit)
|
|
52
|
+
|
|
53
|
+
async def dep(request: Request):
|
|
54
|
+
key = key_fn(request)
|
|
55
|
+
count, lim, reset = store_.incr(str(key), window)
|
|
56
|
+
if count > lim:
|
|
57
|
+
retry = max(0, reset - int(time.time()))
|
|
58
|
+
try:
|
|
59
|
+
emit_rate_limited(str(key), lim, retry)
|
|
60
|
+
except Exception:
|
|
61
|
+
pass
|
|
62
|
+
raise HTTPException(
|
|
63
|
+
status_code=429, detail="Rate limit exceeded", headers={"Retry-After": str(retry)}
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
return dep
|
|
@@ -3,25 +3,41 @@ import time
|
|
|
3
3
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
4
4
|
from starlette.responses import JSONResponse
|
|
5
5
|
|
|
6
|
+
from svc_infra.obs.metrics import emit_rate_limited
|
|
7
|
+
|
|
8
|
+
from .ratelimit_store import InMemoryRateLimitStore, RateLimitStore
|
|
9
|
+
|
|
6
10
|
|
|
7
11
|
class SimpleRateLimitMiddleware(BaseHTTPMiddleware):
|
|
8
|
-
def __init__(
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
app,
|
|
15
|
+
limit: int = 120,
|
|
16
|
+
window: int = 60,
|
|
17
|
+
key_fn=None,
|
|
18
|
+
store: RateLimitStore | None = None,
|
|
19
|
+
):
|
|
9
20
|
super().__init__(app)
|
|
10
21
|
self.limit, self.window = limit, window
|
|
11
22
|
self.key_fn = key_fn or (lambda r: r.headers.get("X-API-Key") or r.client.host)
|
|
12
|
-
self.
|
|
23
|
+
self.store = store or InMemoryRateLimitStore(limit=limit)
|
|
13
24
|
|
|
14
25
|
async def dispatch(self, request, call_next):
|
|
15
26
|
key = self.key_fn(request)
|
|
16
27
|
now = int(time.time())
|
|
17
|
-
|
|
18
|
-
|
|
28
|
+
# Increment counter in store
|
|
29
|
+
count, limit, reset = self.store.incr(str(key), self.window)
|
|
30
|
+
remaining = max(0, limit - count)
|
|
19
31
|
|
|
20
|
-
remaining
|
|
21
|
-
|
|
32
|
+
if remaining < 0: # defensive clamp
|
|
33
|
+
remaining = 0
|
|
22
34
|
|
|
23
|
-
if
|
|
35
|
+
if count > limit:
|
|
24
36
|
retry = max(0, reset - now)
|
|
37
|
+
try:
|
|
38
|
+
emit_rate_limited(str(key), limit, retry)
|
|
39
|
+
except Exception:
|
|
40
|
+
pass
|
|
25
41
|
return JSONResponse(
|
|
26
42
|
status_code=429,
|
|
27
43
|
content={
|
|
@@ -31,16 +47,15 @@ class SimpleRateLimitMiddleware(BaseHTTPMiddleware):
|
|
|
31
47
|
"code": "RATE_LIMITED",
|
|
32
48
|
},
|
|
33
49
|
headers={
|
|
34
|
-
"X-RateLimit-Limit": str(
|
|
50
|
+
"X-RateLimit-Limit": str(limit),
|
|
35
51
|
"X-RateLimit-Remaining": "0",
|
|
36
52
|
"X-RateLimit-Reset": str(reset),
|
|
37
53
|
"Retry-After": str(retry),
|
|
38
54
|
},
|
|
39
55
|
)
|
|
40
56
|
|
|
41
|
-
self.buckets[(key, win)] = bucket + 1
|
|
42
57
|
resp = await call_next(request)
|
|
43
|
-
resp.headers.setdefault("X-RateLimit-Limit", str(
|
|
44
|
-
resp.headers.setdefault("X-RateLimit-Remaining", str(
|
|
58
|
+
resp.headers.setdefault("X-RateLimit-Limit", str(limit))
|
|
59
|
+
resp.headers.setdefault("X-RateLimit-Remaining", str(remaining))
|
|
45
60
|
resp.headers.setdefault("X-RateLimit-Reset", str(reset))
|
|
46
61
|
return resp
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from typing import Optional, Protocol, Tuple
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class RateLimitStore(Protocol):
|
|
8
|
+
def incr(self, key: str, window: int) -> Tuple[int, int, int]:
|
|
9
|
+
"""Increment and return (count, limit, resetEpoch).
|
|
10
|
+
|
|
11
|
+
Implementations should manage per-window buckets. The 'limit' is stored configuration.
|
|
12
|
+
"""
|
|
13
|
+
...
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class InMemoryRateLimitStore:
|
|
17
|
+
def __init__(self, limit: int = 120):
|
|
18
|
+
self.limit = limit
|
|
19
|
+
self._buckets: dict[tuple[str, int], int] = {}
|
|
20
|
+
|
|
21
|
+
def incr(self, key: str, window: int) -> Tuple[int, int, int]:
|
|
22
|
+
now = int(time.time())
|
|
23
|
+
win = now - (now % window)
|
|
24
|
+
count = self._buckets.get((key, win), 0) + 1
|
|
25
|
+
self._buckets[(key, win)] = count
|
|
26
|
+
reset = win + window
|
|
27
|
+
return count, self.limit, reset
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class RedisRateLimitStore:
|
|
31
|
+
"""Fixed-window counter store using Redis.
|
|
32
|
+
|
|
33
|
+
Keys are of the form: {prefix}:{key}:{windowStart}
|
|
34
|
+
Values are incremented and expire automatically at window end.
|
|
35
|
+
|
|
36
|
+
This implementation uses atomic INCR and EXPIRE semantics. To avoid race conditions
|
|
37
|
+
on first-set expiry, we set expiry when the counter is created.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
redis_client,
|
|
43
|
+
*,
|
|
44
|
+
limit: int = 120,
|
|
45
|
+
prefix: str = "ratelimit",
|
|
46
|
+
clock: Optional[callable] = None,
|
|
47
|
+
):
|
|
48
|
+
self.redis = redis_client
|
|
49
|
+
self.limit = limit
|
|
50
|
+
self.prefix = prefix
|
|
51
|
+
self._clock = clock or time.time
|
|
52
|
+
|
|
53
|
+
def _window_key(self, key: str, window: int) -> tuple[str, int, str]:
|
|
54
|
+
now = int(self._clock())
|
|
55
|
+
win = now - (now % window)
|
|
56
|
+
redis_key = f"{self.prefix}:{key}:{win}"
|
|
57
|
+
return redis_key, win, now
|
|
58
|
+
|
|
59
|
+
def incr(self, key: str, window: int) -> Tuple[int, int, int]:
|
|
60
|
+
rkey, win, now = self._window_key(key, window)
|
|
61
|
+
# Increment; if this is the first time we've seen this window key, set expiry to window end
|
|
62
|
+
pipe = self.redis.pipeline()
|
|
63
|
+
pipe.incr(rkey)
|
|
64
|
+
pipe.ttl(rkey)
|
|
65
|
+
count, ttl = pipe.execute()
|
|
66
|
+
if ttl == -1: # key exists without expire or just created; set expire to end of window
|
|
67
|
+
expire_sec = (win + window) - now
|
|
68
|
+
if expire_sec <= 0:
|
|
69
|
+
expire_sec = window
|
|
70
|
+
try:
|
|
71
|
+
self.redis.expire(rkey, expire_sec)
|
|
72
|
+
except Exception:
|
|
73
|
+
pass
|
|
74
|
+
reset = win + window
|
|
75
|
+
return int(count), self.limit, reset
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
__all__ = ["RateLimitStore", "InMemoryRateLimitStore", "RedisRateLimitStore"]
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
4
|
+
from starlette.responses import JSONResponse
|
|
5
|
+
|
|
6
|
+
from svc_infra.obs.metrics import emit_suspect_payload
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class RequestSizeLimitMiddleware(BaseHTTPMiddleware):
|
|
10
|
+
def __init__(self, app, max_bytes: int = 1_000_000):
|
|
11
|
+
super().__init__(app)
|
|
12
|
+
self.max_bytes = max_bytes
|
|
13
|
+
|
|
14
|
+
async def dispatch(self, request, call_next):
|
|
15
|
+
length = request.headers.get("content-length")
|
|
16
|
+
try:
|
|
17
|
+
size = int(length) if length is not None else None
|
|
18
|
+
except Exception:
|
|
19
|
+
size = None
|
|
20
|
+
if size is not None and size > self.max_bytes:
|
|
21
|
+
try:
|
|
22
|
+
emit_suspect_payload(
|
|
23
|
+
getattr(request, "url", None).path if hasattr(request, "url") else None, size
|
|
24
|
+
)
|
|
25
|
+
except Exception:
|
|
26
|
+
pass
|
|
27
|
+
return JSONResponse(
|
|
28
|
+
status_code=413,
|
|
29
|
+
content={
|
|
30
|
+
"title": "Payload Too Large",
|
|
31
|
+
"status": 413,
|
|
32
|
+
"detail": "Request body exceeds allowed size.",
|
|
33
|
+
"code": "PAYLOAD_TOO_LARGE",
|
|
34
|
+
},
|
|
35
|
+
)
|
|
36
|
+
return await call_next(request)
|
svc_infra/api/fastapi/setup.py
CHANGED
|
@@ -61,7 +61,8 @@ def _setup_cors(app: FastAPI, public_cors_origins: list[str] | str | None = None
|
|
|
61
61
|
elif isinstance(public_cors_origins, str):
|
|
62
62
|
origins = [o.strip() for o in public_cors_origins.split(",") if o and o.strip()]
|
|
63
63
|
else:
|
|
64
|
-
|
|
64
|
+
# Strict by default: no CORS unless explicitly configured via env or parameter.
|
|
65
|
+
fallback = os.getenv("CORS_ALLOW_ORIGINS", "")
|
|
65
66
|
origins = [o.strip() for o in fallback.split(",") if o and o.strip()]
|
|
66
67
|
|
|
67
68
|
if not origins:
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
Metrics package public API.
|
|
5
|
+
|
|
6
|
+
Provides lightweight, overridable hooks for abuse heuristics so callers can
|
|
7
|
+
plug in logging or a metrics backend without a hard dependency.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from typing import Callable, Optional
|
|
11
|
+
|
|
12
|
+
# Function variables so applications/tests can replace them at runtime.
|
|
13
|
+
on_rate_limit_exceeded: Callable[[str, int, int], None] | None = None
|
|
14
|
+
"""
|
|
15
|
+
Called when a request is rate-limited.
|
|
16
|
+
Args:
|
|
17
|
+
key: identifier used for rate limiting (e.g., API key or IP)
|
|
18
|
+
limit: configured limit for the window
|
|
19
|
+
retry_after: seconds until next allowed attempt
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
on_suspect_payload: Callable[[Optional[str], int], None] | None = None
|
|
23
|
+
"""
|
|
24
|
+
Called when a request exceeds the configured size limit.
|
|
25
|
+
Args:
|
|
26
|
+
path: request path if available
|
|
27
|
+
size: reported content-length
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def emit_rate_limited(key: str, limit: int, retry_after: int) -> None:
|
|
32
|
+
if on_rate_limit_exceeded:
|
|
33
|
+
try:
|
|
34
|
+
on_rate_limit_exceeded(key, limit, retry_after)
|
|
35
|
+
except Exception:
|
|
36
|
+
# Never break request flow on metrics exceptions
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def emit_suspect_payload(path: Optional[str], size: int) -> None:
|
|
41
|
+
if on_suspect_payload:
|
|
42
|
+
try:
|
|
43
|
+
on_suspect_payload(path, size)
|
|
44
|
+
except Exception:
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
__all__ = [
|
|
49
|
+
"emit_rate_limited",
|
|
50
|
+
"emit_suspect_payload",
|
|
51
|
+
"on_rate_limit_exceeded",
|
|
52
|
+
"on_suspect_payload",
|
|
53
|
+
]
|
svc_infra/obs/metrics.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
Lightweight metrics hooks for abuse heuristics. Intentionally minimal to avoid pulling
|
|
5
|
+
full metrics stacks; these are no-ops by default but can be swapped in tests or wired
|
|
6
|
+
to a metrics backend by overriding the functions.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Callable, Optional
|
|
10
|
+
|
|
11
|
+
# Function variables so applications/tests can replace them at runtime.
|
|
12
|
+
on_rate_limit_exceeded: Callable[[str, int, int], None] | None = None
|
|
13
|
+
"""
|
|
14
|
+
Called when a request is rate-limited.
|
|
15
|
+
Args:
|
|
16
|
+
key: identifier used for rate limiting (e.g., API key or IP)
|
|
17
|
+
limit: configured limit for the window
|
|
18
|
+
retry_after: seconds until next allowed attempt
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
on_suspect_payload: Callable[[Optional[str], int], None] | None = None
|
|
22
|
+
"""
|
|
23
|
+
Called when a request exceeds the configured size limit.
|
|
24
|
+
Args:
|
|
25
|
+
path: request path if available
|
|
26
|
+
size: reported content-length
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def emit_rate_limited(key: str, limit: int, retry_after: int) -> None:
|
|
31
|
+
if on_rate_limit_exceeded:
|
|
32
|
+
try:
|
|
33
|
+
on_rate_limit_exceeded(key, limit, retry_after)
|
|
34
|
+
except Exception:
|
|
35
|
+
# Never break request flow on metrics exceptions
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def emit_suspect_payload(path: Optional[str], size: int) -> None:
|
|
40
|
+
if on_suspect_payload:
|
|
41
|
+
try:
|
|
42
|
+
on_suspect_payload(path, size)
|
|
43
|
+
except Exception:
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
__all__ = [
|
|
48
|
+
"emit_rate_limited",
|
|
49
|
+
"emit_suspect_payload",
|
|
50
|
+
"on_rate_limit_exceeded",
|
|
51
|
+
"on_suspect_payload",
|
|
52
|
+
]
|