regstack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. regstack/__init__.py +5 -0
  2. regstack/app.py +150 -0
  3. regstack/auth/__init__.py +21 -0
  4. regstack/auth/clock.py +29 -0
  5. regstack/auth/dependencies.py +102 -0
  6. regstack/auth/jwt.py +145 -0
  7. regstack/auth/lockout.py +59 -0
  8. regstack/auth/mfa.py +29 -0
  9. regstack/auth/password.py +20 -0
  10. regstack/auth/tokens.py +19 -0
  11. regstack/cli/__init__.py +0 -0
  12. regstack/cli/__main__.py +27 -0
  13. regstack/cli/_runtime.py +39 -0
  14. regstack/cli/admin.py +45 -0
  15. regstack/cli/doctor.py +186 -0
  16. regstack/cli/init.py +236 -0
  17. regstack/config/__init__.py +4 -0
  18. regstack/config/loader.py +114 -0
  19. regstack/config/schema.py +148 -0
  20. regstack/config/secrets.py +22 -0
  21. regstack/db/__init__.py +17 -0
  22. regstack/db/client.py +26 -0
  23. regstack/db/indexes.py +70 -0
  24. regstack/db/repositories/__init__.py +0 -0
  25. regstack/db/repositories/blacklist_repo.py +28 -0
  26. regstack/db/repositories/login_attempt_repo.py +27 -0
  27. regstack/db/repositories/mfa_code_repo.py +99 -0
  28. regstack/db/repositories/pending_repo.py +76 -0
  29. regstack/db/repositories/user_repo.py +169 -0
  30. regstack/email/__init__.py +12 -0
  31. regstack/email/base.py +23 -0
  32. regstack/email/composer.py +142 -0
  33. regstack/email/console.py +28 -0
  34. regstack/email/factory.py +23 -0
  35. regstack/email/ses.py +47 -0
  36. regstack/email/smtp.py +46 -0
  37. regstack/email/templates/email_change.html +15 -0
  38. regstack/email/templates/email_change.subject.txt +1 -0
  39. regstack/email/templates/email_change.txt +7 -0
  40. regstack/email/templates/password_reset.html +15 -0
  41. regstack/email/templates/password_reset.subject.txt +1 -0
  42. regstack/email/templates/password_reset.txt +7 -0
  43. regstack/email/templates/sms_login_mfa.txt +1 -0
  44. regstack/email/templates/sms_phone_setup.txt +1 -0
  45. regstack/email/templates/verification.html +15 -0
  46. regstack/email/templates/verification.subject.txt +1 -0
  47. regstack/email/templates/verification.txt +7 -0
  48. regstack/hooks/__init__.py +3 -0
  49. regstack/hooks/events.py +59 -0
  50. regstack/models/__init__.py +15 -0
  51. regstack/models/_objectid.py +30 -0
  52. regstack/models/login_attempt.py +31 -0
  53. regstack/models/mfa_code.py +40 -0
  54. regstack/models/pending_registration.py +38 -0
  55. regstack/models/user.py +104 -0
  56. regstack/routers/__init__.py +37 -0
  57. regstack/routers/_schemas.py +34 -0
  58. regstack/routers/account.py +274 -0
  59. regstack/routers/admin.py +187 -0
  60. regstack/routers/login.py +223 -0
  61. regstack/routers/logout.py +39 -0
  62. regstack/routers/password.py +114 -0
  63. regstack/routers/phone.py +242 -0
  64. regstack/routers/register.py +99 -0
  65. regstack/routers/verify.py +116 -0
  66. regstack/sms/__init__.py +5 -0
  67. regstack/sms/base.py +24 -0
  68. regstack/sms/factory.py +23 -0
  69. regstack/sms/null.py +26 -0
  70. regstack/sms/sns.py +42 -0
  71. regstack/sms/twilio.py +49 -0
  72. regstack/ui/__init__.py +3 -0
  73. regstack/ui/pages.py +148 -0
  74. regstack/ui/static/css/core.css +204 -0
  75. regstack/ui/static/css/theme.css +43 -0
  76. regstack/ui/static/js/regstack.js +411 -0
  77. regstack/ui/templates/auth/email_change_confirm.html +10 -0
  78. regstack/ui/templates/auth/forgot.html +14 -0
  79. regstack/ui/templates/auth/login.html +24 -0
  80. regstack/ui/templates/auth/me.html +110 -0
  81. regstack/ui/templates/auth/mfa_confirm.html +14 -0
  82. regstack/ui/templates/auth/register.html +23 -0
  83. regstack/ui/templates/auth/reset.html +13 -0
  84. regstack/ui/templates/auth/verify.html +10 -0
  85. regstack/ui/templates/base.html +46 -0
  86. regstack/version.py +1 -0
  87. regstack-0.1.0.dist-info/METADATA +209 -0
  88. regstack-0.1.0.dist-info/RECORD +92 -0
  89. regstack-0.1.0.dist-info/WHEEL +4 -0
  90. regstack-0.1.0.dist-info/entry_points.txt +2 -0
  91. regstack-0.1.0.dist-info/licenses/LICENSE +202 -0
  92. regstack-0.1.0.dist-info/licenses/NOTICE +5 -0
@@ -0,0 +1,148 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Annotated, Literal
5
+
6
+ from pydantic import AnyHttpUrl, BaseModel, EmailStr, Field, SecretStr, field_validator
7
+ from pydantic_settings import BaseSettings, SettingsConfigDict
8
+
9
+ EmailBackend = Literal["console", "smtp", "ses"]
10
+ SmsBackend = Literal["null", "sns", "twilio"]
11
+ JwtAlgorithm = Literal["HS256", "HS384", "HS512"]
12
+ TokenTransport = Literal["bearer", "cookie"]
13
+
14
+
15
+ class EmailConfig(BaseModel):
16
+ backend: EmailBackend = "console"
17
+ from_address: EmailStr = "noreply@example.com"
18
+ from_name: str = "RegStack"
19
+
20
+ smtp_host: str | None = None
21
+ smtp_port: int = 587
22
+ smtp_starttls: bool = True
23
+ smtp_username: str | None = None
24
+ smtp_password: SecretStr | None = None
25
+
26
+ ses_region: str = "eu-west-1"
27
+ ses_profile: str | None = None
28
+
29
+
30
+ class SmsConfig(BaseModel):
31
+ backend: SmsBackend = "null"
32
+ from_number: str | None = None
33
+
34
+ sns_region: str = "eu-west-1"
35
+
36
+ twilio_account_sid: str | None = None
37
+ twilio_auth_token: SecretStr | None = None
38
+
39
+
40
+ class RegStackConfig(BaseSettings):
41
+ """Top-level configuration for an embedded regstack instance.
42
+
43
+ Loading order (highest priority first):
44
+ 1. Programmatic kwargs.
45
+ 2. Environment variables (``REGSTACK_*``, nested via ``__``).
46
+ 3. TOML file at ``$REGSTACK_CONFIG`` or ``./regstack.toml``.
47
+ 4. Field defaults defined here.
48
+ """
49
+
50
+ model_config = SettingsConfigDict(
51
+ env_prefix="REGSTACK_",
52
+ env_nested_delimiter="__",
53
+ extra="ignore",
54
+ populate_by_name=True,
55
+ )
56
+
57
+ # Identity / hosting
58
+ app_name: str = "RegStack"
59
+ base_url: AnyHttpUrl = AnyHttpUrl("http://localhost:8000")
60
+ cookie_domain: str | None = None
61
+ behind_proxy: bool = False
62
+
63
+ # Database
64
+ mongodb_url: SecretStr = SecretStr("mongodb://localhost:27017")
65
+ mongodb_database: str = "regstack"
66
+ user_collection: str = "users"
67
+ pending_collection: str = "pending_registrations"
68
+ blacklist_collection: str = "token_blacklist"
69
+ login_attempt_collection: str = "login_attempts"
70
+ mfa_code_collection: str = "mfa_codes"
71
+
72
+ # JWT
73
+ jwt_secret: SecretStr = Field(default_factory=lambda: SecretStr(""))
74
+ jwt_algorithm: JwtAlgorithm = "HS256"
75
+ jwt_ttl_seconds: Annotated[int, Field(ge=60, le=60 * 60 * 24 * 30)] = 7200
76
+ jwt_audience: str | None = None
77
+ transport: TokenTransport = "bearer"
78
+
79
+ # Verification & password-reset & email-change token lifetimes
80
+ verification_token_ttl_seconds: Annotated[int, Field(ge=60)] = 60 * 60 * 24
81
+ password_reset_token_ttl_seconds: Annotated[int, Field(ge=60)] = 60 * 30
82
+ email_change_token_ttl_seconds: Annotated[int, Field(ge=60)] = 60 * 60
83
+
84
+ # SMS / 2FA
85
+ sms_code_length: Annotated[int, Field(ge=4, le=10)] = 6
86
+ sms_code_ttl_seconds: Annotated[int, Field(ge=30, le=60 * 30)] = 300
87
+ sms_code_max_attempts: Annotated[int, Field(ge=1, le=20)] = 5
88
+ mfa_pending_token_ttl_seconds: Annotated[int, Field(ge=60, le=60 * 30)] = 600
89
+
90
+ # Feature flags
91
+ require_verification: bool = True
92
+ allow_registration: bool = True
93
+ enable_password_reset: bool = True
94
+ enable_account_deletion: bool = True
95
+ enable_admin_router: bool = False
96
+ enable_ui_router: bool = False
97
+ enable_sms_2fa: bool = False
98
+ enable_oauth: bool = False # reserved; no providers ship in v1
99
+
100
+ # Login lockout (M2: count failed attempts per email in a sliding window)
101
+ rate_limit_disabled: bool = False
102
+ login_lockout_threshold: Annotated[int, Field(ge=1)] = 5
103
+ login_lockout_window_seconds: Annotated[int, Field(ge=10)] = 900
104
+
105
+ # Reserved for future route-level rate limiting (slowapi-style).
106
+ login_max_per_minute: Annotated[int, Field(ge=1)] = 5
107
+ login_max_per_hour: Annotated[int, Field(ge=1)] = 20
108
+
109
+ # Sub-configs
110
+ email: EmailConfig = Field(default_factory=EmailConfig)
111
+ sms: SmsConfig = Field(default_factory=SmsConfig)
112
+
113
+ # Branding / theming
114
+ brand_logo_url: str | None = None
115
+ brand_tagline: str | None = None
116
+ extra_template_dirs: list[Path] = Field(default_factory=list)
117
+ extra_static_dirs: list[Path] = Field(default_factory=list)
118
+
119
+ # SSR ui_router URLs
120
+ api_prefix: str = "/api/auth"
121
+ ui_prefix: str = "/account"
122
+ static_prefix: str = "/regstack-static"
123
+ theme_css_url: str | None = None # if set, loaded AFTER bundled defaults
124
+
125
+ @field_validator("jwt_secret")
126
+ @classmethod
127
+ def _warn_empty_secret(cls, v: SecretStr) -> SecretStr:
128
+ # An empty secret is allowed at construction time so defaults remain
129
+ # usable in tests; production callers should populate it explicitly
130
+ # or via the wizard. Validation that *requires* it lives at the
131
+ # RegStack façade boundary so test fixtures can opt out.
132
+ return v
133
+
134
+ @classmethod
135
+ def load(
136
+ cls,
137
+ toml_path: Path | str | None = None,
138
+ secrets_env_path: Path | str | None = None,
139
+ **overrides: object,
140
+ ) -> RegStackConfig:
141
+ """Convenience constructor delegating to ``regstack.config.loader.load_config``."""
142
+ from regstack.config.loader import load_config
143
+
144
+ return load_config(
145
+ toml_path=toml_path,
146
+ secrets_env_path=secrets_env_path,
147
+ **overrides,
148
+ )
@@ -0,0 +1,22 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import hmac
5
+ import secrets
6
+
7
+
8
+ def derive_secret(master: str | bytes, purpose: str) -> bytes:
9
+ """Derive a purpose-specific secret from the master JWT secret.
10
+
11
+ Uses HMAC-SHA256 so every subsystem (verification tokens, password reset
12
+ tokens, refresh tokens, etc.) signs with a different key. Compromising one
13
+ derived key does not compromise the master.
14
+ """
15
+ if isinstance(master, str):
16
+ master = master.encode("utf-8")
17
+ return hmac.new(master, purpose.encode("utf-8"), hashlib.sha256).digest()
18
+
19
+
20
+ def generate_secret(num_bytes: int = 64) -> str:
21
+ """Return a URL-safe random secret suitable for the JWT master key."""
22
+ return secrets.token_urlsafe(num_bytes)
@@ -0,0 +1,17 @@
1
+ from regstack.db.indexes import install_indexes
2
+ from regstack.db.repositories.blacklist_repo import BlacklistRepo
3
+ from regstack.db.repositories.login_attempt_repo import LoginAttemptRepo
4
+ from regstack.db.repositories.mfa_code_repo import MfaCodeRepo, MfaVerifyOutcome, MfaVerifyResult
5
+ from regstack.db.repositories.pending_repo import PendingRepo
6
+ from regstack.db.repositories.user_repo import UserRepo
7
+
8
+ __all__ = [
9
+ "BlacklistRepo",
10
+ "LoginAttemptRepo",
11
+ "MfaCodeRepo",
12
+ "MfaVerifyOutcome",
13
+ "MfaVerifyResult",
14
+ "PendingRepo",
15
+ "UserRepo",
16
+ "install_indexes",
17
+ ]
regstack/db/client.py ADDED
@@ -0,0 +1,26 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from pymongo import AsyncMongoClient
6
+
7
+ if TYPE_CHECKING:
8
+ from pymongo.asynchronous.database import AsyncDatabase
9
+
10
+ from regstack.config.schema import RegStackConfig
11
+
12
+
13
+ def make_client(config: RegStackConfig) -> AsyncMongoClient:
14
+ """Build an AsyncMongoClient with the settings regstack expects.
15
+
16
+ ``tz_aware=True`` makes BSON datetimes round-trip as UTC-aware values; the
17
+ JWT and bulk-revocation comparisons assume aware datetimes throughout.
18
+ """
19
+ return AsyncMongoClient(
20
+ config.mongodb_url.get_secret_value(),
21
+ tz_aware=True,
22
+ )
23
+
24
+
25
+ def get_database(client: AsyncMongoClient, config: RegStackConfig) -> AsyncDatabase:
26
+ return client[config.mongodb_database]
regstack/db/indexes.py ADDED
@@ -0,0 +1,70 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import TYPE_CHECKING
5
+
6
+ from pymongo import ASCENDING, IndexModel
7
+
8
+ if TYPE_CHECKING:
9
+ from pymongo.asynchronous.database import AsyncDatabase
10
+
11
+ from regstack.config.schema import RegStackConfig
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ async def install_indexes(db: AsyncDatabase, config: RegStackConfig) -> None:
17
+ """Create the indexes regstack relies on. Safe to call repeatedly."""
18
+ users = db[config.user_collection]
19
+ await users.create_indexes(
20
+ [IndexModel([("email", ASCENDING)], unique=True, name="email_unique")]
21
+ )
22
+
23
+ blacklist = db[config.blacklist_collection]
24
+ # TTL on `exp` lets MongoDB reap revoked tokens when they would have
25
+ # expired anyway. expireAfterSeconds=0 means "delete when the date is
26
+ # in the past" — the value at `exp` is the deletion deadline.
27
+ await blacklist.create_indexes(
28
+ [
29
+ IndexModel([("jti", ASCENDING)], unique=True, name="jti_unique"),
30
+ IndexModel([("exp", ASCENDING)], expireAfterSeconds=0, name="exp_ttl"),
31
+ ]
32
+ )
33
+
34
+ pending = db[config.pending_collection]
35
+ await pending.create_indexes(
36
+ [
37
+ IndexModel([("email", ASCENDING)], unique=True, name="pending_email_unique"),
38
+ IndexModel([("token_hash", ASCENDING)], unique=True, name="pending_token_unique"),
39
+ IndexModel([("expires_at", ASCENDING)], expireAfterSeconds=0, name="pending_ttl"),
40
+ ]
41
+ )
42
+
43
+ attempts = db[config.login_attempt_collection]
44
+ # Sparse-ish TTL — rows survive `login_lockout_window_seconds` after
45
+ # `when`. The TTL value comes from config so tightening the lockout
46
+ # window also tightens cleanup.
47
+ await attempts.create_indexes(
48
+ [
49
+ IndexModel([("email", ASCENDING), ("when", ASCENDING)], name="email_when"),
50
+ IndexModel(
51
+ [("when", ASCENDING)],
52
+ expireAfterSeconds=config.login_lockout_window_seconds,
53
+ name="when_ttl",
54
+ ),
55
+ ]
56
+ )
57
+
58
+ mfa = db[config.mfa_code_collection]
59
+ await mfa.create_indexes(
60
+ [
61
+ IndexModel(
62
+ [("user_id", ASCENDING), ("kind", ASCENDING)],
63
+ unique=True,
64
+ name="user_kind_unique",
65
+ ),
66
+ IndexModel([("expires_at", ASCENDING)], expireAfterSeconds=0, name="mfa_ttl"),
67
+ ]
68
+ )
69
+
70
+ log.info("regstack indexes installed on database %s", db.name)
File without changes
@@ -0,0 +1,28 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ from datetime import datetime
5
+ from typing import TYPE_CHECKING
6
+
7
+ from pymongo.errors import DuplicateKeyError
8
+
9
+ if TYPE_CHECKING:
10
+ from pymongo.asynchronous.database import AsyncDatabase
11
+
12
+
13
+ class BlacklistRepo:
14
+ """Per-token revocation store. The `exp` field has a TTL index that
15
+ auto-reaps documents once the underlying token would have expired anyway.
16
+ """
17
+
18
+ def __init__(self, db: AsyncDatabase, collection_name: str) -> None:
19
+ self._collection = db[collection_name]
20
+
21
+ async def revoke(self, jti: str, exp: datetime) -> None:
22
+ # Idempotent — re-revoking the same jti is a no-op.
23
+ with contextlib.suppress(DuplicateKeyError):
24
+ await self._collection.insert_one({"jti": jti, "exp": exp})
25
+
26
+ async def is_revoked(self, jti: str) -> bool:
27
+ doc = await self._collection.find_one({"jti": jti}, projection={"_id": 1})
28
+ return doc is not None
@@ -0,0 +1,27 @@
1
+ from __future__ import annotations
2
+
3
+ from datetime import UTC, datetime, timedelta
4
+ from typing import TYPE_CHECKING
5
+
6
+ from regstack.models.login_attempt import LoginAttempt
7
+
8
+ if TYPE_CHECKING:
9
+ from pymongo.asynchronous.database import AsyncDatabase
10
+
11
+
12
+ class LoginAttemptRepo:
13
+ def __init__(self, db: AsyncDatabase, collection_name: str) -> None:
14
+ self._collection = db[collection_name]
15
+
16
+ async def record_failure(
17
+ self, email: str, *, when: datetime | None = None, ip: str | None = None
18
+ ) -> None:
19
+ attempt = LoginAttempt(email=email, when=when or datetime.now(UTC), ip=ip)
20
+ await self._collection.insert_one(attempt.to_mongo())
21
+
22
+ async def count_recent(self, email: str, *, window: timedelta, now: datetime) -> int:
23
+ cutoff = now - window
24
+ return await self._collection.count_documents({"email": email, "when": {"$gte": cutoff}})
25
+
26
+ async def clear(self, email: str) -> None:
27
+ await self._collection.delete_many({"email": email})
@@ -0,0 +1,99 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from datetime import datetime
5
+ from enum import StrEnum
6
+ from typing import TYPE_CHECKING
7
+
8
+ from regstack.auth.tokens import hash_token
9
+ from regstack.models.mfa_code import MfaCode, MfaKind
10
+
11
+ if TYPE_CHECKING:
12
+ from pymongo.asynchronous.database import AsyncDatabase
13
+
14
+ from regstack.auth.clock import Clock
15
+
16
+
17
+ class MfaVerifyOutcome(StrEnum):
18
+ OK = "ok"
19
+ WRONG = "wrong"
20
+ EXPIRED = "expired"
21
+ LOCKED = "locked"
22
+ MISSING = "missing"
23
+
24
+
25
+ @dataclass(slots=True, frozen=True)
26
+ class MfaVerifyResult:
27
+ outcome: MfaVerifyOutcome
28
+ attempts_remaining: int = 0
29
+
30
+
31
+ class MfaCodeRepo:
32
+ def __init__(self, db: AsyncDatabase, collection_name: str, *, clock: Clock) -> None:
33
+ self._collection = db[collection_name]
34
+ self._clock = clock
35
+
36
+ async def put(self, code: MfaCode) -> None:
37
+ """Upsert by ``(user_id, kind)`` — re-issuing a code overwrites
38
+ any previous outstanding code, so old SMS messages stop working
39
+ as soon as a new one is sent.
40
+ """
41
+ doc = code.to_mongo()
42
+ await self._collection.find_one_and_replace(
43
+ {"user_id": code.user_id, "kind": code.kind},
44
+ doc,
45
+ upsert=True,
46
+ )
47
+
48
+ async def verify(
49
+ self,
50
+ *,
51
+ user_id: str,
52
+ kind: MfaKind,
53
+ raw_code: str,
54
+ ) -> MfaVerifyResult:
55
+ doc = await self._collection.find_one({"user_id": user_id, "kind": kind})
56
+ if doc is None:
57
+ return MfaVerifyResult(MfaVerifyOutcome.MISSING)
58
+ attempts = int(doc.get("attempts", 0))
59
+ max_attempts = int(doc.get("max_attempts", 5))
60
+ if attempts >= max_attempts:
61
+ await self._collection.delete_one({"_id": doc["_id"]})
62
+ return MfaVerifyResult(MfaVerifyOutcome.LOCKED)
63
+
64
+ if doc["expires_at"] <= self._clock.now():
65
+ await self._collection.delete_one({"_id": doc["_id"]})
66
+ return MfaVerifyResult(MfaVerifyOutcome.EXPIRED)
67
+
68
+ if doc["code_hash"] != hash_token(raw_code):
69
+ new_attempts = attempts + 1
70
+ await self._collection.update_one(
71
+ {"_id": doc["_id"]}, {"$set": {"attempts": new_attempts}}
72
+ )
73
+ remaining = max(max_attempts - new_attempts, 0)
74
+ if remaining == 0:
75
+ await self._collection.delete_one({"_id": doc["_id"]})
76
+ return MfaVerifyResult(MfaVerifyOutcome.LOCKED)
77
+ return MfaVerifyResult(MfaVerifyOutcome.WRONG, attempts_remaining=remaining)
78
+
79
+ await self._collection.delete_one({"_id": doc["_id"]})
80
+ return MfaVerifyResult(MfaVerifyOutcome.OK)
81
+
82
+ async def delete(self, *, user_id: str, kind: MfaKind | None = None) -> None:
83
+ query: dict[str, object] = {"user_id": user_id}
84
+ if kind is not None:
85
+ query["kind"] = kind
86
+ await self._collection.delete_many(query)
87
+
88
+ async def find(self, *, user_id: str, kind: MfaKind) -> dict[str, object] | None:
89
+ return await self._collection.find_one({"user_id": user_id, "kind": kind})
90
+
91
+ @staticmethod
92
+ def make_code_hash(raw_code: str) -> str:
93
+ return hash_token(raw_code)
94
+
95
+
96
+ def now_plus_seconds(clock: Clock, seconds: int) -> datetime:
97
+ from datetime import timedelta
98
+
99
+ return clock.now() + timedelta(seconds=seconds)
@@ -0,0 +1,76 @@
1
+ from __future__ import annotations
2
+
3
+ from datetime import UTC, datetime
4
+ from typing import TYPE_CHECKING, Any
5
+
6
+ from bson import ObjectId
7
+ from pymongo.errors import DuplicateKeyError
8
+
9
+ from regstack.models.pending_registration import PendingRegistration
10
+
11
+ if TYPE_CHECKING:
12
+ from pymongo.asynchronous.database import AsyncDatabase
13
+
14
+
15
+ class PendingAlreadyExistsError(Exception):
16
+ """A pending registration with this email already exists."""
17
+
18
+
19
+ class PendingRepo:
20
+ def __init__(self, db: AsyncDatabase, collection_name: str) -> None:
21
+ self._collection = db[collection_name]
22
+
23
+ async def upsert(self, pending: PendingRegistration) -> PendingRegistration:
24
+ """Insert or replace the pending registration for this email.
25
+
26
+ Resends overwrite an outstanding row so the most recent token is the
27
+ only valid one — old links stop working as soon as a new one is sent.
28
+ """
29
+ doc = pending.to_mongo()
30
+ result = await self._collection.find_one_and_replace(
31
+ {"email": pending.email},
32
+ doc,
33
+ upsert=True,
34
+ return_document=True,
35
+ )
36
+ if result is not None and "_id" in result:
37
+ pending.id = str(result["_id"])
38
+ return pending
39
+
40
+ async def create(self, pending: PendingRegistration) -> PendingRegistration:
41
+ try:
42
+ result = await self._collection.insert_one(pending.to_mongo())
43
+ except DuplicateKeyError as exc:
44
+ raise PendingAlreadyExistsError(pending.email) from exc
45
+ pending.id = str(result.inserted_id)
46
+ return pending
47
+
48
+ async def find_by_token_hash(self, token_hash: str) -> PendingRegistration | None:
49
+ doc = await self._collection.find_one({"token_hash": token_hash})
50
+ return self._hydrate(doc)
51
+
52
+ async def find_by_email(self, email: str) -> PendingRegistration | None:
53
+ doc = await self._collection.find_one({"email": email})
54
+ return self._hydrate(doc)
55
+
56
+ async def delete_by_id(self, pending_id: str) -> None:
57
+ if not ObjectId.is_valid(pending_id):
58
+ return
59
+ await self._collection.delete_one({"_id": ObjectId(pending_id)})
60
+
61
+ async def delete_by_email(self, email: str) -> None:
62
+ await self._collection.delete_one({"email": email})
63
+
64
+ async def purge_expired(self, now: datetime | None = None) -> int:
65
+ """Manual reaper for callers that don't trust the TTL background sweep."""
66
+ cutoff = now or datetime.now(UTC)
67
+ result = await self._collection.delete_many({"expires_at": {"$lt": cutoff}})
68
+ return int(result.deleted_count)
69
+
70
+ @staticmethod
71
+ def _hydrate(doc: dict[str, Any] | None) -> PendingRegistration | None:
72
+ if doc is None:
73
+ return None
74
+ if isinstance(doc.get("_id"), ObjectId):
75
+ doc["_id"] = str(doc["_id"])
76
+ return PendingRegistration.model_validate(doc)
@@ -0,0 +1,169 @@
1
+ from __future__ import annotations
2
+
3
+ from datetime import datetime
4
+ from typing import TYPE_CHECKING, Any
5
+
6
+ from bson import ObjectId
7
+ from pymongo.errors import DuplicateKeyError
8
+
9
+ from regstack.auth.clock import Clock, SystemClock
10
+ from regstack.models.user import BaseUser
11
+
12
+ if TYPE_CHECKING:
13
+ from pymongo.asynchronous.database import AsyncDatabase
14
+
15
+
16
+ class UserAlreadyExistsError(Exception):
17
+ """Raised when an attempt is made to insert a user with a duplicate email."""
18
+
19
+
20
+ def _bulk_revoke_cutoff(now: datetime) -> datetime:
21
+ """The cutoff timestamp recorded on the user document. Stored at full
22
+ microsecond precision; the JWT ``iat`` claim is also emitted as a float
23
+ so the ``iat < cutoff`` comparison is exact and a fresh login completing
24
+ even microseconds after a password / email change is recognised as
25
+ later-than-cutoff.
26
+ """
27
+ return now
28
+
29
+
30
+ class UserRepo:
31
+ def __init__(
32
+ self,
33
+ db: AsyncDatabase,
34
+ collection_name: str,
35
+ *,
36
+ clock: Clock | None = None,
37
+ ) -> None:
38
+ self._collection = db[collection_name]
39
+ self._clock: Clock = clock or SystemClock()
40
+
41
+ async def create(self, user: BaseUser) -> BaseUser:
42
+ doc = user.to_mongo()
43
+ try:
44
+ result = await self._collection.insert_one(doc)
45
+ except DuplicateKeyError as exc:
46
+ raise UserAlreadyExistsError(user.email) from exc
47
+ user.id = str(result.inserted_id)
48
+ return user
49
+
50
+ async def get_by_email(self, email: str) -> BaseUser | None:
51
+ doc = await self._collection.find_one({"email": email})
52
+ return self._hydrate(doc)
53
+
54
+ async def get_by_id(self, user_id: str) -> BaseUser | None:
55
+ if not ObjectId.is_valid(user_id):
56
+ return None
57
+ doc = await self._collection.find_one({"_id": ObjectId(user_id)})
58
+ return self._hydrate(doc)
59
+
60
+ async def set_last_login(self, user_id: str, when: datetime) -> None:
61
+ await self._collection.update_one(
62
+ {"_id": ObjectId(user_id)},
63
+ {"$set": {"last_login": when, "updated_at": self._clock.now()}},
64
+ )
65
+
66
+ async def set_tokens_invalidated_after(self, user_id: str, when: datetime) -> None:
67
+ await self._collection.update_one(
68
+ {"_id": ObjectId(user_id)},
69
+ {
70
+ "$set": {
71
+ "tokens_invalidated_after": _bulk_revoke_cutoff(when),
72
+ "updated_at": self._clock.now(),
73
+ }
74
+ },
75
+ )
76
+
77
+ async def update_password(self, user_id: str, hashed_password: str) -> None:
78
+ now = self._clock.now()
79
+ await self._collection.update_one(
80
+ {"_id": ObjectId(user_id)},
81
+ {
82
+ "$set": {
83
+ "hashed_password": hashed_password,
84
+ "tokens_invalidated_after": _bulk_revoke_cutoff(now),
85
+ "updated_at": now,
86
+ }
87
+ },
88
+ )
89
+
90
+ async def set_active(self, user_id: str, *, is_active: bool) -> None:
91
+ await self._collection.update_one(
92
+ {"_id": ObjectId(user_id)},
93
+ {"$set": {"is_active": is_active, "updated_at": self._clock.now()}},
94
+ )
95
+
96
+ async def set_superuser(self, user_id: str, *, is_superuser: bool) -> None:
97
+ await self._collection.update_one(
98
+ {"_id": ObjectId(user_id)},
99
+ {"$set": {"is_superuser": is_superuser, "updated_at": self._clock.now()}},
100
+ )
101
+
102
+ async def set_full_name(self, user_id: str, full_name: str | None) -> None:
103
+ await self._collection.update_one(
104
+ {"_id": ObjectId(user_id)},
105
+ {"$set": {"full_name": full_name, "updated_at": self._clock.now()}},
106
+ )
107
+
108
+ async def set_phone(self, user_id: str, phone_number: str | None) -> None:
109
+ await self._collection.update_one(
110
+ {"_id": ObjectId(user_id)},
111
+ {"$set": {"phone_number": phone_number, "updated_at": self._clock.now()}},
112
+ )
113
+
114
+ async def set_mfa_enabled(self, user_id: str, *, is_mfa_enabled: bool) -> None:
115
+ await self._collection.update_one(
116
+ {"_id": ObjectId(user_id)},
117
+ {"$set": {"is_mfa_enabled": is_mfa_enabled, "updated_at": self._clock.now()}},
118
+ )
119
+
120
+ async def update_email(self, user_id: str, new_email: str) -> None:
121
+ """Atomically swap the user's email. Bumps tokens_invalidated_after so
122
+ any session bound to the old email becomes useless.
123
+ """
124
+ now = self._clock.now()
125
+ try:
126
+ await self._collection.update_one(
127
+ {"_id": ObjectId(user_id)},
128
+ {
129
+ "$set": {
130
+ "email": new_email,
131
+ "tokens_invalidated_after": _bulk_revoke_cutoff(now),
132
+ "updated_at": now,
133
+ }
134
+ },
135
+ )
136
+ except DuplicateKeyError as exc:
137
+ raise UserAlreadyExistsError(new_email) from exc
138
+
139
+ async def delete(self, user_id: str) -> bool:
140
+ if not ObjectId.is_valid(user_id):
141
+ return False
142
+ result = await self._collection.delete_one({"_id": ObjectId(user_id)})
143
+ return bool(result.deleted_count)
144
+
145
+ async def count(self, *, filter_: dict[str, Any] | None = None) -> int:
146
+ return await self._collection.count_documents(filter_ or {})
147
+
148
+ async def list_paged(
149
+ self,
150
+ *,
151
+ skip: int = 0,
152
+ limit: int = 50,
153
+ sort: tuple[str, int] = ("created_at", -1),
154
+ ) -> list[BaseUser]:
155
+ cursor = self._collection.find().sort([sort]).skip(skip).limit(limit)
156
+ out: list[BaseUser] = []
157
+ async for doc in cursor:
158
+ user = self._hydrate(doc)
159
+ if user is not None:
160
+ out.append(user)
161
+ return out
162
+
163
+ @staticmethod
164
+ def _hydrate(doc: dict[str, Any] | None) -> BaseUser | None:
165
+ if doc is None:
166
+ return None
167
+ if isinstance(doc.get("_id"), ObjectId):
168
+ doc["_id"] = str(doc["_id"])
169
+ return BaseUser.model_validate(doc)