squirrels 0.5.0b3__py3-none-any.whl → 0.6.0.post0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- squirrels/__init__.py +4 -0
- squirrels/_api_routes/__init__.py +5 -0
- squirrels/_api_routes/auth.py +337 -0
- squirrels/_api_routes/base.py +196 -0
- squirrels/_api_routes/dashboards.py +156 -0
- squirrels/_api_routes/data_management.py +148 -0
- squirrels/_api_routes/datasets.py +220 -0
- squirrels/_api_routes/project.py +289 -0
- squirrels/_api_server.py +440 -792
- squirrels/_arguments/__init__.py +0 -0
- squirrels/_arguments/{_init_time_args.py → init_time_args.py} +23 -43
- squirrels/_arguments/{_run_time_args.py → run_time_args.py} +32 -68
- squirrels/_auth.py +590 -264
- squirrels/_command_line.py +130 -58
- squirrels/_compile_prompts.py +147 -0
- squirrels/_connection_set.py +16 -15
- squirrels/_constants.py +36 -11
- squirrels/_dashboards.py +179 -0
- squirrels/_data_sources.py +40 -34
- squirrels/_dataset_types.py +16 -11
- squirrels/_env_vars.py +209 -0
- squirrels/_exceptions.py +9 -37
- squirrels/_http_error_responses.py +52 -0
- squirrels/_initializer.py +7 -6
- squirrels/_logging.py +121 -0
- squirrels/_manifest.py +155 -77
- squirrels/_mcp_server.py +578 -0
- squirrels/_model_builder.py +11 -55
- squirrels/_model_configs.py +5 -5
- squirrels/_model_queries.py +1 -1
- squirrels/_models.py +276 -143
- squirrels/_package_data/base_project/.env +1 -24
- squirrels/_package_data/base_project/.env.example +31 -17
- squirrels/_package_data/base_project/connections.yml +4 -3
- squirrels/_package_data/base_project/dashboards/dashboard_example.py +13 -7
- squirrels/_package_data/base_project/dashboards/dashboard_example.yml +6 -6
- squirrels/_package_data/base_project/docker/Dockerfile +2 -2
- squirrels/_package_data/base_project/docker/compose.yml +1 -1
- squirrels/_package_data/base_project/duckdb_init.sql +1 -0
- squirrels/_package_data/base_project/models/builds/build_example.py +2 -2
- squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +7 -2
- squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +16 -10
- squirrels/_package_data/base_project/models/federates/federate_example.py +27 -17
- squirrels/_package_data/base_project/models/federates/federate_example.sql +3 -7
- squirrels/_package_data/base_project/models/federates/federate_example.yml +7 -7
- squirrels/_package_data/base_project/models/sources.yml +5 -6
- squirrels/_package_data/base_project/parameters.yml +24 -38
- squirrels/_package_data/base_project/pyconfigs/connections.py +8 -3
- squirrels/_package_data/base_project/pyconfigs/context.py +26 -14
- squirrels/_package_data/base_project/pyconfigs/parameters.py +124 -81
- squirrels/_package_data/base_project/pyconfigs/user.py +48 -15
- squirrels/_package_data/base_project/resources/public/.gitkeep +0 -0
- squirrels/_package_data/base_project/seeds/seed_categories.yml +1 -1
- squirrels/_package_data/base_project/seeds/seed_subcategories.yml +1 -1
- squirrels/_package_data/base_project/squirrels.yml.j2 +21 -31
- squirrels/_package_data/templates/login_successful.html +53 -0
- squirrels/_package_data/templates/squirrels_studio.html +22 -0
- squirrels/_parameter_configs.py +43 -22
- squirrels/_parameter_options.py +1 -1
- squirrels/_parameter_sets.py +41 -30
- squirrels/_parameters.py +560 -123
- squirrels/_project.py +487 -277
- squirrels/_py_module.py +71 -10
- squirrels/_request_context.py +33 -0
- squirrels/_schemas/__init__.py +0 -0
- squirrels/_schemas/auth_models.py +83 -0
- squirrels/_schemas/query_param_models.py +70 -0
- squirrels/_schemas/request_models.py +26 -0
- squirrels/_schemas/response_models.py +286 -0
- squirrels/_seeds.py +52 -13
- squirrels/_sources.py +29 -23
- squirrels/_utils.py +221 -42
- squirrels/_version.py +1 -3
- squirrels/arguments.py +7 -2
- squirrels/auth.py +4 -0
- squirrels/connections.py +2 -0
- squirrels/dashboards.py +3 -1
- squirrels/data_sources.py +6 -0
- squirrels/parameter_options.py +5 -0
- squirrels/parameters.py +5 -0
- squirrels/types.py +10 -3
- squirrels-0.6.0.post0.dist-info/METADATA +148 -0
- squirrels-0.6.0.post0.dist-info/RECORD +101 -0
- {squirrels-0.5.0b3.dist-info → squirrels-0.6.0.post0.dist-info}/WHEEL +1 -1
- squirrels/_api_response_models.py +0 -190
- squirrels/_dashboard_types.py +0 -82
- squirrels/_dashboards_io.py +0 -79
- squirrels-0.5.0b3.dist-info/METADATA +0 -110
- squirrels-0.5.0b3.dist-info/RECORD +0 -80
- /squirrels/_package_data/base_project/{assets → resources}/expenses.db +0 -0
- /squirrels/_package_data/base_project/{assets → resources}/weather.db +0 -0
- {squirrels-0.5.0b3.dist-info → squirrels-0.6.0.post0.dist-info}/entry_points.txt +0 -0
- {squirrels-0.5.0b3.dist-info → squirrels-0.6.0.post0.dist-info}/licenses/LICENSE +0 -0
squirrels/_auth.py
CHANGED
|
@@ -1,231 +1,165 @@
|
|
|
1
|
+
from typing import Callable, Any
|
|
1
2
|
from datetime import datetime, timedelta, timezone
|
|
2
|
-
from enum import Enum
|
|
3
3
|
from functools import cached_property
|
|
4
4
|
from jwt.exceptions import InvalidTokenError
|
|
5
5
|
from passlib.context import CryptContext
|
|
6
|
-
from pydantic import
|
|
7
|
-
from pydantic_core import PydanticUndefined
|
|
6
|
+
from pydantic import ValidationError
|
|
8
7
|
from sqlalchemy import create_engine, Engine, func, inspect, text, ForeignKey
|
|
9
|
-
from sqlalchemy import Column, String, Integer, Float, Boolean
|
|
10
8
|
from sqlalchemy.orm import declarative_base, sessionmaker, Mapped, mapped_column
|
|
11
|
-
import jwt,
|
|
9
|
+
import jwt, uuid, secrets, json, base64, requests
|
|
10
|
+
from jwt import PyJWKClient
|
|
12
11
|
|
|
12
|
+
from ._env_vars import SquirrelsEnvVars
|
|
13
13
|
from ._manifest import PermissionScope
|
|
14
|
-
from ._py_module import PyModule
|
|
15
14
|
from ._exceptions import InvalidInputError, ConfigurationError
|
|
15
|
+
from ._arguments.init_time_args import AuthProviderArgs
|
|
16
|
+
from ._schemas.auth_models import (
|
|
17
|
+
CustomUserFields, AbstractUser, RegisteredUser, ApiKey, UserField, UserFieldsModel, AuthProvider, ProviderConfigs
|
|
18
|
+
)
|
|
19
|
+
from ._schemas import response_models as rm
|
|
16
20
|
from . import _utils as u, _constants as c
|
|
17
21
|
|
|
18
22
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
19
23
|
|
|
20
|
-
|
|
21
|
-
disallowed_fields = ["password", "password_hash", "created_at", "token_id", "exp"]
|
|
24
|
+
ProviderFunctionType = Callable[[AuthProviderArgs], AuthProvider]
|
|
22
25
|
|
|
23
|
-
class BaseUser(BaseModel):
|
|
24
|
-
model_config = ConfigDict(from_attributes=True)
|
|
25
|
-
username: str
|
|
26
|
-
is_admin: bool = False
|
|
27
|
-
|
|
28
|
-
@classmethod
|
|
29
|
-
def dropped_columns(cls):
|
|
30
|
-
return []
|
|
31
|
-
|
|
32
|
-
def __hash__(self):
|
|
33
|
-
return hash(self.username)
|
|
34
|
-
|
|
35
|
-
User = _t.TypeVar('User', bound=BaseUser)
|
|
36
|
-
|
|
37
|
-
class AccessToken(BaseModel):
|
|
38
|
-
model_config = ConfigDict(from_attributes=True)
|
|
39
|
-
token_id: str
|
|
40
|
-
title: str
|
|
41
|
-
username: str
|
|
42
|
-
created_at: datetime
|
|
43
|
-
expires_at: datetime
|
|
44
26
|
|
|
27
|
+
class Authenticator:
|
|
28
|
+
providers: list[ProviderFunctionType] = [] # static variable to stage providers
|
|
45
29
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
class Authenticator(_t.Generic[User]):
|
|
55
|
-
def __init__(self, logger: u.Logger, base_path: str, env_vars: dict[str, str], *, sa_engine: Engine | None = None, cls: type[User] | None = None):
|
|
30
|
+
def __init__(
|
|
31
|
+
self, logger: u.Logger, env_vars: SquirrelsEnvVars, auth_args: AuthProviderArgs,
|
|
32
|
+
provider_functions: list[ProviderFunctionType], custom_user_fields_cls: type[CustomUserFields],
|
|
33
|
+
*,
|
|
34
|
+
sa_engine: Engine | None = None, external_only: bool = False
|
|
35
|
+
):
|
|
56
36
|
self.logger = logger
|
|
57
37
|
self.env_vars = env_vars
|
|
58
|
-
self.secret_key =
|
|
38
|
+
self.secret_key = env_vars.secret_key
|
|
39
|
+
self.external_only = external_only
|
|
40
|
+
self.password_requirements = rm.PasswordRequirements()
|
|
59
41
|
|
|
60
42
|
# Create a new declarative base for this instance
|
|
61
43
|
self.Base = declarative_base()
|
|
62
44
|
|
|
63
|
-
# Define
|
|
64
|
-
class
|
|
45
|
+
# Define DbUser class for this instance
|
|
46
|
+
class DbUser(self.Base):
|
|
65
47
|
__tablename__ = 'users'
|
|
66
48
|
__table_args__ = {'extend_existing': True}
|
|
67
49
|
username: Mapped[str] = mapped_column(primary_key=True)
|
|
68
|
-
|
|
50
|
+
access_level: Mapped[str] = mapped_column(nullable=False, default="member")
|
|
69
51
|
password_hash: Mapped[str] = mapped_column(nullable=False)
|
|
52
|
+
custom_fields: Mapped[str] = mapped_column(nullable=False, default="{}")
|
|
70
53
|
created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
|
|
71
54
|
|
|
72
|
-
# Define
|
|
73
|
-
class
|
|
74
|
-
__tablename__ = '
|
|
55
|
+
# Define DbApiKey class for this instance
|
|
56
|
+
class DbApiKey(self.Base):
|
|
57
|
+
__tablename__ = 'api_keys'
|
|
75
58
|
|
|
76
|
-
|
|
59
|
+
id: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
|
|
60
|
+
hashed_key: Mapped[str] = mapped_column(unique=True, nullable=False)
|
|
61
|
+
last_four: Mapped[str] = mapped_column(nullable=False)
|
|
77
62
|
title: Mapped[str] = mapped_column(nullable=False)
|
|
78
63
|
username: Mapped[str] = mapped_column(ForeignKey('users.username', ondelete='CASCADE'), nullable=False)
|
|
79
64
|
created_at: Mapped[datetime] = mapped_column(nullable=False)
|
|
80
65
|
expires_at: Mapped[datetime] = mapped_column(nullable=False)
|
|
81
|
-
|
|
82
|
-
def __repr__(self):
|
|
83
|
-
return f"<AccessToken(token_id='{self.token_id}', username='{self.username}')>"
|
|
84
66
|
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
self.User = self._get_user_model(base_path) if cls is None else cls
|
|
89
|
-
self.DbUser: type[DbBaseUser] = self._initialize_db_user_model(self.User)
|
|
67
|
+
def __repr__(self):
|
|
68
|
+
return f"<DbApiKey(id='{self.id}', username='{self.username}')>"
|
|
90
69
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
self.engine = create_engine(f"sqlite:///{str(sqlite_path)}")
|
|
96
|
-
else:
|
|
97
|
-
self.engine = sa_engine
|
|
70
|
+
self.CustomUserFields = custom_user_fields_cls
|
|
71
|
+
self.DbUser = DbUser
|
|
72
|
+
|
|
73
|
+
self.DbApiKey = DbApiKey
|
|
98
74
|
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
conn.execute(text("PRAGMA synchronous = NORMAL"))
|
|
103
|
-
conn.commit()
|
|
75
|
+
self.auth_providers = [provider_function(auth_args) for provider_function in provider_functions]
|
|
76
|
+
self._jwks_clients: dict[str, PyJWKClient] = {}
|
|
77
|
+
self._provider_metadata_cache: dict[str, dict] = {}
|
|
104
78
|
|
|
105
|
-
self.
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
def _get_user_model(self, base_path: str) -> type[BaseUser]:
|
|
112
|
-
user_module_path = u.Path(base_path, c.PYCONFIGS_FOLDER, c.USER_FILE)
|
|
113
|
-
user_module = PyModule(user_module_path)
|
|
114
|
-
User = user_module.get_func_or_class("User", default_attr=BaseUser)
|
|
115
|
-
if not issubclass(User, BaseUser):
|
|
116
|
-
raise ConfigurationError(f"User class in '{c.USER_FILE}' must inherit from BaseUser")
|
|
117
|
-
return User
|
|
118
|
-
|
|
119
|
-
def _initialize_db_user_model(self, *args) -> type:
|
|
120
|
-
"""Get the user model with any custom attributes defined in user.py"""
|
|
121
|
-
attrs = {}
|
|
122
|
-
|
|
123
|
-
# Iterate over all fields in the User model
|
|
124
|
-
for field_name, field in self.User.model_fields.items():
|
|
125
|
-
if field_name in reserved_fields:
|
|
126
|
-
continue
|
|
127
|
-
if field_name in disallowed_fields:
|
|
128
|
-
raise ConfigurationError(f"Field name '{field_name}' is disallowed in the User model and cannot be used")
|
|
129
|
-
|
|
130
|
-
field_type = field.annotation
|
|
131
|
-
if _t.get_origin(field_type) in (_t.Union, types.UnionType):
|
|
132
|
-
field_type = _t.get_args(field_type)[0]
|
|
133
|
-
nullable = True
|
|
79
|
+
if not self.external_only:
|
|
80
|
+
if sa_engine is None:
|
|
81
|
+
raw_sqlite_path = self.env_vars.auth_db_file_path
|
|
82
|
+
sqlite_path = u.Path(raw_sqlite_path.format(project_path=self.env_vars.project_path))
|
|
83
|
+
sqlite_path.parent.mkdir(parents=True, exist_ok=True)
|
|
84
|
+
self.engine = create_engine(f"sqlite:///{str(sqlite_path)}")
|
|
134
85
|
else:
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
if _t.get_origin(field_type) == _t.Literal:
|
|
138
|
-
field_type = str
|
|
139
|
-
|
|
140
|
-
# Map Python types and default values to SQLAlchemy columns
|
|
141
|
-
default_value = field.default
|
|
142
|
-
if default_value is PydanticUndefined:
|
|
143
|
-
raise ConfigurationError(f"No default value found for field '{field_name}' in User model")
|
|
144
|
-
elif not nullable and default_value is None:
|
|
145
|
-
raise ConfigurationError(f"Default value for non-nullable field '{field_name}' was set as None in User model")
|
|
146
|
-
elif default_value is not None and type(default_value) is not field_type:
|
|
147
|
-
raise ConfigurationError(f"Default value for field '{field_name}' does not match field type in User model")
|
|
86
|
+
self.engine = sa_engine
|
|
148
87
|
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
col_type = Boolean
|
|
157
|
-
elif isinstance(field_type, type) and issubclass(field_type, Enum):
|
|
158
|
-
col_type = String
|
|
159
|
-
default_value = default_value.value
|
|
160
|
-
else:
|
|
161
|
-
continue
|
|
162
|
-
|
|
163
|
-
attrs[field_name] = Column(col_type, nullable=nullable, default=default_value) # type: ignore
|
|
88
|
+
# Configure SQLite pragmas
|
|
89
|
+
with self.engine.connect() as conn:
|
|
90
|
+
conn.execute(text("PRAGMA journal_mode = WAL"))
|
|
91
|
+
conn.execute(text("PRAGMA synchronous = NORMAL"))
|
|
92
|
+
conn.commit()
|
|
93
|
+
|
|
94
|
+
self.Base.metadata.create_all(self.engine)
|
|
164
95
|
|
|
165
|
-
|
|
166
|
-
DbUser = type('DbUser', (self.DbBaseUser,), attrs)
|
|
167
|
-
return DbUser
|
|
96
|
+
self.Session = sessionmaker(bind=self.engine)
|
|
168
97
|
|
|
169
|
-
|
|
98
|
+
self._initialize_db()
|
|
99
|
+
|
|
100
|
+
def _convert_db_user_to_user(self, db_user) -> RegisteredUser:
|
|
101
|
+
"""Convert a database user to an AbstractUser object"""
|
|
102
|
+
# Deserialize custom_fields JSON and merge with defaults
|
|
103
|
+
custom_fields_json = json.loads(db_user.custom_fields) if db_user.custom_fields else {}
|
|
104
|
+
custom_fields = self.CustomUserFields(**custom_fields_json)
|
|
105
|
+
|
|
106
|
+
return RegisteredUser(
|
|
107
|
+
username=db_user.username,
|
|
108
|
+
access_level=db_user.access_level,
|
|
109
|
+
custom_fields=custom_fields
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
def _validate_password_length(self, password: str) -> None:
|
|
113
|
+
"""Validate that password meets length requirements and does not exceed 72 characters (bcrypt limit)"""
|
|
114
|
+
min_len = self.password_requirements.min_length
|
|
115
|
+
max_len = min(self.password_requirements.max_length, 72)
|
|
116
|
+
if len(password) < min_len:
|
|
117
|
+
raise InvalidInputError(400, "password_too_short", f"Password must be at least {min_len} characters long")
|
|
118
|
+
if len(password) > max_len:
|
|
119
|
+
raise InvalidInputError(400, "password_too_long", f"Password cannot exceed {max_len} characters")
|
|
120
|
+
|
|
121
|
+
def _initialize_db(self):
|
|
170
122
|
session = self.Session()
|
|
171
123
|
try:
|
|
172
124
|
# Get existing columns in the database
|
|
173
125
|
inspector = inspect(self.engine)
|
|
174
|
-
existing_columns = {col['name'] for col in inspector.get_columns('users')}
|
|
175
|
-
|
|
176
|
-
# Get all columns defined in the model
|
|
177
|
-
dropped_columns = set(self.User.dropped_columns())
|
|
178
|
-
model_columns = set(self.DbUser.__table__.columns.keys()) - dropped_columns
|
|
179
|
-
|
|
180
|
-
# Find columns that are in the model but not in the database
|
|
181
|
-
new_columns = model_columns - existing_columns
|
|
182
|
-
if new_columns:
|
|
183
|
-
add_columns_msg = f"Adding columns to database: {new_columns}"
|
|
184
|
-
print("NOTE:", add_columns_msg)
|
|
185
|
-
self.logger.info(add_columns_msg)
|
|
186
|
-
|
|
187
|
-
for col_name in new_columns:
|
|
188
|
-
col = self.DbUser.__table__.columns[col_name]
|
|
189
|
-
column_type = col.type.compile(self.engine.dialect)
|
|
190
|
-
nullable = "NULL" if col.nullable else "NOT NULL"
|
|
191
|
-
if col.default is not None:
|
|
192
|
-
default_val = f"'{col.default.arg}'" if isinstance(col.default.arg, str) else col.default.arg
|
|
193
|
-
default = f"DEFAULT {default_val}"
|
|
194
|
-
else:
|
|
195
|
-
default = ""
|
|
196
|
-
|
|
197
|
-
alter_stmt = f"ALTER TABLE users ADD COLUMN {col_name} {column_type} {nullable} {default}"
|
|
198
|
-
session.execute(text(alter_stmt))
|
|
199
|
-
|
|
200
|
-
session.commit()
|
|
201
126
|
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
self.logger.info(drop_columns_msg)
|
|
208
|
-
|
|
209
|
-
for col_name in columns_to_drop:
|
|
210
|
-
session.execute(text(f"ALTER TABLE users DROP COLUMN {col_name}"))
|
|
127
|
+
for db_model in [self.DbUser, self.DbApiKey]:
|
|
128
|
+
table_name = db_model.__tablename__
|
|
129
|
+
existing_columns = {col['name'] for col in inspector.get_columns(table_name)}
|
|
130
|
+
model_columns = set(db_model.__table__.columns.keys())
|
|
131
|
+
new_columns = model_columns - existing_columns
|
|
211
132
|
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
133
|
+
if new_columns:
|
|
134
|
+
add_columns_msg = f"Adding columns to table {table_name}: {new_columns}"
|
|
135
|
+
self.logger.info(add_columns_msg)
|
|
136
|
+
|
|
137
|
+
for col_name in new_columns:
|
|
138
|
+
col = db_model.__table__.columns[col_name]
|
|
139
|
+
column_type = col.type.compile(self.engine.dialect)
|
|
140
|
+
nullable = "NULL" if col.nullable else "NOT NULL"
|
|
141
|
+
if col.default is not None and not callable(col.default.arg):
|
|
142
|
+
default_val = f"'{col.default.arg}'" if isinstance(col.default.arg, str) else col.default.arg
|
|
143
|
+
default = f"DEFAULT {default_val}"
|
|
144
|
+
else:
|
|
145
|
+
# If nullable is False and no default is provided, use an empty string as a placeholder default for SQLite
|
|
146
|
+
# TODO: Use a different default value (instead of empty string) based on the column type
|
|
147
|
+
default = "DEFAULT ''" if not col.nullable else ""
|
|
148
|
+
|
|
149
|
+
alter_stmt = f"ALTER TABLE {table_name} ADD COLUMN {col_name} {column_type} {nullable} {default}"
|
|
150
|
+
session.execute(text(alter_stmt))
|
|
151
|
+
|
|
152
|
+
session.commit()
|
|
219
153
|
|
|
220
154
|
# Get admin password from environment variable if exists
|
|
221
|
-
admin_password = self.env_vars.
|
|
155
|
+
admin_password = self.env_vars.secret_admin_password
|
|
222
156
|
|
|
223
|
-
# If admin password variable exists, find username "admin". If it does not exist, add it
|
|
224
157
|
if admin_password is not None:
|
|
158
|
+
self._validate_password_length(admin_password)
|
|
225
159
|
password_hash = pwd_context.hash(admin_password)
|
|
226
160
|
admin_user = session.get(self.DbUser, c.ADMIN_USERNAME)
|
|
227
161
|
if admin_user is None:
|
|
228
|
-
admin_user = self.DbUser(username=c.ADMIN_USERNAME, password_hash=password_hash,
|
|
162
|
+
admin_user = self.DbUser(username=c.ADMIN_USERNAME, password_hash=password_hash, access_level="admin")
|
|
229
163
|
session.add(admin_user)
|
|
230
164
|
else:
|
|
231
165
|
admin_user.password_hash = password_hash
|
|
@@ -236,9 +170,9 @@ class Authenticator(_t.Generic[User]):
|
|
|
236
170
|
session.close()
|
|
237
171
|
|
|
238
172
|
@cached_property
|
|
239
|
-
def user_fields(self) ->
|
|
173
|
+
def user_fields(self) -> UserFieldsModel:
|
|
240
174
|
"""
|
|
241
|
-
Get the fields of the
|
|
175
|
+
Get the fields of the CustomUserFields model as a list of dictionaries
|
|
242
176
|
|
|
243
177
|
Each dictionary contains the following keys:
|
|
244
178
|
- name: The name of the field
|
|
@@ -247,11 +181,10 @@ class Authenticator(_t.Generic[User]):
|
|
|
247
181
|
- enum: The possible values of the field (or None if not applicable)
|
|
248
182
|
- default: The default value of the field (or None if field is required)
|
|
249
183
|
"""
|
|
250
|
-
schema = self.User.model_json_schema()
|
|
251
|
-
|
|
252
|
-
fields = []
|
|
253
184
|
|
|
254
|
-
|
|
185
|
+
custom_fields = []
|
|
186
|
+
schema = self.CustomUserFields.model_json_schema()
|
|
187
|
+
properties: dict[str, dict[str, Any]] = schema.get("properties", {})
|
|
255
188
|
for field_name, field_schema in properties.items():
|
|
256
189
|
if choices := field_schema.get("anyOf"):
|
|
257
190
|
field_type = choices[0]["type"]
|
|
@@ -260,62 +193,112 @@ class Authenticator(_t.Generic[User]):
|
|
|
260
193
|
field_type = field_schema["type"]
|
|
261
194
|
nullable = False
|
|
262
195
|
|
|
263
|
-
field_data = UserField(name=field_name, type=field_type, nullable=nullable, enum=field_schema.get("enum"), default=field_schema.get("default"))
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
return
|
|
196
|
+
field_data = UserField(name=field_name, label=field_schema.get("title", field_name), type=field_type, nullable=nullable, enum=field_schema.get("enum"), default=field_schema.get("default"))
|
|
197
|
+
custom_fields.append(field_data)
|
|
198
|
+
|
|
199
|
+
return UserFieldsModel(
|
|
200
|
+
username=UserField(name="username", label="Username / Email", type="string", nullable=False, enum=None, default=None),
|
|
201
|
+
access_level=UserField(name="access_level", label="Access Level", type="string", nullable=False, enum=["admin", "member"], default="member"),
|
|
202
|
+
custom_fields=custom_fields
|
|
203
|
+
)
|
|
267
204
|
|
|
268
|
-
def add_user(self, username: str, user_fields: dict, *, update_user: bool = False) ->
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
205
|
+
def add_user(self, username: str, user_fields: dict, *, update_user: bool = False) -> RegisteredUser:
|
|
206
|
+
# Separate custom fields from base fields
|
|
207
|
+
access_level = user_fields.get('access_level', 'member')
|
|
208
|
+
password = user_fields.get('password')
|
|
209
|
+
|
|
210
|
+
# Validate access level - cannot add/update users with guest access level
|
|
211
|
+
if access_level == "guest":
|
|
212
|
+
raise InvalidInputError(400, "invalid_access_level", "Cannot create or update users with 'guest' access level")
|
|
213
|
+
|
|
214
|
+
# Extract custom fields
|
|
215
|
+
custom_fields_data: dict[str, Any] = user_fields.get('custom_fields', {})
|
|
216
|
+
|
|
217
|
+
# Validate the custom fields
|
|
272
218
|
try:
|
|
273
|
-
|
|
219
|
+
custom_fields = self.CustomUserFields(**custom_fields_data)
|
|
220
|
+
custom_fields_json = json.dumps(custom_fields.model_dump(mode='json'))
|
|
274
221
|
except ValidationError as e:
|
|
275
|
-
raise InvalidInputError(
|
|
222
|
+
raise InvalidInputError(400, "invalid_user_data", f"Invalid user field '{e.errors()[0]['loc'][0]}': {e.errors()[0]['msg']}")
|
|
276
223
|
|
|
277
|
-
# Add
|
|
224
|
+
# Add or update user
|
|
225
|
+
session = self.Session()
|
|
278
226
|
try:
|
|
279
227
|
# Check if the user already exists
|
|
280
|
-
|
|
281
|
-
if
|
|
228
|
+
db_user = session.get(self.DbUser, username)
|
|
229
|
+
if db_user is not None:
|
|
282
230
|
if not update_user:
|
|
283
|
-
raise InvalidInputError(
|
|
231
|
+
raise InvalidInputError(400, "username_already_exists", f"User '{username}' already exists")
|
|
232
|
+
|
|
233
|
+
if username == c.ADMIN_USERNAME and access_level != "admin":
|
|
234
|
+
raise InvalidInputError(403, "admin_cannot_be_non_admin", "Setting the admin user to non-admin is not permitted")
|
|
284
235
|
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
session.delete(existing_user)
|
|
236
|
+
# Update existing user
|
|
237
|
+
db_user.access_level = access_level
|
|
238
|
+
db_user.custom_fields = custom_fields_json
|
|
289
239
|
else:
|
|
290
240
|
if update_user:
|
|
291
|
-
raise InvalidInputError(
|
|
241
|
+
raise InvalidInputError(404, "no_user_found_for_username", f"No user found for username: {username}")
|
|
292
242
|
|
|
293
|
-
password = user_fields.get('password')
|
|
294
243
|
if password is None:
|
|
295
|
-
raise InvalidInputError(
|
|
244
|
+
raise InvalidInputError(400, "missing_password", f"Missing required field 'password' when adding a new user")
|
|
245
|
+
|
|
246
|
+
self._validate_password_length(password)
|
|
296
247
|
password_hash = pwd_context.hash(password)
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
248
|
+
db_user = self.DbUser(
|
|
249
|
+
username=username,
|
|
250
|
+
access_level=access_level,
|
|
251
|
+
password_hash=password_hash,
|
|
252
|
+
custom_fields=custom_fields_json
|
|
253
|
+
)
|
|
254
|
+
session.add(db_user)
|
|
301
255
|
|
|
302
256
|
# Commit the transaction
|
|
303
257
|
session.commit()
|
|
258
|
+
return self._convert_db_user_to_user(db_user)
|
|
304
259
|
|
|
305
260
|
finally:
|
|
306
261
|
session.close()
|
|
262
|
+
|
|
263
|
+
def create_or_get_user_from_provider(self, provider_name: str, user_info: dict) -> RegisteredUser:
|
|
264
|
+
provider = next((p for p in self.auth_providers if p.name == provider_name), None)
|
|
265
|
+
if provider is None:
|
|
266
|
+
raise InvalidInputError(404, "auth_provider_not_found", f"Provider '{provider_name}' not found")
|
|
267
|
+
|
|
268
|
+
user = provider.provider_configs.get_user(user_info)
|
|
269
|
+
session = self.Session()
|
|
270
|
+
try:
|
|
271
|
+
# Convert user to database user
|
|
272
|
+
custom_fields_json = user.custom_fields.model_dump_json()
|
|
273
|
+
db_user = self.DbUser(
|
|
274
|
+
username=user.username,
|
|
275
|
+
access_level=user.access_level,
|
|
276
|
+
password_hash="", # By omitting password_hash, it becomes impossible to login with username and password (OAuth only)
|
|
277
|
+
custom_fields=custom_fields_json
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
existing_db_user = session.get(self.DbUser, db_user.username)
|
|
281
|
+
if existing_db_user is None:
|
|
282
|
+
session.add(db_user)
|
|
283
|
+
|
|
284
|
+
session.commit()
|
|
307
285
|
|
|
308
|
-
|
|
286
|
+
return self._convert_db_user_to_user(db_user)
|
|
287
|
+
|
|
288
|
+
finally:
|
|
289
|
+
session.close()
|
|
290
|
+
|
|
291
|
+
def get_user(self, username: str, password: str) -> RegisteredUser:
|
|
309
292
|
session = self.Session()
|
|
310
293
|
try:
|
|
311
294
|
# Query for user by username
|
|
312
295
|
db_user = session.get(self.DbUser, username)
|
|
313
296
|
|
|
314
297
|
if db_user and pwd_context.verify(password, db_user.password_hash):
|
|
315
|
-
user = self.
|
|
316
|
-
return user
|
|
298
|
+
user = self._convert_db_user_to_user(db_user)
|
|
299
|
+
return user
|
|
317
300
|
else:
|
|
318
|
-
raise InvalidInputError(
|
|
301
|
+
raise InvalidInputError(401, "incorrect_username_or_password", f"Incorrect username or password")
|
|
319
302
|
|
|
320
303
|
finally:
|
|
321
304
|
session.close()
|
|
@@ -325,127 +308,470 @@ class Authenticator(_t.Generic[User]):
|
|
|
325
308
|
try:
|
|
326
309
|
db_user = session.get(self.DbUser, username)
|
|
327
310
|
if db_user is None:
|
|
328
|
-
raise InvalidInputError(
|
|
311
|
+
raise InvalidInputError(401, "user_not_found", f"Username '{username}' not found for password change")
|
|
329
312
|
|
|
330
|
-
if pwd_context.verify(old_password, db_user.password_hash):
|
|
313
|
+
if db_user.password_hash and pwd_context.verify(old_password, db_user.password_hash):
|
|
314
|
+
self._validate_password_length(new_password)
|
|
331
315
|
db_user.password_hash = pwd_context.hash(new_password)
|
|
332
316
|
session.commit()
|
|
333
317
|
else:
|
|
334
|
-
raise InvalidInputError(
|
|
318
|
+
raise InvalidInputError(401, "incorrect_password", f"Incorrect password")
|
|
335
319
|
finally:
|
|
336
320
|
session.close()
|
|
337
321
|
|
|
338
322
|
def delete_user(self, username: str) -> None:
|
|
339
323
|
if username == c.ADMIN_USERNAME:
|
|
340
|
-
raise InvalidInputError(
|
|
324
|
+
raise InvalidInputError(403, "cannot_delete_admin_user", "Cannot delete the admin user")
|
|
341
325
|
|
|
342
326
|
session = self.Session()
|
|
343
327
|
try:
|
|
344
328
|
db_user = session.get(self.DbUser, username)
|
|
345
329
|
if db_user is None:
|
|
346
|
-
raise InvalidInputError(
|
|
330
|
+
raise InvalidInputError(404, "no_user_found_for_username", f"No user found for username: {username}")
|
|
347
331
|
session.delete(db_user)
|
|
348
332
|
session.commit()
|
|
349
333
|
finally:
|
|
350
334
|
session.close()
|
|
351
335
|
|
|
352
|
-
def get_all_users(self) -> list[
|
|
336
|
+
def get_all_users(self) -> list[RegisteredUser]:
|
|
353
337
|
session = self.Session()
|
|
354
338
|
try:
|
|
355
339
|
db_users = session.query(self.DbUser).all()
|
|
356
|
-
return [self.
|
|
340
|
+
return [self._convert_db_user_to_user(user) for user in db_users]
|
|
357
341
|
finally:
|
|
358
342
|
session.close()
|
|
359
343
|
|
|
360
|
-
def create_access_token(self, user:
|
|
344
|
+
def create_access_token(self, user: AbstractUser, expiry_minutes: int | None, *, title: str | None = None) -> tuple[str, datetime]:
|
|
345
|
+
"""
|
|
346
|
+
Creates an API key if title is provided. Otherwise, creates a JWT token.
|
|
347
|
+
"""
|
|
361
348
|
created_at = datetime.now(timezone.utc)
|
|
362
349
|
expire_at = created_at + timedelta(minutes=expiry_minutes) if expiry_minutes is not None else datetime.max
|
|
363
|
-
|
|
350
|
+
|
|
351
|
+
if self.secret_key is None:
|
|
352
|
+
raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to create an access token")
|
|
353
|
+
|
|
364
354
|
if title is not None:
|
|
365
355
|
session = self.Session()
|
|
366
356
|
try:
|
|
367
|
-
|
|
368
|
-
|
|
357
|
+
token_id = "sqrl-" + secrets.token_urlsafe(16)
|
|
358
|
+
hashed_key = u.hash_string(token_id, salt=self.secret_key)
|
|
359
|
+
last_four = token_id[-4:]
|
|
360
|
+
api_key = self.DbApiKey(
|
|
361
|
+
hashed_key=hashed_key, last_four=last_four, title=title, username=user.username,
|
|
362
|
+
created_at=created_at, expires_at=expire_at
|
|
363
|
+
)
|
|
364
|
+
session.add(api_key)
|
|
369
365
|
session.commit()
|
|
370
|
-
token_id = access_token.token_id
|
|
371
366
|
finally:
|
|
372
367
|
session.close()
|
|
368
|
+
else:
|
|
369
|
+
to_encode = {"username": user.username, "exp": expire_at}
|
|
370
|
+
token_id = jwt.encode(to_encode, self.secret_key, algorithm="HS256")
|
|
373
371
|
|
|
374
|
-
|
|
375
|
-
raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to create an access token")
|
|
376
|
-
to_encode = {"username": user.username, "token_id": token_id, "exp": expire_at}
|
|
377
|
-
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm="HS256")
|
|
378
|
-
return encoded_jwt, expire_at
|
|
372
|
+
return token_id, expire_at
|
|
379
373
|
|
|
380
|
-
def get_user_from_token(self, token: str | None) ->
|
|
381
|
-
|
|
382
|
-
|
|
374
|
+
def get_user_from_token(self, token: str | None) -> tuple[RegisteredUser | None, float | None]:
|
|
375
|
+
"""
|
|
376
|
+
Get a user and expiry time from an access token (JWT, or API key if token starts with 'sqrl-')
|
|
377
|
+
"""
|
|
378
|
+
if not token:
|
|
379
|
+
return None, None
|
|
383
380
|
|
|
384
381
|
if self.secret_key is None:
|
|
385
382
|
raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to get user from an access token")
|
|
386
383
|
|
|
387
|
-
try:
|
|
388
|
-
payload: dict = jwt.decode(token, self.secret_key, algorithms=["HS256"])
|
|
389
|
-
except InvalidTokenError:
|
|
390
|
-
raise InvalidInputError(1, "Invalid authorization token")
|
|
391
|
-
|
|
392
384
|
session = self.Session()
|
|
393
385
|
try:
|
|
394
|
-
if
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
self.
|
|
398
|
-
self.
|
|
386
|
+
if token.startswith("sqrl-"):
|
|
387
|
+
hashed_key = u.hash_string(token, salt=self.secret_key)
|
|
388
|
+
api_key = session.query(self.DbApiKey).filter(
|
|
389
|
+
self.DbApiKey.hashed_key == hashed_key,
|
|
390
|
+
self.DbApiKey.expires_at >= func.now()
|
|
399
391
|
).first()
|
|
400
|
-
if
|
|
401
|
-
raise
|
|
402
|
-
|
|
403
|
-
|
|
392
|
+
if api_key is None:
|
|
393
|
+
raise InvalidTokenError()
|
|
394
|
+
username = api_key.username
|
|
395
|
+
expiry = None
|
|
396
|
+
else:
|
|
397
|
+
payload: dict = jwt.decode(token, self.secret_key, algorithms=["HS256"])
|
|
398
|
+
username = payload["username"]
|
|
399
|
+
expiry = payload.get("exp")
|
|
400
|
+
|
|
401
|
+
db_user = session.get(self.DbUser, username)
|
|
404
402
|
if db_user is None:
|
|
405
|
-
raise
|
|
403
|
+
raise InvalidTokenError()
|
|
404
|
+
|
|
405
|
+
user = self._convert_db_user_to_user(db_user)
|
|
406
|
+
|
|
407
|
+
except InvalidTokenError:
|
|
408
|
+
raise InvalidInputError(401, "invalid_authorization_token", "Invalid authorization token")
|
|
406
409
|
finally:
|
|
407
410
|
session.close()
|
|
408
411
|
|
|
409
|
-
user
|
|
410
|
-
return user # type: ignore
|
|
412
|
+
return user, expiry
|
|
411
413
|
|
|
412
|
-
def
|
|
414
|
+
def _get_jwks_client(self, jwks_uri: str) -> PyJWKClient:
|
|
415
|
+
if jwks_uri not in self._jwks_clients:
|
|
416
|
+
self._jwks_clients[jwks_uri] = PyJWKClient(jwks_uri)
|
|
417
|
+
return self._jwks_clients[jwks_uri]
|
|
418
|
+
|
|
419
|
+
def _get_issuer_from_token(self, token: str) -> str | None:
|
|
420
|
+
try:
|
|
421
|
+
# JWT format is header.payload.signature
|
|
422
|
+
parts = token.split('.')
|
|
423
|
+
if len(parts) != 3:
|
|
424
|
+
return None
|
|
425
|
+
|
|
426
|
+
# Base64url decode the payload (JWT uses base64url encoding)
|
|
427
|
+
payload_b64 = parts[1]
|
|
428
|
+
# Add padding if necessary: (-len % 4) yields 0..3
|
|
429
|
+
padding = '=' * ((-len(payload_b64)) % 4)
|
|
430
|
+
payload_json = base64.urlsafe_b64decode(payload_b64 + padding).decode("utf-8")
|
|
431
|
+
payload: dict = json.loads(payload_json)
|
|
432
|
+
|
|
433
|
+
issuer = payload.get("iss")
|
|
434
|
+
return issuer if isinstance(issuer, str) and issuer else None
|
|
435
|
+
except Exception:
|
|
436
|
+
return None
|
|
437
|
+
|
|
438
|
+
@staticmethod
|
|
439
|
+
def _is_jwt(token: str) -> bool:
|
|
440
|
+
if not isinstance(token, str) or not token:
|
|
441
|
+
return False
|
|
442
|
+
parts = token.split(".")
|
|
443
|
+
return len(parts) == 3 and all(parts)
|
|
444
|
+
|
|
445
|
+
def _get_provider_metadata(self, provider_name: str) -> dict:
|
|
446
|
+
provider = next((p for p in self.auth_providers if p.name == provider_name), None)
|
|
447
|
+
if provider is None:
|
|
448
|
+
raise InvalidInputError(404, "auth_provider_not_found", f"Provider '{provider_name}' not found")
|
|
449
|
+
|
|
450
|
+
metadata_url = provider.provider_configs.server_metadata_url
|
|
451
|
+
cached = self._provider_metadata_cache.get(metadata_url)
|
|
452
|
+
if isinstance(cached, dict) and cached:
|
|
453
|
+
return cached
|
|
454
|
+
|
|
455
|
+
try:
|
|
456
|
+
response = requests.get(metadata_url, timeout=10)
|
|
457
|
+
response.raise_for_status()
|
|
458
|
+
metadata = response.json()
|
|
459
|
+
if not isinstance(metadata, dict):
|
|
460
|
+
raise ValueError("Provider metadata was not a JSON object")
|
|
461
|
+
except Exception as e:
|
|
462
|
+
raise ConfigurationError(f"Failed to fetch metadata for provider '{provider_name}': {str(e)}") from e
|
|
463
|
+
|
|
464
|
+
self._provider_metadata_cache[metadata_url] = metadata
|
|
465
|
+
return metadata
|
|
466
|
+
|
|
467
|
+
def _verify_provider_jwt(
|
|
468
|
+
self,
|
|
469
|
+
provider_name: str,
|
|
470
|
+
token: str,
|
|
471
|
+
*,
|
|
472
|
+
purpose: str,
|
|
473
|
+
expected_nonce: str | None = None,
|
|
474
|
+
verify_aud: bool = True,
|
|
475
|
+
) -> dict | None:
|
|
476
|
+
"""
|
|
477
|
+
Verify a provider-issued JWT (signature + exp; aud when requested).
|
|
478
|
+
Uses OIDC discovery for jwks_uri and (best-effort) issuer validation.
|
|
479
|
+
"""
|
|
480
|
+
if not self._is_jwt(token):
|
|
481
|
+
return None
|
|
482
|
+
|
|
483
|
+
provider = next((p for p in self.auth_providers if p.name == provider_name), None)
|
|
484
|
+
if provider is None:
|
|
485
|
+
raise InvalidInputError(404, "auth_provider_not_found", f"Provider '{provider_name}' not found")
|
|
486
|
+
|
|
487
|
+
metadata = self._get_provider_metadata(provider_name)
|
|
488
|
+
jwks_uri = metadata.get("jwks_uri")
|
|
489
|
+
if not isinstance(jwks_uri, str) or not jwks_uri:
|
|
490
|
+
raise ConfigurationError(f"jwks_uri not found in metadata for provider '{provider_name}'")
|
|
491
|
+
|
|
492
|
+
signing_algs = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
|
|
493
|
+
if not isinstance(signing_algs, list) or not signing_algs:
|
|
494
|
+
signing_algs = ["RS256"]
|
|
495
|
+
|
|
496
|
+
jwks_client = self._get_jwks_client(jwks_uri)
|
|
497
|
+
signing_key = jwks_client.get_signing_key_from_jwt(token)
|
|
498
|
+
|
|
499
|
+
decode_kwargs: dict[str, Any] = {
|
|
500
|
+
"key": signing_key.key,
|
|
501
|
+
"algorithms": signing_algs,
|
|
502
|
+
"options": {
|
|
503
|
+
"verify_aud": bool(verify_aud),
|
|
504
|
+
# We'll validate issuer manually to avoid brittle trailing-slash mismatches.
|
|
505
|
+
"verify_iss": False,
|
|
506
|
+
},
|
|
507
|
+
}
|
|
508
|
+
if verify_aud:
|
|
509
|
+
decode_kwargs["audience"] = provider.provider_configs.client_id
|
|
510
|
+
|
|
511
|
+
try:
|
|
512
|
+
payload = jwt.decode(token, **decode_kwargs)
|
|
513
|
+
except Exception:
|
|
514
|
+
return None
|
|
515
|
+
|
|
516
|
+
if not isinstance(payload, dict):
|
|
517
|
+
return None
|
|
518
|
+
|
|
519
|
+
expected_issuer = metadata.get("issuer") or provider.provider_configs.server_url
|
|
520
|
+
token_issuer = payload.get("iss")
|
|
521
|
+
if isinstance(expected_issuer, str) and expected_issuer and isinstance(token_issuer, str) and token_issuer:
|
|
522
|
+
if token_issuer.rstrip("/") != expected_issuer.rstrip("/"):
|
|
523
|
+
raise InvalidInputError(
|
|
524
|
+
401,
|
|
525
|
+
"invalid_provider_token",
|
|
526
|
+
f"Invalid {purpose} issuer for provider '{provider_name}'",
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
if expected_nonce is not None:
|
|
530
|
+
nonce_claim = payload.get("nonce")
|
|
531
|
+
if nonce_claim != expected_nonce:
|
|
532
|
+
raise InvalidInputError(
|
|
533
|
+
401,
|
|
534
|
+
"invalid_provider_token",
|
|
535
|
+
f"Invalid {purpose} nonce for provider '{provider_name}'",
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
return payload
|
|
539
|
+
|
|
540
|
+
def get_user_info_from_token_details(self, provider_name: str, token_details: dict, *, expected_nonce: str | None = None) -> dict:
|
|
541
|
+
"""
|
|
542
|
+
Determine user_info from an OAuth/OIDC token response.
|
|
543
|
+
|
|
544
|
+
Priority:
|
|
545
|
+
- token_details["user_info"] / token_details["userinfo"]
|
|
546
|
+
- verify + decode token_details["id_token"] if it's a JWT
|
|
547
|
+
- verify + decode token_details["access_token"] if it's a JWT
|
|
548
|
+
- for opaque access_token: call userinfo or introspection endpoint from provider metadata
|
|
549
|
+
"""
|
|
550
|
+
for key in ("user_info", "userinfo"):
|
|
551
|
+
user_info = token_details.get(key)
|
|
552
|
+
if isinstance(user_info, dict) and user_info:
|
|
553
|
+
return user_info
|
|
554
|
+
|
|
555
|
+
id_token = token_details.get("id_token")
|
|
556
|
+
if isinstance(id_token, str):
|
|
557
|
+
if payload := self._verify_provider_jwt(
|
|
558
|
+
provider_name, id_token, purpose="id_token", expected_nonce=expected_nonce, verify_aud=True
|
|
559
|
+
):
|
|
560
|
+
return payload
|
|
561
|
+
|
|
562
|
+
access_token = token_details.get("access_token")
|
|
563
|
+
if isinstance(access_token, str):
|
|
564
|
+
# Some providers issue JWT access tokens. Audience can vary (resource server),
|
|
565
|
+
# so we verify signature/exp and issuer, but skip aud validation.
|
|
566
|
+
if payload := self._verify_provider_jwt(
|
|
567
|
+
provider_name, access_token, purpose="access_token", expected_nonce=None, verify_aud=False
|
|
568
|
+
):
|
|
569
|
+
return payload
|
|
570
|
+
|
|
571
|
+
if not isinstance(access_token, str) or not access_token:
|
|
572
|
+
raise InvalidInputError(
|
|
573
|
+
400,
|
|
574
|
+
"invalid_provider_user_info",
|
|
575
|
+
f"User information not found in token details for {provider_name}",
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
provider = next((p for p in self.auth_providers if p.name == provider_name), None)
|
|
579
|
+
if provider is None:
|
|
580
|
+
raise InvalidInputError(404, "auth_provider_not_found", f"Provider '{provider_name}' not found")
|
|
581
|
+
|
|
582
|
+
metadata: dict = self._get_provider_metadata(provider_name)
|
|
583
|
+
|
|
584
|
+
userinfo_endpoint = metadata.get("userinfo_endpoint")
|
|
585
|
+
if isinstance(userinfo_endpoint, str) and userinfo_endpoint:
|
|
586
|
+
try:
|
|
587
|
+
response = requests.get(
|
|
588
|
+
userinfo_endpoint,
|
|
589
|
+
headers={"Authorization": f"Bearer {access_token}"},
|
|
590
|
+
timeout=10,
|
|
591
|
+
)
|
|
592
|
+
response.raise_for_status()
|
|
593
|
+
user_info = response.json()
|
|
594
|
+
if isinstance(user_info, dict) and user_info:
|
|
595
|
+
return user_info
|
|
596
|
+
except Exception as e:
|
|
597
|
+
self.logger.warning(f"Failed to fetch user info from {userinfo_endpoint} for provider {provider_name}: {str(e)}")
|
|
598
|
+
# Fall back to introspection if available
|
|
599
|
+
pass
|
|
600
|
+
|
|
601
|
+
introspection_endpoint = metadata.get("introspection_endpoint")
|
|
602
|
+
if isinstance(introspection_endpoint, str) and introspection_endpoint:
|
|
603
|
+
try:
|
|
604
|
+
response = requests.post(
|
|
605
|
+
introspection_endpoint,
|
|
606
|
+
data={"token": access_token},
|
|
607
|
+
auth=(provider.provider_configs.client_id, provider.provider_configs.client_secret),
|
|
608
|
+
timeout=10,
|
|
609
|
+
)
|
|
610
|
+
response.raise_for_status()
|
|
611
|
+
token_info = response.json()
|
|
612
|
+
if isinstance(token_info, dict):
|
|
613
|
+
if token_info.get("active") is False:
|
|
614
|
+
raise InvalidInputError(401, "inactive_external_token", "External authorization token is inactive")
|
|
615
|
+
return token_info
|
|
616
|
+
except InvalidInputError:
|
|
617
|
+
raise
|
|
618
|
+
except Exception as e:
|
|
619
|
+
self.logger.warning(f"Introspection request failed for provider {provider_name} using basic auth: {str(e)}")
|
|
620
|
+
# Some providers require "client_secret_post" instead of basic auth for introspection.
|
|
621
|
+
try:
|
|
622
|
+
response = requests.post(
|
|
623
|
+
introspection_endpoint,
|
|
624
|
+
data={
|
|
625
|
+
"token": access_token,
|
|
626
|
+
"client_id": provider.provider_configs.client_id,
|
|
627
|
+
"client_secret": provider.provider_configs.client_secret,
|
|
628
|
+
},
|
|
629
|
+
timeout=10,
|
|
630
|
+
)
|
|
631
|
+
response.raise_for_status()
|
|
632
|
+
token_info = response.json()
|
|
633
|
+
if isinstance(token_info, dict):
|
|
634
|
+
if token_info.get("active") is False:
|
|
635
|
+
raise InvalidInputError(401, "inactive_external_token", "External authorization token is inactive")
|
|
636
|
+
return token_info
|
|
637
|
+
except InvalidInputError:
|
|
638
|
+
raise
|
|
639
|
+
except Exception as e2:
|
|
640
|
+
self.logger.warning(f"Introspection request failed for provider {provider_name} using post data: {str(e2)}")
|
|
641
|
+
pass
|
|
642
|
+
|
|
643
|
+
raise InvalidInputError(
|
|
644
|
+
400,
|
|
645
|
+
"invalid_provider_user_info",
|
|
646
|
+
f"User information not found in token details for {provider_name}",
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
def get_user_from_external_token(self, token: str, provider_name: str | None = None) -> tuple[RegisteredUser, float | None]:
|
|
650
|
+
"""
|
|
651
|
+
Get a user from an external OAuth token by validating against provider's JWKS
|
|
652
|
+
"""
|
|
653
|
+
issuer: str | None = None
|
|
654
|
+
token_is_jwt = self._is_jwt(token)
|
|
655
|
+
|
|
656
|
+
if provider_name:
|
|
657
|
+
provider = next((p for p in self.auth_providers if p.name == provider_name), None)
|
|
658
|
+
elif token_is_jwt:
|
|
659
|
+
issuer = self._get_issuer_from_token(token)
|
|
660
|
+
if not issuer:
|
|
661
|
+
raise InvalidInputError(401, "invalid_external_token", "Could not extract issuer from token")
|
|
662
|
+
|
|
663
|
+
# Match provider by issuer (server_url)
|
|
664
|
+
provider = next(
|
|
665
|
+
(p for p in self.auth_providers if p.provider_configs.server_url.rstrip("/") == issuer.rstrip("/")),
|
|
666
|
+
None,
|
|
667
|
+
)
|
|
668
|
+
else:
|
|
669
|
+
# Opaque external token: if there's exactly one configured provider, assume it.
|
|
670
|
+
provider = self.auth_providers[0] if len(self.auth_providers) == 1 else None
|
|
671
|
+
|
|
672
|
+
if provider is None:
|
|
673
|
+
if provider_name:
|
|
674
|
+
raise InvalidInputError(404, "auth_provider_not_found", f"Provider '{provider_name}' not found")
|
|
675
|
+
if token_is_jwt:
|
|
676
|
+
raise InvalidInputError(401, "auth_provider_not_found", f"No provider found for issuer: {issuer}")
|
|
677
|
+
raise InvalidInputError(401, "invalid_external_token", "Could not determine provider for external token")
|
|
678
|
+
|
|
679
|
+
# JWT external token: validate signature/exp (+ issuer) via provider JWKS
|
|
680
|
+
if token_is_jwt:
|
|
681
|
+
try:
|
|
682
|
+
payload = self._verify_provider_jwt(provider.name, token, purpose="external_token", verify_aud=False)
|
|
683
|
+
except InvalidInputError:
|
|
684
|
+
# Keep the external-auth contract stable (avoid leaking provider-specific details).
|
|
685
|
+
raise InvalidInputError(401, "invalid_external_token", "Invalid external authorization token")
|
|
686
|
+
|
|
687
|
+
if not isinstance(payload, dict) or not payload:
|
|
688
|
+
raise InvalidInputError(401, "invalid_external_token", "Invalid external authorization token")
|
|
689
|
+
else:
|
|
690
|
+
# Opaque token: reuse the existing provider userinfo/introspection logic.
|
|
691
|
+
try:
|
|
692
|
+
payload = self.get_user_info_from_token_details(provider.name, {"access_token": token})
|
|
693
|
+
except InvalidInputError as e:
|
|
694
|
+
# Normalize into the external-auth error contract.
|
|
695
|
+
if getattr(e, "error", None) == "inactive_external_token":
|
|
696
|
+
raise
|
|
697
|
+
raise InvalidInputError(401, "invalid_external_token", "Invalid external authorization token")
|
|
698
|
+
|
|
699
|
+
if not isinstance(payload, dict) or not payload:
|
|
700
|
+
raise InvalidInputError(401, "invalid_external_token", "Invalid external authorization token")
|
|
701
|
+
|
|
702
|
+
user = provider.provider_configs.get_user(payload)
|
|
703
|
+
exp = payload.get("exp")
|
|
704
|
+
expiry: float | None
|
|
705
|
+
if isinstance(exp, (int, float)):
|
|
706
|
+
expiry = float(exp)
|
|
707
|
+
else:
|
|
708
|
+
expiry = None
|
|
709
|
+
|
|
710
|
+
return user, expiry
|
|
711
|
+
|
|
712
|
+
def get_all_api_keys(self, username: str) -> list[ApiKey]:
|
|
713
|
+
"""
|
|
714
|
+
Get the ID, title, and expiry date of all API keys for a user. Note that the ID is a hash of the API key, not the API key itself.
|
|
715
|
+
"""
|
|
413
716
|
session = self.Session()
|
|
414
717
|
try:
|
|
415
|
-
tokens = session.query(self.
|
|
416
|
-
self.
|
|
417
|
-
self.
|
|
718
|
+
tokens = session.query(self.DbApiKey).filter(
|
|
719
|
+
self.DbApiKey.username == username,
|
|
720
|
+
self.DbApiKey.expires_at >= func.now()
|
|
418
721
|
).all()
|
|
419
722
|
|
|
420
|
-
return [
|
|
723
|
+
return [ApiKey.model_validate(token) for token in tokens]
|
|
421
724
|
finally:
|
|
422
725
|
session.close()
|
|
423
726
|
|
|
424
|
-
def
|
|
727
|
+
def revoke_api_key(self, username: str, api_key_id: str) -> None:
|
|
728
|
+
"""
|
|
729
|
+
Revoke an API key
|
|
730
|
+
"""
|
|
425
731
|
session = self.Session()
|
|
426
732
|
try:
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
self.
|
|
733
|
+
|
|
734
|
+
api_key = session.query(self.DbApiKey).filter(
|
|
735
|
+
self.DbApiKey.username == username,
|
|
736
|
+
self.DbApiKey.id == api_key_id
|
|
430
737
|
).first()
|
|
431
738
|
|
|
432
|
-
if
|
|
433
|
-
raise InvalidInputError(
|
|
739
|
+
if api_key is None:
|
|
740
|
+
raise InvalidInputError(404, "api_key_not_found", f"The API key could not be found: {api_key_id}")
|
|
434
741
|
|
|
435
|
-
session.delete(
|
|
742
|
+
session.delete(api_key)
|
|
436
743
|
session.commit()
|
|
437
744
|
finally:
|
|
438
745
|
session.close()
|
|
439
746
|
|
|
440
|
-
def can_user_access_scope(self, user:
|
|
441
|
-
if user
|
|
747
|
+
def can_user_access_scope(self, user: AbstractUser, scope: PermissionScope) -> bool:
|
|
748
|
+
if user.access_level == "guest":
|
|
442
749
|
user_level = PermissionScope.PUBLIC
|
|
443
|
-
elif user.
|
|
750
|
+
elif user.access_level == "admin":
|
|
444
751
|
user_level = PermissionScope.PRIVATE
|
|
445
|
-
else:
|
|
752
|
+
else: # member
|
|
446
753
|
user_level = PermissionScope.PROTECTED
|
|
447
754
|
|
|
448
755
|
return user_level.value >= scope.value
|
|
449
|
-
|
|
756
|
+
|
|
450
757
|
def close(self) -> None:
|
|
451
|
-
self
|
|
758
|
+
if hasattr(self, "engine"):
|
|
759
|
+
self.engine.dispose()
|
|
760
|
+
|
|
761
|
+
|
|
762
|
+
def provider(name: str, label: str, icon: str):
|
|
763
|
+
"""
|
|
764
|
+
Decorator to register an authentication provider
|
|
765
|
+
|
|
766
|
+
Arguments:
|
|
767
|
+
name: The name of the provider (must be unique, e.g. 'google')
|
|
768
|
+
label: The label of the provider (e.g. 'Google')
|
|
769
|
+
icon: The URL of the icon of the provider (e.g. 'https://www.google.com/favicon.ico')
|
|
770
|
+
"""
|
|
771
|
+
def decorator(func: Callable[[AuthProviderArgs], ProviderConfigs]):
|
|
772
|
+
def wrapper(sqrl: AuthProviderArgs):
|
|
773
|
+
provider_configs = func(sqrl)
|
|
774
|
+
return AuthProvider(name=name, label=label, icon=icon, provider_configs=provider_configs)
|
|
775
|
+
Authenticator.providers.append(wrapper)
|
|
776
|
+
return wrapper
|
|
777
|
+
return decorator
|