svc-infra 0.1.595__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of svc-infra might be problematic. Click here for more details.
- svc_infra/__init__.py +58 -2
- svc_infra/apf_payments/models.py +68 -38
- svc_infra/apf_payments/provider/__init__.py +2 -2
- svc_infra/apf_payments/provider/aiydan.py +39 -23
- svc_infra/apf_payments/provider/base.py +8 -3
- svc_infra/apf_payments/provider/registry.py +3 -5
- svc_infra/apf_payments/provider/stripe.py +74 -52
- svc_infra/apf_payments/schemas.py +84 -83
- svc_infra/apf_payments/service.py +27 -16
- svc_infra/apf_payments/settings.py +12 -11
- svc_infra/api/__init__.py +61 -0
- svc_infra/api/fastapi/__init__.py +34 -0
- svc_infra/api/fastapi/admin/__init__.py +3 -0
- svc_infra/api/fastapi/admin/add.py +240 -0
- svc_infra/api/fastapi/apf_payments/router.py +94 -73
- svc_infra/api/fastapi/apf_payments/setup.py +10 -9
- svc_infra/api/fastapi/auth/__init__.py +65 -0
- svc_infra/api/fastapi/auth/_cookies.py +1 -3
- svc_infra/api/fastapi/auth/add.py +14 -15
- svc_infra/api/fastapi/auth/gaurd.py +32 -20
- svc_infra/api/fastapi/auth/mfa/models.py +3 -4
- svc_infra/api/fastapi/auth/mfa/pre_auth.py +13 -9
- svc_infra/api/fastapi/auth/mfa/router.py +9 -8
- svc_infra/api/fastapi/auth/mfa/security.py +4 -7
- svc_infra/api/fastapi/auth/mfa/utils.py +5 -3
- svc_infra/api/fastapi/auth/policy.py +0 -1
- svc_infra/api/fastapi/auth/providers.py +3 -3
- svc_infra/api/fastapi/auth/routers/apikey_router.py +19 -21
- svc_infra/api/fastapi/auth/routers/oauth_router.py +98 -52
- svc_infra/api/fastapi/auth/routers/session_router.py +6 -5
- svc_infra/api/fastapi/auth/security.py +25 -15
- svc_infra/api/fastapi/auth/sender.py +5 -0
- svc_infra/api/fastapi/auth/settings.py +18 -19
- svc_infra/api/fastapi/auth/state.py +5 -4
- svc_infra/api/fastapi/auth/ws_security.py +275 -0
- svc_infra/api/fastapi/billing/router.py +71 -0
- svc_infra/api/fastapi/billing/setup.py +19 -0
- svc_infra/api/fastapi/cache/add.py +9 -5
- svc_infra/api/fastapi/db/__init__.py +5 -1
- svc_infra/api/fastapi/db/http.py +10 -9
- svc_infra/api/fastapi/db/nosql/__init__.py +39 -1
- svc_infra/api/fastapi/db/nosql/mongo/add.py +35 -30
- svc_infra/api/fastapi/db/nosql/mongo/crud_router.py +39 -21
- svc_infra/api/fastapi/db/sql/__init__.py +5 -1
- svc_infra/api/fastapi/db/sql/add.py +62 -25
- svc_infra/api/fastapi/db/sql/crud_router.py +205 -30
- svc_infra/api/fastapi/db/sql/session.py +19 -2
- svc_infra/api/fastapi/db/sql/users.py +18 -9
- svc_infra/api/fastapi/dependencies/ratelimit.py +76 -14
- svc_infra/api/fastapi/docs/add.py +163 -0
- svc_infra/api/fastapi/docs/landing.py +6 -6
- svc_infra/api/fastapi/docs/scoped.py +75 -36
- svc_infra/api/fastapi/dual/__init__.py +12 -2
- svc_infra/api/fastapi/dual/dualize.py +2 -2
- svc_infra/api/fastapi/dual/protected.py +123 -10
- svc_infra/api/fastapi/dual/public.py +25 -0
- svc_infra/api/fastapi/dual/router.py +18 -8
- svc_infra/api/fastapi/dx.py +33 -2
- svc_infra/api/fastapi/ease.py +59 -7
- svc_infra/api/fastapi/http/concurrency.py +2 -1
- svc_infra/api/fastapi/http/conditional.py +2 -2
- svc_infra/api/fastapi/middleware/debug.py +4 -1
- svc_infra/api/fastapi/middleware/errors/exceptions.py +2 -5
- svc_infra/api/fastapi/middleware/errors/handlers.py +50 -10
- svc_infra/api/fastapi/middleware/graceful_shutdown.py +95 -0
- svc_infra/api/fastapi/middleware/idempotency.py +190 -68
- svc_infra/api/fastapi/middleware/idempotency_store.py +187 -0
- svc_infra/api/fastapi/middleware/optimistic_lock.py +39 -0
- svc_infra/api/fastapi/middleware/ratelimit.py +125 -28
- svc_infra/api/fastapi/middleware/ratelimit_store.py +45 -13
- svc_infra/api/fastapi/middleware/request_id.py +24 -10
- svc_infra/api/fastapi/middleware/request_size_limit.py +3 -3
- svc_infra/api/fastapi/middleware/timeout.py +176 -0
- svc_infra/api/fastapi/object_router.py +1060 -0
- svc_infra/api/fastapi/openapi/apply.py +4 -3
- svc_infra/api/fastapi/openapi/conventions.py +13 -6
- svc_infra/api/fastapi/openapi/mutators.py +144 -17
- svc_infra/api/fastapi/openapi/pipeline.py +2 -2
- svc_infra/api/fastapi/openapi/responses.py +4 -6
- svc_infra/api/fastapi/openapi/security.py +1 -1
- svc_infra/api/fastapi/ops/add.py +73 -0
- svc_infra/api/fastapi/pagination.py +47 -32
- svc_infra/api/fastapi/routers/__init__.py +16 -10
- svc_infra/api/fastapi/routers/ping.py +1 -0
- svc_infra/api/fastapi/setup.py +167 -54
- svc_infra/api/fastapi/tenancy/add.py +20 -0
- svc_infra/api/fastapi/tenancy/context.py +113 -0
- svc_infra/api/fastapi/versioned.py +102 -0
- svc_infra/app/README.md +5 -5
- svc_infra/app/__init__.py +3 -1
- svc_infra/app/env.py +70 -4
- svc_infra/app/logging/add.py +10 -2
- svc_infra/app/logging/filter.py +1 -1
- svc_infra/app/logging/formats.py +13 -5
- svc_infra/app/root.py +3 -3
- svc_infra/billing/__init__.py +40 -0
- svc_infra/billing/async_service.py +167 -0
- svc_infra/billing/jobs.py +231 -0
- svc_infra/billing/models.py +146 -0
- svc_infra/billing/quotas.py +101 -0
- svc_infra/billing/schemas.py +34 -0
- svc_infra/bundled_docs/README.md +5 -0
- svc_infra/bundled_docs/__init__.py +1 -0
- svc_infra/bundled_docs/getting-started.md +6 -0
- svc_infra/cache/__init__.py +21 -5
- svc_infra/cache/add.py +167 -0
- svc_infra/cache/backend.py +9 -7
- svc_infra/cache/decorators.py +75 -20
- svc_infra/cache/demo.py +2 -2
- svc_infra/cache/keys.py +26 -6
- svc_infra/cache/recache.py +26 -27
- svc_infra/cache/resources.py +6 -5
- svc_infra/cache/tags.py +19 -44
- svc_infra/cache/ttl.py +2 -3
- svc_infra/cache/utils.py +4 -3
- svc_infra/cli/__init__.py +44 -8
- svc_infra/cli/__main__.py +4 -0
- svc_infra/cli/cmds/__init__.py +39 -2
- svc_infra/cli/cmds/db/nosql/mongo/mongo_cmds.py +18 -14
- svc_infra/cli/cmds/db/nosql/mongo/mongo_scaffold_cmds.py +9 -10
- svc_infra/cli/cmds/db/ops_cmds.py +267 -0
- svc_infra/cli/cmds/db/sql/alembic_cmds.py +97 -29
- svc_infra/cli/cmds/db/sql/sql_export_cmds.py +80 -0
- svc_infra/cli/cmds/db/sql/sql_scaffold_cmds.py +13 -13
- svc_infra/cli/cmds/docs/docs_cmds.py +139 -0
- svc_infra/cli/cmds/dx/__init__.py +12 -0
- svc_infra/cli/cmds/dx/dx_cmds.py +110 -0
- svc_infra/cli/cmds/health/__init__.py +179 -0
- svc_infra/cli/cmds/health/health_cmds.py +8 -0
- svc_infra/cli/cmds/help.py +4 -0
- svc_infra/cli/cmds/jobs/__init__.py +1 -0
- svc_infra/cli/cmds/jobs/jobs_cmds.py +42 -0
- svc_infra/cli/cmds/obs/obs_cmds.py +31 -13
- svc_infra/cli/cmds/sdk/__init__.py +0 -0
- svc_infra/cli/cmds/sdk/sdk_cmds.py +102 -0
- svc_infra/cli/foundation/runner.py +4 -5
- svc_infra/cli/foundation/typer_bootstrap.py +1 -2
- svc_infra/data/__init__.py +83 -0
- svc_infra/data/add.py +61 -0
- svc_infra/data/backup.py +56 -0
- svc_infra/data/erasure.py +46 -0
- svc_infra/data/fixtures.py +42 -0
- svc_infra/data/retention.py +56 -0
- svc_infra/db/__init__.py +15 -0
- svc_infra/db/crud_schema.py +14 -13
- svc_infra/db/inbox.py +67 -0
- svc_infra/db/nosql/__init__.py +2 -0
- svc_infra/db/nosql/constants.py +1 -1
- svc_infra/db/nosql/core.py +19 -5
- svc_infra/db/nosql/indexes.py +12 -9
- svc_infra/db/nosql/management.py +4 -4
- svc_infra/db/nosql/mongo/README.md +13 -13
- svc_infra/db/nosql/mongo/client.py +21 -4
- svc_infra/db/nosql/mongo/settings.py +1 -1
- svc_infra/db/nosql/repository.py +46 -27
- svc_infra/db/nosql/resource.py +28 -16
- svc_infra/db/nosql/scaffold.py +14 -12
- svc_infra/db/nosql/service.py +2 -1
- svc_infra/db/nosql/service_with_hooks.py +4 -3
- svc_infra/db/nosql/utils.py +4 -4
- svc_infra/db/ops.py +380 -0
- svc_infra/db/outbox.py +105 -0
- svc_infra/db/sql/apikey.py +34 -15
- svc_infra/db/sql/authref.py +8 -6
- svc_infra/db/sql/constants.py +5 -1
- svc_infra/db/sql/core.py +13 -13
- svc_infra/db/sql/management.py +5 -6
- svc_infra/db/sql/repository.py +92 -26
- svc_infra/db/sql/resource.py +18 -12
- svc_infra/db/sql/scaffold.py +11 -11
- svc_infra/db/sql/service.py +2 -1
- svc_infra/db/sql/service_with_hooks.py +4 -3
- svc_infra/db/sql/templates/models_schemas/auth/models.py.tmpl +7 -56
- svc_infra/db/sql/templates/setup/env_async.py.tmpl +34 -12
- svc_infra/db/sql/templates/setup/env_sync.py.tmpl +29 -7
- svc_infra/db/sql/tenant.py +80 -0
- svc_infra/db/sql/uniq.py +8 -7
- svc_infra/db/sql/uniq_hooks.py +12 -11
- svc_infra/db/sql/utils.py +105 -47
- svc_infra/db/sql/versioning.py +14 -0
- svc_infra/db/utils.py +3 -3
- svc_infra/deploy/__init__.py +531 -0
- svc_infra/documents/__init__.py +100 -0
- svc_infra/documents/add.py +263 -0
- svc_infra/documents/ease.py +233 -0
- svc_infra/documents/models.py +114 -0
- svc_infra/documents/storage.py +262 -0
- svc_infra/dx/__init__.py +58 -0
- svc_infra/dx/add.py +63 -0
- svc_infra/dx/changelog.py +74 -0
- svc_infra/dx/checks.py +68 -0
- svc_infra/exceptions.py +141 -0
- svc_infra/health/__init__.py +863 -0
- svc_infra/http/__init__.py +13 -0
- svc_infra/http/client.py +101 -0
- svc_infra/jobs/__init__.py +79 -0
- svc_infra/jobs/builtins/outbox_processor.py +38 -0
- svc_infra/jobs/builtins/webhook_delivery.py +93 -0
- svc_infra/jobs/easy.py +33 -0
- svc_infra/jobs/loader.py +49 -0
- svc_infra/jobs/queue.py +106 -0
- svc_infra/jobs/redis_queue.py +242 -0
- svc_infra/jobs/runner.py +75 -0
- svc_infra/jobs/scheduler.py +53 -0
- svc_infra/jobs/worker.py +40 -0
- svc_infra/loaders/__init__.py +186 -0
- svc_infra/loaders/base.py +143 -0
- svc_infra/loaders/github.py +309 -0
- svc_infra/loaders/models.py +147 -0
- svc_infra/loaders/url.py +229 -0
- svc_infra/logging/__init__.py +375 -0
- svc_infra/mcp/__init__.py +82 -0
- svc_infra/mcp/svc_infra_mcp.py +91 -33
- svc_infra/obs/README.md +2 -0
- svc_infra/obs/add.py +68 -11
- svc_infra/obs/cloud_dash.py +2 -1
- svc_infra/obs/grafana/dashboards/http-overview.json +45 -0
- svc_infra/obs/metrics/__init__.py +6 -7
- svc_infra/obs/metrics/asgi.py +8 -7
- svc_infra/obs/metrics/base.py +13 -13
- svc_infra/obs/metrics/http.py +3 -3
- svc_infra/obs/metrics/sqlalchemy.py +14 -13
- svc_infra/obs/metrics.py +9 -8
- svc_infra/resilience/__init__.py +44 -0
- svc_infra/resilience/circuit_breaker.py +328 -0
- svc_infra/resilience/retry.py +289 -0
- svc_infra/security/__init__.py +167 -0
- svc_infra/security/add.py +213 -0
- svc_infra/security/audit.py +97 -18
- svc_infra/security/audit_service.py +10 -9
- svc_infra/security/headers.py +15 -2
- svc_infra/security/hibp.py +14 -7
- svc_infra/security/jwt_rotation.py +78 -29
- svc_infra/security/lockout.py +23 -16
- svc_infra/security/models.py +77 -44
- svc_infra/security/oauth_models.py +73 -0
- svc_infra/security/org_invites.py +12 -12
- svc_infra/security/passwords.py +3 -3
- svc_infra/security/permissions.py +31 -7
- svc_infra/security/session.py +7 -8
- svc_infra/security/signed_cookies.py +26 -6
- svc_infra/storage/__init__.py +93 -0
- svc_infra/storage/add.py +250 -0
- svc_infra/storage/backends/__init__.py +11 -0
- svc_infra/storage/backends/local.py +331 -0
- svc_infra/storage/backends/memory.py +213 -0
- svc_infra/storage/backends/s3.py +334 -0
- svc_infra/storage/base.py +239 -0
- svc_infra/storage/easy.py +181 -0
- svc_infra/storage/settings.py +193 -0
- svc_infra/testing/__init__.py +682 -0
- svc_infra/utils.py +170 -5
- svc_infra/webhooks/__init__.py +69 -0
- svc_infra/webhooks/add.py +327 -0
- svc_infra/webhooks/encryption.py +115 -0
- svc_infra/webhooks/fastapi.py +37 -0
- svc_infra/webhooks/router.py +55 -0
- svc_infra/webhooks/service.py +69 -0
- svc_infra/webhooks/signing.py +34 -0
- svc_infra/websocket/__init__.py +79 -0
- svc_infra/websocket/add.py +139 -0
- svc_infra/websocket/client.py +283 -0
- svc_infra/websocket/config.py +57 -0
- svc_infra/websocket/easy.py +76 -0
- svc_infra/websocket/exceptions.py +61 -0
- svc_infra/websocket/manager.py +343 -0
- svc_infra/websocket/models.py +49 -0
- svc_infra-1.1.0.dist-info/LICENSE +21 -0
- svc_infra-1.1.0.dist-info/METADATA +362 -0
- svc_infra-1.1.0.dist-info/RECORD +364 -0
- svc_infra-0.1.595.dist-info/METADATA +0 -80
- svc_infra-0.1.595.dist-info/RECORD +0 -253
- {svc_infra-0.1.595.dist-info → svc_infra-1.1.0.dist-info}/WHEEL +0 -0
- {svc_infra-0.1.595.dist-info → svc_infra-1.1.0.dist-info}/entry_points.txt +0 -0
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
|
-
from typing import Optional
|
|
3
2
|
|
|
4
3
|
from pydantic import BaseModel
|
|
5
4
|
|
|
@@ -54,9 +53,9 @@ class MFAProof(BaseModel):
|
|
|
54
53
|
|
|
55
54
|
|
|
56
55
|
class DisableAccountIn(BaseModel):
|
|
57
|
-
reason:
|
|
58
|
-
mfa:
|
|
56
|
+
reason: str | None = None
|
|
57
|
+
mfa: MFAProof | None = None
|
|
59
58
|
|
|
60
59
|
|
|
61
60
|
class DeleteAccountIn(BaseModel):
|
|
62
|
-
mfa:
|
|
61
|
+
mfa: MFAProof | None = None
|
|
@@ -1,18 +1,22 @@
|
|
|
1
|
-
from datetime import
|
|
1
|
+
from datetime import UTC, datetime
|
|
2
2
|
|
|
3
3
|
from svc_infra.api.fastapi.auth.settings import get_auth_settings
|
|
4
|
+
from svc_infra.app.env import require_secret
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
def get_mfa_pre_jwt_writer():
|
|
7
8
|
st = get_auth_settings()
|
|
8
9
|
jwt_block = getattr(st, "jwt", None)
|
|
9
10
|
|
|
10
|
-
# Force to plain string
|
|
11
|
-
|
|
12
|
-
jwt_block.secret.get_secret_value()
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
11
|
+
# Force to plain string - use require_secret to ensure it's set in production
|
|
12
|
+
if jwt_block and getattr(jwt_block, "secret", None):
|
|
13
|
+
secret = jwt_block.secret.get_secret_value()
|
|
14
|
+
else:
|
|
15
|
+
secret = require_secret(
|
|
16
|
+
None,
|
|
17
|
+
"JWT_SECRET (via auth settings jwt.secret for MFA)",
|
|
18
|
+
dev_default="dev-only-mfa-jwt-secret-not-for-production",
|
|
19
|
+
)
|
|
16
20
|
secret = str(secret)
|
|
17
21
|
|
|
18
22
|
lifetime = int(getattr(st, "mfa_pre_token_lifetime_seconds", 300))
|
|
@@ -25,9 +29,9 @@ def get_mfa_pre_jwt_writer():
|
|
|
25
29
|
async def write(self, user):
|
|
26
30
|
from fastapi_users.jwt import generate_jwt
|
|
27
31
|
|
|
28
|
-
now = datetime.now(
|
|
32
|
+
now = datetime.now(UTC)
|
|
29
33
|
payload = {
|
|
30
|
-
"sub": str(
|
|
34
|
+
"sub": str(user.id),
|
|
31
35
|
"aud": ["fastapi-users:mfa"],
|
|
32
36
|
"iat": int(now.timestamp()),
|
|
33
37
|
"exp": int(now.timestamp()) + self.lifetime,
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from datetime import
|
|
3
|
+
from datetime import UTC, datetime
|
|
4
|
+
from typing import Any, cast
|
|
4
5
|
|
|
5
6
|
import pyotp
|
|
6
7
|
from fastapi import APIRouter, Body, Depends, HTTPException, Request, status
|
|
@@ -79,7 +80,7 @@ def mfa_router(
|
|
|
79
80
|
raise HTTPException(401, "Invalid token")
|
|
80
81
|
|
|
81
82
|
# IMPORTANT: rehydrate into *your* session
|
|
82
|
-
db_user = await session.get(user_model, user.id)
|
|
83
|
+
db_user = await cast("Any", session).get(user_model, user.id)
|
|
83
84
|
if not db_user:
|
|
84
85
|
raise HTTPException(401, "Invalid token")
|
|
85
86
|
|
|
@@ -126,7 +127,7 @@ def mfa_router(
|
|
|
126
127
|
|
|
127
128
|
# RELOAD from DB to avoid stale state
|
|
128
129
|
user = (
|
|
129
|
-
await session.execute(select(user_model).where(user_model.id == user.id))
|
|
130
|
+
await session.execute(select(user_model).where(user_model.id == user.id)) # type: ignore[attr-defined]
|
|
130
131
|
).scalar_one()
|
|
131
132
|
|
|
132
133
|
if not getattr(user, "mfa_secret", None):
|
|
@@ -141,7 +142,7 @@ def mfa_router(
|
|
|
141
142
|
|
|
142
143
|
user.mfa_recovery = [_hash(c) for c in codes]
|
|
143
144
|
user.mfa_enabled = True
|
|
144
|
-
user.mfa_confirmed_at = datetime.now(
|
|
145
|
+
user.mfa_confirmed_at = datetime.now(UTC)
|
|
145
146
|
await session.commit()
|
|
146
147
|
|
|
147
148
|
return RecoveryCodesOut(codes=codes)
|
|
@@ -198,7 +199,7 @@ def mfa_router(
|
|
|
198
199
|
raise HTTPException(401, "Invalid pre-auth token")
|
|
199
200
|
|
|
200
201
|
# 2) load user
|
|
201
|
-
user = await session.get(user_model, uid)
|
|
202
|
+
user = await cast("Any", session).get(user_model, uid)
|
|
202
203
|
if not user:
|
|
203
204
|
raise HTTPException(401, "Invalid pre-auth token")
|
|
204
205
|
|
|
@@ -242,7 +243,7 @@ def mfa_router(
|
|
|
242
243
|
raise HTTPException(400, "Invalid code")
|
|
243
244
|
|
|
244
245
|
# NEW: set last_login on successful MFA
|
|
245
|
-
user.last_login = datetime.now(
|
|
246
|
+
user.last_login = datetime.now(UTC)
|
|
246
247
|
await session.commit()
|
|
247
248
|
|
|
248
249
|
# 4) mint normal JWT and set cookie
|
|
@@ -271,7 +272,7 @@ def mfa_router(
|
|
|
271
272
|
raise HTTPException(401, "Invalid pre-auth token")
|
|
272
273
|
|
|
273
274
|
# 1b) Load user to get their email
|
|
274
|
-
user = await session.get(user_model, uid)
|
|
275
|
+
user = await cast("Any", session).get(user_model, uid)
|
|
275
276
|
if not user or not getattr(user, "email", None):
|
|
276
277
|
# (optionally also check user.mfa_enabled here)
|
|
277
278
|
raise HTTPException(401, "Invalid pre-auth token")
|
|
@@ -326,7 +327,7 @@ def mfa_router(
|
|
|
326
327
|
# Email OTP is always offered in your flow at verify-time
|
|
327
328
|
methods.append("email")
|
|
328
329
|
|
|
329
|
-
def _mask(email: str) -> str:
|
|
330
|
+
def _mask(email: str) -> str | None:
|
|
330
331
|
if not email or "@" not in email:
|
|
331
332
|
return None
|
|
332
333
|
name, domain = email.split("@", 1)
|
|
@@ -1,21 +1,18 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
|
-
|
|
3
1
|
from fastapi import Body, Depends, HTTPException, Query
|
|
4
2
|
|
|
5
3
|
from svc_infra.api.fastapi.auth.security import Identity
|
|
6
4
|
from svc_infra.api.fastapi.db.sql.session import SqlSessionDep
|
|
7
5
|
|
|
8
|
-
from .
|
|
9
|
-
from .verify import verify_mfa_for_user
|
|
6
|
+
from .verify import MFAProof, verify_mfa_for_user
|
|
10
7
|
|
|
11
8
|
|
|
12
9
|
def RequireMFAIfEnabled(body_field: str = "mfa"):
|
|
13
10
|
async def _dep(
|
|
14
11
|
p: Identity,
|
|
15
12
|
sess: SqlSessionDep,
|
|
16
|
-
mfa:
|
|
17
|
-
mfa_code:
|
|
18
|
-
mfa_pre_token:
|
|
13
|
+
mfa: MFAProof | None = Body(None, embed=True, alias=body_field),
|
|
14
|
+
mfa_code: str | None = Query(None, alias="mfa_code"),
|
|
15
|
+
mfa_pre_token: str | None = Query(None, alias="mfa_pre_token"),
|
|
19
16
|
):
|
|
20
17
|
proof = mfa or (
|
|
21
18
|
MFAProof(code=mfa_code, pre_token=mfa_pre_token)
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import base64
|
|
2
2
|
import hashlib
|
|
3
3
|
import os
|
|
4
|
+
from datetime import UTC
|
|
4
5
|
|
|
5
6
|
import pyotp
|
|
6
7
|
|
|
@@ -29,7 +30,8 @@ def _gen_recovery_codes(n: int, length: int) -> list[str]:
|
|
|
29
30
|
def _gen_numeric_code(n: int = 6) -> str:
|
|
30
31
|
import random
|
|
31
32
|
|
|
32
|
-
|
|
33
|
+
code = "".join(str(random.randrange(10)) for _ in range(n))
|
|
34
|
+
return code
|
|
33
35
|
|
|
34
36
|
|
|
35
37
|
def _hash(s: str) -> str:
|
|
@@ -37,6 +39,6 @@ def _hash(s: str) -> str:
|
|
|
37
39
|
|
|
38
40
|
|
|
39
41
|
def _now_utc_ts() -> int:
|
|
40
|
-
from datetime import datetime
|
|
42
|
+
from datetime import datetime
|
|
41
43
|
|
|
42
|
-
return int(datetime.now(
|
|
44
|
+
return int(datetime.now(UTC).timestamp())
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
from typing import Any
|
|
1
|
+
from typing import Any
|
|
2
2
|
|
|
3
3
|
|
|
4
|
-
def providers_from_settings(settings: Any) ->
|
|
4
|
+
def providers_from_settings(settings: Any) -> dict[str, dict[str, Any]]:
|
|
5
5
|
"""
|
|
6
6
|
Returns a registry of providers:
|
|
7
7
|
{
|
|
@@ -20,7 +20,7 @@ def providers_from_settings(settings: Any) -> Dict[str, Dict[str, Any]]:
|
|
|
20
20
|
}
|
|
21
21
|
}
|
|
22
22
|
"""
|
|
23
|
-
reg:
|
|
23
|
+
reg: dict[str, dict[str, Any]] = {}
|
|
24
24
|
|
|
25
25
|
# Google (OIDC)
|
|
26
26
|
if getattr(settings, "google_client_id", None) and getattr(
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from datetime import datetime, timedelta
|
|
4
|
-
from typing import
|
|
3
|
+
from datetime import UTC, datetime, timedelta
|
|
4
|
+
from typing import Any, cast
|
|
5
5
|
from uuid import UUID
|
|
6
6
|
|
|
7
7
|
from fastapi import HTTPException, Query
|
|
@@ -23,21 +23,21 @@ from svc_infra.db.sql.apikey import get_apikey_model
|
|
|
23
23
|
|
|
24
24
|
class ApiKeyCreateIn(BaseModel):
|
|
25
25
|
name: str
|
|
26
|
-
user_id:
|
|
27
|
-
scopes:
|
|
28
|
-
ttl_hours:
|
|
26
|
+
user_id: str | None = None
|
|
27
|
+
scopes: list[str] = Field(default_factory=list)
|
|
28
|
+
ttl_hours: int | None = 24 * 365 # default 1y
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
class ApiKeyOut(BaseModel):
|
|
32
32
|
id: str
|
|
33
33
|
name: str
|
|
34
|
-
user_id:
|
|
35
|
-
key:
|
|
34
|
+
user_id: str | None
|
|
35
|
+
key: str | None = None
|
|
36
36
|
key_prefix: str
|
|
37
|
-
scopes:
|
|
37
|
+
scopes: list[str]
|
|
38
38
|
active: bool
|
|
39
|
-
expires_at:
|
|
40
|
-
last_used_at:
|
|
39
|
+
expires_at: datetime | None
|
|
40
|
+
last_used_at: datetime | None
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
def _to_uuid(val):
|
|
@@ -56,17 +56,15 @@ def apikey_router():
|
|
|
56
56
|
description="Create a new API key. The plaintext key is shown only once, at creation time.",
|
|
57
57
|
)
|
|
58
58
|
async def create_key(sess: SqlSessionDep, payload: ApiKeyCreateIn, p: Identity):
|
|
59
|
-
caller_id: UUID =
|
|
59
|
+
caller_id: UUID = p.user.id
|
|
60
60
|
owner_id: UUID = _to_uuid(payload.user_id) if payload.user_id else caller_id
|
|
61
61
|
|
|
62
62
|
if owner_id != caller_id and not getattr(p.user, "is_superuser", False):
|
|
63
63
|
raise HTTPException(403, "forbidden")
|
|
64
64
|
|
|
65
|
-
plaintext, prefix, hashed = ApiKey.make_secret()
|
|
65
|
+
plaintext, prefix, hashed = ApiKey.make_secret() # type: ignore[attr-defined]
|
|
66
66
|
expires = (
|
|
67
|
-
(datetime.now(
|
|
68
|
-
if payload.ttl_hours
|
|
69
|
-
else None
|
|
67
|
+
(datetime.now(UTC) + timedelta(hours=payload.ttl_hours)) if payload.ttl_hours else None
|
|
70
68
|
)
|
|
71
69
|
|
|
72
70
|
row = ApiKey(
|
|
@@ -98,9 +96,9 @@ def apikey_router():
|
|
|
98
96
|
description="List API keys. Non-superusers see only their own keys.",
|
|
99
97
|
)
|
|
100
98
|
async def list_keys(sess: SqlSessionDep, p: Identity):
|
|
101
|
-
q = select(ApiKey)
|
|
99
|
+
q: Any = select(ApiKey)
|
|
102
100
|
if not getattr(p.user, "is_superuser", False):
|
|
103
|
-
q = q.where(ApiKey.user_id == p.user.id)
|
|
101
|
+
q = q.where(ApiKey.user_id == p.user.id) # type: ignore[attr-defined]
|
|
104
102
|
rows = (await sess.execute(q)).scalars().all()
|
|
105
103
|
return [
|
|
106
104
|
ApiKeyOut(
|
|
@@ -124,11 +122,11 @@ def apikey_router():
|
|
|
124
122
|
description="Revoke an API key",
|
|
125
123
|
)
|
|
126
124
|
async def revoke_key(key_id: str, sess: SqlSessionDep, p: Identity):
|
|
127
|
-
row = await sess.get(ApiKey, key_id)
|
|
125
|
+
row = await cast("Any", sess).get(ApiKey, key_id)
|
|
128
126
|
if not row:
|
|
129
127
|
raise HTTPException(404, "not_found")
|
|
130
128
|
|
|
131
|
-
caller_id: UUID =
|
|
129
|
+
caller_id: UUID = p.user.id
|
|
132
130
|
if not (getattr(p.user, "is_superuser", False) or row.user_id == caller_id):
|
|
133
131
|
raise HTTPException(403, "forbidden")
|
|
134
132
|
|
|
@@ -148,11 +146,11 @@ def apikey_router():
|
|
|
148
146
|
p: Identity,
|
|
149
147
|
force: bool = Query(False, description="Allow deleting an active key if True"),
|
|
150
148
|
):
|
|
151
|
-
row = await sess.get(ApiKey, key_id)
|
|
149
|
+
row = await cast("Any", sess).get(ApiKey, key_id)
|
|
152
150
|
if not row:
|
|
153
151
|
return # 204
|
|
154
152
|
|
|
155
|
-
caller_id: UUID =
|
|
153
|
+
caller_id: UUID = p.user.id
|
|
156
154
|
if not (getattr(p.user, "is_superuser", False) or row.user_id == caller_id):
|
|
157
155
|
raise HTTPException(403, "forbidden")
|
|
158
156
|
|
|
@@ -3,16 +3,16 @@ from __future__ import annotations
|
|
|
3
3
|
import base64
|
|
4
4
|
import hashlib
|
|
5
5
|
import secrets
|
|
6
|
-
from datetime import datetime, timedelta
|
|
7
|
-
from typing import Any,
|
|
6
|
+
from datetime import UTC, datetime, timedelta
|
|
7
|
+
from typing import Any, Literal, cast
|
|
8
8
|
from urllib.parse import urlencode, urlparse
|
|
9
9
|
|
|
10
10
|
import jwt
|
|
11
11
|
from authlib.integrations.base_client.errors import OAuthError
|
|
12
12
|
from authlib.integrations.starlette_client import OAuth
|
|
13
|
-
from fastapi import APIRouter, HTTPException, Request
|
|
13
|
+
from fastapi import APIRouter, Depends, HTTPException, Request
|
|
14
14
|
from fastapi.responses import RedirectResponse
|
|
15
|
-
from fastapi_users.authentication import AuthenticationBackend
|
|
15
|
+
from fastapi_users.authentication import AuthenticationBackend, Strategy
|
|
16
16
|
from fastapi_users.password import PasswordHelper
|
|
17
17
|
from sqlalchemy import select
|
|
18
18
|
from starlette import status
|
|
@@ -20,7 +20,10 @@ from starlette.responses import Response
|
|
|
20
20
|
|
|
21
21
|
from svc_infra.api.fastapi.auth.mfa.pre_auth import get_mfa_pre_jwt_writer
|
|
22
22
|
from svc_infra.api.fastapi.auth.policy import AuthPolicy, DefaultAuthPolicy
|
|
23
|
-
from svc_infra.api.fastapi.auth.settings import
|
|
23
|
+
from svc_infra.api.fastapi.auth.settings import (
|
|
24
|
+
get_auth_settings,
|
|
25
|
+
parse_redirect_allow_hosts,
|
|
26
|
+
)
|
|
24
27
|
from svc_infra.api.fastapi.db.sql.session import SqlSessionDep
|
|
25
28
|
from svc_infra.api.fastapi.dual.public import public_router
|
|
26
29
|
from svc_infra.api.fastapi.paths.auth import (
|
|
@@ -28,6 +31,7 @@ from svc_infra.api.fastapi.paths.auth import (
|
|
|
28
31
|
OAUTH_LOGIN_PATH,
|
|
29
32
|
OAUTH_REFRESH_PATH,
|
|
30
33
|
)
|
|
34
|
+
from svc_infra.app.env import require_secret
|
|
31
35
|
from svc_infra.security.models import RefreshToken
|
|
32
36
|
from svc_infra.security.session import issue_session_and_refresh, rotate_session_refresh
|
|
33
37
|
|
|
@@ -45,9 +49,12 @@ def _validate_redirect(url: str, allow_hosts: list[str], *, require_https: bool)
|
|
|
45
49
|
p = urlparse(url)
|
|
46
50
|
if not p.netloc:
|
|
47
51
|
return
|
|
48
|
-
|
|
52
|
+
if not p.hostname:
|
|
53
|
+
raise HTTPException(400, "redirect_not_allowed")
|
|
54
|
+
hostname = p.hostname
|
|
55
|
+
host_port = hostname.lower() + (f":{p.port}" if p.port else "")
|
|
49
56
|
allowed = {h.lower() for h in allow_hosts}
|
|
50
|
-
if host_port not in allowed and
|
|
57
|
+
if host_port not in allowed and hostname.lower() not in allowed:
|
|
51
58
|
raise HTTPException(400, "redirect_not_allowed")
|
|
52
59
|
if require_https and p.scheme != "https":
|
|
53
60
|
raise HTTPException(400, "https_required")
|
|
@@ -62,13 +69,13 @@ def _coerce_expires_at(token: dict | None) -> datetime | None:
|
|
|
62
69
|
v = float(token["expires_at"])
|
|
63
70
|
if v > 1e12: # ms -> s
|
|
64
71
|
v /= 1000.0
|
|
65
|
-
return datetime.fromtimestamp(v, tz=
|
|
72
|
+
return datetime.fromtimestamp(v, tz=UTC)
|
|
66
73
|
except Exception:
|
|
67
74
|
pass
|
|
68
75
|
if token.get("expires_in") is not None:
|
|
69
76
|
try:
|
|
70
77
|
secs = int(token["expires_in"])
|
|
71
|
-
return datetime.now(
|
|
78
|
+
return datetime.now(UTC) + timedelta(seconds=secs)
|
|
72
79
|
except Exception:
|
|
73
80
|
pass
|
|
74
81
|
return None
|
|
@@ -88,7 +95,7 @@ def _cookie_domain(st):
|
|
|
88
95
|
return d or None
|
|
89
96
|
|
|
90
97
|
|
|
91
|
-
def _register_oauth_providers(oauth: OAuth, providers:
|
|
98
|
+
def _register_oauth_providers(oauth: OAuth, providers: dict[str, dict[str, Any]]) -> None:
|
|
92
99
|
"""Register all OAuth providers with the OAuth client."""
|
|
93
100
|
for name, cfg in providers.items():
|
|
94
101
|
kind = cfg.get("kind")
|
|
@@ -263,14 +270,17 @@ async def _find_or_create_user(session, user_model, email: str, full_name: str |
|
|
|
263
270
|
is_verified=True,
|
|
264
271
|
)
|
|
265
272
|
|
|
266
|
-
# Set hashed password for OAuth users
|
|
273
|
+
# Set hashed password for OAuth users - use cryptographically random password
|
|
274
|
+
# OAuth users authenticate via provider, not password, so this is never used
|
|
275
|
+
# but must be unpredictable to prevent password-based login attacks
|
|
276
|
+
random_password = secrets.token_urlsafe(32)
|
|
267
277
|
if hasattr(user, "hashed_password"):
|
|
268
|
-
user.hashed_password = PasswordHelper().hash(
|
|
278
|
+
user.hashed_password = PasswordHelper().hash(random_password)
|
|
269
279
|
elif hasattr(user, "password_hash"):
|
|
270
|
-
user.password_hash = PasswordHelper().hash(
|
|
280
|
+
user.password_hash = PasswordHelper().hash(random_password)
|
|
271
281
|
|
|
272
282
|
if full_name and hasattr(user, "full_name"):
|
|
273
|
-
|
|
283
|
+
user.full_name = full_name
|
|
274
284
|
|
|
275
285
|
session.add(user)
|
|
276
286
|
await session.flush() # ensure user.id exists
|
|
@@ -335,11 +345,11 @@ async def _update_provider_account(
|
|
|
335
345
|
expires_at = _coerce_expires_at(tok)
|
|
336
346
|
|
|
337
347
|
if not link:
|
|
338
|
-
values =
|
|
339
|
-
user_id
|
|
340
|
-
provider
|
|
341
|
-
provider_account_id
|
|
342
|
-
|
|
348
|
+
values = {
|
|
349
|
+
"user_id": user.id,
|
|
350
|
+
"provider": provider,
|
|
351
|
+
"provider_account_id": provider_user_id,
|
|
352
|
+
}
|
|
343
353
|
if hasattr(provider_account_model, "access_token"):
|
|
344
354
|
values["access_token"] = access_token
|
|
345
355
|
if hasattr(provider_account_model, "refresh_token"):
|
|
@@ -373,9 +383,8 @@ async def _update_provider_account(
|
|
|
373
383
|
def _determine_final_redirect_url(request: Request, provider: str, post_login_redirect: str) -> str:
|
|
374
384
|
"""Determine the final redirect URL after successful authentication."""
|
|
375
385
|
st = get_auth_settings()
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
)
|
|
386
|
+
# Prioritize the parameter passed to the router over settings
|
|
387
|
+
redirect_url = str(post_login_redirect or getattr(st, "post_login_redirect", "/"))
|
|
379
388
|
allow_hosts = parse_redirect_allow_hosts(getattr(st, "redirect_allow_hosts_raw", None))
|
|
380
389
|
require_https = bool(getattr(st, "session_cookie_secure", False))
|
|
381
390
|
|
|
@@ -437,7 +446,13 @@ async def _process_user_authentication(
|
|
|
437
446
|
|
|
438
447
|
# Ensure provider link exists
|
|
439
448
|
await _update_provider_account(
|
|
440
|
-
session,
|
|
449
|
+
session,
|
|
450
|
+
provider_account_model,
|
|
451
|
+
user,
|
|
452
|
+
provider,
|
|
453
|
+
provider_user_id,
|
|
454
|
+
token,
|
|
455
|
+
raw_claims,
|
|
441
456
|
)
|
|
442
457
|
|
|
443
458
|
return user
|
|
@@ -446,11 +461,16 @@ async def _process_user_authentication(
|
|
|
446
461
|
async def _validate_and_decode_jwt_token(raw_token: str) -> str:
|
|
447
462
|
"""Validate and decode JWT token to extract user ID."""
|
|
448
463
|
st = get_auth_settings()
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
464
|
+
jwt_settings = getattr(st, "jwt", None)
|
|
465
|
+
jwt_secret = getattr(jwt_settings, "secret", None) if jwt_settings is not None else None
|
|
466
|
+
if jwt_secret:
|
|
467
|
+
secret = jwt_secret.get_secret_value()
|
|
468
|
+
else:
|
|
469
|
+
secret = require_secret(
|
|
470
|
+
None,
|
|
471
|
+
"JWT_SECRET (via auth settings jwt.secret for token validation)",
|
|
472
|
+
dev_default="dev-only-jwt-validation-secret-not-for-production",
|
|
473
|
+
)
|
|
454
474
|
|
|
455
475
|
try:
|
|
456
476
|
payload = jwt.decode(
|
|
@@ -462,24 +482,25 @@ async def _validate_and_decode_jwt_token(raw_token: str) -> str:
|
|
|
462
482
|
user_id = payload.get("sub")
|
|
463
483
|
if not user_id:
|
|
464
484
|
raise HTTPException(401, "invalid_token")
|
|
465
|
-
return user_id
|
|
485
|
+
return cast("str", user_id)
|
|
466
486
|
except Exception:
|
|
467
487
|
raise HTTPException(401, "invalid_token")
|
|
468
488
|
|
|
469
489
|
|
|
470
490
|
async def _set_cookie_on_response(
|
|
471
491
|
resp: Response,
|
|
472
|
-
|
|
492
|
+
strategy: Strategy[Any, Any],
|
|
473
493
|
user: Any,
|
|
474
494
|
*,
|
|
475
495
|
refresh_raw: str,
|
|
476
496
|
) -> None:
|
|
477
497
|
"""Set authentication (JWT) and refresh cookies on response."""
|
|
478
498
|
st = get_auth_settings()
|
|
479
|
-
strategy = auth_backend.get_strategy()
|
|
480
499
|
jwt_token = await strategy.write_token(user)
|
|
481
500
|
|
|
482
|
-
same_site_lit = cast(
|
|
501
|
+
same_site_lit = cast(
|
|
502
|
+
"Literal['lax', 'strict', 'none']", str(st.session_cookie_samesite).lower()
|
|
503
|
+
)
|
|
483
504
|
if same_site_lit == "none" and not bool(st.session_cookie_secure):
|
|
484
505
|
raise HTTPException(500, "session_cookie_samesite=None requires session_cookie_secure=True")
|
|
485
506
|
|
|
@@ -529,7 +550,7 @@ async def _handle_mfa_redirect(
|
|
|
529
550
|
def oauth_router_with_backend(
|
|
530
551
|
user_model: type,
|
|
531
552
|
auth_backend: AuthenticationBackend,
|
|
532
|
-
providers:
|
|
553
|
+
providers: dict[str, dict[str, Any]],
|
|
533
554
|
post_login_redirect: str = "/",
|
|
534
555
|
provider_account_model: type | None = None,
|
|
535
556
|
auth_policy: AuthPolicy | None = None,
|
|
@@ -547,7 +568,7 @@ def oauth_router_with_backend(
|
|
|
547
568
|
def _create_oauth_router(
|
|
548
569
|
user_model: type,
|
|
549
570
|
auth_backend: AuthenticationBackend,
|
|
550
|
-
providers:
|
|
571
|
+
providers: dict[str, dict[str, Any]],
|
|
551
572
|
post_login_redirect: str = "/",
|
|
552
573
|
provider_account_model: type | None = None,
|
|
553
574
|
auth_policy: AuthPolicy | None = None,
|
|
@@ -600,7 +621,12 @@ def _create_oauth_router(
|
|
|
600
621
|
responses={302: {"description": "Redirect to app (or MFA redirect)."}},
|
|
601
622
|
description="OAuth callback endpoint.",
|
|
602
623
|
)
|
|
603
|
-
async def oauth_callback(
|
|
624
|
+
async def oauth_callback(
|
|
625
|
+
request: Request,
|
|
626
|
+
provider: str,
|
|
627
|
+
session: SqlSessionDep,
|
|
628
|
+
strategy: Strategy[Any, Any] = Depends(auth_backend.get_strategy),
|
|
629
|
+
):
|
|
604
630
|
"""Handle OAuth callback and complete authentication."""
|
|
605
631
|
# Handle provider-side errors up front
|
|
606
632
|
if err := request.query_params.get("error"):
|
|
@@ -621,14 +647,20 @@ def _create_oauth_router(
|
|
|
621
647
|
|
|
622
648
|
# Extract user information from provider
|
|
623
649
|
cfg = providers.get(provider, {})
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
650
|
+
(
|
|
651
|
+
email,
|
|
652
|
+
full_name,
|
|
653
|
+
provider_user_id,
|
|
654
|
+
email_verified,
|
|
655
|
+
raw_claims,
|
|
656
|
+
) = await _extract_user_info_from_provider(request, client, token, provider, cfg, nonce)
|
|
627
657
|
|
|
628
658
|
if email_verified is False:
|
|
629
659
|
raise HTTPException(400, "unverified_email")
|
|
630
660
|
if not email:
|
|
631
661
|
raise HTTPException(400, "No email from provider")
|
|
662
|
+
if not provider_user_id:
|
|
663
|
+
raise HTTPException(400, "No user ID from provider")
|
|
632
664
|
|
|
633
665
|
# Process user authentication
|
|
634
666
|
user = await _process_user_authentication(
|
|
@@ -657,7 +689,7 @@ def _create_oauth_router(
|
|
|
657
689
|
return mfa_response
|
|
658
690
|
|
|
659
691
|
# NEW: set last_login only when we are actually logging in now
|
|
660
|
-
user.last_login = datetime.now(
|
|
692
|
+
user.last_login = datetime.now(UTC)
|
|
661
693
|
await session.commit()
|
|
662
694
|
|
|
663
695
|
# Create session + initial refresh token
|
|
@@ -669,9 +701,24 @@ def _create_oauth_router(
|
|
|
669
701
|
ip_hash=None,
|
|
670
702
|
)
|
|
671
703
|
|
|
672
|
-
#
|
|
704
|
+
# Generate JWT token for the response
|
|
705
|
+
jwt_token = await strategy.write_token(user)
|
|
706
|
+
|
|
707
|
+
# If redirecting to a different origin, append token as URL fragment for frontend to extract
|
|
708
|
+
# This handles cross-port scenarios like localhost:8000 -> localhost:3000
|
|
709
|
+
parsed_redirect = urlparse(redirect_url)
|
|
710
|
+
request_origin = f"{request.url.scheme}://{request.url.netloc}"
|
|
711
|
+
redirect_origin = f"{parsed_redirect.scheme}://{parsed_redirect.netloc}"
|
|
712
|
+
|
|
713
|
+
if redirect_origin and redirect_origin != request_origin:
|
|
714
|
+
# Cross-origin redirect: append token as URL fragment
|
|
715
|
+
# Fragment is not sent to server, only accessible to client-side JS
|
|
716
|
+
separator = "#" if not parsed_redirect.fragment else "&"
|
|
717
|
+
redirect_url = f"{redirect_url}{separator}access_token={jwt_token}"
|
|
718
|
+
|
|
719
|
+
# Create response with auth + refresh cookies (for same-origin requests)
|
|
673
720
|
resp = RedirectResponse(url=redirect_url, status_code=status.HTTP_302_FOUND)
|
|
674
|
-
await _set_cookie_on_response(resp,
|
|
721
|
+
await _set_cookie_on_response(resp, strategy, user, refresh_raw=raw_refresh)
|
|
675
722
|
|
|
676
723
|
# Clean up session state
|
|
677
724
|
_clean_oauth_session_state(request, provider)
|
|
@@ -691,7 +738,11 @@ def _create_oauth_router(
|
|
|
691
738
|
responses={204: {"description": "Cookie refreshed"}},
|
|
692
739
|
description="Refresh authentication token.",
|
|
693
740
|
)
|
|
694
|
-
async def refresh(
|
|
741
|
+
async def refresh(
|
|
742
|
+
request: Request,
|
|
743
|
+
session: SqlSessionDep,
|
|
744
|
+
strategy: Strategy[Any, Any] = Depends(auth_backend.get_strategy),
|
|
745
|
+
):
|
|
695
746
|
"""Refresh authentication token."""
|
|
696
747
|
st = get_auth_settings()
|
|
697
748
|
|
|
@@ -705,7 +756,7 @@ def _create_oauth_router(
|
|
|
705
756
|
user_id = await _validate_and_decode_jwt_token(raw_auth)
|
|
706
757
|
|
|
707
758
|
# Load user
|
|
708
|
-
user = await session.get(user_model, user_id)
|
|
759
|
+
user = await cast("Any", session).get(user_model, user_id)
|
|
709
760
|
if not user:
|
|
710
761
|
raise HTTPException(401, "invalid_token")
|
|
711
762
|
|
|
@@ -733,22 +784,17 @@ def _create_oauth_router(
|
|
|
733
784
|
if (
|
|
734
785
|
not found
|
|
735
786
|
or found.revoked_at
|
|
736
|
-
or (found.expires_at and found.expires_at < datetime.now(
|
|
787
|
+
or (found.expires_at and found.expires_at < datetime.now(UTC))
|
|
737
788
|
):
|
|
738
789
|
raise HTTPException(401, "invalid_refresh_token")
|
|
739
790
|
|
|
740
791
|
# Rotate refresh token
|
|
741
|
-
|
|
742
|
-
new_raw, _new_rt = await rotate_session_refresh(session, current=found)
|
|
743
|
-
except ValueError:
|
|
744
|
-
# Token expired between validation and rotation; treat as invalid
|
|
745
|
-
raise HTTPException(401, "invalid_refresh_token") from None
|
|
792
|
+
new_raw, _new_rt = await rotate_session_refresh(session, current=found)
|
|
746
793
|
|
|
747
794
|
# Write response (204) with new cookies
|
|
748
795
|
resp = Response(status_code=status.HTTP_204_NO_CONTENT)
|
|
749
|
-
await _set_cookie_on_response(resp,
|
|
750
|
-
|
|
751
|
-
# Dead code removed: MFA branch handled earlier in login flow, refresh returns 204 above.
|
|
796
|
+
await _set_cookie_on_response(resp, strategy, user, refresh_raw=new_raw)
|
|
797
|
+
# Policy hook: trigger after successful rotation; suppress hook errors
|
|
752
798
|
if hasattr(policy, "on_token_refresh"):
|
|
753
799
|
try:
|
|
754
800
|
await policy.on_token_refresh(user)
|