svc-infra 0.1.593__py3-none-any.whl → 0.1.594__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of svc-infra might be problematic. Click here for more details.

Files changed (32) hide show
  1. svc_infra/apf_payments/provider/aiydan.py +28 -2
  2. svc_infra/apf_payments/service.py +113 -20
  3. svc_infra/api/fastapi/apf_payments/router.py +3 -1
  4. svc_infra/api/fastapi/auth/add.py +10 -0
  5. svc_infra/api/fastapi/auth/gaurd.py +67 -5
  6. svc_infra/api/fastapi/auth/routers/oauth_router.py +76 -36
  7. svc_infra/api/fastapi/auth/routers/session_router.py +63 -0
  8. svc_infra/api/fastapi/auth/settings.py +2 -0
  9. svc_infra/api/fastapi/db/sql/users.py +13 -1
  10. svc_infra/api/fastapi/dependencies/ratelimit.py +66 -0
  11. svc_infra/api/fastapi/middleware/ratelimit.py +26 -11
  12. svc_infra/api/fastapi/middleware/ratelimit_store.py +30 -0
  13. svc_infra/api/fastapi/middleware/request_size_limit.py +36 -0
  14. svc_infra/api/fastapi/setup.py +2 -1
  15. svc_infra/obs/metrics/__init__.py +53 -0
  16. svc_infra/obs/metrics.py +52 -0
  17. svc_infra/security/audit.py +130 -0
  18. svc_infra/security/audit_service.py +73 -0
  19. svc_infra/security/headers.py +39 -0
  20. svc_infra/security/hibp.py +91 -0
  21. svc_infra/security/jwt_rotation.py +53 -0
  22. svc_infra/security/lockout.py +96 -0
  23. svc_infra/security/models.py +245 -0
  24. svc_infra/security/org_invites.py +128 -0
  25. svc_infra/security/passwords.py +77 -0
  26. svc_infra/security/permissions.py +148 -0
  27. svc_infra/security/session.py +89 -0
  28. svc_infra/security/signed_cookies.py +80 -0
  29. {svc_infra-0.1.593.dist-info → svc_infra-0.1.594.dist-info}/METADATA +1 -1
  30. {svc_infra-0.1.593.dist-info → svc_infra-0.1.594.dist-info}/RECORD +32 -15
  31. {svc_infra-0.1.593.dist-info → svc_infra-0.1.594.dist-info}/WHEEL +0 -0
  32. {svc_infra-0.1.593.dist-info → svc_infra-0.1.594.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,63 @@
1
+ from __future__ import annotations
2
+
3
+ from datetime import datetime, timezone
4
+ from typing import List
5
+
6
+ from fastapi import APIRouter, HTTPException
7
+ from sqlalchemy import select
8
+
9
+ from svc_infra.api.fastapi.auth.security import Identity
10
+ from svc_infra.api.fastapi.db.sql.session import SqlSessionDep
11
+ from svc_infra.security.models import AuthSession
12
+ from svc_infra.security.permissions import RequirePermission
13
+
14
+
15
+ def build_session_router() -> APIRouter:
16
+ router = APIRouter(prefix="/sessions", tags=["sessions"])
17
+
18
+ @router.get(
19
+ "/me", response_model=list[dict], dependencies=[RequirePermission("security.session.list")]
20
+ )
21
+ async def list_my_sessions(identity: Identity, session: SqlSessionDep) -> List[dict]:
22
+ stmt = select(AuthSession).where(AuthSession.user_id == identity.user.id)
23
+ rows = (await session.execute(stmt)).scalars().all()
24
+ return [
25
+ {
26
+ "id": str(r.id),
27
+ "user_agent": r.user_agent,
28
+ "ip_hash": r.ip_hash,
29
+ "revoked": bool(r.revoked_at),
30
+ "last_seen_at": r.last_seen_at.isoformat() if r.last_seen_at else None,
31
+ "created_at": r.created_at.isoformat() if r.created_at else None,
32
+ }
33
+ for r in rows
34
+ ]
35
+
36
+ @router.post(
37
+ "/{session_id}/revoke",
38
+ status_code=204,
39
+ dependencies=[RequirePermission("security.session.revoke")],
40
+ )
41
+ async def revoke_session(session_id: str, identity: Identity, db: SqlSessionDep):
42
+ # Load session and ensure it belongs to the user (non-admin users cannot revoke others)
43
+ s = await db.get(AuthSession, session_id)
44
+ if not s:
45
+ raise HTTPException(404, "session_not_found")
46
+ # Basic ownership check; could extend for admin bypass later
47
+ if s.user_id != identity.user.id:
48
+ raise HTTPException(403, "forbidden")
49
+ if s.revoked_at:
50
+ return # already revoked
51
+ s.revoked_at = datetime.now(timezone.utc)
52
+ s.revoke_reason = "user_revoked"
53
+ # Revoke all refresh tokens for this session
54
+ for rt in s.refresh_tokens:
55
+ if not rt.revoked_at:
56
+ rt.revoked_at = s.revoked_at
57
+ rt.revoke_reason = "session_revoked"
58
+ await db.flush()
59
+
60
+ return router
61
+
62
+
63
+ __all__ = ["build_session_router"]
@@ -18,6 +18,8 @@ class OIDCProvider(BaseModel):
18
18
  class JWTSettings(BaseModel):
19
19
  secret: SecretStr
20
20
  lifetime_seconds: int = 60 * 60 * 24 * 7
21
+ # Optional older secrets accepted for verification during rotation window
22
+ old_secrets: List[SecretStr] = Field(default_factory=list)
21
23
 
22
24
 
23
25
  class PasswordClient(BaseModel):
@@ -12,6 +12,7 @@ from svc_infra.api.fastapi.auth.settings import get_auth_settings
12
12
  from svc_infra.api.fastapi.dual.dualize import dualize_public, dualize_user
13
13
  from svc_infra.api.fastapi.dual.router import DualAPIRouter
14
14
  from svc_infra.app.env import CURRENT_ENVIRONMENT, DEV_ENV, LOCAL_ENV
15
+ from svc_infra.security.jwt_rotation import RotatingJWTStrategy
15
16
 
16
17
  from ...auth.security import auth_login_path
17
18
  from ...auth.sender import get_sender
@@ -94,7 +95,18 @@ def get_fastapi_users(
94
95
  lifetime = getattr(jwt_block, "lifetime_seconds", None) if jwt_block else None
95
96
  if not isinstance(lifetime, int) or lifetime <= 0:
96
97
  lifetime = 3600
97
- return JWTStrategy(secret=secret, lifetime_seconds=lifetime)
98
+ old = []
99
+ if jwt_block and getattr(jwt_block, "old_secrets", None):
100
+ old = [s.get_secret_value() for s in jwt_block.old_secrets or []]
101
+ audience = "fastapi-users:auth"
102
+ if old:
103
+ return RotatingJWTStrategy(
104
+ secret=secret,
105
+ lifetime_seconds=lifetime,
106
+ old_secrets=old,
107
+ token_audience=audience,
108
+ )
109
+ return JWTStrategy(secret=secret, lifetime_seconds=lifetime, token_audience=audience)
98
110
 
99
111
  bearer_transport = BearerTransport(tokenUrl=auth_login_path)
100
112
  auth_backend = AuthenticationBackend(
@@ -0,0 +1,66 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from typing import Callable
5
+
6
+ from fastapi import HTTPException
7
+ from starlette.requests import Request
8
+
9
+ from svc_infra.api.fastapi.middleware.ratelimit_store import InMemoryRateLimitStore, RateLimitStore
10
+ from svc_infra.obs.metrics import emit_rate_limited
11
+
12
+
13
+ class RateLimiter:
14
+ def __init__(
15
+ self,
16
+ *,
17
+ limit: int,
18
+ window: int = 60,
19
+ key_fn: Callable = lambda r: "global",
20
+ store: RateLimitStore | None = None,
21
+ ):
22
+ self.limit = limit
23
+ self.window = window
24
+ self.key_fn = key_fn
25
+ self.store = store or InMemoryRateLimitStore(limit=limit)
26
+
27
+ async def __call__(self, request: Request):
28
+ key = self.key_fn(request)
29
+ count, limit, reset = self.store.incr(str(key), self.window)
30
+ if count > limit:
31
+ retry = max(0, reset - int(time.time()))
32
+ try:
33
+ emit_rate_limited(str(key), limit, retry)
34
+ except Exception:
35
+ pass
36
+ raise HTTPException(
37
+ status_code=429, detail="Rate limit exceeded", headers={"Retry-After": str(retry)}
38
+ )
39
+
40
+
41
+ __all__ = ["RateLimiter"]
42
+
43
+
44
+ def rate_limiter(
45
+ *,
46
+ limit: int,
47
+ window: int = 60,
48
+ key_fn: Callable = lambda r: "global",
49
+ store: RateLimitStore | None = None,
50
+ ):
51
+ store_ = store or InMemoryRateLimitStore(limit=limit)
52
+
53
+ async def dep(request: Request):
54
+ key = key_fn(request)
55
+ count, lim, reset = store_.incr(str(key), window)
56
+ if count > lim:
57
+ retry = max(0, reset - int(time.time()))
58
+ try:
59
+ emit_rate_limited(str(key), lim, retry)
60
+ except Exception:
61
+ pass
62
+ raise HTTPException(
63
+ status_code=429, detail="Rate limit exceeded", headers={"Retry-After": str(retry)}
64
+ )
65
+
66
+ return dep
@@ -3,25 +3,41 @@ import time
3
3
  from starlette.middleware.base import BaseHTTPMiddleware
4
4
  from starlette.responses import JSONResponse
5
5
 
6
+ from svc_infra.obs.metrics import emit_rate_limited
7
+
8
+ from .ratelimit_store import InMemoryRateLimitStore, RateLimitStore
9
+
6
10
 
7
11
  class SimpleRateLimitMiddleware(BaseHTTPMiddleware):
8
- def __init__(self, app, limit: int = 120, window: int = 60, key_fn=None):
12
+ def __init__(
13
+ self,
14
+ app,
15
+ limit: int = 120,
16
+ window: int = 60,
17
+ key_fn=None,
18
+ store: RateLimitStore | None = None,
19
+ ):
9
20
  super().__init__(app)
10
21
  self.limit, self.window = limit, window
11
22
  self.key_fn = key_fn or (lambda r: r.headers.get("X-API-Key") or r.client.host)
12
- self.buckets = {} # replace with Redis in prod
23
+ self.store = store or InMemoryRateLimitStore(limit=limit)
13
24
 
14
25
  async def dispatch(self, request, call_next):
15
26
  key = self.key_fn(request)
16
27
  now = int(time.time())
17
- win = now - (now % self.window)
18
- bucket = self.buckets.setdefault((key, win), 0)
28
+ # Increment counter in store
29
+ count, limit, reset = self.store.incr(str(key), self.window)
30
+ remaining = max(0, limit - count)
19
31
 
20
- remaining = self.limit - bucket
21
- reset = win + self.window
32
+ if remaining < 0: # defensive clamp
33
+ remaining = 0
22
34
 
23
- if remaining <= 0:
35
+ if count > limit:
24
36
  retry = max(0, reset - now)
37
+ try:
38
+ emit_rate_limited(str(key), limit, retry)
39
+ except Exception:
40
+ pass
25
41
  return JSONResponse(
26
42
  status_code=429,
27
43
  content={
@@ -31,16 +47,15 @@ class SimpleRateLimitMiddleware(BaseHTTPMiddleware):
31
47
  "code": "RATE_LIMITED",
32
48
  },
33
49
  headers={
34
- "X-RateLimit-Limit": str(self.limit),
50
+ "X-RateLimit-Limit": str(limit),
35
51
  "X-RateLimit-Remaining": "0",
36
52
  "X-RateLimit-Reset": str(reset),
37
53
  "Retry-After": str(retry),
38
54
  },
39
55
  )
40
56
 
41
- self.buckets[(key, win)] = bucket + 1
42
57
  resp = await call_next(request)
43
- resp.headers.setdefault("X-RateLimit-Limit", str(self.limit))
44
- resp.headers.setdefault("X-RateLimit-Remaining", str(self.limit - (bucket + 1)))
58
+ resp.headers.setdefault("X-RateLimit-Limit", str(limit))
59
+ resp.headers.setdefault("X-RateLimit-Remaining", str(remaining))
45
60
  resp.headers.setdefault("X-RateLimit-Reset", str(reset))
46
61
  return resp
@@ -0,0 +1,30 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from typing import Protocol, Tuple
5
+
6
+
7
+ class RateLimitStore(Protocol):
8
+ def incr(self, key: str, window: int) -> Tuple[int, int, int]:
9
+ """Increment and return (count, limit, resetEpoch).
10
+
11
+ Implementations should manage per-window buckets. The 'limit' is stored configuration.
12
+ """
13
+ ...
14
+
15
+
16
+ class InMemoryRateLimitStore:
17
+ def __init__(self, limit: int = 120):
18
+ self.limit = limit
19
+ self._buckets: dict[tuple[str, int], int] = {}
20
+
21
+ def incr(self, key: str, window: int) -> Tuple[int, int, int]:
22
+ now = int(time.time())
23
+ win = now - (now % window)
24
+ count = self._buckets.get((key, win), 0) + 1
25
+ self._buckets[(key, win)] = count
26
+ reset = win + window
27
+ return count, self.limit, reset
28
+
29
+
30
+ __all__ = ["RateLimitStore", "InMemoryRateLimitStore"]
@@ -0,0 +1,36 @@
1
+ from __future__ import annotations
2
+
3
+ from starlette.middleware.base import BaseHTTPMiddleware
4
+ from starlette.responses import JSONResponse
5
+
6
+ from svc_infra.obs.metrics import emit_suspect_payload
7
+
8
+
9
+ class RequestSizeLimitMiddleware(BaseHTTPMiddleware):
10
+ def __init__(self, app, max_bytes: int = 1_000_000):
11
+ super().__init__(app)
12
+ self.max_bytes = max_bytes
13
+
14
+ async def dispatch(self, request, call_next):
15
+ length = request.headers.get("content-length")
16
+ try:
17
+ size = int(length) if length is not None else None
18
+ except Exception:
19
+ size = None
20
+ if size is not None and size > self.max_bytes:
21
+ try:
22
+ emit_suspect_payload(
23
+ getattr(request, "url", None).path if hasattr(request, "url") else None, size
24
+ )
25
+ except Exception:
26
+ pass
27
+ return JSONResponse(
28
+ status_code=413,
29
+ content={
30
+ "title": "Payload Too Large",
31
+ "status": 413,
32
+ "detail": "Request body exceeds allowed size.",
33
+ "code": "PAYLOAD_TOO_LARGE",
34
+ },
35
+ )
36
+ return await call_next(request)
@@ -61,7 +61,8 @@ def _setup_cors(app: FastAPI, public_cors_origins: list[str] | str | None = None
61
61
  elif isinstance(public_cors_origins, str):
62
62
  origins = [o.strip() for o in public_cors_origins.split(",") if o and o.strip()]
63
63
  else:
64
- fallback = os.getenv("CORS_ALLOW_ORIGINS", "http://localhost:3000")
64
+ # Strict by default: no CORS unless explicitly configured via env or parameter.
65
+ fallback = os.getenv("CORS_ALLOW_ORIGINS", "")
65
66
  origins = [o.strip() for o in fallback.split(",") if o and o.strip()]
66
67
 
67
68
  if not origins:
@@ -0,0 +1,53 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ Metrics package public API.
5
+
6
+ Provides lightweight, overridable hooks for abuse heuristics so callers can
7
+ plug in logging or a metrics backend without a hard dependency.
8
+ """
9
+
10
+ from typing import Callable, Optional
11
+
12
+ # Function variables so applications/tests can replace them at runtime.
13
+ on_rate_limit_exceeded: Callable[[str, int, int], None] | None = None
14
+ """
15
+ Called when a request is rate-limited.
16
+ Args:
17
+ key: identifier used for rate limiting (e.g., API key or IP)
18
+ limit: configured limit for the window
19
+ retry_after: seconds until next allowed attempt
20
+ """
21
+
22
+ on_suspect_payload: Callable[[Optional[str], int], None] | None = None
23
+ """
24
+ Called when a request exceeds the configured size limit.
25
+ Args:
26
+ path: request path if available
27
+ size: reported content-length
28
+ """
29
+
30
+
31
+ def emit_rate_limited(key: str, limit: int, retry_after: int) -> None:
32
+ if on_rate_limit_exceeded:
33
+ try:
34
+ on_rate_limit_exceeded(key, limit, retry_after)
35
+ except Exception:
36
+ # Never break request flow on metrics exceptions
37
+ pass
38
+
39
+
40
+ def emit_suspect_payload(path: Optional[str], size: int) -> None:
41
+ if on_suspect_payload:
42
+ try:
43
+ on_suspect_payload(path, size)
44
+ except Exception:
45
+ pass
46
+
47
+
48
+ __all__ = [
49
+ "emit_rate_limited",
50
+ "emit_suspect_payload",
51
+ "on_rate_limit_exceeded",
52
+ "on_suspect_payload",
53
+ ]
@@ -0,0 +1,52 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ Lightweight metrics hooks for abuse heuristics. Intentionally minimal to avoid pulling
5
+ full metrics stacks; these are no-ops by default but can be swapped in tests or wired
6
+ to a metrics backend by overriding the functions.
7
+ """
8
+
9
+ from typing import Callable, Optional
10
+
11
+ # Function variables so applications/tests can replace them at runtime.
12
+ on_rate_limit_exceeded: Callable[[str, int, int], None] | None = None
13
+ """
14
+ Called when a request is rate-limited.
15
+ Args:
16
+ key: identifier used for rate limiting (e.g., API key or IP)
17
+ limit: configured limit for the window
18
+ retry_after: seconds until next allowed attempt
19
+ """
20
+
21
+ on_suspect_payload: Callable[[Optional[str], int], None] | None = None
22
+ """
23
+ Called when a request exceeds the configured size limit.
24
+ Args:
25
+ path: request path if available
26
+ size: reported content-length
27
+ """
28
+
29
+
30
+ def emit_rate_limited(key: str, limit: int, retry_after: int) -> None:
31
+ if on_rate_limit_exceeded:
32
+ try:
33
+ on_rate_limit_exceeded(key, limit, retry_after)
34
+ except Exception:
35
+ # Never break request flow on metrics exceptions
36
+ pass
37
+
38
+
39
+ def emit_suspect_payload(path: Optional[str], size: int) -> None:
40
+ if on_suspect_payload:
41
+ try:
42
+ on_suspect_payload(path, size)
43
+ except Exception:
44
+ pass
45
+
46
+
47
+ __all__ = [
48
+ "emit_rate_limited",
49
+ "emit_suspect_payload",
50
+ "on_rate_limit_exceeded",
51
+ "on_suspect_payload",
52
+ ]
@@ -0,0 +1,130 @@
1
+ from __future__ import annotations
2
+
3
+ """Audit log append & chain verification utilities.
4
+
5
+ Provides helpers to append a new AuditLog entry maintaining a hash-chain
6
+ integrity model and to verify an existing sequence for tampering.
7
+
8
+ Design notes:
9
+ - Each event stores prev_hash (previous event's hash or 64 zeros for genesis).
10
+ - Hash = sha256(prev_hash + canonical_json_payload).
11
+ - Verification recomputes expected hash for each event and compares.
12
+ - If a middle event is altered, that event and all subsequent events will
13
+ fail verification (because their prev_hash links break transitively).
14
+ """
15
+
16
+ from datetime import datetime, timezone
17
+ from typing import Any, List, Optional, Sequence, Tuple
18
+
19
+ try: # SQLAlchemy may not be present in minimal test context
20
+ from sqlalchemy import select
21
+ from sqlalchemy.ext.asyncio import AsyncSession
22
+ except Exception: # pragma: no cover
23
+ AsyncSession = Any # type: ignore
24
+ select = None # type: ignore
25
+
26
+ from svc_infra.security.models import AuditLog, compute_audit_hash
27
+
28
+
29
+ async def append_audit_event(
30
+ db: Any,
31
+ *,
32
+ actor_id=None,
33
+ tenant_id: Optional[str] = None,
34
+ event_type: str,
35
+ resource_ref: Optional[str] = None,
36
+ metadata: dict | None = None,
37
+ ts: Optional[datetime] = None,
38
+ prev_event: Optional[AuditLog] = None,
39
+ ) -> AuditLog:
40
+ """Append an audit event returning the persisted row.
41
+
42
+ If prev_event is not supplied, it attempts to fetch the latest event for
43
+ the tenant (or global chain when tenant_id is None).
44
+ """
45
+ metadata = metadata or {}
46
+ ts = ts or datetime.now(timezone.utc)
47
+
48
+ prev_hash: Optional[str] = None
49
+ if prev_event is not None:
50
+ prev_hash = prev_event.hash
51
+ elif select is not None and hasattr(db, "execute"): # attempt DB lookup for previous event
52
+ try:
53
+ stmt = (
54
+ select(AuditLog)
55
+ .where(AuditLog.tenant_id == tenant_id)
56
+ .order_by(AuditLog.id.desc())
57
+ .limit(1)
58
+ )
59
+ result = await db.execute(stmt) # type: ignore[attr-defined]
60
+ prev = result.scalars().first()
61
+ if prev:
62
+ prev_hash = prev.hash
63
+ except Exception: # pragma: no cover - defensive for minimal fakes
64
+ pass
65
+
66
+ new_hash = compute_audit_hash(
67
+ prev_hash,
68
+ ts=ts,
69
+ actor_id=actor_id,
70
+ tenant_id=tenant_id,
71
+ event_type=event_type,
72
+ resource_ref=resource_ref,
73
+ metadata=metadata,
74
+ )
75
+
76
+ row = AuditLog(
77
+ ts=ts,
78
+ actor_id=actor_id,
79
+ tenant_id=tenant_id,
80
+ event_type=event_type,
81
+ resource_ref=resource_ref,
82
+ event_metadata=metadata,
83
+ prev_hash=prev_hash or "0" * 64,
84
+ hash=new_hash,
85
+ )
86
+ if hasattr(db, "add"):
87
+ try:
88
+ db.add(row) # type: ignore[attr-defined]
89
+ except Exception: # pragma: no cover - minimal shim safety
90
+ pass
91
+ if hasattr(db, "flush"):
92
+ try:
93
+ await db.flush() # type: ignore[attr-defined]
94
+ except Exception: # pragma: no cover
95
+ pass
96
+ return row
97
+
98
+
99
+ def verify_audit_chain(events: Sequence[AuditLog]) -> Tuple[bool, List[int]]:
100
+ """Verify a sequence of audit events.
101
+
102
+ Returns (ok, broken_indices). If any event's hash doesn't match the recomputed
103
+ expected hash (based on previous event), its index is recorded. All events are
104
+ checked so callers can analyze extent of tampering.
105
+ """
106
+ broken: List[int] = []
107
+ prev_hash = "0" * 64
108
+ for idx, ev in enumerate(events):
109
+ expected = compute_audit_hash(
110
+ prev_hash if ev.prev_hash == prev_hash else ev.prev_hash,
111
+ ts=ev.ts,
112
+ actor_id=ev.actor_id,
113
+ tenant_id=ev.tenant_id,
114
+ event_type=ev.event_type,
115
+ resource_ref=ev.resource_ref,
116
+ metadata=ev.event_metadata,
117
+ )
118
+ # prev_hash stored should equal previous event hash (or zeros for genesis)
119
+ if (idx == 0 and ev.prev_hash != "0" * 64) or (
120
+ idx > 0 and ev.prev_hash != events[idx - 1].hash
121
+ ):
122
+ broken.append(idx)
123
+ if ev.hash != expected:
124
+ broken.append(idx)
125
+ prev_hash = ev.hash
126
+ ok = not broken
127
+ return ok, sorted(set(broken))
128
+
129
+
130
+ __all__ = ["append_audit_event", "verify_audit_chain"]
@@ -0,0 +1,73 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, List, Optional, Sequence, Tuple
4
+
5
+ try: # optional SQLAlchemy import for environments without SA
6
+ from sqlalchemy import select
7
+ except Exception: # pragma: no cover
8
+ select = None # type: ignore
9
+
10
+ from .audit import append_audit_event, verify_audit_chain
11
+ from .models import AuditLog
12
+
13
+
14
+ async def append_event(
15
+ db: Any,
16
+ *,
17
+ actor_id=None,
18
+ tenant_id: Optional[str] = None,
19
+ event_type: str,
20
+ resource_ref: Optional[str] = None,
21
+ metadata: dict | None = None,
22
+ prev_event: Optional[AuditLog] = None,
23
+ ) -> AuditLog:
24
+ """Append an AuditLog event using the shared append utility.
25
+
26
+ If prev_event is not provided, attempts to look up the last event for the tenant.
27
+ """
28
+ return await append_audit_event(
29
+ db,
30
+ actor_id=actor_id,
31
+ tenant_id=tenant_id,
32
+ event_type=event_type,
33
+ resource_ref=resource_ref,
34
+ metadata=metadata,
35
+ prev_event=prev_event,
36
+ )
37
+
38
+
39
+ async def verify_chain_for_tenant(
40
+ db: Any, *, tenant_id: Optional[str] = None
41
+ ) -> Tuple[bool, List[int]]:
42
+ """Fetch all AuditLog events for a tenant and verify hash-chain integrity.
43
+
44
+ Falls back to inspecting an in-memory 'added' list when SQLAlchemy is not available,
45
+ to simplify unit tests with fake DBs.
46
+ """
47
+ events: Sequence[AuditLog] = []
48
+ if select is not None and hasattr(db, "execute"):
49
+ try:
50
+ stmt = select(AuditLog)
51
+ if tenant_id is not None:
52
+ stmt = stmt.where(AuditLog.tenant_id == tenant_id)
53
+ stmt = stmt.order_by(AuditLog.id.asc())
54
+ result = await db.execute(stmt) # type: ignore[attr-defined]
55
+ events = list(result.scalars().all())
56
+ except Exception: # pragma: no cover
57
+ events = []
58
+ elif hasattr(db, "added"):
59
+ try:
60
+ pool = getattr(db, "added")
61
+ # Preserve insertion order for in-memory fake DBs where primary keys may be None
62
+ events = [
63
+ e
64
+ for e in pool
65
+ if isinstance(e, AuditLog) and (tenant_id is None or e.tenant_id == tenant_id)
66
+ ]
67
+ except Exception: # pragma: no cover
68
+ events = []
69
+
70
+ return verify_audit_chain(list(events))
71
+
72
+
73
+ __all__ = ["append_event", "verify_chain_for_tenant"]
@@ -0,0 +1,39 @@
1
+ from __future__ import annotations
2
+
3
+ SECURE_DEFAULTS = {
4
+ "Strict-Transport-Security": "max-age=63072000; includeSubDomains; preload",
5
+ "X-Content-Type-Options": "nosniff",
6
+ "X-Frame-Options": "DENY",
7
+ "Referrer-Policy": "strict-origin-when-cross-origin",
8
+ "X-XSS-Protection": "0",
9
+ # CSP kept minimal; allow config override
10
+ "Content-Security-Policy": "default-src 'none'; frame-ancestors 'none'; base-uri 'none'; form-action 'self'",
11
+ }
12
+
13
+
14
+ class SecurityHeadersMiddleware:
15
+ def __init__(self, app, overrides: dict[str, str] | None = None):
16
+ self.app = app
17
+ self.overrides = overrides or {}
18
+
19
+ async def __call__(self, scope, receive, send):
20
+ if scope.get("type") != "http":
21
+ await self.app(scope, receive, send)
22
+ return
23
+
24
+ async def _send(message):
25
+ if message.get("type") == "http.response.start":
26
+ headers = message.setdefault("headers", [])
27
+ existing = {k.decode(): v.decode() for k, v in headers}
28
+ merged = {**SECURE_DEFAULTS, **existing, **self.overrides}
29
+ # rebuild headers list
30
+ new_headers = []
31
+ for k, v in merged.items():
32
+ new_headers.append((k.encode(), v.encode()))
33
+ message["headers"] = new_headers
34
+ await send(message)
35
+
36
+ await self.app(scope, receive, _send)
37
+
38
+
39
+ __all__ = ["SecurityHeadersMiddleware", "SECURE_DEFAULTS"]