svc-infra 0.1.593__py3-none-any.whl → 0.1.594__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of svc-infra might be problematic. Click here for more details.

Files changed (32) hide show
  1. svc_infra/apf_payments/provider/aiydan.py +28 -2
  2. svc_infra/apf_payments/service.py +113 -20
  3. svc_infra/api/fastapi/apf_payments/router.py +3 -1
  4. svc_infra/api/fastapi/auth/add.py +10 -0
  5. svc_infra/api/fastapi/auth/gaurd.py +67 -5
  6. svc_infra/api/fastapi/auth/routers/oauth_router.py +76 -36
  7. svc_infra/api/fastapi/auth/routers/session_router.py +63 -0
  8. svc_infra/api/fastapi/auth/settings.py +2 -0
  9. svc_infra/api/fastapi/db/sql/users.py +13 -1
  10. svc_infra/api/fastapi/dependencies/ratelimit.py +66 -0
  11. svc_infra/api/fastapi/middleware/ratelimit.py +26 -11
  12. svc_infra/api/fastapi/middleware/ratelimit_store.py +30 -0
  13. svc_infra/api/fastapi/middleware/request_size_limit.py +36 -0
  14. svc_infra/api/fastapi/setup.py +2 -1
  15. svc_infra/obs/metrics/__init__.py +53 -0
  16. svc_infra/obs/metrics.py +52 -0
  17. svc_infra/security/audit.py +130 -0
  18. svc_infra/security/audit_service.py +73 -0
  19. svc_infra/security/headers.py +39 -0
  20. svc_infra/security/hibp.py +91 -0
  21. svc_infra/security/jwt_rotation.py +53 -0
  22. svc_infra/security/lockout.py +96 -0
  23. svc_infra/security/models.py +245 -0
  24. svc_infra/security/org_invites.py +128 -0
  25. svc_infra/security/passwords.py +77 -0
  26. svc_infra/security/permissions.py +148 -0
  27. svc_infra/security/session.py +89 -0
  28. svc_infra/security/signed_cookies.py +80 -0
  29. {svc_infra-0.1.593.dist-info → svc_infra-0.1.594.dist-info}/METADATA +1 -1
  30. {svc_infra-0.1.593.dist-info → svc_infra-0.1.594.dist-info}/RECORD +32 -15
  31. {svc_infra-0.1.593.dist-info → svc_infra-0.1.594.dist-info}/WHEEL +0 -0
  32. {svc_infra-0.1.593.dist-info → svc_infra-0.1.594.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,91 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import time
5
+ from dataclasses import dataclass
6
+ from typing import Dict, Optional
7
+
8
+ import httpx
9
+
10
+
11
+ def sha1_hex(data: str) -> str:
12
+ return hashlib.sha1(data.encode("utf-8")).hexdigest().upper()
13
+
14
+
15
+ @dataclass
16
+ class CacheEntry:
17
+ body: str
18
+ expires_at: float
19
+
20
+
21
+ class HIBPClient:
22
+ """Minimal HaveIBeenPwned range API client with simple in-memory cache.
23
+
24
+ - Uses k-anonymity range query: send first 5 chars of SHA1 hash, receive suffix list.
25
+ - Caches prefix responses for TTL to avoid repeated network calls.
26
+ - Synchronous implementation to allow use in sync validators.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ *,
32
+ base_url: str = "https://api.pwnedpasswords.com",
33
+ ttl_seconds: int = 3600,
34
+ timeout: float = 5.0,
35
+ user_agent: str = "svc-infra/hibp",
36
+ ) -> None:
37
+ self.base_url = base_url.rstrip("/")
38
+ self.ttl_seconds = ttl_seconds
39
+ self.timeout = timeout
40
+ self.user_agent = user_agent
41
+ self._cache: Dict[str, CacheEntry] = {}
42
+ self._http = httpx.Client(timeout=self.timeout, headers={"User-Agent": self.user_agent})
43
+
44
+ def _get_cached(self, prefix: str) -> Optional[str]:
45
+ now = time.time()
46
+ ent = self._cache.get(prefix)
47
+ if ent and ent.expires_at > now:
48
+ return ent.body
49
+ return None
50
+
51
+ def _set_cache(self, prefix: str, body: str) -> None:
52
+ self._cache[prefix] = CacheEntry(body=body, expires_at=time.time() + self.ttl_seconds)
53
+
54
+ def range_query(self, prefix: str) -> str:
55
+ cached = self._get_cached(prefix)
56
+ if cached is not None:
57
+ return cached
58
+ url = f"{self.base_url}/range/{prefix}"
59
+ resp = self._http.get(url)
60
+ resp.raise_for_status()
61
+ body = resp.text
62
+ self._set_cache(prefix, body)
63
+ return body
64
+
65
+ def is_breached(self, password: str) -> bool:
66
+ full = sha1_hex(password)
67
+ prefix, suffix = full[:5], full[5:]
68
+ try:
69
+ body = self.range_query(prefix)
70
+ except Exception:
71
+ # Fail-open: if HIBP unavailable, do not block users.
72
+ return False
73
+
74
+ for line in body.splitlines():
75
+ # Lines formatted as "SUFFIX:COUNT"
76
+ if not line:
77
+ continue
78
+ parts = line.split(":")
79
+ if len(parts) != 2:
80
+ continue
81
+ sfx = parts[0].strip().upper()
82
+ if sfx == suffix:
83
+ # Count > 0 implies breached
84
+ try:
85
+ return int(parts[1].strip()) > 0
86
+ except ValueError:
87
+ return True
88
+ return False
89
+
90
+
91
+ __all__ = ["HIBPClient", "sha1_hex"]
@@ -0,0 +1,53 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Iterable, List, Optional, Union
4
+
5
+ import jwt as pyjwt
6
+ from fastapi_users.authentication.strategy.jwt import JWTStrategy
7
+
8
+
9
+ class RotatingJWTStrategy(JWTStrategy):
10
+ """JWTStrategy that can verify tokens against multiple secrets.
11
+
12
+ Signing uses the primary secret (as in base class). Verification accepts any of
13
+ the provided secrets: [primary] + old_secrets.
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ *,
19
+ secret: str,
20
+ lifetime_seconds: int,
21
+ old_secrets: Optional[Iterable[str]] = None,
22
+ token_audience: Optional[Union[str, List[str]]] = None,
23
+ ):
24
+ super().__init__(
25
+ secret=secret, lifetime_seconds=lifetime_seconds, token_audience=token_audience
26
+ )
27
+ self._verify_secrets: List[str] = [secret] + list(old_secrets or [])
28
+
29
+ async def read_token(self, token: str, audience: Optional[str] = None): # type: ignore[override]
30
+ # Try with current strategy's configured secret first
31
+ eff_aud = audience or self.token_audience
32
+ try:
33
+ return await super().read_token(token, audience=eff_aud)
34
+ except Exception:
35
+ pass
36
+ # Try older secrets
37
+ for s in self._verify_secrets[1:]:
38
+ try:
39
+ data = pyjwt.decode(
40
+ token,
41
+ s,
42
+ algorithms=["HS256"],
43
+ audience=eff_aud,
44
+ )
45
+ if data is not None:
46
+ return data
47
+ except Exception:
48
+ pass
49
+ # If none of the secrets validated the token, raise a generic error
50
+ raise ValueError("Invalid token for all configured secrets")
51
+
52
+
53
+ __all__ = ["RotatingJWTStrategy"]
@@ -0,0 +1,96 @@
1
+ from __future__ import annotations
2
+
3
+ import uuid
4
+ from dataclasses import dataclass
5
+ from datetime import datetime, timedelta, timezone
6
+ from typing import Any, Optional, Sequence
7
+
8
+ try:
9
+ from sqlalchemy import select
10
+ from sqlalchemy.ext.asyncio import AsyncSession
11
+ except Exception: # pragma: no cover - optional import for type hints
12
+ AsyncSession = Any # type: ignore[misc]
13
+ select = None # type: ignore
14
+
15
+ from svc_infra.security.models import FailedAuthAttempt
16
+
17
+
18
+ @dataclass
19
+ class LockoutConfig:
20
+ threshold: int = 5 # failures before cooldown starts
21
+ window_minutes: int = 15 # look-back window for counting failures
22
+ base_cooldown_seconds: int = 30 # initial cooldown once threshold reached
23
+ max_cooldown_seconds: int = 3600 # cap exponential growth at 1 hour
24
+
25
+
26
+ @dataclass
27
+ class LockoutStatus:
28
+ locked: bool
29
+ next_allowed_at: Optional[datetime]
30
+ failure_count: int
31
+
32
+
33
+ # ---------------- Pure calculation -----------------
34
+
35
+
36
+ def compute_lockout(
37
+ fail_count: int, *, cfg: LockoutConfig, now: Optional[datetime] = None
38
+ ) -> LockoutStatus:
39
+ now = now or datetime.now(timezone.utc)
40
+ if fail_count < cfg.threshold:
41
+ return LockoutStatus(False, None, fail_count)
42
+ # cooldown factor exponent = fail_count - threshold
43
+ exponent = fail_count - cfg.threshold
44
+ cooldown = cfg.base_cooldown_seconds * (2**exponent)
45
+ if cooldown > cfg.max_cooldown_seconds:
46
+ cooldown = cfg.max_cooldown_seconds
47
+ return LockoutStatus(True, now + timedelta(seconds=cooldown), fail_count)
48
+
49
+
50
+ # ---------------- Persistence helpers (async) ---------------
51
+
52
+
53
+ async def record_attempt(
54
+ session: AsyncSession,
55
+ *,
56
+ user_id: Optional[uuid.UUID],
57
+ ip_hash: Optional[str],
58
+ success: bool,
59
+ ) -> None:
60
+ attempt = FailedAuthAttempt(user_id=user_id, ip_hash=ip_hash, success=success)
61
+ session.add(attempt)
62
+ await session.flush()
63
+
64
+
65
+ async def get_lockout_status(
66
+ session: AsyncSession,
67
+ *,
68
+ user_id: Optional[uuid.UUID],
69
+ ip_hash: Optional[str],
70
+ cfg: Optional[LockoutConfig] = None,
71
+ ) -> LockoutStatus:
72
+ cfg = cfg or LockoutConfig()
73
+ now = datetime.now(timezone.utc)
74
+ window_start = now - timedelta(minutes=cfg.window_minutes)
75
+
76
+ q = select(FailedAuthAttempt).where(
77
+ FailedAuthAttempt.ts >= window_start,
78
+ FailedAuthAttempt.success == False, # noqa: E712
79
+ )
80
+ if user_id:
81
+ q = q.where(FailedAuthAttempt.user_id == user_id)
82
+ if ip_hash:
83
+ q = q.where(FailedAuthAttempt.ip_hash == ip_hash)
84
+
85
+ rows: Sequence[FailedAuthAttempt] = (await session.execute(q)).scalars().all()
86
+ fail_count = len(rows)
87
+ return compute_lockout(fail_count, cfg=cfg, now=now)
88
+
89
+
90
+ __all__ = [
91
+ "LockoutConfig",
92
+ "LockoutStatus",
93
+ "compute_lockout",
94
+ "record_attempt",
95
+ "get_lockout_status",
96
+ ]
@@ -0,0 +1,245 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import json
5
+ import uuid
6
+ from datetime import datetime, timedelta, timezone
7
+ from typing import Optional
8
+
9
+ from sqlalchemy import JSON, Boolean, DateTime, ForeignKey, Index, String, Text, UniqueConstraint
10
+ from sqlalchemy.orm import Mapped, mapped_column, relationship
11
+
12
+ from svc_infra.db.sql.base import ModelBase
13
+ from svc_infra.db.sql.types import GUID
14
+
15
+ # ----------------------------- Models -----------------------------------------
16
+
17
+
18
+ class AuthSession(ModelBase):
19
+ __tablename__ = "auth_sessions"
20
+
21
+ id: Mapped[uuid.UUID] = mapped_column(GUID(), primary_key=True, default=uuid.uuid4)
22
+ user_id: Mapped[uuid.UUID] = mapped_column(
23
+ GUID(), ForeignKey("users.id", ondelete="CASCADE"), index=True
24
+ )
25
+ tenant_id: Mapped[Optional[str]] = mapped_column(String(64), index=True)
26
+ user_agent: Mapped[Optional[str]] = mapped_column(String(512))
27
+ ip_hash: Mapped[Optional[str]] = mapped_column(String(64), index=True)
28
+ last_seen_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
29
+ revoked_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
30
+ revoke_reason: Mapped[Optional[str]] = mapped_column(Text)
31
+
32
+ refresh_tokens: Mapped[list["RefreshToken"]] = relationship(
33
+ back_populates="session", cascade="all, delete-orphan", lazy="selectin"
34
+ )
35
+
36
+ created_at = mapped_column(
37
+ DateTime(timezone=True), server_default="CURRENT_TIMESTAMP", nullable=False
38
+ )
39
+
40
+
41
+ class RefreshToken(ModelBase):
42
+ __tablename__ = "refresh_tokens"
43
+
44
+ id: Mapped[uuid.UUID] = mapped_column(GUID(), primary_key=True, default=uuid.uuid4)
45
+ session_id: Mapped[uuid.UUID] = mapped_column(
46
+ GUID(), ForeignKey("auth_sessions.id", ondelete="CASCADE"), index=True
47
+ )
48
+ session: Mapped[AuthSession] = relationship(back_populates="refresh_tokens")
49
+
50
+ token_hash: Mapped[str] = mapped_column(String(64), index=True, nullable=False)
51
+ rotated_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
52
+ revoked_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
53
+ revoke_reason: Mapped[Optional[str]] = mapped_column(Text)
54
+ expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), index=True)
55
+
56
+ created_at = mapped_column(
57
+ DateTime(timezone=True), server_default="CURRENT_TIMESTAMP", nullable=False
58
+ )
59
+
60
+ __table_args__ = (UniqueConstraint("token_hash", name="uq_refresh_token_hash"),)
61
+
62
+
63
+ class RefreshTokenRevocation(ModelBase):
64
+ __tablename__ = "refresh_token_revocations"
65
+
66
+ id: Mapped[uuid.UUID] = mapped_column(GUID(), primary_key=True, default=uuid.uuid4)
67
+ token_hash: Mapped[str] = mapped_column(String(64), index=True, nullable=False)
68
+ revoked_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
69
+ reason: Mapped[Optional[str]] = mapped_column(Text)
70
+
71
+
72
+ class FailedAuthAttempt(ModelBase):
73
+ __tablename__ = "failed_auth_attempts"
74
+
75
+ id: Mapped[uuid.UUID] = mapped_column(GUID(), primary_key=True, default=uuid.uuid4)
76
+ user_id: Mapped[Optional[uuid.UUID]] = mapped_column(
77
+ GUID(), ForeignKey("users.id", ondelete="CASCADE"), index=True, nullable=True
78
+ )
79
+ ip_hash: Mapped[Optional[str]] = mapped_column(String(64), index=True)
80
+ ts: Mapped[datetime] = mapped_column(
81
+ DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc)
82
+ )
83
+ success: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
84
+
85
+ __table_args__ = (Index("ix_failed_attempt_user_time", "user_id", "ts"),)
86
+
87
+
88
+ class RolePermission(ModelBase):
89
+ __tablename__ = "role_permissions"
90
+
91
+ role: Mapped[str] = mapped_column(String(64), primary_key=True)
92
+ permission: Mapped[str] = mapped_column(String(128), primary_key=True)
93
+
94
+
95
+ class AuditLog(ModelBase):
96
+ __tablename__ = "audit_logs"
97
+
98
+ id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
99
+ ts: Mapped[datetime] = mapped_column(
100
+ DateTime(timezone=True),
101
+ nullable=False,
102
+ default=lambda: datetime.now(timezone.utc),
103
+ index=True,
104
+ )
105
+ actor_id: Mapped[Optional[uuid.UUID]] = mapped_column(
106
+ GUID(), ForeignKey("users.id", ondelete="SET NULL"), nullable=True, index=True
107
+ )
108
+ tenant_id: Mapped[Optional[str]] = mapped_column(String(64), index=True)
109
+ event_type: Mapped[str] = mapped_column(String(128), nullable=False, index=True)
110
+ resource_ref: Mapped[Optional[str]] = mapped_column(String(255), index=True)
111
+ event_metadata: Mapped[dict] = mapped_column("metadata", JSON, default=dict)
112
+ prev_hash: Mapped[Optional[str]] = mapped_column(String(64))
113
+ hash: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
114
+
115
+ __table_args__ = (Index("ix_audit_chain", "tenant_id", "id"),)
116
+
117
+
118
+ # ------------------------ Org / Teams ----------------------------------------
119
+
120
+
121
+ class Organization(ModelBase):
122
+ __tablename__ = "organizations"
123
+
124
+ id: Mapped[uuid.UUID] = mapped_column(GUID(), primary_key=True, default=uuid.uuid4)
125
+ name: Mapped[str] = mapped_column(String(128), nullable=False)
126
+ slug: Mapped[Optional[str]] = mapped_column(String(64), index=True)
127
+ tenant_id: Mapped[Optional[str]] = mapped_column(String(64), index=True)
128
+ created_at = mapped_column(
129
+ DateTime(timezone=True), server_default="CURRENT_TIMESTAMP", nullable=False
130
+ )
131
+
132
+
133
+ class Team(ModelBase):
134
+ __tablename__ = "teams"
135
+
136
+ id: Mapped[uuid.UUID] = mapped_column(GUID(), primary_key=True, default=uuid.uuid4)
137
+ org_id: Mapped[uuid.UUID] = mapped_column(
138
+ GUID(), ForeignKey("organizations.id", ondelete="CASCADE"), index=True
139
+ )
140
+ name: Mapped[str] = mapped_column(String(128), nullable=False)
141
+ created_at = mapped_column(
142
+ DateTime(timezone=True), server_default="CURRENT_TIMESTAMP", nullable=False
143
+ )
144
+
145
+
146
+ class OrganizationMembership(ModelBase):
147
+ __tablename__ = "organization_memberships"
148
+
149
+ id: Mapped[uuid.UUID] = mapped_column(GUID(), primary_key=True, default=uuid.uuid4)
150
+ org_id: Mapped[uuid.UUID] = mapped_column(
151
+ GUID(), ForeignKey("organizations.id", ondelete="CASCADE"), index=True
152
+ )
153
+ user_id: Mapped[uuid.UUID] = mapped_column(
154
+ GUID(), ForeignKey("users.id", ondelete="CASCADE"), index=True
155
+ )
156
+ role: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
157
+ created_at = mapped_column(
158
+ DateTime(timezone=True), server_default="CURRENT_TIMESTAMP", nullable=False
159
+ )
160
+ deactivated_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
161
+
162
+ __table_args__ = (UniqueConstraint("org_id", "user_id", name="uq_org_user_membership"),)
163
+
164
+
165
+ class OrganizationInvitation(ModelBase):
166
+ __tablename__ = "organization_invitations"
167
+
168
+ id: Mapped[uuid.UUID] = mapped_column(GUID(), primary_key=True, default=uuid.uuid4)
169
+ org_id: Mapped[uuid.UUID] = mapped_column(
170
+ GUID(), ForeignKey("organizations.id", ondelete="CASCADE"), index=True
171
+ )
172
+ email: Mapped[str] = mapped_column(String(255), index=True)
173
+ role: Mapped[str] = mapped_column(String(64), nullable=False)
174
+ token_hash: Mapped[str] = mapped_column(String(64), index=True)
175
+ expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), index=True)
176
+ created_by: Mapped[Optional[uuid.UUID]] = mapped_column(
177
+ GUID(), ForeignKey("users.id", ondelete="SET NULL"), index=True
178
+ )
179
+ created_at = mapped_column(
180
+ DateTime(timezone=True), server_default="CURRENT_TIMESTAMP", nullable=False
181
+ )
182
+ last_sent_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
183
+ resend_count: Mapped[int] = mapped_column(default=0)
184
+ used_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
185
+ revoked_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
186
+
187
+
188
+ # ------------------------ Utilities -------------------------------------------
189
+
190
+
191
+ def generate_refresh_token() -> str:
192
+ """Generate a random refresh token (opaque)."""
193
+ return uuid.uuid4().hex + uuid.uuid4().hex # 64 hex chars
194
+
195
+
196
+ def hash_refresh_token(raw: str) -> str:
197
+ return hashlib.sha256(raw.encode()).hexdigest()
198
+
199
+
200
+ def compute_audit_hash(
201
+ prev_hash: Optional[str],
202
+ *,
203
+ ts: datetime,
204
+ actor_id: Optional[uuid.UUID],
205
+ tenant_id: Optional[str],
206
+ event_type: str,
207
+ resource_ref: Optional[str],
208
+ metadata: dict,
209
+ ) -> str:
210
+ """Compute SHA256 hash chaining previous hash + canonical event payload."""
211
+ prev = prev_hash or "0" * 64
212
+ payload = {
213
+ "ts": ts.isoformat(),
214
+ "actor_id": str(actor_id) if actor_id else None,
215
+ "tenant_id": tenant_id,
216
+ "event_type": event_type,
217
+ "resource_ref": resource_ref,
218
+ "metadata": metadata,
219
+ }
220
+ canonical = json.dumps(payload, sort_keys=True, separators=(",", ":"))
221
+ return hashlib.sha256((prev + canonical).encode()).hexdigest()
222
+
223
+
224
+ def rotate_refresh_token(
225
+ current_hash: str, *, ttl_minutes: int = 10080
226
+ ) -> tuple[str, str, datetime]:
227
+ """Rotate: returns (new_raw, new_hash, expires_at)."""
228
+ new_raw = generate_refresh_token()
229
+ new_hash = hash_refresh_token(new_raw)
230
+ expires_at = datetime.now(timezone.utc) + timedelta(minutes=ttl_minutes)
231
+ return new_raw, new_hash, expires_at
232
+
233
+
234
+ __all__ = [
235
+ "AuthSession",
236
+ "RefreshToken",
237
+ "RefreshTokenRevocation",
238
+ "FailedAuthAttempt",
239
+ "RolePermission",
240
+ "AuditLog",
241
+ "generate_refresh_token",
242
+ "hash_refresh_token",
243
+ "compute_audit_hash",
244
+ "rotate_refresh_token",
245
+ ]
@@ -0,0 +1,128 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import uuid
5
+ from datetime import datetime, timedelta, timezone
6
+ from typing import Any, Optional
7
+
8
+ try:
9
+ from sqlalchemy import select
10
+ from sqlalchemy.ext.asyncio import AsyncSession
11
+ except Exception: # pragma: no cover
12
+ AsyncSession = object # type: ignore
13
+ select = None # type: ignore
14
+
15
+ from .models import OrganizationInvitation, OrganizationMembership
16
+
17
+
18
+ def _hash_token(raw: str) -> str:
19
+ return hashlib.sha256(raw.encode()).hexdigest()
20
+
21
+
22
+ def _new_token() -> str:
23
+ return uuid.uuid4().hex + uuid.uuid4().hex
24
+
25
+
26
+ async def issue_invitation(
27
+ db: Any,
28
+ *,
29
+ org_id: uuid.UUID,
30
+ email: str,
31
+ role: str,
32
+ created_by: Optional[uuid.UUID] = None,
33
+ ttl_hours: int = 72,
34
+ ) -> tuple[str, OrganizationInvitation]:
35
+ """Create a new invitation; revoke any existing active invites for the same email+org."""
36
+ # Revoke existing active invites
37
+ if select is not None and hasattr(db, "execute"):
38
+ try:
39
+ rows = (
40
+ (
41
+ await db.execute(
42
+ select(OrganizationInvitation).where(
43
+ OrganizationInvitation.org_id == org_id,
44
+ OrganizationInvitation.email == email,
45
+ OrganizationInvitation.used_at.is_(None),
46
+ OrganizationInvitation.revoked_at.is_(None),
47
+ )
48
+ )
49
+ )
50
+ .scalars()
51
+ .all()
52
+ )
53
+ now = datetime.now(timezone.utc)
54
+ for r in rows:
55
+ r.revoked_at = now
56
+ except Exception: # pragma: no cover
57
+ pass
58
+ else:
59
+ # FakeDB path: revoke in-memory invites
60
+ if hasattr(db, "added"):
61
+ now = datetime.now(timezone.utc)
62
+ for r in list(getattr(db, "added")):
63
+ if (
64
+ isinstance(r, OrganizationInvitation)
65
+ and r.org_id == org_id
66
+ and r.email == email.lower().strip()
67
+ and r.used_at is None
68
+ and r.revoked_at is None
69
+ ):
70
+ r.revoked_at = now
71
+
72
+ raw = _new_token()
73
+ inv = OrganizationInvitation(
74
+ org_id=org_id,
75
+ email=email.lower().strip(),
76
+ role=role,
77
+ token_hash=_hash_token(raw),
78
+ expires_at=datetime.now(timezone.utc) + timedelta(hours=ttl_hours),
79
+ created_by=created_by,
80
+ last_sent_at=datetime.now(timezone.utc),
81
+ resend_count=0,
82
+ )
83
+ if hasattr(db, "add"):
84
+ db.add(inv)
85
+ if hasattr(db, "flush"):
86
+ await db.flush()
87
+ return raw, inv
88
+
89
+
90
+ async def resend_invitation(db: Any, *, invitation: OrganizationInvitation) -> str:
91
+ raw = _new_token()
92
+ invitation.token_hash = _hash_token(raw)
93
+ invitation.last_sent_at = datetime.now(timezone.utc)
94
+ invitation.resend_count = (invitation.resend_count or 0) + 1
95
+ if hasattr(db, "flush"):
96
+ await db.flush()
97
+ return raw
98
+
99
+
100
+ async def accept_invitation(
101
+ db: Any,
102
+ *,
103
+ invitation: OrganizationInvitation,
104
+ user_id: uuid.UUID,
105
+ ) -> OrganizationMembership:
106
+ now = datetime.now(timezone.utc)
107
+ if invitation.revoked_at or invitation.used_at:
108
+ raise ValueError("invitation_unusable")
109
+ if invitation.expires_at and invitation.expires_at < now:
110
+ raise ValueError("invitation_expired")
111
+
112
+ # mark used
113
+ invitation.used_at = now
114
+
115
+ # create membership (upsert-like enforced by DB unique constraint)
116
+ mem = OrganizationMembership(org_id=invitation.org_id, user_id=user_id, role=invitation.role)
117
+ if hasattr(db, "add"):
118
+ db.add(mem)
119
+ if hasattr(db, "flush"):
120
+ await db.flush()
121
+ return mem
122
+
123
+
124
+ __all__ = [
125
+ "issue_invitation",
126
+ "resend_invitation",
127
+ "accept_invitation",
128
+ ]
@@ -0,0 +1,77 @@
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from dataclasses import dataclass
5
+ from typing import Callable, Iterable, Optional
6
+
7
+ COMMON_PASSWORDS = {"password", "123456", "qwerty", "letmein", "admin"}
8
+
9
+ HIBP_DISABLED = False # default enabled; can be toggled via settings at startup
10
+
11
+
12
+ @dataclass
13
+ class PasswordPolicy:
14
+ min_length: int = 12
15
+ require_upper: bool = True
16
+ require_lower: bool = True
17
+ require_digit: bool = True
18
+ require_symbol: bool = True
19
+ forbid_common: bool = True
20
+ forbid_breached: bool = True # will toggle off if HIBP integration not configured
21
+ symbols_regex: str = r"[!@#$%^&*()_+=\-{}\[\]:;,.?/]"
22
+
23
+
24
+ class PasswordValidationError(Exception):
25
+ def __init__(self, reasons: Iterable[str]):
26
+ super().__init__("Password validation failed")
27
+ self.reasons = list(reasons)
28
+
29
+
30
+ UPPER = re.compile(r"[A-Z]")
31
+ LOWER = re.compile(r"[a-z]")
32
+ DIGIT = re.compile(r"[0-9]")
33
+ SYMBOL = re.compile(r"[!@#$%^&*()_+=\-{}\[\]:;,.?/]")
34
+
35
+
36
+ BreachedChecker = Callable[[str], bool]
37
+
38
+
39
+ _breached_checker: Optional[BreachedChecker] = None
40
+
41
+
42
+ def configure_breached_checker(checker: Optional[BreachedChecker]) -> None:
43
+ global _breached_checker
44
+ _breached_checker = checker
45
+
46
+
47
+ def validate_password(pw: str, policy: PasswordPolicy | None = None) -> None:
48
+ policy = policy or PasswordPolicy()
49
+ reasons: list[str] = []
50
+ if len(pw) < policy.min_length:
51
+ reasons.append(f"min_length({policy.min_length})")
52
+ if policy.require_upper and not UPPER.search(pw):
53
+ reasons.append("missing_upper")
54
+ if policy.require_lower and not LOWER.search(pw):
55
+ reasons.append("missing_lower")
56
+ if policy.require_digit and not DIGIT.search(pw):
57
+ reasons.append("missing_digit")
58
+ if policy.require_symbol and not SYMBOL.search(pw):
59
+ reasons.append("missing_symbol")
60
+ if policy.forbid_common:
61
+ lowered = pw.lower()
62
+ # Reject if whole password matches a common one or contains it as a substring
63
+ if lowered in COMMON_PASSWORDS or any(term in lowered for term in COMMON_PASSWORDS):
64
+ reasons.append("common_password")
65
+ if policy.forbid_breached and not HIBP_DISABLED:
66
+ if _breached_checker and _breached_checker(pw):
67
+ reasons.append("breached_password")
68
+ if reasons:
69
+ raise PasswordValidationError(reasons)
70
+
71
+
72
+ __all__ = [
73
+ "PasswordPolicy",
74
+ "validate_password",
75
+ "PasswordValidationError",
76
+ "configure_breached_checker",
77
+ ]