skrift 0.1.0a12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- skrift/__init__.py +1 -0
- skrift/__main__.py +12 -0
- skrift/admin/__init__.py +11 -0
- skrift/admin/controller.py +452 -0
- skrift/admin/navigation.py +105 -0
- skrift/alembic/env.py +92 -0
- skrift/alembic/script.py.mako +26 -0
- skrift/alembic/versions/20260120_210154_09b0364dbb7b_initial_schema.py +70 -0
- skrift/alembic/versions/20260122_152744_0b7c927d2591_add_roles_and_permissions.py +57 -0
- skrift/alembic/versions/20260122_172836_cdf734a5b847_add_sa_orm_sentinel_column.py +31 -0
- skrift/alembic/versions/20260122_175637_a9c55348eae7_remove_page_type_column.py +43 -0
- skrift/alembic/versions/20260122_200000_add_settings_table.py +38 -0
- skrift/alembic/versions/20260129_add_oauth_accounts.py +141 -0
- skrift/alembic/versions/20260129_add_provider_metadata.py +29 -0
- skrift/alembic.ini +77 -0
- skrift/asgi.py +670 -0
- skrift/auth/__init__.py +58 -0
- skrift/auth/guards.py +130 -0
- skrift/auth/roles.py +129 -0
- skrift/auth/services.py +184 -0
- skrift/cli.py +143 -0
- skrift/config.py +259 -0
- skrift/controllers/__init__.py +4 -0
- skrift/controllers/auth.py +595 -0
- skrift/controllers/web.py +67 -0
- skrift/db/__init__.py +3 -0
- skrift/db/base.py +7 -0
- skrift/db/models/__init__.py +7 -0
- skrift/db/models/oauth_account.py +50 -0
- skrift/db/models/page.py +26 -0
- skrift/db/models/role.py +56 -0
- skrift/db/models/setting.py +13 -0
- skrift/db/models/user.py +36 -0
- skrift/db/services/__init__.py +1 -0
- skrift/db/services/oauth_service.py +195 -0
- skrift/db/services/page_service.py +217 -0
- skrift/db/services/setting_service.py +206 -0
- skrift/lib/__init__.py +3 -0
- skrift/lib/exceptions.py +168 -0
- skrift/lib/template.py +108 -0
- skrift/setup/__init__.py +14 -0
- skrift/setup/config_writer.py +213 -0
- skrift/setup/controller.py +888 -0
- skrift/setup/middleware.py +89 -0
- skrift/setup/providers.py +214 -0
- skrift/setup/state.py +315 -0
- skrift/static/css/style.css +1003 -0
- skrift/templates/admin/admin.html +19 -0
- skrift/templates/admin/base.html +24 -0
- skrift/templates/admin/pages/edit.html +32 -0
- skrift/templates/admin/pages/list.html +62 -0
- skrift/templates/admin/settings/site.html +32 -0
- skrift/templates/admin/users/list.html +58 -0
- skrift/templates/admin/users/roles.html +42 -0
- skrift/templates/auth/dummy_login.html +102 -0
- skrift/templates/auth/login.html +139 -0
- skrift/templates/base.html +52 -0
- skrift/templates/error-404.html +19 -0
- skrift/templates/error-500.html +19 -0
- skrift/templates/error.html +19 -0
- skrift/templates/index.html +9 -0
- skrift/templates/page.html +26 -0
- skrift/templates/setup/admin.html +24 -0
- skrift/templates/setup/auth.html +110 -0
- skrift/templates/setup/base.html +407 -0
- skrift/templates/setup/complete.html +17 -0
- skrift/templates/setup/configuring.html +158 -0
- skrift/templates/setup/database.html +125 -0
- skrift/templates/setup/restart.html +28 -0
- skrift/templates/setup/site.html +39 -0
- skrift-0.1.0a12.dist-info/METADATA +235 -0
- skrift-0.1.0a12.dist-info/RECORD +74 -0
- skrift-0.1.0a12.dist-info/WHEEL +4 -0
- skrift-0.1.0a12.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,595 @@
|
|
|
1
|
+
"""Authentication controller for OAuth login flows.
|
|
2
|
+
|
|
3
|
+
Supports multiple OAuth providers: Google, GitHub, Microsoft, Discord, Facebook, X (Twitter).
|
|
4
|
+
Also supports a development-only "dummy" provider for testing.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import base64
|
|
8
|
+
import fnmatch
|
|
9
|
+
import hashlib
|
|
10
|
+
import secrets
|
|
11
|
+
from datetime import UTC, datetime
|
|
12
|
+
from typing import Annotated
|
|
13
|
+
from urllib.parse import urlencode, urlparse
|
|
14
|
+
|
|
15
|
+
import httpx
|
|
16
|
+
from litestar import Controller, Request, get, post
|
|
17
|
+
from litestar.exceptions import HTTPException, NotFoundException
|
|
18
|
+
from litestar.params import Parameter
|
|
19
|
+
from litestar.response import Redirect, Template as TemplateResponse
|
|
20
|
+
from sqlalchemy import select
|
|
21
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
22
|
+
from sqlalchemy.orm import selectinload
|
|
23
|
+
|
|
24
|
+
from skrift.config import get_settings
|
|
25
|
+
from skrift.db.models.oauth_account import OAuthAccount
|
|
26
|
+
from skrift.db.models.user import User
|
|
27
|
+
from skrift.setup.providers import DUMMY_PROVIDER_KEY, OAUTH_PROVIDERS, get_provider_info
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _is_safe_redirect_url(url: str, allowed_domains: list[str]) -> bool:
|
|
31
|
+
"""Check if URL is safe to redirect to.
|
|
32
|
+
|
|
33
|
+
Supports wildcard patterns using fnmatch-style matching:
|
|
34
|
+
- "*.example.com" matches any subdomain of example.com
|
|
35
|
+
- "app-*.example.com" matches app-foo.example.com, app-bar.example.com, etc.
|
|
36
|
+
- "example.com" (no wildcards) matches example.com and all subdomains
|
|
37
|
+
"""
|
|
38
|
+
# Relative paths are always safe (but not protocol-relative //domain.com)
|
|
39
|
+
if url.startswith("/") and not url.startswith("//"):
|
|
40
|
+
return True
|
|
41
|
+
|
|
42
|
+
# Parse absolute URL
|
|
43
|
+
try:
|
|
44
|
+
parsed = urlparse(url)
|
|
45
|
+
except Exception:
|
|
46
|
+
return False
|
|
47
|
+
|
|
48
|
+
# Must have scheme and netloc
|
|
49
|
+
if not parsed.scheme or not parsed.netloc:
|
|
50
|
+
return False
|
|
51
|
+
|
|
52
|
+
# Only allow http/https
|
|
53
|
+
if parsed.scheme not in ("http", "https"):
|
|
54
|
+
return False
|
|
55
|
+
|
|
56
|
+
# Check if domain matches allowed list
|
|
57
|
+
host = parsed.netloc.lower().split(":")[0] # Remove port
|
|
58
|
+
for pattern in allowed_domains:
|
|
59
|
+
pattern = pattern.lower()
|
|
60
|
+
# If pattern contains wildcards, use fnmatch
|
|
61
|
+
if "*" in pattern or "?" in pattern:
|
|
62
|
+
if fnmatch.fnmatch(host, pattern):
|
|
63
|
+
return True
|
|
64
|
+
else:
|
|
65
|
+
# No wildcards: exact match or subdomain match
|
|
66
|
+
if host == pattern or host.endswith(f".{pattern}"):
|
|
67
|
+
return True
|
|
68
|
+
|
|
69
|
+
return False
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _get_safe_redirect_url(request: Request, allowed_domains: list[str], default: str = "/") -> str:
|
|
73
|
+
"""Get the next redirect URL from session, validating it's safe."""
|
|
74
|
+
next_url = request.session.pop("auth_next", None)
|
|
75
|
+
if next_url and _is_safe_redirect_url(next_url, allowed_domains):
|
|
76
|
+
return next_url
|
|
77
|
+
return default
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def get_auth_url(provider: str, settings, state: str, code_challenge: str | None = None) -> str:
|
|
81
|
+
"""Build the OAuth authorization URL for a provider."""
|
|
82
|
+
provider_info = get_provider_info(provider)
|
|
83
|
+
if not provider_info:
|
|
84
|
+
raise ValueError(f"Unknown provider: {provider}")
|
|
85
|
+
|
|
86
|
+
provider_config = settings.auth.providers.get(provider)
|
|
87
|
+
if not provider_config:
|
|
88
|
+
raise ValueError(f"Provider {provider} not configured")
|
|
89
|
+
|
|
90
|
+
# Build auth URL (handle Microsoft tenant placeholder)
|
|
91
|
+
auth_url = provider_info.auth_url
|
|
92
|
+
if "{tenant}" in auth_url:
|
|
93
|
+
tenant = getattr(provider_config, "tenant_id", None) or "common"
|
|
94
|
+
auth_url = auth_url.replace("{tenant}", tenant)
|
|
95
|
+
|
|
96
|
+
params = {
|
|
97
|
+
"client_id": provider_config.client_id,
|
|
98
|
+
"redirect_uri": settings.auth.get_redirect_uri(provider),
|
|
99
|
+
"response_type": "code",
|
|
100
|
+
"scope": " ".join(provider_config.scopes),
|
|
101
|
+
"state": state,
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
# Provider-specific parameters
|
|
105
|
+
if provider == "google":
|
|
106
|
+
params["access_type"] = "offline"
|
|
107
|
+
params["prompt"] = "select_account"
|
|
108
|
+
elif provider == "twitter":
|
|
109
|
+
# Twitter requires PKCE
|
|
110
|
+
if code_challenge:
|
|
111
|
+
params["code_challenge"] = code_challenge
|
|
112
|
+
params["code_challenge_method"] = "S256"
|
|
113
|
+
elif provider == "discord":
|
|
114
|
+
params["prompt"] = "consent"
|
|
115
|
+
|
|
116
|
+
return f"{auth_url}?{urlencode(params)}"
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
async def exchange_code_for_token(
|
|
120
|
+
provider: str, settings, code: str, code_verifier: str | None = None
|
|
121
|
+
) -> dict:
|
|
122
|
+
"""Exchange authorization code for access token."""
|
|
123
|
+
provider_info = get_provider_info(provider)
|
|
124
|
+
if not provider_info:
|
|
125
|
+
raise ValueError(f"Unknown provider: {provider}")
|
|
126
|
+
|
|
127
|
+
provider_config = settings.auth.providers.get(provider)
|
|
128
|
+
if not provider_config:
|
|
129
|
+
raise ValueError(f"Provider {provider} not configured")
|
|
130
|
+
|
|
131
|
+
# Build token URL (handle Microsoft tenant placeholder)
|
|
132
|
+
token_url = provider_info.token_url
|
|
133
|
+
if "{tenant}" in token_url:
|
|
134
|
+
tenant = getattr(provider_config, "tenant_id", None) or "common"
|
|
135
|
+
token_url = token_url.replace("{tenant}", tenant)
|
|
136
|
+
|
|
137
|
+
data = {
|
|
138
|
+
"client_id": provider_config.client_id,
|
|
139
|
+
"client_secret": provider_config.client_secret,
|
|
140
|
+
"code": code,
|
|
141
|
+
"grant_type": "authorization_code",
|
|
142
|
+
"redirect_uri": settings.auth.get_redirect_uri(provider),
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
# Twitter requires PKCE code_verifier
|
|
146
|
+
if provider == "twitter" and code_verifier:
|
|
147
|
+
data["code_verifier"] = code_verifier
|
|
148
|
+
|
|
149
|
+
headers = {"Accept": "application/json"}
|
|
150
|
+
|
|
151
|
+
# GitHub needs special Accept header
|
|
152
|
+
if provider == "github":
|
|
153
|
+
headers["Accept"] = "application/json"
|
|
154
|
+
|
|
155
|
+
# Twitter uses Basic auth for token exchange
|
|
156
|
+
if provider == "twitter":
|
|
157
|
+
credentials = base64.b64encode(
|
|
158
|
+
f"{provider_config.client_id}:{provider_config.client_secret}".encode()
|
|
159
|
+
).decode()
|
|
160
|
+
headers["Authorization"] = f"Basic {credentials}"
|
|
161
|
+
del data["client_secret"]
|
|
162
|
+
|
|
163
|
+
async with httpx.AsyncClient() as client:
|
|
164
|
+
response = await client.post(token_url, data=data, headers=headers)
|
|
165
|
+
|
|
166
|
+
if response.status_code != 200:
|
|
167
|
+
raise HTTPException(
|
|
168
|
+
status_code=400,
|
|
169
|
+
detail=f"Failed to exchange code for tokens: {response.text}",
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
return response.json()
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
async def fetch_user_info(provider: str, access_token: str) -> dict:
|
|
176
|
+
"""Fetch user information from the OAuth provider."""
|
|
177
|
+
provider_info = get_provider_info(provider)
|
|
178
|
+
if not provider_info:
|
|
179
|
+
raise ValueError(f"Unknown provider: {provider}")
|
|
180
|
+
|
|
181
|
+
headers = {"Authorization": f"Bearer {access_token}"}
|
|
182
|
+
|
|
183
|
+
async with httpx.AsyncClient() as client:
|
|
184
|
+
response = await client.get(provider_info.userinfo_url, headers=headers)
|
|
185
|
+
|
|
186
|
+
if response.status_code != 200:
|
|
187
|
+
raise HTTPException(status_code=400, detail="Failed to fetch user info")
|
|
188
|
+
|
|
189
|
+
user_info = response.json()
|
|
190
|
+
|
|
191
|
+
# GitHub requires separate email fetch if email is private
|
|
192
|
+
if provider == "github" and not user_info.get("email"):
|
|
193
|
+
email_response = await client.get(
|
|
194
|
+
"https://api.github.com/user/emails", headers=headers
|
|
195
|
+
)
|
|
196
|
+
if email_response.status_code == 200:
|
|
197
|
+
emails = email_response.json()
|
|
198
|
+
primary_email = next(
|
|
199
|
+
(e["email"] for e in emails if e.get("primary")), None
|
|
200
|
+
)
|
|
201
|
+
if primary_email:
|
|
202
|
+
user_info["email"] = primary_email
|
|
203
|
+
|
|
204
|
+
# Twitter has different structure
|
|
205
|
+
if provider == "twitter":
|
|
206
|
+
data = user_info.get("data", {})
|
|
207
|
+
user_info = {
|
|
208
|
+
"id": data.get("id"),
|
|
209
|
+
"name": data.get("name"),
|
|
210
|
+
"username": data.get("username"),
|
|
211
|
+
"email": None, # Twitter OAuth 2.0 doesn't provide email by default
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
return user_info
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def extract_user_data(provider: str, user_info: dict) -> dict:
|
|
218
|
+
"""Extract normalized user data from provider-specific response."""
|
|
219
|
+
if provider == "google":
|
|
220
|
+
return {
|
|
221
|
+
"oauth_id": user_info.get("id"),
|
|
222
|
+
"email": user_info.get("email"),
|
|
223
|
+
"name": user_info.get("name"),
|
|
224
|
+
"picture_url": user_info.get("picture"),
|
|
225
|
+
}
|
|
226
|
+
elif provider == "github":
|
|
227
|
+
return {
|
|
228
|
+
"oauth_id": str(user_info.get("id")),
|
|
229
|
+
"email": user_info.get("email"),
|
|
230
|
+
"name": user_info.get("name") or user_info.get("login"),
|
|
231
|
+
"picture_url": user_info.get("avatar_url"),
|
|
232
|
+
}
|
|
233
|
+
elif provider == "microsoft":
|
|
234
|
+
return {
|
|
235
|
+
"oauth_id": user_info.get("id"),
|
|
236
|
+
"email": user_info.get("mail") or user_info.get("userPrincipalName"),
|
|
237
|
+
"name": user_info.get("displayName"),
|
|
238
|
+
"picture_url": None, # Microsoft Graph requires separate call for photo
|
|
239
|
+
}
|
|
240
|
+
elif provider == "discord":
|
|
241
|
+
avatar = user_info.get("avatar")
|
|
242
|
+
user_id = user_info.get("id")
|
|
243
|
+
avatar_url = None
|
|
244
|
+
if avatar and user_id:
|
|
245
|
+
avatar_url = f"https://cdn.discordapp.com/avatars/{user_id}/{avatar}.png"
|
|
246
|
+
return {
|
|
247
|
+
"oauth_id": user_id,
|
|
248
|
+
"email": user_info.get("email"),
|
|
249
|
+
"name": user_info.get("global_name") or user_info.get("username"),
|
|
250
|
+
"picture_url": avatar_url,
|
|
251
|
+
}
|
|
252
|
+
elif provider == "facebook":
|
|
253
|
+
picture = user_info.get("picture", {}).get("data", {})
|
|
254
|
+
return {
|
|
255
|
+
"oauth_id": user_info.get("id"),
|
|
256
|
+
"email": user_info.get("email"),
|
|
257
|
+
"name": user_info.get("name"),
|
|
258
|
+
"picture_url": picture.get("url") if not picture.get("is_silhouette") else None,
|
|
259
|
+
}
|
|
260
|
+
elif provider == "twitter":
|
|
261
|
+
return {
|
|
262
|
+
"oauth_id": user_info.get("id"),
|
|
263
|
+
"email": user_info.get("email"),
|
|
264
|
+
"name": user_info.get("name") or user_info.get("username"),
|
|
265
|
+
"picture_url": None,
|
|
266
|
+
}
|
|
267
|
+
else:
|
|
268
|
+
return {
|
|
269
|
+
"oauth_id": str(user_info.get("id", user_info.get("sub"))),
|
|
270
|
+
"email": user_info.get("email"),
|
|
271
|
+
"name": user_info.get("name"),
|
|
272
|
+
"picture_url": user_info.get("picture"),
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
class AuthController(Controller):
|
|
277
|
+
path = "/auth"
|
|
278
|
+
|
|
279
|
+
@get("/{provider:str}/login")
|
|
280
|
+
async def oauth_login(
|
|
281
|
+
self,
|
|
282
|
+
request: Request,
|
|
283
|
+
provider: str,
|
|
284
|
+
next_url: Annotated[str | None, Parameter(query="next")] = None,
|
|
285
|
+
) -> Redirect | TemplateResponse:
|
|
286
|
+
"""Redirect to OAuth provider consent screen, or show dummy login form."""
|
|
287
|
+
settings = get_settings()
|
|
288
|
+
provider_info = get_provider_info(provider)
|
|
289
|
+
|
|
290
|
+
# Store next URL in session if provided and valid
|
|
291
|
+
if next_url and _is_safe_redirect_url(next_url, settings.auth.allowed_redirect_domains):
|
|
292
|
+
request.session["auth_next"] = next_url
|
|
293
|
+
|
|
294
|
+
if not provider_info:
|
|
295
|
+
raise NotFoundException(f"Unknown provider: {provider}")
|
|
296
|
+
|
|
297
|
+
if provider not in settings.auth.providers:
|
|
298
|
+
raise NotFoundException(f"Provider {provider} not configured")
|
|
299
|
+
|
|
300
|
+
# Dummy provider shows local login form instead of redirecting to OAuth
|
|
301
|
+
if provider == DUMMY_PROVIDER_KEY:
|
|
302
|
+
flash = request.session.pop("flash", None)
|
|
303
|
+
return TemplateResponse(
|
|
304
|
+
"auth/dummy_login.html",
|
|
305
|
+
context={"flash": flash},
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# Generate CSRF state token
|
|
309
|
+
state = secrets.token_urlsafe(32)
|
|
310
|
+
request.session["oauth_state"] = state
|
|
311
|
+
request.session["oauth_provider"] = provider
|
|
312
|
+
|
|
313
|
+
# Generate PKCE for Twitter
|
|
314
|
+
code_challenge = None
|
|
315
|
+
if provider == "twitter":
|
|
316
|
+
code_verifier = secrets.token_urlsafe(64)[:128]
|
|
317
|
+
request.session["oauth_code_verifier"] = code_verifier
|
|
318
|
+
# S256 challenge
|
|
319
|
+
code_challenge = base64.urlsafe_b64encode(
|
|
320
|
+
hashlib.sha256(code_verifier.encode()).digest()
|
|
321
|
+
).decode().rstrip("=")
|
|
322
|
+
|
|
323
|
+
auth_url = get_auth_url(provider, settings, state, code_challenge)
|
|
324
|
+
return Redirect(path=auth_url)
|
|
325
|
+
|
|
326
|
+
@get("/{provider:str}/callback")
|
|
327
|
+
async def oauth_callback(
|
|
328
|
+
self,
|
|
329
|
+
request: Request,
|
|
330
|
+
db_session: AsyncSession,
|
|
331
|
+
provider: str,
|
|
332
|
+
code: str | None = None,
|
|
333
|
+
oauth_state: Annotated[str | None, Parameter(query="state")] = None,
|
|
334
|
+
error: str | None = None,
|
|
335
|
+
) -> Redirect:
|
|
336
|
+
"""Handle OAuth callback from provider."""
|
|
337
|
+
settings = get_settings()
|
|
338
|
+
provider_info = get_provider_info(provider)
|
|
339
|
+
|
|
340
|
+
if not provider_info:
|
|
341
|
+
raise NotFoundException(f"Unknown provider: {provider}")
|
|
342
|
+
|
|
343
|
+
# Check for OAuth errors
|
|
344
|
+
if error:
|
|
345
|
+
request.session["flash"] = f"OAuth error: {error}"
|
|
346
|
+
return Redirect(path="/auth/login")
|
|
347
|
+
|
|
348
|
+
# Verify CSRF state
|
|
349
|
+
stored_state = request.session.pop("oauth_state", None)
|
|
350
|
+
if not oauth_state or oauth_state != stored_state:
|
|
351
|
+
raise HTTPException(status_code=400, detail="Invalid OAuth state")
|
|
352
|
+
|
|
353
|
+
if not code:
|
|
354
|
+
raise HTTPException(status_code=400, detail="Missing authorization code")
|
|
355
|
+
|
|
356
|
+
# Get PKCE verifier if present (for Twitter)
|
|
357
|
+
code_verifier = request.session.pop("oauth_code_verifier", None)
|
|
358
|
+
|
|
359
|
+
# Exchange code for tokens
|
|
360
|
+
tokens = await exchange_code_for_token(
|
|
361
|
+
provider, settings, code, code_verifier
|
|
362
|
+
)
|
|
363
|
+
access_token = tokens.get("access_token")
|
|
364
|
+
|
|
365
|
+
if not access_token:
|
|
366
|
+
raise HTTPException(status_code=400, detail="No access token received")
|
|
367
|
+
|
|
368
|
+
# Fetch user info
|
|
369
|
+
user_info = await fetch_user_info(provider, access_token)
|
|
370
|
+
user_data = extract_user_data(provider, user_info)
|
|
371
|
+
|
|
372
|
+
oauth_id = user_data["oauth_id"]
|
|
373
|
+
if not oauth_id:
|
|
374
|
+
raise HTTPException(status_code=400, detail="Could not determine user ID")
|
|
375
|
+
|
|
376
|
+
email = user_data["email"]
|
|
377
|
+
|
|
378
|
+
# Step 1: Check if OAuth account already exists
|
|
379
|
+
result = await db_session.execute(
|
|
380
|
+
select(OAuthAccount)
|
|
381
|
+
.options(selectinload(OAuthAccount.user))
|
|
382
|
+
.where(OAuthAccount.provider == provider, OAuthAccount.provider_account_id == oauth_id)
|
|
383
|
+
)
|
|
384
|
+
oauth_account = result.scalar_one_or_none()
|
|
385
|
+
|
|
386
|
+
if oauth_account:
|
|
387
|
+
# Existing OAuth account - update user profile
|
|
388
|
+
user = oauth_account.user
|
|
389
|
+
user.name = user_data["name"]
|
|
390
|
+
if user_data["picture_url"]:
|
|
391
|
+
user.picture_url = user_data["picture_url"]
|
|
392
|
+
user.last_login_at = datetime.now(UTC)
|
|
393
|
+
# Update provider email if changed
|
|
394
|
+
if email:
|
|
395
|
+
oauth_account.provider_email = email
|
|
396
|
+
# Update provider metadata
|
|
397
|
+
oauth_account.provider_metadata = user_info
|
|
398
|
+
else:
|
|
399
|
+
# Step 2: Check if a user with this email already exists
|
|
400
|
+
user = None
|
|
401
|
+
if email:
|
|
402
|
+
result = await db_session.execute(
|
|
403
|
+
select(User).where(User.email == email)
|
|
404
|
+
)
|
|
405
|
+
user = result.scalar_one_or_none()
|
|
406
|
+
|
|
407
|
+
if user:
|
|
408
|
+
# Link new OAuth account to existing user
|
|
409
|
+
oauth_account = OAuthAccount(
|
|
410
|
+
provider=provider,
|
|
411
|
+
provider_account_id=oauth_id,
|
|
412
|
+
provider_email=email,
|
|
413
|
+
provider_metadata=user_info,
|
|
414
|
+
user_id=user.id,
|
|
415
|
+
)
|
|
416
|
+
db_session.add(oauth_account)
|
|
417
|
+
# Update user profile
|
|
418
|
+
user.name = user_data["name"]
|
|
419
|
+
if user_data["picture_url"]:
|
|
420
|
+
user.picture_url = user_data["picture_url"]
|
|
421
|
+
user.last_login_at = datetime.now(UTC)
|
|
422
|
+
else:
|
|
423
|
+
# Step 3: Create new user + OAuth account
|
|
424
|
+
user = User(
|
|
425
|
+
email=email,
|
|
426
|
+
name=user_data["name"],
|
|
427
|
+
picture_url=user_data["picture_url"],
|
|
428
|
+
last_login_at=datetime.now(UTC),
|
|
429
|
+
)
|
|
430
|
+
db_session.add(user)
|
|
431
|
+
await db_session.flush()
|
|
432
|
+
|
|
433
|
+
oauth_account = OAuthAccount(
|
|
434
|
+
provider=provider,
|
|
435
|
+
provider_account_id=oauth_id,
|
|
436
|
+
provider_email=email,
|
|
437
|
+
provider_metadata=user_info,
|
|
438
|
+
user_id=user.id,
|
|
439
|
+
)
|
|
440
|
+
db_session.add(oauth_account)
|
|
441
|
+
|
|
442
|
+
await db_session.commit()
|
|
443
|
+
|
|
444
|
+
# Set session with user info
|
|
445
|
+
request.session["user_id"] = str(user.id)
|
|
446
|
+
request.session["user_name"] = user.name
|
|
447
|
+
request.session["user_email"] = user.email
|
|
448
|
+
request.session["user_picture_url"] = user.picture_url
|
|
449
|
+
request.session["flash"] = "Successfully logged in!"
|
|
450
|
+
|
|
451
|
+
return Redirect(path=_get_safe_redirect_url(request, settings.auth.allowed_redirect_domains))
|
|
452
|
+
|
|
453
|
+
@get("/login")
|
|
454
|
+
async def login_page(
|
|
455
|
+
self,
|
|
456
|
+
request: Request,
|
|
457
|
+
next_url: Annotated[str | None, Parameter(query="next")] = None,
|
|
458
|
+
) -> TemplateResponse:
|
|
459
|
+
"""Show login page with available providers."""
|
|
460
|
+
flash = request.session.pop("flash", None)
|
|
461
|
+
settings = get_settings()
|
|
462
|
+
|
|
463
|
+
# Store next URL in session if provided and valid
|
|
464
|
+
if next_url and _is_safe_redirect_url(next_url, settings.auth.allowed_redirect_domains):
|
|
465
|
+
request.session["auth_next"] = next_url
|
|
466
|
+
|
|
467
|
+
# Get configured providers (excluding dummy from main list)
|
|
468
|
+
configured_providers = list(settings.auth.providers.keys())
|
|
469
|
+
providers = {
|
|
470
|
+
key: OAUTH_PROVIDERS[key]
|
|
471
|
+
for key in configured_providers
|
|
472
|
+
if key in OAUTH_PROVIDERS and key != DUMMY_PROVIDER_KEY
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
# Check if dummy provider is configured
|
|
476
|
+
has_dummy = DUMMY_PROVIDER_KEY in settings.auth.providers
|
|
477
|
+
|
|
478
|
+
return TemplateResponse(
|
|
479
|
+
"auth/login.html",
|
|
480
|
+
context={
|
|
481
|
+
"flash": flash,
|
|
482
|
+
"providers": providers,
|
|
483
|
+
"has_dummy": has_dummy,
|
|
484
|
+
},
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
@post("/dummy-login")
|
|
488
|
+
async def dummy_login_submit(
|
|
489
|
+
self,
|
|
490
|
+
request: Request,
|
|
491
|
+
db_session: AsyncSession,
|
|
492
|
+
) -> Redirect:
|
|
493
|
+
"""Process dummy login form submission."""
|
|
494
|
+
settings = get_settings()
|
|
495
|
+
|
|
496
|
+
if DUMMY_PROVIDER_KEY not in settings.auth.providers:
|
|
497
|
+
raise NotFoundException("Dummy provider not configured")
|
|
498
|
+
|
|
499
|
+
# Parse form data from request
|
|
500
|
+
form_data = await request.form()
|
|
501
|
+
email = form_data.get("email", "").strip()
|
|
502
|
+
name = form_data.get("name", "").strip()
|
|
503
|
+
|
|
504
|
+
if not email:
|
|
505
|
+
request.session["flash"] = "Email is required"
|
|
506
|
+
return Redirect(path="/auth/dummy/login")
|
|
507
|
+
|
|
508
|
+
# Default name to email username if not provided
|
|
509
|
+
if not name:
|
|
510
|
+
name = email.split("@")[0]
|
|
511
|
+
|
|
512
|
+
# Generate deterministic oauth_id from email
|
|
513
|
+
oauth_id = f"dummy_{hashlib.sha256(email.encode()).hexdigest()[:16]}"
|
|
514
|
+
|
|
515
|
+
# Create synthetic metadata for dummy provider
|
|
516
|
+
dummy_metadata = {
|
|
517
|
+
"id": oauth_id,
|
|
518
|
+
"email": email,
|
|
519
|
+
"name": name,
|
|
520
|
+
}
|
|
521
|
+
|
|
522
|
+
# Step 1: Check if OAuth account already exists
|
|
523
|
+
result = await db_session.execute(
|
|
524
|
+
select(OAuthAccount)
|
|
525
|
+
.options(selectinload(OAuthAccount.user))
|
|
526
|
+
.where(
|
|
527
|
+
OAuthAccount.provider == DUMMY_PROVIDER_KEY,
|
|
528
|
+
OAuthAccount.provider_account_id == oauth_id,
|
|
529
|
+
)
|
|
530
|
+
)
|
|
531
|
+
oauth_account = result.scalar_one_or_none()
|
|
532
|
+
|
|
533
|
+
if oauth_account:
|
|
534
|
+
# Existing OAuth account - update user profile
|
|
535
|
+
user = oauth_account.user
|
|
536
|
+
user.name = name
|
|
537
|
+
user.email = email
|
|
538
|
+
user.last_login_at = datetime.now(UTC)
|
|
539
|
+
oauth_account.provider_email = email
|
|
540
|
+
oauth_account.provider_metadata = dummy_metadata
|
|
541
|
+
else:
|
|
542
|
+
# Step 2: Check if a user with this email already exists
|
|
543
|
+
result = await db_session.execute(
|
|
544
|
+
select(User).where(User.email == email)
|
|
545
|
+
)
|
|
546
|
+
user = result.scalar_one_or_none()
|
|
547
|
+
|
|
548
|
+
if user:
|
|
549
|
+
# Link new OAuth account to existing user
|
|
550
|
+
oauth_account = OAuthAccount(
|
|
551
|
+
provider=DUMMY_PROVIDER_KEY,
|
|
552
|
+
provider_account_id=oauth_id,
|
|
553
|
+
provider_email=email,
|
|
554
|
+
provider_metadata=dummy_metadata,
|
|
555
|
+
user_id=user.id,
|
|
556
|
+
)
|
|
557
|
+
db_session.add(oauth_account)
|
|
558
|
+
# Update user profile
|
|
559
|
+
user.name = name
|
|
560
|
+
user.last_login_at = datetime.now(UTC)
|
|
561
|
+
else:
|
|
562
|
+
# Step 3: Create new user + OAuth account
|
|
563
|
+
user = User(
|
|
564
|
+
email=email,
|
|
565
|
+
name=name,
|
|
566
|
+
last_login_at=datetime.now(UTC),
|
|
567
|
+
)
|
|
568
|
+
db_session.add(user)
|
|
569
|
+
await db_session.flush()
|
|
570
|
+
|
|
571
|
+
oauth_account = OAuthAccount(
|
|
572
|
+
provider=DUMMY_PROVIDER_KEY,
|
|
573
|
+
provider_account_id=oauth_id,
|
|
574
|
+
provider_email=email,
|
|
575
|
+
provider_metadata=dummy_metadata,
|
|
576
|
+
user_id=user.id,
|
|
577
|
+
)
|
|
578
|
+
db_session.add(oauth_account)
|
|
579
|
+
|
|
580
|
+
await db_session.commit()
|
|
581
|
+
|
|
582
|
+
# Set session with user info
|
|
583
|
+
request.session["user_id"] = str(user.id)
|
|
584
|
+
request.session["user_name"] = user.name
|
|
585
|
+
request.session["user_email"] = user.email
|
|
586
|
+
request.session["user_picture_url"] = user.picture_url
|
|
587
|
+
request.session["flash"] = "Successfully logged in!"
|
|
588
|
+
|
|
589
|
+
return Redirect(path=_get_safe_redirect_url(request, settings.auth.allowed_redirect_domains))
|
|
590
|
+
|
|
591
|
+
@get("/logout")
|
|
592
|
+
async def logout(self, request: Request) -> Redirect:
|
|
593
|
+
"""Clear session and redirect to home."""
|
|
594
|
+
request.session.clear()
|
|
595
|
+
return Redirect(path="/")
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from uuid import UUID
|
|
3
|
+
|
|
4
|
+
from litestar import Controller, Request, get
|
|
5
|
+
from litestar.exceptions import NotFoundException
|
|
6
|
+
from litestar.response import Template as TemplateResponse
|
|
7
|
+
from sqlalchemy import select
|
|
8
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
9
|
+
|
|
10
|
+
from skrift.db.models.user import User
|
|
11
|
+
from skrift.db.services import page_service
|
|
12
|
+
from skrift.lib.template import Template
|
|
13
|
+
|
|
14
|
+
TEMPLATE_DIR = Path(__file__).parent.parent.parent / "templates"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class WebController(Controller):
|
|
18
|
+
path = "/"
|
|
19
|
+
|
|
20
|
+
async def _get_user_context(
|
|
21
|
+
self, request: "Request", db_session: AsyncSession
|
|
22
|
+
) -> dict:
|
|
23
|
+
"""Get user data for template context if logged in."""
|
|
24
|
+
user_id = request.session.get("user_id")
|
|
25
|
+
if not user_id:
|
|
26
|
+
return {"user": None}
|
|
27
|
+
|
|
28
|
+
result = await db_session.execute(select(User).where(User.id == UUID(user_id)))
|
|
29
|
+
user = result.scalar_one_or_none()
|
|
30
|
+
return {"user": user}
|
|
31
|
+
|
|
32
|
+
@get("/")
|
|
33
|
+
async def index(
|
|
34
|
+
self, request: "Request", db_session: AsyncSession
|
|
35
|
+
) -> TemplateResponse:
|
|
36
|
+
"""Home page."""
|
|
37
|
+
user_ctx = await self._get_user_context(request, db_session)
|
|
38
|
+
flash = request.session.pop("flash", None)
|
|
39
|
+
|
|
40
|
+
return TemplateResponse(
|
|
41
|
+
"index.html",
|
|
42
|
+
context={"flash": flash, **user_ctx},
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
@get("/{path:path}")
|
|
46
|
+
async def view_page(
|
|
47
|
+
self, request: "Request", db_session: AsyncSession, path: str
|
|
48
|
+
) -> TemplateResponse:
|
|
49
|
+
"""View a page by path with WP-like template resolution."""
|
|
50
|
+
user_ctx = await self._get_user_context(request, db_session)
|
|
51
|
+
flash = request.session.pop("flash", None)
|
|
52
|
+
|
|
53
|
+
# Split path into slugs (e.g., "services/web" -> ["services", "web"])
|
|
54
|
+
slugs = [s for s in path.split("/") if s]
|
|
55
|
+
|
|
56
|
+
# Use the full path as the slug for database lookup
|
|
57
|
+
page_slug = "/".join(slugs)
|
|
58
|
+
|
|
59
|
+
# Fetch page from database
|
|
60
|
+
page = await page_service.get_page_by_slug(
|
|
61
|
+
db_session, page_slug, published_only=not request.session.get("user_id")
|
|
62
|
+
)
|
|
63
|
+
if not page:
|
|
64
|
+
raise NotFoundException(f"Page '{path}' not found")
|
|
65
|
+
|
|
66
|
+
template = Template("page", *slugs, context={"path": path, "slugs": slugs, "page": page})
|
|
67
|
+
return template.render(TEMPLATE_DIR, flash=flash, **user_ctx)
|
skrift/db/__init__.py
ADDED
skrift/db/base.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
from skrift.db.models.oauth_account import OAuthAccount
|
|
2
|
+
from skrift.db.models.page import Page
|
|
3
|
+
from skrift.db.models.role import Role, RolePermission, user_roles
|
|
4
|
+
from skrift.db.models.setting import Setting
|
|
5
|
+
from skrift.db.models.user import User
|
|
6
|
+
|
|
7
|
+
__all__ = ["OAuthAccount", "Page", "Role", "RolePermission", "Setting", "User", "user_roles"]
|