skrift 0.1.0a1__py3-none-any.whl → 0.1.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- skrift/alembic/env.py +2 -1
- skrift/alembic/versions/20260129_add_oauth_accounts.py +134 -0
- skrift/alembic.ini +2 -2
- skrift/asgi.py +19 -11
- skrift/cli.py +22 -13
- skrift/config.py +59 -5
- skrift/controllers/auth.py +168 -22
- skrift/db/models/__init__.py +2 -1
- skrift/db/models/oauth_account.py +37 -0
- skrift/db/models/user.py +5 -5
- skrift/setup/config_writer.py +4 -2
- skrift/setup/controller.py +209 -72
- skrift/setup/providers.py +53 -2
- skrift/setup/state.py +185 -4
- skrift/static/css/style.css +3 -3
- skrift/templates/auth/dummy_login.html +102 -0
- skrift/templates/auth/login.html +14 -0
- skrift/templates/setup/configuring.html +158 -0
- {skrift-0.1.0a1.dist-info → skrift-0.1.0a3.dist-info}/METADATA +3 -1
- {skrift-0.1.0a1.dist-info → skrift-0.1.0a3.dist-info}/RECORD +22 -18
- {skrift-0.1.0a1.dist-info → skrift-0.1.0a3.dist-info}/WHEEL +0 -0
- {skrift-0.1.0a1.dist-info → skrift-0.1.0a3.dist-info}/entry_points.txt +0 -0
skrift/alembic/env.py
CHANGED
|
@@ -12,6 +12,7 @@ from skrift.config import get_settings
|
|
|
12
12
|
from skrift.db.base import Base
|
|
13
13
|
|
|
14
14
|
# Import all models to ensure they're registered with Base.metadata
|
|
15
|
+
from skrift.db.models.oauth_account import OAuthAccount # noqa: F401
|
|
15
16
|
from skrift.db.models.user import User # noqa: F401
|
|
16
17
|
from skrift.db.models.page import Page # noqa: F401
|
|
17
18
|
from skrift.db.models.role import Role, RolePermission # noqa: F401
|
|
@@ -31,7 +32,7 @@ def get_url() -> str:
|
|
|
31
32
|
"""Get database URL from settings or alembic.ini."""
|
|
32
33
|
try:
|
|
33
34
|
settings = get_settings()
|
|
34
|
-
return settings.
|
|
35
|
+
return settings.db.url
|
|
35
36
|
except Exception:
|
|
36
37
|
# Fall back to alembic.ini config if settings can't be loaded
|
|
37
38
|
return config.get_main_option("sqlalchemy.url", "")
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
"""add oauth_accounts table
|
|
2
|
+
|
|
3
|
+
Revision ID: 1a2b3c4d5e6f
|
|
4
|
+
Revises: 8f3a5c2d1e0b
|
|
5
|
+
Create Date: 2026-01-29 10:00:00.000000
|
|
6
|
+
|
|
7
|
+
This migration:
|
|
8
|
+
1. Creates the oauth_accounts table to store multiple OAuth identities per user
|
|
9
|
+
2. Migrates existing oauth data from users table to oauth_accounts
|
|
10
|
+
3. Makes users.email nullable (for providers like Twitter that don't provide email)
|
|
11
|
+
4. Removes oauth_provider and oauth_id columns from users table
|
|
12
|
+
"""
|
|
13
|
+
from typing import Sequence, Union
|
|
14
|
+
|
|
15
|
+
from alembic import op
|
|
16
|
+
import sqlalchemy as sa
|
|
17
|
+
from advanced_alchemy.types import GUID, DateTimeUTC
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# revision identifiers, used by Alembic.
|
|
21
|
+
revision: str = '1a2b3c4d5e6f'
|
|
22
|
+
down_revision: Union[str, None] = '8f3a5c2d1e0b'
|
|
23
|
+
branch_labels: Union[str, Sequence[str], None] = None
|
|
24
|
+
depends_on: Union[str, Sequence[str], None] = None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def upgrade() -> None:
|
|
28
|
+
# Step 1: Create oauth_accounts table
|
|
29
|
+
op.create_table(
|
|
30
|
+
'oauth_accounts',
|
|
31
|
+
sa.Column('id', GUID(length=16), nullable=False),
|
|
32
|
+
sa.Column('created_at', DateTimeUTC(timezone=True), nullable=False),
|
|
33
|
+
sa.Column('updated_at', DateTimeUTC(timezone=True), nullable=False),
|
|
34
|
+
sa.Column('sa_orm_sentinel', sa.Integer(), nullable=True),
|
|
35
|
+
sa.Column('provider', sa.String(length=50), nullable=False),
|
|
36
|
+
sa.Column('provider_account_id', sa.String(length=255), nullable=False),
|
|
37
|
+
sa.Column('provider_email', sa.String(length=255), nullable=True),
|
|
38
|
+
sa.Column('user_id', GUID(length=16), nullable=False),
|
|
39
|
+
sa.PrimaryKeyConstraint('id', name=op.f('pk_oauth_accounts')),
|
|
40
|
+
sa.ForeignKeyConstraint(
|
|
41
|
+
['user_id'], ['users.id'],
|
|
42
|
+
name=op.f('fk_oauth_accounts_user_id_users'),
|
|
43
|
+
ondelete='CASCADE'
|
|
44
|
+
),
|
|
45
|
+
sa.UniqueConstraint(
|
|
46
|
+
'provider', 'provider_account_id',
|
|
47
|
+
name='uq_oauth_provider_account'
|
|
48
|
+
),
|
|
49
|
+
)
|
|
50
|
+
op.create_index(
|
|
51
|
+
op.f('ix_oauth_accounts_user_id'),
|
|
52
|
+
'oauth_accounts',
|
|
53
|
+
['user_id'],
|
|
54
|
+
unique=False
|
|
55
|
+
)
|
|
56
|
+
op.create_index(
|
|
57
|
+
op.f('ix_oauth_accounts_provider_account'),
|
|
58
|
+
'oauth_accounts',
|
|
59
|
+
['provider', 'provider_account_id'],
|
|
60
|
+
unique=True
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# Step 2: Migrate existing data from users to oauth_accounts
|
|
64
|
+
# Generate binary UUIDs (16 bytes) for new records and copy oauth data
|
|
65
|
+
conn = op.get_bind()
|
|
66
|
+
conn.execute(sa.text("""
|
|
67
|
+
INSERT INTO oauth_accounts (id, created_at, updated_at, provider, provider_account_id, provider_email, user_id)
|
|
68
|
+
SELECT
|
|
69
|
+
randomblob(16),
|
|
70
|
+
created_at,
|
|
71
|
+
updated_at,
|
|
72
|
+
oauth_provider,
|
|
73
|
+
oauth_id,
|
|
74
|
+
email,
|
|
75
|
+
id
|
|
76
|
+
FROM users
|
|
77
|
+
WHERE oauth_provider IS NOT NULL AND oauth_id IS NOT NULL
|
|
78
|
+
"""))
|
|
79
|
+
|
|
80
|
+
# Step 3: Make email nullable on users table
|
|
81
|
+
# SQLite doesn't support ALTER COLUMN, so we need to recreate the table
|
|
82
|
+
# For SQLite, we'll use batch_alter_table
|
|
83
|
+
with op.batch_alter_table('users', schema=None) as batch_op:
|
|
84
|
+
# Drop the unique constraint on oauth_id
|
|
85
|
+
batch_op.drop_constraint('uq_users_oauth_id', type_='unique')
|
|
86
|
+
# Drop the oauth columns
|
|
87
|
+
batch_op.drop_column('oauth_provider')
|
|
88
|
+
batch_op.drop_column('oauth_id')
|
|
89
|
+
# Make email nullable - this requires recreating the column in SQLite
|
|
90
|
+
batch_op.alter_column('email',
|
|
91
|
+
existing_type=sa.String(length=255),
|
|
92
|
+
nullable=True)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def downgrade() -> None:
|
|
96
|
+
# Step 1: Add back oauth columns to users table
|
|
97
|
+
with op.batch_alter_table('users', schema=None) as batch_op:
|
|
98
|
+
batch_op.add_column(sa.Column('oauth_provider', sa.String(length=50), nullable=True))
|
|
99
|
+
batch_op.add_column(sa.Column('oauth_id', sa.String(length=255), nullable=True))
|
|
100
|
+
batch_op.alter_column('email',
|
|
101
|
+
existing_type=sa.String(length=255),
|
|
102
|
+
nullable=False)
|
|
103
|
+
|
|
104
|
+
# Step 2: Migrate data back from oauth_accounts to users
|
|
105
|
+
# Only migrate the first oauth account per user
|
|
106
|
+
conn = op.get_bind()
|
|
107
|
+
conn.execute(sa.text("""
|
|
108
|
+
UPDATE users
|
|
109
|
+
SET oauth_provider = (
|
|
110
|
+
SELECT provider FROM oauth_accounts
|
|
111
|
+
WHERE oauth_accounts.user_id = users.id
|
|
112
|
+
LIMIT 1
|
|
113
|
+
),
|
|
114
|
+
oauth_id = (
|
|
115
|
+
SELECT provider_account_id FROM oauth_accounts
|
|
116
|
+
WHERE oauth_accounts.user_id = users.id
|
|
117
|
+
LIMIT 1
|
|
118
|
+
)
|
|
119
|
+
"""))
|
|
120
|
+
|
|
121
|
+
# Step 3: Make oauth columns non-nullable and add unique constraint
|
|
122
|
+
with op.batch_alter_table('users', schema=None) as batch_op:
|
|
123
|
+
batch_op.alter_column('oauth_provider',
|
|
124
|
+
existing_type=sa.String(length=50),
|
|
125
|
+
nullable=False)
|
|
126
|
+
batch_op.alter_column('oauth_id',
|
|
127
|
+
existing_type=sa.String(length=255),
|
|
128
|
+
nullable=False)
|
|
129
|
+
batch_op.create_unique_constraint('uq_users_oauth_id', ['oauth_id'])
|
|
130
|
+
|
|
131
|
+
# Step 4: Drop oauth_accounts table
|
|
132
|
+
op.drop_index(op.f('ix_oauth_accounts_provider_account'), table_name='oauth_accounts')
|
|
133
|
+
op.drop_index(op.f('ix_oauth_accounts_user_id'), table_name='oauth_accounts')
|
|
134
|
+
op.drop_table('oauth_accounts')
|
skrift/alembic.ini
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
# Alembic Configuration File
|
|
2
2
|
|
|
3
3
|
[alembic]
|
|
4
|
-
# Path to migration scripts
|
|
5
|
-
script_location = alembic
|
|
4
|
+
# Path to migration scripts (relative to this file)
|
|
5
|
+
script_location = %(here)s/alembic
|
|
6
6
|
|
|
7
7
|
# Template used to generate migration files
|
|
8
8
|
file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d%%(second).2d_%%(rev)s_%%(slug)s
|
skrift/asgi.py
CHANGED
|
@@ -29,7 +29,7 @@ from litestar.static_files import create_static_files_router
|
|
|
29
29
|
from litestar.template import TemplateConfig
|
|
30
30
|
from litestar.types import ASGIApp, Receive, Scope, Send
|
|
31
31
|
|
|
32
|
-
from skrift.config import get_settings, is_config_valid
|
|
32
|
+
from skrift.config import get_config_path, get_settings, is_config_valid
|
|
33
33
|
from skrift.db.base import Base
|
|
34
34
|
from skrift.db.services.setting_service import (
|
|
35
35
|
load_site_settings_cache,
|
|
@@ -45,7 +45,7 @@ from skrift.lib.exceptions import http_exception_handler, internal_server_error_
|
|
|
45
45
|
|
|
46
46
|
def load_controllers() -> list:
|
|
47
47
|
"""Load controllers from app.yaml configuration."""
|
|
48
|
-
config_path =
|
|
48
|
+
config_path = get_config_path()
|
|
49
49
|
|
|
50
50
|
if not config_path.exists():
|
|
51
51
|
return []
|
|
@@ -196,12 +196,7 @@ class AppDispatcher:
|
|
|
196
196
|
await self.setup_app(scope, receive, send)
|
|
197
197
|
return
|
|
198
198
|
|
|
199
|
-
#
|
|
200
|
-
if path.startswith("/auth"):
|
|
201
|
-
await self.setup_app(scope, receive, send)
|
|
202
|
-
return
|
|
203
|
-
|
|
204
|
-
# Non-setup path: check if setup is complete in DB
|
|
199
|
+
# Check if setup is complete in DB
|
|
205
200
|
if await self._is_setup_complete_in_db():
|
|
206
201
|
# Setup complete - try to get/create main app
|
|
207
202
|
main_app = await self._get_or_create_main_app()
|
|
@@ -215,8 +210,13 @@ class AppDispatcher:
|
|
|
215
210
|
f"Setup complete but cannot start application: {self._main_app_error}"
|
|
216
211
|
)
|
|
217
212
|
else:
|
|
218
|
-
# Setup not complete
|
|
219
|
-
|
|
213
|
+
# Setup not complete
|
|
214
|
+
# Route /auth/* to setup app for OAuth callbacks during setup
|
|
215
|
+
if path.startswith("/auth"):
|
|
216
|
+
await self.setup_app(scope, receive, send)
|
|
217
|
+
else:
|
|
218
|
+
# Redirect other paths to /setup
|
|
219
|
+
await self._redirect(send, "/setup")
|
|
220
220
|
|
|
221
221
|
async def _is_setup_complete_in_db(self) -> bool:
|
|
222
222
|
"""Check if setup is complete in the database."""
|
|
@@ -270,6 +270,10 @@ def create_app() -> Litestar:
|
|
|
270
270
|
This app has all routes for normal operation. It is used by the dispatcher
|
|
271
271
|
after setup is complete.
|
|
272
272
|
"""
|
|
273
|
+
# CRITICAL: Check for dummy auth in production BEFORE anything else
|
|
274
|
+
from skrift.setup.providers import validate_no_dummy_auth_in_production
|
|
275
|
+
validate_no_dummy_auth_in_production()
|
|
276
|
+
|
|
273
277
|
settings = get_settings()
|
|
274
278
|
|
|
275
279
|
# Load controllers from app.yaml
|
|
@@ -404,7 +408,7 @@ def create_setup_app() -> Litestar:
|
|
|
404
408
|
|
|
405
409
|
# Also try to get the raw db URL from config (before env var resolution)
|
|
406
410
|
if not db_url:
|
|
407
|
-
config_path =
|
|
411
|
+
config_path = get_config_path()
|
|
408
412
|
if config_path.exists():
|
|
409
413
|
try:
|
|
410
414
|
with open(config_path, "r") as f:
|
|
@@ -493,6 +497,10 @@ def create_dispatcher() -> ASGIApp:
|
|
|
493
497
|
This is the main entry point. The dispatcher handles routing between
|
|
494
498
|
setup and main apps, with lazy creation of the main app after setup completes.
|
|
495
499
|
"""
|
|
500
|
+
# CRITICAL: Check for dummy auth in production BEFORE anything else
|
|
501
|
+
from skrift.setup.providers import validate_no_dummy_auth_in_production
|
|
502
|
+
validate_no_dummy_auth_in_production()
|
|
503
|
+
|
|
496
504
|
global _dispatcher
|
|
497
505
|
from skrift.setup.state import get_database_url_from_yaml
|
|
498
506
|
|
skrift/cli.py
CHANGED
|
@@ -19,24 +19,33 @@ def db() -> None:
|
|
|
19
19
|
"""
|
|
20
20
|
from alembic.config import main as alembic_main
|
|
21
21
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
22
|
+
import os
|
|
23
|
+
|
|
24
|
+
# Always run from the project root (where app.yaml and .env are)
|
|
25
|
+
# This ensures database paths like ./app.db resolve correctly
|
|
26
|
+
project_root = Path.cwd()
|
|
27
|
+
if not (project_root / "app.yaml").exists():
|
|
28
|
+
# If not in project root, try parent directory
|
|
29
|
+
project_root = Path(__file__).parent.parent
|
|
30
|
+
os.chdir(project_root)
|
|
31
|
+
|
|
32
|
+
# Find alembic.ini - check project root first, then skrift package directory
|
|
33
|
+
alembic_ini = project_root / "alembic.ini"
|
|
26
34
|
if not alembic_ini.exists():
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
if alembic_ini.exists():
|
|
32
|
-
# Change to the directory containing alembic.ini
|
|
33
|
-
import os
|
|
34
|
-
os.chdir(module_dir)
|
|
35
|
-
else:
|
|
35
|
+
skrift_dir = Path(__file__).parent
|
|
36
|
+
alembic_ini = skrift_dir / "alembic.ini"
|
|
37
|
+
|
|
38
|
+
if not alembic_ini.exists():
|
|
36
39
|
print("Error: Could not find alembic.ini", file=sys.stderr)
|
|
37
40
|
print("Make sure you're running from the project root directory.", file=sys.stderr)
|
|
38
41
|
sys.exit(1)
|
|
39
42
|
|
|
43
|
+
# Build argv with config path at the beginning (before any subcommand)
|
|
44
|
+
# Original argv: ['skrift-db', 'upgrade', 'head']
|
|
45
|
+
# New argv: ['skrift-db', '-c', '/path/to/alembic.ini', 'upgrade', 'head']
|
|
46
|
+
new_argv = [sys.argv[0], "-c", str(alembic_ini)] + sys.argv[1:]
|
|
47
|
+
sys.argv = new_argv
|
|
48
|
+
|
|
40
49
|
# Pass through all CLI arguments to Alembic
|
|
41
50
|
sys.exit(alembic_main(sys.argv[1:]))
|
|
42
51
|
|
skrift/config.py
CHANGED
|
@@ -16,6 +16,31 @@ load_dotenv(_env_file)
|
|
|
16
16
|
# Pattern to match $VAR_NAME environment variable references
|
|
17
17
|
ENV_VAR_PATTERN = re.compile(r"\$([A-Z_][A-Z0-9_]*)")
|
|
18
18
|
|
|
19
|
+
# Environment configuration
|
|
20
|
+
SKRIFT_ENV = "SKRIFT_ENV"
|
|
21
|
+
DEFAULT_ENVIRONMENT = "production"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_environment() -> str:
|
|
25
|
+
"""Get the current environment name, normalized to lowercase.
|
|
26
|
+
|
|
27
|
+
Reads from SKRIFT_ENV environment variable. Defaults to "production".
|
|
28
|
+
"""
|
|
29
|
+
env = os.environ.get(SKRIFT_ENV, DEFAULT_ENVIRONMENT)
|
|
30
|
+
return env.lower().strip()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_config_path() -> Path:
|
|
34
|
+
"""Get the path to the environment-specific config file.
|
|
35
|
+
|
|
36
|
+
Production -> app.yaml
|
|
37
|
+
Other envs -> app.{env}.yaml (e.g., app.dev.yaml)
|
|
38
|
+
"""
|
|
39
|
+
env = get_environment()
|
|
40
|
+
if env == "production":
|
|
41
|
+
return Path.cwd() / "app.yaml"
|
|
42
|
+
return Path.cwd() / f"app.{env}.yaml"
|
|
43
|
+
|
|
19
44
|
|
|
20
45
|
def interpolate_env_vars(value, strict: bool = True):
|
|
21
46
|
"""Recursively replace $VAR_NAME with os.environ values.
|
|
@@ -54,10 +79,10 @@ def load_app_config(interpolate: bool = True, strict: bool = True) -> dict:
|
|
|
54
79
|
Returns:
|
|
55
80
|
Parsed configuration dictionary
|
|
56
81
|
"""
|
|
57
|
-
config_path =
|
|
82
|
+
config_path = get_config_path()
|
|
58
83
|
|
|
59
84
|
if not config_path.exists():
|
|
60
|
-
raise FileNotFoundError(f"
|
|
85
|
+
raise FileNotFoundError(f"{config_path.name} not found at {config_path}")
|
|
61
86
|
|
|
62
87
|
with open(config_path, "r") as f:
|
|
63
88
|
config = yaml.safe_load(f)
|
|
@@ -69,7 +94,7 @@ def load_app_config(interpolate: bool = True, strict: bool = True) -> dict:
|
|
|
69
94
|
|
|
70
95
|
def load_raw_app_config() -> dict | None:
|
|
71
96
|
"""Load app.yaml without any processing. Returns None if file doesn't exist."""
|
|
72
|
-
config_path =
|
|
97
|
+
config_path = get_config_path()
|
|
73
98
|
|
|
74
99
|
if not config_path.exists():
|
|
75
100
|
return None
|
|
@@ -98,11 +123,40 @@ class OAuthProviderConfig(BaseModel):
|
|
|
98
123
|
tenant_id: str | None = None
|
|
99
124
|
|
|
100
125
|
|
|
126
|
+
class DummyProviderConfig(BaseModel):
|
|
127
|
+
"""Dummy provider configuration (no credentials required)."""
|
|
128
|
+
|
|
129
|
+
pass
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
# Union type for provider configs - dummy has no required fields
|
|
133
|
+
ProviderConfig = OAuthProviderConfig | DummyProviderConfig
|
|
134
|
+
|
|
135
|
+
|
|
101
136
|
class AuthConfig(BaseModel):
|
|
102
137
|
"""Authentication configuration."""
|
|
103
138
|
|
|
104
139
|
redirect_base_url: str = "http://localhost:8000"
|
|
105
|
-
providers: dict[str,
|
|
140
|
+
providers: dict[str, ProviderConfig] = {}
|
|
141
|
+
|
|
142
|
+
@classmethod
|
|
143
|
+
def _parse_provider(cls, name: str, config: dict) -> ProviderConfig:
|
|
144
|
+
"""Parse a provider config, using the appropriate model based on provider name."""
|
|
145
|
+
if name == "dummy":
|
|
146
|
+
return DummyProviderConfig(**config)
|
|
147
|
+
return OAuthProviderConfig(**config)
|
|
148
|
+
|
|
149
|
+
def __init__(self, **data):
|
|
150
|
+
# Convert raw provider dicts to appropriate config objects
|
|
151
|
+
if "providers" in data and isinstance(data["providers"], dict):
|
|
152
|
+
parsed_providers = {}
|
|
153
|
+
for name, config in data["providers"].items():
|
|
154
|
+
if isinstance(config, dict):
|
|
155
|
+
parsed_providers[name] = self._parse_provider(name, config)
|
|
156
|
+
else:
|
|
157
|
+
parsed_providers[name] = config
|
|
158
|
+
data["providers"] = parsed_providers
|
|
159
|
+
super().__init__(**data)
|
|
106
160
|
|
|
107
161
|
def get_redirect_uri(self, provider: str) -> str:
|
|
108
162
|
"""Get the OAuth callback URL for a provider."""
|
|
@@ -137,7 +191,7 @@ def is_config_valid() -> tuple[bool, str | None]:
|
|
|
137
191
|
try:
|
|
138
192
|
config = load_raw_app_config()
|
|
139
193
|
if config is None:
|
|
140
|
-
return False, "
|
|
194
|
+
return False, f"{get_config_path().name} not found"
|
|
141
195
|
|
|
142
196
|
# Check database URL
|
|
143
197
|
db_config = config.get("db", {})
|
skrift/controllers/auth.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Authentication controller for OAuth login flows.
|
|
2
2
|
|
|
3
3
|
Supports multiple OAuth providers: Google, GitHub, Microsoft, Discord, Facebook, X (Twitter).
|
|
4
|
+
Also supports a development-only "dummy" provider for testing.
|
|
4
5
|
"""
|
|
5
6
|
|
|
6
7
|
import base64
|
|
@@ -11,16 +12,18 @@ from typing import Annotated
|
|
|
11
12
|
from urllib.parse import urlencode
|
|
12
13
|
|
|
13
14
|
import httpx
|
|
14
|
-
from litestar import Controller, Request, get
|
|
15
|
+
from litestar import Controller, Request, get, post
|
|
15
16
|
from litestar.exceptions import HTTPException, NotFoundException
|
|
16
17
|
from litestar.params import Parameter
|
|
17
18
|
from litestar.response import Redirect, Template as TemplateResponse
|
|
18
19
|
from sqlalchemy import select
|
|
19
20
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
21
|
+
from sqlalchemy.orm import selectinload
|
|
20
22
|
|
|
21
23
|
from skrift.config import get_settings
|
|
24
|
+
from skrift.db.models.oauth_account import OAuthAccount
|
|
22
25
|
from skrift.db.models.user import User
|
|
23
|
-
from skrift.setup.providers import OAUTH_PROVIDERS, get_provider_info
|
|
26
|
+
from skrift.setup.providers import DUMMY_PROVIDER_KEY, OAUTH_PROVIDERS, get_provider_info
|
|
24
27
|
|
|
25
28
|
|
|
26
29
|
def get_auth_url(provider: str, settings, state: str, code_challenge: str | None = None) -> str:
|
|
@@ -227,8 +230,8 @@ class AuthController(Controller):
|
|
|
227
230
|
self,
|
|
228
231
|
request: Request,
|
|
229
232
|
provider: str,
|
|
230
|
-
) -> Redirect:
|
|
231
|
-
"""Redirect to OAuth provider consent screen."""
|
|
233
|
+
) -> Redirect | TemplateResponse:
|
|
234
|
+
"""Redirect to OAuth provider consent screen, or show dummy login form."""
|
|
232
235
|
settings = get_settings()
|
|
233
236
|
provider_info = get_provider_info(provider)
|
|
234
237
|
|
|
@@ -238,6 +241,14 @@ class AuthController(Controller):
|
|
|
238
241
|
if provider not in settings.auth.providers:
|
|
239
242
|
raise NotFoundException(f"Provider {provider} not configured")
|
|
240
243
|
|
|
244
|
+
# Dummy provider shows local login form instead of redirecting to OAuth
|
|
245
|
+
if provider == DUMMY_PROVIDER_KEY:
|
|
246
|
+
flash = request.session.pop("flash", None)
|
|
247
|
+
return TemplateResponse(
|
|
248
|
+
"auth/dummy_login.html",
|
|
249
|
+
context={"flash": flash},
|
|
250
|
+
)
|
|
251
|
+
|
|
241
252
|
# Generate CSRF state token
|
|
242
253
|
state = secrets.token_urlsafe(32)
|
|
243
254
|
request.session["oauth_state"] = state
|
|
@@ -306,30 +317,67 @@ class AuthController(Controller):
|
|
|
306
317
|
if not oauth_id:
|
|
307
318
|
raise HTTPException(status_code=400, detail="Could not determine user ID")
|
|
308
319
|
|
|
309
|
-
|
|
320
|
+
email = user_data["email"]
|
|
321
|
+
|
|
322
|
+
# Step 1: Check if OAuth account already exists
|
|
310
323
|
result = await db_session.execute(
|
|
311
|
-
select(
|
|
324
|
+
select(OAuthAccount)
|
|
325
|
+
.options(selectinload(OAuthAccount.user))
|
|
326
|
+
.where(OAuthAccount.provider == provider, OAuthAccount.provider_account_id == oauth_id)
|
|
312
327
|
)
|
|
313
|
-
|
|
328
|
+
oauth_account = result.scalar_one_or_none()
|
|
314
329
|
|
|
315
|
-
if
|
|
316
|
-
#
|
|
330
|
+
if oauth_account:
|
|
331
|
+
# Existing OAuth account - update user profile
|
|
332
|
+
user = oauth_account.user
|
|
317
333
|
user.name = user_data["name"]
|
|
318
334
|
if user_data["picture_url"]:
|
|
319
335
|
user.picture_url = user_data["picture_url"]
|
|
320
336
|
user.last_login_at = datetime.now(UTC)
|
|
337
|
+
# Update provider email if changed
|
|
338
|
+
if email:
|
|
339
|
+
oauth_account.provider_email = email
|
|
321
340
|
else:
|
|
322
|
-
#
|
|
323
|
-
user =
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
341
|
+
# Step 2: Check if a user with this email already exists
|
|
342
|
+
user = None
|
|
343
|
+
if email:
|
|
344
|
+
result = await db_session.execute(
|
|
345
|
+
select(User).where(User.email == email)
|
|
346
|
+
)
|
|
347
|
+
user = result.scalar_one_or_none()
|
|
348
|
+
|
|
349
|
+
if user:
|
|
350
|
+
# Link new OAuth account to existing user
|
|
351
|
+
oauth_account = OAuthAccount(
|
|
352
|
+
provider=provider,
|
|
353
|
+
provider_account_id=oauth_id,
|
|
354
|
+
provider_email=email,
|
|
355
|
+
user_id=user.id,
|
|
356
|
+
)
|
|
357
|
+
db_session.add(oauth_account)
|
|
358
|
+
# Update user profile
|
|
359
|
+
user.name = user_data["name"]
|
|
360
|
+
if user_data["picture_url"]:
|
|
361
|
+
user.picture_url = user_data["picture_url"]
|
|
362
|
+
user.last_login_at = datetime.now(UTC)
|
|
363
|
+
else:
|
|
364
|
+
# Step 3: Create new user + OAuth account
|
|
365
|
+
user = User(
|
|
366
|
+
email=email,
|
|
367
|
+
name=user_data["name"],
|
|
368
|
+
picture_url=user_data["picture_url"],
|
|
369
|
+
last_login_at=datetime.now(UTC),
|
|
370
|
+
)
|
|
371
|
+
db_session.add(user)
|
|
372
|
+
await db_session.flush()
|
|
373
|
+
|
|
374
|
+
oauth_account = OAuthAccount(
|
|
375
|
+
provider=provider,
|
|
376
|
+
provider_account_id=oauth_id,
|
|
377
|
+
provider_email=email,
|
|
378
|
+
user_id=user.id,
|
|
379
|
+
)
|
|
380
|
+
db_session.add(oauth_account)
|
|
333
381
|
|
|
334
382
|
await db_session.commit()
|
|
335
383
|
|
|
@@ -348,22 +396,120 @@ class AuthController(Controller):
|
|
|
348
396
|
flash = request.session.pop("flash", None)
|
|
349
397
|
settings = get_settings()
|
|
350
398
|
|
|
351
|
-
# Get configured providers
|
|
399
|
+
# Get configured providers (excluding dummy from main list)
|
|
352
400
|
configured_providers = list(settings.auth.providers.keys())
|
|
353
401
|
providers = {
|
|
354
402
|
key: OAUTH_PROVIDERS[key]
|
|
355
403
|
for key in configured_providers
|
|
356
|
-
if key in OAUTH_PROVIDERS
|
|
404
|
+
if key in OAUTH_PROVIDERS and key != DUMMY_PROVIDER_KEY
|
|
357
405
|
}
|
|
358
406
|
|
|
407
|
+
# Check if dummy provider is configured
|
|
408
|
+
has_dummy = DUMMY_PROVIDER_KEY in settings.auth.providers
|
|
409
|
+
|
|
359
410
|
return TemplateResponse(
|
|
360
411
|
"auth/login.html",
|
|
361
412
|
context={
|
|
362
413
|
"flash": flash,
|
|
363
414
|
"providers": providers,
|
|
415
|
+
"has_dummy": has_dummy,
|
|
364
416
|
},
|
|
365
417
|
)
|
|
366
418
|
|
|
419
|
+
@post("/dummy-login")
|
|
420
|
+
async def dummy_login_submit(
|
|
421
|
+
self,
|
|
422
|
+
request: Request,
|
|
423
|
+
db_session: AsyncSession,
|
|
424
|
+
) -> Redirect:
|
|
425
|
+
"""Process dummy login form submission."""
|
|
426
|
+
settings = get_settings()
|
|
427
|
+
|
|
428
|
+
if DUMMY_PROVIDER_KEY not in settings.auth.providers:
|
|
429
|
+
raise NotFoundException("Dummy provider not configured")
|
|
430
|
+
|
|
431
|
+
# Parse form data from request
|
|
432
|
+
form_data = await request.form()
|
|
433
|
+
email = form_data.get("email", "").strip()
|
|
434
|
+
name = form_data.get("name", "").strip()
|
|
435
|
+
|
|
436
|
+
if not email:
|
|
437
|
+
request.session["flash"] = "Email is required"
|
|
438
|
+
return Redirect(path="/auth/dummy/login")
|
|
439
|
+
|
|
440
|
+
# Default name to email username if not provided
|
|
441
|
+
if not name:
|
|
442
|
+
name = email.split("@")[0]
|
|
443
|
+
|
|
444
|
+
# Generate deterministic oauth_id from email
|
|
445
|
+
oauth_id = f"dummy_{hashlib.sha256(email.encode()).hexdigest()[:16]}"
|
|
446
|
+
|
|
447
|
+
# Step 1: Check if OAuth account already exists
|
|
448
|
+
result = await db_session.execute(
|
|
449
|
+
select(OAuthAccount)
|
|
450
|
+
.options(selectinload(OAuthAccount.user))
|
|
451
|
+
.where(
|
|
452
|
+
OAuthAccount.provider == DUMMY_PROVIDER_KEY,
|
|
453
|
+
OAuthAccount.provider_account_id == oauth_id,
|
|
454
|
+
)
|
|
455
|
+
)
|
|
456
|
+
oauth_account = result.scalar_one_or_none()
|
|
457
|
+
|
|
458
|
+
if oauth_account:
|
|
459
|
+
# Existing OAuth account - update user profile
|
|
460
|
+
user = oauth_account.user
|
|
461
|
+
user.name = name
|
|
462
|
+
user.email = email
|
|
463
|
+
user.last_login_at = datetime.now(UTC)
|
|
464
|
+
oauth_account.provider_email = email
|
|
465
|
+
else:
|
|
466
|
+
# Step 2: Check if a user with this email already exists
|
|
467
|
+
result = await db_session.execute(
|
|
468
|
+
select(User).where(User.email == email)
|
|
469
|
+
)
|
|
470
|
+
user = result.scalar_one_or_none()
|
|
471
|
+
|
|
472
|
+
if user:
|
|
473
|
+
# Link new OAuth account to existing user
|
|
474
|
+
oauth_account = OAuthAccount(
|
|
475
|
+
provider=DUMMY_PROVIDER_KEY,
|
|
476
|
+
provider_account_id=oauth_id,
|
|
477
|
+
provider_email=email,
|
|
478
|
+
user_id=user.id,
|
|
479
|
+
)
|
|
480
|
+
db_session.add(oauth_account)
|
|
481
|
+
# Update user profile
|
|
482
|
+
user.name = name
|
|
483
|
+
user.last_login_at = datetime.now(UTC)
|
|
484
|
+
else:
|
|
485
|
+
# Step 3: Create new user + OAuth account
|
|
486
|
+
user = User(
|
|
487
|
+
email=email,
|
|
488
|
+
name=name,
|
|
489
|
+
last_login_at=datetime.now(UTC),
|
|
490
|
+
)
|
|
491
|
+
db_session.add(user)
|
|
492
|
+
await db_session.flush()
|
|
493
|
+
|
|
494
|
+
oauth_account = OAuthAccount(
|
|
495
|
+
provider=DUMMY_PROVIDER_KEY,
|
|
496
|
+
provider_account_id=oauth_id,
|
|
497
|
+
provider_email=email,
|
|
498
|
+
user_id=user.id,
|
|
499
|
+
)
|
|
500
|
+
db_session.add(oauth_account)
|
|
501
|
+
|
|
502
|
+
await db_session.commit()
|
|
503
|
+
|
|
504
|
+
# Set session with user info
|
|
505
|
+
request.session["user_id"] = str(user.id)
|
|
506
|
+
request.session["user_name"] = user.name
|
|
507
|
+
request.session["user_email"] = user.email
|
|
508
|
+
request.session["user_picture_url"] = user.picture_url
|
|
509
|
+
request.session["flash"] = "Successfully logged in!"
|
|
510
|
+
|
|
511
|
+
return Redirect(path="/")
|
|
512
|
+
|
|
367
513
|
@get("/logout")
|
|
368
514
|
async def logout(self, request: Request) -> Redirect:
|
|
369
515
|
"""Clear session and redirect to home."""
|
skrift/db/models/__init__.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
|
+
from skrift.db.models.oauth_account import OAuthAccount
|
|
1
2
|
from skrift.db.models.page import Page
|
|
2
3
|
from skrift.db.models.role import Role, RolePermission, user_roles
|
|
3
4
|
from skrift.db.models.setting import Setting
|
|
4
5
|
from skrift.db.models.user import User
|
|
5
6
|
|
|
6
|
-
__all__ = ["Page", "Role", "RolePermission", "Setting", "User", "user_roles"]
|
|
7
|
+
__all__ = ["OAuthAccount", "Page", "Role", "RolePermission", "Setting", "User", "user_roles"]
|