squirrels 0.5.0rc0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of squirrels might be problematic. Click here for more details.
- dateutils/__init__.py +6 -0
- dateutils/_enums.py +25 -0
- squirrels/dateutils.py → dateutils/_implementation.py +58 -111
- dateutils/types.py +6 -0
- squirrels/__init__.py +10 -12
- squirrels/_api_routes/__init__.py +5 -0
- squirrels/_api_routes/auth.py +271 -0
- squirrels/_api_routes/base.py +171 -0
- squirrels/_api_routes/dashboards.py +158 -0
- squirrels/_api_routes/data_management.py +148 -0
- squirrels/_api_routes/datasets.py +265 -0
- squirrels/_api_routes/oauth2.py +298 -0
- squirrels/_api_routes/project.py +252 -0
- squirrels/_api_server.py +245 -781
- squirrels/_arguments/__init__.py +0 -0
- squirrels/{arguments → _arguments}/init_time_args.py +7 -2
- squirrels/{arguments → _arguments}/run_time_args.py +13 -35
- squirrels/_auth.py +720 -212
- squirrels/_command_line.py +81 -41
- squirrels/_compile_prompts.py +147 -0
- squirrels/_connection_set.py +16 -7
- squirrels/_constants.py +29 -9
- squirrels/{_dashboards_io.py → _dashboards.py} +87 -6
- squirrels/_data_sources.py +570 -0
- squirrels/{dataset_result.py → _dataset_types.py} +2 -4
- squirrels/_exceptions.py +9 -37
- squirrels/_initializer.py +83 -59
- squirrels/_logging.py +117 -0
- squirrels/_manifest.py +129 -62
- squirrels/_model_builder.py +10 -52
- squirrels/_model_configs.py +3 -3
- squirrels/_model_queries.py +1 -1
- squirrels/_models.py +249 -118
- squirrels/{package_data → _package_data}/base_project/.env +16 -4
- squirrels/{package_data → _package_data}/base_project/.env.example +15 -3
- squirrels/{package_data → _package_data}/base_project/connections.yml +4 -3
- squirrels/{package_data → _package_data}/base_project/dashboards/dashboard_example.py +4 -4
- squirrels/_package_data/base_project/dashboards/dashboard_example.yml +22 -0
- squirrels/{package_data → _package_data}/base_project/duckdb_init.sql +1 -0
- squirrels/_package_data/base_project/macros/macros_example.sql +17 -0
- squirrels/{package_data → _package_data}/base_project/models/builds/build_example.py +2 -2
- squirrels/{package_data → _package_data}/base_project/models/builds/build_example.sql +1 -1
- squirrels/{package_data → _package_data}/base_project/models/builds/build_example.yml +2 -0
- squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +17 -0
- squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +32 -0
- squirrels/_package_data/base_project/models/federates/federate_example.py +48 -0
- squirrels/_package_data/base_project/models/federates/federate_example.sql +21 -0
- squirrels/{package_data → _package_data}/base_project/models/federates/federate_example.yml +7 -7
- squirrels/{package_data → _package_data}/base_project/models/sources.yml +5 -6
- squirrels/{package_data → _package_data}/base_project/parameters.yml +32 -45
- squirrels/_package_data/base_project/pyconfigs/connections.py +18 -0
- squirrels/{package_data → _package_data}/base_project/pyconfigs/context.py +31 -22
- squirrels/_package_data/base_project/pyconfigs/parameters.py +141 -0
- squirrels/_package_data/base_project/pyconfigs/user.py +44 -0
- squirrels/{package_data → _package_data}/base_project/seeds/seed_categories.yml +1 -1
- squirrels/{package_data → _package_data}/base_project/seeds/seed_subcategories.yml +1 -1
- squirrels/_package_data/base_project/squirrels.yml.j2 +61 -0
- squirrels/_package_data/templates/dataset_results.html +112 -0
- squirrels/_package_data/templates/oauth_login.html +271 -0
- squirrels/_package_data/templates/squirrels_studio.html +20 -0
- squirrels/_parameter_configs.py +76 -55
- squirrels/_parameter_options.py +348 -0
- squirrels/_parameter_sets.py +53 -45
- squirrels/_parameters.py +1664 -0
- squirrels/_project.py +403 -242
- squirrels/_py_module.py +3 -2
- squirrels/_request_context.py +33 -0
- squirrels/_schemas/__init__.py +0 -0
- squirrels/_schemas/auth_models.py +167 -0
- squirrels/_schemas/query_param_models.py +75 -0
- squirrels/{_api_response_models.py → _schemas/response_models.py} +48 -18
- squirrels/_seeds.py +1 -1
- squirrels/_sources.py +23 -19
- squirrels/_utils.py +121 -39
- squirrels/_version.py +1 -1
- squirrels/arguments.py +7 -0
- squirrels/auth.py +4 -0
- squirrels/connections.py +3 -0
- squirrels/dashboards.py +2 -81
- squirrels/data_sources.py +14 -563
- squirrels/parameter_options.py +13 -348
- squirrels/parameters.py +14 -1266
- squirrels/types.py +16 -0
- {squirrels-0.5.0rc0.dist-info → squirrels-0.5.1.dist-info}/METADATA +42 -30
- squirrels-0.5.1.dist-info/RECORD +98 -0
- squirrels/package_data/base_project/dashboards/dashboard_example.yml +0 -22
- squirrels/package_data/base_project/macros/macros_example.sql +0 -15
- squirrels/package_data/base_project/models/dbviews/dbview_example.sql +0 -12
- squirrels/package_data/base_project/models/dbviews/dbview_example.yml +0 -26
- squirrels/package_data/base_project/models/federates/federate_example.py +0 -44
- squirrels/package_data/base_project/models/federates/federate_example.sql +0 -17
- squirrels/package_data/base_project/pyconfigs/connections.py +0 -14
- squirrels/package_data/base_project/pyconfigs/parameters.py +0 -93
- squirrels/package_data/base_project/pyconfigs/user.py +0 -23
- squirrels/package_data/base_project/squirrels.yml.j2 +0 -71
- squirrels-0.5.0rc0.dist-info/RECORD +0 -70
- /squirrels/{package_data → _package_data}/base_project/assets/expenses.db +0 -0
- /squirrels/{package_data → _package_data}/base_project/assets/weather.db +0 -0
- /squirrels/{package_data → _package_data}/base_project/docker/.dockerignore +0 -0
- /squirrels/{package_data → _package_data}/base_project/docker/Dockerfile +0 -0
- /squirrels/{package_data → _package_data}/base_project/docker/compose.yml +0 -0
- /squirrels/{package_data/base_project/.gitignore → _package_data/base_project/gitignore} +0 -0
- /squirrels/{package_data → _package_data}/base_project/seeds/seed_categories.csv +0 -0
- /squirrels/{package_data → _package_data}/base_project/seeds/seed_subcategories.csv +0 -0
- /squirrels/{package_data → _package_data}/base_project/tmp/.gitignore +0 -0
- {squirrels-0.5.0rc0.dist-info → squirrels-0.5.1.dist-info}/WHEEL +0 -0
- {squirrels-0.5.0rc0.dist-info → squirrels-0.5.1.dist-info}/entry_points.txt +0 -0
- {squirrels-0.5.0rc0.dist-info → squirrels-0.5.1.dist-info}/licenses/LICENSE +0 -0
squirrels/_auth.py
CHANGED
|
@@ -1,96 +1,137 @@
|
|
|
1
|
+
from typing import Callable, Any
|
|
1
2
|
from datetime import datetime, timedelta, timezone
|
|
2
|
-
from enum import Enum
|
|
3
3
|
from functools import cached_property
|
|
4
4
|
from jwt.exceptions import InvalidTokenError
|
|
5
5
|
from passlib.context import CryptContext
|
|
6
|
-
from pydantic import
|
|
7
|
-
from pydantic_core import PydanticUndefined
|
|
6
|
+
from pydantic import ValidationError
|
|
8
7
|
from sqlalchemy import create_engine, Engine, func, inspect, text, ForeignKey
|
|
9
|
-
from sqlalchemy import Column, String, Integer, Float, Boolean
|
|
10
8
|
from sqlalchemy.orm import declarative_base, sessionmaker, Mapped, mapped_column
|
|
11
|
-
import jwt,
|
|
9
|
+
import jwt, uuid, secrets, json
|
|
12
10
|
|
|
13
11
|
from ._manifest import PermissionScope
|
|
14
12
|
from ._py_module import PyModule
|
|
15
13
|
from ._exceptions import InvalidInputError, ConfigurationError
|
|
14
|
+
from ._arguments.init_time_args import AuthProviderArgs
|
|
15
|
+
from ._schemas.auth_models import (
|
|
16
|
+
CustomUserFields, AbstractUser, GuestUser, RegisteredUser, ApiKey, UserField, AuthProvider, ProviderConfigs,
|
|
17
|
+
ClientRegistrationRequest, ClientUpdateRequest, ClientDetailsResponse, ClientRegistrationResponse,
|
|
18
|
+
ClientUpdateResponse, TokenResponse
|
|
19
|
+
)
|
|
16
20
|
from . import _utils as u, _constants as c
|
|
17
21
|
|
|
18
22
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
19
23
|
|
|
20
|
-
|
|
21
|
-
disallowed_fields = ["password", "password_hash", "created_at", "token_id", "exp"]
|
|
24
|
+
ProviderFunctionType = Callable[[AuthProviderArgs], AuthProvider]
|
|
22
25
|
|
|
23
|
-
class BaseUser(BaseModel):
|
|
24
|
-
model_config = ConfigDict(from_attributes=True)
|
|
25
|
-
username: str
|
|
26
|
-
is_admin: bool = False
|
|
27
|
-
|
|
28
|
-
@classmethod
|
|
29
|
-
def dropped_columns(cls):
|
|
30
|
-
return []
|
|
31
|
-
|
|
32
|
-
def __hash__(self):
|
|
33
|
-
return hash(self.username)
|
|
34
|
-
|
|
35
|
-
User = _t.TypeVar('User', bound=BaseUser)
|
|
36
|
-
|
|
37
|
-
class AccessToken(BaseModel):
|
|
38
|
-
model_config = ConfigDict(from_attributes=True)
|
|
39
|
-
token_id: str
|
|
40
|
-
title: str
|
|
41
|
-
username: str
|
|
42
|
-
created_at: datetime
|
|
43
|
-
expires_at: datetime
|
|
44
26
|
|
|
27
|
+
class Authenticator:
|
|
28
|
+
providers: list[ProviderFunctionType] = [] # static variable to stage providers
|
|
45
29
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
enum: list[str] | None
|
|
51
|
-
default: _t.Any | None
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
class Authenticator(_t.Generic[User]):
|
|
55
|
-
def __init__(self, logger: u.Logger, base_path: str, env_vars: dict[str, str], *, sa_engine: Engine | None = None, cls: type[User] | None = None):
|
|
30
|
+
def __init__(
|
|
31
|
+
self, logger: u.Logger, base_path: str, auth_args: AuthProviderArgs, provider_functions: list[ProviderFunctionType],
|
|
32
|
+
custom_user_fields_cls: type[CustomUserFields], *, sa_engine: Engine | None = None, external_only: bool = False
|
|
33
|
+
):
|
|
56
34
|
self.logger = logger
|
|
57
|
-
self.env_vars = env_vars
|
|
35
|
+
self.env_vars = auth_args.env_vars
|
|
58
36
|
self.secret_key = self.env_vars.get(c.SQRL_SECRET_KEY)
|
|
37
|
+
self.external_only = external_only
|
|
59
38
|
|
|
60
39
|
# Create a new declarative base for this instance
|
|
61
40
|
self.Base = declarative_base()
|
|
62
41
|
|
|
63
|
-
# Define
|
|
64
|
-
class
|
|
42
|
+
# Define DbUser class for this instance
|
|
43
|
+
class DbUser(self.Base):
|
|
65
44
|
__tablename__ = 'users'
|
|
66
45
|
__table_args__ = {'extend_existing': True}
|
|
67
46
|
username: Mapped[str] = mapped_column(primary_key=True)
|
|
68
|
-
|
|
47
|
+
access_level: Mapped[str] = mapped_column(nullable=False, default="member")
|
|
69
48
|
password_hash: Mapped[str] = mapped_column(nullable=False)
|
|
49
|
+
custom_fields: Mapped[str] = mapped_column(nullable=False, default="{}")
|
|
70
50
|
created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
|
|
71
51
|
|
|
72
|
-
# Define
|
|
73
|
-
class
|
|
74
|
-
__tablename__ = '
|
|
52
|
+
# Define DbApiKey class for this instance
|
|
53
|
+
class DbApiKey(self.Base):
|
|
54
|
+
__tablename__ = 'api_keys'
|
|
75
55
|
|
|
76
|
-
|
|
56
|
+
id: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
|
|
57
|
+
hashed_key: Mapped[str] = mapped_column(unique=True, nullable=False)
|
|
77
58
|
title: Mapped[str] = mapped_column(nullable=False)
|
|
78
59
|
username: Mapped[str] = mapped_column(ForeignKey('users.username', ondelete='CASCADE'), nullable=False)
|
|
79
60
|
created_at: Mapped[datetime] = mapped_column(nullable=False)
|
|
80
61
|
expires_at: Mapped[datetime] = mapped_column(nullable=False)
|
|
81
|
-
|
|
62
|
+
|
|
63
|
+
def __repr__(self):
|
|
64
|
+
return f"<DbApiKey(id='{self.id}', username='{self.username}')>"
|
|
65
|
+
|
|
66
|
+
# Define DbOAuthClient class for this instance
|
|
67
|
+
class DbOAuthClient(self.Base):
|
|
68
|
+
__tablename__ = 'oauth_clients'
|
|
69
|
+
|
|
70
|
+
client_id: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
|
|
71
|
+
client_secret_hash: Mapped[str] = mapped_column(nullable=False)
|
|
72
|
+
client_name: Mapped[str] = mapped_column(nullable=False)
|
|
73
|
+
redirect_uris: Mapped[str] = mapped_column(nullable=False) # JSON array of allowed redirect URIs
|
|
74
|
+
scope: Mapped[str] = mapped_column(nullable=False, default='read')
|
|
75
|
+
grant_types: Mapped[str] = mapped_column(nullable=False, default='authorization_code,refresh_token')
|
|
76
|
+
response_types: Mapped[str] = mapped_column(nullable=False, default='code')
|
|
77
|
+
client_type: Mapped[str] = mapped_column(nullable=False, default='confidential') # 'confidential' or 'public'
|
|
78
|
+
registration_access_token_hash: Mapped[str] = mapped_column(nullable=False) # Token for client management
|
|
79
|
+
created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
|
|
80
|
+
is_active: Mapped[bool] = mapped_column(nullable=False, default=True)
|
|
81
|
+
|
|
82
|
+
def __repr__(self):
|
|
83
|
+
return f"<DbOAuthClient(client_id='{self.client_id}', name='{self.client_name}')>"
|
|
84
|
+
|
|
85
|
+
# Define DbAuthorizationCode class for this instance
|
|
86
|
+
class DbAuthorizationCode(self.Base):
|
|
87
|
+
__tablename__ = 'authorization_codes'
|
|
88
|
+
|
|
89
|
+
code: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
|
|
90
|
+
client_id: Mapped[str] = mapped_column(ForeignKey('oauth_clients.client_id', ondelete='CASCADE'), nullable=False)
|
|
91
|
+
username: Mapped[str] = mapped_column(ForeignKey('users.username', ondelete='CASCADE'), nullable=False)
|
|
92
|
+
redirect_uri: Mapped[str] = mapped_column(nullable=False)
|
|
93
|
+
scope: Mapped[str] = mapped_column(nullable=True)
|
|
94
|
+
code_challenge: Mapped[str] = mapped_column(nullable=False) # PKCE always required
|
|
95
|
+
code_challenge_method: Mapped[str] = mapped_column(nullable=False) # only S256 is supported
|
|
96
|
+
created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
|
|
97
|
+
expires_at: Mapped[datetime] = mapped_column(nullable=False) # 10 minutes from creation
|
|
98
|
+
used: Mapped[bool] = mapped_column(nullable=False, default=False)
|
|
99
|
+
|
|
100
|
+
def __repr__(self):
|
|
101
|
+
return f"<DbAuthorizationCode(code='{self.code[:8]}...', client_id='{self.client_id}')>"
|
|
102
|
+
|
|
103
|
+
# Define DbOAuthToken class for this instance
|
|
104
|
+
class DbOAuthToken(self.Base):
|
|
105
|
+
__tablename__ = 'oauth_tokens'
|
|
106
|
+
|
|
107
|
+
token_id: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
|
|
108
|
+
access_token_hash: Mapped[str] = mapped_column(unique=True, nullable=False)
|
|
109
|
+
refresh_token_hash: Mapped[str] = mapped_column(unique=True, nullable=True) # NULL for client_credentials grants
|
|
110
|
+
client_id: Mapped[str] = mapped_column(ForeignKey('oauth_clients.client_id', ondelete='CASCADE'), nullable=False)
|
|
111
|
+
username: Mapped[str] = mapped_column(ForeignKey('users.username', ondelete='CASCADE'), nullable=False)
|
|
112
|
+
scope: Mapped[str] = mapped_column(nullable=True)
|
|
113
|
+
token_type: Mapped[str] = mapped_column(nullable=False, default='Bearer')
|
|
114
|
+
created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
|
|
115
|
+
access_token_expires_at: Mapped[datetime] = mapped_column(nullable=False) # Uses SQRL_AUTH_TOKEN_EXPIRE_MINUTES
|
|
116
|
+
refresh_token_expires_at: Mapped[datetime] = mapped_column(nullable=True) # 30 days from creation, NULL for client_credentials
|
|
117
|
+
is_revoked: Mapped[bool] = mapped_column(nullable=False, default=False)
|
|
118
|
+
|
|
82
119
|
def __repr__(self):
|
|
83
|
-
return f"<
|
|
120
|
+
return f"<DbOAuthToken(token_id='{self.token_id}', client_id='{self.client_id}', username='{self.username}')>"
|
|
84
121
|
|
|
85
|
-
self.
|
|
86
|
-
self.
|
|
122
|
+
self.DbApiKey = DbApiKey
|
|
123
|
+
self.DbOAuthClient = DbOAuthClient
|
|
124
|
+
self.DbAuthorizationCode = DbAuthorizationCode
|
|
125
|
+
self.DbOAuthToken = DbOAuthToken
|
|
87
126
|
|
|
88
|
-
self.
|
|
89
|
-
self.DbUser
|
|
127
|
+
self.CustomUserFields = custom_user_fields_cls
|
|
128
|
+
self.DbUser = DbUser
|
|
129
|
+
|
|
130
|
+
self.auth_providers = [provider_function(auth_args) for provider_function in provider_functions]
|
|
90
131
|
|
|
91
132
|
if sa_engine is None:
|
|
92
|
-
|
|
93
|
-
sqlite_path = u.Path(base_path
|
|
133
|
+
raw_sqlite_path = self.env_vars.get(c.SQRL_AUTH_DB_FILE_PATH, f"{{project_path}}/{c.TARGET_FOLDER}/{c.DB_FILE}")
|
|
134
|
+
sqlite_path = u.Path(raw_sqlite_path.format(project_path=base_path))
|
|
94
135
|
sqlite_path.parent.mkdir(parents=True, exist_ok=True)
|
|
95
136
|
self.engine = create_engine(f"sqlite:///{str(sqlite_path)}")
|
|
96
137
|
else:
|
|
@@ -106,67 +147,21 @@ class Authenticator(_t.Generic[User]):
|
|
|
106
147
|
|
|
107
148
|
self.Session = sessionmaker(bind=self.engine)
|
|
108
149
|
|
|
109
|
-
self._initialize_db(
|
|
150
|
+
self._initialize_db()
|
|
110
151
|
|
|
111
|
-
def
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
return
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
# Iterate over all fields in the User model
|
|
124
|
-
for field_name, field in self.User.model_fields.items():
|
|
125
|
-
if field_name in reserved_fields:
|
|
126
|
-
continue
|
|
127
|
-
if field_name in disallowed_fields:
|
|
128
|
-
raise ConfigurationError(f"Field name '{field_name}' is disallowed in the User model and cannot be used")
|
|
129
|
-
|
|
130
|
-
field_type = field.annotation
|
|
131
|
-
if _t.get_origin(field_type) in (_t.Union, types.UnionType):
|
|
132
|
-
field_type = _t.get_args(field_type)[0]
|
|
133
|
-
nullable = True
|
|
134
|
-
else:
|
|
135
|
-
nullable = False
|
|
152
|
+
def _convert_db_user_to_user(self, db_user) -> RegisteredUser:
|
|
153
|
+
"""Convert a database user to an AbstractUser object"""
|
|
154
|
+
# Deserialize custom_fields JSON and merge with defaults
|
|
155
|
+
custom_fields_json = json.loads(db_user.custom_fields) if db_user.custom_fields else {}
|
|
156
|
+
custom_fields = self.CustomUserFields(**custom_fields_json)
|
|
157
|
+
|
|
158
|
+
return RegisteredUser(
|
|
159
|
+
username=db_user.username,
|
|
160
|
+
access_level=db_user.access_level,
|
|
161
|
+
custom_fields=custom_fields
|
|
162
|
+
)
|
|
136
163
|
|
|
137
|
-
|
|
138
|
-
field_type = str
|
|
139
|
-
|
|
140
|
-
# Map Python types and default values to SQLAlchemy columns
|
|
141
|
-
default_value = field.default
|
|
142
|
-
if default_value is PydanticUndefined:
|
|
143
|
-
raise ConfigurationError(f"No default value found for field '{field_name}' in User model")
|
|
144
|
-
elif not nullable and default_value is None:
|
|
145
|
-
raise ConfigurationError(f"Default value for non-nullable field '{field_name}' was set as None in User model")
|
|
146
|
-
elif default_value is not None and type(default_value) is not field_type:
|
|
147
|
-
raise ConfigurationError(f"Default value for field '{field_name}' does not match field type in User model")
|
|
148
|
-
|
|
149
|
-
if field_type == str:
|
|
150
|
-
col_type = String
|
|
151
|
-
elif field_type == int:
|
|
152
|
-
col_type = Integer
|
|
153
|
-
elif field_type == float:
|
|
154
|
-
col_type = Float
|
|
155
|
-
elif field_type == bool:
|
|
156
|
-
col_type = Boolean
|
|
157
|
-
elif isinstance(field_type, type) and issubclass(field_type, Enum):
|
|
158
|
-
col_type = String
|
|
159
|
-
default_value = default_value.value
|
|
160
|
-
else:
|
|
161
|
-
continue
|
|
162
|
-
|
|
163
|
-
attrs[field_name] = Column(col_type, nullable=nullable, default=default_value) # type: ignore
|
|
164
|
-
|
|
165
|
-
# Create the sqlalchemy model class
|
|
166
|
-
DbUser = type('DbUser', (self.DbBaseUser,), attrs)
|
|
167
|
-
return DbUser
|
|
168
|
-
|
|
169
|
-
def _initialize_db(self, *args): # TODO: Use logger instead of print
|
|
164
|
+
def _initialize_db(self): # TODO: Use logger instead of print
|
|
170
165
|
session = self.Session()
|
|
171
166
|
try:
|
|
172
167
|
# Get existing columns in the database
|
|
@@ -174,14 +169,12 @@ class Authenticator(_t.Generic[User]):
|
|
|
174
169
|
existing_columns = {col['name'] for col in inspector.get_columns('users')}
|
|
175
170
|
|
|
176
171
|
# Get all columns defined in the model
|
|
177
|
-
|
|
178
|
-
model_columns = set(self.DbUser.__table__.columns.keys()) - dropped_columns
|
|
172
|
+
model_columns = set(self.DbUser.__table__.columns.keys())
|
|
179
173
|
|
|
180
174
|
# Find columns that are in the model but not in the database
|
|
181
175
|
new_columns = model_columns - existing_columns
|
|
182
176
|
if new_columns:
|
|
183
177
|
add_columns_msg = f"Adding columns to database: {new_columns}"
|
|
184
|
-
print("NOTE:", add_columns_msg)
|
|
185
178
|
self.logger.info(add_columns_msg)
|
|
186
179
|
|
|
187
180
|
for col_name in new_columns:
|
|
@@ -198,24 +191,6 @@ class Authenticator(_t.Generic[User]):
|
|
|
198
191
|
session.execute(text(alter_stmt))
|
|
199
192
|
|
|
200
193
|
session.commit()
|
|
201
|
-
|
|
202
|
-
# Determine columns to drop
|
|
203
|
-
columns_to_drop = dropped_columns.intersection(existing_columns)
|
|
204
|
-
if columns_to_drop:
|
|
205
|
-
drop_columns_msg = f"Dropping columns from database: {columns_to_drop}"
|
|
206
|
-
print("NOTE:", drop_columns_msg)
|
|
207
|
-
self.logger.info(drop_columns_msg)
|
|
208
|
-
|
|
209
|
-
for col_name in columns_to_drop:
|
|
210
|
-
session.execute(text(f"ALTER TABLE users DROP COLUMN {col_name}"))
|
|
211
|
-
|
|
212
|
-
session.commit()
|
|
213
|
-
|
|
214
|
-
# Find columns that are in the database but not in the model
|
|
215
|
-
extra_db_columns = existing_columns - columns_to_drop - model_columns
|
|
216
|
-
if extra_db_columns:
|
|
217
|
-
self.logger.warn(f"The following database columns are not in the User model: {extra_db_columns}\n"
|
|
218
|
-
"If you want to drop these columns, please use the `dropped_columns` class method of the User model.")
|
|
219
194
|
|
|
220
195
|
# Get admin password from environment variable if exists
|
|
221
196
|
admin_password = self.env_vars.get(c.SQRL_SECRET_ADMIN_PASSWORD)
|
|
@@ -225,7 +200,7 @@ class Authenticator(_t.Generic[User]):
|
|
|
225
200
|
password_hash = pwd_context.hash(admin_password)
|
|
226
201
|
admin_user = session.get(self.DbUser, c.ADMIN_USERNAME)
|
|
227
202
|
if admin_user is None:
|
|
228
|
-
admin_user = self.DbUser(username=c.ADMIN_USERNAME, password_hash=password_hash,
|
|
203
|
+
admin_user = self.DbUser(username=c.ADMIN_USERNAME, password_hash=password_hash, access_level="admin")
|
|
229
204
|
session.add(admin_user)
|
|
230
205
|
else:
|
|
231
206
|
admin_user.password_hash = password_hash
|
|
@@ -238,7 +213,7 @@ class Authenticator(_t.Generic[User]):
|
|
|
238
213
|
@cached_property
|
|
239
214
|
def user_fields(self) -> list[UserField]:
|
|
240
215
|
"""
|
|
241
|
-
Get the fields of the
|
|
216
|
+
Get the fields of the CustomUserFields model as a list of dictionaries
|
|
242
217
|
|
|
243
218
|
Each dictionary contains the following keys:
|
|
244
219
|
- name: The name of the field
|
|
@@ -247,11 +222,15 @@ class Authenticator(_t.Generic[User]):
|
|
|
247
222
|
- enum: The possible values of the field (or None if not applicable)
|
|
248
223
|
- default: The default value of the field (or None if field is required)
|
|
249
224
|
"""
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
225
|
+
# Add the base fields
|
|
226
|
+
fields = [
|
|
227
|
+
UserField(name="username", type="string", nullable=False, enum=None, default=None),
|
|
228
|
+
UserField(name="access_level", type="string", nullable=False, enum=["admin", "member"], default="member")
|
|
229
|
+
]
|
|
253
230
|
|
|
254
|
-
|
|
231
|
+
# Add custom fields
|
|
232
|
+
schema = self.CustomUserFields.model_json_schema()
|
|
233
|
+
properties: dict[str, dict[str, Any]] = schema.get("properties", {})
|
|
255
234
|
for field_name, field_schema in properties.items():
|
|
256
235
|
if choices := field_schema.get("anyOf"):
|
|
257
236
|
field_type = choices[0]["type"]
|
|
@@ -268,54 +247,99 @@ class Authenticator(_t.Generic[User]):
|
|
|
268
247
|
def add_user(self, username: str, user_fields: dict, *, update_user: bool = False) -> None:
|
|
269
248
|
session = self.Session()
|
|
270
249
|
|
|
271
|
-
#
|
|
250
|
+
# Separate custom fields from base fields
|
|
251
|
+
access_level = user_fields.get('access_level', 'member')
|
|
252
|
+
password = user_fields.get('password')
|
|
253
|
+
|
|
254
|
+
# Validate access level - cannot add/update users with guest access level
|
|
255
|
+
if access_level == "guest":
|
|
256
|
+
raise InvalidInputError(400, "invalid_access_level", "Cannot create or update users with 'guest' access level")
|
|
257
|
+
|
|
258
|
+
# Extract custom fields (everything except username, access_level, password)
|
|
259
|
+
custom_fields_dict = {k: v for k, v in user_fields.items() if k not in ['username', 'access_level', 'password']}
|
|
260
|
+
|
|
261
|
+
# Validate the custom fields
|
|
272
262
|
try:
|
|
273
|
-
|
|
263
|
+
custom_fields = self.CustomUserFields(**custom_fields_dict)
|
|
264
|
+
custom_fields_json = json.dumps(custom_fields.model_dump(mode='json'))
|
|
274
265
|
except ValidationError as e:
|
|
275
|
-
raise InvalidInputError(
|
|
266
|
+
raise InvalidInputError(400, "invalid_user_data", f"Invalid user field '{e.errors()[0]['loc'][0]}': {e.errors()[0]['msg']}")
|
|
276
267
|
|
|
277
|
-
# Add
|
|
268
|
+
# Add or update user
|
|
278
269
|
try:
|
|
279
270
|
# Check if the user already exists
|
|
280
271
|
existing_user = session.get(self.DbUser, username)
|
|
281
272
|
if existing_user is not None:
|
|
282
273
|
if not update_user:
|
|
283
|
-
raise InvalidInputError(
|
|
274
|
+
raise InvalidInputError(400, "username_already_exists", f"User '{username}' already exists")
|
|
284
275
|
|
|
285
|
-
if username == c.ADMIN_USERNAME:
|
|
286
|
-
raise InvalidInputError(
|
|
287
|
-
|
|
288
|
-
|
|
276
|
+
if username == c.ADMIN_USERNAME and access_level != "admin":
|
|
277
|
+
raise InvalidInputError(403, "admin_cannot_be_non_admin", "Setting the admin user to non-admin is not permitted")
|
|
278
|
+
|
|
279
|
+
# Update existing user
|
|
280
|
+
existing_user.access_level = access_level
|
|
281
|
+
existing_user.custom_fields = custom_fields_json
|
|
289
282
|
else:
|
|
290
283
|
if update_user:
|
|
291
|
-
raise InvalidInputError(
|
|
284
|
+
raise InvalidInputError(404, "no_user_found_for_username", f"No user found for username: {username}")
|
|
292
285
|
|
|
293
|
-
password = user_fields.get('password')
|
|
294
286
|
if password is None:
|
|
295
|
-
raise InvalidInputError(
|
|
287
|
+
raise InvalidInputError(400, "missing_password", f"Missing required field 'password' when adding a new user")
|
|
288
|
+
|
|
296
289
|
password_hash = pwd_context.hash(password)
|
|
297
|
-
new_user = self.DbUser(
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
290
|
+
new_user = self.DbUser(
|
|
291
|
+
username=username,
|
|
292
|
+
access_level=access_level,
|
|
293
|
+
password_hash=password_hash,
|
|
294
|
+
custom_fields=custom_fields_json
|
|
295
|
+
)
|
|
296
|
+
session.add(new_user)
|
|
301
297
|
|
|
302
298
|
# Commit the transaction
|
|
303
299
|
session.commit()
|
|
304
300
|
|
|
305
301
|
finally:
|
|
306
302
|
session.close()
|
|
303
|
+
|
|
304
|
+
def create_or_get_user_from_provider(self, provider_name: str, user_info: dict) -> RegisteredUser:
|
|
305
|
+
provider = next((p for p in self.auth_providers if p.name == provider_name), None)
|
|
306
|
+
if provider is None:
|
|
307
|
+
raise InvalidInputError(404, "auth_provider_not_found", f"Provider '{provider_name}' not found")
|
|
308
|
+
|
|
309
|
+
user = provider.provider_configs.get_user(user_info)
|
|
310
|
+
session = self.Session()
|
|
311
|
+
try:
|
|
312
|
+
# Convert user to database user
|
|
313
|
+
custom_fields_json = user.custom_fields.model_dump_json()
|
|
314
|
+
db_user = self.DbUser(
|
|
315
|
+
username=user.username,
|
|
316
|
+
access_level=user.access_level,
|
|
317
|
+
password_hash="", # By omitting password_hash, it becomes impossible to login with username and password (OAuth only)
|
|
318
|
+
custom_fields=custom_fields_json
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
existing_db_user = session.get(self.DbUser, db_user.username)
|
|
322
|
+
if existing_db_user is None:
|
|
323
|
+
session.add(db_user)
|
|
324
|
+
|
|
325
|
+
session.commit()
|
|
326
|
+
|
|
327
|
+
return self._convert_db_user_to_user(db_user)
|
|
328
|
+
|
|
329
|
+
finally:
|
|
330
|
+
session.close()
|
|
307
331
|
|
|
308
|
-
def get_user(self, username: str, password: str) ->
|
|
332
|
+
def get_user(self, username: str, password: str) -> RegisteredUser:
|
|
309
333
|
session = self.Session()
|
|
310
334
|
try:
|
|
311
335
|
# Query for user by username
|
|
312
336
|
db_user = session.get(self.DbUser, username)
|
|
313
337
|
|
|
314
338
|
if db_user and pwd_context.verify(password, db_user.password_hash):
|
|
315
|
-
user = self.
|
|
316
|
-
return user
|
|
339
|
+
user = self._convert_db_user_to_user(db_user)
|
|
340
|
+
return user
|
|
317
341
|
else:
|
|
318
|
-
raise InvalidInputError(
|
|
342
|
+
raise InvalidInputError(401, "incorrect_username_or_password", f"Incorrect username or password")
|
|
319
343
|
|
|
320
344
|
finally:
|
|
321
345
|
session.close()
|
|
@@ -325,127 +349,611 @@ class Authenticator(_t.Generic[User]):
|
|
|
325
349
|
try:
|
|
326
350
|
db_user = session.get(self.DbUser, username)
|
|
327
351
|
if db_user is None:
|
|
328
|
-
raise InvalidInputError(
|
|
352
|
+
raise InvalidInputError(401, "user_not_found", f"Username '{username}' not found for password change")
|
|
329
353
|
|
|
330
|
-
if pwd_context.verify(old_password, db_user.password_hash):
|
|
354
|
+
if db_user.password_hash and pwd_context.verify(old_password, db_user.password_hash):
|
|
331
355
|
db_user.password_hash = pwd_context.hash(new_password)
|
|
332
356
|
session.commit()
|
|
333
357
|
else:
|
|
334
|
-
raise InvalidInputError(
|
|
358
|
+
raise InvalidInputError(401, "incorrect_password", f"Incorrect password")
|
|
335
359
|
finally:
|
|
336
360
|
session.close()
|
|
337
361
|
|
|
338
362
|
def delete_user(self, username: str) -> None:
|
|
339
363
|
if username == c.ADMIN_USERNAME:
|
|
340
|
-
raise InvalidInputError(
|
|
364
|
+
raise InvalidInputError(403, "cannot_delete_admin_user", "Cannot delete the admin user")
|
|
341
365
|
|
|
342
366
|
session = self.Session()
|
|
343
367
|
try:
|
|
344
368
|
db_user = session.get(self.DbUser, username)
|
|
345
369
|
if db_user is None:
|
|
346
|
-
raise InvalidInputError(
|
|
370
|
+
raise InvalidInputError(404, "no_user_found_for_username", f"No user found for username: {username}")
|
|
347
371
|
session.delete(db_user)
|
|
348
372
|
session.commit()
|
|
349
373
|
finally:
|
|
350
374
|
session.close()
|
|
351
375
|
|
|
352
|
-
def get_all_users(self) -> list[
|
|
376
|
+
def get_all_users(self) -> list[RegisteredUser]:
|
|
353
377
|
session = self.Session()
|
|
354
378
|
try:
|
|
355
379
|
db_users = session.query(self.DbUser).all()
|
|
356
|
-
return [self.
|
|
380
|
+
return [self._convert_db_user_to_user(user) for user in db_users]
|
|
357
381
|
finally:
|
|
358
382
|
session.close()
|
|
359
383
|
|
|
360
|
-
def create_access_token(self, user:
|
|
384
|
+
def create_access_token(self, user: AbstractUser, expiry_minutes: int | None, *, title: str | None = None) -> tuple[str, datetime]:
|
|
385
|
+
"""
|
|
386
|
+
Creates an API key if title is provided. Otherwise, creates a JWT token.
|
|
387
|
+
"""
|
|
361
388
|
created_at = datetime.now(timezone.utc)
|
|
362
389
|
expire_at = created_at + timedelta(minutes=expiry_minutes) if expiry_minutes is not None else datetime.max
|
|
363
|
-
|
|
390
|
+
|
|
391
|
+
if self.secret_key is None:
|
|
392
|
+
raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to create an access token")
|
|
393
|
+
|
|
364
394
|
if title is not None:
|
|
365
395
|
session = self.Session()
|
|
366
396
|
try:
|
|
367
|
-
|
|
368
|
-
|
|
397
|
+
token_id = "sqrl-" + secrets.token_urlsafe(16)
|
|
398
|
+
hashed_key = u.hash_string(token_id, salt=self.secret_key)
|
|
399
|
+
api_key = self.DbApiKey(hashed_key=hashed_key, title=title, username=user.username, created_at=created_at, expires_at=expire_at)
|
|
400
|
+
session.add(api_key)
|
|
369
401
|
session.commit()
|
|
370
|
-
token_id = access_token.token_id
|
|
371
402
|
finally:
|
|
372
403
|
session.close()
|
|
404
|
+
else:
|
|
405
|
+
to_encode = {"username": user.username, "exp": expire_at}
|
|
406
|
+
token_id = jwt.encode(to_encode, self.secret_key, algorithm="HS256")
|
|
373
407
|
|
|
374
|
-
|
|
375
|
-
raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to create an access token")
|
|
376
|
-
to_encode = {"username": user.username, "token_id": token_id, "exp": expire_at}
|
|
377
|
-
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm="HS256")
|
|
378
|
-
return encoded_jwt, expire_at
|
|
408
|
+
return token_id, expire_at
|
|
379
409
|
|
|
380
|
-
def get_user_from_token(self, token: str | None) ->
|
|
381
|
-
|
|
410
|
+
def get_user_from_token(self, token: str | None) -> RegisteredUser | None:
|
|
411
|
+
"""
|
|
412
|
+
Get a user from an access token (JWT, or API key if token starts with 'sqrl-')
|
|
413
|
+
"""
|
|
414
|
+
if not token:
|
|
382
415
|
return None
|
|
383
416
|
|
|
384
417
|
if self.secret_key is None:
|
|
385
418
|
raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to get user from an access token")
|
|
386
419
|
|
|
387
|
-
try:
|
|
388
|
-
payload: dict = jwt.decode(token, self.secret_key, algorithms=["HS256"])
|
|
389
|
-
except InvalidTokenError:
|
|
390
|
-
raise InvalidInputError(1, "Invalid authorization token")
|
|
391
|
-
|
|
392
420
|
session = self.Session()
|
|
393
421
|
try:
|
|
394
|
-
if
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
self.
|
|
398
|
-
self.
|
|
422
|
+
if token.startswith("sqrl-"):
|
|
423
|
+
hashed_key = u.hash_string(token, salt=self.secret_key)
|
|
424
|
+
api_key = session.query(self.DbApiKey).filter(
|
|
425
|
+
self.DbApiKey.hashed_key == hashed_key,
|
|
426
|
+
self.DbApiKey.expires_at >= func.now()
|
|
399
427
|
).first()
|
|
400
|
-
if
|
|
401
|
-
raise
|
|
402
|
-
|
|
403
|
-
|
|
428
|
+
if api_key is None:
|
|
429
|
+
raise InvalidTokenError()
|
|
430
|
+
username = api_key.username
|
|
431
|
+
else:
|
|
432
|
+
payload: dict = jwt.decode(token, self.secret_key, algorithms=["HS256"])
|
|
433
|
+
username = payload["username"]
|
|
434
|
+
|
|
435
|
+
db_user = session.get(self.DbUser, username)
|
|
404
436
|
if db_user is None:
|
|
405
|
-
raise
|
|
437
|
+
raise InvalidTokenError()
|
|
438
|
+
|
|
439
|
+
user = self._convert_db_user_to_user(db_user)
|
|
440
|
+
|
|
441
|
+
except InvalidTokenError:
|
|
442
|
+
raise InvalidInputError(401, "invalid_authorization_token", "Invalid authorization token")
|
|
406
443
|
finally:
|
|
407
444
|
session.close()
|
|
408
445
|
|
|
409
|
-
user
|
|
410
|
-
return user # type: ignore
|
|
446
|
+
return user
|
|
411
447
|
|
|
412
|
-
def
|
|
448
|
+
def get_all_api_keys(self, username: str) -> list[ApiKey]:
|
|
449
|
+
"""
|
|
450
|
+
Get the ID, title, and expiry date of all API keys for a user. Note that the ID is a hash of the API key, not the API key itself.
|
|
451
|
+
"""
|
|
413
452
|
session = self.Session()
|
|
414
453
|
try:
|
|
415
|
-
tokens = session.query(self.
|
|
416
|
-
self.
|
|
417
|
-
self.
|
|
454
|
+
tokens = session.query(self.DbApiKey).filter(
|
|
455
|
+
self.DbApiKey.username == username,
|
|
456
|
+
self.DbApiKey.expires_at >= func.now()
|
|
418
457
|
).all()
|
|
419
458
|
|
|
420
|
-
return [
|
|
459
|
+
return [ApiKey.model_validate(token) for token in tokens]
|
|
421
460
|
finally:
|
|
422
461
|
session.close()
|
|
423
462
|
|
|
424
|
-
def
|
|
463
|
+
def revoke_api_key(self, username: str, api_key_id: str) -> None:
|
|
464
|
+
"""
|
|
465
|
+
Revoke an API key
|
|
466
|
+
"""
|
|
425
467
|
session = self.Session()
|
|
426
468
|
try:
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
self.
|
|
469
|
+
|
|
470
|
+
api_key = session.query(self.DbApiKey).filter(
|
|
471
|
+
self.DbApiKey.username == username,
|
|
472
|
+
self.DbApiKey.id == api_key_id
|
|
430
473
|
).first()
|
|
431
474
|
|
|
432
|
-
if
|
|
433
|
-
raise InvalidInputError(
|
|
475
|
+
if api_key is None:
|
|
476
|
+
raise InvalidInputError(404, "api_key_not_found", f"The API key could not be found: {api_key_id}")
|
|
434
477
|
|
|
435
|
-
session.delete(
|
|
478
|
+
session.delete(api_key)
|
|
436
479
|
session.commit()
|
|
437
480
|
finally:
|
|
438
481
|
session.close()
|
|
439
482
|
|
|
440
|
-
def can_user_access_scope(self, user:
|
|
441
|
-
if user
|
|
483
|
+
def can_user_access_scope(self, user: AbstractUser, scope: PermissionScope) -> bool:
|
|
484
|
+
if user.access_level == "guest":
|
|
442
485
|
user_level = PermissionScope.PUBLIC
|
|
443
|
-
elif user.
|
|
486
|
+
elif user.access_level == "admin":
|
|
444
487
|
user_level = PermissionScope.PRIVATE
|
|
445
|
-
else:
|
|
488
|
+
else: # member
|
|
446
489
|
user_level = PermissionScope.PROTECTED
|
|
447
490
|
|
|
448
491
|
return user_level.value >= scope.value
|
|
449
492
|
|
|
493
|
+
# OAuth Client Management Methods
|
|
494
|
+
|
|
495
|
+
def generate_secret_and_hash(self) -> tuple[str, str]:
|
|
496
|
+
"""Generate a secure access token and its hash"""
|
|
497
|
+
secret = secrets.token_urlsafe(64)
|
|
498
|
+
secret_hash = pwd_context.hash(secret)
|
|
499
|
+
return secret, secret_hash
|
|
500
|
+
|
|
501
|
+
def _validate_client_registration_request(self, request: ClientRegistrationRequest | ClientUpdateRequest) -> dict:
|
|
502
|
+
updates = {}
|
|
503
|
+
if request.client_name:
|
|
504
|
+
updates['client_name'] = request.client_name
|
|
505
|
+
|
|
506
|
+
# Validate redirect_uris if being updated
|
|
507
|
+
if request.redirect_uris:
|
|
508
|
+
for uri in request.redirect_uris:
|
|
509
|
+
if not self._validate_redirect_uri_format(uri):
|
|
510
|
+
raise InvalidInputError(400, "invalid_redirect_uri", f"Invalid redirect URI format: {uri}")
|
|
511
|
+
updates['redirect_uris'] = json.dumps(request.redirect_uris)
|
|
512
|
+
|
|
513
|
+
# Validate grant_types if being updated
|
|
514
|
+
if request.grant_types:
|
|
515
|
+
if not all(grant_type in c.SUPPORTED_GRANT_TYPES for grant_type in request.grant_types):
|
|
516
|
+
raise InvalidInputError(400, "invalid_grant_types", f"Invalid grant types. Supported grant types are: {c.SUPPORTED_GRANT_TYPES}")
|
|
517
|
+
updates['grant_types'] = ','.join(request.grant_types)
|
|
518
|
+
|
|
519
|
+
# Validate response_types if being updated
|
|
520
|
+
if request.response_types:
|
|
521
|
+
if not all(response_type in c.SUPPORTED_RESPONSE_TYPES for response_type in request.response_types):
|
|
522
|
+
raise InvalidInputError(400, "invalid_response_types", f"Invalid response types. Supported response types are: {c.SUPPORTED_RESPONSE_TYPES}")
|
|
523
|
+
updates['response_types'] = ','.join(request.response_types)
|
|
524
|
+
|
|
525
|
+
# Validate scope if being updated
|
|
526
|
+
if request.scope:
|
|
527
|
+
scopes = request.scope.split()
|
|
528
|
+
if not all(scope in c.SUPPORTED_SCOPES for scope in scopes):
|
|
529
|
+
raise InvalidInputError(400, "invalid_scope", f"Invalid scope. Supported scopes are: {c.SUPPORTED_SCOPES}")
|
|
530
|
+
updates['scope'] = ','.join(scopes)
|
|
531
|
+
|
|
532
|
+
return updates
|
|
533
|
+
|
|
534
|
+
def register_oauth_client(
|
|
535
|
+
self, request: ClientRegistrationRequest, client_management_path_format: str
|
|
536
|
+
) -> ClientRegistrationResponse:
|
|
537
|
+
"""Register a new OAuth client and return client_id, client_secret, and registration_access_token"""
|
|
538
|
+
grant_types = request.grant_types
|
|
539
|
+
if grant_types is None:
|
|
540
|
+
grant_types = ['authorization_code', 'refresh_token']
|
|
541
|
+
|
|
542
|
+
# Validate request
|
|
543
|
+
self._validate_client_registration_request(request)
|
|
544
|
+
|
|
545
|
+
# Generate secure client credentials and registration access token
|
|
546
|
+
client_id = secrets.token_urlsafe(16)
|
|
547
|
+
client_secret, client_secret_hash = self.generate_secret_and_hash()
|
|
548
|
+
|
|
549
|
+
registration_access_token, registration_access_token_hash = self.generate_secret_and_hash()
|
|
550
|
+
registration_client_uri = client_management_path_format.format(client_id=client_id)
|
|
551
|
+
|
|
552
|
+
session = self.Session()
|
|
553
|
+
try:
|
|
554
|
+
oauth_client = self.DbOAuthClient(
|
|
555
|
+
client_id=client_id,
|
|
556
|
+
client_secret_hash=client_secret_hash,
|
|
557
|
+
client_name=request.client_name,
|
|
558
|
+
redirect_uris=json.dumps(request.redirect_uris),
|
|
559
|
+
scope=request.scope,
|
|
560
|
+
grant_types=','.join(grant_types),
|
|
561
|
+
registration_access_token_hash=registration_access_token_hash
|
|
562
|
+
)
|
|
563
|
+
session.add(oauth_client)
|
|
564
|
+
session.commit()
|
|
565
|
+
|
|
566
|
+
return ClientRegistrationResponse(
|
|
567
|
+
client_id=client_id,
|
|
568
|
+
client_secret=client_secret,
|
|
569
|
+
client_name=request.client_name,
|
|
570
|
+
redirect_uris=request.redirect_uris,
|
|
571
|
+
scope=request.scope,
|
|
572
|
+
grant_types=grant_types,
|
|
573
|
+
response_types=request.response_types,
|
|
574
|
+
created_at=datetime.now(timezone.utc),
|
|
575
|
+
is_active=True,
|
|
576
|
+
registration_client_uri=registration_client_uri,
|
|
577
|
+
registration_access_token=registration_access_token,
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
finally:
|
|
581
|
+
session.close()
|
|
582
|
+
|
|
583
|
+
def get_oauth_client_details(self, client_id: str) -> ClientDetailsResponse:
|
|
584
|
+
"""Get OAuth client details with parsed JSON fields"""
|
|
585
|
+
session = self.Session()
|
|
586
|
+
try:
|
|
587
|
+
client = session.get(self.DbOAuthClient, client_id)
|
|
588
|
+
if client is None:
|
|
589
|
+
raise InvalidInputError(404, "invalid_client_id", "Client not found")
|
|
590
|
+
if not client.is_active:
|
|
591
|
+
raise InvalidInputError(404, "invalid_client_id", "Client is no longer active")
|
|
592
|
+
|
|
593
|
+
return ClientDetailsResponse(
|
|
594
|
+
client_id=client.client_id,
|
|
595
|
+
client_name=client.client_name,
|
|
596
|
+
redirect_uris=json.loads(client.redirect_uris),
|
|
597
|
+
scope=client.scope,
|
|
598
|
+
grant_types=client.grant_types.split(','),
|
|
599
|
+
response_types=client.response_types.split(','),
|
|
600
|
+
created_at=client.created_at,
|
|
601
|
+
is_active=client.is_active
|
|
602
|
+
)
|
|
603
|
+
finally:
|
|
604
|
+
session.close()
|
|
605
|
+
|
|
606
|
+
def validate_client_credentials(self, client_id: str, client_secret: str) -> bool:
|
|
607
|
+
"""Validate OAuth client credentials"""
|
|
608
|
+
session = self.Session()
|
|
609
|
+
try:
|
|
610
|
+
client = session.get(self.DbOAuthClient, client_id)
|
|
611
|
+
if client is None or not client.is_active:
|
|
612
|
+
return False
|
|
613
|
+
return pwd_context.verify(client_secret, client.client_secret_hash)
|
|
614
|
+
finally:
|
|
615
|
+
session.close()
|
|
616
|
+
|
|
617
|
+
def validate_redirect_uri(self, client_id: str, redirect_uri: str) -> bool:
|
|
618
|
+
"""Validate that redirect_uri is registered for the client"""
|
|
619
|
+
session = self.Session()
|
|
620
|
+
try:
|
|
621
|
+
client = session.get(self.DbOAuthClient, client_id)
|
|
622
|
+
if client is None or not client.is_active:
|
|
623
|
+
return False
|
|
624
|
+
|
|
625
|
+
registered_uris = json.loads(client.redirect_uris)
|
|
626
|
+
return redirect_uri in registered_uris
|
|
627
|
+
finally:
|
|
628
|
+
session.close()
|
|
629
|
+
|
|
630
|
+
def validate_registration_access_token(self, client_id: str, registration_access_token: str) -> bool:
|
|
631
|
+
"""Validate registration access token for client management operations"""
|
|
632
|
+
session = self.Session()
|
|
633
|
+
try:
|
|
634
|
+
client = session.get(self.DbOAuthClient, client_id)
|
|
635
|
+
if client is None:
|
|
636
|
+
return False
|
|
637
|
+
|
|
638
|
+
return pwd_context.verify(registration_access_token, client.registration_access_token_hash)
|
|
639
|
+
finally:
|
|
640
|
+
session.close()
|
|
641
|
+
|
|
642
|
+
def _validate_redirect_uri_format(self, uri: str) -> bool:
|
|
643
|
+
"""Validate redirect URI format for security"""
|
|
644
|
+
# Basic validation - must be https (except localhost) and not contain fragments
|
|
645
|
+
if '#' in uri:
|
|
646
|
+
return False
|
|
647
|
+
|
|
648
|
+
if uri.startswith('http://'):
|
|
649
|
+
# Only allow http for localhost/development
|
|
650
|
+
if not ('localhost' in uri or '127.0.0.1' in uri):
|
|
651
|
+
return False
|
|
652
|
+
elif not uri.startswith('https://'):
|
|
653
|
+
# Custom schemes allowed for mobile apps
|
|
654
|
+
if '://' not in uri:
|
|
655
|
+
return False
|
|
656
|
+
|
|
657
|
+
return True
|
|
658
|
+
|
|
659
|
+
def update_oauth_client_with_token_rotation(self, client_id: str, request: ClientUpdateRequest) -> ClientUpdateResponse:
|
|
660
|
+
"""Update OAuth client metadata and rotate registration access token"""
|
|
661
|
+
|
|
662
|
+
# Build update dictionary, excluding None values
|
|
663
|
+
updates = self._validate_client_registration_request(request)
|
|
664
|
+
|
|
665
|
+
session = self.Session()
|
|
666
|
+
try:
|
|
667
|
+
client = session.get(self.DbOAuthClient, client_id)
|
|
668
|
+
if client is None:
|
|
669
|
+
raise InvalidInputError(404, "invalid_client_id", f"OAuth client not found: {client_id}")
|
|
670
|
+
if not client.is_active and not request.is_active:
|
|
671
|
+
raise InvalidInputError(400, "invalid_request", "Cannot update a deactivated client")
|
|
672
|
+
|
|
673
|
+
if request.is_active is not None:
|
|
674
|
+
client.is_active = request.is_active
|
|
675
|
+
|
|
676
|
+
# Generate new registration access token
|
|
677
|
+
new_registration_access_token, new_registration_access_token_hash = self.generate_secret_and_hash()
|
|
678
|
+
|
|
679
|
+
# Update client fields
|
|
680
|
+
for key, value in updates.items():
|
|
681
|
+
if hasattr(client, key):
|
|
682
|
+
setattr(client, key, value)
|
|
683
|
+
|
|
684
|
+
# Update registration access token
|
|
685
|
+
client.registration_access_token_hash = new_registration_access_token_hash
|
|
686
|
+
|
|
687
|
+
session.commit()
|
|
688
|
+
|
|
689
|
+
return ClientUpdateResponse(
|
|
690
|
+
client_id=client.client_id,
|
|
691
|
+
client_name=client.client_name,
|
|
692
|
+
redirect_uris=json.loads(client.redirect_uris),
|
|
693
|
+
scope=client.scope,
|
|
694
|
+
grant_types=client.grant_types.split(','),
|
|
695
|
+
response_types=client.response_types.split(','),
|
|
696
|
+
created_at=client.created_at,
|
|
697
|
+
is_active=client.is_active,
|
|
698
|
+
registration_access_token=new_registration_access_token,
|
|
699
|
+
)
|
|
700
|
+
|
|
701
|
+
finally:
|
|
702
|
+
session.close()
|
|
703
|
+
|
|
704
|
+
def revoke_oauth_client(self, client_id: str) -> None:
|
|
705
|
+
"""Revoke (deactivate) an OAuth client"""
|
|
706
|
+
session = self.Session()
|
|
707
|
+
try:
|
|
708
|
+
client = session.get(self.DbOAuthClient, client_id)
|
|
709
|
+
if client is None:
|
|
710
|
+
raise InvalidInputError(404, "client_not_found", f"OAuth client not found: {client_id}")
|
|
711
|
+
|
|
712
|
+
client.is_active = False
|
|
713
|
+
session.commit()
|
|
714
|
+
finally:
|
|
715
|
+
session.close()
|
|
716
|
+
|
|
717
|
+
# Authorization Code Flow Methods
|
|
718
|
+
|
|
719
|
+
def create_authorization_code(
|
|
720
|
+
self, client_id: str, username: str, redirect_uri: str, scope: str,
|
|
721
|
+
code_challenge: str, code_challenge_method: str = "S256"
|
|
722
|
+
) -> str:
|
|
723
|
+
"""Create and store an authorization code for the authorization code flow"""
|
|
724
|
+
|
|
725
|
+
# Validate client_id
|
|
726
|
+
client_details = self.get_oauth_client_details(client_id)
|
|
727
|
+
if client_details is None or not client_details.is_active:
|
|
728
|
+
raise InvalidInputError(400, "invalid_client", "Invalid or inactive client")
|
|
729
|
+
|
|
730
|
+
# Validate redirect_uri
|
|
731
|
+
if not self.validate_redirect_uri(client_id, redirect_uri):
|
|
732
|
+
raise InvalidInputError(400, "invalid_redirect_uri", "Invalid redirect_uri for this client")
|
|
733
|
+
|
|
734
|
+
# Validate PKCE parameters
|
|
735
|
+
if not code_challenge:
|
|
736
|
+
raise InvalidInputError(400, "invalid_request", "code_challenge is required")
|
|
737
|
+
|
|
738
|
+
if code_challenge_method not in ["S256"]:
|
|
739
|
+
raise InvalidInputError(400, "invalid_request", "code_challenge_method must be 'S256'")
|
|
740
|
+
|
|
741
|
+
# Generate authorization code
|
|
742
|
+
code = secrets.token_urlsafe(48)
|
|
743
|
+
|
|
744
|
+
# Set expiration (10 minutes from now)
|
|
745
|
+
created_at = datetime.now(timezone.utc)
|
|
746
|
+
expires_at = created_at + timedelta(minutes=10)
|
|
747
|
+
|
|
748
|
+
session = self.Session()
|
|
749
|
+
try:
|
|
750
|
+
auth_code = self.DbAuthorizationCode(
|
|
751
|
+
code=code,
|
|
752
|
+
client_id=client_id,
|
|
753
|
+
username=username,
|
|
754
|
+
redirect_uri=redirect_uri,
|
|
755
|
+
scope=scope,
|
|
756
|
+
code_challenge=code_challenge,
|
|
757
|
+
code_challenge_method=code_challenge_method,
|
|
758
|
+
created_at=created_at,
|
|
759
|
+
expires_at=expires_at,
|
|
760
|
+
used=False
|
|
761
|
+
)
|
|
762
|
+
session.add(auth_code)
|
|
763
|
+
session.commit()
|
|
764
|
+
|
|
765
|
+
return code
|
|
766
|
+
|
|
767
|
+
finally:
|
|
768
|
+
session.close()
|
|
769
|
+
|
|
770
|
+
def exchange_authorization_code(
|
|
771
|
+
self, code: str, client_id: str, redirect_uri: str, code_verifier: str, access_token_expiry_minutes: int
|
|
772
|
+
) -> TokenResponse:
|
|
773
|
+
"""
|
|
774
|
+
Exchange authorization code for access and refresh tokens
|
|
775
|
+
Returns (access_token, refresh_token)
|
|
776
|
+
"""
|
|
777
|
+
session = self.Session()
|
|
778
|
+
try:
|
|
779
|
+
# Get and validate authorization code
|
|
780
|
+
auth_code = session.query(self.DbAuthorizationCode).filter(
|
|
781
|
+
self.DbAuthorizationCode.code == code,
|
|
782
|
+
self.DbAuthorizationCode.client_id == client_id,
|
|
783
|
+
self.DbAuthorizationCode.expires_at >= func.now(),
|
|
784
|
+
self.DbAuthorizationCode.used == False
|
|
785
|
+
).first()
|
|
786
|
+
|
|
787
|
+
if auth_code is None:
|
|
788
|
+
raise InvalidInputError(400, "invalid_grant", "Invalid authorization code")
|
|
789
|
+
|
|
790
|
+
# Validate redirect URI
|
|
791
|
+
if auth_code.redirect_uri != redirect_uri:
|
|
792
|
+
raise InvalidInputError(400, "invalid_grant", "Redirect URI mismatch")
|
|
793
|
+
|
|
794
|
+
if not u.validate_pkce_challenge(code_verifier, auth_code.code_challenge):
|
|
795
|
+
raise InvalidInputError(400, "invalid_grant", "Invalid code_verifier")
|
|
796
|
+
|
|
797
|
+
# Get user
|
|
798
|
+
db_user = session.get(self.DbUser, auth_code.username)
|
|
799
|
+
if db_user is None:
|
|
800
|
+
raise InvalidInputError(400, "invalid_grant", "User not found")
|
|
801
|
+
|
|
802
|
+
# Mark authorization code as used
|
|
803
|
+
auth_code.used = True
|
|
804
|
+
|
|
805
|
+
# Generate tokens
|
|
806
|
+
user_obj = self._convert_db_user_to_user(db_user)
|
|
807
|
+
access_token, token_expires_at = self.create_access_token(user_obj, expiry_minutes=access_token_expiry_minutes)
|
|
808
|
+
access_token_hash = pwd_context.hash(access_token)
|
|
809
|
+
|
|
810
|
+
# Generate refresh token
|
|
811
|
+
refresh_token, refresh_token_hash = self.generate_secret_and_hash()
|
|
812
|
+
refresh_expires_at = datetime.now(timezone.utc) + timedelta(days=30)
|
|
813
|
+
|
|
814
|
+
oauth_token = self.DbOAuthToken(
|
|
815
|
+
access_token_hash=access_token_hash,
|
|
816
|
+
refresh_token_hash=refresh_token_hash,
|
|
817
|
+
client_id=client_id,
|
|
818
|
+
username=auth_code.username,
|
|
819
|
+
scope=auth_code.scope,
|
|
820
|
+
access_token_expires_at=token_expires_at,
|
|
821
|
+
refresh_token_expires_at=refresh_expires_at
|
|
822
|
+
)
|
|
823
|
+
session.add(oauth_token)
|
|
824
|
+
|
|
825
|
+
session.commit()
|
|
826
|
+
|
|
827
|
+
return TokenResponse(
|
|
828
|
+
access_token=access_token,
|
|
829
|
+
expires_in=access_token_expiry_minutes*60,
|
|
830
|
+
refresh_token=refresh_token
|
|
831
|
+
)
|
|
832
|
+
|
|
833
|
+
finally:
|
|
834
|
+
session.close()
|
|
835
|
+
|
|
836
|
+
def refresh_oauth_access_token(self, refresh_token: str, client_id: str, access_token_expiry_minutes: int) -> TokenResponse:
|
|
837
|
+
"""
|
|
838
|
+
Refresh OAuth access token using refresh token
|
|
839
|
+
Returns (access_token, new_refresh_token)
|
|
840
|
+
"""
|
|
841
|
+
session = self.Session()
|
|
842
|
+
try:
|
|
843
|
+
# Validate client
|
|
844
|
+
client = session.get(self.DbOAuthClient, client_id)
|
|
845
|
+
if client is None or not client.is_active:
|
|
846
|
+
raise InvalidInputError(400, "invalid_client", "Invalid or inactive client")
|
|
847
|
+
|
|
848
|
+
# Find active refresh token for this client
|
|
849
|
+
oauth_token = session.query(self.DbOAuthToken).filter(
|
|
850
|
+
self.DbOAuthToken.client_id == client_id,
|
|
851
|
+
self.DbOAuthToken.refresh_token_expires_at >= func.now(),
|
|
852
|
+
self.DbOAuthToken.is_revoked == False
|
|
853
|
+
).first()
|
|
854
|
+
|
|
855
|
+
# Find the token that matches our refresh token
|
|
856
|
+
if oauth_token is None or not pwd_context.verify(refresh_token, oauth_token.refresh_token_hash):
|
|
857
|
+
raise InvalidInputError(400, "invalid_grant", "Invalid or expired refresh token")
|
|
858
|
+
|
|
859
|
+
# Get user
|
|
860
|
+
db_user = session.get(self.DbUser, oauth_token.username)
|
|
861
|
+
if db_user is None:
|
|
862
|
+
raise InvalidInputError(400, "invalid_client", "User not found")
|
|
863
|
+
|
|
864
|
+
# Check secret key is available
|
|
865
|
+
if self.secret_key is None:
|
|
866
|
+
raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required for OAuth token operations")
|
|
867
|
+
|
|
868
|
+
# Generate new tokens
|
|
869
|
+
user_obj = self._convert_db_user_to_user(db_user)
|
|
870
|
+
access_token, token_expires_at = self.create_access_token(user_obj, expiry_minutes=access_token_expiry_minutes)
|
|
871
|
+
access_token_hash = pwd_context.hash(access_token)
|
|
872
|
+
|
|
873
|
+
# Generate new refresh token
|
|
874
|
+
new_refresh_token, new_refresh_token_hash = self.generate_secret_and_hash()
|
|
875
|
+
refresh_expires_at = datetime.now(timezone.utc) + timedelta(days=30)
|
|
876
|
+
|
|
877
|
+
# Revoke old token
|
|
878
|
+
oauth_token.is_revoked = True
|
|
879
|
+
|
|
880
|
+
# Create new token entry
|
|
881
|
+
new_oauth_token = self.DbOAuthToken(
|
|
882
|
+
access_token_hash=access_token_hash,
|
|
883
|
+
refresh_token_hash=new_refresh_token_hash,
|
|
884
|
+
client_id=client_id,
|
|
885
|
+
username=oauth_token.username,
|
|
886
|
+
scope=oauth_token.scope,
|
|
887
|
+
access_token_expires_at=token_expires_at,
|
|
888
|
+
refresh_token_expires_at=refresh_expires_at
|
|
889
|
+
)
|
|
890
|
+
session.add(new_oauth_token)
|
|
891
|
+
|
|
892
|
+
session.commit()
|
|
893
|
+
|
|
894
|
+
return TokenResponse(
|
|
895
|
+
access_token=access_token,
|
|
896
|
+
expires_in=access_token_expiry_minutes*60,
|
|
897
|
+
refresh_token=new_refresh_token
|
|
898
|
+
)
|
|
899
|
+
|
|
900
|
+
finally:
|
|
901
|
+
session.close()
|
|
902
|
+
|
|
903
|
+
def revoke_oauth_token(self, client_id: str, token: str, token_type_hint: str | None = None) -> None:
|
|
904
|
+
"""
|
|
905
|
+
Revoke an OAuth refresh token
|
|
906
|
+
token_type_hint is optional or must be 'refresh_token'. Revoking access token is not supported yet.
|
|
907
|
+
"""
|
|
908
|
+
if token_type_hint and token_type_hint != 'refresh_token':
|
|
909
|
+
raise InvalidInputError(400, "invalid_request", "Only refresh tokens can be revoked")
|
|
910
|
+
|
|
911
|
+
session = self.Session()
|
|
912
|
+
try:
|
|
913
|
+
# Validate client
|
|
914
|
+
client = session.get(self.DbOAuthClient, client_id)
|
|
915
|
+
if client is None or not client.is_active:
|
|
916
|
+
raise InvalidInputError(400, "invalid_client", "Invalid or inactive client")
|
|
917
|
+
|
|
918
|
+
# Get all potentially matching tokens
|
|
919
|
+
oauth_tokens = session.query(self.DbOAuthToken).filter(
|
|
920
|
+
self.DbOAuthToken.client_id == client_id,
|
|
921
|
+
self.DbOAuthToken.refresh_token_expires_at >= func.now(),
|
|
922
|
+
self.DbOAuthToken.is_revoked == False
|
|
923
|
+
).all()
|
|
924
|
+
|
|
925
|
+
# Find the token that matches
|
|
926
|
+
oauth_token = None
|
|
927
|
+
for token_obj in oauth_tokens:
|
|
928
|
+
if pwd_context.verify(token, token_obj.refresh_token_hash):
|
|
929
|
+
oauth_token = token_obj
|
|
930
|
+
break
|
|
931
|
+
|
|
932
|
+
# Revoke token if found (per OAuth spec, always return success)
|
|
933
|
+
if oauth_token:
|
|
934
|
+
oauth_token.is_revoked = True
|
|
935
|
+
session.commit()
|
|
936
|
+
|
|
937
|
+
finally:
|
|
938
|
+
session.close()
|
|
939
|
+
|
|
450
940
|
def close(self) -> None:
|
|
451
941
|
self.engine.dispose()
|
|
942
|
+
|
|
943
|
+
|
|
944
|
+
def provider(name: str, label: str, icon: str):
|
|
945
|
+
"""
|
|
946
|
+
Decorator to register an authentication provider
|
|
947
|
+
|
|
948
|
+
Arguments:
|
|
949
|
+
name: The name of the provider (must be unique, e.g. 'google')
|
|
950
|
+
label: The label of the provider (e.g. 'Google')
|
|
951
|
+
icon: The URL of the icon of the provider (e.g. 'https://www.google.com/favicon.ico')
|
|
952
|
+
"""
|
|
953
|
+
def decorator(func: Callable[[AuthProviderArgs], ProviderConfigs]):
|
|
954
|
+
def wrapper(sqrl: AuthProviderArgs):
|
|
955
|
+
provider_configs = func(sqrl)
|
|
956
|
+
return AuthProvider(name=name, label=label, icon=icon, provider_configs=provider_configs)
|
|
957
|
+
Authenticator.providers.append(wrapper)
|
|
958
|
+
return wrapper
|
|
959
|
+
return decorator
|