squirrels 0.5.0b2__py3-none-any.whl → 0.5.0b4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of squirrels might be problematic. Click here for more details.
- dateutils/__init__.py +6 -460
- dateutils/_enums.py +25 -0
- dateutils/_implementation.py +409 -0
- dateutils/types.py +6 -0
- squirrels/__init__.py +9 -13
- squirrels/_api_routes/__init__.py +5 -0
- squirrels/_api_routes/auth.py +262 -0
- squirrels/_api_routes/base.py +154 -0
- squirrels/_api_routes/dashboards.py +142 -0
- squirrels/_api_routes/data_management.py +103 -0
- squirrels/_api_routes/datasets.py +242 -0
- squirrels/_api_routes/oauth2.py +300 -0
- squirrels/_api_routes/project.py +214 -0
- squirrels/_api_server.py +145 -748
- squirrels/_arguments/__init__.py +0 -0
- squirrels/{arguments → _arguments}/init_time_args.py +7 -2
- squirrels/{arguments → _arguments}/run_time_args.py +4 -26
- squirrels/_auth.py +646 -93
- squirrels/_connection_set.py +5 -5
- squirrels/_constants.py +7 -1
- squirrels/{_dashboards_io.py → _dashboards.py} +87 -6
- squirrels/_data_sources.py +564 -0
- squirrels/_exceptions.py +9 -37
- squirrels/_initializer.py +31 -26
- squirrels/_manifest.py +5 -5
- squirrels/_model_builder.py +1 -1
- squirrels/_model_configs.py +2 -2
- squirrels/_model_queries.py +1 -1
- squirrels/_models.py +40 -27
- squirrels/{package_data → _package_data}/base_project/.env +1 -0
- squirrels/{package_data → _package_data}/base_project/.env.example +1 -0
- squirrels/{package_data → _package_data}/base_project/dashboards/dashboard_example.py +4 -4
- squirrels/{package_data → _package_data}/base_project/dashboards/dashboard_example.yml +2 -2
- squirrels/_package_data/base_project/macros/macros_example.sql +17 -0
- squirrels/{package_data → _package_data}/base_project/models/builds/build_example.py +2 -2
- squirrels/{package_data → _package_data}/base_project/models/builds/build_example.sql +1 -1
- squirrels/{package_data → _package_data}/base_project/models/dbviews/dbview_example.sql +1 -1
- squirrels/_package_data/base_project/models/federates/federate_example.py +41 -0
- squirrels/_package_data/base_project/models/federates/federate_example.sql +25 -0
- squirrels/{package_data → _package_data}/base_project/models/federates/federate_example.yml +6 -6
- squirrels/{package_data → _package_data}/base_project/parameters.yml +9 -8
- squirrels/_package_data/base_project/pyconfigs/connections.py +14 -0
- squirrels/{package_data → _package_data}/base_project/pyconfigs/context.py +14 -16
- squirrels/_package_data/base_project/pyconfigs/parameters.py +106 -0
- squirrels/_package_data/base_project/pyconfigs/user.py +51 -0
- squirrels/_package_data/templates/dataset_results.html +112 -0
- squirrels/_package_data/templates/oauth_login.html +271 -0
- squirrels/_parameter_configs.py +35 -35
- squirrels/_parameter_options.py +348 -0
- squirrels/_parameter_sets.py +47 -37
- squirrels/_parameters.py +1664 -0
- squirrels/_project.py +76 -32
- squirrels/_py_module.py +3 -2
- squirrels/_schemas/__init__.py +0 -0
- squirrels/_schemas/auth_models.py +144 -0
- squirrels/_schemas/query_param_models.py +67 -0
- squirrels/{_api_response_models.py → _schemas/response_models.py} +12 -8
- squirrels/_utils.py +38 -4
- squirrels/arguments.py +2 -0
- squirrels/auth.py +1 -0
- squirrels/connections.py +1 -0
- squirrels/dashboards.py +1 -82
- squirrels/data_sources.py +8 -563
- squirrels/parameter_options.py +8 -348
- squirrels/parameters.py +9 -1266
- squirrels/types.py +11 -0
- {squirrels-0.5.0b2.dist-info → squirrels-0.5.0b4.dist-info}/METADATA +4 -1
- squirrels-0.5.0b4.dist-info/RECORD +94 -0
- squirrels/package_data/base_project/macros/macros_example.sql +0 -15
- squirrels/package_data/base_project/models/federates/federate_example.py +0 -44
- squirrels/package_data/base_project/models/federates/federate_example.sql +0 -17
- squirrels/package_data/base_project/pyconfigs/connections.py +0 -14
- squirrels/package_data/base_project/pyconfigs/parameters.py +0 -93
- squirrels/package_data/base_project/pyconfigs/user.py +0 -23
- squirrels-0.5.0b2.dist-info/RECORD +0 -70
- /squirrels/{dataset_result.py → _dataset_types.py} +0 -0
- /squirrels/{package_data → _package_data}/base_project/assets/expenses.db +0 -0
- /squirrels/{package_data → _package_data}/base_project/assets/weather.db +0 -0
- /squirrels/{package_data → _package_data}/base_project/connections.yml +0 -0
- /squirrels/{package_data → _package_data}/base_project/docker/.dockerignore +0 -0
- /squirrels/{package_data → _package_data}/base_project/docker/Dockerfile +0 -0
- /squirrels/{package_data → _package_data}/base_project/docker/compose.yml +0 -0
- /squirrels/{package_data → _package_data}/base_project/duckdb_init.sql +0 -0
- /squirrels/{package_data/base_project/.gitignore → _package_data/base_project/gitignore} +0 -0
- /squirrels/{package_data → _package_data}/base_project/models/builds/build_example.yml +0 -0
- /squirrels/{package_data → _package_data}/base_project/models/dbviews/dbview_example.yml +0 -0
- /squirrels/{package_data → _package_data}/base_project/models/sources.yml +0 -0
- /squirrels/{package_data → _package_data}/base_project/seeds/seed_categories.csv +0 -0
- /squirrels/{package_data → _package_data}/base_project/seeds/seed_categories.yml +0 -0
- /squirrels/{package_data → _package_data}/base_project/seeds/seed_subcategories.csv +0 -0
- /squirrels/{package_data → _package_data}/base_project/seeds/seed_subcategories.yml +0 -0
- /squirrels/{package_data → _package_data}/base_project/squirrels.yml.j2 +0 -0
- /squirrels/{package_data → _package_data}/base_project/tmp/.gitignore +0 -0
- {squirrels-0.5.0b2.dist-info → squirrels-0.5.0b4.dist-info}/WHEEL +0 -0
- {squirrels-0.5.0b2.dist-info → squirrels-0.5.0b4.dist-info}/entry_points.txt +0 -0
- {squirrels-0.5.0b2.dist-info → squirrels-0.5.0b4.dist-info}/licenses/LICENSE +0 -0
squirrels/_auth.py
CHANGED
|
@@ -1,18 +1,24 @@
|
|
|
1
|
+
from typing import Callable
|
|
1
2
|
from datetime import datetime, timedelta, timezone
|
|
2
3
|
from enum import Enum
|
|
3
4
|
from functools import cached_property
|
|
4
5
|
from jwt.exceptions import InvalidTokenError
|
|
5
6
|
from passlib.context import CryptContext
|
|
6
|
-
from pydantic import
|
|
7
|
+
from pydantic import ValidationError
|
|
7
8
|
from pydantic_core import PydanticUndefined
|
|
8
9
|
from sqlalchemy import create_engine, Engine, func, inspect, text, ForeignKey
|
|
9
10
|
from sqlalchemy import Column, String, Integer, Float, Boolean
|
|
10
11
|
from sqlalchemy.orm import declarative_base, sessionmaker, Mapped, mapped_column
|
|
11
|
-
import jwt, types, typing as _t, uuid
|
|
12
|
+
import jwt, types, typing as _t, uuid, secrets, json
|
|
12
13
|
|
|
13
14
|
from ._manifest import PermissionScope
|
|
14
15
|
from ._py_module import PyModule
|
|
15
16
|
from ._exceptions import InvalidInputError, ConfigurationError
|
|
17
|
+
from ._arguments.init_time_args import AuthProviderArgs
|
|
18
|
+
from ._schemas.auth_models import (
|
|
19
|
+
BaseUser, ApiKey, UserField, AuthProvider, ProviderConfigs, ClientRegistrationRequest, ClientUpdateRequest,
|
|
20
|
+
ClientDetailsResponse, ClientRegistrationResponse, ClientUpdateResponse, TokenResponse
|
|
21
|
+
)
|
|
16
22
|
from . import _utils as u, _constants as c
|
|
17
23
|
|
|
18
24
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
@@ -20,41 +26,20 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
20
26
|
reserved_fields = ["username", "is_admin"]
|
|
21
27
|
disallowed_fields = ["password", "password_hash", "created_at", "token_id", "exp"]
|
|
22
28
|
|
|
23
|
-
class BaseUser(BaseModel):
|
|
24
|
-
model_config = ConfigDict(from_attributes=True)
|
|
25
|
-
username: str
|
|
26
|
-
is_admin: bool = False
|
|
27
|
-
|
|
28
|
-
@classmethod
|
|
29
|
-
def dropped_columns(cls):
|
|
30
|
-
return []
|
|
31
|
-
|
|
32
|
-
def __hash__(self):
|
|
33
|
-
return hash(self.username)
|
|
34
|
-
|
|
35
29
|
User = _t.TypeVar('User', bound=BaseUser)
|
|
36
30
|
|
|
37
|
-
|
|
38
|
-
model_config = ConfigDict(from_attributes=True)
|
|
39
|
-
token_id: str
|
|
40
|
-
title: str
|
|
41
|
-
username: str
|
|
42
|
-
created_at: datetime
|
|
43
|
-
expires_at: datetime
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
class UserField(BaseModel):
|
|
47
|
-
name: str
|
|
48
|
-
type: str
|
|
49
|
-
nullable: bool
|
|
50
|
-
enum: list[str] | None
|
|
51
|
-
default: _t.Any | None
|
|
31
|
+
ProviderFunctionType = Callable[[AuthProviderArgs], AuthProvider]
|
|
52
32
|
|
|
53
33
|
|
|
54
34
|
class Authenticator(_t.Generic[User]):
|
|
55
|
-
|
|
35
|
+
providers: list[ProviderFunctionType] = [] # static variable to stage providers
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self, logger: u.Logger, base_path: str, auth_args: AuthProviderArgs, provider_functions: list[ProviderFunctionType],
|
|
39
|
+
user_cls: type[User], *, sa_engine: Engine | None = None
|
|
40
|
+
):
|
|
56
41
|
self.logger = logger
|
|
57
|
-
self.env_vars = env_vars
|
|
42
|
+
self.env_vars = auth_args.env_vars
|
|
58
43
|
self.secret_key = self.env_vars.get(c.SQRL_SECRET_KEY)
|
|
59
44
|
|
|
60
45
|
# Create a new declarative base for this instance
|
|
@@ -69,27 +54,89 @@ class Authenticator(_t.Generic[User]):
|
|
|
69
54
|
password_hash: Mapped[str] = mapped_column(nullable=False)
|
|
70
55
|
created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
|
|
71
56
|
|
|
72
|
-
# Define
|
|
73
|
-
class
|
|
74
|
-
__tablename__ = '
|
|
57
|
+
# Define DbApiKey class for this instance
|
|
58
|
+
class DbApiKey(self.Base):
|
|
59
|
+
__tablename__ = 'api_keys'
|
|
75
60
|
|
|
76
|
-
|
|
61
|
+
id: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
|
|
62
|
+
hashed_key: Mapped[str] = mapped_column(unique=True, nullable=False)
|
|
77
63
|
title: Mapped[str] = mapped_column(nullable=False)
|
|
78
64
|
username: Mapped[str] = mapped_column(ForeignKey('users.username', ondelete='CASCADE'), nullable=False)
|
|
79
65
|
created_at: Mapped[datetime] = mapped_column(nullable=False)
|
|
80
66
|
expires_at: Mapped[datetime] = mapped_column(nullable=False)
|
|
81
|
-
|
|
67
|
+
|
|
68
|
+
def __repr__(self):
|
|
69
|
+
return f"<DbApiKey(id='{self.id}', username='{self.username}')>"
|
|
70
|
+
|
|
71
|
+
# Define DbOAuthClient class for this instance
|
|
72
|
+
class DbOAuthClient(self.Base):
|
|
73
|
+
__tablename__ = 'oauth_clients'
|
|
74
|
+
|
|
75
|
+
client_id: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
|
|
76
|
+
client_secret_hash: Mapped[str] = mapped_column(nullable=False)
|
|
77
|
+
client_name: Mapped[str] = mapped_column(nullable=False)
|
|
78
|
+
redirect_uris: Mapped[str] = mapped_column(nullable=False) # JSON array of allowed redirect URIs
|
|
79
|
+
scope: Mapped[str] = mapped_column(nullable=False, default='read')
|
|
80
|
+
grant_types: Mapped[str] = mapped_column(nullable=False, default='authorization_code,refresh_token')
|
|
81
|
+
response_types: Mapped[str] = mapped_column(nullable=False, default='code')
|
|
82
|
+
client_type: Mapped[str] = mapped_column(nullable=False, default='confidential') # 'confidential' or 'public'
|
|
83
|
+
registration_access_token_hash: Mapped[str] = mapped_column(nullable=False) # Token for client management
|
|
84
|
+
created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
|
|
85
|
+
is_active: Mapped[bool] = mapped_column(nullable=False, default=True)
|
|
86
|
+
|
|
82
87
|
def __repr__(self):
|
|
83
|
-
return f"<
|
|
88
|
+
return f"<DbOAuthClient(client_id='{self.client_id}', name='{self.client_name}')>"
|
|
89
|
+
|
|
90
|
+
# Define DbAuthorizationCode class for this instance
|
|
91
|
+
class DbAuthorizationCode(self.Base):
|
|
92
|
+
__tablename__ = 'authorization_codes'
|
|
93
|
+
|
|
94
|
+
code: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
|
|
95
|
+
client_id: Mapped[str] = mapped_column(ForeignKey('oauth_clients.client_id', ondelete='CASCADE'), nullable=False)
|
|
96
|
+
username: Mapped[str] = mapped_column(ForeignKey('users.username', ondelete='CASCADE'), nullable=False)
|
|
97
|
+
redirect_uri: Mapped[str] = mapped_column(nullable=False)
|
|
98
|
+
scope: Mapped[str] = mapped_column(nullable=True)
|
|
99
|
+
code_challenge: Mapped[str] = mapped_column(nullable=False) # PKCE always required
|
|
100
|
+
code_challenge_method: Mapped[str] = mapped_column(nullable=False) # only S256 is supported
|
|
101
|
+
created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
|
|
102
|
+
expires_at: Mapped[datetime] = mapped_column(nullable=False) # 10 minutes from creation
|
|
103
|
+
used: Mapped[bool] = mapped_column(nullable=False, default=False)
|
|
104
|
+
|
|
105
|
+
def __repr__(self):
|
|
106
|
+
return f"<DbAuthorizationCode(code='{self.code[:8]}...', client_id='{self.client_id}')>"
|
|
107
|
+
|
|
108
|
+
# Define DbOAuthToken class for this instance
|
|
109
|
+
class DbOAuthToken(self.Base):
|
|
110
|
+
__tablename__ = 'oauth_tokens'
|
|
111
|
+
|
|
112
|
+
token_id: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
|
|
113
|
+
access_token_hash: Mapped[str] = mapped_column(unique=True, nullable=False)
|
|
114
|
+
refresh_token_hash: Mapped[str] = mapped_column(unique=True, nullable=True) # NULL for client_credentials grants
|
|
115
|
+
client_id: Mapped[str] = mapped_column(ForeignKey('oauth_clients.client_id', ondelete='CASCADE'), nullable=False)
|
|
116
|
+
username: Mapped[str] = mapped_column(ForeignKey('users.username', ondelete='CASCADE'), nullable=False)
|
|
117
|
+
scope: Mapped[str] = mapped_column(nullable=True)
|
|
118
|
+
token_type: Mapped[str] = mapped_column(nullable=False, default='Bearer')
|
|
119
|
+
created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
|
|
120
|
+
access_token_expires_at: Mapped[datetime] = mapped_column(nullable=False) # Uses SQRL_AUTH_TOKEN_EXPIRE_MINUTES
|
|
121
|
+
refresh_token_expires_at: Mapped[datetime] = mapped_column(nullable=True) # 30 days from creation, NULL for client_credentials
|
|
122
|
+
is_revoked: Mapped[bool] = mapped_column(nullable=False, default=False)
|
|
123
|
+
|
|
124
|
+
def __repr__(self):
|
|
125
|
+
return f"<DbOAuthToken(token_id='{self.token_id}', client_id='{self.client_id}', username='{self.username}')>"
|
|
84
126
|
|
|
85
127
|
self.DbBaseUser = DbBaseUser
|
|
86
|
-
self.
|
|
128
|
+
self.DbApiKey = DbApiKey
|
|
129
|
+
self.DbOAuthClient = DbOAuthClient
|
|
130
|
+
self.DbAuthorizationCode = DbAuthorizationCode
|
|
131
|
+
self.DbOAuthToken = DbOAuthToken
|
|
87
132
|
|
|
88
|
-
self.User =
|
|
133
|
+
self.User = user_cls
|
|
89
134
|
self.DbUser: type[DbBaseUser] = self._initialize_db_user_model(self.User)
|
|
135
|
+
|
|
136
|
+
self.auth_providers = [provider_function(auth_args) for provider_function in provider_functions]
|
|
90
137
|
|
|
91
138
|
if sa_engine is None:
|
|
92
|
-
sqlite_relative_path = env_vars.get(c.SQRL_AUTH_DB_FILE_PATH, f"{c.TARGET_FOLDER}/{c.DB_FILE}")
|
|
139
|
+
sqlite_relative_path = self.env_vars.get(c.SQRL_AUTH_DB_FILE_PATH, f"{c.TARGET_FOLDER}/{c.DB_FILE}")
|
|
93
140
|
sqlite_path = u.Path(base_path, sqlite_relative_path)
|
|
94
141
|
sqlite_path.parent.mkdir(parents=True, exist_ok=True)
|
|
95
142
|
self.engine = create_engine(f"sqlite:///{str(sqlite_path)}")
|
|
@@ -181,7 +228,7 @@ class Authenticator(_t.Generic[User]):
|
|
|
181
228
|
new_columns = model_columns - existing_columns
|
|
182
229
|
if new_columns:
|
|
183
230
|
add_columns_msg = f"Adding columns to database: {new_columns}"
|
|
184
|
-
print("NOTE
|
|
231
|
+
print("NOTE -", add_columns_msg)
|
|
185
232
|
self.logger.info(add_columns_msg)
|
|
186
233
|
|
|
187
234
|
for col_name in new_columns:
|
|
@@ -203,7 +250,7 @@ class Authenticator(_t.Generic[User]):
|
|
|
203
250
|
columns_to_drop = dropped_columns.intersection(existing_columns)
|
|
204
251
|
if columns_to_drop:
|
|
205
252
|
drop_columns_msg = f"Dropping columns from database: {columns_to_drop}"
|
|
206
|
-
print("NOTE
|
|
253
|
+
print("NOTE -", drop_columns_msg)
|
|
207
254
|
self.logger.info(drop_columns_msg)
|
|
208
255
|
|
|
209
256
|
for col_name in columns_to_drop:
|
|
@@ -272,7 +319,7 @@ class Authenticator(_t.Generic[User]):
|
|
|
272
319
|
try:
|
|
273
320
|
user_data = self.User(**user_fields, username=username).model_dump(mode='json')
|
|
274
321
|
except ValidationError as e:
|
|
275
|
-
raise InvalidInputError(
|
|
322
|
+
raise InvalidInputError(400, "Invalid user data", f"Invalid user field '{e.errors()[0]['loc'][0]}': {e.errors()[0]['msg']}")
|
|
276
323
|
|
|
277
324
|
# Add a new user
|
|
278
325
|
try:
|
|
@@ -280,19 +327,20 @@ class Authenticator(_t.Generic[User]):
|
|
|
280
327
|
existing_user = session.get(self.DbUser, username)
|
|
281
328
|
if existing_user is not None:
|
|
282
329
|
if not update_user:
|
|
283
|
-
raise InvalidInputError(
|
|
330
|
+
raise InvalidInputError(400, "Username already exists", f"User '{username}' already exists")
|
|
331
|
+
|
|
332
|
+
if username == c.ADMIN_USERNAME and user_data.get("is_admin") is False:
|
|
333
|
+
raise InvalidInputError(403, "Non-admin 'admin' user not permitted", "Setting the admin user to non-admin is not permitted")
|
|
284
334
|
|
|
285
|
-
if username == c.ADMIN_USERNAME:
|
|
286
|
-
raise InvalidInputError(24, "Changing the admin user is not permitted")
|
|
287
335
|
new_user = self.DbUser(password_hash=existing_user.password_hash, **user_data)
|
|
288
336
|
session.delete(existing_user)
|
|
289
337
|
else:
|
|
290
338
|
if update_user:
|
|
291
|
-
raise InvalidInputError(
|
|
339
|
+
raise InvalidInputError(404, "No user found for username", f"No user found for username: {username}")
|
|
292
340
|
|
|
293
341
|
password = user_fields.get('password')
|
|
294
342
|
if password is None:
|
|
295
|
-
raise InvalidInputError(
|
|
343
|
+
raise InvalidInputError(400, "Missing required field 'password'", f"Missing required field 'password' when adding a new user")
|
|
296
344
|
password_hash = pwd_context.hash(password)
|
|
297
345
|
new_user = self.DbUser(password_hash=password_hash, **user_data)
|
|
298
346
|
|
|
@@ -304,6 +352,28 @@ class Authenticator(_t.Generic[User]):
|
|
|
304
352
|
|
|
305
353
|
finally:
|
|
306
354
|
session.close()
|
|
355
|
+
|
|
356
|
+
def create_or_get_user_from_provider(self, provider_name: str, user_info: dict) -> User:
|
|
357
|
+
provider = next((p for p in self.auth_providers if p.name == provider_name), None)
|
|
358
|
+
if provider is None:
|
|
359
|
+
raise InvalidInputError(404, "Provider not found", f"Provider '{provider_name}' not found")
|
|
360
|
+
|
|
361
|
+
user = provider.provider_configs.get_user(user_info)
|
|
362
|
+
session = self.Session()
|
|
363
|
+
try:
|
|
364
|
+
db_user = session.get(self.DbUser, user.username)
|
|
365
|
+
if db_user is None:
|
|
366
|
+
# Create new user
|
|
367
|
+
user_data = user.model_dump()
|
|
368
|
+
password_hash = "" # No hash makes it impossible to login with username and password
|
|
369
|
+
db_user = self.DbUser(password_hash=password_hash, **user_data)
|
|
370
|
+
session.add(db_user)
|
|
371
|
+
session.commit()
|
|
372
|
+
|
|
373
|
+
return self.User.model_validate(db_user)
|
|
374
|
+
|
|
375
|
+
finally:
|
|
376
|
+
session.close()
|
|
307
377
|
|
|
308
378
|
def get_user(self, username: str, password: str) -> User:
|
|
309
379
|
session = self.Session()
|
|
@@ -315,7 +385,7 @@ class Authenticator(_t.Generic[User]):
|
|
|
315
385
|
user = self.User.model_validate(db_user)
|
|
316
386
|
return user # type: ignore
|
|
317
387
|
else:
|
|
318
|
-
raise InvalidInputError(
|
|
388
|
+
raise InvalidInputError(401, "Incorrect username or password", f"Incorrect username or password")
|
|
319
389
|
|
|
320
390
|
finally:
|
|
321
391
|
session.close()
|
|
@@ -325,114 +395,132 @@ class Authenticator(_t.Generic[User]):
|
|
|
325
395
|
try:
|
|
326
396
|
db_user = session.get(self.DbUser, username)
|
|
327
397
|
if db_user is None:
|
|
328
|
-
raise InvalidInputError(
|
|
398
|
+
raise InvalidInputError(401, "User not found", f"Username '{username}' not found for password change")
|
|
329
399
|
|
|
330
|
-
if pwd_context.verify(old_password, db_user.password_hash):
|
|
400
|
+
if db_user.password_hash and pwd_context.verify(old_password, db_user.password_hash):
|
|
331
401
|
db_user.password_hash = pwd_context.hash(new_password)
|
|
332
402
|
session.commit()
|
|
333
403
|
else:
|
|
334
|
-
raise InvalidInputError(
|
|
404
|
+
raise InvalidInputError(401, "Incorrect password", f"Incorrect password")
|
|
335
405
|
finally:
|
|
336
406
|
session.close()
|
|
337
407
|
|
|
338
408
|
def delete_user(self, username: str) -> None:
|
|
339
409
|
if username == c.ADMIN_USERNAME:
|
|
340
|
-
raise InvalidInputError(
|
|
410
|
+
raise InvalidInputError(403, "Cannot delete admin user", "Cannot delete the admin user")
|
|
341
411
|
|
|
342
412
|
session = self.Session()
|
|
343
413
|
try:
|
|
344
414
|
db_user = session.get(self.DbUser, username)
|
|
345
415
|
if db_user is None:
|
|
346
|
-
raise InvalidInputError(
|
|
416
|
+
raise InvalidInputError(404, "No user found for username", f"No user found for username: {username}")
|
|
347
417
|
session.delete(db_user)
|
|
348
418
|
session.commit()
|
|
349
419
|
finally:
|
|
350
420
|
session.close()
|
|
351
421
|
|
|
352
|
-
def get_all_users(self) -> list
|
|
422
|
+
def get_all_users(self) -> list:
|
|
353
423
|
session = self.Session()
|
|
354
424
|
try:
|
|
355
425
|
db_users = session.query(self.DbUser).all()
|
|
356
|
-
return [self.User.model_validate(user) for user in db_users]
|
|
426
|
+
return [self.User.model_validate(user) for user in db_users]
|
|
357
427
|
finally:
|
|
358
428
|
session.close()
|
|
359
429
|
|
|
360
430
|
def create_access_token(self, user: User, expiry_minutes: int | None, *, title: str | None = None) -> tuple[str, datetime]:
|
|
431
|
+
"""
|
|
432
|
+
Creates an API key if title is provided. Otherwise, creates a JWT token.
|
|
433
|
+
"""
|
|
361
434
|
created_at = datetime.now(timezone.utc)
|
|
362
435
|
expire_at = created_at + timedelta(minutes=expiry_minutes) if expiry_minutes is not None else datetime.max
|
|
363
|
-
|
|
436
|
+
|
|
437
|
+
if self.secret_key is None:
|
|
438
|
+
raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to create an access token")
|
|
439
|
+
|
|
364
440
|
if title is not None:
|
|
365
441
|
session = self.Session()
|
|
366
442
|
try:
|
|
367
|
-
|
|
368
|
-
|
|
443
|
+
token_id = "sqrl-" + uuid.uuid4().hex
|
|
444
|
+
hashed_key = u.hash_string(token_id, salt=self.secret_key)
|
|
445
|
+
api_key = self.DbApiKey(hashed_key=hashed_key, title=title, username=user.username, created_at=created_at, expires_at=expire_at)
|
|
446
|
+
session.add(api_key)
|
|
369
447
|
session.commit()
|
|
370
|
-
token_id = access_token.token_id
|
|
371
448
|
finally:
|
|
372
449
|
session.close()
|
|
450
|
+
else:
|
|
451
|
+
to_encode = {"username": user.username, "exp": expire_at}
|
|
452
|
+
token_id = jwt.encode(to_encode, self.secret_key, algorithm="HS256")
|
|
373
453
|
|
|
374
|
-
|
|
375
|
-
raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to create an access token")
|
|
376
|
-
to_encode = {"username": user.username, "token_id": token_id, "exp": expire_at}
|
|
377
|
-
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm="HS256")
|
|
378
|
-
return encoded_jwt, expire_at
|
|
454
|
+
return token_id, expire_at
|
|
379
455
|
|
|
380
456
|
def get_user_from_token(self, token: str | None) -> User | None:
|
|
381
|
-
|
|
457
|
+
"""
|
|
458
|
+
Get a user from an access token (JWT, or API key if token starts with 'sqrl-')
|
|
459
|
+
"""
|
|
460
|
+
if not token:
|
|
382
461
|
return None
|
|
383
462
|
|
|
384
463
|
if self.secret_key is None:
|
|
385
464
|
raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to get user from an access token")
|
|
386
465
|
|
|
387
|
-
try:
|
|
388
|
-
payload: dict = jwt.decode(token, self.secret_key, algorithms=["HS256"])
|
|
389
|
-
except InvalidTokenError:
|
|
390
|
-
raise InvalidInputError(1, "Invalid authorization token")
|
|
391
|
-
|
|
392
466
|
session = self.Session()
|
|
393
467
|
try:
|
|
394
|
-
if
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
self.
|
|
398
|
-
self.
|
|
468
|
+
if token.startswith("sqrl-"):
|
|
469
|
+
hashed_key = u.hash_string(token, salt=self.secret_key)
|
|
470
|
+
api_key = session.query(self.DbApiKey).filter(
|
|
471
|
+
self.DbApiKey.hashed_key == hashed_key,
|
|
472
|
+
self.DbApiKey.expires_at >= func.now()
|
|
399
473
|
).first()
|
|
400
|
-
if
|
|
401
|
-
raise
|
|
402
|
-
|
|
403
|
-
|
|
474
|
+
if api_key is None:
|
|
475
|
+
raise InvalidTokenError()
|
|
476
|
+
username = api_key.username
|
|
477
|
+
else:
|
|
478
|
+
payload: dict = jwt.decode(token, self.secret_key, algorithms=["HS256"])
|
|
479
|
+
username = payload["username"]
|
|
480
|
+
|
|
481
|
+
db_user = session.get(self.DbUser, username)
|
|
404
482
|
if db_user is None:
|
|
405
|
-
raise
|
|
483
|
+
raise InvalidTokenError()
|
|
484
|
+
|
|
485
|
+
except InvalidTokenError:
|
|
486
|
+
raise InvalidInputError(401, "Invalid authorization token", "Invalid authorization token")
|
|
406
487
|
finally:
|
|
407
488
|
session.close()
|
|
408
489
|
|
|
409
490
|
user = self.User.model_validate(db_user)
|
|
410
491
|
return user # type: ignore
|
|
411
492
|
|
|
412
|
-
def
|
|
493
|
+
def get_all_api_keys(self, username: str) -> list[ApiKey]:
|
|
494
|
+
"""
|
|
495
|
+
Get the ID, title, and expiry date of all API keys for a user. Note that the ID is a hash of the API key, not the API key itself.
|
|
496
|
+
"""
|
|
413
497
|
session = self.Session()
|
|
414
498
|
try:
|
|
415
|
-
tokens = session.query(self.
|
|
416
|
-
self.
|
|
417
|
-
self.
|
|
499
|
+
tokens = session.query(self.DbApiKey).filter(
|
|
500
|
+
self.DbApiKey.username == username,
|
|
501
|
+
self.DbApiKey.expires_at >= func.now()
|
|
418
502
|
).all()
|
|
419
503
|
|
|
420
|
-
return [
|
|
504
|
+
return [ApiKey.model_validate(token) for token in tokens]
|
|
421
505
|
finally:
|
|
422
506
|
session.close()
|
|
423
507
|
|
|
424
|
-
def
|
|
508
|
+
def revoke_api_key(self, username: str, api_key_id: str) -> None:
|
|
509
|
+
"""
|
|
510
|
+
Revoke an API key
|
|
511
|
+
"""
|
|
425
512
|
session = self.Session()
|
|
426
513
|
try:
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
self.
|
|
514
|
+
|
|
515
|
+
api_key = session.query(self.DbApiKey).filter(
|
|
516
|
+
self.DbApiKey.username == username,
|
|
517
|
+
self.DbApiKey.id == api_key_id
|
|
430
518
|
).first()
|
|
431
519
|
|
|
432
|
-
if
|
|
433
|
-
raise InvalidInputError(
|
|
520
|
+
if api_key is None:
|
|
521
|
+
raise InvalidInputError(404, "API key not found", f"The API key could not be found: {api_key_id}")
|
|
434
522
|
|
|
435
|
-
session.delete(
|
|
523
|
+
session.delete(api_key)
|
|
436
524
|
session.commit()
|
|
437
525
|
finally:
|
|
438
526
|
session.close()
|
|
@@ -447,5 +535,470 @@ class Authenticator(_t.Generic[User]):
|
|
|
447
535
|
|
|
448
536
|
return user_level.value >= scope.value
|
|
449
537
|
|
|
538
|
+
# OAuth Client Management Methods
|
|
539
|
+
|
|
540
|
+
def generate_secret_and_hash(self) -> tuple[str, str]:
|
|
541
|
+
"""Generate a secure access token and its hash"""
|
|
542
|
+
secret = secrets.token_urlsafe(64)
|
|
543
|
+
secret_hash = pwd_context.hash(secret)
|
|
544
|
+
return secret, secret_hash
|
|
545
|
+
|
|
546
|
+
def _validate_client_registration_request(self, request: ClientRegistrationRequest | ClientUpdateRequest) -> dict:
|
|
547
|
+
updates = {}
|
|
548
|
+
if request.client_name:
|
|
549
|
+
updates['client_name'] = request.client_name
|
|
550
|
+
|
|
551
|
+
# Validate redirect_uris if being updated
|
|
552
|
+
if request.redirect_uris:
|
|
553
|
+
for uri in request.redirect_uris:
|
|
554
|
+
if not self._validate_redirect_uri_format(uri):
|
|
555
|
+
raise InvalidInputError(400, "invalid_redirect_uri", f"Invalid redirect URI format: {uri}")
|
|
556
|
+
updates['redirect_uris'] = json.dumps(request.redirect_uris)
|
|
557
|
+
|
|
558
|
+
# Validate grant_types if being updated
|
|
559
|
+
if request.grant_types:
|
|
560
|
+
if not all(grant_type in c.SUPPORTED_GRANT_TYPES for grant_type in request.grant_types):
|
|
561
|
+
raise InvalidInputError(400, "invalid_grant_types", f"Invalid grant types. Supported grant types are: {c.SUPPORTED_GRANT_TYPES}")
|
|
562
|
+
updates['grant_types'] = ','.join(request.grant_types)
|
|
563
|
+
|
|
564
|
+
# Validate response_types if being updated
|
|
565
|
+
if request.response_types:
|
|
566
|
+
if not all(response_type in c.SUPPORTED_RESPONSE_TYPES for response_type in request.response_types):
|
|
567
|
+
raise InvalidInputError(400, "invalid_response_types", f"Invalid response types. Supported response types are: {c.SUPPORTED_RESPONSE_TYPES}")
|
|
568
|
+
updates['response_types'] = ','.join(request.response_types)
|
|
569
|
+
|
|
570
|
+
# Validate scope if being updated
|
|
571
|
+
if request.scope:
|
|
572
|
+
scopes = request.scope.split()
|
|
573
|
+
if not all(scope in c.SUPPORTED_SCOPES for scope in scopes):
|
|
574
|
+
raise InvalidInputError(400, "invalid_scope", f"Invalid scope. Supported scopes are: {c.SUPPORTED_SCOPES}")
|
|
575
|
+
updates['scope'] = ','.join(scopes)
|
|
576
|
+
|
|
577
|
+
return updates
|
|
578
|
+
|
|
579
|
+
def register_oauth_client(
|
|
580
|
+
self, request: ClientRegistrationRequest, client_management_path_format: str
|
|
581
|
+
) -> ClientRegistrationResponse:
|
|
582
|
+
"""Register a new OAuth client and return client_id, client_secret, and registration_access_token"""
|
|
583
|
+
grant_types = request.grant_types
|
|
584
|
+
if grant_types is None:
|
|
585
|
+
grant_types = ['authorization_code', 'refresh_token']
|
|
586
|
+
|
|
587
|
+
# Validate request
|
|
588
|
+
self._validate_client_registration_request(request)
|
|
589
|
+
|
|
590
|
+
# Generate secure client credentials and registration access token
|
|
591
|
+
client_id = secrets.token_urlsafe(16)
|
|
592
|
+
client_secret, client_secret_hash = self.generate_secret_and_hash()
|
|
593
|
+
|
|
594
|
+
registration_access_token, registration_access_token_hash = self.generate_secret_and_hash()
|
|
595
|
+
registration_client_uri = client_management_path_format.format(client_id=client_id)
|
|
596
|
+
|
|
597
|
+
session = self.Session()
|
|
598
|
+
try:
|
|
599
|
+
oauth_client = self.DbOAuthClient(
|
|
600
|
+
client_id=client_id,
|
|
601
|
+
client_secret_hash=client_secret_hash,
|
|
602
|
+
client_name=request.client_name,
|
|
603
|
+
redirect_uris=json.dumps(request.redirect_uris),
|
|
604
|
+
scope=request.scope,
|
|
605
|
+
grant_types=','.join(grant_types),
|
|
606
|
+
registration_access_token_hash=registration_access_token_hash
|
|
607
|
+
)
|
|
608
|
+
session.add(oauth_client)
|
|
609
|
+
session.commit()
|
|
610
|
+
|
|
611
|
+
return ClientRegistrationResponse(
|
|
612
|
+
client_id=client_id,
|
|
613
|
+
client_secret=client_secret,
|
|
614
|
+
client_name=request.client_name,
|
|
615
|
+
redirect_uris=request.redirect_uris,
|
|
616
|
+
scope=request.scope,
|
|
617
|
+
grant_types=grant_types,
|
|
618
|
+
response_types=request.response_types,
|
|
619
|
+
created_at=datetime.now(timezone.utc),
|
|
620
|
+
is_active=True,
|
|
621
|
+
registration_client_uri=registration_client_uri,
|
|
622
|
+
registration_access_token=registration_access_token,
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
finally:
|
|
626
|
+
session.close()
|
|
627
|
+
|
|
628
|
+
def get_oauth_client_details(self, client_id: str) -> ClientDetailsResponse:
|
|
629
|
+
"""Get OAuth client details with parsed JSON fields"""
|
|
630
|
+
session = self.Session()
|
|
631
|
+
try:
|
|
632
|
+
client = session.get(self.DbOAuthClient, client_id)
|
|
633
|
+
if client is None:
|
|
634
|
+
raise InvalidInputError(404, "invalid_client_id", "Client not found")
|
|
635
|
+
if not client.is_active:
|
|
636
|
+
raise InvalidInputError(404, "invalid_client_id", "Client is no longer active")
|
|
637
|
+
|
|
638
|
+
return ClientDetailsResponse(
|
|
639
|
+
client_id=client.client_id,
|
|
640
|
+
client_name=client.client_name,
|
|
641
|
+
redirect_uris=json.loads(client.redirect_uris),
|
|
642
|
+
scope=client.scope,
|
|
643
|
+
grant_types=client.grant_types.split(','),
|
|
644
|
+
response_types=client.response_types.split(','),
|
|
645
|
+
created_at=client.created_at,
|
|
646
|
+
is_active=client.is_active
|
|
647
|
+
)
|
|
648
|
+
finally:
|
|
649
|
+
session.close()
|
|
650
|
+
|
|
651
|
+
def validate_client_credentials(self, client_id: str, client_secret: str) -> bool:
|
|
652
|
+
"""Validate OAuth client credentials"""
|
|
653
|
+
session = self.Session()
|
|
654
|
+
try:
|
|
655
|
+
client = session.get(self.DbOAuthClient, client_id)
|
|
656
|
+
if client is None or not client.is_active:
|
|
657
|
+
return False
|
|
658
|
+
return pwd_context.verify(client_secret, client.client_secret_hash)
|
|
659
|
+
finally:
|
|
660
|
+
session.close()
|
|
661
|
+
|
|
662
|
+
def validate_redirect_uri(self, client_id: str, redirect_uri: str) -> bool:
|
|
663
|
+
"""Validate that redirect_uri is registered for the client"""
|
|
664
|
+
session = self.Session()
|
|
665
|
+
try:
|
|
666
|
+
client = session.get(self.DbOAuthClient, client_id)
|
|
667
|
+
if client is None or not client.is_active:
|
|
668
|
+
return False
|
|
669
|
+
|
|
670
|
+
registered_uris = json.loads(client.redirect_uris)
|
|
671
|
+
return redirect_uri in registered_uris
|
|
672
|
+
finally:
|
|
673
|
+
session.close()
|
|
674
|
+
|
|
675
|
+
def validate_registration_access_token(self, client_id: str, registration_access_token: str) -> bool:
|
|
676
|
+
"""Validate registration access token for client management operations"""
|
|
677
|
+
session = self.Session()
|
|
678
|
+
try:
|
|
679
|
+
client = session.get(self.DbOAuthClient, client_id)
|
|
680
|
+
if client is None:
|
|
681
|
+
return False
|
|
682
|
+
|
|
683
|
+
return pwd_context.verify(registration_access_token, client.registration_access_token_hash)
|
|
684
|
+
finally:
|
|
685
|
+
session.close()
|
|
686
|
+
|
|
687
|
+
def _validate_redirect_uri_format(self, uri: str) -> bool:
|
|
688
|
+
"""Validate redirect URI format for security"""
|
|
689
|
+
# Basic validation - must be https (except localhost) and not contain fragments
|
|
690
|
+
if '#' in uri:
|
|
691
|
+
return False
|
|
692
|
+
|
|
693
|
+
if uri.startswith('http://'):
|
|
694
|
+
# Only allow http for localhost/development
|
|
695
|
+
if not ('localhost' in uri or '127.0.0.1' in uri):
|
|
696
|
+
return False
|
|
697
|
+
elif not uri.startswith('https://'):
|
|
698
|
+
# Custom schemes allowed for mobile apps
|
|
699
|
+
if '://' not in uri:
|
|
700
|
+
return False
|
|
701
|
+
|
|
702
|
+
return True
|
|
703
|
+
|
|
704
|
+
def update_oauth_client_with_token_rotation(self, client_id: str, request: ClientUpdateRequest) -> ClientUpdateResponse:
|
|
705
|
+
"""Update OAuth client metadata and rotate registration access token"""
|
|
706
|
+
|
|
707
|
+
# Build update dictionary, excluding None values
|
|
708
|
+
updates = self._validate_client_registration_request(request)
|
|
709
|
+
|
|
710
|
+
session = self.Session()
|
|
711
|
+
try:
|
|
712
|
+
client = session.get(self.DbOAuthClient, client_id)
|
|
713
|
+
if client is None:
|
|
714
|
+
raise InvalidInputError(404, "invalid_client_id", f"OAuth client not found: {client_id}")
|
|
715
|
+
if not client.is_active and not request.is_active:
|
|
716
|
+
raise InvalidInputError(400, "invalid_request", "Cannot update a deactivated client")
|
|
717
|
+
|
|
718
|
+
if request.is_active is not None:
|
|
719
|
+
client.is_active = request.is_active
|
|
720
|
+
|
|
721
|
+
# Generate new registration access token
|
|
722
|
+
new_registration_access_token, new_registration_access_token_hash = self.generate_secret_and_hash()
|
|
723
|
+
|
|
724
|
+
# Update client fields
|
|
725
|
+
for key, value in updates.items():
|
|
726
|
+
if hasattr(client, key):
|
|
727
|
+
setattr(client, key, value)
|
|
728
|
+
|
|
729
|
+
# Update registration access token
|
|
730
|
+
client.registration_access_token_hash = new_registration_access_token_hash
|
|
731
|
+
|
|
732
|
+
session.commit()
|
|
733
|
+
|
|
734
|
+
return ClientUpdateResponse(
|
|
735
|
+
client_id=client.client_id,
|
|
736
|
+
client_name=client.client_name,
|
|
737
|
+
redirect_uris=json.loads(client.redirect_uris),
|
|
738
|
+
scope=client.scope,
|
|
739
|
+
grant_types=client.grant_types.split(','),
|
|
740
|
+
response_types=client.response_types.split(','),
|
|
741
|
+
created_at=client.created_at,
|
|
742
|
+
is_active=client.is_active,
|
|
743
|
+
registration_access_token=new_registration_access_token,
|
|
744
|
+
)
|
|
745
|
+
|
|
746
|
+
finally:
|
|
747
|
+
session.close()
|
|
748
|
+
|
|
749
|
+
def revoke_oauth_client(self, client_id: str) -> None:
|
|
750
|
+
"""Revoke (deactivate) an OAuth client"""
|
|
751
|
+
session = self.Session()
|
|
752
|
+
try:
|
|
753
|
+
client = session.get(self.DbOAuthClient, client_id)
|
|
754
|
+
if client is None:
|
|
755
|
+
raise InvalidInputError(404, "client_not_found", f"OAuth client not found: {client_id}")
|
|
756
|
+
|
|
757
|
+
client.is_active = False
|
|
758
|
+
session.commit()
|
|
759
|
+
finally:
|
|
760
|
+
session.close()
|
|
761
|
+
|
|
762
|
+
# Authorization Code Flow Methods
|
|
763
|
+
|
|
764
|
+
def create_authorization_code(
|
|
765
|
+
self, client_id: str, username: str, redirect_uri: str, scope: str,
|
|
766
|
+
code_challenge: str, code_challenge_method: str = "S256"
|
|
767
|
+
) -> str:
|
|
768
|
+
"""Create and store an authorization code for the authorization code flow"""
|
|
769
|
+
|
|
770
|
+
# Validate client_id
|
|
771
|
+
client_details = self.get_oauth_client_details(client_id)
|
|
772
|
+
if client_details is None or not client_details.is_active:
|
|
773
|
+
raise InvalidInputError(400, "invalid_client", "Invalid or inactive client")
|
|
774
|
+
|
|
775
|
+
# Validate redirect_uri
|
|
776
|
+
if not self.validate_redirect_uri(client_id, redirect_uri):
|
|
777
|
+
raise InvalidInputError(400, "invalid_redirect_uri", "Invalid redirect_uri for this client")
|
|
778
|
+
|
|
779
|
+
# Validate PKCE parameters
|
|
780
|
+
if not code_challenge:
|
|
781
|
+
raise InvalidInputError(400, "invalid_request", "code_challenge is required")
|
|
782
|
+
|
|
783
|
+
if code_challenge_method not in ["S256"]:
|
|
784
|
+
raise InvalidInputError(400, "invalid_request", "code_challenge_method must be 'S256'")
|
|
785
|
+
|
|
786
|
+
# Generate authorization code
|
|
787
|
+
code = secrets.token_urlsafe(48)
|
|
788
|
+
|
|
789
|
+
# Set expiration (10 minutes from now)
|
|
790
|
+
created_at = datetime.now(timezone.utc)
|
|
791
|
+
expires_at = created_at + timedelta(minutes=10)
|
|
792
|
+
|
|
793
|
+
session = self.Session()
|
|
794
|
+
try:
|
|
795
|
+
auth_code = self.DbAuthorizationCode(
|
|
796
|
+
code=code,
|
|
797
|
+
client_id=client_id,
|
|
798
|
+
username=username,
|
|
799
|
+
redirect_uri=redirect_uri,
|
|
800
|
+
scope=scope,
|
|
801
|
+
code_challenge=code_challenge,
|
|
802
|
+
code_challenge_method=code_challenge_method,
|
|
803
|
+
created_at=created_at,
|
|
804
|
+
expires_at=expires_at,
|
|
805
|
+
used=False
|
|
806
|
+
)
|
|
807
|
+
session.add(auth_code)
|
|
808
|
+
session.commit()
|
|
809
|
+
|
|
810
|
+
return code
|
|
811
|
+
|
|
812
|
+
finally:
|
|
813
|
+
session.close()
|
|
814
|
+
|
|
815
|
+
def exchange_authorization_code(
|
|
816
|
+
self, code: str, client_id: str, redirect_uri: str, code_verifier: str, access_token_expiry_minutes: int
|
|
817
|
+
) -> TokenResponse:
|
|
818
|
+
"""
|
|
819
|
+
Exchange authorization code for access and refresh tokens
|
|
820
|
+
Returns (access_token, refresh_token)
|
|
821
|
+
"""
|
|
822
|
+
session = self.Session()
|
|
823
|
+
try:
|
|
824
|
+
# Get and validate authorization code
|
|
825
|
+
auth_code = session.query(self.DbAuthorizationCode).filter(
|
|
826
|
+
self.DbAuthorizationCode.code == code,
|
|
827
|
+
self.DbAuthorizationCode.client_id == client_id,
|
|
828
|
+
self.DbAuthorizationCode.expires_at >= func.now(),
|
|
829
|
+
self.DbAuthorizationCode.used == False
|
|
830
|
+
).first()
|
|
831
|
+
|
|
832
|
+
if auth_code is None:
|
|
833
|
+
raise InvalidInputError(400, "invalid_grant", "Invalid authorization code")
|
|
834
|
+
|
|
835
|
+
# Validate redirect URI
|
|
836
|
+
if auth_code.redirect_uri != redirect_uri:
|
|
837
|
+
raise InvalidInputError(400, "invalid_grant", "Redirect URI mismatch")
|
|
838
|
+
|
|
839
|
+
if not u.validate_pkce_challenge(code_verifier, auth_code.code_challenge):
|
|
840
|
+
raise InvalidInputError(400, "invalid_grant", "Invalid code_verifier")
|
|
841
|
+
|
|
842
|
+
# Get user
|
|
843
|
+
user = session.get(self.DbUser, auth_code.username)
|
|
844
|
+
if user is None:
|
|
845
|
+
raise InvalidInputError(400, "invalid_grant", "User not found")
|
|
846
|
+
|
|
847
|
+
# Mark authorization code as used
|
|
848
|
+
auth_code.used = True
|
|
849
|
+
|
|
850
|
+
# Generate tokens
|
|
851
|
+
user_obj = self.User.model_validate(user)
|
|
852
|
+
access_token, token_expires_at = self.create_access_token(user_obj, expiry_minutes=access_token_expiry_minutes)
|
|
853
|
+
access_token_hash = pwd_context.hash(access_token)
|
|
854
|
+
|
|
855
|
+
# Generate refresh token
|
|
856
|
+
refresh_token, refresh_token_hash = self.generate_secret_and_hash()
|
|
857
|
+
refresh_expires_at = datetime.now(timezone.utc) + timedelta(days=30)
|
|
858
|
+
|
|
859
|
+
oauth_token = self.DbOAuthToken(
|
|
860
|
+
access_token_hash=access_token_hash,
|
|
861
|
+
refresh_token_hash=refresh_token_hash,
|
|
862
|
+
client_id=client_id,
|
|
863
|
+
username=auth_code.username,
|
|
864
|
+
scope=auth_code.scope,
|
|
865
|
+
access_token_expires_at=token_expires_at,
|
|
866
|
+
refresh_token_expires_at=refresh_expires_at
|
|
867
|
+
)
|
|
868
|
+
session.add(oauth_token)
|
|
869
|
+
|
|
870
|
+
session.commit()
|
|
871
|
+
|
|
872
|
+
return TokenResponse(
|
|
873
|
+
access_token=access_token,
|
|
874
|
+
expires_in=access_token_expiry_minutes*60,
|
|
875
|
+
refresh_token=refresh_token
|
|
876
|
+
)
|
|
877
|
+
|
|
878
|
+
finally:
|
|
879
|
+
session.close()
|
|
880
|
+
|
|
881
|
+
def refresh_oauth_access_token(self, refresh_token: str, client_id: str, access_token_expiry_minutes: int) -> TokenResponse:
|
|
882
|
+
"""
|
|
883
|
+
Refresh OAuth access token using refresh token
|
|
884
|
+
Returns (access_token, new_refresh_token)
|
|
885
|
+
"""
|
|
886
|
+
session = self.Session()
|
|
887
|
+
try:
|
|
888
|
+
# Validate client
|
|
889
|
+
client = session.get(self.DbOAuthClient, client_id)
|
|
890
|
+
if client is None or not client.is_active:
|
|
891
|
+
raise InvalidInputError(400, "invalid_client", "Invalid or inactive client")
|
|
892
|
+
|
|
893
|
+
# Find active refresh token for this client
|
|
894
|
+
oauth_token = session.query(self.DbOAuthToken).filter(
|
|
895
|
+
self.DbOAuthToken.client_id == client_id,
|
|
896
|
+
self.DbOAuthToken.refresh_token_expires_at >= func.now(),
|
|
897
|
+
self.DbOAuthToken.is_revoked == False
|
|
898
|
+
).first()
|
|
899
|
+
|
|
900
|
+
# Find the token that matches our refresh token
|
|
901
|
+
if oauth_token is None or not pwd_context.verify(refresh_token, oauth_token.refresh_token_hash):
|
|
902
|
+
raise InvalidInputError(400, "invalid_grant", "Invalid or expired refresh token")
|
|
903
|
+
|
|
904
|
+
# Get user
|
|
905
|
+
user = session.get(self.DbUser, oauth_token.username)
|
|
906
|
+
if user is None:
|
|
907
|
+
raise InvalidInputError(400, "invalid_client", "User not found")
|
|
908
|
+
|
|
909
|
+
# Check secret key is available
|
|
910
|
+
if self.secret_key is None:
|
|
911
|
+
raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required for OAuth token operations")
|
|
912
|
+
|
|
913
|
+
# Generate new tokens
|
|
914
|
+
user_obj = self.User.model_validate(user)
|
|
915
|
+
access_token, token_expires_at = self.create_access_token(user_obj, expiry_minutes=access_token_expiry_minutes)
|
|
916
|
+
access_token_hash = pwd_context.hash(access_token)
|
|
917
|
+
|
|
918
|
+
# Generate new refresh token
|
|
919
|
+
new_refresh_token, new_refresh_token_hash = self.generate_secret_and_hash()
|
|
920
|
+
refresh_expires_at = datetime.now(timezone.utc) + timedelta(days=30)
|
|
921
|
+
|
|
922
|
+
# Revoke old token
|
|
923
|
+
oauth_token.is_revoked = True
|
|
924
|
+
|
|
925
|
+
# Create new token entry
|
|
926
|
+
new_oauth_token = self.DbOAuthToken(
|
|
927
|
+
access_token_hash=access_token_hash,
|
|
928
|
+
refresh_token_hash=new_refresh_token_hash,
|
|
929
|
+
client_id=client_id,
|
|
930
|
+
username=oauth_token.username,
|
|
931
|
+
scope=oauth_token.scope,
|
|
932
|
+
access_token_expires_at=token_expires_at,
|
|
933
|
+
refresh_token_expires_at=refresh_expires_at
|
|
934
|
+
)
|
|
935
|
+
session.add(new_oauth_token)
|
|
936
|
+
|
|
937
|
+
session.commit()
|
|
938
|
+
|
|
939
|
+
return TokenResponse(
|
|
940
|
+
access_token=access_token,
|
|
941
|
+
expires_in=access_token_expiry_minutes*60,
|
|
942
|
+
refresh_token=new_refresh_token
|
|
943
|
+
)
|
|
944
|
+
|
|
945
|
+
finally:
|
|
946
|
+
session.close()
|
|
947
|
+
|
|
948
|
+
def revoke_oauth_token(self, client_id: str, token: str, token_type_hint: str | None = None) -> None:
|
|
949
|
+
"""
|
|
950
|
+
Revoke an OAuth refresh token
|
|
951
|
+
token_type_hint is optional or must be 'refresh_token'. Revoking access token is not supported yet.
|
|
952
|
+
"""
|
|
953
|
+
if token_type_hint and token_type_hint != 'refresh_token':
|
|
954
|
+
raise InvalidInputError(400, "invalid_request", "Only refresh tokens can be revoked")
|
|
955
|
+
|
|
956
|
+
session = self.Session()
|
|
957
|
+
try:
|
|
958
|
+
# Validate client
|
|
959
|
+
client = session.get(self.DbOAuthClient, client_id)
|
|
960
|
+
if client is None or not client.is_active:
|
|
961
|
+
raise InvalidInputError(400, "invalid_client", "Invalid or inactive client")
|
|
962
|
+
|
|
963
|
+
# Get all potentially matching tokens
|
|
964
|
+
oauth_tokens = session.query(self.DbOAuthToken).filter(
|
|
965
|
+
self.DbOAuthToken.client_id == client_id,
|
|
966
|
+
self.DbOAuthToken.refresh_token_expires_at >= func.now(),
|
|
967
|
+
self.DbOAuthToken.is_revoked == False
|
|
968
|
+
).all()
|
|
969
|
+
|
|
970
|
+
# Find the token that matches
|
|
971
|
+
oauth_token = None
|
|
972
|
+
for token_obj in oauth_tokens:
|
|
973
|
+
if pwd_context.verify(token, token_obj.refresh_token_hash):
|
|
974
|
+
oauth_token = token_obj
|
|
975
|
+
break
|
|
976
|
+
|
|
977
|
+
# Revoke token if found (per OAuth spec, always return success)
|
|
978
|
+
if oauth_token:
|
|
979
|
+
oauth_token.is_revoked = True
|
|
980
|
+
session.commit()
|
|
981
|
+
|
|
982
|
+
finally:
|
|
983
|
+
session.close()
|
|
984
|
+
|
|
450
985
|
def close(self) -> None:
|
|
451
986
|
self.engine.dispose()
|
|
987
|
+
|
|
988
|
+
|
|
989
|
+
def provider(name: str, label: str, icon: str):
|
|
990
|
+
"""
|
|
991
|
+
Decorator to register an authentication provider
|
|
992
|
+
|
|
993
|
+
Arguments:
|
|
994
|
+
name: The name of the provider (must be unique, e.g. 'google')
|
|
995
|
+
label: The label of the provider (e.g. 'Google')
|
|
996
|
+
icon: The URL of the icon of the provider (e.g. 'https://www.google.com/favicon.ico')
|
|
997
|
+
"""
|
|
998
|
+
def decorator(func: Callable[[AuthProviderArgs], ProviderConfigs]):
|
|
999
|
+
def wrapper(sqrl: AuthProviderArgs):
|
|
1000
|
+
provider_configs = func(sqrl)
|
|
1001
|
+
return AuthProvider(name=name, label=label, icon=icon, provider_configs=provider_configs)
|
|
1002
|
+
Authenticator.providers.append(wrapper)
|
|
1003
|
+
return wrapper
|
|
1004
|
+
return decorator
|