vibetuner 2.6.1__py3-none-any.whl → 2.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of vibetuner might be problematic. Click here for more details.
- vibetuner/__init__.py +2 -0
- vibetuner/__main__.py +4 -0
- vibetuner/cli/__init__.py +68 -0
- vibetuner/cli/run.py +161 -0
- vibetuner/config.py +128 -0
- vibetuner/context.py +25 -0
- vibetuner/frontend/AGENTS.md +113 -0
- vibetuner/frontend/CLAUDE.md +113 -0
- vibetuner/frontend/__init__.py +94 -0
- vibetuner/frontend/context.py +10 -0
- vibetuner/frontend/deps.py +41 -0
- vibetuner/frontend/email.py +45 -0
- vibetuner/frontend/hotreload.py +13 -0
- vibetuner/frontend/lifespan.py +26 -0
- vibetuner/frontend/middleware.py +151 -0
- vibetuner/frontend/oauth.py +196 -0
- vibetuner/frontend/routes/__init__.py +12 -0
- vibetuner/frontend/routes/auth.py +150 -0
- vibetuner/frontend/routes/debug.py +414 -0
- vibetuner/frontend/routes/health.py +33 -0
- vibetuner/frontend/routes/language.py +43 -0
- vibetuner/frontend/routes/meta.py +55 -0
- vibetuner/frontend/routes/user.py +94 -0
- vibetuner/frontend/templates.py +176 -0
- vibetuner/logging.py +87 -0
- vibetuner/models/AGENTS.md +165 -0
- vibetuner/models/CLAUDE.md +165 -0
- vibetuner/models/__init__.py +14 -0
- vibetuner/models/blob.py +89 -0
- vibetuner/models/email_verification.py +84 -0
- vibetuner/models/mixins.py +76 -0
- vibetuner/models/oauth.py +57 -0
- vibetuner/models/registry.py +15 -0
- vibetuner/models/types.py +16 -0
- vibetuner/models/user.py +91 -0
- vibetuner/mongo.py +18 -0
- vibetuner/paths.py +112 -0
- vibetuner/services/AGENTS.md +104 -0
- vibetuner/services/CLAUDE.md +104 -0
- vibetuner/services/__init__.py +0 -0
- vibetuner/services/blob.py +175 -0
- vibetuner/services/email.py +50 -0
- vibetuner/tasks/AGENTS.md +98 -0
- vibetuner/tasks/CLAUDE.md +98 -0
- vibetuner/tasks/__init__.py +2 -0
- vibetuner/tasks/context.py +34 -0
- vibetuner/tasks/worker.py +18 -0
- vibetuner/templates/email/AGENTS.md +48 -0
- vibetuner/templates/email/CLAUDE.md +48 -0
- vibetuner/templates/email/default/magic_link.html.jinja +16 -0
- vibetuner/templates/email/default/magic_link.txt.jinja +5 -0
- vibetuner/templates/frontend/AGENTS.md +74 -0
- vibetuner/templates/frontend/CLAUDE.md +74 -0
- vibetuner/templates/frontend/base/favicons.html.jinja +1 -0
- vibetuner/templates/frontend/base/footer.html.jinja +3 -0
- vibetuner/templates/frontend/base/header.html.jinja +0 -0
- vibetuner/templates/frontend/base/opengraph.html.jinja +7 -0
- vibetuner/templates/frontend/base/skeleton.html.jinja +42 -0
- vibetuner/templates/frontend/debug/collections.html.jinja +103 -0
- vibetuner/templates/frontend/debug/components/debug_nav.html.jinja +55 -0
- vibetuner/templates/frontend/debug/index.html.jinja +83 -0
- vibetuner/templates/frontend/debug/info.html.jinja +256 -0
- vibetuner/templates/frontend/debug/users.html.jinja +137 -0
- vibetuner/templates/frontend/debug/version.html.jinja +53 -0
- vibetuner/templates/frontend/email/magic_link.txt.jinja +5 -0
- vibetuner/templates/frontend/email_sent.html.jinja +82 -0
- vibetuner/templates/frontend/index.html.jinja +19 -0
- vibetuner/templates/frontend/lang/select.html.jinja +4 -0
- vibetuner/templates/frontend/login.html.jinja +84 -0
- vibetuner/templates/frontend/meta/browserconfig.xml.jinja +10 -0
- vibetuner/templates/frontend/meta/robots.txt.jinja +3 -0
- vibetuner/templates/frontend/meta/site.webmanifest.jinja +7 -0
- vibetuner/templates/frontend/meta/sitemap.xml.jinja +6 -0
- vibetuner/templates/frontend/user/edit.html.jinja +85 -0
- vibetuner/templates/frontend/user/profile.html.jinja +156 -0
- vibetuner/templates/markdown/.placeholder +0 -0
- vibetuner/templates/markdown/AGENTS.md +29 -0
- vibetuner/templates/markdown/CLAUDE.md +29 -0
- vibetuner/templates.py +152 -0
- vibetuner/time.py +57 -0
- vibetuner/versioning.py +8 -0
- {vibetuner-2.6.1.dist-info → vibetuner-2.7.0.dist-info}/METADATA +2 -1
- vibetuner-2.7.0.dist-info/RECORD +84 -0
- vibetuner-2.6.1.dist-info/RECORD +0 -4
- {vibetuner-2.6.1.dist-info → vibetuner-2.7.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from fastapi import APIRouter, Depends as Depends, FastAPI, Request
|
|
4
|
+
from fastapi.responses import HTMLResponse, RedirectResponse
|
|
5
|
+
from fastapi.staticfiles import StaticFiles
|
|
6
|
+
|
|
7
|
+
from vibetuner import paths
|
|
8
|
+
|
|
9
|
+
from .deps import LangDep as LangDep, MagicCookieDep as MagicCookieDep
|
|
10
|
+
from .lifespan import ctx, lifespan
|
|
11
|
+
from .middleware import middlewares
|
|
12
|
+
from .routes import auth, debug, health, language, meta, user
|
|
13
|
+
from .templates import render_template
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
_registered_routers: list[APIRouter] = []
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def register_router(router: APIRouter) -> None:
|
|
20
|
+
_registered_routers.append(router)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
import app.frontend.oauth as _app_oauth # noqa: F401
|
|
25
|
+
import app.frontend.routes as _app_routes # noqa: F401
|
|
26
|
+
except (ImportError, ModuleNotFoundError):
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
dependencies: list[Any] = [
|
|
31
|
+
# Add any dependencies that should be available globally
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
app = FastAPI(
|
|
35
|
+
debug=ctx.DEBUG,
|
|
36
|
+
lifespan=lifespan,
|
|
37
|
+
docs_url=None,
|
|
38
|
+
redoc_url=None,
|
|
39
|
+
openapi_url=None,
|
|
40
|
+
middleware=middlewares,
|
|
41
|
+
dependencies=dependencies,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# Static files
|
|
45
|
+
app.mount(f"/static/v{ctx.v_hash}/css", StaticFiles(directory=paths.css), name="css")
|
|
46
|
+
app.mount(f"/static/v{ctx.v_hash}/img", StaticFiles(directory=paths.img), name="img")
|
|
47
|
+
app.mount(f"/static/v{ctx.v_hash}/js", StaticFiles(directory=paths.js), name="js")
|
|
48
|
+
|
|
49
|
+
app.mount("/static/favicons", StaticFiles(directory=paths.favicons), name="favicons")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@app.get("/static/v{v_hash}/css/{subpath:path}", response_class=RedirectResponse)
|
|
53
|
+
@app.get("/static/css/{subpath:path}", response_class=RedirectResponse)
|
|
54
|
+
def css_redirect(request: Request, subpath: str):
|
|
55
|
+
return request.url_for("css", path=subpath).path
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@app.get("/static/v{v_hash}/img/{subpath:path}", response_class=RedirectResponse)
|
|
59
|
+
@app.get("/static/img/{subpath:path}", response_class=RedirectResponse)
|
|
60
|
+
def img_redirect(request: Request, subpath: str):
|
|
61
|
+
return request.url_for("img", path=subpath).path
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@app.get("/static/v{v_hash}/js/{subpath:path}", response_class=RedirectResponse)
|
|
65
|
+
@app.get("/static/js/{subpath:path}", response_class=RedirectResponse)
|
|
66
|
+
def js_redirect(request: Request, subpath: str):
|
|
67
|
+
return request.url_for("js", path=subpath).path
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
if ctx.DEBUG:
|
|
71
|
+
from .hotreload import hotreload
|
|
72
|
+
|
|
73
|
+
app.add_websocket_route(
|
|
74
|
+
"/hot-reload",
|
|
75
|
+
route=hotreload, # type: ignore
|
|
76
|
+
name="hot-reload",
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
app.include_router(meta.router)
|
|
80
|
+
app.include_router(auth.router)
|
|
81
|
+
app.include_router(user.router)
|
|
82
|
+
app.include_router(language.router)
|
|
83
|
+
|
|
84
|
+
for router in _registered_routers:
|
|
85
|
+
app.include_router(router)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@app.get("/", name="homepage", response_class=HTMLResponse)
|
|
89
|
+
def default_index(request: Request) -> HTMLResponse:
|
|
90
|
+
return render_template("index.html.jinja", request)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
app.include_router(debug.router)
|
|
94
|
+
app.include_router(health.router)
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from typing import Annotated, Optional
|
|
2
|
+
|
|
3
|
+
from fastapi import Depends, HTTPException, Request
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
async def require_htmx(request: Request) -> None:
|
|
7
|
+
if not request.state.htmx:
|
|
8
|
+
raise HTTPException(status_code=400, detail="HTMX header not found")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
async def enforce_lang(request: Request, lang: Optional[str] = None):
|
|
12
|
+
if lang is None or lang != request.state.language:
|
|
13
|
+
redirect_url = request.url_for(
|
|
14
|
+
request.scope["endpoint"].__name__,
|
|
15
|
+
**{**request.path_params, "lang": request.state.language},
|
|
16
|
+
).path
|
|
17
|
+
raise HTTPException(
|
|
18
|
+
status_code=307,
|
|
19
|
+
detail=f"Redirecting to canonical language: {request.state.language}",
|
|
20
|
+
headers={"Location": redirect_url},
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
return request.state.language
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
LangDep = Annotated[str, Depends(enforce_lang)]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
MAGIC_COOKIE_NAME = "magic_access"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def require_magic_cookie(request: Request) -> None:
|
|
33
|
+
"""Dependency to check if the magic access cookie is present."""
|
|
34
|
+
if MAGIC_COOKIE_NAME not in request.cookies:
|
|
35
|
+
raise HTTPException(status_code=403, detail="Access forbidden")
|
|
36
|
+
|
|
37
|
+
if request.cookies[MAGIC_COOKIE_NAME] != "granted":
|
|
38
|
+
raise HTTPException(status_code=403, detail="Access forbidden")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
MagicCookieDep = Depends(require_magic_cookie)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from pydantic import EmailStr
|
|
2
|
+
from starlette_babel import gettext_lazy as _
|
|
3
|
+
|
|
4
|
+
from vibetuner.config import settings
|
|
5
|
+
from vibetuner.services.email import SESEmailService
|
|
6
|
+
|
|
7
|
+
from .templates import render_static_template
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
async def send_magic_link_email(
|
|
11
|
+
ses_service: SESEmailService,
|
|
12
|
+
lang: str,
|
|
13
|
+
to_address: EmailStr,
|
|
14
|
+
login_url: str,
|
|
15
|
+
) -> None:
|
|
16
|
+
project_name = settings.project.project_name
|
|
17
|
+
|
|
18
|
+
html_body = render_static_template(
|
|
19
|
+
"magic_link.html",
|
|
20
|
+
namespace="email",
|
|
21
|
+
lang=lang,
|
|
22
|
+
context={
|
|
23
|
+
"login_url": str(login_url),
|
|
24
|
+
"project_name": project_name,
|
|
25
|
+
},
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
text_body = render_static_template(
|
|
29
|
+
"magic_link.txt",
|
|
30
|
+
namespace="email",
|
|
31
|
+
lang=lang,
|
|
32
|
+
context={
|
|
33
|
+
"login_url": str(login_url),
|
|
34
|
+
"project_name": project_name,
|
|
35
|
+
},
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
await ses_service.send_email(
|
|
39
|
+
subject=_("Sign in to {project_name}").format(
|
|
40
|
+
project_name=settings.project.project_name
|
|
41
|
+
),
|
|
42
|
+
html_body=html_body,
|
|
43
|
+
text_body=text_body,
|
|
44
|
+
to_address=to_address,
|
|
45
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import arel
|
|
2
|
+
|
|
3
|
+
from vibetuner.paths import css as css_path, js as js_path, templates as templates_path
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
hotreload = arel.HotReload(
|
|
7
|
+
paths=[
|
|
8
|
+
arel.Path(str(js_path)),
|
|
9
|
+
arel.Path(str(css_path)),
|
|
10
|
+
arel.Path(str(templates_path)),
|
|
11
|
+
],
|
|
12
|
+
reconnect_interval=2,
|
|
13
|
+
)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from contextlib import asynccontextmanager
|
|
2
|
+
|
|
3
|
+
from fastapi import FastAPI
|
|
4
|
+
|
|
5
|
+
from vibetuner.mongo import init_models
|
|
6
|
+
|
|
7
|
+
from .context import ctx
|
|
8
|
+
from .hotreload import hotreload
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@asynccontextmanager
|
|
12
|
+
async def lifespan(app: FastAPI):
|
|
13
|
+
if ctx.DEBUG:
|
|
14
|
+
await hotreload.startup()
|
|
15
|
+
|
|
16
|
+
await init_models()
|
|
17
|
+
# Add below anything that should happen before startup
|
|
18
|
+
|
|
19
|
+
# Until here
|
|
20
|
+
yield
|
|
21
|
+
|
|
22
|
+
# Add below anything that should happen before shutdown
|
|
23
|
+
|
|
24
|
+
# Until here
|
|
25
|
+
if ctx.DEBUG:
|
|
26
|
+
await hotreload.shutdown()
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
from fastapi import Request, Response
|
|
2
|
+
from fastapi.middleware import Middleware
|
|
3
|
+
from fastapi.requests import HTTPConnection
|
|
4
|
+
from starlette.authentication import AuthCredentials, AuthenticationBackend
|
|
5
|
+
from starlette.middleware.authentication import AuthenticationMiddleware
|
|
6
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
7
|
+
from starlette.middleware.sessions import SessionMiddleware
|
|
8
|
+
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
|
9
|
+
from starlette.types import ASGIApp, Receive, Scope, Send
|
|
10
|
+
from starlette_babel import (
|
|
11
|
+
LocaleFromCookie,
|
|
12
|
+
LocaleFromQuery,
|
|
13
|
+
LocaleMiddleware,
|
|
14
|
+
get_translator,
|
|
15
|
+
)
|
|
16
|
+
from starlette_htmx.middleware import HtmxMiddleware # type: ignore[import-untyped]
|
|
17
|
+
|
|
18
|
+
from vibetuner.config import settings
|
|
19
|
+
from vibetuner.paths import locales as locales_path
|
|
20
|
+
|
|
21
|
+
from .context import ctx
|
|
22
|
+
from .oauth import WebUser
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def locale_selector(conn: HTTPConnection) -> str | None:
|
|
26
|
+
"""
|
|
27
|
+
Selects the locale based on the first part of the path if it matches a 2-letter language code.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
parts = conn.scope.get("path", "").strip("/").split("/")
|
|
31
|
+
|
|
32
|
+
# Check if first part is a 2-letter lowercase language code
|
|
33
|
+
if parts and len(parts[0]) == 2 and parts[0].islower() and parts[0].isalpha():
|
|
34
|
+
return parts[0]
|
|
35
|
+
|
|
36
|
+
return None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def user_preference_selector(conn: HTTPConnection) -> str | None:
|
|
40
|
+
"""
|
|
41
|
+
Selects the locale based on authenticated user's language preference from session.
|
|
42
|
+
This takes priority over all other locale selectors to avoid database queries.
|
|
43
|
+
"""
|
|
44
|
+
# Check if session is available in scope
|
|
45
|
+
if "session" not in conn.scope:
|
|
46
|
+
return None
|
|
47
|
+
|
|
48
|
+
session = conn.scope["session"]
|
|
49
|
+
if not session:
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
user_data = session.get("user")
|
|
53
|
+
if not user_data:
|
|
54
|
+
return None
|
|
55
|
+
|
|
56
|
+
# Get language preference from user settings stored in session
|
|
57
|
+
user_settings = user_data.get("settings")
|
|
58
|
+
if not user_settings:
|
|
59
|
+
return None
|
|
60
|
+
|
|
61
|
+
language = user_settings.get("language")
|
|
62
|
+
if language and isinstance(language, str) and len(language) == 2:
|
|
63
|
+
return language.lower()
|
|
64
|
+
|
|
65
|
+
return None
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
shared_translator = get_translator()
|
|
69
|
+
if locales_path.exists() and locales_path.is_dir():
|
|
70
|
+
# Load translations from the locales directory
|
|
71
|
+
shared_translator.load_from_directories([locales_path])
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class AdjustLangCookieMiddleware(BaseHTTPMiddleware):
|
|
75
|
+
async def dispatch(self, request: Request, call_next):
|
|
76
|
+
response: Response = await call_next(request)
|
|
77
|
+
|
|
78
|
+
lang_cookie = request.cookies.get("language")
|
|
79
|
+
if not lang_cookie or lang_cookie != request.state.language:
|
|
80
|
+
response.set_cookie(
|
|
81
|
+
key="language", value=request.state.language, max_age=3600
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
return response
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class ForwardedProtocolMiddleware:
|
|
88
|
+
def __init__(self, app: ASGIApp):
|
|
89
|
+
self.app = app
|
|
90
|
+
|
|
91
|
+
# Based on https://github.com/encode/uvicorn/blob/master/uvicorn/middleware/proxy_headers.py
|
|
92
|
+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
93
|
+
if scope["type"] == "lifespan":
|
|
94
|
+
return await self.app(scope, receive, send)
|
|
95
|
+
|
|
96
|
+
headers = dict(scope["headers"])
|
|
97
|
+
|
|
98
|
+
if b"x-forwarded-proto" in headers:
|
|
99
|
+
x_forwarded_proto = headers[b"x-forwarded-proto"].decode("latin1").strip()
|
|
100
|
+
|
|
101
|
+
if x_forwarded_proto in {"http", "https", "ws", "wss"}:
|
|
102
|
+
if scope["type"] == "websocket":
|
|
103
|
+
scope["scheme"] = x_forwarded_proto.replace("http", "ws")
|
|
104
|
+
else:
|
|
105
|
+
scope["scheme"] = x_forwarded_proto
|
|
106
|
+
|
|
107
|
+
return await self.app(scope, receive, send)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class AuthBackend(AuthenticationBackend):
|
|
111
|
+
async def authenticate(
|
|
112
|
+
self,
|
|
113
|
+
conn: HTTPConnection,
|
|
114
|
+
) -> tuple[AuthCredentials, WebUser] | None:
|
|
115
|
+
if user := conn.session.get("user"):
|
|
116
|
+
try:
|
|
117
|
+
return (
|
|
118
|
+
AuthCredentials(["authenticated"]),
|
|
119
|
+
WebUser.model_validate(user),
|
|
120
|
+
)
|
|
121
|
+
except Exception:
|
|
122
|
+
# Clear corrupted session data and continue unauthenticated
|
|
123
|
+
conn.session.pop("user", None)
|
|
124
|
+
return None
|
|
125
|
+
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
# Until this line
|
|
130
|
+
middlewares: list[Middleware] = [
|
|
131
|
+
Middleware(TrustedHostMiddleware),
|
|
132
|
+
Middleware(ForwardedProtocolMiddleware),
|
|
133
|
+
Middleware(HtmxMiddleware),
|
|
134
|
+
Middleware(SessionMiddleware, secret_key=settings.session_key.get_secret_value()),
|
|
135
|
+
Middleware(
|
|
136
|
+
LocaleMiddleware,
|
|
137
|
+
locales=list(ctx.supported_languages),
|
|
138
|
+
default_locale=ctx.default_language,
|
|
139
|
+
selectors=[
|
|
140
|
+
LocaleFromQuery(query_param="l"),
|
|
141
|
+
locale_selector,
|
|
142
|
+
user_preference_selector,
|
|
143
|
+
LocaleFromCookie(),
|
|
144
|
+
],
|
|
145
|
+
),
|
|
146
|
+
Middleware(AdjustLangCookieMiddleware),
|
|
147
|
+
Middleware(AuthenticationMiddleware, backend=AuthBackend()),
|
|
148
|
+
# Add your middleware below this line
|
|
149
|
+
]
|
|
150
|
+
|
|
151
|
+
# EOF
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from authlib.integrations.base_client.errors import OAuthError
|
|
4
|
+
from authlib.integrations.starlette_client import OAuth # ty: ignore[unresolved-import]
|
|
5
|
+
from fastapi import Request
|
|
6
|
+
from fastapi.responses import RedirectResponse
|
|
7
|
+
from pydantic import BaseModel, Field
|
|
8
|
+
from pydantic_extra_types.language_code import LanguageAlpha2
|
|
9
|
+
from starlette.authentication import BaseUser
|
|
10
|
+
|
|
11
|
+
from vibetuner.frontend.routes import get_homepage_url
|
|
12
|
+
from vibetuner.models.oauth import OAuthAccountModel, OauthProviderModel
|
|
13
|
+
from vibetuner.models.user import UserModel
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
DEFAULT_AVATAR_IMAGE = "/statics/img/user-avatar.png"
|
|
17
|
+
|
|
18
|
+
_PROVIDERS: dict[str, OauthProviderModel] = {}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def register_oauth_provider(name: str, provider: OauthProviderModel) -> None:
|
|
22
|
+
_PROVIDERS[name] = provider
|
|
23
|
+
PROVIDER_IDENTIFIERS[name] = provider.identifier
|
|
24
|
+
_oauth_config.update(**provider.config)
|
|
25
|
+
register_kwargs = {"client_kwargs": provider.client_kwargs, **provider.params}
|
|
26
|
+
oauth.register(name, overwrite=True, **register_kwargs)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class WebUser(BaseUser, BaseModel):
|
|
30
|
+
id: str
|
|
31
|
+
name: str
|
|
32
|
+
email: str
|
|
33
|
+
picture: Optional[str] = Field(
|
|
34
|
+
default=DEFAULT_AVATAR_IMAGE,
|
|
35
|
+
description="URL to the user's avatar image",
|
|
36
|
+
)
|
|
37
|
+
language: Optional[LanguageAlpha2] = Field(
|
|
38
|
+
default=None,
|
|
39
|
+
description="Preferred language for the user",
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def is_authenticated(self) -> bool:
|
|
44
|
+
return True
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def display_name(self) -> str:
|
|
48
|
+
return self.name
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class Config:
|
|
52
|
+
def __init__(self, **kwargs):
|
|
53
|
+
self._data = kwargs
|
|
54
|
+
|
|
55
|
+
def get(self, key, default=None):
|
|
56
|
+
return self._data.get(key, default)
|
|
57
|
+
|
|
58
|
+
def update(self, **kwargs):
|
|
59
|
+
self._data.update(kwargs)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
_oauth_config = Config()
|
|
63
|
+
oauth = OAuth(_oauth_config)
|
|
64
|
+
|
|
65
|
+
PROVIDER_IDENTIFIERS: dict[str, str] = {}
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def get_oauth_providers() -> list[str]:
|
|
69
|
+
return list(_PROVIDERS.keys())
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
async def _handle_user_account(
|
|
73
|
+
provider: str, identifier: str, email: str, name: str, picture: str
|
|
74
|
+
) -> UserModel:
|
|
75
|
+
"""Handle user account creation or OAuth linking."""
|
|
76
|
+
# Check if OAuth account already exists
|
|
77
|
+
oauth_account = await OAuthAccountModel.get_by_provider_and_id(
|
|
78
|
+
provider=provider,
|
|
79
|
+
provider_user_id=identifier,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
if oauth_account:
|
|
83
|
+
# OAuth account exists, get linked user account
|
|
84
|
+
account = await UserModel.get_by_email(email)
|
|
85
|
+
if not account:
|
|
86
|
+
raise OAuthError("No account linked to this OAuth account")
|
|
87
|
+
return account
|
|
88
|
+
|
|
89
|
+
# OAuth account doesn't exist, check if user exists
|
|
90
|
+
|
|
91
|
+
if account := (await UserModel.get_by_email(email)):
|
|
92
|
+
# User exists, link OAuth account
|
|
93
|
+
await _link_oauth_account(account, provider, identifier, email, name, picture)
|
|
94
|
+
else:
|
|
95
|
+
# New user, create account and OAuth link
|
|
96
|
+
account = await _create_new_user_with_oauth(
|
|
97
|
+
provider, identifier, email, name, picture
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
return account
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
async def _link_oauth_account(
|
|
104
|
+
account: UserModel,
|
|
105
|
+
provider: str,
|
|
106
|
+
identifier: str,
|
|
107
|
+
email: str,
|
|
108
|
+
name: str,
|
|
109
|
+
picture: str,
|
|
110
|
+
) -> None:
|
|
111
|
+
"""Link OAuth account to existing user."""
|
|
112
|
+
oauth_account = OAuthAccountModel(
|
|
113
|
+
provider=provider,
|
|
114
|
+
provider_user_id=identifier,
|
|
115
|
+
email=email,
|
|
116
|
+
name=name,
|
|
117
|
+
picture=picture,
|
|
118
|
+
)
|
|
119
|
+
await oauth_account.insert()
|
|
120
|
+
account.oauth_accounts.append(oauth_account)
|
|
121
|
+
await account.save()
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
async def _create_new_user_with_oauth(
|
|
125
|
+
provider: str, identifier: str, email: str, name: str, picture: str
|
|
126
|
+
) -> UserModel:
|
|
127
|
+
"""Create new user account with OAuth linking."""
|
|
128
|
+
# Create user account
|
|
129
|
+
oauth_account = OAuthAccountModel(
|
|
130
|
+
provider=provider,
|
|
131
|
+
provider_user_id=identifier,
|
|
132
|
+
email=email,
|
|
133
|
+
name=name,
|
|
134
|
+
picture=picture,
|
|
135
|
+
)
|
|
136
|
+
await oauth_account.insert()
|
|
137
|
+
|
|
138
|
+
account = UserModel(
|
|
139
|
+
email=email,
|
|
140
|
+
name=name,
|
|
141
|
+
picture=picture,
|
|
142
|
+
oauth_accounts=[oauth_account],
|
|
143
|
+
)
|
|
144
|
+
await account.insert()
|
|
145
|
+
|
|
146
|
+
return account
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _create_auth_login_handler(provider_name: str):
|
|
150
|
+
async def auth_login(request: Request, next: str | None = None):
|
|
151
|
+
redirect_uri = request.url_for(f"auth_with_{provider_name}")
|
|
152
|
+
request.session["next_url"] = next or get_homepage_url(request)
|
|
153
|
+
client = oauth.create_client(provider_name)
|
|
154
|
+
if not client:
|
|
155
|
+
return RedirectResponse(url=get_homepage_url(request))
|
|
156
|
+
|
|
157
|
+
return await client.authorize_redirect(
|
|
158
|
+
request, redirect_uri, hl=request.state.language
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
return auth_login
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _create_auth_handler(provider_name: str):
|
|
165
|
+
async def auth_handler(request: Request):
|
|
166
|
+
"""Handle OAuth authentication flow."""
|
|
167
|
+
try:
|
|
168
|
+
# Initialize OAuth client
|
|
169
|
+
client = oauth.create_client(provider_name)
|
|
170
|
+
if not client:
|
|
171
|
+
return get_homepage_url(request)
|
|
172
|
+
|
|
173
|
+
# Get user info from OAuth provider
|
|
174
|
+
token = await client.authorize_access_token(request)
|
|
175
|
+
userinfo = token.get("userinfo")
|
|
176
|
+
if not userinfo:
|
|
177
|
+
raise OAuthError("No userinfo found in token")
|
|
178
|
+
|
|
179
|
+
# Extract user data
|
|
180
|
+
identifier = userinfo.get(PROVIDER_IDENTIFIERS[provider_name])
|
|
181
|
+
email = userinfo.get("email")
|
|
182
|
+
name = userinfo.get("name")
|
|
183
|
+
picture = userinfo.get("picture")
|
|
184
|
+
|
|
185
|
+
# Handle user account creation/linking
|
|
186
|
+
account = await _handle_user_account(
|
|
187
|
+
provider_name, identifier, email, name, picture
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# Set session and redirect
|
|
191
|
+
request.session["user"] = account.session_dict
|
|
192
|
+
return request.session.pop("next_url", get_homepage_url(request))
|
|
193
|
+
except OAuthError:
|
|
194
|
+
return get_homepage_url(request)
|
|
195
|
+
|
|
196
|
+
return auth_handler
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from fastapi import Request
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def get_homepage_url(request: Request, path_only: bool = True) -> str:
|
|
5
|
+
"""Get homepage URL for the current language."""
|
|
6
|
+
try:
|
|
7
|
+
url = request.url_for("homepage", lang=request.state.language)
|
|
8
|
+
except Exception:
|
|
9
|
+
# Fallback to default language if the requested language is not available
|
|
10
|
+
url = request.url_for("homepage")
|
|
11
|
+
|
|
12
|
+
return url.path if path_only else str(url)
|