squirrels 0.1.0__py3-none-any.whl → 0.6.0.post0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dateutils/__init__.py +6 -0
- dateutils/_enums.py +25 -0
- squirrels/dateutils.py → dateutils/_implementation.py +409 -380
- dateutils/types.py +6 -0
- squirrels/__init__.py +21 -18
- squirrels/_api_routes/__init__.py +5 -0
- squirrels/_api_routes/auth.py +337 -0
- squirrels/_api_routes/base.py +196 -0
- squirrels/_api_routes/dashboards.py +156 -0
- squirrels/_api_routes/data_management.py +148 -0
- squirrels/_api_routes/datasets.py +220 -0
- squirrels/_api_routes/project.py +289 -0
- squirrels/_api_server.py +552 -134
- squirrels/_arguments/__init__.py +0 -0
- squirrels/_arguments/init_time_args.py +83 -0
- squirrels/_arguments/run_time_args.py +111 -0
- squirrels/_auth.py +777 -0
- squirrels/_command_line.py +239 -107
- squirrels/_compile_prompts.py +147 -0
- squirrels/_connection_set.py +94 -0
- squirrels/_constants.py +141 -64
- squirrels/_dashboards.py +179 -0
- squirrels/_data_sources.py +570 -0
- squirrels/_dataset_types.py +91 -0
- squirrels/_env_vars.py +209 -0
- squirrels/_exceptions.py +29 -0
- squirrels/_http_error_responses.py +52 -0
- squirrels/_initializer.py +319 -110
- squirrels/_logging.py +121 -0
- squirrels/_manifest.py +357 -187
- squirrels/_mcp_server.py +578 -0
- squirrels/_model_builder.py +69 -0
- squirrels/_model_configs.py +74 -0
- squirrels/_model_queries.py +52 -0
- squirrels/_models.py +1201 -0
- squirrels/_package_data/base_project/.env +7 -0
- squirrels/_package_data/base_project/.env.example +44 -0
- squirrels/_package_data/base_project/connections.yml +16 -0
- squirrels/_package_data/base_project/dashboards/dashboard_example.py +40 -0
- squirrels/_package_data/base_project/dashboards/dashboard_example.yml +22 -0
- squirrels/_package_data/base_project/docker/.dockerignore +16 -0
- squirrels/_package_data/base_project/docker/Dockerfile +16 -0
- squirrels/_package_data/base_project/docker/compose.yml +7 -0
- squirrels/_package_data/base_project/duckdb_init.sql +10 -0
- squirrels/_package_data/base_project/gitignore +13 -0
- squirrels/_package_data/base_project/macros/macros_example.sql +17 -0
- squirrels/_package_data/base_project/models/builds/build_example.py +26 -0
- squirrels/_package_data/base_project/models/builds/build_example.sql +16 -0
- squirrels/_package_data/base_project/models/builds/build_example.yml +57 -0
- squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +17 -0
- squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +32 -0
- squirrels/_package_data/base_project/models/federates/federate_example.py +51 -0
- squirrels/_package_data/base_project/models/federates/federate_example.sql +21 -0
- squirrels/_package_data/base_project/models/federates/federate_example.yml +65 -0
- squirrels/_package_data/base_project/models/sources.yml +38 -0
- squirrels/_package_data/base_project/parameters.yml +142 -0
- squirrels/_package_data/base_project/pyconfigs/connections.py +19 -0
- squirrels/_package_data/base_project/pyconfigs/context.py +96 -0
- squirrels/_package_data/base_project/pyconfigs/parameters.py +141 -0
- squirrels/_package_data/base_project/pyconfigs/user.py +56 -0
- squirrels/_package_data/base_project/resources/expenses.db +0 -0
- squirrels/_package_data/base_project/resources/public/.gitkeep +0 -0
- squirrels/_package_data/base_project/resources/weather.db +0 -0
- squirrels/_package_data/base_project/seeds/seed_categories.csv +6 -0
- squirrels/_package_data/base_project/seeds/seed_categories.yml +15 -0
- squirrels/_package_data/base_project/seeds/seed_subcategories.csv +15 -0
- squirrels/_package_data/base_project/seeds/seed_subcategories.yml +21 -0
- squirrels/_package_data/base_project/squirrels.yml.j2 +61 -0
- squirrels/_package_data/base_project/tmp/.gitignore +2 -0
- squirrels/_package_data/templates/login_successful.html +53 -0
- squirrels/_package_data/templates/squirrels_studio.html +22 -0
- squirrels/_package_loader.py +29 -0
- squirrels/_parameter_configs.py +592 -0
- squirrels/_parameter_options.py +348 -0
- squirrels/_parameter_sets.py +207 -0
- squirrels/_parameters.py +1703 -0
- squirrels/_project.py +796 -0
- squirrels/_py_module.py +122 -0
- squirrels/_request_context.py +33 -0
- squirrels/_schemas/__init__.py +0 -0
- squirrels/_schemas/auth_models.py +83 -0
- squirrels/_schemas/query_param_models.py +70 -0
- squirrels/_schemas/request_models.py +26 -0
- squirrels/_schemas/response_models.py +286 -0
- squirrels/_seeds.py +97 -0
- squirrels/_sources.py +112 -0
- squirrels/_utils.py +540 -149
- squirrels/_version.py +1 -3
- squirrels/arguments.py +7 -0
- squirrels/auth.py +4 -0
- squirrels/connections.py +3 -0
- squirrels/dashboards.py +3 -0
- squirrels/data_sources.py +14 -282
- squirrels/parameter_options.py +13 -189
- squirrels/parameters.py +14 -801
- squirrels/types.py +18 -0
- squirrels-0.6.0.post0.dist-info/METADATA +148 -0
- squirrels-0.6.0.post0.dist-info/RECORD +101 -0
- {squirrels-0.1.0.dist-info → squirrels-0.6.0.post0.dist-info}/WHEEL +1 -2
- {squirrels-0.1.0.dist-info → squirrels-0.6.0.post0.dist-info}/entry_points.txt +1 -0
- squirrels-0.6.0.post0.dist-info/licenses/LICENSE +201 -0
- squirrels/_credentials_manager.py +0 -87
- squirrels/_module_loader.py +0 -37
- squirrels/_parameter_set.py +0 -151
- squirrels/_renderer.py +0 -286
- squirrels/_timed_imports.py +0 -37
- squirrels/connection_set.py +0 -126
- squirrels/package_data/base_project/.gitignore +0 -4
- squirrels/package_data/base_project/connections.py +0 -21
- squirrels/package_data/base_project/database/sample_database.db +0 -0
- squirrels/package_data/base_project/database/seattle_weather.db +0 -0
- squirrels/package_data/base_project/datasets/sample_dataset/context.py +0 -8
- squirrels/package_data/base_project/datasets/sample_dataset/database_view1.py +0 -23
- squirrels/package_data/base_project/datasets/sample_dataset/database_view1.sql.j2 +0 -7
- squirrels/package_data/base_project/datasets/sample_dataset/final_view.py +0 -10
- squirrels/package_data/base_project/datasets/sample_dataset/final_view.sql.j2 +0 -2
- squirrels/package_data/base_project/datasets/sample_dataset/parameters.py +0 -30
- squirrels/package_data/base_project/datasets/sample_dataset/selections.cfg +0 -6
- squirrels/package_data/base_project/squirrels.yaml +0 -26
- squirrels/package_data/static/favicon.ico +0 -0
- squirrels/package_data/static/script.js +0 -234
- squirrels/package_data/static/style.css +0 -110
- squirrels/package_data/templates/index.html +0 -32
- squirrels-0.1.0.dist-info/LICENSE +0 -22
- squirrels-0.1.0.dist-info/METADATA +0 -67
- squirrels-0.1.0.dist-info/RECORD +0 -40
- squirrels-0.1.0.dist-info/top_level.txt +0 -1
squirrels/_auth.py
ADDED
|
@@ -0,0 +1,777 @@
|
|
|
1
|
+
from typing import Callable, Any
|
|
2
|
+
from datetime import datetime, timedelta, timezone
|
|
3
|
+
from functools import cached_property
|
|
4
|
+
from jwt.exceptions import InvalidTokenError
|
|
5
|
+
from passlib.context import CryptContext
|
|
6
|
+
from pydantic import ValidationError
|
|
7
|
+
from sqlalchemy import create_engine, Engine, func, inspect, text, ForeignKey
|
|
8
|
+
from sqlalchemy.orm import declarative_base, sessionmaker, Mapped, mapped_column
|
|
9
|
+
import jwt, uuid, secrets, json, base64, requests
|
|
10
|
+
from jwt import PyJWKClient
|
|
11
|
+
|
|
12
|
+
from ._env_vars import SquirrelsEnvVars
|
|
13
|
+
from ._manifest import PermissionScope
|
|
14
|
+
from ._exceptions import InvalidInputError, ConfigurationError
|
|
15
|
+
from ._arguments.init_time_args import AuthProviderArgs
|
|
16
|
+
from ._schemas.auth_models import (
|
|
17
|
+
CustomUserFields, AbstractUser, RegisteredUser, ApiKey, UserField, UserFieldsModel, AuthProvider, ProviderConfigs
|
|
18
|
+
)
|
|
19
|
+
from ._schemas import response_models as rm
|
|
20
|
+
from . import _utils as u, _constants as c
|
|
21
|
+
|
|
22
|
+
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
23
|
+
|
|
24
|
+
ProviderFunctionType = Callable[[AuthProviderArgs], AuthProvider]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Authenticator:
|
|
28
|
+
providers: list[ProviderFunctionType] = [] # static variable to stage providers
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self, logger: u.Logger, env_vars: SquirrelsEnvVars, auth_args: AuthProviderArgs,
|
|
32
|
+
provider_functions: list[ProviderFunctionType], custom_user_fields_cls: type[CustomUserFields],
|
|
33
|
+
*,
|
|
34
|
+
sa_engine: Engine | None = None, external_only: bool = False
|
|
35
|
+
):
|
|
36
|
+
self.logger = logger
|
|
37
|
+
self.env_vars = env_vars
|
|
38
|
+
self.secret_key = env_vars.secret_key
|
|
39
|
+
self.external_only = external_only
|
|
40
|
+
self.password_requirements = rm.PasswordRequirements()
|
|
41
|
+
|
|
42
|
+
# Create a new declarative base for this instance
|
|
43
|
+
self.Base = declarative_base()
|
|
44
|
+
|
|
45
|
+
# Define DbUser class for this instance
|
|
46
|
+
class DbUser(self.Base):
|
|
47
|
+
__tablename__ = 'users'
|
|
48
|
+
__table_args__ = {'extend_existing': True}
|
|
49
|
+
username: Mapped[str] = mapped_column(primary_key=True)
|
|
50
|
+
access_level: Mapped[str] = mapped_column(nullable=False, default="member")
|
|
51
|
+
password_hash: Mapped[str] = mapped_column(nullable=False)
|
|
52
|
+
custom_fields: Mapped[str] = mapped_column(nullable=False, default="{}")
|
|
53
|
+
created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
|
|
54
|
+
|
|
55
|
+
# Define DbApiKey class for this instance
|
|
56
|
+
class DbApiKey(self.Base):
|
|
57
|
+
__tablename__ = 'api_keys'
|
|
58
|
+
|
|
59
|
+
id: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
|
|
60
|
+
hashed_key: Mapped[str] = mapped_column(unique=True, nullable=False)
|
|
61
|
+
last_four: Mapped[str] = mapped_column(nullable=False)
|
|
62
|
+
title: Mapped[str] = mapped_column(nullable=False)
|
|
63
|
+
username: Mapped[str] = mapped_column(ForeignKey('users.username', ondelete='CASCADE'), nullable=False)
|
|
64
|
+
created_at: Mapped[datetime] = mapped_column(nullable=False)
|
|
65
|
+
expires_at: Mapped[datetime] = mapped_column(nullable=False)
|
|
66
|
+
|
|
67
|
+
def __repr__(self):
|
|
68
|
+
return f"<DbApiKey(id='{self.id}', username='{self.username}')>"
|
|
69
|
+
|
|
70
|
+
self.CustomUserFields = custom_user_fields_cls
|
|
71
|
+
self.DbUser = DbUser
|
|
72
|
+
|
|
73
|
+
self.DbApiKey = DbApiKey
|
|
74
|
+
|
|
75
|
+
self.auth_providers = [provider_function(auth_args) for provider_function in provider_functions]
|
|
76
|
+
self._jwks_clients: dict[str, PyJWKClient] = {}
|
|
77
|
+
self._provider_metadata_cache: dict[str, dict] = {}
|
|
78
|
+
|
|
79
|
+
if not self.external_only:
|
|
80
|
+
if sa_engine is None:
|
|
81
|
+
raw_sqlite_path = self.env_vars.auth_db_file_path
|
|
82
|
+
sqlite_path = u.Path(raw_sqlite_path.format(project_path=self.env_vars.project_path))
|
|
83
|
+
sqlite_path.parent.mkdir(parents=True, exist_ok=True)
|
|
84
|
+
self.engine = create_engine(f"sqlite:///{str(sqlite_path)}")
|
|
85
|
+
else:
|
|
86
|
+
self.engine = sa_engine
|
|
87
|
+
|
|
88
|
+
# Configure SQLite pragmas
|
|
89
|
+
with self.engine.connect() as conn:
|
|
90
|
+
conn.execute(text("PRAGMA journal_mode = WAL"))
|
|
91
|
+
conn.execute(text("PRAGMA synchronous = NORMAL"))
|
|
92
|
+
conn.commit()
|
|
93
|
+
|
|
94
|
+
self.Base.metadata.create_all(self.engine)
|
|
95
|
+
|
|
96
|
+
self.Session = sessionmaker(bind=self.engine)
|
|
97
|
+
|
|
98
|
+
self._initialize_db()
|
|
99
|
+
|
|
100
|
+
def _convert_db_user_to_user(self, db_user) -> RegisteredUser:
|
|
101
|
+
"""Convert a database user to an AbstractUser object"""
|
|
102
|
+
# Deserialize custom_fields JSON and merge with defaults
|
|
103
|
+
custom_fields_json = json.loads(db_user.custom_fields) if db_user.custom_fields else {}
|
|
104
|
+
custom_fields = self.CustomUserFields(**custom_fields_json)
|
|
105
|
+
|
|
106
|
+
return RegisteredUser(
|
|
107
|
+
username=db_user.username,
|
|
108
|
+
access_level=db_user.access_level,
|
|
109
|
+
custom_fields=custom_fields
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
def _validate_password_length(self, password: str) -> None:
|
|
113
|
+
"""Validate that password meets length requirements and does not exceed 72 characters (bcrypt limit)"""
|
|
114
|
+
min_len = self.password_requirements.min_length
|
|
115
|
+
max_len = min(self.password_requirements.max_length, 72)
|
|
116
|
+
if len(password) < min_len:
|
|
117
|
+
raise InvalidInputError(400, "password_too_short", f"Password must be at least {min_len} characters long")
|
|
118
|
+
if len(password) > max_len:
|
|
119
|
+
raise InvalidInputError(400, "password_too_long", f"Password cannot exceed {max_len} characters")
|
|
120
|
+
|
|
121
|
+
def _initialize_db(self):
|
|
122
|
+
session = self.Session()
|
|
123
|
+
try:
|
|
124
|
+
# Get existing columns in the database
|
|
125
|
+
inspector = inspect(self.engine)
|
|
126
|
+
|
|
127
|
+
for db_model in [self.DbUser, self.DbApiKey]:
|
|
128
|
+
table_name = db_model.__tablename__
|
|
129
|
+
existing_columns = {col['name'] for col in inspector.get_columns(table_name)}
|
|
130
|
+
model_columns = set(db_model.__table__.columns.keys())
|
|
131
|
+
new_columns = model_columns - existing_columns
|
|
132
|
+
|
|
133
|
+
if new_columns:
|
|
134
|
+
add_columns_msg = f"Adding columns to table {table_name}: {new_columns}"
|
|
135
|
+
self.logger.info(add_columns_msg)
|
|
136
|
+
|
|
137
|
+
for col_name in new_columns:
|
|
138
|
+
col = db_model.__table__.columns[col_name]
|
|
139
|
+
column_type = col.type.compile(self.engine.dialect)
|
|
140
|
+
nullable = "NULL" if col.nullable else "NOT NULL"
|
|
141
|
+
if col.default is not None and not callable(col.default.arg):
|
|
142
|
+
default_val = f"'{col.default.arg}'" if isinstance(col.default.arg, str) else col.default.arg
|
|
143
|
+
default = f"DEFAULT {default_val}"
|
|
144
|
+
else:
|
|
145
|
+
# If nullable is False and no default is provided, use an empty string as a placeholder default for SQLite
|
|
146
|
+
# TODO: Use a different default value (instead of empty string) based on the column type
|
|
147
|
+
default = "DEFAULT ''" if not col.nullable else ""
|
|
148
|
+
|
|
149
|
+
alter_stmt = f"ALTER TABLE {table_name} ADD COLUMN {col_name} {column_type} {nullable} {default}"
|
|
150
|
+
session.execute(text(alter_stmt))
|
|
151
|
+
|
|
152
|
+
session.commit()
|
|
153
|
+
|
|
154
|
+
# Get admin password from environment variable if exists
|
|
155
|
+
admin_password = self.env_vars.secret_admin_password
|
|
156
|
+
|
|
157
|
+
if admin_password is not None:
|
|
158
|
+
self._validate_password_length(admin_password)
|
|
159
|
+
password_hash = pwd_context.hash(admin_password)
|
|
160
|
+
admin_user = session.get(self.DbUser, c.ADMIN_USERNAME)
|
|
161
|
+
if admin_user is None:
|
|
162
|
+
admin_user = self.DbUser(username=c.ADMIN_USERNAME, password_hash=password_hash, access_level="admin")
|
|
163
|
+
session.add(admin_user)
|
|
164
|
+
else:
|
|
165
|
+
admin_user.password_hash = password_hash
|
|
166
|
+
|
|
167
|
+
session.commit()
|
|
168
|
+
|
|
169
|
+
finally:
|
|
170
|
+
session.close()
|
|
171
|
+
|
|
172
|
+
@cached_property
|
|
173
|
+
def user_fields(self) -> UserFieldsModel:
|
|
174
|
+
"""
|
|
175
|
+
Get the fields of the CustomUserFields model as a list of dictionaries
|
|
176
|
+
|
|
177
|
+
Each dictionary contains the following keys:
|
|
178
|
+
- name: The name of the field
|
|
179
|
+
- type: The type of the field
|
|
180
|
+
- nullable: Whether the field is nullable
|
|
181
|
+
- enum: The possible values of the field (or None if not applicable)
|
|
182
|
+
- default: The default value of the field (or None if field is required)
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
custom_fields = []
|
|
186
|
+
schema = self.CustomUserFields.model_json_schema()
|
|
187
|
+
properties: dict[str, dict[str, Any]] = schema.get("properties", {})
|
|
188
|
+
for field_name, field_schema in properties.items():
|
|
189
|
+
if choices := field_schema.get("anyOf"):
|
|
190
|
+
field_type = choices[0]["type"]
|
|
191
|
+
nullable = (choices[1]["type"] == "null")
|
|
192
|
+
else:
|
|
193
|
+
field_type = field_schema["type"]
|
|
194
|
+
nullable = False
|
|
195
|
+
|
|
196
|
+
field_data = UserField(name=field_name, label=field_schema.get("title", field_name), type=field_type, nullable=nullable, enum=field_schema.get("enum"), default=field_schema.get("default"))
|
|
197
|
+
custom_fields.append(field_data)
|
|
198
|
+
|
|
199
|
+
return UserFieldsModel(
|
|
200
|
+
username=UserField(name="username", label="Username / Email", type="string", nullable=False, enum=None, default=None),
|
|
201
|
+
access_level=UserField(name="access_level", label="Access Level", type="string", nullable=False, enum=["admin", "member"], default="member"),
|
|
202
|
+
custom_fields=custom_fields
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
def add_user(self, username: str, user_fields: dict, *, update_user: bool = False) -> RegisteredUser:
|
|
206
|
+
# Separate custom fields from base fields
|
|
207
|
+
access_level = user_fields.get('access_level', 'member')
|
|
208
|
+
password = user_fields.get('password')
|
|
209
|
+
|
|
210
|
+
# Validate access level - cannot add/update users with guest access level
|
|
211
|
+
if access_level == "guest":
|
|
212
|
+
raise InvalidInputError(400, "invalid_access_level", "Cannot create or update users with 'guest' access level")
|
|
213
|
+
|
|
214
|
+
# Extract custom fields
|
|
215
|
+
custom_fields_data: dict[str, Any] = user_fields.get('custom_fields', {})
|
|
216
|
+
|
|
217
|
+
# Validate the custom fields
|
|
218
|
+
try:
|
|
219
|
+
custom_fields = self.CustomUserFields(**custom_fields_data)
|
|
220
|
+
custom_fields_json = json.dumps(custom_fields.model_dump(mode='json'))
|
|
221
|
+
except ValidationError as e:
|
|
222
|
+
raise InvalidInputError(400, "invalid_user_data", f"Invalid user field '{e.errors()[0]['loc'][0]}': {e.errors()[0]['msg']}")
|
|
223
|
+
|
|
224
|
+
# Add or update user
|
|
225
|
+
session = self.Session()
|
|
226
|
+
try:
|
|
227
|
+
# Check if the user already exists
|
|
228
|
+
db_user = session.get(self.DbUser, username)
|
|
229
|
+
if db_user is not None:
|
|
230
|
+
if not update_user:
|
|
231
|
+
raise InvalidInputError(400, "username_already_exists", f"User '{username}' already exists")
|
|
232
|
+
|
|
233
|
+
if username == c.ADMIN_USERNAME and access_level != "admin":
|
|
234
|
+
raise InvalidInputError(403, "admin_cannot_be_non_admin", "Setting the admin user to non-admin is not permitted")
|
|
235
|
+
|
|
236
|
+
# Update existing user
|
|
237
|
+
db_user.access_level = access_level
|
|
238
|
+
db_user.custom_fields = custom_fields_json
|
|
239
|
+
else:
|
|
240
|
+
if update_user:
|
|
241
|
+
raise InvalidInputError(404, "no_user_found_for_username", f"No user found for username: {username}")
|
|
242
|
+
|
|
243
|
+
if password is None:
|
|
244
|
+
raise InvalidInputError(400, "missing_password", f"Missing required field 'password' when adding a new user")
|
|
245
|
+
|
|
246
|
+
self._validate_password_length(password)
|
|
247
|
+
password_hash = pwd_context.hash(password)
|
|
248
|
+
db_user = self.DbUser(
|
|
249
|
+
username=username,
|
|
250
|
+
access_level=access_level,
|
|
251
|
+
password_hash=password_hash,
|
|
252
|
+
custom_fields=custom_fields_json
|
|
253
|
+
)
|
|
254
|
+
session.add(db_user)
|
|
255
|
+
|
|
256
|
+
# Commit the transaction
|
|
257
|
+
session.commit()
|
|
258
|
+
return self._convert_db_user_to_user(db_user)
|
|
259
|
+
|
|
260
|
+
finally:
|
|
261
|
+
session.close()
|
|
262
|
+
|
|
263
|
+
def create_or_get_user_from_provider(self, provider_name: str, user_info: dict) -> RegisteredUser:
|
|
264
|
+
provider = next((p for p in self.auth_providers if p.name == provider_name), None)
|
|
265
|
+
if provider is None:
|
|
266
|
+
raise InvalidInputError(404, "auth_provider_not_found", f"Provider '{provider_name}' not found")
|
|
267
|
+
|
|
268
|
+
user = provider.provider_configs.get_user(user_info)
|
|
269
|
+
session = self.Session()
|
|
270
|
+
try:
|
|
271
|
+
# Convert user to database user
|
|
272
|
+
custom_fields_json = user.custom_fields.model_dump_json()
|
|
273
|
+
db_user = self.DbUser(
|
|
274
|
+
username=user.username,
|
|
275
|
+
access_level=user.access_level,
|
|
276
|
+
password_hash="", # By omitting password_hash, it becomes impossible to login with username and password (OAuth only)
|
|
277
|
+
custom_fields=custom_fields_json
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
existing_db_user = session.get(self.DbUser, db_user.username)
|
|
281
|
+
if existing_db_user is None:
|
|
282
|
+
session.add(db_user)
|
|
283
|
+
|
|
284
|
+
session.commit()
|
|
285
|
+
|
|
286
|
+
return self._convert_db_user_to_user(db_user)
|
|
287
|
+
|
|
288
|
+
finally:
|
|
289
|
+
session.close()
|
|
290
|
+
|
|
291
|
+
def get_user(self, username: str, password: str) -> RegisteredUser:
|
|
292
|
+
session = self.Session()
|
|
293
|
+
try:
|
|
294
|
+
# Query for user by username
|
|
295
|
+
db_user = session.get(self.DbUser, username)
|
|
296
|
+
|
|
297
|
+
if db_user and pwd_context.verify(password, db_user.password_hash):
|
|
298
|
+
user = self._convert_db_user_to_user(db_user)
|
|
299
|
+
return user
|
|
300
|
+
else:
|
|
301
|
+
raise InvalidInputError(401, "incorrect_username_or_password", f"Incorrect username or password")
|
|
302
|
+
|
|
303
|
+
finally:
|
|
304
|
+
session.close()
|
|
305
|
+
|
|
306
|
+
def change_password(self, username: str, old_password: str, new_password: str) -> None:
|
|
307
|
+
session = self.Session()
|
|
308
|
+
try:
|
|
309
|
+
db_user = session.get(self.DbUser, username)
|
|
310
|
+
if db_user is None:
|
|
311
|
+
raise InvalidInputError(401, "user_not_found", f"Username '{username}' not found for password change")
|
|
312
|
+
|
|
313
|
+
if db_user.password_hash and pwd_context.verify(old_password, db_user.password_hash):
|
|
314
|
+
self._validate_password_length(new_password)
|
|
315
|
+
db_user.password_hash = pwd_context.hash(new_password)
|
|
316
|
+
session.commit()
|
|
317
|
+
else:
|
|
318
|
+
raise InvalidInputError(401, "incorrect_password", f"Incorrect password")
|
|
319
|
+
finally:
|
|
320
|
+
session.close()
|
|
321
|
+
|
|
322
|
+
def delete_user(self, username: str) -> None:
|
|
323
|
+
if username == c.ADMIN_USERNAME:
|
|
324
|
+
raise InvalidInputError(403, "cannot_delete_admin_user", "Cannot delete the admin user")
|
|
325
|
+
|
|
326
|
+
session = self.Session()
|
|
327
|
+
try:
|
|
328
|
+
db_user = session.get(self.DbUser, username)
|
|
329
|
+
if db_user is None:
|
|
330
|
+
raise InvalidInputError(404, "no_user_found_for_username", f"No user found for username: {username}")
|
|
331
|
+
session.delete(db_user)
|
|
332
|
+
session.commit()
|
|
333
|
+
finally:
|
|
334
|
+
session.close()
|
|
335
|
+
|
|
336
|
+
def get_all_users(self) -> list[RegisteredUser]:
|
|
337
|
+
session = self.Session()
|
|
338
|
+
try:
|
|
339
|
+
db_users = session.query(self.DbUser).all()
|
|
340
|
+
return [self._convert_db_user_to_user(user) for user in db_users]
|
|
341
|
+
finally:
|
|
342
|
+
session.close()
|
|
343
|
+
|
|
344
|
+
def create_access_token(self, user: AbstractUser, expiry_minutes: int | None, *, title: str | None = None) -> tuple[str, datetime]:
|
|
345
|
+
"""
|
|
346
|
+
Creates an API key if title is provided. Otherwise, creates a JWT token.
|
|
347
|
+
"""
|
|
348
|
+
created_at = datetime.now(timezone.utc)
|
|
349
|
+
expire_at = created_at + timedelta(minutes=expiry_minutes) if expiry_minutes is not None else datetime.max
|
|
350
|
+
|
|
351
|
+
if self.secret_key is None:
|
|
352
|
+
raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to create an access token")
|
|
353
|
+
|
|
354
|
+
if title is not None:
|
|
355
|
+
session = self.Session()
|
|
356
|
+
try:
|
|
357
|
+
token_id = "sqrl-" + secrets.token_urlsafe(16)
|
|
358
|
+
hashed_key = u.hash_string(token_id, salt=self.secret_key)
|
|
359
|
+
last_four = token_id[-4:]
|
|
360
|
+
api_key = self.DbApiKey(
|
|
361
|
+
hashed_key=hashed_key, last_four=last_four, title=title, username=user.username,
|
|
362
|
+
created_at=created_at, expires_at=expire_at
|
|
363
|
+
)
|
|
364
|
+
session.add(api_key)
|
|
365
|
+
session.commit()
|
|
366
|
+
finally:
|
|
367
|
+
session.close()
|
|
368
|
+
else:
|
|
369
|
+
to_encode = {"username": user.username, "exp": expire_at}
|
|
370
|
+
token_id = jwt.encode(to_encode, self.secret_key, algorithm="HS256")
|
|
371
|
+
|
|
372
|
+
return token_id, expire_at
|
|
373
|
+
|
|
374
|
+
def get_user_from_token(self, token: str | None) -> tuple[RegisteredUser | None, float | None]:
|
|
375
|
+
"""
|
|
376
|
+
Get a user and expiry time from an access token (JWT, or API key if token starts with 'sqrl-')
|
|
377
|
+
"""
|
|
378
|
+
if not token:
|
|
379
|
+
return None, None
|
|
380
|
+
|
|
381
|
+
if self.secret_key is None:
|
|
382
|
+
raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to get user from an access token")
|
|
383
|
+
|
|
384
|
+
session = self.Session()
|
|
385
|
+
try:
|
|
386
|
+
if token.startswith("sqrl-"):
|
|
387
|
+
hashed_key = u.hash_string(token, salt=self.secret_key)
|
|
388
|
+
api_key = session.query(self.DbApiKey).filter(
|
|
389
|
+
self.DbApiKey.hashed_key == hashed_key,
|
|
390
|
+
self.DbApiKey.expires_at >= func.now()
|
|
391
|
+
).first()
|
|
392
|
+
if api_key is None:
|
|
393
|
+
raise InvalidTokenError()
|
|
394
|
+
username = api_key.username
|
|
395
|
+
expiry = None
|
|
396
|
+
else:
|
|
397
|
+
payload: dict = jwt.decode(token, self.secret_key, algorithms=["HS256"])
|
|
398
|
+
username = payload["username"]
|
|
399
|
+
expiry = payload.get("exp")
|
|
400
|
+
|
|
401
|
+
db_user = session.get(self.DbUser, username)
|
|
402
|
+
if db_user is None:
|
|
403
|
+
raise InvalidTokenError()
|
|
404
|
+
|
|
405
|
+
user = self._convert_db_user_to_user(db_user)
|
|
406
|
+
|
|
407
|
+
except InvalidTokenError:
|
|
408
|
+
raise InvalidInputError(401, "invalid_authorization_token", "Invalid authorization token")
|
|
409
|
+
finally:
|
|
410
|
+
session.close()
|
|
411
|
+
|
|
412
|
+
return user, expiry
|
|
413
|
+
|
|
414
|
+
def _get_jwks_client(self, jwks_uri: str) -> PyJWKClient:
|
|
415
|
+
if jwks_uri not in self._jwks_clients:
|
|
416
|
+
self._jwks_clients[jwks_uri] = PyJWKClient(jwks_uri)
|
|
417
|
+
return self._jwks_clients[jwks_uri]
|
|
418
|
+
|
|
419
|
+
def _get_issuer_from_token(self, token: str) -> str | None:
|
|
420
|
+
try:
|
|
421
|
+
# JWT format is header.payload.signature
|
|
422
|
+
parts = token.split('.')
|
|
423
|
+
if len(parts) != 3:
|
|
424
|
+
return None
|
|
425
|
+
|
|
426
|
+
# Base64url decode the payload (JWT uses base64url encoding)
|
|
427
|
+
payload_b64 = parts[1]
|
|
428
|
+
# Add padding if necessary: (-len % 4) yields 0..3
|
|
429
|
+
padding = '=' * ((-len(payload_b64)) % 4)
|
|
430
|
+
payload_json = base64.urlsafe_b64decode(payload_b64 + padding).decode("utf-8")
|
|
431
|
+
payload: dict = json.loads(payload_json)
|
|
432
|
+
|
|
433
|
+
issuer = payload.get("iss")
|
|
434
|
+
return issuer if isinstance(issuer, str) and issuer else None
|
|
435
|
+
except Exception:
|
|
436
|
+
return None
|
|
437
|
+
|
|
438
|
+
@staticmethod
|
|
439
|
+
def _is_jwt(token: str) -> bool:
|
|
440
|
+
if not isinstance(token, str) or not token:
|
|
441
|
+
return False
|
|
442
|
+
parts = token.split(".")
|
|
443
|
+
return len(parts) == 3 and all(parts)
|
|
444
|
+
|
|
445
|
+
def _get_provider_metadata(self, provider_name: str) -> dict:
|
|
446
|
+
provider = next((p for p in self.auth_providers if p.name == provider_name), None)
|
|
447
|
+
if provider is None:
|
|
448
|
+
raise InvalidInputError(404, "auth_provider_not_found", f"Provider '{provider_name}' not found")
|
|
449
|
+
|
|
450
|
+
metadata_url = provider.provider_configs.server_metadata_url
|
|
451
|
+
cached = self._provider_metadata_cache.get(metadata_url)
|
|
452
|
+
if isinstance(cached, dict) and cached:
|
|
453
|
+
return cached
|
|
454
|
+
|
|
455
|
+
try:
|
|
456
|
+
response = requests.get(metadata_url, timeout=10)
|
|
457
|
+
response.raise_for_status()
|
|
458
|
+
metadata = response.json()
|
|
459
|
+
if not isinstance(metadata, dict):
|
|
460
|
+
raise ValueError("Provider metadata was not a JSON object")
|
|
461
|
+
except Exception as e:
|
|
462
|
+
raise ConfigurationError(f"Failed to fetch metadata for provider '{provider_name}': {str(e)}") from e
|
|
463
|
+
|
|
464
|
+
self._provider_metadata_cache[metadata_url] = metadata
|
|
465
|
+
return metadata
|
|
466
|
+
|
|
467
|
+
def _verify_provider_jwt(
|
|
468
|
+
self,
|
|
469
|
+
provider_name: str,
|
|
470
|
+
token: str,
|
|
471
|
+
*,
|
|
472
|
+
purpose: str,
|
|
473
|
+
expected_nonce: str | None = None,
|
|
474
|
+
verify_aud: bool = True,
|
|
475
|
+
) -> dict | None:
|
|
476
|
+
"""
|
|
477
|
+
Verify a provider-issued JWT (signature + exp; aud when requested).
|
|
478
|
+
Uses OIDC discovery for jwks_uri and (best-effort) issuer validation.
|
|
479
|
+
"""
|
|
480
|
+
if not self._is_jwt(token):
|
|
481
|
+
return None
|
|
482
|
+
|
|
483
|
+
provider = next((p for p in self.auth_providers if p.name == provider_name), None)
|
|
484
|
+
if provider is None:
|
|
485
|
+
raise InvalidInputError(404, "auth_provider_not_found", f"Provider '{provider_name}' not found")
|
|
486
|
+
|
|
487
|
+
metadata = self._get_provider_metadata(provider_name)
|
|
488
|
+
jwks_uri = metadata.get("jwks_uri")
|
|
489
|
+
if not isinstance(jwks_uri, str) or not jwks_uri:
|
|
490
|
+
raise ConfigurationError(f"jwks_uri not found in metadata for provider '{provider_name}'")
|
|
491
|
+
|
|
492
|
+
signing_algs = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
|
|
493
|
+
if not isinstance(signing_algs, list) or not signing_algs:
|
|
494
|
+
signing_algs = ["RS256"]
|
|
495
|
+
|
|
496
|
+
jwks_client = self._get_jwks_client(jwks_uri)
|
|
497
|
+
signing_key = jwks_client.get_signing_key_from_jwt(token)
|
|
498
|
+
|
|
499
|
+
decode_kwargs: dict[str, Any] = {
|
|
500
|
+
"key": signing_key.key,
|
|
501
|
+
"algorithms": signing_algs,
|
|
502
|
+
"options": {
|
|
503
|
+
"verify_aud": bool(verify_aud),
|
|
504
|
+
# We'll validate issuer manually to avoid brittle trailing-slash mismatches.
|
|
505
|
+
"verify_iss": False,
|
|
506
|
+
},
|
|
507
|
+
}
|
|
508
|
+
if verify_aud:
|
|
509
|
+
decode_kwargs["audience"] = provider.provider_configs.client_id
|
|
510
|
+
|
|
511
|
+
try:
|
|
512
|
+
payload = jwt.decode(token, **decode_kwargs)
|
|
513
|
+
except Exception:
|
|
514
|
+
return None
|
|
515
|
+
|
|
516
|
+
if not isinstance(payload, dict):
|
|
517
|
+
return None
|
|
518
|
+
|
|
519
|
+
expected_issuer = metadata.get("issuer") or provider.provider_configs.server_url
|
|
520
|
+
token_issuer = payload.get("iss")
|
|
521
|
+
if isinstance(expected_issuer, str) and expected_issuer and isinstance(token_issuer, str) and token_issuer:
|
|
522
|
+
if token_issuer.rstrip("/") != expected_issuer.rstrip("/"):
|
|
523
|
+
raise InvalidInputError(
|
|
524
|
+
401,
|
|
525
|
+
"invalid_provider_token",
|
|
526
|
+
f"Invalid {purpose} issuer for provider '{provider_name}'",
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
if expected_nonce is not None:
|
|
530
|
+
nonce_claim = payload.get("nonce")
|
|
531
|
+
if nonce_claim != expected_nonce:
|
|
532
|
+
raise InvalidInputError(
|
|
533
|
+
401,
|
|
534
|
+
"invalid_provider_token",
|
|
535
|
+
f"Invalid {purpose} nonce for provider '{provider_name}'",
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
return payload
|
|
539
|
+
|
|
540
|
+
def get_user_info_from_token_details(self, provider_name: str, token_details: dict, *, expected_nonce: str | None = None) -> dict:
|
|
541
|
+
"""
|
|
542
|
+
Determine user_info from an OAuth/OIDC token response.
|
|
543
|
+
|
|
544
|
+
Priority:
|
|
545
|
+
- token_details["user_info"] / token_details["userinfo"]
|
|
546
|
+
- verify + decode token_details["id_token"] if it's a JWT
|
|
547
|
+
- verify + decode token_details["access_token"] if it's a JWT
|
|
548
|
+
- for opaque access_token: call userinfo or introspection endpoint from provider metadata
|
|
549
|
+
"""
|
|
550
|
+
for key in ("user_info", "userinfo"):
|
|
551
|
+
user_info = token_details.get(key)
|
|
552
|
+
if isinstance(user_info, dict) and user_info:
|
|
553
|
+
return user_info
|
|
554
|
+
|
|
555
|
+
id_token = token_details.get("id_token")
|
|
556
|
+
if isinstance(id_token, str):
|
|
557
|
+
if payload := self._verify_provider_jwt(
|
|
558
|
+
provider_name, id_token, purpose="id_token", expected_nonce=expected_nonce, verify_aud=True
|
|
559
|
+
):
|
|
560
|
+
return payload
|
|
561
|
+
|
|
562
|
+
access_token = token_details.get("access_token")
|
|
563
|
+
if isinstance(access_token, str):
|
|
564
|
+
# Some providers issue JWT access tokens. Audience can vary (resource server),
|
|
565
|
+
# so we verify signature/exp and issuer, but skip aud validation.
|
|
566
|
+
if payload := self._verify_provider_jwt(
|
|
567
|
+
provider_name, access_token, purpose="access_token", expected_nonce=None, verify_aud=False
|
|
568
|
+
):
|
|
569
|
+
return payload
|
|
570
|
+
|
|
571
|
+
if not isinstance(access_token, str) or not access_token:
|
|
572
|
+
raise InvalidInputError(
|
|
573
|
+
400,
|
|
574
|
+
"invalid_provider_user_info",
|
|
575
|
+
f"User information not found in token details for {provider_name}",
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
provider = next((p for p in self.auth_providers if p.name == provider_name), None)
|
|
579
|
+
if provider is None:
|
|
580
|
+
raise InvalidInputError(404, "auth_provider_not_found", f"Provider '{provider_name}' not found")
|
|
581
|
+
|
|
582
|
+
metadata: dict = self._get_provider_metadata(provider_name)
|
|
583
|
+
|
|
584
|
+
userinfo_endpoint = metadata.get("userinfo_endpoint")
|
|
585
|
+
if isinstance(userinfo_endpoint, str) and userinfo_endpoint:
|
|
586
|
+
try:
|
|
587
|
+
response = requests.get(
|
|
588
|
+
userinfo_endpoint,
|
|
589
|
+
headers={"Authorization": f"Bearer {access_token}"},
|
|
590
|
+
timeout=10,
|
|
591
|
+
)
|
|
592
|
+
response.raise_for_status()
|
|
593
|
+
user_info = response.json()
|
|
594
|
+
if isinstance(user_info, dict) and user_info:
|
|
595
|
+
return user_info
|
|
596
|
+
except Exception as e:
|
|
597
|
+
self.logger.warning(f"Failed to fetch user info from {userinfo_endpoint} for provider {provider_name}: {str(e)}")
|
|
598
|
+
# Fall back to introspection if available
|
|
599
|
+
pass
|
|
600
|
+
|
|
601
|
+
introspection_endpoint = metadata.get("introspection_endpoint")
|
|
602
|
+
if isinstance(introspection_endpoint, str) and introspection_endpoint:
|
|
603
|
+
try:
|
|
604
|
+
response = requests.post(
|
|
605
|
+
introspection_endpoint,
|
|
606
|
+
data={"token": access_token},
|
|
607
|
+
auth=(provider.provider_configs.client_id, provider.provider_configs.client_secret),
|
|
608
|
+
timeout=10,
|
|
609
|
+
)
|
|
610
|
+
response.raise_for_status()
|
|
611
|
+
token_info = response.json()
|
|
612
|
+
if isinstance(token_info, dict):
|
|
613
|
+
if token_info.get("active") is False:
|
|
614
|
+
raise InvalidInputError(401, "inactive_external_token", "External authorization token is inactive")
|
|
615
|
+
return token_info
|
|
616
|
+
except InvalidInputError:
|
|
617
|
+
raise
|
|
618
|
+
except Exception as e:
|
|
619
|
+
self.logger.warning(f"Introspection request failed for provider {provider_name} using basic auth: {str(e)}")
|
|
620
|
+
# Some providers require "client_secret_post" instead of basic auth for introspection.
|
|
621
|
+
try:
|
|
622
|
+
response = requests.post(
|
|
623
|
+
introspection_endpoint,
|
|
624
|
+
data={
|
|
625
|
+
"token": access_token,
|
|
626
|
+
"client_id": provider.provider_configs.client_id,
|
|
627
|
+
"client_secret": provider.provider_configs.client_secret,
|
|
628
|
+
},
|
|
629
|
+
timeout=10,
|
|
630
|
+
)
|
|
631
|
+
response.raise_for_status()
|
|
632
|
+
token_info = response.json()
|
|
633
|
+
if isinstance(token_info, dict):
|
|
634
|
+
if token_info.get("active") is False:
|
|
635
|
+
raise InvalidInputError(401, "inactive_external_token", "External authorization token is inactive")
|
|
636
|
+
return token_info
|
|
637
|
+
except InvalidInputError:
|
|
638
|
+
raise
|
|
639
|
+
except Exception as e2:
|
|
640
|
+
self.logger.warning(f"Introspection request failed for provider {provider_name} using post data: {str(e2)}")
|
|
641
|
+
pass
|
|
642
|
+
|
|
643
|
+
raise InvalidInputError(
|
|
644
|
+
400,
|
|
645
|
+
"invalid_provider_user_info",
|
|
646
|
+
f"User information not found in token details for {provider_name}",
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
def get_user_from_external_token(self, token: str, provider_name: str | None = None) -> tuple[RegisteredUser, float | None]:
|
|
650
|
+
"""
|
|
651
|
+
Get a user from an external OAuth token by validating against provider's JWKS
|
|
652
|
+
"""
|
|
653
|
+
issuer: str | None = None
|
|
654
|
+
token_is_jwt = self._is_jwt(token)
|
|
655
|
+
|
|
656
|
+
if provider_name:
|
|
657
|
+
provider = next((p for p in self.auth_providers if p.name == provider_name), None)
|
|
658
|
+
elif token_is_jwt:
|
|
659
|
+
issuer = self._get_issuer_from_token(token)
|
|
660
|
+
if not issuer:
|
|
661
|
+
raise InvalidInputError(401, "invalid_external_token", "Could not extract issuer from token")
|
|
662
|
+
|
|
663
|
+
# Match provider by issuer (server_url)
|
|
664
|
+
provider = next(
|
|
665
|
+
(p for p in self.auth_providers if p.provider_configs.server_url.rstrip("/") == issuer.rstrip("/")),
|
|
666
|
+
None,
|
|
667
|
+
)
|
|
668
|
+
else:
|
|
669
|
+
# Opaque external token: if there's exactly one configured provider, assume it.
|
|
670
|
+
provider = self.auth_providers[0] if len(self.auth_providers) == 1 else None
|
|
671
|
+
|
|
672
|
+
if provider is None:
|
|
673
|
+
if provider_name:
|
|
674
|
+
raise InvalidInputError(404, "auth_provider_not_found", f"Provider '{provider_name}' not found")
|
|
675
|
+
if token_is_jwt:
|
|
676
|
+
raise InvalidInputError(401, "auth_provider_not_found", f"No provider found for issuer: {issuer}")
|
|
677
|
+
raise InvalidInputError(401, "invalid_external_token", "Could not determine provider for external token")
|
|
678
|
+
|
|
679
|
+
# JWT external token: validate signature/exp (+ issuer) via provider JWKS
|
|
680
|
+
if token_is_jwt:
|
|
681
|
+
try:
|
|
682
|
+
payload = self._verify_provider_jwt(provider.name, token, purpose="external_token", verify_aud=False)
|
|
683
|
+
except InvalidInputError:
|
|
684
|
+
# Keep the external-auth contract stable (avoid leaking provider-specific details).
|
|
685
|
+
raise InvalidInputError(401, "invalid_external_token", "Invalid external authorization token")
|
|
686
|
+
|
|
687
|
+
if not isinstance(payload, dict) or not payload:
|
|
688
|
+
raise InvalidInputError(401, "invalid_external_token", "Invalid external authorization token")
|
|
689
|
+
else:
|
|
690
|
+
# Opaque token: reuse the existing provider userinfo/introspection logic.
|
|
691
|
+
try:
|
|
692
|
+
payload = self.get_user_info_from_token_details(provider.name, {"access_token": token})
|
|
693
|
+
except InvalidInputError as e:
|
|
694
|
+
# Normalize into the external-auth error contract.
|
|
695
|
+
if getattr(e, "error", None) == "inactive_external_token":
|
|
696
|
+
raise
|
|
697
|
+
raise InvalidInputError(401, "invalid_external_token", "Invalid external authorization token")
|
|
698
|
+
|
|
699
|
+
if not isinstance(payload, dict) or not payload:
|
|
700
|
+
raise InvalidInputError(401, "invalid_external_token", "Invalid external authorization token")
|
|
701
|
+
|
|
702
|
+
user = provider.provider_configs.get_user(payload)
|
|
703
|
+
exp = payload.get("exp")
|
|
704
|
+
expiry: float | None
|
|
705
|
+
if isinstance(exp, (int, float)):
|
|
706
|
+
expiry = float(exp)
|
|
707
|
+
else:
|
|
708
|
+
expiry = None
|
|
709
|
+
|
|
710
|
+
return user, expiry
|
|
711
|
+
|
|
712
|
+
def get_all_api_keys(self, username: str) -> list[ApiKey]:
|
|
713
|
+
"""
|
|
714
|
+
Get the ID, title, and expiry date of all API keys for a user. Note that the ID is a hash of the API key, not the API key itself.
|
|
715
|
+
"""
|
|
716
|
+
session = self.Session()
|
|
717
|
+
try:
|
|
718
|
+
tokens = session.query(self.DbApiKey).filter(
|
|
719
|
+
self.DbApiKey.username == username,
|
|
720
|
+
self.DbApiKey.expires_at >= func.now()
|
|
721
|
+
).all()
|
|
722
|
+
|
|
723
|
+
return [ApiKey.model_validate(token) for token in tokens]
|
|
724
|
+
finally:
|
|
725
|
+
session.close()
|
|
726
|
+
|
|
727
|
+
def revoke_api_key(self, username: str, api_key_id: str) -> None:
|
|
728
|
+
"""
|
|
729
|
+
Revoke an API key
|
|
730
|
+
"""
|
|
731
|
+
session = self.Session()
|
|
732
|
+
try:
|
|
733
|
+
|
|
734
|
+
api_key = session.query(self.DbApiKey).filter(
|
|
735
|
+
self.DbApiKey.username == username,
|
|
736
|
+
self.DbApiKey.id == api_key_id
|
|
737
|
+
).first()
|
|
738
|
+
|
|
739
|
+
if api_key is None:
|
|
740
|
+
raise InvalidInputError(404, "api_key_not_found", f"The API key could not be found: {api_key_id}")
|
|
741
|
+
|
|
742
|
+
session.delete(api_key)
|
|
743
|
+
session.commit()
|
|
744
|
+
finally:
|
|
745
|
+
session.close()
|
|
746
|
+
|
|
747
|
+
def can_user_access_scope(self, user: AbstractUser, scope: PermissionScope) -> bool:
|
|
748
|
+
if user.access_level == "guest":
|
|
749
|
+
user_level = PermissionScope.PUBLIC
|
|
750
|
+
elif user.access_level == "admin":
|
|
751
|
+
user_level = PermissionScope.PRIVATE
|
|
752
|
+
else: # member
|
|
753
|
+
user_level = PermissionScope.PROTECTED
|
|
754
|
+
|
|
755
|
+
return user_level.value >= scope.value
|
|
756
|
+
|
|
757
|
+
def close(self) -> None:
|
|
758
|
+
if hasattr(self, "engine"):
|
|
759
|
+
self.engine.dispose()
|
|
760
|
+
|
|
761
|
+
|
|
762
|
+
def provider(name: str, label: str, icon: str):
|
|
763
|
+
"""
|
|
764
|
+
Decorator to register an authentication provider
|
|
765
|
+
|
|
766
|
+
Arguments:
|
|
767
|
+
name: The name of the provider (must be unique, e.g. 'google')
|
|
768
|
+
label: The label of the provider (e.g. 'Google')
|
|
769
|
+
icon: The URL of the icon of the provider (e.g. 'https://www.google.com/favicon.ico')
|
|
770
|
+
"""
|
|
771
|
+
def decorator(func: Callable[[AuthProviderArgs], ProviderConfigs]):
|
|
772
|
+
def wrapper(sqrl: AuthProviderArgs):
|
|
773
|
+
provider_configs = func(sqrl)
|
|
774
|
+
return AuthProvider(name=name, label=label, icon=icon, provider_configs=provider_configs)
|
|
775
|
+
Authenticator.providers.append(wrapper)
|
|
776
|
+
return wrapper
|
|
777
|
+
return decorator
|