skrift 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. skrift/__init__.py +1 -0
  2. skrift/__main__.py +17 -0
  3. skrift/admin/__init__.py +11 -0
  4. skrift/admin/controller.py +452 -0
  5. skrift/admin/navigation.py +105 -0
  6. skrift/alembic/env.py +91 -0
  7. skrift/alembic/script.py.mako +26 -0
  8. skrift/alembic/versions/20260120_210154_09b0364dbb7b_initial_schema.py +70 -0
  9. skrift/alembic/versions/20260122_152744_0b7c927d2591_add_roles_and_permissions.py +57 -0
  10. skrift/alembic/versions/20260122_172836_cdf734a5b847_add_sa_orm_sentinel_column.py +31 -0
  11. skrift/alembic/versions/20260122_175637_a9c55348eae7_remove_page_type_column.py +43 -0
  12. skrift/alembic/versions/20260122_200000_add_settings_table.py +38 -0
  13. skrift/alembic.ini +77 -0
  14. skrift/asgi.py +545 -0
  15. skrift/auth/__init__.py +58 -0
  16. skrift/auth/guards.py +130 -0
  17. skrift/auth/roles.py +94 -0
  18. skrift/auth/services.py +184 -0
  19. skrift/cli.py +45 -0
  20. skrift/config.py +192 -0
  21. skrift/controllers/__init__.py +4 -0
  22. skrift/controllers/auth.py +371 -0
  23. skrift/controllers/web.py +67 -0
  24. skrift/db/__init__.py +3 -0
  25. skrift/db/base.py +7 -0
  26. skrift/db/models/__init__.py +6 -0
  27. skrift/db/models/page.py +26 -0
  28. skrift/db/models/role.py +56 -0
  29. skrift/db/models/setting.py +13 -0
  30. skrift/db/models/user.py +36 -0
  31. skrift/db/services/__init__.py +1 -0
  32. skrift/db/services/page_service.py +217 -0
  33. skrift/db/services/setting_service.py +206 -0
  34. skrift/lib/__init__.py +3 -0
  35. skrift/lib/exceptions.py +168 -0
  36. skrift/lib/template.py +108 -0
  37. skrift/setup/__init__.py +14 -0
  38. skrift/setup/config_writer.py +211 -0
  39. skrift/setup/controller.py +751 -0
  40. skrift/setup/middleware.py +89 -0
  41. skrift/setup/providers.py +163 -0
  42. skrift/setup/state.py +134 -0
  43. skrift/static/css/style.css +998 -0
  44. skrift/templates/admin/admin.html +19 -0
  45. skrift/templates/admin/base.html +24 -0
  46. skrift/templates/admin/pages/edit.html +32 -0
  47. skrift/templates/admin/pages/list.html +62 -0
  48. skrift/templates/admin/settings/site.html +32 -0
  49. skrift/templates/admin/users/list.html +58 -0
  50. skrift/templates/admin/users/roles.html +42 -0
  51. skrift/templates/auth/login.html +125 -0
  52. skrift/templates/base.html +52 -0
  53. skrift/templates/error-404.html +19 -0
  54. skrift/templates/error-500.html +19 -0
  55. skrift/templates/error.html +19 -0
  56. skrift/templates/index.html +9 -0
  57. skrift/templates/page.html +26 -0
  58. skrift/templates/setup/admin.html +24 -0
  59. skrift/templates/setup/auth.html +110 -0
  60. skrift/templates/setup/base.html +407 -0
  61. skrift/templates/setup/complete.html +17 -0
  62. skrift/templates/setup/database.html +125 -0
  63. skrift/templates/setup/restart.html +28 -0
  64. skrift/templates/setup/site.html +39 -0
  65. skrift-0.1.0a1.dist-info/METADATA +233 -0
  66. skrift-0.1.0a1.dist-info/RECORD +68 -0
  67. skrift-0.1.0a1.dist-info/WHEEL +4 -0
  68. skrift-0.1.0a1.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,751 @@
1
+ """Setup wizard controller for first-time Skrift configuration."""
2
+
3
+ import base64
4
+ import hashlib
5
+ import secrets
6
+ import subprocess
7
+ from contextlib import asynccontextmanager
8
+ from datetime import UTC, datetime
9
+ from pathlib import Path
10
+ from urllib.parse import urlencode
11
+
12
+ import httpx
13
+ from typing import Annotated
14
+
15
+ from litestar import Controller, Request, get, post
16
+ from litestar.exceptions import HTTPException
17
+ from litestar.params import Parameter
18
+ from litestar.response import Redirect, Template as TemplateResponse
19
+ from sqlalchemy import func, select
20
+ from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
21
+
22
+ from skrift.db.models.role import Role, user_roles
23
+ from skrift.db.models.user import User
24
+ from skrift.db.services import setting_service
25
+ from skrift.db.services.setting_service import (
26
+ SETUP_COMPLETED_AT_KEY,
27
+ get_setting,
28
+ )
29
+ from skrift.setup.config_writer import (
30
+ load_config,
31
+ update_auth_config,
32
+ update_database_config,
33
+ )
34
+ from skrift.setup.providers import get_all_providers, get_provider_info
35
+ from skrift.setup.state import can_connect_to_database, app_yaml_exists, get_database_url_from_yaml
36
+
37
+
38
+ @asynccontextmanager
39
+ async def get_setup_db_session():
40
+ """Create a database session for setup operations.
41
+
42
+ This is used during setup when the SQLAlchemy plugin isn't available.
43
+ """
44
+ db_url = get_database_url_from_yaml()
45
+ if not db_url:
46
+ raise RuntimeError("Database not configured")
47
+
48
+ engine = create_async_engine(db_url)
49
+ async_session = async_sessionmaker(engine, expire_on_commit=False)
50
+
51
+ async with async_session() as session:
52
+ try:
53
+ yield session
54
+ await session.commit()
55
+ except Exception:
56
+ await session.rollback()
57
+ raise
58
+ finally:
59
+ await engine.dispose()
60
+
61
+
62
+ class SetupController(Controller):
63
+ """Controller for the setup wizard."""
64
+
65
+ path = "/setup"
66
+
67
+ async def _check_already_complete(self) -> bool:
68
+ """Defense in depth: check if setup is already complete."""
69
+ try:
70
+ async with get_setup_db_session() as db_session:
71
+ value = await get_setting(db_session, SETUP_COMPLETED_AT_KEY)
72
+ return value is not None
73
+ except Exception:
74
+ return False
75
+
76
+ @get("/")
77
+ async def index(self, request: Request) -> Redirect:
78
+ """Redirect to appropriate setup step."""
79
+ # Check wizard progress from session
80
+ wizard_step = request.session.get("setup_wizard_step", "database")
81
+
82
+ # If we don't have app.yaml or db isn't configured, start at database
83
+ if not app_yaml_exists():
84
+ return Redirect(path="/setup/database")
85
+
86
+ can_connect, _ = await can_connect_to_database()
87
+ if not can_connect:
88
+ return Redirect(path="/setup/database")
89
+
90
+ # Otherwise go to the saved step
91
+ return Redirect(path=f"/setup/{wizard_step}")
92
+
93
+ @get("/database")
94
+ async def database_step(self, request: Request) -> TemplateResponse:
95
+ """Step 1: Database configuration."""
96
+ flash = request.session.pop("flash", None)
97
+ error = request.session.pop("setup_error", None)
98
+
99
+ # Load current config if exists
100
+ config = load_config()
101
+ db_config = config.get("db", {})
102
+ current_url = db_config.get("url", "")
103
+
104
+ # Determine current type
105
+ db_type = "sqlite"
106
+ if "postgresql" in current_url:
107
+ db_type = "postgresql"
108
+
109
+ return TemplateResponse(
110
+ "setup/database.html",
111
+ context={
112
+ "flash": flash,
113
+ "error": error,
114
+ "step": 1,
115
+ "total_steps": 4,
116
+ "db_type": db_type,
117
+ "current_url": current_url,
118
+ },
119
+ )
120
+
121
+ @post("/database")
122
+ async def save_database(self, request: Request) -> Redirect:
123
+ """Save database configuration."""
124
+ form_data = await request.form()
125
+ db_type = form_data.get("db_type", "sqlite")
126
+
127
+ try:
128
+ if db_type == "sqlite":
129
+ file_path = form_data.get("sqlite_path", "./app.db")
130
+ use_env = form_data.get("sqlite_path_env") == "on"
131
+
132
+ update_database_config(
133
+ db_type="sqlite",
134
+ url=file_path,
135
+ use_env_vars={"url": use_env},
136
+ )
137
+ else:
138
+ # PostgreSQL
139
+ use_env_url = form_data.get("pg_url_env") == "on"
140
+
141
+ if use_env_url:
142
+ env_var = form_data.get("pg_url_envvar", "DATABASE_URL")
143
+ update_database_config(
144
+ db_type="postgresql",
145
+ url=env_var,
146
+ use_env_vars={"url": True},
147
+ )
148
+ else:
149
+ host = form_data.get("pg_host", "localhost")
150
+ port = int(form_data.get("pg_port", 5432))
151
+ database = form_data.get("pg_database", "skrift")
152
+ username = form_data.get("pg_username", "postgres")
153
+ password = form_data.get("pg_password", "")
154
+
155
+ update_database_config(
156
+ db_type="postgresql",
157
+ host=host,
158
+ port=port,
159
+ database=database,
160
+ username=username,
161
+ password=password,
162
+ )
163
+
164
+ # Test connection
165
+ can_connect, error = await can_connect_to_database()
166
+ if not can_connect:
167
+ request.session["setup_error"] = f"Connection failed: {error}"
168
+ return Redirect(path="/setup/database")
169
+
170
+ # Run migrations
171
+ try:
172
+ result = subprocess.run(
173
+ ["skrift-db", "upgrade", "head"],
174
+ capture_output=True,
175
+ text=True,
176
+ cwd=Path.cwd(),
177
+ timeout=60,
178
+ )
179
+ if result.returncode != 0:
180
+ request.session["setup_error"] = f"Migration failed: {result.stderr}"
181
+ return Redirect(path="/setup/database")
182
+ except subprocess.TimeoutExpired:
183
+ request.session["setup_error"] = "Migration timed out"
184
+ return Redirect(path="/setup/database")
185
+ except FileNotFoundError:
186
+ # skrift-db might not be installed yet, try alembic directly
187
+ try:
188
+ result = subprocess.run(
189
+ ["alembic", "upgrade", "head"],
190
+ capture_output=True,
191
+ text=True,
192
+ cwd=Path.cwd(),
193
+ timeout=60,
194
+ )
195
+ if result.returncode != 0:
196
+ request.session["setup_error"] = f"Migration failed: {result.stderr}"
197
+ return Redirect(path="/setup/database")
198
+ except Exception as e:
199
+ request.session["setup_error"] = f"Could not run migrations: {e}"
200
+ return Redirect(path="/setup/database")
201
+
202
+ request.session["setup_wizard_step"] = "auth"
203
+ request.session["flash"] = "Database configured successfully!"
204
+ return Redirect(path="/setup/auth")
205
+
206
+ except Exception as e:
207
+ request.session["setup_error"] = str(e)
208
+ return Redirect(path="/setup/database")
209
+
210
+ @get("/restart")
211
+ async def restart_step(self, request: Request) -> Redirect:
212
+ """Legacy restart route - now redirects to auth since restart is no longer required."""
213
+ request.session["setup_wizard_step"] = "auth"
214
+ return Redirect(path="/setup/auth")
215
+
216
+ @get("/auth")
217
+ async def auth_step(self, request: Request) -> TemplateResponse:
218
+ """Step 2: Authentication providers."""
219
+ flash = request.session.pop("flash", None)
220
+ error = request.session.pop("setup_error", None)
221
+
222
+ # Get current redirect URL from request
223
+ scheme = request.headers.get("x-forwarded-proto", request.url.scheme)
224
+ host = request.headers.get("host", request.url.netloc)
225
+ redirect_base_url = f"{scheme}://{host}"
226
+
227
+ # Get configured providers
228
+ config = load_config()
229
+ auth_config = config.get("auth", {})
230
+ configured_providers = auth_config.get("providers", {})
231
+
232
+ # Get all available providers
233
+ all_providers = get_all_providers()
234
+
235
+ return TemplateResponse(
236
+ "setup/auth.html",
237
+ context={
238
+ "flash": flash,
239
+ "error": error,
240
+ "step": 2,
241
+ "total_steps": 4,
242
+ "redirect_base_url": redirect_base_url,
243
+ "providers": all_providers,
244
+ "configured_providers": configured_providers,
245
+ },
246
+ )
247
+
248
+ @post("/auth")
249
+ async def save_auth(self, request: Request) -> Redirect:
250
+ """Save authentication configuration."""
251
+ form_data = await request.form()
252
+
253
+ # Get redirect base URL
254
+ redirect_base_url = form_data.get("redirect_base_url", "http://localhost:8000")
255
+
256
+ # Parse provider configurations
257
+ all_providers = get_all_providers()
258
+ providers = {}
259
+ use_env_vars = {}
260
+
261
+ for provider_key in all_providers.keys():
262
+ enabled = form_data.get(f"{provider_key}_enabled") == "on"
263
+ if not enabled:
264
+ continue
265
+
266
+ provider_info = all_providers[provider_key]
267
+ provider_config = {}
268
+ provider_env_vars = {}
269
+
270
+ for field in provider_info.fields:
271
+ field_key = field["key"]
272
+ value = form_data.get(f"{provider_key}_{field_key}", "")
273
+ use_env = form_data.get(f"{provider_key}_{field_key}_env") == "on"
274
+
275
+ if value or not field.get("optional"):
276
+ provider_config[field_key] = value
277
+ provider_env_vars[field_key] = use_env
278
+
279
+ if provider_config:
280
+ providers[provider_key] = provider_config
281
+ use_env_vars[provider_key] = provider_env_vars
282
+
283
+ if not providers:
284
+ request.session["setup_error"] = "Please configure at least one authentication provider"
285
+ return Redirect(path="/setup/auth")
286
+
287
+ try:
288
+ update_auth_config(
289
+ redirect_base_url=redirect_base_url,
290
+ providers=providers,
291
+ use_env_vars=use_env_vars,
292
+ )
293
+
294
+ request.session["setup_wizard_step"] = "site"
295
+ request.session["flash"] = "Authentication configured successfully!"
296
+ return Redirect(path="/setup/site")
297
+
298
+ except Exception as e:
299
+ request.session["setup_error"] = str(e)
300
+ return Redirect(path="/setup/auth")
301
+
302
+ @get("/site")
303
+ async def site_step(self, request: Request) -> TemplateResponse:
304
+ """Step 3: Site settings."""
305
+ flash = request.session.pop("flash", None)
306
+ error = request.session.pop("setup_error", None)
307
+
308
+ return TemplateResponse(
309
+ "setup/site.html",
310
+ context={
311
+ "flash": flash,
312
+ "error": error,
313
+ "step": 3,
314
+ "total_steps": 4,
315
+ "settings": {
316
+ "site_name": "",
317
+ "site_tagline": "",
318
+ "site_copyright_holder": "",
319
+ "site_copyright_start_year": datetime.now().year,
320
+ },
321
+ },
322
+ )
323
+
324
+ @post("/site")
325
+ async def save_site(self, request: Request) -> Redirect:
326
+ """Save site settings."""
327
+ form_data = await request.form()
328
+
329
+ try:
330
+ site_name = form_data.get("site_name", "").strip()
331
+ if not site_name:
332
+ request.session["setup_error"] = "Site name is required"
333
+ return Redirect(path="/setup/site")
334
+
335
+ site_tagline = form_data.get("site_tagline", "").strip()
336
+ site_copyright_holder = form_data.get("site_copyright_holder", "").strip()
337
+ site_copyright_start_year = form_data.get("site_copyright_start_year", "").strip()
338
+
339
+ # Save settings to database using manual session
340
+ async with get_setup_db_session() as db_session:
341
+ await setting_service.set_setting(
342
+ db_session, setting_service.SITE_NAME_KEY, site_name
343
+ )
344
+ await setting_service.set_setting(
345
+ db_session, setting_service.SITE_TAGLINE_KEY, site_tagline
346
+ )
347
+ await setting_service.set_setting(
348
+ db_session, setting_service.SITE_COPYRIGHT_HOLDER_KEY, site_copyright_holder
349
+ )
350
+ await setting_service.set_setting(
351
+ db_session,
352
+ setting_service.SITE_COPYRIGHT_START_YEAR_KEY,
353
+ site_copyright_start_year,
354
+ )
355
+
356
+ # Reload cache
357
+ await setting_service.load_site_settings_cache(db_session)
358
+
359
+ request.session["setup_wizard_step"] = "admin"
360
+ request.session["flash"] = "Site settings saved!"
361
+ return Redirect(path="/setup/admin")
362
+
363
+ except Exception as e:
364
+ request.session["setup_error"] = str(e)
365
+ return Redirect(path="/setup/site")
366
+
367
+ @get("/admin")
368
+ async def admin_step(self, request: Request) -> TemplateResponse:
369
+ """Step 4: Create admin account."""
370
+ flash = request.session.pop("flash", None)
371
+ error = request.session.pop("setup_error", None)
372
+
373
+ # Get configured providers
374
+ config = load_config()
375
+ auth_config = config.get("auth", {})
376
+ configured_providers = list(auth_config.get("providers", {}).keys())
377
+
378
+ # Get provider display info
379
+ all_providers = get_all_providers()
380
+ provider_info = {
381
+ key: all_providers[key] for key in configured_providers if key in all_providers
382
+ }
383
+
384
+ return TemplateResponse(
385
+ "setup/admin.html",
386
+ context={
387
+ "flash": flash,
388
+ "error": error,
389
+ "step": 4,
390
+ "total_steps": 4,
391
+ "providers": provider_info,
392
+ "configured_providers": configured_providers,
393
+ },
394
+ )
395
+
396
+ @get("/oauth/{provider:str}/login")
397
+ async def setup_oauth_login(self, request: Request, provider: str) -> Redirect:
398
+ """Redirect to OAuth provider for setup admin creation."""
399
+ config = load_config()
400
+ auth_config = config.get("auth", {})
401
+ providers_config = auth_config.get("providers", {})
402
+
403
+ if provider not in providers_config:
404
+ raise HTTPException(status_code=404, detail=f"Provider {provider} not configured")
405
+
406
+ provider_info = get_provider_info(provider)
407
+ if not provider_info:
408
+ raise HTTPException(status_code=404, detail=f"Unknown provider: {provider}")
409
+
410
+ provider_config = providers_config[provider]
411
+
412
+ # Resolve env var references in config
413
+ client_id = self._resolve_env_var(provider_config.get("client_id", ""))
414
+
415
+ # Generate CSRF state token
416
+ state = secrets.token_urlsafe(32)
417
+ request.session["oauth_state"] = state
418
+ request.session["oauth_provider"] = provider
419
+ request.session["oauth_setup"] = True
420
+
421
+ # Build redirect URI - use the standard /auth callback URL
422
+ # This matches what's configured in the OAuth provider console
423
+ scheme = request.headers.get("x-forwarded-proto", request.url.scheme)
424
+ host = request.headers.get("host", request.url.netloc)
425
+ redirect_uri = f"{scheme}://{host}/auth/{provider}/callback"
426
+
427
+ # Get scopes
428
+ scopes = provider_config.get("scopes", provider_info.scopes)
429
+
430
+ # Generate PKCE for Twitter
431
+ code_challenge = None
432
+ if provider == "twitter":
433
+ code_verifier = secrets.token_urlsafe(64)[:128]
434
+ request.session["oauth_code_verifier"] = code_verifier
435
+ code_challenge = base64.urlsafe_b64encode(
436
+ hashlib.sha256(code_verifier.encode()).digest()
437
+ ).decode().rstrip("=")
438
+
439
+ # Build auth URL
440
+ auth_url = provider_info.auth_url
441
+ if "{tenant}" in auth_url:
442
+ tenant = provider_config.get("tenant_id", "common")
443
+ if isinstance(tenant, str) and tenant.startswith("$"):
444
+ tenant = self._resolve_env_var(tenant) or "common"
445
+ auth_url = auth_url.replace("{tenant}", tenant)
446
+
447
+ params = {
448
+ "client_id": client_id,
449
+ "redirect_uri": redirect_uri,
450
+ "response_type": "code",
451
+ "scope": " ".join(scopes),
452
+ "state": state,
453
+ }
454
+
455
+ if provider == "google":
456
+ params["access_type"] = "offline"
457
+ params["prompt"] = "select_account"
458
+ elif provider == "twitter" and code_challenge:
459
+ params["code_challenge"] = code_challenge
460
+ params["code_challenge_method"] = "S256"
461
+ elif provider == "discord":
462
+ params["prompt"] = "consent"
463
+
464
+ return Redirect(path=f"{auth_url}?{urlencode(params)}")
465
+
466
+ def _resolve_env_var(self, value: str) -> str:
467
+ """Resolve environment variable reference if value starts with $."""
468
+ import os
469
+ if value.startswith("$"):
470
+ return os.environ.get(value[1:], "")
471
+ return value
472
+
473
+ @get("/complete")
474
+ async def complete(self, request: Request) -> TemplateResponse | Redirect:
475
+ """Setup complete page."""
476
+ # Verify setup is actually complete in database
477
+ if not await self._check_already_complete():
478
+ return Redirect(path="/setup")
479
+
480
+ # Clear the session flag if present
481
+ request.session.pop("setup_just_completed", None)
482
+
483
+ return TemplateResponse(
484
+ "setup/complete.html",
485
+ context={
486
+ "step": 4,
487
+ "total_steps": 4,
488
+ },
489
+ )
490
+
491
+
492
+ async def mark_setup_complete(db_session: AsyncSession | None = None) -> None:
493
+ """Mark setup as complete by setting the timestamp.
494
+
495
+ Args:
496
+ db_session: Optional database session. If not provided, creates one.
497
+ """
498
+ timestamp = datetime.now(UTC).isoformat()
499
+ if db_session:
500
+ await setting_service.set_setting(db_session, SETUP_COMPLETED_AT_KEY, timestamp)
501
+ else:
502
+ async with get_setup_db_session() as session:
503
+ await setting_service.set_setting(session, SETUP_COMPLETED_AT_KEY, timestamp)
504
+
505
+
506
+ class SetupAuthController(Controller):
507
+ """Auth controller for setup OAuth callbacks.
508
+
509
+ This handles OAuth callbacks at /auth/{provider}/callback during setup,
510
+ matching the redirect URI configured in OAuth providers.
511
+ """
512
+
513
+ path = "/auth"
514
+
515
+ @get("/{provider:str}/callback")
516
+ async def setup_oauth_callback(
517
+ self,
518
+ request: Request,
519
+ provider: str,
520
+ code: str | None = None,
521
+ oauth_state: Annotated[str | None, Parameter(query="state")] = None,
522
+ error: str | None = None,
523
+ ) -> Redirect:
524
+ """Handle OAuth callback during setup."""
525
+ # Check if this is a setup flow
526
+ if not request.session.get("oauth_setup"):
527
+ # Not a setup flow, return error
528
+ raise HTTPException(status_code=400, detail="Invalid OAuth flow")
529
+
530
+ if error:
531
+ request.session["setup_error"] = f"OAuth error: {error}"
532
+ return Redirect(path="/setup/admin")
533
+
534
+ # Verify CSRF state
535
+ stored_state = request.session.pop("oauth_state", None)
536
+ if not oauth_state or oauth_state != stored_state:
537
+ raise HTTPException(status_code=400, detail="Invalid OAuth state")
538
+
539
+ if not code:
540
+ raise HTTPException(status_code=400, detail="Missing authorization code")
541
+
542
+ config = load_config()
543
+ auth_config = config.get("auth", {})
544
+ providers_config = auth_config.get("providers", {})
545
+
546
+ if provider not in providers_config:
547
+ raise HTTPException(status_code=404, detail=f"Provider {provider} not configured")
548
+
549
+ provider_info = get_provider_info(provider)
550
+ provider_config = providers_config[provider]
551
+
552
+ # Resolve env vars
553
+ client_id = self._resolve_env_var(provider_config.get("client_id", ""))
554
+ client_secret = self._resolve_env_var(provider_config.get("client_secret", ""))
555
+
556
+ # Build redirect URI
557
+ scheme = request.headers.get("x-forwarded-proto", request.url.scheme)
558
+ host = request.headers.get("host", request.url.netloc)
559
+ redirect_uri = f"{scheme}://{host}/auth/{provider}/callback"
560
+
561
+ # Get PKCE verifier if present
562
+ code_verifier = request.session.pop("oauth_code_verifier", None)
563
+
564
+ # Exchange code for token
565
+ token_url = provider_info.token_url
566
+ if "{tenant}" in token_url:
567
+ tenant = provider_config.get("tenant_id", "common")
568
+ if isinstance(tenant, str) and tenant.startswith("$"):
569
+ tenant = self._resolve_env_var(tenant) or "common"
570
+ token_url = token_url.replace("{tenant}", tenant)
571
+
572
+ data = {
573
+ "client_id": client_id,
574
+ "client_secret": client_secret,
575
+ "code": code,
576
+ "grant_type": "authorization_code",
577
+ "redirect_uri": redirect_uri,
578
+ }
579
+
580
+ if provider == "twitter" and code_verifier:
581
+ data["code_verifier"] = code_verifier
582
+
583
+ headers = {"Accept": "application/json"}
584
+ if provider == "github":
585
+ headers["Accept"] = "application/json"
586
+ if provider == "twitter":
587
+ credentials = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
588
+ headers["Authorization"] = f"Basic {credentials}"
589
+ del data["client_secret"]
590
+
591
+ async with httpx.AsyncClient() as client:
592
+ response = await client.post(token_url, data=data, headers=headers)
593
+ if response.status_code != 200:
594
+ raise HTTPException(status_code=400, detail=f"Token exchange failed: {response.text}")
595
+ tokens = response.json()
596
+
597
+ access_token = tokens.get("access_token")
598
+ if not access_token:
599
+ raise HTTPException(status_code=400, detail="No access token received")
600
+
601
+ # Fetch user info
602
+ async with httpx.AsyncClient() as client:
603
+ user_headers = {"Authorization": f"Bearer {access_token}"}
604
+ response = await client.get(provider_info.userinfo_url, headers=user_headers)
605
+ if response.status_code != 200:
606
+ raise HTTPException(status_code=400, detail="Failed to fetch user info")
607
+ user_info = response.json()
608
+
609
+ # GitHub email handling
610
+ if provider == "github" and not user_info.get("email"):
611
+ email_response = await client.get("https://api.github.com/user/emails", headers=user_headers)
612
+ if email_response.status_code == 200:
613
+ emails = email_response.json()
614
+ primary_email = next((e["email"] for e in emails if e.get("primary")), None)
615
+ if primary_email:
616
+ user_info["email"] = primary_email
617
+
618
+ # Extract user data based on provider
619
+ user_data = self._extract_user_data(provider, user_info)
620
+ oauth_id = user_data["oauth_id"]
621
+ if not oauth_id:
622
+ raise HTTPException(status_code=400, detail="Could not determine user ID")
623
+
624
+ # Create user and mark setup complete
625
+ async with get_setup_db_session() as db_session:
626
+ # Check if user exists
627
+ result = await db_session.execute(
628
+ select(User).where(User.oauth_id == oauth_id, User.oauth_provider == provider)
629
+ )
630
+ user = result.scalar_one_or_none()
631
+
632
+ if user:
633
+ user.name = user_data["name"]
634
+ if user_data["picture_url"]:
635
+ user.picture_url = user_data["picture_url"]
636
+ user.last_login_at = datetime.now(UTC)
637
+ else:
638
+ # Create new user
639
+ user = User(
640
+ oauth_provider=provider,
641
+ oauth_id=oauth_id,
642
+ email=user_data["email"],
643
+ name=user_data["name"],
644
+ picture_url=user_data["picture_url"],
645
+ last_login_at=datetime.now(UTC),
646
+ )
647
+ db_session.add(user)
648
+ await db_session.flush()
649
+
650
+ # Ensure roles are synced (they may not exist if DB was created after server start)
651
+ from skrift.auth import sync_roles_to_database
652
+ await sync_roles_to_database(db_session)
653
+
654
+ # Always assign admin role during setup (whether user is new or existing)
655
+ admin_role = await db_session.scalar(select(Role).where(Role.name == "admin"))
656
+ if admin_role:
657
+ # Check if user already has admin role
658
+ existing = await db_session.execute(
659
+ select(user_roles).where(
660
+ user_roles.c.user_id == user.id,
661
+ user_roles.c.role_id == admin_role.id
662
+ )
663
+ )
664
+ if not existing.first():
665
+ await db_session.execute(
666
+ user_roles.insert().values(user_id=user.id, role_id=admin_role.id)
667
+ )
668
+
669
+ # Mark setup complete
670
+ timestamp = datetime.now(UTC).isoformat()
671
+ await setting_service.set_setting(db_session, SETUP_COMPLETED_AT_KEY, timestamp)
672
+
673
+ # Clear setup flag
674
+ request.session.pop("oauth_setup", None)
675
+
676
+ # Set session
677
+ request.session["user_id"] = str(user.id)
678
+ request.session["user_name"] = user.name
679
+ request.session["user_email"] = user.email
680
+ request.session["user_picture_url"] = user.picture_url
681
+ request.session["flash"] = "Admin account created successfully!"
682
+ request.session["setup_just_completed"] = True
683
+
684
+ # Note: Don't call mark_setup_complete_in_dispatcher() here.
685
+ # The switch happens in /setup/complete after rendering the page.
686
+
687
+ return Redirect(path="/setup/complete")
688
+
689
+ def _resolve_env_var(self, value: str) -> str:
690
+ """Resolve environment variable reference if value starts with $."""
691
+ import os
692
+ if value.startswith("$"):
693
+ return os.environ.get(value[1:], "")
694
+ return value
695
+
696
+ def _extract_user_data(self, provider: str, user_info: dict) -> dict:
697
+ """Extract normalized user data from provider response."""
698
+ if provider == "google":
699
+ return {
700
+ "oauth_id": user_info.get("id"),
701
+ "email": user_info.get("email"),
702
+ "name": user_info.get("name"),
703
+ "picture_url": user_info.get("picture"),
704
+ }
705
+ elif provider == "github":
706
+ return {
707
+ "oauth_id": str(user_info.get("id")),
708
+ "email": user_info.get("email"),
709
+ "name": user_info.get("name") or user_info.get("login"),
710
+ "picture_url": user_info.get("avatar_url"),
711
+ }
712
+ elif provider == "microsoft":
713
+ return {
714
+ "oauth_id": user_info.get("id"),
715
+ "email": user_info.get("mail") or user_info.get("userPrincipalName"),
716
+ "name": user_info.get("displayName"),
717
+ "picture_url": None,
718
+ }
719
+ elif provider == "discord":
720
+ avatar = user_info.get("avatar")
721
+ user_id = user_info.get("id")
722
+ avatar_url = f"https://cdn.discordapp.com/avatars/{user_id}/{avatar}.png" if avatar and user_id else None
723
+ return {
724
+ "oauth_id": user_id,
725
+ "email": user_info.get("email"),
726
+ "name": user_info.get("global_name") or user_info.get("username"),
727
+ "picture_url": avatar_url,
728
+ }
729
+ elif provider == "facebook":
730
+ picture = user_info.get("picture", {}).get("data", {})
731
+ return {
732
+ "oauth_id": user_info.get("id"),
733
+ "email": user_info.get("email"),
734
+ "name": user_info.get("name"),
735
+ "picture_url": picture.get("url") if not picture.get("is_silhouette") else None,
736
+ }
737
+ elif provider == "twitter":
738
+ data = user_info.get("data", user_info)
739
+ return {
740
+ "oauth_id": data.get("id"),
741
+ "email": data.get("email"),
742
+ "name": data.get("name") or data.get("username"),
743
+ "picture_url": None,
744
+ }
745
+ else:
746
+ return {
747
+ "oauth_id": str(user_info.get("id", user_info.get("sub"))),
748
+ "email": user_info.get("email"),
749
+ "name": user_info.get("name"),
750
+ "picture_url": user_info.get("picture"),
751
+ }