squirrels 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of squirrels might be problematic. Click here for more details.
- dateutils/__init__.py +6 -0
- dateutils/_enums.py +25 -0
- squirrels/dateutils.py → dateutils/_implementation.py +58 -111
- dateutils/types.py +6 -0
- squirrels/__init__.py +13 -11
- squirrels/_api_routes/__init__.py +5 -0
- squirrels/_api_routes/auth.py +271 -0
- squirrels/_api_routes/base.py +165 -0
- squirrels/_api_routes/dashboards.py +150 -0
- squirrels/_api_routes/data_management.py +145 -0
- squirrels/_api_routes/datasets.py +257 -0
- squirrels/_api_routes/oauth2.py +298 -0
- squirrels/_api_routes/project.py +252 -0
- squirrels/_api_server.py +256 -450
- squirrels/_arguments/__init__.py +0 -0
- squirrels/_arguments/init_time_args.py +108 -0
- squirrels/_arguments/run_time_args.py +147 -0
- squirrels/_auth.py +960 -0
- squirrels/_command_line.py +126 -45
- squirrels/_compile_prompts.py +147 -0
- squirrels/_connection_set.py +48 -26
- squirrels/_constants.py +68 -38
- squirrels/_dashboards.py +160 -0
- squirrels/_data_sources.py +570 -0
- squirrels/_dataset_types.py +84 -0
- squirrels/_exceptions.py +29 -0
- squirrels/_initializer.py +177 -80
- squirrels/_logging.py +115 -0
- squirrels/_manifest.py +208 -79
- squirrels/_model_builder.py +69 -0
- squirrels/_model_configs.py +74 -0
- squirrels/_model_queries.py +52 -0
- squirrels/_models.py +926 -367
- squirrels/_package_data/base_project/.env +42 -0
- squirrels/_package_data/base_project/.env.example +42 -0
- squirrels/_package_data/base_project/assets/expenses.db +0 -0
- squirrels/_package_data/base_project/connections.yml +16 -0
- squirrels/_package_data/base_project/dashboards/dashboard_example.py +34 -0
- squirrels/_package_data/base_project/dashboards/dashboard_example.yml +22 -0
- squirrels/{package_data → _package_data}/base_project/docker/.dockerignore +5 -2
- squirrels/{package_data → _package_data}/base_project/docker/Dockerfile +3 -3
- squirrels/{package_data → _package_data}/base_project/docker/compose.yml +1 -1
- squirrels/_package_data/base_project/duckdb_init.sql +10 -0
- squirrels/{package_data/base_project/.gitignore → _package_data/base_project/gitignore} +3 -2
- squirrels/_package_data/base_project/macros/macros_example.sql +17 -0
- squirrels/_package_data/base_project/models/builds/build_example.py +26 -0
- squirrels/_package_data/base_project/models/builds/build_example.sql +16 -0
- squirrels/_package_data/base_project/models/builds/build_example.yml +57 -0
- squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +12 -0
- squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +26 -0
- squirrels/_package_data/base_project/models/federates/federate_example.py +37 -0
- squirrels/_package_data/base_project/models/federates/federate_example.sql +19 -0
- squirrels/_package_data/base_project/models/federates/federate_example.yml +65 -0
- squirrels/_package_data/base_project/models/sources.yml +38 -0
- squirrels/{package_data → _package_data}/base_project/parameters.yml +56 -40
- squirrels/_package_data/base_project/pyconfigs/connections.py +14 -0
- squirrels/{package_data → _package_data}/base_project/pyconfigs/context.py +21 -40
- squirrels/_package_data/base_project/pyconfigs/parameters.py +141 -0
- squirrels/_package_data/base_project/pyconfigs/user.py +44 -0
- squirrels/_package_data/base_project/seeds/seed_categories.yml +15 -0
- squirrels/_package_data/base_project/seeds/seed_subcategories.csv +15 -0
- squirrels/_package_data/base_project/seeds/seed_subcategories.yml +21 -0
- squirrels/_package_data/base_project/squirrels.yml.j2 +61 -0
- squirrels/_package_data/templates/dataset_results.html +112 -0
- squirrels/_package_data/templates/oauth_login.html +271 -0
- squirrels/_package_data/templates/squirrels_studio.html +20 -0
- squirrels/_package_loader.py +8 -4
- squirrels/_parameter_configs.py +104 -103
- squirrels/_parameter_options.py +348 -0
- squirrels/_parameter_sets.py +57 -47
- squirrels/_parameters.py +1664 -0
- squirrels/_project.py +721 -0
- squirrels/_py_module.py +7 -5
- squirrels/_schemas/__init__.py +0 -0
- squirrels/_schemas/auth_models.py +167 -0
- squirrels/_schemas/query_param_models.py +75 -0
- squirrels/{_api_response_models.py → _schemas/response_models.py} +126 -47
- squirrels/_seeds.py +35 -16
- squirrels/_sources.py +110 -0
- squirrels/_utils.py +248 -73
- squirrels/_version.py +1 -1
- squirrels/arguments.py +7 -0
- squirrels/auth.py +4 -0
- squirrels/connections.py +3 -0
- squirrels/dashboards.py +2 -81
- squirrels/data_sources.py +14 -631
- squirrels/parameter_options.py +13 -348
- squirrels/parameters.py +14 -1266
- squirrels/types.py +16 -0
- squirrels-0.5.0.dist-info/METADATA +113 -0
- squirrels-0.5.0.dist-info/RECORD +97 -0
- {squirrels-0.4.1.dist-info → squirrels-0.5.0.dist-info}/WHEEL +1 -1
- squirrels-0.5.0.dist-info/entry_points.txt +3 -0
- {squirrels-0.4.1.dist-info → squirrels-0.5.0.dist-info/licenses}/LICENSE +1 -1
- squirrels/_authenticator.py +0 -85
- squirrels/_dashboards_io.py +0 -61
- squirrels/_environcfg.py +0 -84
- squirrels/arguments/init_time_args.py +0 -40
- squirrels/arguments/run_time_args.py +0 -208
- squirrels/package_data/assets/favicon.ico +0 -0
- squirrels/package_data/assets/index.css +0 -1
- squirrels/package_data/assets/index.js +0 -58
- squirrels/package_data/base_project/assets/expenses.db +0 -0
- squirrels/package_data/base_project/connections.yml +0 -7
- squirrels/package_data/base_project/dashboards/dashboard_example.py +0 -32
- squirrels/package_data/base_project/dashboards.yml +0 -10
- squirrels/package_data/base_project/env.yml +0 -29
- squirrels/package_data/base_project/models/dbviews/dbview_example.py +0 -47
- squirrels/package_data/base_project/models/dbviews/dbview_example.sql +0 -22
- squirrels/package_data/base_project/models/federates/federate_example.py +0 -21
- squirrels/package_data/base_project/models/federates/federate_example.sql +0 -3
- squirrels/package_data/base_project/pyconfigs/auth.py +0 -45
- squirrels/package_data/base_project/pyconfigs/connections.py +0 -19
- squirrels/package_data/base_project/pyconfigs/parameters.py +0 -95
- squirrels/package_data/base_project/seeds/seed_subcategories.csv +0 -15
- squirrels/package_data/base_project/squirrels.yml.j2 +0 -94
- squirrels/package_data/templates/index.html +0 -18
- squirrels/project.py +0 -378
- squirrels/user_base.py +0 -55
- squirrels-0.4.1.dist-info/METADATA +0 -117
- squirrels-0.4.1.dist-info/RECORD +0 -60
- squirrels-0.4.1.dist-info/entry_points.txt +0 -4
- /squirrels/{package_data → _package_data}/base_project/assets/weather.db +0 -0
- /squirrels/{package_data → _package_data}/base_project/seeds/seed_categories.csv +0 -0
- /squirrels/{package_data → _package_data}/base_project/tmp/.gitignore +0 -0
squirrels/_auth.py
ADDED
|
@@ -0,0 +1,960 @@
|
|
|
1
|
+
from typing import Callable, Any
|
|
2
|
+
from datetime import datetime, timedelta, timezone
|
|
3
|
+
from functools import cached_property
|
|
4
|
+
from jwt.exceptions import InvalidTokenError
|
|
5
|
+
from passlib.context import CryptContext
|
|
6
|
+
from pydantic import ValidationError
|
|
7
|
+
from sqlalchemy import create_engine, Engine, func, inspect, text, ForeignKey
|
|
8
|
+
from sqlalchemy.orm import declarative_base, sessionmaker, Mapped, mapped_column
|
|
9
|
+
import jwt, uuid, secrets, json
|
|
10
|
+
|
|
11
|
+
from ._manifest import PermissionScope
|
|
12
|
+
from ._py_module import PyModule
|
|
13
|
+
from ._exceptions import InvalidInputError, ConfigurationError
|
|
14
|
+
from ._arguments.init_time_args import AuthProviderArgs
|
|
15
|
+
from ._schemas.auth_models import (
|
|
16
|
+
CustomUserFields, AbstractUser, GuestUser, RegisteredUser, ApiKey, UserField, AuthProvider, ProviderConfigs,
|
|
17
|
+
ClientRegistrationRequest, ClientUpdateRequest, ClientDetailsResponse, ClientRegistrationResponse,
|
|
18
|
+
ClientUpdateResponse, TokenResponse
|
|
19
|
+
)
|
|
20
|
+
from . import _utils as u, _constants as c
|
|
21
|
+
|
|
22
|
+
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
23
|
+
|
|
24
|
+
ProviderFunctionType = Callable[[AuthProviderArgs], AuthProvider]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Authenticator:
|
|
28
|
+
providers: list[ProviderFunctionType] = [] # static variable to stage providers
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self, logger: u.Logger, base_path: str, auth_args: AuthProviderArgs, provider_functions: list[ProviderFunctionType],
|
|
32
|
+
custom_user_fields_cls: type[CustomUserFields], *, sa_engine: Engine | None = None, external_only: bool = False
|
|
33
|
+
):
|
|
34
|
+
self.logger = logger
|
|
35
|
+
self.env_vars = auth_args.env_vars
|
|
36
|
+
self.secret_key = self.env_vars.get(c.SQRL_SECRET_KEY)
|
|
37
|
+
self.external_only = external_only
|
|
38
|
+
|
|
39
|
+
# Create a new declarative base for this instance
|
|
40
|
+
self.Base = declarative_base()
|
|
41
|
+
|
|
42
|
+
# Define DbUser class for this instance
|
|
43
|
+
class DbUser(self.Base):
|
|
44
|
+
__tablename__ = 'users'
|
|
45
|
+
__table_args__ = {'extend_existing': True}
|
|
46
|
+
username: Mapped[str] = mapped_column(primary_key=True)
|
|
47
|
+
access_level: Mapped[str] = mapped_column(nullable=False, default="member")
|
|
48
|
+
password_hash: Mapped[str] = mapped_column(nullable=False)
|
|
49
|
+
custom_fields: Mapped[str] = mapped_column(nullable=False, default="{}")
|
|
50
|
+
created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
|
|
51
|
+
|
|
52
|
+
# Define DbApiKey class for this instance
|
|
53
|
+
class DbApiKey(self.Base):
|
|
54
|
+
__tablename__ = 'api_keys'
|
|
55
|
+
|
|
56
|
+
id: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
|
|
57
|
+
hashed_key: Mapped[str] = mapped_column(unique=True, nullable=False)
|
|
58
|
+
title: Mapped[str] = mapped_column(nullable=False)
|
|
59
|
+
username: Mapped[str] = mapped_column(ForeignKey('users.username', ondelete='CASCADE'), nullable=False)
|
|
60
|
+
created_at: Mapped[datetime] = mapped_column(nullable=False)
|
|
61
|
+
expires_at: Mapped[datetime] = mapped_column(nullable=False)
|
|
62
|
+
|
|
63
|
+
def __repr__(self):
|
|
64
|
+
return f"<DbApiKey(id='{self.id}', username='{self.username}')>"
|
|
65
|
+
|
|
66
|
+
# Define DbOAuthClient class for this instance
|
|
67
|
+
class DbOAuthClient(self.Base):
|
|
68
|
+
__tablename__ = 'oauth_clients'
|
|
69
|
+
|
|
70
|
+
client_id: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
|
|
71
|
+
client_secret_hash: Mapped[str] = mapped_column(nullable=False)
|
|
72
|
+
client_name: Mapped[str] = mapped_column(nullable=False)
|
|
73
|
+
redirect_uris: Mapped[str] = mapped_column(nullable=False) # JSON array of allowed redirect URIs
|
|
74
|
+
scope: Mapped[str] = mapped_column(nullable=False, default='read')
|
|
75
|
+
grant_types: Mapped[str] = mapped_column(nullable=False, default='authorization_code,refresh_token')
|
|
76
|
+
response_types: Mapped[str] = mapped_column(nullable=False, default='code')
|
|
77
|
+
client_type: Mapped[str] = mapped_column(nullable=False, default='confidential') # 'confidential' or 'public'
|
|
78
|
+
registration_access_token_hash: Mapped[str] = mapped_column(nullable=False) # Token for client management
|
|
79
|
+
created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
|
|
80
|
+
is_active: Mapped[bool] = mapped_column(nullable=False, default=True)
|
|
81
|
+
|
|
82
|
+
def __repr__(self):
|
|
83
|
+
return f"<DbOAuthClient(client_id='{self.client_id}', name='{self.client_name}')>"
|
|
84
|
+
|
|
85
|
+
# Define DbAuthorizationCode class for this instance
|
|
86
|
+
class DbAuthorizationCode(self.Base):
|
|
87
|
+
__tablename__ = 'authorization_codes'
|
|
88
|
+
|
|
89
|
+
code: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
|
|
90
|
+
client_id: Mapped[str] = mapped_column(ForeignKey('oauth_clients.client_id', ondelete='CASCADE'), nullable=False)
|
|
91
|
+
username: Mapped[str] = mapped_column(ForeignKey('users.username', ondelete='CASCADE'), nullable=False)
|
|
92
|
+
redirect_uri: Mapped[str] = mapped_column(nullable=False)
|
|
93
|
+
scope: Mapped[str] = mapped_column(nullable=True)
|
|
94
|
+
code_challenge: Mapped[str] = mapped_column(nullable=False) # PKCE always required
|
|
95
|
+
code_challenge_method: Mapped[str] = mapped_column(nullable=False) # only S256 is supported
|
|
96
|
+
created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
|
|
97
|
+
expires_at: Mapped[datetime] = mapped_column(nullable=False) # 10 minutes from creation
|
|
98
|
+
used: Mapped[bool] = mapped_column(nullable=False, default=False)
|
|
99
|
+
|
|
100
|
+
def __repr__(self):
|
|
101
|
+
return f"<DbAuthorizationCode(code='{self.code[:8]}...', client_id='{self.client_id}')>"
|
|
102
|
+
|
|
103
|
+
# Define DbOAuthToken class for this instance
|
|
104
|
+
class DbOAuthToken(self.Base):
|
|
105
|
+
__tablename__ = 'oauth_tokens'
|
|
106
|
+
|
|
107
|
+
token_id: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
|
|
108
|
+
access_token_hash: Mapped[str] = mapped_column(unique=True, nullable=False)
|
|
109
|
+
refresh_token_hash: Mapped[str] = mapped_column(unique=True, nullable=True) # NULL for client_credentials grants
|
|
110
|
+
client_id: Mapped[str] = mapped_column(ForeignKey('oauth_clients.client_id', ondelete='CASCADE'), nullable=False)
|
|
111
|
+
username: Mapped[str] = mapped_column(ForeignKey('users.username', ondelete='CASCADE'), nullable=False)
|
|
112
|
+
scope: Mapped[str] = mapped_column(nullable=True)
|
|
113
|
+
token_type: Mapped[str] = mapped_column(nullable=False, default='Bearer')
|
|
114
|
+
created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
|
|
115
|
+
access_token_expires_at: Mapped[datetime] = mapped_column(nullable=False) # Uses SQRL_AUTH_TOKEN_EXPIRE_MINUTES
|
|
116
|
+
refresh_token_expires_at: Mapped[datetime] = mapped_column(nullable=True) # 30 days from creation, NULL for client_credentials
|
|
117
|
+
is_revoked: Mapped[bool] = mapped_column(nullable=False, default=False)
|
|
118
|
+
|
|
119
|
+
def __repr__(self):
|
|
120
|
+
return f"<DbOAuthToken(token_id='{self.token_id}', client_id='{self.client_id}', username='{self.username}')>"
|
|
121
|
+
|
|
122
|
+
self.DbApiKey = DbApiKey
|
|
123
|
+
self.DbOAuthClient = DbOAuthClient
|
|
124
|
+
self.DbAuthorizationCode = DbAuthorizationCode
|
|
125
|
+
self.DbOAuthToken = DbOAuthToken
|
|
126
|
+
|
|
127
|
+
self.CustomUserFields = custom_user_fields_cls
|
|
128
|
+
self.DbUser = DbUser
|
|
129
|
+
|
|
130
|
+
self.auth_providers = [provider_function(auth_args) for provider_function in provider_functions]
|
|
131
|
+
|
|
132
|
+
if sa_engine is None:
|
|
133
|
+
raw_sqlite_path = self.env_vars.get(c.SQRL_AUTH_DB_FILE_PATH, f"{{project_path}}/{c.TARGET_FOLDER}/{c.DB_FILE}")
|
|
134
|
+
sqlite_path = u.Path(raw_sqlite_path.format(project_path=base_path))
|
|
135
|
+
sqlite_path.parent.mkdir(parents=True, exist_ok=True)
|
|
136
|
+
self.engine = create_engine(f"sqlite:///{str(sqlite_path)}")
|
|
137
|
+
else:
|
|
138
|
+
self.engine = sa_engine
|
|
139
|
+
|
|
140
|
+
# Configure SQLite pragmas
|
|
141
|
+
with self.engine.connect() as conn:
|
|
142
|
+
conn.execute(text("PRAGMA journal_mode = WAL"))
|
|
143
|
+
conn.execute(text("PRAGMA synchronous = NORMAL"))
|
|
144
|
+
conn.commit()
|
|
145
|
+
|
|
146
|
+
self.Base.metadata.create_all(self.engine)
|
|
147
|
+
|
|
148
|
+
self.Session = sessionmaker(bind=self.engine)
|
|
149
|
+
|
|
150
|
+
self._initialize_db()
|
|
151
|
+
|
|
152
|
+
def _convert_db_user_to_user(self, db_user) -> RegisteredUser:
|
|
153
|
+
"""Convert a database user to an AbstractUser object"""
|
|
154
|
+
# Deserialize custom_fields JSON and merge with defaults
|
|
155
|
+
custom_fields_json = json.loads(db_user.custom_fields) if db_user.custom_fields else {}
|
|
156
|
+
custom_fields = self.CustomUserFields(**custom_fields_json)
|
|
157
|
+
|
|
158
|
+
return RegisteredUser(
|
|
159
|
+
username=db_user.username,
|
|
160
|
+
access_level=db_user.access_level,
|
|
161
|
+
custom_fields=custom_fields
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
def _initialize_db(self): # TODO: Use logger instead of print
|
|
165
|
+
session = self.Session()
|
|
166
|
+
try:
|
|
167
|
+
# Get existing columns in the database
|
|
168
|
+
inspector = inspect(self.engine)
|
|
169
|
+
existing_columns = {col['name'] for col in inspector.get_columns('users')}
|
|
170
|
+
|
|
171
|
+
# Get all columns defined in the model
|
|
172
|
+
model_columns = set(self.DbUser.__table__.columns.keys())
|
|
173
|
+
|
|
174
|
+
# Find columns that are in the model but not in the database
|
|
175
|
+
new_columns = model_columns - existing_columns
|
|
176
|
+
if new_columns:
|
|
177
|
+
add_columns_msg = f"Adding columns to database: {new_columns}"
|
|
178
|
+
print("NOTE -", add_columns_msg)
|
|
179
|
+
self.logger.info(add_columns_msg)
|
|
180
|
+
|
|
181
|
+
for col_name in new_columns:
|
|
182
|
+
col = self.DbUser.__table__.columns[col_name]
|
|
183
|
+
column_type = col.type.compile(self.engine.dialect)
|
|
184
|
+
nullable = "NULL" if col.nullable else "NOT NULL"
|
|
185
|
+
if col.default is not None:
|
|
186
|
+
default_val = f"'{col.default.arg}'" if isinstance(col.default.arg, str) else col.default.arg
|
|
187
|
+
default = f"DEFAULT {default_val}"
|
|
188
|
+
else:
|
|
189
|
+
default = ""
|
|
190
|
+
|
|
191
|
+
alter_stmt = f"ALTER TABLE users ADD COLUMN {col_name} {column_type} {nullable} {default}"
|
|
192
|
+
session.execute(text(alter_stmt))
|
|
193
|
+
|
|
194
|
+
session.commit()
|
|
195
|
+
|
|
196
|
+
# Get admin password from environment variable if exists
|
|
197
|
+
admin_password = self.env_vars.get(c.SQRL_SECRET_ADMIN_PASSWORD)
|
|
198
|
+
|
|
199
|
+
# If admin password variable exists, find username "admin". If it does not exist, add it
|
|
200
|
+
if admin_password is not None:
|
|
201
|
+
password_hash = pwd_context.hash(admin_password)
|
|
202
|
+
admin_user = session.get(self.DbUser, c.ADMIN_USERNAME)
|
|
203
|
+
if admin_user is None:
|
|
204
|
+
admin_user = self.DbUser(username=c.ADMIN_USERNAME, password_hash=password_hash, access_level="admin")
|
|
205
|
+
session.add(admin_user)
|
|
206
|
+
else:
|
|
207
|
+
admin_user.password_hash = password_hash
|
|
208
|
+
|
|
209
|
+
session.commit()
|
|
210
|
+
|
|
211
|
+
finally:
|
|
212
|
+
session.close()
|
|
213
|
+
|
|
214
|
+
@cached_property
|
|
215
|
+
def user_fields(self) -> list[UserField]:
|
|
216
|
+
"""
|
|
217
|
+
Get the fields of the CustomUserFields model as a list of dictionaries
|
|
218
|
+
|
|
219
|
+
Each dictionary contains the following keys:
|
|
220
|
+
- name: The name of the field
|
|
221
|
+
- type: The type of the field
|
|
222
|
+
- nullable: Whether the field is nullable
|
|
223
|
+
- enum: The possible values of the field (or None if not applicable)
|
|
224
|
+
- default: The default value of the field (or None if field is required)
|
|
225
|
+
"""
|
|
226
|
+
# Add the base fields
|
|
227
|
+
fields = [
|
|
228
|
+
UserField(name="username", type="string", nullable=False, enum=None, default=None),
|
|
229
|
+
UserField(name="access_level", type="string", nullable=False, enum=["admin", "member"], default="member")
|
|
230
|
+
]
|
|
231
|
+
|
|
232
|
+
# Add custom fields
|
|
233
|
+
schema = self.CustomUserFields.model_json_schema()
|
|
234
|
+
properties: dict[str, dict[str, Any]] = schema.get("properties", {})
|
|
235
|
+
for field_name, field_schema in properties.items():
|
|
236
|
+
if choices := field_schema.get("anyOf"):
|
|
237
|
+
field_type = choices[0]["type"]
|
|
238
|
+
nullable = (choices[1]["type"] == "null")
|
|
239
|
+
else:
|
|
240
|
+
field_type = field_schema["type"]
|
|
241
|
+
nullable = False
|
|
242
|
+
|
|
243
|
+
field_data = UserField(name=field_name, type=field_type, nullable=nullable, enum=field_schema.get("enum"), default=field_schema.get("default"))
|
|
244
|
+
fields.append(field_data)
|
|
245
|
+
|
|
246
|
+
return fields
|
|
247
|
+
|
|
248
|
+
def add_user(self, username: str, user_fields: dict, *, update_user: bool = False) -> None:
|
|
249
|
+
session = self.Session()
|
|
250
|
+
|
|
251
|
+
# Separate custom fields from base fields
|
|
252
|
+
access_level = user_fields.get('access_level', 'member')
|
|
253
|
+
password = user_fields.get('password')
|
|
254
|
+
|
|
255
|
+
# Validate access level - cannot add/update users with guest access level
|
|
256
|
+
if access_level == "guest":
|
|
257
|
+
raise InvalidInputError(400, "invalid_access_level", "Cannot create or update users with 'guest' access level")
|
|
258
|
+
|
|
259
|
+
# Extract custom fields (everything except username, access_level, password)
|
|
260
|
+
custom_fields_dict = {k: v for k, v in user_fields.items() if k not in ['username', 'access_level', 'password']}
|
|
261
|
+
|
|
262
|
+
# Validate the custom fields
|
|
263
|
+
try:
|
|
264
|
+
custom_fields = self.CustomUserFields(**custom_fields_dict)
|
|
265
|
+
custom_fields_json = json.dumps(custom_fields.model_dump(mode='json'))
|
|
266
|
+
except ValidationError as e:
|
|
267
|
+
raise InvalidInputError(400, "invalid_user_data", f"Invalid user field '{e.errors()[0]['loc'][0]}': {e.errors()[0]['msg']}")
|
|
268
|
+
|
|
269
|
+
# Add or update user
|
|
270
|
+
try:
|
|
271
|
+
# Check if the user already exists
|
|
272
|
+
existing_user = session.get(self.DbUser, username)
|
|
273
|
+
if existing_user is not None:
|
|
274
|
+
if not update_user:
|
|
275
|
+
raise InvalidInputError(400, "username_already_exists", f"User '{username}' already exists")
|
|
276
|
+
|
|
277
|
+
if username == c.ADMIN_USERNAME and access_level != "admin":
|
|
278
|
+
raise InvalidInputError(403, "admin_cannot_be_non_admin", "Setting the admin user to non-admin is not permitted")
|
|
279
|
+
|
|
280
|
+
# Update existing user
|
|
281
|
+
existing_user.access_level = access_level
|
|
282
|
+
existing_user.custom_fields = custom_fields_json
|
|
283
|
+
else:
|
|
284
|
+
if update_user:
|
|
285
|
+
raise InvalidInputError(404, "no_user_found_for_username", f"No user found for username: {username}")
|
|
286
|
+
|
|
287
|
+
if password is None:
|
|
288
|
+
raise InvalidInputError(400, "missing_password", f"Missing required field 'password' when adding a new user")
|
|
289
|
+
|
|
290
|
+
password_hash = pwd_context.hash(password)
|
|
291
|
+
new_user = self.DbUser(
|
|
292
|
+
username=username,
|
|
293
|
+
access_level=access_level,
|
|
294
|
+
password_hash=password_hash,
|
|
295
|
+
custom_fields=custom_fields_json
|
|
296
|
+
)
|
|
297
|
+
session.add(new_user)
|
|
298
|
+
|
|
299
|
+
# Commit the transaction
|
|
300
|
+
session.commit()
|
|
301
|
+
|
|
302
|
+
finally:
|
|
303
|
+
session.close()
|
|
304
|
+
|
|
305
|
+
def create_or_get_user_from_provider(self, provider_name: str, user_info: dict) -> RegisteredUser:
|
|
306
|
+
provider = next((p for p in self.auth_providers if p.name == provider_name), None)
|
|
307
|
+
if provider is None:
|
|
308
|
+
raise InvalidInputError(404, "auth_provider_not_found", f"Provider '{provider_name}' not found")
|
|
309
|
+
|
|
310
|
+
user = provider.provider_configs.get_user(user_info)
|
|
311
|
+
session = self.Session()
|
|
312
|
+
try:
|
|
313
|
+
# Convert user to database user
|
|
314
|
+
custom_fields_json = user.custom_fields.model_dump_json()
|
|
315
|
+
db_user = self.DbUser(
|
|
316
|
+
username=user.username,
|
|
317
|
+
access_level=user.access_level,
|
|
318
|
+
password_hash="", # By omitting password_hash, it becomes impossible to login with username and password (OAuth only)
|
|
319
|
+
custom_fields=custom_fields_json
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
existing_db_user = session.get(self.DbUser, db_user.username)
|
|
323
|
+
if existing_db_user is None:
|
|
324
|
+
session.add(db_user)
|
|
325
|
+
|
|
326
|
+
session.commit()
|
|
327
|
+
|
|
328
|
+
return self._convert_db_user_to_user(db_user)
|
|
329
|
+
|
|
330
|
+
finally:
|
|
331
|
+
session.close()
|
|
332
|
+
|
|
333
|
+
def get_user(self, username: str, password: str) -> RegisteredUser:
|
|
334
|
+
session = self.Session()
|
|
335
|
+
try:
|
|
336
|
+
# Query for user by username
|
|
337
|
+
db_user = session.get(self.DbUser, username)
|
|
338
|
+
|
|
339
|
+
if db_user and pwd_context.verify(password, db_user.password_hash):
|
|
340
|
+
user = self._convert_db_user_to_user(db_user)
|
|
341
|
+
return user
|
|
342
|
+
else:
|
|
343
|
+
raise InvalidInputError(401, "incorrect_username_or_password", f"Incorrect username or password")
|
|
344
|
+
|
|
345
|
+
finally:
|
|
346
|
+
session.close()
|
|
347
|
+
|
|
348
|
+
def change_password(self, username: str, old_password: str, new_password: str) -> None:
|
|
349
|
+
session = self.Session()
|
|
350
|
+
try:
|
|
351
|
+
db_user = session.get(self.DbUser, username)
|
|
352
|
+
if db_user is None:
|
|
353
|
+
raise InvalidInputError(401, "user_not_found", f"Username '{username}' not found for password change")
|
|
354
|
+
|
|
355
|
+
if db_user.password_hash and pwd_context.verify(old_password, db_user.password_hash):
|
|
356
|
+
db_user.password_hash = pwd_context.hash(new_password)
|
|
357
|
+
session.commit()
|
|
358
|
+
else:
|
|
359
|
+
raise InvalidInputError(401, "incorrect_password", f"Incorrect password")
|
|
360
|
+
finally:
|
|
361
|
+
session.close()
|
|
362
|
+
|
|
363
|
+
def delete_user(self, username: str) -> None:
|
|
364
|
+
if username == c.ADMIN_USERNAME:
|
|
365
|
+
raise InvalidInputError(403, "cannot_delete_admin_user", "Cannot delete the admin user")
|
|
366
|
+
|
|
367
|
+
session = self.Session()
|
|
368
|
+
try:
|
|
369
|
+
db_user = session.get(self.DbUser, username)
|
|
370
|
+
if db_user is None:
|
|
371
|
+
raise InvalidInputError(404, "no_user_found_for_username", f"No user found for username: {username}")
|
|
372
|
+
session.delete(db_user)
|
|
373
|
+
session.commit()
|
|
374
|
+
finally:
|
|
375
|
+
session.close()
|
|
376
|
+
|
|
377
|
+
def get_all_users(self) -> list[RegisteredUser]:
|
|
378
|
+
session = self.Session()
|
|
379
|
+
try:
|
|
380
|
+
db_users = session.query(self.DbUser).all()
|
|
381
|
+
return [self._convert_db_user_to_user(user) for user in db_users]
|
|
382
|
+
finally:
|
|
383
|
+
session.close()
|
|
384
|
+
|
|
385
|
+
def create_access_token(self, user: AbstractUser, expiry_minutes: int | None, *, title: str | None = None) -> tuple[str, datetime]:
|
|
386
|
+
"""
|
|
387
|
+
Creates an API key if title is provided. Otherwise, creates a JWT token.
|
|
388
|
+
"""
|
|
389
|
+
created_at = datetime.now(timezone.utc)
|
|
390
|
+
expire_at = created_at + timedelta(minutes=expiry_minutes) if expiry_minutes is not None else datetime.max
|
|
391
|
+
|
|
392
|
+
if self.secret_key is None:
|
|
393
|
+
raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to create an access token")
|
|
394
|
+
|
|
395
|
+
if title is not None:
|
|
396
|
+
session = self.Session()
|
|
397
|
+
try:
|
|
398
|
+
token_id = "sqrl-" + uuid.uuid4().hex
|
|
399
|
+
hashed_key = u.hash_string(token_id, salt=self.secret_key)
|
|
400
|
+
api_key = self.DbApiKey(hashed_key=hashed_key, title=title, username=user.username, created_at=created_at, expires_at=expire_at)
|
|
401
|
+
session.add(api_key)
|
|
402
|
+
session.commit()
|
|
403
|
+
finally:
|
|
404
|
+
session.close()
|
|
405
|
+
else:
|
|
406
|
+
to_encode = {"username": user.username, "exp": expire_at}
|
|
407
|
+
token_id = jwt.encode(to_encode, self.secret_key, algorithm="HS256")
|
|
408
|
+
|
|
409
|
+
return token_id, expire_at
|
|
410
|
+
|
|
411
|
+
def get_user_from_token(self, token: str | None) -> RegisteredUser | None:
|
|
412
|
+
"""
|
|
413
|
+
Get a user from an access token (JWT, or API key if token starts with 'sqrl-')
|
|
414
|
+
"""
|
|
415
|
+
if not token:
|
|
416
|
+
return None
|
|
417
|
+
|
|
418
|
+
if self.secret_key is None:
|
|
419
|
+
raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to get user from an access token")
|
|
420
|
+
|
|
421
|
+
session = self.Session()
|
|
422
|
+
try:
|
|
423
|
+
if token.startswith("sqrl-"):
|
|
424
|
+
hashed_key = u.hash_string(token, salt=self.secret_key)
|
|
425
|
+
api_key = session.query(self.DbApiKey).filter(
|
|
426
|
+
self.DbApiKey.hashed_key == hashed_key,
|
|
427
|
+
self.DbApiKey.expires_at >= func.now()
|
|
428
|
+
).first()
|
|
429
|
+
if api_key is None:
|
|
430
|
+
raise InvalidTokenError()
|
|
431
|
+
username = api_key.username
|
|
432
|
+
else:
|
|
433
|
+
payload: dict = jwt.decode(token, self.secret_key, algorithms=["HS256"])
|
|
434
|
+
username = payload["username"]
|
|
435
|
+
|
|
436
|
+
db_user = session.get(self.DbUser, username)
|
|
437
|
+
if db_user is None:
|
|
438
|
+
raise InvalidTokenError()
|
|
439
|
+
|
|
440
|
+
user = self._convert_db_user_to_user(db_user)
|
|
441
|
+
|
|
442
|
+
except InvalidTokenError:
|
|
443
|
+
raise InvalidInputError(401, "invalid_authorization_token", "Invalid authorization token")
|
|
444
|
+
finally:
|
|
445
|
+
session.close()
|
|
446
|
+
|
|
447
|
+
return user
|
|
448
|
+
|
|
449
|
+
def get_all_api_keys(self, username: str) -> list[ApiKey]:
|
|
450
|
+
"""
|
|
451
|
+
Get the ID, title, and expiry date of all API keys for a user. Note that the ID is a hash of the API key, not the API key itself.
|
|
452
|
+
"""
|
|
453
|
+
session = self.Session()
|
|
454
|
+
try:
|
|
455
|
+
tokens = session.query(self.DbApiKey).filter(
|
|
456
|
+
self.DbApiKey.username == username,
|
|
457
|
+
self.DbApiKey.expires_at >= func.now()
|
|
458
|
+
).all()
|
|
459
|
+
|
|
460
|
+
return [ApiKey.model_validate(token) for token in tokens]
|
|
461
|
+
finally:
|
|
462
|
+
session.close()
|
|
463
|
+
|
|
464
|
+
def revoke_api_key(self, username: str, api_key_id: str) -> None:
|
|
465
|
+
"""
|
|
466
|
+
Revoke an API key
|
|
467
|
+
"""
|
|
468
|
+
session = self.Session()
|
|
469
|
+
try:
|
|
470
|
+
|
|
471
|
+
api_key = session.query(self.DbApiKey).filter(
|
|
472
|
+
self.DbApiKey.username == username,
|
|
473
|
+
self.DbApiKey.id == api_key_id
|
|
474
|
+
).first()
|
|
475
|
+
|
|
476
|
+
if api_key is None:
|
|
477
|
+
raise InvalidInputError(404, "api_key_not_found", f"The API key could not be found: {api_key_id}")
|
|
478
|
+
|
|
479
|
+
session.delete(api_key)
|
|
480
|
+
session.commit()
|
|
481
|
+
finally:
|
|
482
|
+
session.close()
|
|
483
|
+
|
|
484
|
+
def can_user_access_scope(self, user: AbstractUser, scope: PermissionScope) -> bool:
|
|
485
|
+
if user.access_level == "guest":
|
|
486
|
+
user_level = PermissionScope.PUBLIC
|
|
487
|
+
elif user.access_level == "admin":
|
|
488
|
+
user_level = PermissionScope.PRIVATE
|
|
489
|
+
else: # member
|
|
490
|
+
user_level = PermissionScope.PROTECTED
|
|
491
|
+
|
|
492
|
+
return user_level.value >= scope.value
|
|
493
|
+
|
|
494
|
+
# OAuth Client Management Methods
|
|
495
|
+
|
|
496
|
+
def generate_secret_and_hash(self) -> tuple[str, str]:
|
|
497
|
+
"""Generate a secure access token and its hash"""
|
|
498
|
+
secret = secrets.token_urlsafe(64)
|
|
499
|
+
secret_hash = pwd_context.hash(secret)
|
|
500
|
+
return secret, secret_hash
|
|
501
|
+
|
|
502
|
+
def _validate_client_registration_request(self, request: ClientRegistrationRequest | ClientUpdateRequest) -> dict:
|
|
503
|
+
updates = {}
|
|
504
|
+
if request.client_name:
|
|
505
|
+
updates['client_name'] = request.client_name
|
|
506
|
+
|
|
507
|
+
# Validate redirect_uris if being updated
|
|
508
|
+
if request.redirect_uris:
|
|
509
|
+
for uri in request.redirect_uris:
|
|
510
|
+
if not self._validate_redirect_uri_format(uri):
|
|
511
|
+
raise InvalidInputError(400, "invalid_redirect_uri", f"Invalid redirect URI format: {uri}")
|
|
512
|
+
updates['redirect_uris'] = json.dumps(request.redirect_uris)
|
|
513
|
+
|
|
514
|
+
# Validate grant_types if being updated
|
|
515
|
+
if request.grant_types:
|
|
516
|
+
if not all(grant_type in c.SUPPORTED_GRANT_TYPES for grant_type in request.grant_types):
|
|
517
|
+
raise InvalidInputError(400, "invalid_grant_types", f"Invalid grant types. Supported grant types are: {c.SUPPORTED_GRANT_TYPES}")
|
|
518
|
+
updates['grant_types'] = ','.join(request.grant_types)
|
|
519
|
+
|
|
520
|
+
# Validate response_types if being updated
|
|
521
|
+
if request.response_types:
|
|
522
|
+
if not all(response_type in c.SUPPORTED_RESPONSE_TYPES for response_type in request.response_types):
|
|
523
|
+
raise InvalidInputError(400, "invalid_response_types", f"Invalid response types. Supported response types are: {c.SUPPORTED_RESPONSE_TYPES}")
|
|
524
|
+
updates['response_types'] = ','.join(request.response_types)
|
|
525
|
+
|
|
526
|
+
# Validate scope if being updated
|
|
527
|
+
if request.scope:
|
|
528
|
+
scopes = request.scope.split()
|
|
529
|
+
if not all(scope in c.SUPPORTED_SCOPES for scope in scopes):
|
|
530
|
+
raise InvalidInputError(400, "invalid_scope", f"Invalid scope. Supported scopes are: {c.SUPPORTED_SCOPES}")
|
|
531
|
+
updates['scope'] = ','.join(scopes)
|
|
532
|
+
|
|
533
|
+
return updates
|
|
534
|
+
|
|
535
|
+
def register_oauth_client(
|
|
536
|
+
self, request: ClientRegistrationRequest, client_management_path_format: str
|
|
537
|
+
) -> ClientRegistrationResponse:
|
|
538
|
+
"""Register a new OAuth client and return client_id, client_secret, and registration_access_token"""
|
|
539
|
+
grant_types = request.grant_types
|
|
540
|
+
if grant_types is None:
|
|
541
|
+
grant_types = ['authorization_code', 'refresh_token']
|
|
542
|
+
|
|
543
|
+
# Validate request
|
|
544
|
+
self._validate_client_registration_request(request)
|
|
545
|
+
|
|
546
|
+
# Generate secure client credentials and registration access token
|
|
547
|
+
client_id = secrets.token_urlsafe(16)
|
|
548
|
+
client_secret, client_secret_hash = self.generate_secret_and_hash()
|
|
549
|
+
|
|
550
|
+
registration_access_token, registration_access_token_hash = self.generate_secret_and_hash()
|
|
551
|
+
registration_client_uri = client_management_path_format.format(client_id=client_id)
|
|
552
|
+
|
|
553
|
+
session = self.Session()
|
|
554
|
+
try:
|
|
555
|
+
oauth_client = self.DbOAuthClient(
|
|
556
|
+
client_id=client_id,
|
|
557
|
+
client_secret_hash=client_secret_hash,
|
|
558
|
+
client_name=request.client_name,
|
|
559
|
+
redirect_uris=json.dumps(request.redirect_uris),
|
|
560
|
+
scope=request.scope,
|
|
561
|
+
grant_types=','.join(grant_types),
|
|
562
|
+
registration_access_token_hash=registration_access_token_hash
|
|
563
|
+
)
|
|
564
|
+
session.add(oauth_client)
|
|
565
|
+
session.commit()
|
|
566
|
+
|
|
567
|
+
return ClientRegistrationResponse(
|
|
568
|
+
client_id=client_id,
|
|
569
|
+
client_secret=client_secret,
|
|
570
|
+
client_name=request.client_name,
|
|
571
|
+
redirect_uris=request.redirect_uris,
|
|
572
|
+
scope=request.scope,
|
|
573
|
+
grant_types=grant_types,
|
|
574
|
+
response_types=request.response_types,
|
|
575
|
+
created_at=datetime.now(timezone.utc),
|
|
576
|
+
is_active=True,
|
|
577
|
+
registration_client_uri=registration_client_uri,
|
|
578
|
+
registration_access_token=registration_access_token,
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
finally:
|
|
582
|
+
session.close()
|
|
583
|
+
|
|
584
|
+
def get_oauth_client_details(self, client_id: str) -> ClientDetailsResponse:
|
|
585
|
+
"""Get OAuth client details with parsed JSON fields"""
|
|
586
|
+
session = self.Session()
|
|
587
|
+
try:
|
|
588
|
+
client = session.get(self.DbOAuthClient, client_id)
|
|
589
|
+
if client is None:
|
|
590
|
+
raise InvalidInputError(404, "invalid_client_id", "Client not found")
|
|
591
|
+
if not client.is_active:
|
|
592
|
+
raise InvalidInputError(404, "invalid_client_id", "Client is no longer active")
|
|
593
|
+
|
|
594
|
+
return ClientDetailsResponse(
|
|
595
|
+
client_id=client.client_id,
|
|
596
|
+
client_name=client.client_name,
|
|
597
|
+
redirect_uris=json.loads(client.redirect_uris),
|
|
598
|
+
scope=client.scope,
|
|
599
|
+
grant_types=client.grant_types.split(','),
|
|
600
|
+
response_types=client.response_types.split(','),
|
|
601
|
+
created_at=client.created_at,
|
|
602
|
+
is_active=client.is_active
|
|
603
|
+
)
|
|
604
|
+
finally:
|
|
605
|
+
session.close()
|
|
606
|
+
|
|
607
|
+
def validate_client_credentials(self, client_id: str, client_secret: str) -> bool:
|
|
608
|
+
"""Validate OAuth client credentials"""
|
|
609
|
+
session = self.Session()
|
|
610
|
+
try:
|
|
611
|
+
client = session.get(self.DbOAuthClient, client_id)
|
|
612
|
+
if client is None or not client.is_active:
|
|
613
|
+
return False
|
|
614
|
+
return pwd_context.verify(client_secret, client.client_secret_hash)
|
|
615
|
+
finally:
|
|
616
|
+
session.close()
|
|
617
|
+
|
|
618
|
+
def validate_redirect_uri(self, client_id: str, redirect_uri: str) -> bool:
|
|
619
|
+
"""Validate that redirect_uri is registered for the client"""
|
|
620
|
+
session = self.Session()
|
|
621
|
+
try:
|
|
622
|
+
client = session.get(self.DbOAuthClient, client_id)
|
|
623
|
+
if client is None or not client.is_active:
|
|
624
|
+
return False
|
|
625
|
+
|
|
626
|
+
registered_uris = json.loads(client.redirect_uris)
|
|
627
|
+
return redirect_uri in registered_uris
|
|
628
|
+
finally:
|
|
629
|
+
session.close()
|
|
630
|
+
|
|
631
|
+
def validate_registration_access_token(self, client_id: str, registration_access_token: str) -> bool:
|
|
632
|
+
"""Validate registration access token for client management operations"""
|
|
633
|
+
session = self.Session()
|
|
634
|
+
try:
|
|
635
|
+
client = session.get(self.DbOAuthClient, client_id)
|
|
636
|
+
if client is None:
|
|
637
|
+
return False
|
|
638
|
+
|
|
639
|
+
return pwd_context.verify(registration_access_token, client.registration_access_token_hash)
|
|
640
|
+
finally:
|
|
641
|
+
session.close()
|
|
642
|
+
|
|
643
|
+
def _validate_redirect_uri_format(self, uri: str) -> bool:
|
|
644
|
+
"""Validate redirect URI format for security"""
|
|
645
|
+
# Basic validation - must be https (except localhost) and not contain fragments
|
|
646
|
+
if '#' in uri:
|
|
647
|
+
return False
|
|
648
|
+
|
|
649
|
+
if uri.startswith('http://'):
|
|
650
|
+
# Only allow http for localhost/development
|
|
651
|
+
if not ('localhost' in uri or '127.0.0.1' in uri):
|
|
652
|
+
return False
|
|
653
|
+
elif not uri.startswith('https://'):
|
|
654
|
+
# Custom schemes allowed for mobile apps
|
|
655
|
+
if '://' not in uri:
|
|
656
|
+
return False
|
|
657
|
+
|
|
658
|
+
return True
|
|
659
|
+
|
|
660
|
+
def update_oauth_client_with_token_rotation(self, client_id: str, request: ClientUpdateRequest) -> ClientUpdateResponse:
|
|
661
|
+
"""Update OAuth client metadata and rotate registration access token"""
|
|
662
|
+
|
|
663
|
+
# Build update dictionary, excluding None values
|
|
664
|
+
updates = self._validate_client_registration_request(request)
|
|
665
|
+
|
|
666
|
+
session = self.Session()
|
|
667
|
+
try:
|
|
668
|
+
client = session.get(self.DbOAuthClient, client_id)
|
|
669
|
+
if client is None:
|
|
670
|
+
raise InvalidInputError(404, "invalid_client_id", f"OAuth client not found: {client_id}")
|
|
671
|
+
if not client.is_active and not request.is_active:
|
|
672
|
+
raise InvalidInputError(400, "invalid_request", "Cannot update a deactivated client")
|
|
673
|
+
|
|
674
|
+
if request.is_active is not None:
|
|
675
|
+
client.is_active = request.is_active
|
|
676
|
+
|
|
677
|
+
# Generate new registration access token
|
|
678
|
+
new_registration_access_token, new_registration_access_token_hash = self.generate_secret_and_hash()
|
|
679
|
+
|
|
680
|
+
# Update client fields
|
|
681
|
+
for key, value in updates.items():
|
|
682
|
+
if hasattr(client, key):
|
|
683
|
+
setattr(client, key, value)
|
|
684
|
+
|
|
685
|
+
# Update registration access token
|
|
686
|
+
client.registration_access_token_hash = new_registration_access_token_hash
|
|
687
|
+
|
|
688
|
+
session.commit()
|
|
689
|
+
|
|
690
|
+
return ClientUpdateResponse(
|
|
691
|
+
client_id=client.client_id,
|
|
692
|
+
client_name=client.client_name,
|
|
693
|
+
redirect_uris=json.loads(client.redirect_uris),
|
|
694
|
+
scope=client.scope,
|
|
695
|
+
grant_types=client.grant_types.split(','),
|
|
696
|
+
response_types=client.response_types.split(','),
|
|
697
|
+
created_at=client.created_at,
|
|
698
|
+
is_active=client.is_active,
|
|
699
|
+
registration_access_token=new_registration_access_token,
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
finally:
|
|
703
|
+
session.close()
|
|
704
|
+
|
|
705
|
+
def revoke_oauth_client(self, client_id: str) -> None:
|
|
706
|
+
"""Revoke (deactivate) an OAuth client"""
|
|
707
|
+
session = self.Session()
|
|
708
|
+
try:
|
|
709
|
+
client = session.get(self.DbOAuthClient, client_id)
|
|
710
|
+
if client is None:
|
|
711
|
+
raise InvalidInputError(404, "client_not_found", f"OAuth client not found: {client_id}")
|
|
712
|
+
|
|
713
|
+
client.is_active = False
|
|
714
|
+
session.commit()
|
|
715
|
+
finally:
|
|
716
|
+
session.close()
|
|
717
|
+
|
|
718
|
+
# Authorization Code Flow Methods
|
|
719
|
+
|
|
720
|
+
def create_authorization_code(
|
|
721
|
+
self, client_id: str, username: str, redirect_uri: str, scope: str,
|
|
722
|
+
code_challenge: str, code_challenge_method: str = "S256"
|
|
723
|
+
) -> str:
|
|
724
|
+
"""Create and store an authorization code for the authorization code flow"""
|
|
725
|
+
|
|
726
|
+
# Validate client_id
|
|
727
|
+
client_details = self.get_oauth_client_details(client_id)
|
|
728
|
+
if client_details is None or not client_details.is_active:
|
|
729
|
+
raise InvalidInputError(400, "invalid_client", "Invalid or inactive client")
|
|
730
|
+
|
|
731
|
+
# Validate redirect_uri
|
|
732
|
+
if not self.validate_redirect_uri(client_id, redirect_uri):
|
|
733
|
+
raise InvalidInputError(400, "invalid_redirect_uri", "Invalid redirect_uri for this client")
|
|
734
|
+
|
|
735
|
+
# Validate PKCE parameters
|
|
736
|
+
if not code_challenge:
|
|
737
|
+
raise InvalidInputError(400, "invalid_request", "code_challenge is required")
|
|
738
|
+
|
|
739
|
+
if code_challenge_method not in ["S256"]:
|
|
740
|
+
raise InvalidInputError(400, "invalid_request", "code_challenge_method must be 'S256'")
|
|
741
|
+
|
|
742
|
+
# Generate authorization code
|
|
743
|
+
code = secrets.token_urlsafe(48)
|
|
744
|
+
|
|
745
|
+
# Set expiration (10 minutes from now)
|
|
746
|
+
created_at = datetime.now(timezone.utc)
|
|
747
|
+
expires_at = created_at + timedelta(minutes=10)
|
|
748
|
+
|
|
749
|
+
session = self.Session()
|
|
750
|
+
try:
|
|
751
|
+
auth_code = self.DbAuthorizationCode(
|
|
752
|
+
code=code,
|
|
753
|
+
client_id=client_id,
|
|
754
|
+
username=username,
|
|
755
|
+
redirect_uri=redirect_uri,
|
|
756
|
+
scope=scope,
|
|
757
|
+
code_challenge=code_challenge,
|
|
758
|
+
code_challenge_method=code_challenge_method,
|
|
759
|
+
created_at=created_at,
|
|
760
|
+
expires_at=expires_at,
|
|
761
|
+
used=False
|
|
762
|
+
)
|
|
763
|
+
session.add(auth_code)
|
|
764
|
+
session.commit()
|
|
765
|
+
|
|
766
|
+
return code
|
|
767
|
+
|
|
768
|
+
finally:
|
|
769
|
+
session.close()
|
|
770
|
+
|
|
771
|
+
def exchange_authorization_code(
|
|
772
|
+
self, code: str, client_id: str, redirect_uri: str, code_verifier: str, access_token_expiry_minutes: int
|
|
773
|
+
) -> TokenResponse:
|
|
774
|
+
"""
|
|
775
|
+
Exchange authorization code for access and refresh tokens
|
|
776
|
+
Returns (access_token, refresh_token)
|
|
777
|
+
"""
|
|
778
|
+
session = self.Session()
|
|
779
|
+
try:
|
|
780
|
+
# Get and validate authorization code
|
|
781
|
+
auth_code = session.query(self.DbAuthorizationCode).filter(
|
|
782
|
+
self.DbAuthorizationCode.code == code,
|
|
783
|
+
self.DbAuthorizationCode.client_id == client_id,
|
|
784
|
+
self.DbAuthorizationCode.expires_at >= func.now(),
|
|
785
|
+
self.DbAuthorizationCode.used == False
|
|
786
|
+
).first()
|
|
787
|
+
|
|
788
|
+
if auth_code is None:
|
|
789
|
+
raise InvalidInputError(400, "invalid_grant", "Invalid authorization code")
|
|
790
|
+
|
|
791
|
+
# Validate redirect URI
|
|
792
|
+
if auth_code.redirect_uri != redirect_uri:
|
|
793
|
+
raise InvalidInputError(400, "invalid_grant", "Redirect URI mismatch")
|
|
794
|
+
|
|
795
|
+
if not u.validate_pkce_challenge(code_verifier, auth_code.code_challenge):
|
|
796
|
+
raise InvalidInputError(400, "invalid_grant", "Invalid code_verifier")
|
|
797
|
+
|
|
798
|
+
# Get user
|
|
799
|
+
db_user = session.get(self.DbUser, auth_code.username)
|
|
800
|
+
if db_user is None:
|
|
801
|
+
raise InvalidInputError(400, "invalid_grant", "User not found")
|
|
802
|
+
|
|
803
|
+
# Mark authorization code as used
|
|
804
|
+
auth_code.used = True
|
|
805
|
+
|
|
806
|
+
# Generate tokens
|
|
807
|
+
user_obj = self._convert_db_user_to_user(db_user)
|
|
808
|
+
access_token, token_expires_at = self.create_access_token(user_obj, expiry_minutes=access_token_expiry_minutes)
|
|
809
|
+
access_token_hash = pwd_context.hash(access_token)
|
|
810
|
+
|
|
811
|
+
# Generate refresh token
|
|
812
|
+
refresh_token, refresh_token_hash = self.generate_secret_and_hash()
|
|
813
|
+
refresh_expires_at = datetime.now(timezone.utc) + timedelta(days=30)
|
|
814
|
+
|
|
815
|
+
oauth_token = self.DbOAuthToken(
|
|
816
|
+
access_token_hash=access_token_hash,
|
|
817
|
+
refresh_token_hash=refresh_token_hash,
|
|
818
|
+
client_id=client_id,
|
|
819
|
+
username=auth_code.username,
|
|
820
|
+
scope=auth_code.scope,
|
|
821
|
+
access_token_expires_at=token_expires_at,
|
|
822
|
+
refresh_token_expires_at=refresh_expires_at
|
|
823
|
+
)
|
|
824
|
+
session.add(oauth_token)
|
|
825
|
+
|
|
826
|
+
session.commit()
|
|
827
|
+
|
|
828
|
+
return TokenResponse(
|
|
829
|
+
access_token=access_token,
|
|
830
|
+
expires_in=access_token_expiry_minutes*60,
|
|
831
|
+
refresh_token=refresh_token
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
finally:
|
|
835
|
+
session.close()
|
|
836
|
+
|
|
837
|
+
def refresh_oauth_access_token(self, refresh_token: str, client_id: str, access_token_expiry_minutes: int) -> TokenResponse:
|
|
838
|
+
"""
|
|
839
|
+
Refresh OAuth access token using refresh token
|
|
840
|
+
Returns (access_token, new_refresh_token)
|
|
841
|
+
"""
|
|
842
|
+
session = self.Session()
|
|
843
|
+
try:
|
|
844
|
+
# Validate client
|
|
845
|
+
client = session.get(self.DbOAuthClient, client_id)
|
|
846
|
+
if client is None or not client.is_active:
|
|
847
|
+
raise InvalidInputError(400, "invalid_client", "Invalid or inactive client")
|
|
848
|
+
|
|
849
|
+
# Find active refresh token for this client
|
|
850
|
+
oauth_token = session.query(self.DbOAuthToken).filter(
|
|
851
|
+
self.DbOAuthToken.client_id == client_id,
|
|
852
|
+
self.DbOAuthToken.refresh_token_expires_at >= func.now(),
|
|
853
|
+
self.DbOAuthToken.is_revoked == False
|
|
854
|
+
).first()
|
|
855
|
+
|
|
856
|
+
# Find the token that matches our refresh token
|
|
857
|
+
if oauth_token is None or not pwd_context.verify(refresh_token, oauth_token.refresh_token_hash):
|
|
858
|
+
raise InvalidInputError(400, "invalid_grant", "Invalid or expired refresh token")
|
|
859
|
+
|
|
860
|
+
# Get user
|
|
861
|
+
db_user = session.get(self.DbUser, oauth_token.username)
|
|
862
|
+
if db_user is None:
|
|
863
|
+
raise InvalidInputError(400, "invalid_client", "User not found")
|
|
864
|
+
|
|
865
|
+
# Check secret key is available
|
|
866
|
+
if self.secret_key is None:
|
|
867
|
+
raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required for OAuth token operations")
|
|
868
|
+
|
|
869
|
+
# Generate new tokens
|
|
870
|
+
user_obj = self._convert_db_user_to_user(db_user)
|
|
871
|
+
access_token, token_expires_at = self.create_access_token(user_obj, expiry_minutes=access_token_expiry_minutes)
|
|
872
|
+
access_token_hash = pwd_context.hash(access_token)
|
|
873
|
+
|
|
874
|
+
# Generate new refresh token
|
|
875
|
+
new_refresh_token, new_refresh_token_hash = self.generate_secret_and_hash()
|
|
876
|
+
refresh_expires_at = datetime.now(timezone.utc) + timedelta(days=30)
|
|
877
|
+
|
|
878
|
+
# Revoke old token
|
|
879
|
+
oauth_token.is_revoked = True
|
|
880
|
+
|
|
881
|
+
# Create new token entry
|
|
882
|
+
new_oauth_token = self.DbOAuthToken(
|
|
883
|
+
access_token_hash=access_token_hash,
|
|
884
|
+
refresh_token_hash=new_refresh_token_hash,
|
|
885
|
+
client_id=client_id,
|
|
886
|
+
username=oauth_token.username,
|
|
887
|
+
scope=oauth_token.scope,
|
|
888
|
+
access_token_expires_at=token_expires_at,
|
|
889
|
+
refresh_token_expires_at=refresh_expires_at
|
|
890
|
+
)
|
|
891
|
+
session.add(new_oauth_token)
|
|
892
|
+
|
|
893
|
+
session.commit()
|
|
894
|
+
|
|
895
|
+
return TokenResponse(
|
|
896
|
+
access_token=access_token,
|
|
897
|
+
expires_in=access_token_expiry_minutes*60,
|
|
898
|
+
refresh_token=new_refresh_token
|
|
899
|
+
)
|
|
900
|
+
|
|
901
|
+
finally:
|
|
902
|
+
session.close()
|
|
903
|
+
|
|
904
|
+
def revoke_oauth_token(self, client_id: str, token: str, token_type_hint: str | None = None) -> None:
|
|
905
|
+
"""
|
|
906
|
+
Revoke an OAuth refresh token
|
|
907
|
+
token_type_hint is optional or must be 'refresh_token'. Revoking access token is not supported yet.
|
|
908
|
+
"""
|
|
909
|
+
if token_type_hint and token_type_hint != 'refresh_token':
|
|
910
|
+
raise InvalidInputError(400, "invalid_request", "Only refresh tokens can be revoked")
|
|
911
|
+
|
|
912
|
+
session = self.Session()
|
|
913
|
+
try:
|
|
914
|
+
# Validate client
|
|
915
|
+
client = session.get(self.DbOAuthClient, client_id)
|
|
916
|
+
if client is None or not client.is_active:
|
|
917
|
+
raise InvalidInputError(400, "invalid_client", "Invalid or inactive client")
|
|
918
|
+
|
|
919
|
+
# Get all potentially matching tokens
|
|
920
|
+
oauth_tokens = session.query(self.DbOAuthToken).filter(
|
|
921
|
+
self.DbOAuthToken.client_id == client_id,
|
|
922
|
+
self.DbOAuthToken.refresh_token_expires_at >= func.now(),
|
|
923
|
+
self.DbOAuthToken.is_revoked == False
|
|
924
|
+
).all()
|
|
925
|
+
|
|
926
|
+
# Find the token that matches
|
|
927
|
+
oauth_token = None
|
|
928
|
+
for token_obj in oauth_tokens:
|
|
929
|
+
if pwd_context.verify(token, token_obj.refresh_token_hash):
|
|
930
|
+
oauth_token = token_obj
|
|
931
|
+
break
|
|
932
|
+
|
|
933
|
+
# Revoke token if found (per OAuth spec, always return success)
|
|
934
|
+
if oauth_token:
|
|
935
|
+
oauth_token.is_revoked = True
|
|
936
|
+
session.commit()
|
|
937
|
+
|
|
938
|
+
finally:
|
|
939
|
+
session.close()
|
|
940
|
+
|
|
941
|
+
def close(self) -> None:
|
|
942
|
+
self.engine.dispose()
|
|
943
|
+
|
|
944
|
+
|
|
945
|
+
def provider(name: str, label: str, icon: str):
|
|
946
|
+
"""
|
|
947
|
+
Decorator to register an authentication provider
|
|
948
|
+
|
|
949
|
+
Arguments:
|
|
950
|
+
name: The name of the provider (must be unique, e.g. 'google')
|
|
951
|
+
label: The label of the provider (e.g. 'Google')
|
|
952
|
+
icon: The URL of the icon of the provider (e.g. 'https://www.google.com/favicon.ico')
|
|
953
|
+
"""
|
|
954
|
+
def decorator(func: Callable[[AuthProviderArgs], ProviderConfigs]):
|
|
955
|
+
def wrapper(sqrl: AuthProviderArgs):
|
|
956
|
+
provider_configs = func(sqrl)
|
|
957
|
+
return AuthProvider(name=name, label=label, icon=icon, provider_configs=provider_configs)
|
|
958
|
+
Authenticator.providers.append(wrapper)
|
|
959
|
+
return wrapper
|
|
960
|
+
return decorator
|