skrift 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- skrift/__init__.py +1 -0
- skrift/__main__.py +17 -0
- skrift/admin/__init__.py +11 -0
- skrift/admin/controller.py +452 -0
- skrift/admin/navigation.py +105 -0
- skrift/alembic/env.py +91 -0
- skrift/alembic/script.py.mako +26 -0
- skrift/alembic/versions/20260120_210154_09b0364dbb7b_initial_schema.py +70 -0
- skrift/alembic/versions/20260122_152744_0b7c927d2591_add_roles_and_permissions.py +57 -0
- skrift/alembic/versions/20260122_172836_cdf734a5b847_add_sa_orm_sentinel_column.py +31 -0
- skrift/alembic/versions/20260122_175637_a9c55348eae7_remove_page_type_column.py +43 -0
- skrift/alembic/versions/20260122_200000_add_settings_table.py +38 -0
- skrift/alembic.ini +77 -0
- skrift/asgi.py +545 -0
- skrift/auth/__init__.py +58 -0
- skrift/auth/guards.py +130 -0
- skrift/auth/roles.py +94 -0
- skrift/auth/services.py +184 -0
- skrift/cli.py +45 -0
- skrift/config.py +192 -0
- skrift/controllers/__init__.py +4 -0
- skrift/controllers/auth.py +371 -0
- skrift/controllers/web.py +67 -0
- skrift/db/__init__.py +3 -0
- skrift/db/base.py +7 -0
- skrift/db/models/__init__.py +6 -0
- skrift/db/models/page.py +26 -0
- skrift/db/models/role.py +56 -0
- skrift/db/models/setting.py +13 -0
- skrift/db/models/user.py +36 -0
- skrift/db/services/__init__.py +1 -0
- skrift/db/services/page_service.py +217 -0
- skrift/db/services/setting_service.py +206 -0
- skrift/lib/__init__.py +3 -0
- skrift/lib/exceptions.py +168 -0
- skrift/lib/template.py +108 -0
- skrift/setup/__init__.py +14 -0
- skrift/setup/config_writer.py +211 -0
- skrift/setup/controller.py +751 -0
- skrift/setup/middleware.py +89 -0
- skrift/setup/providers.py +163 -0
- skrift/setup/state.py +134 -0
- skrift/static/css/style.css +998 -0
- skrift/templates/admin/admin.html +19 -0
- skrift/templates/admin/base.html +24 -0
- skrift/templates/admin/pages/edit.html +32 -0
- skrift/templates/admin/pages/list.html +62 -0
- skrift/templates/admin/settings/site.html +32 -0
- skrift/templates/admin/users/list.html +58 -0
- skrift/templates/admin/users/roles.html +42 -0
- skrift/templates/auth/login.html +125 -0
- skrift/templates/base.html +52 -0
- skrift/templates/error-404.html +19 -0
- skrift/templates/error-500.html +19 -0
- skrift/templates/error.html +19 -0
- skrift/templates/index.html +9 -0
- skrift/templates/page.html +26 -0
- skrift/templates/setup/admin.html +24 -0
- skrift/templates/setup/auth.html +110 -0
- skrift/templates/setup/base.html +407 -0
- skrift/templates/setup/complete.html +17 -0
- skrift/templates/setup/database.html +125 -0
- skrift/templates/setup/restart.html +28 -0
- skrift/templates/setup/site.html +39 -0
- skrift-0.1.0a1.dist-info/METADATA +233 -0
- skrift-0.1.0a1.dist-info/RECORD +68 -0
- skrift-0.1.0a1.dist-info/WHEEL +4 -0
- skrift-0.1.0a1.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,751 @@
|
|
|
1
|
+
"""Setup wizard controller for first-time Skrift configuration."""
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import hashlib
|
|
5
|
+
import secrets
|
|
6
|
+
import subprocess
|
|
7
|
+
from contextlib import asynccontextmanager
|
|
8
|
+
from datetime import UTC, datetime
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from urllib.parse import urlencode
|
|
11
|
+
|
|
12
|
+
import httpx
|
|
13
|
+
from typing import Annotated
|
|
14
|
+
|
|
15
|
+
from litestar import Controller, Request, get, post
|
|
16
|
+
from litestar.exceptions import HTTPException
|
|
17
|
+
from litestar.params import Parameter
|
|
18
|
+
from litestar.response import Redirect, Template as TemplateResponse
|
|
19
|
+
from sqlalchemy import func, select
|
|
20
|
+
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
|
21
|
+
|
|
22
|
+
from skrift.db.models.role import Role, user_roles
|
|
23
|
+
from skrift.db.models.user import User
|
|
24
|
+
from skrift.db.services import setting_service
|
|
25
|
+
from skrift.db.services.setting_service import (
|
|
26
|
+
SETUP_COMPLETED_AT_KEY,
|
|
27
|
+
get_setting,
|
|
28
|
+
)
|
|
29
|
+
from skrift.setup.config_writer import (
|
|
30
|
+
load_config,
|
|
31
|
+
update_auth_config,
|
|
32
|
+
update_database_config,
|
|
33
|
+
)
|
|
34
|
+
from skrift.setup.providers import get_all_providers, get_provider_info
|
|
35
|
+
from skrift.setup.state import can_connect_to_database, app_yaml_exists, get_database_url_from_yaml
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@asynccontextmanager
|
|
39
|
+
async def get_setup_db_session():
|
|
40
|
+
"""Create a database session for setup operations.
|
|
41
|
+
|
|
42
|
+
This is used during setup when the SQLAlchemy plugin isn't available.
|
|
43
|
+
"""
|
|
44
|
+
db_url = get_database_url_from_yaml()
|
|
45
|
+
if not db_url:
|
|
46
|
+
raise RuntimeError("Database not configured")
|
|
47
|
+
|
|
48
|
+
engine = create_async_engine(db_url)
|
|
49
|
+
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
|
50
|
+
|
|
51
|
+
async with async_session() as session:
|
|
52
|
+
try:
|
|
53
|
+
yield session
|
|
54
|
+
await session.commit()
|
|
55
|
+
except Exception:
|
|
56
|
+
await session.rollback()
|
|
57
|
+
raise
|
|
58
|
+
finally:
|
|
59
|
+
await engine.dispose()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class SetupController(Controller):
|
|
63
|
+
"""Controller for the setup wizard."""
|
|
64
|
+
|
|
65
|
+
path = "/setup"
|
|
66
|
+
|
|
67
|
+
async def _check_already_complete(self) -> bool:
|
|
68
|
+
"""Defense in depth: check if setup is already complete."""
|
|
69
|
+
try:
|
|
70
|
+
async with get_setup_db_session() as db_session:
|
|
71
|
+
value = await get_setting(db_session, SETUP_COMPLETED_AT_KEY)
|
|
72
|
+
return value is not None
|
|
73
|
+
except Exception:
|
|
74
|
+
return False
|
|
75
|
+
|
|
76
|
+
@get("/")
|
|
77
|
+
async def index(self, request: Request) -> Redirect:
|
|
78
|
+
"""Redirect to appropriate setup step."""
|
|
79
|
+
# Check wizard progress from session
|
|
80
|
+
wizard_step = request.session.get("setup_wizard_step", "database")
|
|
81
|
+
|
|
82
|
+
# If we don't have app.yaml or db isn't configured, start at database
|
|
83
|
+
if not app_yaml_exists():
|
|
84
|
+
return Redirect(path="/setup/database")
|
|
85
|
+
|
|
86
|
+
can_connect, _ = await can_connect_to_database()
|
|
87
|
+
if not can_connect:
|
|
88
|
+
return Redirect(path="/setup/database")
|
|
89
|
+
|
|
90
|
+
# Otherwise go to the saved step
|
|
91
|
+
return Redirect(path=f"/setup/{wizard_step}")
|
|
92
|
+
|
|
93
|
+
@get("/database")
|
|
94
|
+
async def database_step(self, request: Request) -> TemplateResponse:
|
|
95
|
+
"""Step 1: Database configuration."""
|
|
96
|
+
flash = request.session.pop("flash", None)
|
|
97
|
+
error = request.session.pop("setup_error", None)
|
|
98
|
+
|
|
99
|
+
# Load current config if exists
|
|
100
|
+
config = load_config()
|
|
101
|
+
db_config = config.get("db", {})
|
|
102
|
+
current_url = db_config.get("url", "")
|
|
103
|
+
|
|
104
|
+
# Determine current type
|
|
105
|
+
db_type = "sqlite"
|
|
106
|
+
if "postgresql" in current_url:
|
|
107
|
+
db_type = "postgresql"
|
|
108
|
+
|
|
109
|
+
return TemplateResponse(
|
|
110
|
+
"setup/database.html",
|
|
111
|
+
context={
|
|
112
|
+
"flash": flash,
|
|
113
|
+
"error": error,
|
|
114
|
+
"step": 1,
|
|
115
|
+
"total_steps": 4,
|
|
116
|
+
"db_type": db_type,
|
|
117
|
+
"current_url": current_url,
|
|
118
|
+
},
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
@post("/database")
|
|
122
|
+
async def save_database(self, request: Request) -> Redirect:
|
|
123
|
+
"""Save database configuration."""
|
|
124
|
+
form_data = await request.form()
|
|
125
|
+
db_type = form_data.get("db_type", "sqlite")
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
if db_type == "sqlite":
|
|
129
|
+
file_path = form_data.get("sqlite_path", "./app.db")
|
|
130
|
+
use_env = form_data.get("sqlite_path_env") == "on"
|
|
131
|
+
|
|
132
|
+
update_database_config(
|
|
133
|
+
db_type="sqlite",
|
|
134
|
+
url=file_path,
|
|
135
|
+
use_env_vars={"url": use_env},
|
|
136
|
+
)
|
|
137
|
+
else:
|
|
138
|
+
# PostgreSQL
|
|
139
|
+
use_env_url = form_data.get("pg_url_env") == "on"
|
|
140
|
+
|
|
141
|
+
if use_env_url:
|
|
142
|
+
env_var = form_data.get("pg_url_envvar", "DATABASE_URL")
|
|
143
|
+
update_database_config(
|
|
144
|
+
db_type="postgresql",
|
|
145
|
+
url=env_var,
|
|
146
|
+
use_env_vars={"url": True},
|
|
147
|
+
)
|
|
148
|
+
else:
|
|
149
|
+
host = form_data.get("pg_host", "localhost")
|
|
150
|
+
port = int(form_data.get("pg_port", 5432))
|
|
151
|
+
database = form_data.get("pg_database", "skrift")
|
|
152
|
+
username = form_data.get("pg_username", "postgres")
|
|
153
|
+
password = form_data.get("pg_password", "")
|
|
154
|
+
|
|
155
|
+
update_database_config(
|
|
156
|
+
db_type="postgresql",
|
|
157
|
+
host=host,
|
|
158
|
+
port=port,
|
|
159
|
+
database=database,
|
|
160
|
+
username=username,
|
|
161
|
+
password=password,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
# Test connection
|
|
165
|
+
can_connect, error = await can_connect_to_database()
|
|
166
|
+
if not can_connect:
|
|
167
|
+
request.session["setup_error"] = f"Connection failed: {error}"
|
|
168
|
+
return Redirect(path="/setup/database")
|
|
169
|
+
|
|
170
|
+
# Run migrations
|
|
171
|
+
try:
|
|
172
|
+
result = subprocess.run(
|
|
173
|
+
["skrift-db", "upgrade", "head"],
|
|
174
|
+
capture_output=True,
|
|
175
|
+
text=True,
|
|
176
|
+
cwd=Path.cwd(),
|
|
177
|
+
timeout=60,
|
|
178
|
+
)
|
|
179
|
+
if result.returncode != 0:
|
|
180
|
+
request.session["setup_error"] = f"Migration failed: {result.stderr}"
|
|
181
|
+
return Redirect(path="/setup/database")
|
|
182
|
+
except subprocess.TimeoutExpired:
|
|
183
|
+
request.session["setup_error"] = "Migration timed out"
|
|
184
|
+
return Redirect(path="/setup/database")
|
|
185
|
+
except FileNotFoundError:
|
|
186
|
+
# skrift-db might not be installed yet, try alembic directly
|
|
187
|
+
try:
|
|
188
|
+
result = subprocess.run(
|
|
189
|
+
["alembic", "upgrade", "head"],
|
|
190
|
+
capture_output=True,
|
|
191
|
+
text=True,
|
|
192
|
+
cwd=Path.cwd(),
|
|
193
|
+
timeout=60,
|
|
194
|
+
)
|
|
195
|
+
if result.returncode != 0:
|
|
196
|
+
request.session["setup_error"] = f"Migration failed: {result.stderr}"
|
|
197
|
+
return Redirect(path="/setup/database")
|
|
198
|
+
except Exception as e:
|
|
199
|
+
request.session["setup_error"] = f"Could not run migrations: {e}"
|
|
200
|
+
return Redirect(path="/setup/database")
|
|
201
|
+
|
|
202
|
+
request.session["setup_wizard_step"] = "auth"
|
|
203
|
+
request.session["flash"] = "Database configured successfully!"
|
|
204
|
+
return Redirect(path="/setup/auth")
|
|
205
|
+
|
|
206
|
+
except Exception as e:
|
|
207
|
+
request.session["setup_error"] = str(e)
|
|
208
|
+
return Redirect(path="/setup/database")
|
|
209
|
+
|
|
210
|
+
@get("/restart")
|
|
211
|
+
async def restart_step(self, request: Request) -> Redirect:
|
|
212
|
+
"""Legacy restart route - now redirects to auth since restart is no longer required."""
|
|
213
|
+
request.session["setup_wizard_step"] = "auth"
|
|
214
|
+
return Redirect(path="/setup/auth")
|
|
215
|
+
|
|
216
|
+
@get("/auth")
|
|
217
|
+
async def auth_step(self, request: Request) -> TemplateResponse:
|
|
218
|
+
"""Step 2: Authentication providers."""
|
|
219
|
+
flash = request.session.pop("flash", None)
|
|
220
|
+
error = request.session.pop("setup_error", None)
|
|
221
|
+
|
|
222
|
+
# Get current redirect URL from request
|
|
223
|
+
scheme = request.headers.get("x-forwarded-proto", request.url.scheme)
|
|
224
|
+
host = request.headers.get("host", request.url.netloc)
|
|
225
|
+
redirect_base_url = f"{scheme}://{host}"
|
|
226
|
+
|
|
227
|
+
# Get configured providers
|
|
228
|
+
config = load_config()
|
|
229
|
+
auth_config = config.get("auth", {})
|
|
230
|
+
configured_providers = auth_config.get("providers", {})
|
|
231
|
+
|
|
232
|
+
# Get all available providers
|
|
233
|
+
all_providers = get_all_providers()
|
|
234
|
+
|
|
235
|
+
return TemplateResponse(
|
|
236
|
+
"setup/auth.html",
|
|
237
|
+
context={
|
|
238
|
+
"flash": flash,
|
|
239
|
+
"error": error,
|
|
240
|
+
"step": 2,
|
|
241
|
+
"total_steps": 4,
|
|
242
|
+
"redirect_base_url": redirect_base_url,
|
|
243
|
+
"providers": all_providers,
|
|
244
|
+
"configured_providers": configured_providers,
|
|
245
|
+
},
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
@post("/auth")
|
|
249
|
+
async def save_auth(self, request: Request) -> Redirect:
|
|
250
|
+
"""Save authentication configuration."""
|
|
251
|
+
form_data = await request.form()
|
|
252
|
+
|
|
253
|
+
# Get redirect base URL
|
|
254
|
+
redirect_base_url = form_data.get("redirect_base_url", "http://localhost:8000")
|
|
255
|
+
|
|
256
|
+
# Parse provider configurations
|
|
257
|
+
all_providers = get_all_providers()
|
|
258
|
+
providers = {}
|
|
259
|
+
use_env_vars = {}
|
|
260
|
+
|
|
261
|
+
for provider_key in all_providers.keys():
|
|
262
|
+
enabled = form_data.get(f"{provider_key}_enabled") == "on"
|
|
263
|
+
if not enabled:
|
|
264
|
+
continue
|
|
265
|
+
|
|
266
|
+
provider_info = all_providers[provider_key]
|
|
267
|
+
provider_config = {}
|
|
268
|
+
provider_env_vars = {}
|
|
269
|
+
|
|
270
|
+
for field in provider_info.fields:
|
|
271
|
+
field_key = field["key"]
|
|
272
|
+
value = form_data.get(f"{provider_key}_{field_key}", "")
|
|
273
|
+
use_env = form_data.get(f"{provider_key}_{field_key}_env") == "on"
|
|
274
|
+
|
|
275
|
+
if value or not field.get("optional"):
|
|
276
|
+
provider_config[field_key] = value
|
|
277
|
+
provider_env_vars[field_key] = use_env
|
|
278
|
+
|
|
279
|
+
if provider_config:
|
|
280
|
+
providers[provider_key] = provider_config
|
|
281
|
+
use_env_vars[provider_key] = provider_env_vars
|
|
282
|
+
|
|
283
|
+
if not providers:
|
|
284
|
+
request.session["setup_error"] = "Please configure at least one authentication provider"
|
|
285
|
+
return Redirect(path="/setup/auth")
|
|
286
|
+
|
|
287
|
+
try:
|
|
288
|
+
update_auth_config(
|
|
289
|
+
redirect_base_url=redirect_base_url,
|
|
290
|
+
providers=providers,
|
|
291
|
+
use_env_vars=use_env_vars,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
request.session["setup_wizard_step"] = "site"
|
|
295
|
+
request.session["flash"] = "Authentication configured successfully!"
|
|
296
|
+
return Redirect(path="/setup/site")
|
|
297
|
+
|
|
298
|
+
except Exception as e:
|
|
299
|
+
request.session["setup_error"] = str(e)
|
|
300
|
+
return Redirect(path="/setup/auth")
|
|
301
|
+
|
|
302
|
+
@get("/site")
|
|
303
|
+
async def site_step(self, request: Request) -> TemplateResponse:
|
|
304
|
+
"""Step 3: Site settings."""
|
|
305
|
+
flash = request.session.pop("flash", None)
|
|
306
|
+
error = request.session.pop("setup_error", None)
|
|
307
|
+
|
|
308
|
+
return TemplateResponse(
|
|
309
|
+
"setup/site.html",
|
|
310
|
+
context={
|
|
311
|
+
"flash": flash,
|
|
312
|
+
"error": error,
|
|
313
|
+
"step": 3,
|
|
314
|
+
"total_steps": 4,
|
|
315
|
+
"settings": {
|
|
316
|
+
"site_name": "",
|
|
317
|
+
"site_tagline": "",
|
|
318
|
+
"site_copyright_holder": "",
|
|
319
|
+
"site_copyright_start_year": datetime.now().year,
|
|
320
|
+
},
|
|
321
|
+
},
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
@post("/site")
|
|
325
|
+
async def save_site(self, request: Request) -> Redirect:
|
|
326
|
+
"""Save site settings."""
|
|
327
|
+
form_data = await request.form()
|
|
328
|
+
|
|
329
|
+
try:
|
|
330
|
+
site_name = form_data.get("site_name", "").strip()
|
|
331
|
+
if not site_name:
|
|
332
|
+
request.session["setup_error"] = "Site name is required"
|
|
333
|
+
return Redirect(path="/setup/site")
|
|
334
|
+
|
|
335
|
+
site_tagline = form_data.get("site_tagline", "").strip()
|
|
336
|
+
site_copyright_holder = form_data.get("site_copyright_holder", "").strip()
|
|
337
|
+
site_copyright_start_year = form_data.get("site_copyright_start_year", "").strip()
|
|
338
|
+
|
|
339
|
+
# Save settings to database using manual session
|
|
340
|
+
async with get_setup_db_session() as db_session:
|
|
341
|
+
await setting_service.set_setting(
|
|
342
|
+
db_session, setting_service.SITE_NAME_KEY, site_name
|
|
343
|
+
)
|
|
344
|
+
await setting_service.set_setting(
|
|
345
|
+
db_session, setting_service.SITE_TAGLINE_KEY, site_tagline
|
|
346
|
+
)
|
|
347
|
+
await setting_service.set_setting(
|
|
348
|
+
db_session, setting_service.SITE_COPYRIGHT_HOLDER_KEY, site_copyright_holder
|
|
349
|
+
)
|
|
350
|
+
await setting_service.set_setting(
|
|
351
|
+
db_session,
|
|
352
|
+
setting_service.SITE_COPYRIGHT_START_YEAR_KEY,
|
|
353
|
+
site_copyright_start_year,
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
# Reload cache
|
|
357
|
+
await setting_service.load_site_settings_cache(db_session)
|
|
358
|
+
|
|
359
|
+
request.session["setup_wizard_step"] = "admin"
|
|
360
|
+
request.session["flash"] = "Site settings saved!"
|
|
361
|
+
return Redirect(path="/setup/admin")
|
|
362
|
+
|
|
363
|
+
except Exception as e:
|
|
364
|
+
request.session["setup_error"] = str(e)
|
|
365
|
+
return Redirect(path="/setup/site")
|
|
366
|
+
|
|
367
|
+
@get("/admin")
|
|
368
|
+
async def admin_step(self, request: Request) -> TemplateResponse:
|
|
369
|
+
"""Step 4: Create admin account."""
|
|
370
|
+
flash = request.session.pop("flash", None)
|
|
371
|
+
error = request.session.pop("setup_error", None)
|
|
372
|
+
|
|
373
|
+
# Get configured providers
|
|
374
|
+
config = load_config()
|
|
375
|
+
auth_config = config.get("auth", {})
|
|
376
|
+
configured_providers = list(auth_config.get("providers", {}).keys())
|
|
377
|
+
|
|
378
|
+
# Get provider display info
|
|
379
|
+
all_providers = get_all_providers()
|
|
380
|
+
provider_info = {
|
|
381
|
+
key: all_providers[key] for key in configured_providers if key in all_providers
|
|
382
|
+
}
|
|
383
|
+
|
|
384
|
+
return TemplateResponse(
|
|
385
|
+
"setup/admin.html",
|
|
386
|
+
context={
|
|
387
|
+
"flash": flash,
|
|
388
|
+
"error": error,
|
|
389
|
+
"step": 4,
|
|
390
|
+
"total_steps": 4,
|
|
391
|
+
"providers": provider_info,
|
|
392
|
+
"configured_providers": configured_providers,
|
|
393
|
+
},
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
@get("/oauth/{provider:str}/login")
|
|
397
|
+
async def setup_oauth_login(self, request: Request, provider: str) -> Redirect:
|
|
398
|
+
"""Redirect to OAuth provider for setup admin creation."""
|
|
399
|
+
config = load_config()
|
|
400
|
+
auth_config = config.get("auth", {})
|
|
401
|
+
providers_config = auth_config.get("providers", {})
|
|
402
|
+
|
|
403
|
+
if provider not in providers_config:
|
|
404
|
+
raise HTTPException(status_code=404, detail=f"Provider {provider} not configured")
|
|
405
|
+
|
|
406
|
+
provider_info = get_provider_info(provider)
|
|
407
|
+
if not provider_info:
|
|
408
|
+
raise HTTPException(status_code=404, detail=f"Unknown provider: {provider}")
|
|
409
|
+
|
|
410
|
+
provider_config = providers_config[provider]
|
|
411
|
+
|
|
412
|
+
# Resolve env var references in config
|
|
413
|
+
client_id = self._resolve_env_var(provider_config.get("client_id", ""))
|
|
414
|
+
|
|
415
|
+
# Generate CSRF state token
|
|
416
|
+
state = secrets.token_urlsafe(32)
|
|
417
|
+
request.session["oauth_state"] = state
|
|
418
|
+
request.session["oauth_provider"] = provider
|
|
419
|
+
request.session["oauth_setup"] = True
|
|
420
|
+
|
|
421
|
+
# Build redirect URI - use the standard /auth callback URL
|
|
422
|
+
# This matches what's configured in the OAuth provider console
|
|
423
|
+
scheme = request.headers.get("x-forwarded-proto", request.url.scheme)
|
|
424
|
+
host = request.headers.get("host", request.url.netloc)
|
|
425
|
+
redirect_uri = f"{scheme}://{host}/auth/{provider}/callback"
|
|
426
|
+
|
|
427
|
+
# Get scopes
|
|
428
|
+
scopes = provider_config.get("scopes", provider_info.scopes)
|
|
429
|
+
|
|
430
|
+
# Generate PKCE for Twitter
|
|
431
|
+
code_challenge = None
|
|
432
|
+
if provider == "twitter":
|
|
433
|
+
code_verifier = secrets.token_urlsafe(64)[:128]
|
|
434
|
+
request.session["oauth_code_verifier"] = code_verifier
|
|
435
|
+
code_challenge = base64.urlsafe_b64encode(
|
|
436
|
+
hashlib.sha256(code_verifier.encode()).digest()
|
|
437
|
+
).decode().rstrip("=")
|
|
438
|
+
|
|
439
|
+
# Build auth URL
|
|
440
|
+
auth_url = provider_info.auth_url
|
|
441
|
+
if "{tenant}" in auth_url:
|
|
442
|
+
tenant = provider_config.get("tenant_id", "common")
|
|
443
|
+
if isinstance(tenant, str) and tenant.startswith("$"):
|
|
444
|
+
tenant = self._resolve_env_var(tenant) or "common"
|
|
445
|
+
auth_url = auth_url.replace("{tenant}", tenant)
|
|
446
|
+
|
|
447
|
+
params = {
|
|
448
|
+
"client_id": client_id,
|
|
449
|
+
"redirect_uri": redirect_uri,
|
|
450
|
+
"response_type": "code",
|
|
451
|
+
"scope": " ".join(scopes),
|
|
452
|
+
"state": state,
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
if provider == "google":
|
|
456
|
+
params["access_type"] = "offline"
|
|
457
|
+
params["prompt"] = "select_account"
|
|
458
|
+
elif provider == "twitter" and code_challenge:
|
|
459
|
+
params["code_challenge"] = code_challenge
|
|
460
|
+
params["code_challenge_method"] = "S256"
|
|
461
|
+
elif provider == "discord":
|
|
462
|
+
params["prompt"] = "consent"
|
|
463
|
+
|
|
464
|
+
return Redirect(path=f"{auth_url}?{urlencode(params)}")
|
|
465
|
+
|
|
466
|
+
def _resolve_env_var(self, value: str) -> str:
|
|
467
|
+
"""Resolve environment variable reference if value starts with $."""
|
|
468
|
+
import os
|
|
469
|
+
if value.startswith("$"):
|
|
470
|
+
return os.environ.get(value[1:], "")
|
|
471
|
+
return value
|
|
472
|
+
|
|
473
|
+
@get("/complete")
|
|
474
|
+
async def complete(self, request: Request) -> TemplateResponse | Redirect:
|
|
475
|
+
"""Setup complete page."""
|
|
476
|
+
# Verify setup is actually complete in database
|
|
477
|
+
if not await self._check_already_complete():
|
|
478
|
+
return Redirect(path="/setup")
|
|
479
|
+
|
|
480
|
+
# Clear the session flag if present
|
|
481
|
+
request.session.pop("setup_just_completed", None)
|
|
482
|
+
|
|
483
|
+
return TemplateResponse(
|
|
484
|
+
"setup/complete.html",
|
|
485
|
+
context={
|
|
486
|
+
"step": 4,
|
|
487
|
+
"total_steps": 4,
|
|
488
|
+
},
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
async def mark_setup_complete(db_session: AsyncSession | None = None) -> None:
|
|
493
|
+
"""Mark setup as complete by setting the timestamp.
|
|
494
|
+
|
|
495
|
+
Args:
|
|
496
|
+
db_session: Optional database session. If not provided, creates one.
|
|
497
|
+
"""
|
|
498
|
+
timestamp = datetime.now(UTC).isoformat()
|
|
499
|
+
if db_session:
|
|
500
|
+
await setting_service.set_setting(db_session, SETUP_COMPLETED_AT_KEY, timestamp)
|
|
501
|
+
else:
|
|
502
|
+
async with get_setup_db_session() as session:
|
|
503
|
+
await setting_service.set_setting(session, SETUP_COMPLETED_AT_KEY, timestamp)
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
class SetupAuthController(Controller):
|
|
507
|
+
"""Auth controller for setup OAuth callbacks.
|
|
508
|
+
|
|
509
|
+
This handles OAuth callbacks at /auth/{provider}/callback during setup,
|
|
510
|
+
matching the redirect URI configured in OAuth providers.
|
|
511
|
+
"""
|
|
512
|
+
|
|
513
|
+
path = "/auth"
|
|
514
|
+
|
|
515
|
+
@get("/{provider:str}/callback")
|
|
516
|
+
async def setup_oauth_callback(
|
|
517
|
+
self,
|
|
518
|
+
request: Request,
|
|
519
|
+
provider: str,
|
|
520
|
+
code: str | None = None,
|
|
521
|
+
oauth_state: Annotated[str | None, Parameter(query="state")] = None,
|
|
522
|
+
error: str | None = None,
|
|
523
|
+
) -> Redirect:
|
|
524
|
+
"""Handle OAuth callback during setup."""
|
|
525
|
+
# Check if this is a setup flow
|
|
526
|
+
if not request.session.get("oauth_setup"):
|
|
527
|
+
# Not a setup flow, return error
|
|
528
|
+
raise HTTPException(status_code=400, detail="Invalid OAuth flow")
|
|
529
|
+
|
|
530
|
+
if error:
|
|
531
|
+
request.session["setup_error"] = f"OAuth error: {error}"
|
|
532
|
+
return Redirect(path="/setup/admin")
|
|
533
|
+
|
|
534
|
+
# Verify CSRF state
|
|
535
|
+
stored_state = request.session.pop("oauth_state", None)
|
|
536
|
+
if not oauth_state or oauth_state != stored_state:
|
|
537
|
+
raise HTTPException(status_code=400, detail="Invalid OAuth state")
|
|
538
|
+
|
|
539
|
+
if not code:
|
|
540
|
+
raise HTTPException(status_code=400, detail="Missing authorization code")
|
|
541
|
+
|
|
542
|
+
config = load_config()
|
|
543
|
+
auth_config = config.get("auth", {})
|
|
544
|
+
providers_config = auth_config.get("providers", {})
|
|
545
|
+
|
|
546
|
+
if provider not in providers_config:
|
|
547
|
+
raise HTTPException(status_code=404, detail=f"Provider {provider} not configured")
|
|
548
|
+
|
|
549
|
+
provider_info = get_provider_info(provider)
|
|
550
|
+
provider_config = providers_config[provider]
|
|
551
|
+
|
|
552
|
+
# Resolve env vars
|
|
553
|
+
client_id = self._resolve_env_var(provider_config.get("client_id", ""))
|
|
554
|
+
client_secret = self._resolve_env_var(provider_config.get("client_secret", ""))
|
|
555
|
+
|
|
556
|
+
# Build redirect URI
|
|
557
|
+
scheme = request.headers.get("x-forwarded-proto", request.url.scheme)
|
|
558
|
+
host = request.headers.get("host", request.url.netloc)
|
|
559
|
+
redirect_uri = f"{scheme}://{host}/auth/{provider}/callback"
|
|
560
|
+
|
|
561
|
+
# Get PKCE verifier if present
|
|
562
|
+
code_verifier = request.session.pop("oauth_code_verifier", None)
|
|
563
|
+
|
|
564
|
+
# Exchange code for token
|
|
565
|
+
token_url = provider_info.token_url
|
|
566
|
+
if "{tenant}" in token_url:
|
|
567
|
+
tenant = provider_config.get("tenant_id", "common")
|
|
568
|
+
if isinstance(tenant, str) and tenant.startswith("$"):
|
|
569
|
+
tenant = self._resolve_env_var(tenant) or "common"
|
|
570
|
+
token_url = token_url.replace("{tenant}", tenant)
|
|
571
|
+
|
|
572
|
+
data = {
|
|
573
|
+
"client_id": client_id,
|
|
574
|
+
"client_secret": client_secret,
|
|
575
|
+
"code": code,
|
|
576
|
+
"grant_type": "authorization_code",
|
|
577
|
+
"redirect_uri": redirect_uri,
|
|
578
|
+
}
|
|
579
|
+
|
|
580
|
+
if provider == "twitter" and code_verifier:
|
|
581
|
+
data["code_verifier"] = code_verifier
|
|
582
|
+
|
|
583
|
+
headers = {"Accept": "application/json"}
|
|
584
|
+
if provider == "github":
|
|
585
|
+
headers["Accept"] = "application/json"
|
|
586
|
+
if provider == "twitter":
|
|
587
|
+
credentials = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
|
|
588
|
+
headers["Authorization"] = f"Basic {credentials}"
|
|
589
|
+
del data["client_secret"]
|
|
590
|
+
|
|
591
|
+
async with httpx.AsyncClient() as client:
|
|
592
|
+
response = await client.post(token_url, data=data, headers=headers)
|
|
593
|
+
if response.status_code != 200:
|
|
594
|
+
raise HTTPException(status_code=400, detail=f"Token exchange failed: {response.text}")
|
|
595
|
+
tokens = response.json()
|
|
596
|
+
|
|
597
|
+
access_token = tokens.get("access_token")
|
|
598
|
+
if not access_token:
|
|
599
|
+
raise HTTPException(status_code=400, detail="No access token received")
|
|
600
|
+
|
|
601
|
+
# Fetch user info
|
|
602
|
+
async with httpx.AsyncClient() as client:
|
|
603
|
+
user_headers = {"Authorization": f"Bearer {access_token}"}
|
|
604
|
+
response = await client.get(provider_info.userinfo_url, headers=user_headers)
|
|
605
|
+
if response.status_code != 200:
|
|
606
|
+
raise HTTPException(status_code=400, detail="Failed to fetch user info")
|
|
607
|
+
user_info = response.json()
|
|
608
|
+
|
|
609
|
+
# GitHub email handling
|
|
610
|
+
if provider == "github" and not user_info.get("email"):
|
|
611
|
+
email_response = await client.get("https://api.github.com/user/emails", headers=user_headers)
|
|
612
|
+
if email_response.status_code == 200:
|
|
613
|
+
emails = email_response.json()
|
|
614
|
+
primary_email = next((e["email"] for e in emails if e.get("primary")), None)
|
|
615
|
+
if primary_email:
|
|
616
|
+
user_info["email"] = primary_email
|
|
617
|
+
|
|
618
|
+
# Extract user data based on provider
|
|
619
|
+
user_data = self._extract_user_data(provider, user_info)
|
|
620
|
+
oauth_id = user_data["oauth_id"]
|
|
621
|
+
if not oauth_id:
|
|
622
|
+
raise HTTPException(status_code=400, detail="Could not determine user ID")
|
|
623
|
+
|
|
624
|
+
# Create user and mark setup complete
|
|
625
|
+
async with get_setup_db_session() as db_session:
|
|
626
|
+
# Check if user exists
|
|
627
|
+
result = await db_session.execute(
|
|
628
|
+
select(User).where(User.oauth_id == oauth_id, User.oauth_provider == provider)
|
|
629
|
+
)
|
|
630
|
+
user = result.scalar_one_or_none()
|
|
631
|
+
|
|
632
|
+
if user:
|
|
633
|
+
user.name = user_data["name"]
|
|
634
|
+
if user_data["picture_url"]:
|
|
635
|
+
user.picture_url = user_data["picture_url"]
|
|
636
|
+
user.last_login_at = datetime.now(UTC)
|
|
637
|
+
else:
|
|
638
|
+
# Create new user
|
|
639
|
+
user = User(
|
|
640
|
+
oauth_provider=provider,
|
|
641
|
+
oauth_id=oauth_id,
|
|
642
|
+
email=user_data["email"],
|
|
643
|
+
name=user_data["name"],
|
|
644
|
+
picture_url=user_data["picture_url"],
|
|
645
|
+
last_login_at=datetime.now(UTC),
|
|
646
|
+
)
|
|
647
|
+
db_session.add(user)
|
|
648
|
+
await db_session.flush()
|
|
649
|
+
|
|
650
|
+
# Ensure roles are synced (they may not exist if DB was created after server start)
|
|
651
|
+
from skrift.auth import sync_roles_to_database
|
|
652
|
+
await sync_roles_to_database(db_session)
|
|
653
|
+
|
|
654
|
+
# Always assign admin role during setup (whether user is new or existing)
|
|
655
|
+
admin_role = await db_session.scalar(select(Role).where(Role.name == "admin"))
|
|
656
|
+
if admin_role:
|
|
657
|
+
# Check if user already has admin role
|
|
658
|
+
existing = await db_session.execute(
|
|
659
|
+
select(user_roles).where(
|
|
660
|
+
user_roles.c.user_id == user.id,
|
|
661
|
+
user_roles.c.role_id == admin_role.id
|
|
662
|
+
)
|
|
663
|
+
)
|
|
664
|
+
if not existing.first():
|
|
665
|
+
await db_session.execute(
|
|
666
|
+
user_roles.insert().values(user_id=user.id, role_id=admin_role.id)
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
# Mark setup complete
|
|
670
|
+
timestamp = datetime.now(UTC).isoformat()
|
|
671
|
+
await setting_service.set_setting(db_session, SETUP_COMPLETED_AT_KEY, timestamp)
|
|
672
|
+
|
|
673
|
+
# Clear setup flag
|
|
674
|
+
request.session.pop("oauth_setup", None)
|
|
675
|
+
|
|
676
|
+
# Set session
|
|
677
|
+
request.session["user_id"] = str(user.id)
|
|
678
|
+
request.session["user_name"] = user.name
|
|
679
|
+
request.session["user_email"] = user.email
|
|
680
|
+
request.session["user_picture_url"] = user.picture_url
|
|
681
|
+
request.session["flash"] = "Admin account created successfully!"
|
|
682
|
+
request.session["setup_just_completed"] = True
|
|
683
|
+
|
|
684
|
+
# Note: Don't call mark_setup_complete_in_dispatcher() here.
|
|
685
|
+
# The switch happens in /setup/complete after rendering the page.
|
|
686
|
+
|
|
687
|
+
return Redirect(path="/setup/complete")
|
|
688
|
+
|
|
689
|
+
def _resolve_env_var(self, value: str) -> str:
|
|
690
|
+
"""Resolve environment variable reference if value starts with $."""
|
|
691
|
+
import os
|
|
692
|
+
if value.startswith("$"):
|
|
693
|
+
return os.environ.get(value[1:], "")
|
|
694
|
+
return value
|
|
695
|
+
|
|
696
|
+
def _extract_user_data(self, provider: str, user_info: dict) -> dict:
|
|
697
|
+
"""Extract normalized user data from provider response."""
|
|
698
|
+
if provider == "google":
|
|
699
|
+
return {
|
|
700
|
+
"oauth_id": user_info.get("id"),
|
|
701
|
+
"email": user_info.get("email"),
|
|
702
|
+
"name": user_info.get("name"),
|
|
703
|
+
"picture_url": user_info.get("picture"),
|
|
704
|
+
}
|
|
705
|
+
elif provider == "github":
|
|
706
|
+
return {
|
|
707
|
+
"oauth_id": str(user_info.get("id")),
|
|
708
|
+
"email": user_info.get("email"),
|
|
709
|
+
"name": user_info.get("name") or user_info.get("login"),
|
|
710
|
+
"picture_url": user_info.get("avatar_url"),
|
|
711
|
+
}
|
|
712
|
+
elif provider == "microsoft":
|
|
713
|
+
return {
|
|
714
|
+
"oauth_id": user_info.get("id"),
|
|
715
|
+
"email": user_info.get("mail") or user_info.get("userPrincipalName"),
|
|
716
|
+
"name": user_info.get("displayName"),
|
|
717
|
+
"picture_url": None,
|
|
718
|
+
}
|
|
719
|
+
elif provider == "discord":
|
|
720
|
+
avatar = user_info.get("avatar")
|
|
721
|
+
user_id = user_info.get("id")
|
|
722
|
+
avatar_url = f"https://cdn.discordapp.com/avatars/{user_id}/{avatar}.png" if avatar and user_id else None
|
|
723
|
+
return {
|
|
724
|
+
"oauth_id": user_id,
|
|
725
|
+
"email": user_info.get("email"),
|
|
726
|
+
"name": user_info.get("global_name") or user_info.get("username"),
|
|
727
|
+
"picture_url": avatar_url,
|
|
728
|
+
}
|
|
729
|
+
elif provider == "facebook":
|
|
730
|
+
picture = user_info.get("picture", {}).get("data", {})
|
|
731
|
+
return {
|
|
732
|
+
"oauth_id": user_info.get("id"),
|
|
733
|
+
"email": user_info.get("email"),
|
|
734
|
+
"name": user_info.get("name"),
|
|
735
|
+
"picture_url": picture.get("url") if not picture.get("is_silhouette") else None,
|
|
736
|
+
}
|
|
737
|
+
elif provider == "twitter":
|
|
738
|
+
data = user_info.get("data", user_info)
|
|
739
|
+
return {
|
|
740
|
+
"oauth_id": data.get("id"),
|
|
741
|
+
"email": data.get("email"),
|
|
742
|
+
"name": data.get("name") or data.get("username"),
|
|
743
|
+
"picture_url": None,
|
|
744
|
+
}
|
|
745
|
+
else:
|
|
746
|
+
return {
|
|
747
|
+
"oauth_id": str(user_info.get("id", user_info.get("sub"))),
|
|
748
|
+
"email": user_info.get("email"),
|
|
749
|
+
"name": user_info.get("name"),
|
|
750
|
+
"picture_url": user_info.get("picture"),
|
|
751
|
+
}
|