skrift 0.1.0a12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. skrift/__init__.py +1 -0
  2. skrift/__main__.py +12 -0
  3. skrift/admin/__init__.py +11 -0
  4. skrift/admin/controller.py +452 -0
  5. skrift/admin/navigation.py +105 -0
  6. skrift/alembic/env.py +92 -0
  7. skrift/alembic/script.py.mako +26 -0
  8. skrift/alembic/versions/20260120_210154_09b0364dbb7b_initial_schema.py +70 -0
  9. skrift/alembic/versions/20260122_152744_0b7c927d2591_add_roles_and_permissions.py +57 -0
  10. skrift/alembic/versions/20260122_172836_cdf734a5b847_add_sa_orm_sentinel_column.py +31 -0
  11. skrift/alembic/versions/20260122_175637_a9c55348eae7_remove_page_type_column.py +43 -0
  12. skrift/alembic/versions/20260122_200000_add_settings_table.py +38 -0
  13. skrift/alembic/versions/20260129_add_oauth_accounts.py +141 -0
  14. skrift/alembic/versions/20260129_add_provider_metadata.py +29 -0
  15. skrift/alembic.ini +77 -0
  16. skrift/asgi.py +670 -0
  17. skrift/auth/__init__.py +58 -0
  18. skrift/auth/guards.py +130 -0
  19. skrift/auth/roles.py +129 -0
  20. skrift/auth/services.py +184 -0
  21. skrift/cli.py +143 -0
  22. skrift/config.py +259 -0
  23. skrift/controllers/__init__.py +4 -0
  24. skrift/controllers/auth.py +595 -0
  25. skrift/controllers/web.py +67 -0
  26. skrift/db/__init__.py +3 -0
  27. skrift/db/base.py +7 -0
  28. skrift/db/models/__init__.py +7 -0
  29. skrift/db/models/oauth_account.py +50 -0
  30. skrift/db/models/page.py +26 -0
  31. skrift/db/models/role.py +56 -0
  32. skrift/db/models/setting.py +13 -0
  33. skrift/db/models/user.py +36 -0
  34. skrift/db/services/__init__.py +1 -0
  35. skrift/db/services/oauth_service.py +195 -0
  36. skrift/db/services/page_service.py +217 -0
  37. skrift/db/services/setting_service.py +206 -0
  38. skrift/lib/__init__.py +3 -0
  39. skrift/lib/exceptions.py +168 -0
  40. skrift/lib/template.py +108 -0
  41. skrift/setup/__init__.py +14 -0
  42. skrift/setup/config_writer.py +213 -0
  43. skrift/setup/controller.py +888 -0
  44. skrift/setup/middleware.py +89 -0
  45. skrift/setup/providers.py +214 -0
  46. skrift/setup/state.py +315 -0
  47. skrift/static/css/style.css +1003 -0
  48. skrift/templates/admin/admin.html +19 -0
  49. skrift/templates/admin/base.html +24 -0
  50. skrift/templates/admin/pages/edit.html +32 -0
  51. skrift/templates/admin/pages/list.html +62 -0
  52. skrift/templates/admin/settings/site.html +32 -0
  53. skrift/templates/admin/users/list.html +58 -0
  54. skrift/templates/admin/users/roles.html +42 -0
  55. skrift/templates/auth/dummy_login.html +102 -0
  56. skrift/templates/auth/login.html +139 -0
  57. skrift/templates/base.html +52 -0
  58. skrift/templates/error-404.html +19 -0
  59. skrift/templates/error-500.html +19 -0
  60. skrift/templates/error.html +19 -0
  61. skrift/templates/index.html +9 -0
  62. skrift/templates/page.html +26 -0
  63. skrift/templates/setup/admin.html +24 -0
  64. skrift/templates/setup/auth.html +110 -0
  65. skrift/templates/setup/base.html +407 -0
  66. skrift/templates/setup/complete.html +17 -0
  67. skrift/templates/setup/configuring.html +158 -0
  68. skrift/templates/setup/database.html +125 -0
  69. skrift/templates/setup/restart.html +28 -0
  70. skrift/templates/setup/site.html +39 -0
  71. skrift-0.1.0a12.dist-info/METADATA +235 -0
  72. skrift-0.1.0a12.dist-info/RECORD +74 -0
  73. skrift-0.1.0a12.dist-info/WHEEL +4 -0
  74. skrift-0.1.0a12.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,888 @@
1
+ """Setup wizard controller for first-time Skrift configuration."""
2
+
3
+ import asyncio
4
+ import base64
5
+ import hashlib
6
+ import json
7
+ import secrets
8
+ from collections.abc import AsyncGenerator
9
+ from contextlib import asynccontextmanager
10
+ from datetime import UTC, datetime
11
+ from urllib.parse import urlencode
12
+
13
+ import httpx
14
+ from typing import Annotated
15
+
16
+ from litestar import Controller, Request, get, post
17
+ from litestar.exceptions import HTTPException
18
+ from litestar.params import Parameter
19
+ from litestar.response import Redirect, Template as TemplateResponse
20
+ from litestar.response.sse import ServerSentEvent
21
+ from sqlalchemy import select
22
+ from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
23
+ from sqlalchemy.orm import selectinload
24
+
25
+ from skrift.db.models.oauth_account import OAuthAccount
26
+ from skrift.db.models.role import Role, user_roles
27
+ from skrift.db.models.user import User
28
+ from skrift.db.services import setting_service
29
+ from skrift.db.services.setting_service import (
30
+ SETUP_COMPLETED_AT_KEY,
31
+ get_setting,
32
+ )
33
+ from skrift.setup.config_writer import (
34
+ load_config,
35
+ update_auth_config,
36
+ update_database_config,
37
+ )
38
+ from skrift.setup.providers import get_all_providers, get_provider_info
39
+ from skrift.setup.state import (
40
+ can_connect_to_database,
41
+ get_database_url_from_yaml,
42
+ get_first_incomplete_step,
43
+ is_auth_configured,
44
+ is_site_configured,
45
+ run_migrations_if_needed,
46
+ reset_migrations_flag,
47
+ )
48
+
49
+
50
+ @asynccontextmanager
51
+ async def get_setup_db_session():
52
+ """Create a database session for setup operations.
53
+
54
+ This is used during setup when the SQLAlchemy plugin isn't available.
55
+ """
56
+ db_url = get_database_url_from_yaml()
57
+ if not db_url:
58
+ raise RuntimeError("Database not configured")
59
+
60
+ engine = create_async_engine(db_url)
61
+ async_session = async_sessionmaker(engine, expire_on_commit=False)
62
+
63
+ async with async_session() as session:
64
+ try:
65
+ yield session
66
+ await session.commit()
67
+ except Exception:
68
+ await session.rollback()
69
+ raise
70
+ finally:
71
+ await engine.dispose()
72
+
73
+
74
+ class SetupController(Controller):
75
+ """Controller for the setup wizard."""
76
+
77
+ path = "/setup"
78
+
79
+ async def _check_already_complete(self) -> bool:
80
+ """Defense in depth: check if setup is already complete."""
81
+ try:
82
+ async with get_setup_db_session() as db_session:
83
+ value = await get_setting(db_session, SETUP_COMPLETED_AT_KEY)
84
+ return value is not None
85
+ except Exception:
86
+ return False
87
+
88
+ @get("/")
89
+ async def index(self, request: Request) -> Redirect:
90
+ """Redirect to the first incomplete setup step.
91
+
92
+ Uses smart detection to skip already-configured steps, allowing users
93
+ to resume setup without re-entering existing configuration.
94
+ """
95
+ # Check if database is configured and connectable
96
+ can_connect, _ = await can_connect_to_database()
97
+ if can_connect:
98
+ # Database is configured - go through configuring page to run migrations
99
+ return Redirect(path="/setup/configuring")
100
+
101
+ # Database not configured - go to database step
102
+ request.session["setup_wizard_step"] = "database"
103
+ return Redirect(path="/setup/database")
104
+
105
+ @get("/database")
106
+ async def database_step(self, request: Request) -> TemplateResponse | Redirect:
107
+ """Step 1: Database configuration."""
108
+ flash = request.session.pop("flash", None)
109
+ error = request.session.pop("setup_error", None)
110
+
111
+ # If database is already configured and no errors, go to configuring page
112
+ can_connect, _ = await can_connect_to_database()
113
+ if can_connect and not error:
114
+ return Redirect(path="/setup/configuring")
115
+
116
+ # Load current config if exists
117
+ config = load_config()
118
+ db_config = config.get("db", {})
119
+ current_url = db_config.get("url", "")
120
+
121
+ # Determine current type
122
+ db_type = "sqlite"
123
+ if "postgresql" in current_url:
124
+ db_type = "postgresql"
125
+
126
+ return TemplateResponse(
127
+ "setup/database.html",
128
+ context={
129
+ "flash": flash,
130
+ "error": error,
131
+ "step": 1,
132
+ "total_steps": 4,
133
+ "db_type": db_type,
134
+ "current_url": current_url,
135
+ },
136
+ )
137
+
138
+ @post("/database")
139
+ async def save_database(self, request: Request) -> Redirect:
140
+ """Save database configuration."""
141
+ form_data = await request.form()
142
+ db_type = form_data.get("db_type", "sqlite")
143
+
144
+ try:
145
+ if db_type == "sqlite":
146
+ file_path = form_data.get("sqlite_path", "./app.db")
147
+ use_env = form_data.get("sqlite_path_env") == "on"
148
+
149
+ update_database_config(
150
+ db_type="sqlite",
151
+ url=file_path,
152
+ use_env_vars={"url": use_env},
153
+ )
154
+ else:
155
+ # PostgreSQL
156
+ use_env_url = form_data.get("pg_url_env") == "on"
157
+
158
+ if use_env_url:
159
+ env_var = form_data.get("pg_url_envvar", "DATABASE_URL")
160
+ update_database_config(
161
+ db_type="postgresql",
162
+ url=env_var,
163
+ use_env_vars={"url": True},
164
+ )
165
+ else:
166
+ host = form_data.get("pg_host", "localhost")
167
+ port = int(form_data.get("pg_port", 5432))
168
+ database = form_data.get("pg_database", "skrift")
169
+ username = form_data.get("pg_username", "postgres")
170
+ password = form_data.get("pg_password", "")
171
+
172
+ update_database_config(
173
+ db_type="postgresql",
174
+ host=host,
175
+ port=port,
176
+ database=database,
177
+ username=username,
178
+ password=password,
179
+ )
180
+
181
+ # Test connection
182
+ can_connect, error = await can_connect_to_database()
183
+ if not can_connect:
184
+ request.session["setup_error"] = f"Connection failed: {error}"
185
+ return Redirect(path="/setup/database")
186
+
187
+ # Connection successful - redirect to configuring page to run migrations
188
+ request.session["setup_wizard_step"] = "configuring"
189
+ return Redirect(path="/setup/configuring")
190
+
191
+ except Exception as e:
192
+ request.session["setup_error"] = str(e)
193
+ return Redirect(path="/setup/database")
194
+
195
+ @get("/restart")
196
+ async def restart_step(self, request: Request) -> Redirect:
197
+ """Legacy restart route - now redirects to auth since restart is no longer required."""
198
+ request.session["setup_wizard_step"] = "auth"
199
+ return Redirect(path="/setup/auth")
200
+
201
+ @get("/configuring")
202
+ async def configuring_step(self, request: Request) -> TemplateResponse | Redirect:
203
+ """Database configuration in progress page.
204
+
205
+ Shows a loading spinner while migrations run via SSE.
206
+ """
207
+ flash = request.session.pop("flash", None)
208
+ error = request.session.pop("setup_error", None)
209
+
210
+ # Verify we can connect to the database first
211
+ can_connect, connection_error = await can_connect_to_database()
212
+ if not can_connect:
213
+ request.session["setup_error"] = f"Cannot connect to database: {connection_error}"
214
+ return Redirect(path="/setup/database")
215
+
216
+ # Reset migrations flag so they run fresh via SSE
217
+ reset_migrations_flag()
218
+
219
+ return TemplateResponse(
220
+ "setup/configuring.html",
221
+ context={
222
+ "flash": flash,
223
+ "error": error,
224
+ "step": 1,
225
+ "total_steps": 4,
226
+ },
227
+ )
228
+
229
+ @get("/configuring/status")
230
+ async def configuring_status(self, request: Request) -> ServerSentEvent:
231
+ """SSE endpoint for database configuration status.
232
+
233
+ Streams migration progress and completion status.
234
+ """
235
+ async def generate_status() -> AsyncGenerator[str, None]:
236
+ # Send initial status
237
+ yield json.dumps({
238
+ "status": "running",
239
+ "message": "Testing database connection...",
240
+ "detail": "",
241
+ })
242
+
243
+ await asyncio.sleep(0.5)
244
+
245
+ # Test connection
246
+ can_connect, connection_error = await can_connect_to_database()
247
+ if not can_connect:
248
+ yield json.dumps({
249
+ "status": "error",
250
+ "message": f"Database connection failed: {connection_error}",
251
+ })
252
+ return
253
+
254
+ yield json.dumps({
255
+ "status": "running",
256
+ "message": "Running database migrations...",
257
+ "detail": "This may take a moment",
258
+ })
259
+
260
+ await asyncio.sleep(0.3)
261
+
262
+ # Run migrations
263
+ success, error = run_migrations_if_needed()
264
+
265
+ if not success:
266
+ yield json.dumps({
267
+ "status": "error",
268
+ "message": f"Migration failed: {error}",
269
+ })
270
+ return
271
+
272
+ yield json.dumps({
273
+ "status": "running",
274
+ "message": "Verifying database schema...",
275
+ "detail": "",
276
+ })
277
+
278
+ await asyncio.sleep(0.3)
279
+
280
+ # Determine next step
281
+ if is_auth_configured():
282
+ if await is_site_configured():
283
+ next_step = "admin"
284
+ else:
285
+ next_step = "site"
286
+ else:
287
+ next_step = "auth"
288
+
289
+ # All done - include next step
290
+ yield json.dumps({
291
+ "status": "complete",
292
+ "message": "Database configured successfully!",
293
+ "next_step": next_step,
294
+ })
295
+
296
+ return ServerSentEvent(generate_status())
297
+
298
+ @get("/auth")
299
+ async def auth_step(self, request: Request) -> TemplateResponse | Redirect:
300
+ """Step 2: Authentication providers."""
301
+ flash = request.session.pop("flash", None)
302
+ error = request.session.pop("setup_error", None)
303
+
304
+ # If auth is already configured and no errors, skip to next step
305
+ if is_auth_configured() and not error:
306
+ next_step = await get_first_incomplete_step()
307
+ if next_step.value != "auth":
308
+ request.session["setup_wizard_step"] = next_step.value
309
+ return Redirect(path=f"/setup/{next_step.value}")
310
+
311
+ # Get current redirect URL from request
312
+ scheme = request.headers.get("x-forwarded-proto", request.url.scheme)
313
+ host = request.headers.get("host", request.url.netloc)
314
+ redirect_base_url = f"{scheme}://{host}"
315
+
316
+ # Get configured providers
317
+ config = load_config()
318
+ auth_config = config.get("auth", {})
319
+ configured_providers = auth_config.get("providers", {})
320
+
321
+ # Get all available providers
322
+ all_providers = get_all_providers()
323
+
324
+ return TemplateResponse(
325
+ "setup/auth.html",
326
+ context={
327
+ "flash": flash,
328
+ "error": error,
329
+ "step": 2,
330
+ "total_steps": 4,
331
+ "redirect_base_url": redirect_base_url,
332
+ "providers": all_providers,
333
+ "configured_providers": configured_providers,
334
+ },
335
+ )
336
+
337
+ @post("/auth")
338
+ async def save_auth(self, request: Request) -> Redirect:
339
+ """Save authentication configuration."""
340
+ form_data = await request.form()
341
+
342
+ # Get redirect base URL
343
+ redirect_base_url = form_data.get("redirect_base_url", "http://localhost:8000")
344
+
345
+ # Parse provider configurations
346
+ all_providers = get_all_providers()
347
+ providers = {}
348
+ use_env_vars = {}
349
+
350
+ for provider_key in all_providers.keys():
351
+ enabled = form_data.get(f"{provider_key}_enabled") == "on"
352
+ if not enabled:
353
+ continue
354
+
355
+ provider_info = all_providers[provider_key]
356
+ provider_config = {}
357
+ provider_env_vars = {}
358
+
359
+ for field in provider_info.fields:
360
+ field_key = field["key"]
361
+ value = form_data.get(f"{provider_key}_{field_key}", "")
362
+ use_env = form_data.get(f"{provider_key}_{field_key}_env") == "on"
363
+
364
+ if value or not field.get("optional"):
365
+ provider_config[field_key] = value
366
+ provider_env_vars[field_key] = use_env
367
+
368
+ if provider_config:
369
+ providers[provider_key] = provider_config
370
+ use_env_vars[provider_key] = provider_env_vars
371
+
372
+ if not providers:
373
+ request.session["setup_error"] = "Please configure at least one authentication provider"
374
+ return Redirect(path="/setup/auth")
375
+
376
+ try:
377
+ update_auth_config(
378
+ redirect_base_url=redirect_base_url,
379
+ providers=providers,
380
+ use_env_vars=use_env_vars,
381
+ )
382
+
383
+ # Determine next step using smart detection
384
+ next_step = await get_first_incomplete_step()
385
+ request.session["setup_wizard_step"] = next_step.value
386
+ request.session["flash"] = "Authentication configured successfully!"
387
+ return Redirect(path=f"/setup/{next_step.value}")
388
+
389
+ except Exception as e:
390
+ request.session["setup_error"] = str(e)
391
+ return Redirect(path="/setup/auth")
392
+
393
+ @get("/site")
394
+ async def site_step(self, request: Request) -> TemplateResponse | Redirect:
395
+ """Step 3: Site settings."""
396
+ flash = request.session.pop("flash", None)
397
+ error = request.session.pop("setup_error", None)
398
+
399
+ # If site is already configured and no errors, skip to next step
400
+ if await is_site_configured() and not error:
401
+ next_step = await get_first_incomplete_step()
402
+ if next_step.value != "site":
403
+ request.session["setup_wizard_step"] = next_step.value
404
+ return Redirect(path=f"/setup/{next_step.value}")
405
+
406
+ return TemplateResponse(
407
+ "setup/site.html",
408
+ context={
409
+ "flash": flash,
410
+ "error": error,
411
+ "step": 3,
412
+ "total_steps": 4,
413
+ "settings": {
414
+ "site_name": "",
415
+ "site_tagline": "",
416
+ "site_copyright_holder": "",
417
+ "site_copyright_start_year": datetime.now().year,
418
+ },
419
+ },
420
+ )
421
+
422
+ @post("/site")
423
+ async def save_site(self, request: Request) -> Redirect:
424
+ """Save site settings."""
425
+ form_data = await request.form()
426
+
427
+ try:
428
+ site_name = form_data.get("site_name", "").strip()
429
+ if not site_name:
430
+ request.session["setup_error"] = "Site name is required"
431
+ return Redirect(path="/setup/site")
432
+
433
+ site_tagline = form_data.get("site_tagline", "").strip()
434
+ site_copyright_holder = form_data.get("site_copyright_holder", "").strip()
435
+ site_copyright_start_year = form_data.get("site_copyright_start_year", "").strip()
436
+
437
+ # Save settings to database using manual session
438
+ async with get_setup_db_session() as db_session:
439
+ await setting_service.set_setting(
440
+ db_session, setting_service.SITE_NAME_KEY, site_name
441
+ )
442
+ await setting_service.set_setting(
443
+ db_session, setting_service.SITE_TAGLINE_KEY, site_tagline
444
+ )
445
+ await setting_service.set_setting(
446
+ db_session, setting_service.SITE_COPYRIGHT_HOLDER_KEY, site_copyright_holder
447
+ )
448
+ await setting_service.set_setting(
449
+ db_session,
450
+ setting_service.SITE_COPYRIGHT_START_YEAR_KEY,
451
+ site_copyright_start_year,
452
+ )
453
+
454
+ # Reload cache
455
+ await setting_service.load_site_settings_cache(db_session)
456
+
457
+ # Determine next step using smart detection - should be admin at this point
458
+ next_step = await get_first_incomplete_step()
459
+ request.session["setup_wizard_step"] = next_step.value
460
+ request.session["flash"] = "Site settings saved!"
461
+ return Redirect(path=f"/setup/{next_step.value}")
462
+
463
+ except Exception as e:
464
+ request.session["setup_error"] = str(e)
465
+ return Redirect(path="/setup/site")
466
+
467
+ @get("/admin")
468
+ async def admin_step(self, request: Request) -> TemplateResponse:
469
+ """Step 4: Create admin account."""
470
+ flash = request.session.pop("flash", None)
471
+ error = request.session.pop("setup_error", None)
472
+
473
+ # Get configured providers
474
+ config = load_config()
475
+ auth_config = config.get("auth", {})
476
+ configured_providers = list(auth_config.get("providers", {}).keys())
477
+
478
+ # Get provider display info
479
+ all_providers = get_all_providers()
480
+ provider_info = {
481
+ key: all_providers[key] for key in configured_providers if key in all_providers
482
+ }
483
+
484
+ return TemplateResponse(
485
+ "setup/admin.html",
486
+ context={
487
+ "flash": flash,
488
+ "error": error,
489
+ "step": 4,
490
+ "total_steps": 4,
491
+ "providers": provider_info,
492
+ "configured_providers": configured_providers,
493
+ },
494
+ )
495
+
496
+ @get("/oauth/{provider:str}/login")
497
+ async def setup_oauth_login(self, request: Request, provider: str) -> Redirect:
498
+ """Redirect to OAuth provider for setup admin creation."""
499
+ config = load_config()
500
+ auth_config = config.get("auth", {})
501
+ providers_config = auth_config.get("providers", {})
502
+
503
+ if provider not in providers_config:
504
+ raise HTTPException(status_code=404, detail=f"Provider {provider} not configured")
505
+
506
+ provider_info = get_provider_info(provider)
507
+ if not provider_info:
508
+ raise HTTPException(status_code=404, detail=f"Unknown provider: {provider}")
509
+
510
+ provider_config = providers_config[provider]
511
+
512
+ # Resolve env var references in config
513
+ client_id = self._resolve_env_var(provider_config.get("client_id", ""))
514
+
515
+ # Generate CSRF state token
516
+ state = secrets.token_urlsafe(32)
517
+ request.session["oauth_state"] = state
518
+ request.session["oauth_provider"] = provider
519
+ request.session["oauth_setup"] = True
520
+
521
+ # Build redirect URI - use the standard /auth callback URL
522
+ # This matches what's configured in the OAuth provider console
523
+ scheme = request.headers.get("x-forwarded-proto", request.url.scheme)
524
+ host = request.headers.get("host", request.url.netloc)
525
+ redirect_uri = f"{scheme}://{host}/auth/{provider}/callback"
526
+
527
+ # Get scopes
528
+ scopes = provider_config.get("scopes", provider_info.scopes)
529
+
530
+ # Generate PKCE for Twitter
531
+ code_challenge = None
532
+ if provider == "twitter":
533
+ code_verifier = secrets.token_urlsafe(64)[:128]
534
+ request.session["oauth_code_verifier"] = code_verifier
535
+ code_challenge = base64.urlsafe_b64encode(
536
+ hashlib.sha256(code_verifier.encode()).digest()
537
+ ).decode().rstrip("=")
538
+
539
+ # Build auth URL
540
+ auth_url = provider_info.auth_url
541
+ if "{tenant}" in auth_url:
542
+ tenant = provider_config.get("tenant_id", "common")
543
+ if isinstance(tenant, str) and tenant.startswith("$"):
544
+ tenant = self._resolve_env_var(tenant) or "common"
545
+ auth_url = auth_url.replace("{tenant}", tenant)
546
+
547
+ params = {
548
+ "client_id": client_id,
549
+ "redirect_uri": redirect_uri,
550
+ "response_type": "code",
551
+ "scope": " ".join(scopes),
552
+ "state": state,
553
+ }
554
+
555
+ if provider == "google":
556
+ params["access_type"] = "offline"
557
+ params["prompt"] = "select_account"
558
+ elif provider == "twitter" and code_challenge:
559
+ params["code_challenge"] = code_challenge
560
+ params["code_challenge_method"] = "S256"
561
+ elif provider == "discord":
562
+ params["prompt"] = "consent"
563
+
564
+ return Redirect(path=f"{auth_url}?{urlencode(params)}")
565
+
566
+ def _resolve_env_var(self, value: str) -> str:
567
+ """Resolve environment variable reference if value starts with $."""
568
+ import os
569
+ if value.startswith("$"):
570
+ return os.environ.get(value[1:], "")
571
+ return value
572
+
573
+ @get("/complete")
574
+ async def complete(self, request: Request) -> TemplateResponse | Redirect:
575
+ """Setup complete page."""
576
+ # Verify setup is actually complete in database
577
+ if not await self._check_already_complete():
578
+ return Redirect(path="/setup")
579
+
580
+ # Clear the session flag if present
581
+ request.session.pop("setup_just_completed", None)
582
+
583
+ return TemplateResponse(
584
+ "setup/complete.html",
585
+ context={
586
+ "step": 4,
587
+ "total_steps": 4,
588
+ },
589
+ )
590
+
591
+
592
+ async def mark_setup_complete(db_session: AsyncSession | None = None) -> None:
593
+ """Mark setup as complete by setting the timestamp.
594
+
595
+ Args:
596
+ db_session: Optional database session. If not provided, creates one.
597
+ """
598
+ timestamp = datetime.now(UTC).isoformat()
599
+ if db_session:
600
+ await setting_service.set_setting(db_session, SETUP_COMPLETED_AT_KEY, timestamp)
601
+ else:
602
+ async with get_setup_db_session() as session:
603
+ await setting_service.set_setting(session, SETUP_COMPLETED_AT_KEY, timestamp)
604
+
605
+
606
+ class SetupAuthController(Controller):
607
+ """Auth controller for setup OAuth callbacks.
608
+
609
+ This handles OAuth callbacks at /auth/{provider}/callback during setup,
610
+ matching the redirect URI configured in OAuth providers.
611
+ """
612
+
613
+ path = "/auth"
614
+
615
+ @get("/{provider:str}/callback")
616
+ async def setup_oauth_callback(
617
+ self,
618
+ request: Request,
619
+ provider: str,
620
+ code: str | None = None,
621
+ oauth_state: Annotated[str | None, Parameter(query="state")] = None,
622
+ error: str | None = None,
623
+ ) -> Redirect:
624
+ """Handle OAuth callback during setup."""
625
+ # Check if this is a setup flow
626
+ if not request.session.get("oauth_setup"):
627
+ # Not a setup flow, return error
628
+ raise HTTPException(status_code=400, detail="Invalid OAuth flow")
629
+
630
+ if error:
631
+ request.session["setup_error"] = f"OAuth error: {error}"
632
+ return Redirect(path="/setup/admin")
633
+
634
+ # Verify CSRF state
635
+ stored_state = request.session.pop("oauth_state", None)
636
+ if not oauth_state or oauth_state != stored_state:
637
+ raise HTTPException(status_code=400, detail="Invalid OAuth state")
638
+
639
+ if not code:
640
+ raise HTTPException(status_code=400, detail="Missing authorization code")
641
+
642
+ config = load_config()
643
+ auth_config = config.get("auth", {})
644
+ providers_config = auth_config.get("providers", {})
645
+
646
+ if provider not in providers_config:
647
+ raise HTTPException(status_code=404, detail=f"Provider {provider} not configured")
648
+
649
+ provider_info = get_provider_info(provider)
650
+ provider_config = providers_config[provider]
651
+
652
+ # Resolve env vars
653
+ client_id = self._resolve_env_var(provider_config.get("client_id", ""))
654
+ client_secret = self._resolve_env_var(provider_config.get("client_secret", ""))
655
+
656
+ # Build redirect URI
657
+ scheme = request.headers.get("x-forwarded-proto", request.url.scheme)
658
+ host = request.headers.get("host", request.url.netloc)
659
+ redirect_uri = f"{scheme}://{host}/auth/{provider}/callback"
660
+
661
+ # Get PKCE verifier if present
662
+ code_verifier = request.session.pop("oauth_code_verifier", None)
663
+
664
+ # Exchange code for token
665
+ token_url = provider_info.token_url
666
+ if "{tenant}" in token_url:
667
+ tenant = provider_config.get("tenant_id", "common")
668
+ if isinstance(tenant, str) and tenant.startswith("$"):
669
+ tenant = self._resolve_env_var(tenant) or "common"
670
+ token_url = token_url.replace("{tenant}", tenant)
671
+
672
+ data = {
673
+ "client_id": client_id,
674
+ "client_secret": client_secret,
675
+ "code": code,
676
+ "grant_type": "authorization_code",
677
+ "redirect_uri": redirect_uri,
678
+ }
679
+
680
+ if provider == "twitter" and code_verifier:
681
+ data["code_verifier"] = code_verifier
682
+
683
+ headers = {"Accept": "application/json"}
684
+ if provider == "github":
685
+ headers["Accept"] = "application/json"
686
+ if provider == "twitter":
687
+ credentials = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
688
+ headers["Authorization"] = f"Basic {credentials}"
689
+ del data["client_secret"]
690
+
691
+ async with httpx.AsyncClient() as client:
692
+ response = await client.post(token_url, data=data, headers=headers)
693
+ if response.status_code != 200:
694
+ raise HTTPException(status_code=400, detail=f"Token exchange failed: {response.text}")
695
+ tokens = response.json()
696
+
697
+ access_token = tokens.get("access_token")
698
+ if not access_token:
699
+ raise HTTPException(status_code=400, detail="No access token received")
700
+
701
+ # Fetch user info
702
+ async with httpx.AsyncClient() as client:
703
+ user_headers = {"Authorization": f"Bearer {access_token}"}
704
+ response = await client.get(provider_info.userinfo_url, headers=user_headers)
705
+ if response.status_code != 200:
706
+ raise HTTPException(status_code=400, detail="Failed to fetch user info")
707
+ user_info = response.json()
708
+
709
+ # GitHub email handling
710
+ if provider == "github" and not user_info.get("email"):
711
+ email_response = await client.get("https://api.github.com/user/emails", headers=user_headers)
712
+ if email_response.status_code == 200:
713
+ emails = email_response.json()
714
+ primary_email = next((e["email"] for e in emails if e.get("primary")), None)
715
+ if primary_email:
716
+ user_info["email"] = primary_email
717
+
718
+ # Extract user data based on provider
719
+ user_data = self._extract_user_data(provider, user_info)
720
+ oauth_id = user_data["oauth_id"]
721
+ if not oauth_id:
722
+ raise HTTPException(status_code=400, detail="Could not determine user ID")
723
+
724
+ email = user_data["email"]
725
+
726
+ # Create user and mark setup complete
727
+ async with get_setup_db_session() as db_session:
728
+ # Step 1: Check if OAuth account already exists
729
+ result = await db_session.execute(
730
+ select(OAuthAccount)
731
+ .options(selectinload(OAuthAccount.user))
732
+ .where(OAuthAccount.provider == provider, OAuthAccount.provider_account_id == oauth_id)
733
+ )
734
+ oauth_account = result.scalar_one_or_none()
735
+
736
+ if oauth_account:
737
+ # Existing OAuth account - update user profile
738
+ user = oauth_account.user
739
+ user.name = user_data["name"]
740
+ if user_data["picture_url"]:
741
+ user.picture_url = user_data["picture_url"]
742
+ user.last_login_at = datetime.now(UTC)
743
+ if email:
744
+ oauth_account.provider_email = email
745
+ else:
746
+ # Step 2: Check if a user with this email already exists
747
+ user = None
748
+ if email:
749
+ result = await db_session.execute(
750
+ select(User).where(User.email == email)
751
+ )
752
+ user = result.scalar_one_or_none()
753
+
754
+ if user:
755
+ # Link new OAuth account to existing user
756
+ oauth_account = OAuthAccount(
757
+ provider=provider,
758
+ provider_account_id=oauth_id,
759
+ provider_email=email,
760
+ user_id=user.id,
761
+ )
762
+ db_session.add(oauth_account)
763
+ # Update user profile
764
+ user.name = user_data["name"]
765
+ if user_data["picture_url"]:
766
+ user.picture_url = user_data["picture_url"]
767
+ user.last_login_at = datetime.now(UTC)
768
+ else:
769
+ # Step 3: Create new user + OAuth account
770
+ user = User(
771
+ email=email,
772
+ name=user_data["name"],
773
+ picture_url=user_data["picture_url"],
774
+ last_login_at=datetime.now(UTC),
775
+ )
776
+ db_session.add(user)
777
+ await db_session.flush()
778
+
779
+ oauth_account = OAuthAccount(
780
+ provider=provider,
781
+ provider_account_id=oauth_id,
782
+ provider_email=email,
783
+ user_id=user.id,
784
+ )
785
+ db_session.add(oauth_account)
786
+
787
+ # Ensure roles are synced (they may not exist if DB was created after server start)
788
+ from skrift.auth import sync_roles_to_database
789
+ await sync_roles_to_database(db_session)
790
+
791
+ # Always assign admin role during setup (whether user is new or existing)
792
+ admin_role = await db_session.scalar(select(Role).where(Role.name == "admin"))
793
+ if admin_role:
794
+ # Check if user already has admin role
795
+ existing = await db_session.execute(
796
+ select(user_roles).where(
797
+ user_roles.c.user_id == user.id,
798
+ user_roles.c.role_id == admin_role.id
799
+ )
800
+ )
801
+ if not existing.first():
802
+ await db_session.execute(
803
+ user_roles.insert().values(user_id=user.id, role_id=admin_role.id)
804
+ )
805
+
806
+ # Mark setup complete
807
+ timestamp = datetime.now(UTC).isoformat()
808
+ await setting_service.set_setting(db_session, SETUP_COMPLETED_AT_KEY, timestamp)
809
+
810
+ # Clear setup flag
811
+ request.session.pop("oauth_setup", None)
812
+
813
+ # Set session
814
+ request.session["user_id"] = str(user.id)
815
+ request.session["user_name"] = user.name
816
+ request.session["user_email"] = user.email
817
+ request.session["user_picture_url"] = user.picture_url
818
+ request.session["flash"] = "Admin account created successfully!"
819
+ request.session["setup_just_completed"] = True
820
+
821
+ # Note: Don't call mark_setup_complete_in_dispatcher() here.
822
+ # The switch happens in /setup/complete after rendering the page.
823
+
824
+ return Redirect(path="/setup/complete")
825
+
826
+ def _resolve_env_var(self, value: str) -> str:
827
+ """Resolve environment variable reference if value starts with $."""
828
+ import os
829
+ if value.startswith("$"):
830
+ return os.environ.get(value[1:], "")
831
+ return value
832
+
833
+ def _extract_user_data(self, provider: str, user_info: dict) -> dict:
834
+ """Extract normalized user data from provider response."""
835
+ if provider == "google":
836
+ return {
837
+ "oauth_id": user_info.get("id"),
838
+ "email": user_info.get("email"),
839
+ "name": user_info.get("name"),
840
+ "picture_url": user_info.get("picture"),
841
+ }
842
+ elif provider == "github":
843
+ return {
844
+ "oauth_id": str(user_info.get("id")),
845
+ "email": user_info.get("email"),
846
+ "name": user_info.get("name") or user_info.get("login"),
847
+ "picture_url": user_info.get("avatar_url"),
848
+ }
849
+ elif provider == "microsoft":
850
+ return {
851
+ "oauth_id": user_info.get("id"),
852
+ "email": user_info.get("mail") or user_info.get("userPrincipalName"),
853
+ "name": user_info.get("displayName"),
854
+ "picture_url": None,
855
+ }
856
+ elif provider == "discord":
857
+ avatar = user_info.get("avatar")
858
+ user_id = user_info.get("id")
859
+ avatar_url = f"https://cdn.discordapp.com/avatars/{user_id}/{avatar}.png" if avatar and user_id else None
860
+ return {
861
+ "oauth_id": user_id,
862
+ "email": user_info.get("email"),
863
+ "name": user_info.get("global_name") or user_info.get("username"),
864
+ "picture_url": avatar_url,
865
+ }
866
+ elif provider == "facebook":
867
+ picture = user_info.get("picture", {}).get("data", {})
868
+ return {
869
+ "oauth_id": user_info.get("id"),
870
+ "email": user_info.get("email"),
871
+ "name": user_info.get("name"),
872
+ "picture_url": picture.get("url") if not picture.get("is_silhouette") else None,
873
+ }
874
+ elif provider == "twitter":
875
+ data = user_info.get("data", user_info)
876
+ return {
877
+ "oauth_id": data.get("id"),
878
+ "email": data.get("email"),
879
+ "name": data.get("name") or data.get("username"),
880
+ "picture_url": None,
881
+ }
882
+ else:
883
+ return {
884
+ "oauth_id": str(user_info.get("id", user_info.get("sub"))),
885
+ "email": user_info.get("email"),
886
+ "name": user_info.get("name"),
887
+ "picture_url": user_info.get("picture"),
888
+ }