squirrels 0.5.0b4__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of squirrels might be problematic. Click here for more details.
- squirrels/__init__.py +2 -0
- squirrels/_api_routes/auth.py +83 -74
- squirrels/_api_routes/base.py +58 -41
- squirrels/_api_routes/dashboards.py +37 -21
- squirrels/_api_routes/data_management.py +72 -27
- squirrels/_api_routes/datasets.py +107 -84
- squirrels/_api_routes/oauth2.py +11 -13
- squirrels/_api_routes/project.py +71 -33
- squirrels/_api_server.py +130 -63
- squirrels/_arguments/run_time_args.py +9 -9
- squirrels/_auth.py +117 -162
- squirrels/_command_line.py +68 -32
- squirrels/_compile_prompts.py +147 -0
- squirrels/_connection_set.py +11 -2
- squirrels/_constants.py +22 -8
- squirrels/_data_sources.py +38 -32
- squirrels/_dataset_types.py +2 -4
- squirrels/_initializer.py +1 -1
- squirrels/_logging.py +117 -0
- squirrels/_manifest.py +125 -58
- squirrels/_model_builder.py +10 -54
- squirrels/_models.py +224 -108
- squirrels/_package_data/base_project/.env +15 -4
- squirrels/_package_data/base_project/.env.example +14 -3
- squirrels/_package_data/base_project/connections.yml +4 -3
- squirrels/_package_data/base_project/dashboards/dashboard_example.py +2 -2
- squirrels/_package_data/base_project/dashboards/dashboard_example.yml +4 -4
- squirrels/_package_data/base_project/duckdb_init.sql +1 -0
- squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +7 -2
- squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +16 -10
- squirrels/_package_data/base_project/models/federates/federate_example.py +22 -15
- squirrels/_package_data/base_project/models/federates/federate_example.sql +3 -7
- squirrels/_package_data/base_project/models/federates/federate_example.yml +1 -1
- squirrels/_package_data/base_project/models/sources.yml +5 -6
- squirrels/_package_data/base_project/parameters.yml +24 -38
- squirrels/_package_data/base_project/pyconfigs/connections.py +5 -1
- squirrels/_package_data/base_project/pyconfigs/context.py +23 -12
- squirrels/_package_data/base_project/pyconfigs/parameters.py +68 -33
- squirrels/_package_data/base_project/pyconfigs/user.py +11 -18
- squirrels/_package_data/base_project/seeds/seed_categories.yml +1 -1
- squirrels/_package_data/base_project/seeds/seed_subcategories.yml +1 -1
- squirrels/_package_data/base_project/squirrels.yml.j2 +18 -28
- squirrels/_package_data/templates/squirrels_studio.html +20 -0
- squirrels/_parameter_configs.py +43 -22
- squirrels/_parameter_options.py +1 -1
- squirrels/_parameter_sets.py +8 -10
- squirrels/_project.py +351 -234
- squirrels/_request_context.py +33 -0
- squirrels/_schemas/auth_models.py +32 -9
- squirrels/_schemas/query_param_models.py +9 -1
- squirrels/_schemas/response_models.py +36 -10
- squirrels/_seeds.py +1 -1
- squirrels/_sources.py +23 -19
- squirrels/_utils.py +83 -35
- squirrels/_version.py +1 -1
- squirrels/arguments.py +5 -0
- squirrels/auth.py +4 -1
- squirrels/connections.py +2 -0
- squirrels/dashboards.py +3 -1
- squirrels/data_sources.py +6 -0
- squirrels/parameter_options.py +5 -0
- squirrels/parameters.py +5 -0
- squirrels/types.py +6 -1
- {squirrels-0.5.0b4.dist-info → squirrels-0.5.1.dist-info}/METADATA +28 -13
- squirrels-0.5.1.dist-info/RECORD +98 -0
- squirrels-0.5.0b4.dist-info/RECORD +0 -94
- {squirrels-0.5.0b4.dist-info → squirrels-0.5.1.dist-info}/WHEEL +0 -0
- {squirrels-0.5.0b4.dist-info → squirrels-0.5.1.dist-info}/entry_points.txt +0 -0
- {squirrels-0.5.0b4.dist-info → squirrels-0.5.1.dist-info}/licenses/LICENSE +0 -0
squirrels/_auth.py
CHANGED
|
@@ -1,57 +1,52 @@
|
|
|
1
|
-
from typing import Callable
|
|
1
|
+
from typing import Callable, Any
|
|
2
2
|
from datetime import datetime, timedelta, timezone
|
|
3
|
-
from enum import Enum
|
|
4
3
|
from functools import cached_property
|
|
5
4
|
from jwt.exceptions import InvalidTokenError
|
|
6
5
|
from passlib.context import CryptContext
|
|
7
6
|
from pydantic import ValidationError
|
|
8
|
-
from pydantic_core import PydanticUndefined
|
|
9
7
|
from sqlalchemy import create_engine, Engine, func, inspect, text, ForeignKey
|
|
10
|
-
from sqlalchemy import Column, String, Integer, Float, Boolean
|
|
11
8
|
from sqlalchemy.orm import declarative_base, sessionmaker, Mapped, mapped_column
|
|
12
|
-
import jwt,
|
|
9
|
+
import jwt, uuid, secrets, json
|
|
13
10
|
|
|
14
11
|
from ._manifest import PermissionScope
|
|
15
12
|
from ._py_module import PyModule
|
|
16
13
|
from ._exceptions import InvalidInputError, ConfigurationError
|
|
17
14
|
from ._arguments.init_time_args import AuthProviderArgs
|
|
18
15
|
from ._schemas.auth_models import (
|
|
19
|
-
|
|
20
|
-
ClientDetailsResponse, ClientRegistrationResponse,
|
|
16
|
+
CustomUserFields, AbstractUser, GuestUser, RegisteredUser, ApiKey, UserField, AuthProvider, ProviderConfigs,
|
|
17
|
+
ClientRegistrationRequest, ClientUpdateRequest, ClientDetailsResponse, ClientRegistrationResponse,
|
|
18
|
+
ClientUpdateResponse, TokenResponse
|
|
21
19
|
)
|
|
22
20
|
from . import _utils as u, _constants as c
|
|
23
21
|
|
|
24
22
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
25
23
|
|
|
26
|
-
reserved_fields = ["username", "is_admin"]
|
|
27
|
-
disallowed_fields = ["password", "password_hash", "created_at", "token_id", "exp"]
|
|
28
|
-
|
|
29
|
-
User = _t.TypeVar('User', bound=BaseUser)
|
|
30
|
-
|
|
31
24
|
ProviderFunctionType = Callable[[AuthProviderArgs], AuthProvider]
|
|
32
25
|
|
|
33
26
|
|
|
34
|
-
class Authenticator
|
|
27
|
+
class Authenticator:
|
|
35
28
|
providers: list[ProviderFunctionType] = [] # static variable to stage providers
|
|
36
29
|
|
|
37
30
|
def __init__(
|
|
38
31
|
self, logger: u.Logger, base_path: str, auth_args: AuthProviderArgs, provider_functions: list[ProviderFunctionType],
|
|
39
|
-
|
|
32
|
+
custom_user_fields_cls: type[CustomUserFields], *, sa_engine: Engine | None = None, external_only: bool = False
|
|
40
33
|
):
|
|
41
34
|
self.logger = logger
|
|
42
35
|
self.env_vars = auth_args.env_vars
|
|
43
36
|
self.secret_key = self.env_vars.get(c.SQRL_SECRET_KEY)
|
|
37
|
+
self.external_only = external_only
|
|
44
38
|
|
|
45
39
|
# Create a new declarative base for this instance
|
|
46
40
|
self.Base = declarative_base()
|
|
47
41
|
|
|
48
|
-
# Define
|
|
49
|
-
class
|
|
42
|
+
# Define DbUser class for this instance
|
|
43
|
+
class DbUser(self.Base):
|
|
50
44
|
__tablename__ = 'users'
|
|
51
45
|
__table_args__ = {'extend_existing': True}
|
|
52
46
|
username: Mapped[str] = mapped_column(primary_key=True)
|
|
53
|
-
|
|
47
|
+
access_level: Mapped[str] = mapped_column(nullable=False, default="member")
|
|
54
48
|
password_hash: Mapped[str] = mapped_column(nullable=False)
|
|
49
|
+
custom_fields: Mapped[str] = mapped_column(nullable=False, default="{}")
|
|
55
50
|
created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
|
|
56
51
|
|
|
57
52
|
# Define DbApiKey class for this instance
|
|
@@ -124,20 +119,19 @@ class Authenticator(_t.Generic[User]):
|
|
|
124
119
|
def __repr__(self):
|
|
125
120
|
return f"<DbOAuthToken(token_id='{self.token_id}', client_id='{self.client_id}', username='{self.username}')>"
|
|
126
121
|
|
|
127
|
-
self.DbBaseUser = DbBaseUser
|
|
128
122
|
self.DbApiKey = DbApiKey
|
|
129
123
|
self.DbOAuthClient = DbOAuthClient
|
|
130
124
|
self.DbAuthorizationCode = DbAuthorizationCode
|
|
131
125
|
self.DbOAuthToken = DbOAuthToken
|
|
132
126
|
|
|
133
|
-
self.
|
|
134
|
-
self.DbUser
|
|
127
|
+
self.CustomUserFields = custom_user_fields_cls
|
|
128
|
+
self.DbUser = DbUser
|
|
135
129
|
|
|
136
130
|
self.auth_providers = [provider_function(auth_args) for provider_function in provider_functions]
|
|
137
131
|
|
|
138
132
|
if sa_engine is None:
|
|
139
|
-
|
|
140
|
-
sqlite_path = u.Path(base_path
|
|
133
|
+
raw_sqlite_path = self.env_vars.get(c.SQRL_AUTH_DB_FILE_PATH, f"{{project_path}}/{c.TARGET_FOLDER}/{c.DB_FILE}")
|
|
134
|
+
sqlite_path = u.Path(raw_sqlite_path.format(project_path=base_path))
|
|
141
135
|
sqlite_path.parent.mkdir(parents=True, exist_ok=True)
|
|
142
136
|
self.engine = create_engine(f"sqlite:///{str(sqlite_path)}")
|
|
143
137
|
else:
|
|
@@ -153,67 +147,21 @@ class Authenticator(_t.Generic[User]):
|
|
|
153
147
|
|
|
154
148
|
self.Session = sessionmaker(bind=self.engine)
|
|
155
149
|
|
|
156
|
-
self._initialize_db(
|
|
150
|
+
self._initialize_db()
|
|
157
151
|
|
|
158
|
-
def
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
return
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
# Iterate over all fields in the User model
|
|
171
|
-
for field_name, field in self.User.model_fields.items():
|
|
172
|
-
if field_name in reserved_fields:
|
|
173
|
-
continue
|
|
174
|
-
if field_name in disallowed_fields:
|
|
175
|
-
raise ConfigurationError(f"Field name '{field_name}' is disallowed in the User model and cannot be used")
|
|
176
|
-
|
|
177
|
-
field_type = field.annotation
|
|
178
|
-
if _t.get_origin(field_type) in (_t.Union, types.UnionType):
|
|
179
|
-
field_type = _t.get_args(field_type)[0]
|
|
180
|
-
nullable = True
|
|
181
|
-
else:
|
|
182
|
-
nullable = False
|
|
152
|
+
def _convert_db_user_to_user(self, db_user) -> RegisteredUser:
|
|
153
|
+
"""Convert a database user to an AbstractUser object"""
|
|
154
|
+
# Deserialize custom_fields JSON and merge with defaults
|
|
155
|
+
custom_fields_json = json.loads(db_user.custom_fields) if db_user.custom_fields else {}
|
|
156
|
+
custom_fields = self.CustomUserFields(**custom_fields_json)
|
|
157
|
+
|
|
158
|
+
return RegisteredUser(
|
|
159
|
+
username=db_user.username,
|
|
160
|
+
access_level=db_user.access_level,
|
|
161
|
+
custom_fields=custom_fields
|
|
162
|
+
)
|
|
183
163
|
|
|
184
|
-
|
|
185
|
-
field_type = str
|
|
186
|
-
|
|
187
|
-
# Map Python types and default values to SQLAlchemy columns
|
|
188
|
-
default_value = field.default
|
|
189
|
-
if default_value is PydanticUndefined:
|
|
190
|
-
raise ConfigurationError(f"No default value found for field '{field_name}' in User model")
|
|
191
|
-
elif not nullable and default_value is None:
|
|
192
|
-
raise ConfigurationError(f"Default value for non-nullable field '{field_name}' was set as None in User model")
|
|
193
|
-
elif default_value is not None and type(default_value) is not field_type:
|
|
194
|
-
raise ConfigurationError(f"Default value for field '{field_name}' does not match field type in User model")
|
|
195
|
-
|
|
196
|
-
if field_type == str:
|
|
197
|
-
col_type = String
|
|
198
|
-
elif field_type == int:
|
|
199
|
-
col_type = Integer
|
|
200
|
-
elif field_type == float:
|
|
201
|
-
col_type = Float
|
|
202
|
-
elif field_type == bool:
|
|
203
|
-
col_type = Boolean
|
|
204
|
-
elif isinstance(field_type, type) and issubclass(field_type, Enum):
|
|
205
|
-
col_type = String
|
|
206
|
-
default_value = default_value.value
|
|
207
|
-
else:
|
|
208
|
-
continue
|
|
209
|
-
|
|
210
|
-
attrs[field_name] = Column(col_type, nullable=nullable, default=default_value) # type: ignore
|
|
211
|
-
|
|
212
|
-
# Create the sqlalchemy model class
|
|
213
|
-
DbUser = type('DbUser', (self.DbBaseUser,), attrs)
|
|
214
|
-
return DbUser
|
|
215
|
-
|
|
216
|
-
def _initialize_db(self, *args): # TODO: Use logger instead of print
|
|
164
|
+
def _initialize_db(self): # TODO: Use logger instead of print
|
|
217
165
|
session = self.Session()
|
|
218
166
|
try:
|
|
219
167
|
# Get existing columns in the database
|
|
@@ -221,14 +169,12 @@ class Authenticator(_t.Generic[User]):
|
|
|
221
169
|
existing_columns = {col['name'] for col in inspector.get_columns('users')}
|
|
222
170
|
|
|
223
171
|
# Get all columns defined in the model
|
|
224
|
-
|
|
225
|
-
model_columns = set(self.DbUser.__table__.columns.keys()) - dropped_columns
|
|
172
|
+
model_columns = set(self.DbUser.__table__.columns.keys())
|
|
226
173
|
|
|
227
174
|
# Find columns that are in the model but not in the database
|
|
228
175
|
new_columns = model_columns - existing_columns
|
|
229
176
|
if new_columns:
|
|
230
177
|
add_columns_msg = f"Adding columns to database: {new_columns}"
|
|
231
|
-
print("NOTE -", add_columns_msg)
|
|
232
178
|
self.logger.info(add_columns_msg)
|
|
233
179
|
|
|
234
180
|
for col_name in new_columns:
|
|
@@ -245,24 +191,6 @@ class Authenticator(_t.Generic[User]):
|
|
|
245
191
|
session.execute(text(alter_stmt))
|
|
246
192
|
|
|
247
193
|
session.commit()
|
|
248
|
-
|
|
249
|
-
# Determine columns to drop
|
|
250
|
-
columns_to_drop = dropped_columns.intersection(existing_columns)
|
|
251
|
-
if columns_to_drop:
|
|
252
|
-
drop_columns_msg = f"Dropping columns from database: {columns_to_drop}"
|
|
253
|
-
print("NOTE -", drop_columns_msg)
|
|
254
|
-
self.logger.info(drop_columns_msg)
|
|
255
|
-
|
|
256
|
-
for col_name in columns_to_drop:
|
|
257
|
-
session.execute(text(f"ALTER TABLE users DROP COLUMN {col_name}"))
|
|
258
|
-
|
|
259
|
-
session.commit()
|
|
260
|
-
|
|
261
|
-
# Find columns that are in the database but not in the model
|
|
262
|
-
extra_db_columns = existing_columns - columns_to_drop - model_columns
|
|
263
|
-
if extra_db_columns:
|
|
264
|
-
self.logger.warning(f"The following database columns are not in the User model: {extra_db_columns}\n"
|
|
265
|
-
"If you want to drop these columns, please use the `dropped_columns` class method of the User model.")
|
|
266
194
|
|
|
267
195
|
# Get admin password from environment variable if exists
|
|
268
196
|
admin_password = self.env_vars.get(c.SQRL_SECRET_ADMIN_PASSWORD)
|
|
@@ -272,7 +200,7 @@ class Authenticator(_t.Generic[User]):
|
|
|
272
200
|
password_hash = pwd_context.hash(admin_password)
|
|
273
201
|
admin_user = session.get(self.DbUser, c.ADMIN_USERNAME)
|
|
274
202
|
if admin_user is None:
|
|
275
|
-
admin_user = self.DbUser(username=c.ADMIN_USERNAME, password_hash=password_hash,
|
|
203
|
+
admin_user = self.DbUser(username=c.ADMIN_USERNAME, password_hash=password_hash, access_level="admin")
|
|
276
204
|
session.add(admin_user)
|
|
277
205
|
else:
|
|
278
206
|
admin_user.password_hash = password_hash
|
|
@@ -285,7 +213,7 @@ class Authenticator(_t.Generic[User]):
|
|
|
285
213
|
@cached_property
|
|
286
214
|
def user_fields(self) -> list[UserField]:
|
|
287
215
|
"""
|
|
288
|
-
Get the fields of the
|
|
216
|
+
Get the fields of the CustomUserFields model as a list of dictionaries
|
|
289
217
|
|
|
290
218
|
Each dictionary contains the following keys:
|
|
291
219
|
- name: The name of the field
|
|
@@ -294,11 +222,15 @@ class Authenticator(_t.Generic[User]):
|
|
|
294
222
|
- enum: The possible values of the field (or None if not applicable)
|
|
295
223
|
- default: The default value of the field (or None if field is required)
|
|
296
224
|
"""
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
225
|
+
# Add the base fields
|
|
226
|
+
fields = [
|
|
227
|
+
UserField(name="username", type="string", nullable=False, enum=None, default=None),
|
|
228
|
+
UserField(name="access_level", type="string", nullable=False, enum=["admin", "member"], default="member")
|
|
229
|
+
]
|
|
230
|
+
|
|
231
|
+
# Add custom fields
|
|
232
|
+
schema = self.CustomUserFields.model_json_schema()
|
|
233
|
+
properties: dict[str, dict[str, Any]] = schema.get("properties", {})
|
|
302
234
|
for field_name, field_schema in properties.items():
|
|
303
235
|
if choices := field_schema.get("anyOf"):
|
|
304
236
|
field_type = choices[0]["type"]
|
|
@@ -315,37 +247,53 @@ class Authenticator(_t.Generic[User]):
|
|
|
315
247
|
def add_user(self, username: str, user_fields: dict, *, update_user: bool = False) -> None:
|
|
316
248
|
session = self.Session()
|
|
317
249
|
|
|
318
|
-
#
|
|
250
|
+
# Separate custom fields from base fields
|
|
251
|
+
access_level = user_fields.get('access_level', 'member')
|
|
252
|
+
password = user_fields.get('password')
|
|
253
|
+
|
|
254
|
+
# Validate access level - cannot add/update users with guest access level
|
|
255
|
+
if access_level == "guest":
|
|
256
|
+
raise InvalidInputError(400, "invalid_access_level", "Cannot create or update users with 'guest' access level")
|
|
257
|
+
|
|
258
|
+
# Extract custom fields (everything except username, access_level, password)
|
|
259
|
+
custom_fields_dict = {k: v for k, v in user_fields.items() if k not in ['username', 'access_level', 'password']}
|
|
260
|
+
|
|
261
|
+
# Validate the custom fields
|
|
319
262
|
try:
|
|
320
|
-
|
|
263
|
+
custom_fields = self.CustomUserFields(**custom_fields_dict)
|
|
264
|
+
custom_fields_json = json.dumps(custom_fields.model_dump(mode='json'))
|
|
321
265
|
except ValidationError as e:
|
|
322
|
-
raise InvalidInputError(400, "
|
|
266
|
+
raise InvalidInputError(400, "invalid_user_data", f"Invalid user field '{e.errors()[0]['loc'][0]}': {e.errors()[0]['msg']}")
|
|
323
267
|
|
|
324
|
-
# Add
|
|
268
|
+
# Add or update user
|
|
325
269
|
try:
|
|
326
270
|
# Check if the user already exists
|
|
327
271
|
existing_user = session.get(self.DbUser, username)
|
|
328
272
|
if existing_user is not None:
|
|
329
273
|
if not update_user:
|
|
330
|
-
raise InvalidInputError(400, "
|
|
274
|
+
raise InvalidInputError(400, "username_already_exists", f"User '{username}' already exists")
|
|
331
275
|
|
|
332
|
-
if username == c.ADMIN_USERNAME and
|
|
333
|
-
raise InvalidInputError(403, "
|
|
276
|
+
if username == c.ADMIN_USERNAME and access_level != "admin":
|
|
277
|
+
raise InvalidInputError(403, "admin_cannot_be_non_admin", "Setting the admin user to non-admin is not permitted")
|
|
334
278
|
|
|
335
|
-
|
|
336
|
-
|
|
279
|
+
# Update existing user
|
|
280
|
+
existing_user.access_level = access_level
|
|
281
|
+
existing_user.custom_fields = custom_fields_json
|
|
337
282
|
else:
|
|
338
283
|
if update_user:
|
|
339
|
-
raise InvalidInputError(404, "
|
|
284
|
+
raise InvalidInputError(404, "no_user_found_for_username", f"No user found for username: {username}")
|
|
340
285
|
|
|
341
|
-
password = user_fields.get('password')
|
|
342
286
|
if password is None:
|
|
343
|
-
raise InvalidInputError(400, "
|
|
287
|
+
raise InvalidInputError(400, "missing_password", f"Missing required field 'password' when adding a new user")
|
|
288
|
+
|
|
344
289
|
password_hash = pwd_context.hash(password)
|
|
345
|
-
new_user = self.DbUser(
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
290
|
+
new_user = self.DbUser(
|
|
291
|
+
username=username,
|
|
292
|
+
access_level=access_level,
|
|
293
|
+
password_hash=password_hash,
|
|
294
|
+
custom_fields=custom_fields_json
|
|
295
|
+
)
|
|
296
|
+
session.add(new_user)
|
|
349
297
|
|
|
350
298
|
# Commit the transaction
|
|
351
299
|
session.commit()
|
|
@@ -353,39 +301,45 @@ class Authenticator(_t.Generic[User]):
|
|
|
353
301
|
finally:
|
|
354
302
|
session.close()
|
|
355
303
|
|
|
356
|
-
def create_or_get_user_from_provider(self, provider_name: str, user_info: dict) ->
|
|
304
|
+
def create_or_get_user_from_provider(self, provider_name: str, user_info: dict) -> RegisteredUser:
|
|
357
305
|
provider = next((p for p in self.auth_providers if p.name == provider_name), None)
|
|
358
306
|
if provider is None:
|
|
359
|
-
raise InvalidInputError(404, "
|
|
307
|
+
raise InvalidInputError(404, "auth_provider_not_found", f"Provider '{provider_name}' not found")
|
|
360
308
|
|
|
361
309
|
user = provider.provider_configs.get_user(user_info)
|
|
362
310
|
session = self.Session()
|
|
363
311
|
try:
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
312
|
+
# Convert user to database user
|
|
313
|
+
custom_fields_json = user.custom_fields.model_dump_json()
|
|
314
|
+
db_user = self.DbUser(
|
|
315
|
+
username=user.username,
|
|
316
|
+
access_level=user.access_level,
|
|
317
|
+
password_hash="", # By omitting password_hash, it becomes impossible to login with username and password (OAuth only)
|
|
318
|
+
custom_fields=custom_fields_json
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
existing_db_user = session.get(self.DbUser, db_user.username)
|
|
322
|
+
if existing_db_user is None:
|
|
370
323
|
session.add(db_user)
|
|
371
|
-
session.commit()
|
|
372
324
|
|
|
373
|
-
|
|
325
|
+
session.commit()
|
|
326
|
+
|
|
327
|
+
return self._convert_db_user_to_user(db_user)
|
|
374
328
|
|
|
375
329
|
finally:
|
|
376
330
|
session.close()
|
|
377
331
|
|
|
378
|
-
def get_user(self, username: str, password: str) ->
|
|
332
|
+
def get_user(self, username: str, password: str) -> RegisteredUser:
|
|
379
333
|
session = self.Session()
|
|
380
334
|
try:
|
|
381
335
|
# Query for user by username
|
|
382
336
|
db_user = session.get(self.DbUser, username)
|
|
383
337
|
|
|
384
338
|
if db_user and pwd_context.verify(password, db_user.password_hash):
|
|
385
|
-
user = self.
|
|
386
|
-
return user
|
|
339
|
+
user = self._convert_db_user_to_user(db_user)
|
|
340
|
+
return user
|
|
387
341
|
else:
|
|
388
|
-
raise InvalidInputError(401, "
|
|
342
|
+
raise InvalidInputError(401, "incorrect_username_or_password", f"Incorrect username or password")
|
|
389
343
|
|
|
390
344
|
finally:
|
|
391
345
|
session.close()
|
|
@@ -395,39 +349,39 @@ class Authenticator(_t.Generic[User]):
|
|
|
395
349
|
try:
|
|
396
350
|
db_user = session.get(self.DbUser, username)
|
|
397
351
|
if db_user is None:
|
|
398
|
-
raise InvalidInputError(401, "
|
|
352
|
+
raise InvalidInputError(401, "user_not_found", f"Username '{username}' not found for password change")
|
|
399
353
|
|
|
400
354
|
if db_user.password_hash and pwd_context.verify(old_password, db_user.password_hash):
|
|
401
355
|
db_user.password_hash = pwd_context.hash(new_password)
|
|
402
356
|
session.commit()
|
|
403
357
|
else:
|
|
404
|
-
raise InvalidInputError(401, "
|
|
358
|
+
raise InvalidInputError(401, "incorrect_password", f"Incorrect password")
|
|
405
359
|
finally:
|
|
406
360
|
session.close()
|
|
407
361
|
|
|
408
362
|
def delete_user(self, username: str) -> None:
|
|
409
363
|
if username == c.ADMIN_USERNAME:
|
|
410
|
-
raise InvalidInputError(403, "
|
|
364
|
+
raise InvalidInputError(403, "cannot_delete_admin_user", "Cannot delete the admin user")
|
|
411
365
|
|
|
412
366
|
session = self.Session()
|
|
413
367
|
try:
|
|
414
368
|
db_user = session.get(self.DbUser, username)
|
|
415
369
|
if db_user is None:
|
|
416
|
-
raise InvalidInputError(404, "
|
|
370
|
+
raise InvalidInputError(404, "no_user_found_for_username", f"No user found for username: {username}")
|
|
417
371
|
session.delete(db_user)
|
|
418
372
|
session.commit()
|
|
419
373
|
finally:
|
|
420
374
|
session.close()
|
|
421
375
|
|
|
422
|
-
def get_all_users(self) -> list:
|
|
376
|
+
def get_all_users(self) -> list[RegisteredUser]:
|
|
423
377
|
session = self.Session()
|
|
424
378
|
try:
|
|
425
379
|
db_users = session.query(self.DbUser).all()
|
|
426
|
-
return [self.
|
|
380
|
+
return [self._convert_db_user_to_user(user) for user in db_users]
|
|
427
381
|
finally:
|
|
428
382
|
session.close()
|
|
429
383
|
|
|
430
|
-
def create_access_token(self, user:
|
|
384
|
+
def create_access_token(self, user: AbstractUser, expiry_minutes: int | None, *, title: str | None = None) -> tuple[str, datetime]:
|
|
431
385
|
"""
|
|
432
386
|
Creates an API key if title is provided. Otherwise, creates a JWT token.
|
|
433
387
|
"""
|
|
@@ -440,7 +394,7 @@ class Authenticator(_t.Generic[User]):
|
|
|
440
394
|
if title is not None:
|
|
441
395
|
session = self.Session()
|
|
442
396
|
try:
|
|
443
|
-
token_id = "sqrl-" +
|
|
397
|
+
token_id = "sqrl-" + secrets.token_urlsafe(16)
|
|
444
398
|
hashed_key = u.hash_string(token_id, salt=self.secret_key)
|
|
445
399
|
api_key = self.DbApiKey(hashed_key=hashed_key, title=title, username=user.username, created_at=created_at, expires_at=expire_at)
|
|
446
400
|
session.add(api_key)
|
|
@@ -453,7 +407,7 @@ class Authenticator(_t.Generic[User]):
|
|
|
453
407
|
|
|
454
408
|
return token_id, expire_at
|
|
455
409
|
|
|
456
|
-
def get_user_from_token(self, token: str | None) ->
|
|
410
|
+
def get_user_from_token(self, token: str | None) -> RegisteredUser | None:
|
|
457
411
|
"""
|
|
458
412
|
Get a user from an access token (JWT, or API key if token starts with 'sqrl-')
|
|
459
413
|
"""
|
|
@@ -481,14 +435,15 @@ class Authenticator(_t.Generic[User]):
|
|
|
481
435
|
db_user = session.get(self.DbUser, username)
|
|
482
436
|
if db_user is None:
|
|
483
437
|
raise InvalidTokenError()
|
|
438
|
+
|
|
439
|
+
user = self._convert_db_user_to_user(db_user)
|
|
484
440
|
|
|
485
441
|
except InvalidTokenError:
|
|
486
|
-
raise InvalidInputError(401, "
|
|
442
|
+
raise InvalidInputError(401, "invalid_authorization_token", "Invalid authorization token")
|
|
487
443
|
finally:
|
|
488
444
|
session.close()
|
|
489
445
|
|
|
490
|
-
user
|
|
491
|
-
return user # type: ignore
|
|
446
|
+
return user
|
|
492
447
|
|
|
493
448
|
def get_all_api_keys(self, username: str) -> list[ApiKey]:
|
|
494
449
|
"""
|
|
@@ -518,19 +473,19 @@ class Authenticator(_t.Generic[User]):
|
|
|
518
473
|
).first()
|
|
519
474
|
|
|
520
475
|
if api_key is None:
|
|
521
|
-
raise InvalidInputError(404, "
|
|
476
|
+
raise InvalidInputError(404, "api_key_not_found", f"The API key could not be found: {api_key_id}")
|
|
522
477
|
|
|
523
478
|
session.delete(api_key)
|
|
524
479
|
session.commit()
|
|
525
480
|
finally:
|
|
526
481
|
session.close()
|
|
527
482
|
|
|
528
|
-
def can_user_access_scope(self, user:
|
|
529
|
-
if user
|
|
483
|
+
def can_user_access_scope(self, user: AbstractUser, scope: PermissionScope) -> bool:
|
|
484
|
+
if user.access_level == "guest":
|
|
530
485
|
user_level = PermissionScope.PUBLIC
|
|
531
|
-
elif user.
|
|
486
|
+
elif user.access_level == "admin":
|
|
532
487
|
user_level = PermissionScope.PRIVATE
|
|
533
|
-
else:
|
|
488
|
+
else: # member
|
|
534
489
|
user_level = PermissionScope.PROTECTED
|
|
535
490
|
|
|
536
491
|
return user_level.value >= scope.value
|
|
@@ -840,15 +795,15 @@ class Authenticator(_t.Generic[User]):
|
|
|
840
795
|
raise InvalidInputError(400, "invalid_grant", "Invalid code_verifier")
|
|
841
796
|
|
|
842
797
|
# Get user
|
|
843
|
-
|
|
844
|
-
if
|
|
798
|
+
db_user = session.get(self.DbUser, auth_code.username)
|
|
799
|
+
if db_user is None:
|
|
845
800
|
raise InvalidInputError(400, "invalid_grant", "User not found")
|
|
846
801
|
|
|
847
802
|
# Mark authorization code as used
|
|
848
803
|
auth_code.used = True
|
|
849
804
|
|
|
850
805
|
# Generate tokens
|
|
851
|
-
user_obj = self.
|
|
806
|
+
user_obj = self._convert_db_user_to_user(db_user)
|
|
852
807
|
access_token, token_expires_at = self.create_access_token(user_obj, expiry_minutes=access_token_expiry_minutes)
|
|
853
808
|
access_token_hash = pwd_context.hash(access_token)
|
|
854
809
|
|
|
@@ -902,8 +857,8 @@ class Authenticator(_t.Generic[User]):
|
|
|
902
857
|
raise InvalidInputError(400, "invalid_grant", "Invalid or expired refresh token")
|
|
903
858
|
|
|
904
859
|
# Get user
|
|
905
|
-
|
|
906
|
-
if
|
|
860
|
+
db_user = session.get(self.DbUser, oauth_token.username)
|
|
861
|
+
if db_user is None:
|
|
907
862
|
raise InvalidInputError(400, "invalid_client", "User not found")
|
|
908
863
|
|
|
909
864
|
# Check secret key is available
|
|
@@ -911,7 +866,7 @@ class Authenticator(_t.Generic[User]):
|
|
|
911
866
|
raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required for OAuth token operations")
|
|
912
867
|
|
|
913
868
|
# Generate new tokens
|
|
914
|
-
user_obj = self.
|
|
869
|
+
user_obj = self._convert_db_user_to_user(db_user)
|
|
915
870
|
access_token, token_expires_at = self.create_access_token(user_obj, expiry_minutes=access_token_expiry_minutes)
|
|
916
871
|
access_token_hash = pwd_context.hash(access_token)
|
|
917
872
|
|