regstack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- regstack/__init__.py +5 -0
- regstack/app.py +150 -0
- regstack/auth/__init__.py +21 -0
- regstack/auth/clock.py +29 -0
- regstack/auth/dependencies.py +102 -0
- regstack/auth/jwt.py +145 -0
- regstack/auth/lockout.py +59 -0
- regstack/auth/mfa.py +29 -0
- regstack/auth/password.py +20 -0
- regstack/auth/tokens.py +19 -0
- regstack/cli/__init__.py +0 -0
- regstack/cli/__main__.py +27 -0
- regstack/cli/_runtime.py +39 -0
- regstack/cli/admin.py +45 -0
- regstack/cli/doctor.py +186 -0
- regstack/cli/init.py +236 -0
- regstack/config/__init__.py +4 -0
- regstack/config/loader.py +114 -0
- regstack/config/schema.py +148 -0
- regstack/config/secrets.py +22 -0
- regstack/db/__init__.py +17 -0
- regstack/db/client.py +26 -0
- regstack/db/indexes.py +70 -0
- regstack/db/repositories/__init__.py +0 -0
- regstack/db/repositories/blacklist_repo.py +28 -0
- regstack/db/repositories/login_attempt_repo.py +27 -0
- regstack/db/repositories/mfa_code_repo.py +99 -0
- regstack/db/repositories/pending_repo.py +76 -0
- regstack/db/repositories/user_repo.py +169 -0
- regstack/email/__init__.py +12 -0
- regstack/email/base.py +23 -0
- regstack/email/composer.py +142 -0
- regstack/email/console.py +28 -0
- regstack/email/factory.py +23 -0
- regstack/email/ses.py +47 -0
- regstack/email/smtp.py +46 -0
- regstack/email/templates/email_change.html +15 -0
- regstack/email/templates/email_change.subject.txt +1 -0
- regstack/email/templates/email_change.txt +7 -0
- regstack/email/templates/password_reset.html +15 -0
- regstack/email/templates/password_reset.subject.txt +1 -0
- regstack/email/templates/password_reset.txt +7 -0
- regstack/email/templates/sms_login_mfa.txt +1 -0
- regstack/email/templates/sms_phone_setup.txt +1 -0
- regstack/email/templates/verification.html +15 -0
- regstack/email/templates/verification.subject.txt +1 -0
- regstack/email/templates/verification.txt +7 -0
- regstack/hooks/__init__.py +3 -0
- regstack/hooks/events.py +59 -0
- regstack/models/__init__.py +15 -0
- regstack/models/_objectid.py +30 -0
- regstack/models/login_attempt.py +31 -0
- regstack/models/mfa_code.py +40 -0
- regstack/models/pending_registration.py +38 -0
- regstack/models/user.py +104 -0
- regstack/routers/__init__.py +37 -0
- regstack/routers/_schemas.py +34 -0
- regstack/routers/account.py +274 -0
- regstack/routers/admin.py +187 -0
- regstack/routers/login.py +223 -0
- regstack/routers/logout.py +39 -0
- regstack/routers/password.py +114 -0
- regstack/routers/phone.py +242 -0
- regstack/routers/register.py +99 -0
- regstack/routers/verify.py +116 -0
- regstack/sms/__init__.py +5 -0
- regstack/sms/base.py +24 -0
- regstack/sms/factory.py +23 -0
- regstack/sms/null.py +26 -0
- regstack/sms/sns.py +42 -0
- regstack/sms/twilio.py +49 -0
- regstack/ui/__init__.py +3 -0
- regstack/ui/pages.py +148 -0
- regstack/ui/static/css/core.css +204 -0
- regstack/ui/static/css/theme.css +43 -0
- regstack/ui/static/js/regstack.js +411 -0
- regstack/ui/templates/auth/email_change_confirm.html +10 -0
- regstack/ui/templates/auth/forgot.html +14 -0
- regstack/ui/templates/auth/login.html +24 -0
- regstack/ui/templates/auth/me.html +110 -0
- regstack/ui/templates/auth/mfa_confirm.html +14 -0
- regstack/ui/templates/auth/register.html +23 -0
- regstack/ui/templates/auth/reset.html +13 -0
- regstack/ui/templates/auth/verify.html +10 -0
- regstack/ui/templates/base.html +46 -0
- regstack/version.py +1 -0
- regstack-0.1.0.dist-info/METADATA +209 -0
- regstack-0.1.0.dist-info/RECORD +92 -0
- regstack-0.1.0.dist-info/WHEEL +4 -0
- regstack-0.1.0.dist-info/entry_points.txt +2 -0
- regstack-0.1.0.dist-info/licenses/LICENSE +202 -0
- regstack-0.1.0.dist-info/licenses/NOTICE +5 -0
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Annotated, Literal
|
|
5
|
+
|
|
6
|
+
from pydantic import AnyHttpUrl, BaseModel, EmailStr, Field, SecretStr, field_validator
|
|
7
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
8
|
+
|
|
9
|
+
EmailBackend = Literal["console", "smtp", "ses"]
|
|
10
|
+
SmsBackend = Literal["null", "sns", "twilio"]
|
|
11
|
+
JwtAlgorithm = Literal["HS256", "HS384", "HS512"]
|
|
12
|
+
TokenTransport = Literal["bearer", "cookie"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class EmailConfig(BaseModel):
|
|
16
|
+
backend: EmailBackend = "console"
|
|
17
|
+
from_address: EmailStr = "noreply@example.com"
|
|
18
|
+
from_name: str = "RegStack"
|
|
19
|
+
|
|
20
|
+
smtp_host: str | None = None
|
|
21
|
+
smtp_port: int = 587
|
|
22
|
+
smtp_starttls: bool = True
|
|
23
|
+
smtp_username: str | None = None
|
|
24
|
+
smtp_password: SecretStr | None = None
|
|
25
|
+
|
|
26
|
+
ses_region: str = "eu-west-1"
|
|
27
|
+
ses_profile: str | None = None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class SmsConfig(BaseModel):
|
|
31
|
+
backend: SmsBackend = "null"
|
|
32
|
+
from_number: str | None = None
|
|
33
|
+
|
|
34
|
+
sns_region: str = "eu-west-1"
|
|
35
|
+
|
|
36
|
+
twilio_account_sid: str | None = None
|
|
37
|
+
twilio_auth_token: SecretStr | None = None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class RegStackConfig(BaseSettings):
|
|
41
|
+
"""Top-level configuration for an embedded regstack instance.
|
|
42
|
+
|
|
43
|
+
Loading order (highest priority first):
|
|
44
|
+
1. Programmatic kwargs.
|
|
45
|
+
2. Environment variables (``REGSTACK_*``, nested via ``__``).
|
|
46
|
+
3. TOML file at ``$REGSTACK_CONFIG`` or ``./regstack.toml``.
|
|
47
|
+
4. Field defaults defined here.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
model_config = SettingsConfigDict(
|
|
51
|
+
env_prefix="REGSTACK_",
|
|
52
|
+
env_nested_delimiter="__",
|
|
53
|
+
extra="ignore",
|
|
54
|
+
populate_by_name=True,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# Identity / hosting
|
|
58
|
+
app_name: str = "RegStack"
|
|
59
|
+
base_url: AnyHttpUrl = AnyHttpUrl("http://localhost:8000")
|
|
60
|
+
cookie_domain: str | None = None
|
|
61
|
+
behind_proxy: bool = False
|
|
62
|
+
|
|
63
|
+
# Database
|
|
64
|
+
mongodb_url: SecretStr = SecretStr("mongodb://localhost:27017")
|
|
65
|
+
mongodb_database: str = "regstack"
|
|
66
|
+
user_collection: str = "users"
|
|
67
|
+
pending_collection: str = "pending_registrations"
|
|
68
|
+
blacklist_collection: str = "token_blacklist"
|
|
69
|
+
login_attempt_collection: str = "login_attempts"
|
|
70
|
+
mfa_code_collection: str = "mfa_codes"
|
|
71
|
+
|
|
72
|
+
# JWT
|
|
73
|
+
jwt_secret: SecretStr = Field(default_factory=lambda: SecretStr(""))
|
|
74
|
+
jwt_algorithm: JwtAlgorithm = "HS256"
|
|
75
|
+
jwt_ttl_seconds: Annotated[int, Field(ge=60, le=60 * 60 * 24 * 30)] = 7200
|
|
76
|
+
jwt_audience: str | None = None
|
|
77
|
+
transport: TokenTransport = "bearer"
|
|
78
|
+
|
|
79
|
+
# Verification & password-reset & email-change token lifetimes
|
|
80
|
+
verification_token_ttl_seconds: Annotated[int, Field(ge=60)] = 60 * 60 * 24
|
|
81
|
+
password_reset_token_ttl_seconds: Annotated[int, Field(ge=60)] = 60 * 30
|
|
82
|
+
email_change_token_ttl_seconds: Annotated[int, Field(ge=60)] = 60 * 60
|
|
83
|
+
|
|
84
|
+
# SMS / 2FA
|
|
85
|
+
sms_code_length: Annotated[int, Field(ge=4, le=10)] = 6
|
|
86
|
+
sms_code_ttl_seconds: Annotated[int, Field(ge=30, le=60 * 30)] = 300
|
|
87
|
+
sms_code_max_attempts: Annotated[int, Field(ge=1, le=20)] = 5
|
|
88
|
+
mfa_pending_token_ttl_seconds: Annotated[int, Field(ge=60, le=60 * 30)] = 600
|
|
89
|
+
|
|
90
|
+
# Feature flags
|
|
91
|
+
require_verification: bool = True
|
|
92
|
+
allow_registration: bool = True
|
|
93
|
+
enable_password_reset: bool = True
|
|
94
|
+
enable_account_deletion: bool = True
|
|
95
|
+
enable_admin_router: bool = False
|
|
96
|
+
enable_ui_router: bool = False
|
|
97
|
+
enable_sms_2fa: bool = False
|
|
98
|
+
enable_oauth: bool = False # reserved; no providers ship in v1
|
|
99
|
+
|
|
100
|
+
# Login lockout (M2: count failed attempts per email in a sliding window)
|
|
101
|
+
rate_limit_disabled: bool = False
|
|
102
|
+
login_lockout_threshold: Annotated[int, Field(ge=1)] = 5
|
|
103
|
+
login_lockout_window_seconds: Annotated[int, Field(ge=10)] = 900
|
|
104
|
+
|
|
105
|
+
# Reserved for future route-level rate limiting (slowapi-style).
|
|
106
|
+
login_max_per_minute: Annotated[int, Field(ge=1)] = 5
|
|
107
|
+
login_max_per_hour: Annotated[int, Field(ge=1)] = 20
|
|
108
|
+
|
|
109
|
+
# Sub-configs
|
|
110
|
+
email: EmailConfig = Field(default_factory=EmailConfig)
|
|
111
|
+
sms: SmsConfig = Field(default_factory=SmsConfig)
|
|
112
|
+
|
|
113
|
+
# Branding / theming
|
|
114
|
+
brand_logo_url: str | None = None
|
|
115
|
+
brand_tagline: str | None = None
|
|
116
|
+
extra_template_dirs: list[Path] = Field(default_factory=list)
|
|
117
|
+
extra_static_dirs: list[Path] = Field(default_factory=list)
|
|
118
|
+
|
|
119
|
+
# SSR ui_router URLs
|
|
120
|
+
api_prefix: str = "/api/auth"
|
|
121
|
+
ui_prefix: str = "/account"
|
|
122
|
+
static_prefix: str = "/regstack-static"
|
|
123
|
+
theme_css_url: str | None = None # if set, loaded AFTER bundled defaults
|
|
124
|
+
|
|
125
|
+
@field_validator("jwt_secret")
|
|
126
|
+
@classmethod
|
|
127
|
+
def _warn_empty_secret(cls, v: SecretStr) -> SecretStr:
|
|
128
|
+
# An empty secret is allowed at construction time so defaults remain
|
|
129
|
+
# usable in tests; production callers should populate it explicitly
|
|
130
|
+
# or via the wizard. Validation that *requires* it lives at the
|
|
131
|
+
# RegStack façade boundary so test fixtures can opt out.
|
|
132
|
+
return v
|
|
133
|
+
|
|
134
|
+
@classmethod
|
|
135
|
+
def load(
|
|
136
|
+
cls,
|
|
137
|
+
toml_path: Path | str | None = None,
|
|
138
|
+
secrets_env_path: Path | str | None = None,
|
|
139
|
+
**overrides: object,
|
|
140
|
+
) -> RegStackConfig:
|
|
141
|
+
"""Convenience constructor delegating to ``regstack.config.loader.load_config``."""
|
|
142
|
+
from regstack.config.loader import load_config
|
|
143
|
+
|
|
144
|
+
return load_config(
|
|
145
|
+
toml_path=toml_path,
|
|
146
|
+
secrets_env_path=secrets_env_path,
|
|
147
|
+
**overrides,
|
|
148
|
+
)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import hmac
|
|
5
|
+
import secrets
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def derive_secret(master: str | bytes, purpose: str) -> bytes:
|
|
9
|
+
"""Derive a purpose-specific secret from the master JWT secret.
|
|
10
|
+
|
|
11
|
+
Uses HMAC-SHA256 so every subsystem (verification tokens, password reset
|
|
12
|
+
tokens, refresh tokens, etc.) signs with a different key. Compromising one
|
|
13
|
+
derived key does not compromise the master.
|
|
14
|
+
"""
|
|
15
|
+
if isinstance(master, str):
|
|
16
|
+
master = master.encode("utf-8")
|
|
17
|
+
return hmac.new(master, purpose.encode("utf-8"), hashlib.sha256).digest()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def generate_secret(num_bytes: int = 64) -> str:
|
|
21
|
+
"""Return a URL-safe random secret suitable for the JWT master key."""
|
|
22
|
+
return secrets.token_urlsafe(num_bytes)
|
regstack/db/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from regstack.db.indexes import install_indexes
|
|
2
|
+
from regstack.db.repositories.blacklist_repo import BlacklistRepo
|
|
3
|
+
from regstack.db.repositories.login_attempt_repo import LoginAttemptRepo
|
|
4
|
+
from regstack.db.repositories.mfa_code_repo import MfaCodeRepo, MfaVerifyOutcome, MfaVerifyResult
|
|
5
|
+
from regstack.db.repositories.pending_repo import PendingRepo
|
|
6
|
+
from regstack.db.repositories.user_repo import UserRepo
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"BlacklistRepo",
|
|
10
|
+
"LoginAttemptRepo",
|
|
11
|
+
"MfaCodeRepo",
|
|
12
|
+
"MfaVerifyOutcome",
|
|
13
|
+
"MfaVerifyResult",
|
|
14
|
+
"PendingRepo",
|
|
15
|
+
"UserRepo",
|
|
16
|
+
"install_indexes",
|
|
17
|
+
]
|
regstack/db/client.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from pymongo import AsyncMongoClient
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from pymongo.asynchronous.database import AsyncDatabase
|
|
9
|
+
|
|
10
|
+
from regstack.config.schema import RegStackConfig
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def make_client(config: RegStackConfig) -> AsyncMongoClient:
|
|
14
|
+
"""Build an AsyncMongoClient with the settings regstack expects.
|
|
15
|
+
|
|
16
|
+
``tz_aware=True`` makes BSON datetimes round-trip as UTC-aware values; the
|
|
17
|
+
JWT and bulk-revocation comparisons assume aware datetimes throughout.
|
|
18
|
+
"""
|
|
19
|
+
return AsyncMongoClient(
|
|
20
|
+
config.mongodb_url.get_secret_value(),
|
|
21
|
+
tz_aware=True,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_database(client: AsyncMongoClient, config: RegStackConfig) -> AsyncDatabase:
|
|
26
|
+
return client[config.mongodb_database]
|
regstack/db/indexes.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from pymongo import ASCENDING, IndexModel
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from pymongo.asynchronous.database import AsyncDatabase
|
|
10
|
+
|
|
11
|
+
from regstack.config.schema import RegStackConfig
|
|
12
|
+
|
|
13
|
+
log = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
async def install_indexes(db: AsyncDatabase, config: RegStackConfig) -> None:
|
|
17
|
+
"""Create the indexes regstack relies on. Safe to call repeatedly."""
|
|
18
|
+
users = db[config.user_collection]
|
|
19
|
+
await users.create_indexes(
|
|
20
|
+
[IndexModel([("email", ASCENDING)], unique=True, name="email_unique")]
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
blacklist = db[config.blacklist_collection]
|
|
24
|
+
# TTL on `exp` lets MongoDB reap revoked tokens when they would have
|
|
25
|
+
# expired anyway. expireAfterSeconds=0 means "delete when the date is
|
|
26
|
+
# in the past" — the value at `exp` is the deletion deadline.
|
|
27
|
+
await blacklist.create_indexes(
|
|
28
|
+
[
|
|
29
|
+
IndexModel([("jti", ASCENDING)], unique=True, name="jti_unique"),
|
|
30
|
+
IndexModel([("exp", ASCENDING)], expireAfterSeconds=0, name="exp_ttl"),
|
|
31
|
+
]
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
pending = db[config.pending_collection]
|
|
35
|
+
await pending.create_indexes(
|
|
36
|
+
[
|
|
37
|
+
IndexModel([("email", ASCENDING)], unique=True, name="pending_email_unique"),
|
|
38
|
+
IndexModel([("token_hash", ASCENDING)], unique=True, name="pending_token_unique"),
|
|
39
|
+
IndexModel([("expires_at", ASCENDING)], expireAfterSeconds=0, name="pending_ttl"),
|
|
40
|
+
]
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
attempts = db[config.login_attempt_collection]
|
|
44
|
+
# Sparse-ish TTL — rows survive `login_lockout_window_seconds` after
|
|
45
|
+
# `when`. The TTL value comes from config so tightening the lockout
|
|
46
|
+
# window also tightens cleanup.
|
|
47
|
+
await attempts.create_indexes(
|
|
48
|
+
[
|
|
49
|
+
IndexModel([("email", ASCENDING), ("when", ASCENDING)], name="email_when"),
|
|
50
|
+
IndexModel(
|
|
51
|
+
[("when", ASCENDING)],
|
|
52
|
+
expireAfterSeconds=config.login_lockout_window_seconds,
|
|
53
|
+
name="when_ttl",
|
|
54
|
+
),
|
|
55
|
+
]
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
mfa = db[config.mfa_code_collection]
|
|
59
|
+
await mfa.create_indexes(
|
|
60
|
+
[
|
|
61
|
+
IndexModel(
|
|
62
|
+
[("user_id", ASCENDING), ("kind", ASCENDING)],
|
|
63
|
+
unique=True,
|
|
64
|
+
name="user_kind_unique",
|
|
65
|
+
),
|
|
66
|
+
IndexModel([("expires_at", ASCENDING)], expireAfterSeconds=0, name="mfa_ttl"),
|
|
67
|
+
]
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
log.info("regstack indexes installed on database %s", db.name)
|
|
File without changes
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
from pymongo.errors import DuplicateKeyError
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from pymongo.asynchronous.database import AsyncDatabase
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class BlacklistRepo:
|
|
14
|
+
"""Per-token revocation store. The `exp` field has a TTL index that
|
|
15
|
+
auto-reaps documents once the underlying token would have expired anyway.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, db: AsyncDatabase, collection_name: str) -> None:
|
|
19
|
+
self._collection = db[collection_name]
|
|
20
|
+
|
|
21
|
+
async def revoke(self, jti: str, exp: datetime) -> None:
|
|
22
|
+
# Idempotent — re-revoking the same jti is a no-op.
|
|
23
|
+
with contextlib.suppress(DuplicateKeyError):
|
|
24
|
+
await self._collection.insert_one({"jti": jti, "exp": exp})
|
|
25
|
+
|
|
26
|
+
async def is_revoked(self, jti: str) -> bool:
|
|
27
|
+
doc = await self._collection.find_one({"jti": jti}, projection={"_id": 1})
|
|
28
|
+
return doc is not None
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from datetime import UTC, datetime, timedelta
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from regstack.models.login_attempt import LoginAttempt
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from pymongo.asynchronous.database import AsyncDatabase
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LoginAttemptRepo:
|
|
13
|
+
def __init__(self, db: AsyncDatabase, collection_name: str) -> None:
|
|
14
|
+
self._collection = db[collection_name]
|
|
15
|
+
|
|
16
|
+
async def record_failure(
|
|
17
|
+
self, email: str, *, when: datetime | None = None, ip: str | None = None
|
|
18
|
+
) -> None:
|
|
19
|
+
attempt = LoginAttempt(email=email, when=when or datetime.now(UTC), ip=ip)
|
|
20
|
+
await self._collection.insert_one(attempt.to_mongo())
|
|
21
|
+
|
|
22
|
+
async def count_recent(self, email: str, *, window: timedelta, now: datetime) -> int:
|
|
23
|
+
cutoff = now - window
|
|
24
|
+
return await self._collection.count_documents({"email": email, "when": {"$gte": cutoff}})
|
|
25
|
+
|
|
26
|
+
async def clear(self, email: str) -> None:
|
|
27
|
+
await self._collection.delete_many({"email": email})
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from enum import StrEnum
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from regstack.auth.tokens import hash_token
|
|
9
|
+
from regstack.models.mfa_code import MfaCode, MfaKind
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from pymongo.asynchronous.database import AsyncDatabase
|
|
13
|
+
|
|
14
|
+
from regstack.auth.clock import Clock
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class MfaVerifyOutcome(StrEnum):
|
|
18
|
+
OK = "ok"
|
|
19
|
+
WRONG = "wrong"
|
|
20
|
+
EXPIRED = "expired"
|
|
21
|
+
LOCKED = "locked"
|
|
22
|
+
MISSING = "missing"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(slots=True, frozen=True)
|
|
26
|
+
class MfaVerifyResult:
|
|
27
|
+
outcome: MfaVerifyOutcome
|
|
28
|
+
attempts_remaining: int = 0
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MfaCodeRepo:
|
|
32
|
+
def __init__(self, db: AsyncDatabase, collection_name: str, *, clock: Clock) -> None:
|
|
33
|
+
self._collection = db[collection_name]
|
|
34
|
+
self._clock = clock
|
|
35
|
+
|
|
36
|
+
async def put(self, code: MfaCode) -> None:
|
|
37
|
+
"""Upsert by ``(user_id, kind)`` — re-issuing a code overwrites
|
|
38
|
+
any previous outstanding code, so old SMS messages stop working
|
|
39
|
+
as soon as a new one is sent.
|
|
40
|
+
"""
|
|
41
|
+
doc = code.to_mongo()
|
|
42
|
+
await self._collection.find_one_and_replace(
|
|
43
|
+
{"user_id": code.user_id, "kind": code.kind},
|
|
44
|
+
doc,
|
|
45
|
+
upsert=True,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
async def verify(
|
|
49
|
+
self,
|
|
50
|
+
*,
|
|
51
|
+
user_id: str,
|
|
52
|
+
kind: MfaKind,
|
|
53
|
+
raw_code: str,
|
|
54
|
+
) -> MfaVerifyResult:
|
|
55
|
+
doc = await self._collection.find_one({"user_id": user_id, "kind": kind})
|
|
56
|
+
if doc is None:
|
|
57
|
+
return MfaVerifyResult(MfaVerifyOutcome.MISSING)
|
|
58
|
+
attempts = int(doc.get("attempts", 0))
|
|
59
|
+
max_attempts = int(doc.get("max_attempts", 5))
|
|
60
|
+
if attempts >= max_attempts:
|
|
61
|
+
await self._collection.delete_one({"_id": doc["_id"]})
|
|
62
|
+
return MfaVerifyResult(MfaVerifyOutcome.LOCKED)
|
|
63
|
+
|
|
64
|
+
if doc["expires_at"] <= self._clock.now():
|
|
65
|
+
await self._collection.delete_one({"_id": doc["_id"]})
|
|
66
|
+
return MfaVerifyResult(MfaVerifyOutcome.EXPIRED)
|
|
67
|
+
|
|
68
|
+
if doc["code_hash"] != hash_token(raw_code):
|
|
69
|
+
new_attempts = attempts + 1
|
|
70
|
+
await self._collection.update_one(
|
|
71
|
+
{"_id": doc["_id"]}, {"$set": {"attempts": new_attempts}}
|
|
72
|
+
)
|
|
73
|
+
remaining = max(max_attempts - new_attempts, 0)
|
|
74
|
+
if remaining == 0:
|
|
75
|
+
await self._collection.delete_one({"_id": doc["_id"]})
|
|
76
|
+
return MfaVerifyResult(MfaVerifyOutcome.LOCKED)
|
|
77
|
+
return MfaVerifyResult(MfaVerifyOutcome.WRONG, attempts_remaining=remaining)
|
|
78
|
+
|
|
79
|
+
await self._collection.delete_one({"_id": doc["_id"]})
|
|
80
|
+
return MfaVerifyResult(MfaVerifyOutcome.OK)
|
|
81
|
+
|
|
82
|
+
async def delete(self, *, user_id: str, kind: MfaKind | None = None) -> None:
|
|
83
|
+
query: dict[str, object] = {"user_id": user_id}
|
|
84
|
+
if kind is not None:
|
|
85
|
+
query["kind"] = kind
|
|
86
|
+
await self._collection.delete_many(query)
|
|
87
|
+
|
|
88
|
+
async def find(self, *, user_id: str, kind: MfaKind) -> dict[str, object] | None:
|
|
89
|
+
return await self._collection.find_one({"user_id": user_id, "kind": kind})
|
|
90
|
+
|
|
91
|
+
@staticmethod
|
|
92
|
+
def make_code_hash(raw_code: str) -> str:
|
|
93
|
+
return hash_token(raw_code)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def now_plus_seconds(clock: Clock, seconds: int) -> datetime:
|
|
97
|
+
from datetime import timedelta
|
|
98
|
+
|
|
99
|
+
return clock.now() + timedelta(seconds=seconds)
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from datetime import UTC, datetime
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
6
|
+
from bson import ObjectId
|
|
7
|
+
from pymongo.errors import DuplicateKeyError
|
|
8
|
+
|
|
9
|
+
from regstack.models.pending_registration import PendingRegistration
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from pymongo.asynchronous.database import AsyncDatabase
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class PendingAlreadyExistsError(Exception):
|
|
16
|
+
"""A pending registration with this email already exists."""
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PendingRepo:
|
|
20
|
+
def __init__(self, db: AsyncDatabase, collection_name: str) -> None:
|
|
21
|
+
self._collection = db[collection_name]
|
|
22
|
+
|
|
23
|
+
async def upsert(self, pending: PendingRegistration) -> PendingRegistration:
|
|
24
|
+
"""Insert or replace the pending registration for this email.
|
|
25
|
+
|
|
26
|
+
Resends overwrite an outstanding row so the most recent token is the
|
|
27
|
+
only valid one — old links stop working as soon as a new one is sent.
|
|
28
|
+
"""
|
|
29
|
+
doc = pending.to_mongo()
|
|
30
|
+
result = await self._collection.find_one_and_replace(
|
|
31
|
+
{"email": pending.email},
|
|
32
|
+
doc,
|
|
33
|
+
upsert=True,
|
|
34
|
+
return_document=True,
|
|
35
|
+
)
|
|
36
|
+
if result is not None and "_id" in result:
|
|
37
|
+
pending.id = str(result["_id"])
|
|
38
|
+
return pending
|
|
39
|
+
|
|
40
|
+
async def create(self, pending: PendingRegistration) -> PendingRegistration:
|
|
41
|
+
try:
|
|
42
|
+
result = await self._collection.insert_one(pending.to_mongo())
|
|
43
|
+
except DuplicateKeyError as exc:
|
|
44
|
+
raise PendingAlreadyExistsError(pending.email) from exc
|
|
45
|
+
pending.id = str(result.inserted_id)
|
|
46
|
+
return pending
|
|
47
|
+
|
|
48
|
+
async def find_by_token_hash(self, token_hash: str) -> PendingRegistration | None:
|
|
49
|
+
doc = await self._collection.find_one({"token_hash": token_hash})
|
|
50
|
+
return self._hydrate(doc)
|
|
51
|
+
|
|
52
|
+
async def find_by_email(self, email: str) -> PendingRegistration | None:
|
|
53
|
+
doc = await self._collection.find_one({"email": email})
|
|
54
|
+
return self._hydrate(doc)
|
|
55
|
+
|
|
56
|
+
async def delete_by_id(self, pending_id: str) -> None:
|
|
57
|
+
if not ObjectId.is_valid(pending_id):
|
|
58
|
+
return
|
|
59
|
+
await self._collection.delete_one({"_id": ObjectId(pending_id)})
|
|
60
|
+
|
|
61
|
+
async def delete_by_email(self, email: str) -> None:
|
|
62
|
+
await self._collection.delete_one({"email": email})
|
|
63
|
+
|
|
64
|
+
async def purge_expired(self, now: datetime | None = None) -> int:
|
|
65
|
+
"""Manual reaper for callers that don't trust the TTL background sweep."""
|
|
66
|
+
cutoff = now or datetime.now(UTC)
|
|
67
|
+
result = await self._collection.delete_many({"expires_at": {"$lt": cutoff}})
|
|
68
|
+
return int(result.deleted_count)
|
|
69
|
+
|
|
70
|
+
@staticmethod
|
|
71
|
+
def _hydrate(doc: dict[str, Any] | None) -> PendingRegistration | None:
|
|
72
|
+
if doc is None:
|
|
73
|
+
return None
|
|
74
|
+
if isinstance(doc.get("_id"), ObjectId):
|
|
75
|
+
doc["_id"] = str(doc["_id"])
|
|
76
|
+
return PendingRegistration.model_validate(doc)
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
6
|
+
from bson import ObjectId
|
|
7
|
+
from pymongo.errors import DuplicateKeyError
|
|
8
|
+
|
|
9
|
+
from regstack.auth.clock import Clock, SystemClock
|
|
10
|
+
from regstack.models.user import BaseUser
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from pymongo.asynchronous.database import AsyncDatabase
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class UserAlreadyExistsError(Exception):
|
|
17
|
+
"""Raised when an attempt is made to insert a user with a duplicate email."""
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _bulk_revoke_cutoff(now: datetime) -> datetime:
|
|
21
|
+
"""The cutoff timestamp recorded on the user document. Stored at full
|
|
22
|
+
microsecond precision; the JWT ``iat`` claim is also emitted as a float
|
|
23
|
+
so the ``iat < cutoff`` comparison is exact and a fresh login completing
|
|
24
|
+
even microseconds after a password / email change is recognised as
|
|
25
|
+
later-than-cutoff.
|
|
26
|
+
"""
|
|
27
|
+
return now
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class UserRepo:
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
db: AsyncDatabase,
|
|
34
|
+
collection_name: str,
|
|
35
|
+
*,
|
|
36
|
+
clock: Clock | None = None,
|
|
37
|
+
) -> None:
|
|
38
|
+
self._collection = db[collection_name]
|
|
39
|
+
self._clock: Clock = clock or SystemClock()
|
|
40
|
+
|
|
41
|
+
async def create(self, user: BaseUser) -> BaseUser:
|
|
42
|
+
doc = user.to_mongo()
|
|
43
|
+
try:
|
|
44
|
+
result = await self._collection.insert_one(doc)
|
|
45
|
+
except DuplicateKeyError as exc:
|
|
46
|
+
raise UserAlreadyExistsError(user.email) from exc
|
|
47
|
+
user.id = str(result.inserted_id)
|
|
48
|
+
return user
|
|
49
|
+
|
|
50
|
+
async def get_by_email(self, email: str) -> BaseUser | None:
|
|
51
|
+
doc = await self._collection.find_one({"email": email})
|
|
52
|
+
return self._hydrate(doc)
|
|
53
|
+
|
|
54
|
+
async def get_by_id(self, user_id: str) -> BaseUser | None:
|
|
55
|
+
if not ObjectId.is_valid(user_id):
|
|
56
|
+
return None
|
|
57
|
+
doc = await self._collection.find_one({"_id": ObjectId(user_id)})
|
|
58
|
+
return self._hydrate(doc)
|
|
59
|
+
|
|
60
|
+
async def set_last_login(self, user_id: str, when: datetime) -> None:
|
|
61
|
+
await self._collection.update_one(
|
|
62
|
+
{"_id": ObjectId(user_id)},
|
|
63
|
+
{"$set": {"last_login": when, "updated_at": self._clock.now()}},
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
async def set_tokens_invalidated_after(self, user_id: str, when: datetime) -> None:
|
|
67
|
+
await self._collection.update_one(
|
|
68
|
+
{"_id": ObjectId(user_id)},
|
|
69
|
+
{
|
|
70
|
+
"$set": {
|
|
71
|
+
"tokens_invalidated_after": _bulk_revoke_cutoff(when),
|
|
72
|
+
"updated_at": self._clock.now(),
|
|
73
|
+
}
|
|
74
|
+
},
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
async def update_password(self, user_id: str, hashed_password: str) -> None:
|
|
78
|
+
now = self._clock.now()
|
|
79
|
+
await self._collection.update_one(
|
|
80
|
+
{"_id": ObjectId(user_id)},
|
|
81
|
+
{
|
|
82
|
+
"$set": {
|
|
83
|
+
"hashed_password": hashed_password,
|
|
84
|
+
"tokens_invalidated_after": _bulk_revoke_cutoff(now),
|
|
85
|
+
"updated_at": now,
|
|
86
|
+
}
|
|
87
|
+
},
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
async def set_active(self, user_id: str, *, is_active: bool) -> None:
|
|
91
|
+
await self._collection.update_one(
|
|
92
|
+
{"_id": ObjectId(user_id)},
|
|
93
|
+
{"$set": {"is_active": is_active, "updated_at": self._clock.now()}},
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
async def set_superuser(self, user_id: str, *, is_superuser: bool) -> None:
|
|
97
|
+
await self._collection.update_one(
|
|
98
|
+
{"_id": ObjectId(user_id)},
|
|
99
|
+
{"$set": {"is_superuser": is_superuser, "updated_at": self._clock.now()}},
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
async def set_full_name(self, user_id: str, full_name: str | None) -> None:
|
|
103
|
+
await self._collection.update_one(
|
|
104
|
+
{"_id": ObjectId(user_id)},
|
|
105
|
+
{"$set": {"full_name": full_name, "updated_at": self._clock.now()}},
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
async def set_phone(self, user_id: str, phone_number: str | None) -> None:
|
|
109
|
+
await self._collection.update_one(
|
|
110
|
+
{"_id": ObjectId(user_id)},
|
|
111
|
+
{"$set": {"phone_number": phone_number, "updated_at": self._clock.now()}},
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
async def set_mfa_enabled(self, user_id: str, *, is_mfa_enabled: bool) -> None:
|
|
115
|
+
await self._collection.update_one(
|
|
116
|
+
{"_id": ObjectId(user_id)},
|
|
117
|
+
{"$set": {"is_mfa_enabled": is_mfa_enabled, "updated_at": self._clock.now()}},
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
async def update_email(self, user_id: str, new_email: str) -> None:
|
|
121
|
+
"""Atomically swap the user's email. Bumps tokens_invalidated_after so
|
|
122
|
+
any session bound to the old email becomes useless.
|
|
123
|
+
"""
|
|
124
|
+
now = self._clock.now()
|
|
125
|
+
try:
|
|
126
|
+
await self._collection.update_one(
|
|
127
|
+
{"_id": ObjectId(user_id)},
|
|
128
|
+
{
|
|
129
|
+
"$set": {
|
|
130
|
+
"email": new_email,
|
|
131
|
+
"tokens_invalidated_after": _bulk_revoke_cutoff(now),
|
|
132
|
+
"updated_at": now,
|
|
133
|
+
}
|
|
134
|
+
},
|
|
135
|
+
)
|
|
136
|
+
except DuplicateKeyError as exc:
|
|
137
|
+
raise UserAlreadyExistsError(new_email) from exc
|
|
138
|
+
|
|
139
|
+
async def delete(self, user_id: str) -> bool:
|
|
140
|
+
if not ObjectId.is_valid(user_id):
|
|
141
|
+
return False
|
|
142
|
+
result = await self._collection.delete_one({"_id": ObjectId(user_id)})
|
|
143
|
+
return bool(result.deleted_count)
|
|
144
|
+
|
|
145
|
+
async def count(self, *, filter_: dict[str, Any] | None = None) -> int:
|
|
146
|
+
return await self._collection.count_documents(filter_ or {})
|
|
147
|
+
|
|
148
|
+
async def list_paged(
|
|
149
|
+
self,
|
|
150
|
+
*,
|
|
151
|
+
skip: int = 0,
|
|
152
|
+
limit: int = 50,
|
|
153
|
+
sort: tuple[str, int] = ("created_at", -1),
|
|
154
|
+
) -> list[BaseUser]:
|
|
155
|
+
cursor = self._collection.find().sort([sort]).skip(skip).limit(limit)
|
|
156
|
+
out: list[BaseUser] = []
|
|
157
|
+
async for doc in cursor:
|
|
158
|
+
user = self._hydrate(doc)
|
|
159
|
+
if user is not None:
|
|
160
|
+
out.append(user)
|
|
161
|
+
return out
|
|
162
|
+
|
|
163
|
+
@staticmethod
|
|
164
|
+
def _hydrate(doc: dict[str, Any] | None) -> BaseUser | None:
|
|
165
|
+
if doc is None:
|
|
166
|
+
return None
|
|
167
|
+
if isinstance(doc.get("_id"), ObjectId):
|
|
168
|
+
doc["_id"] = str(doc["_id"])
|
|
169
|
+
return BaseUser.model_validate(doc)
|