squirrels 0.5.0rc0__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of squirrels might be problematic. Click here for more details.

Files changed (108) hide show
  1. dateutils/__init__.py +6 -0
  2. dateutils/_enums.py +25 -0
  3. squirrels/dateutils.py → dateutils/_implementation.py +58 -111
  4. dateutils/types.py +6 -0
  5. squirrels/__init__.py +10 -12
  6. squirrels/_api_routes/__init__.py +5 -0
  7. squirrels/_api_routes/auth.py +271 -0
  8. squirrels/_api_routes/base.py +171 -0
  9. squirrels/_api_routes/dashboards.py +158 -0
  10. squirrels/_api_routes/data_management.py +148 -0
  11. squirrels/_api_routes/datasets.py +265 -0
  12. squirrels/_api_routes/oauth2.py +298 -0
  13. squirrels/_api_routes/project.py +252 -0
  14. squirrels/_api_server.py +245 -781
  15. squirrels/_arguments/__init__.py +0 -0
  16. squirrels/{arguments → _arguments}/init_time_args.py +7 -2
  17. squirrels/{arguments → _arguments}/run_time_args.py +13 -35
  18. squirrels/_auth.py +720 -212
  19. squirrels/_command_line.py +81 -41
  20. squirrels/_compile_prompts.py +147 -0
  21. squirrels/_connection_set.py +16 -7
  22. squirrels/_constants.py +29 -9
  23. squirrels/{_dashboards_io.py → _dashboards.py} +87 -6
  24. squirrels/_data_sources.py +570 -0
  25. squirrels/{dataset_result.py → _dataset_types.py} +2 -4
  26. squirrels/_exceptions.py +9 -37
  27. squirrels/_initializer.py +83 -59
  28. squirrels/_logging.py +117 -0
  29. squirrels/_manifest.py +129 -62
  30. squirrels/_model_builder.py +10 -52
  31. squirrels/_model_configs.py +3 -3
  32. squirrels/_model_queries.py +1 -1
  33. squirrels/_models.py +249 -118
  34. squirrels/{package_data → _package_data}/base_project/.env +16 -4
  35. squirrels/{package_data → _package_data}/base_project/.env.example +15 -3
  36. squirrels/{package_data → _package_data}/base_project/connections.yml +4 -3
  37. squirrels/{package_data → _package_data}/base_project/dashboards/dashboard_example.py +4 -4
  38. squirrels/_package_data/base_project/dashboards/dashboard_example.yml +22 -0
  39. squirrels/{package_data → _package_data}/base_project/duckdb_init.sql +1 -0
  40. squirrels/_package_data/base_project/macros/macros_example.sql +17 -0
  41. squirrels/{package_data → _package_data}/base_project/models/builds/build_example.py +2 -2
  42. squirrels/{package_data → _package_data}/base_project/models/builds/build_example.sql +1 -1
  43. squirrels/{package_data → _package_data}/base_project/models/builds/build_example.yml +2 -0
  44. squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +17 -0
  45. squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +32 -0
  46. squirrels/_package_data/base_project/models/federates/federate_example.py +48 -0
  47. squirrels/_package_data/base_project/models/federates/federate_example.sql +21 -0
  48. squirrels/{package_data → _package_data}/base_project/models/federates/federate_example.yml +7 -7
  49. squirrels/{package_data → _package_data}/base_project/models/sources.yml +5 -6
  50. squirrels/{package_data → _package_data}/base_project/parameters.yml +32 -45
  51. squirrels/_package_data/base_project/pyconfigs/connections.py +18 -0
  52. squirrels/{package_data → _package_data}/base_project/pyconfigs/context.py +31 -22
  53. squirrels/_package_data/base_project/pyconfigs/parameters.py +141 -0
  54. squirrels/_package_data/base_project/pyconfigs/user.py +44 -0
  55. squirrels/{package_data → _package_data}/base_project/seeds/seed_categories.yml +1 -1
  56. squirrels/{package_data → _package_data}/base_project/seeds/seed_subcategories.yml +1 -1
  57. squirrels/_package_data/base_project/squirrels.yml.j2 +61 -0
  58. squirrels/_package_data/templates/dataset_results.html +112 -0
  59. squirrels/_package_data/templates/oauth_login.html +271 -0
  60. squirrels/_package_data/templates/squirrels_studio.html +20 -0
  61. squirrels/_parameter_configs.py +76 -55
  62. squirrels/_parameter_options.py +348 -0
  63. squirrels/_parameter_sets.py +53 -45
  64. squirrels/_parameters.py +1664 -0
  65. squirrels/_project.py +403 -242
  66. squirrels/_py_module.py +3 -2
  67. squirrels/_request_context.py +33 -0
  68. squirrels/_schemas/__init__.py +0 -0
  69. squirrels/_schemas/auth_models.py +167 -0
  70. squirrels/_schemas/query_param_models.py +75 -0
  71. squirrels/{_api_response_models.py → _schemas/response_models.py} +48 -18
  72. squirrels/_seeds.py +1 -1
  73. squirrels/_sources.py +23 -19
  74. squirrels/_utils.py +121 -39
  75. squirrels/_version.py +1 -1
  76. squirrels/arguments.py +7 -0
  77. squirrels/auth.py +4 -0
  78. squirrels/connections.py +3 -0
  79. squirrels/dashboards.py +2 -81
  80. squirrels/data_sources.py +14 -563
  81. squirrels/parameter_options.py +13 -348
  82. squirrels/parameters.py +14 -1266
  83. squirrels/types.py +16 -0
  84. {squirrels-0.5.0rc0.dist-info → squirrels-0.5.1.dist-info}/METADATA +42 -30
  85. squirrels-0.5.1.dist-info/RECORD +98 -0
  86. squirrels/package_data/base_project/dashboards/dashboard_example.yml +0 -22
  87. squirrels/package_data/base_project/macros/macros_example.sql +0 -15
  88. squirrels/package_data/base_project/models/dbviews/dbview_example.sql +0 -12
  89. squirrels/package_data/base_project/models/dbviews/dbview_example.yml +0 -26
  90. squirrels/package_data/base_project/models/federates/federate_example.py +0 -44
  91. squirrels/package_data/base_project/models/federates/federate_example.sql +0 -17
  92. squirrels/package_data/base_project/pyconfigs/connections.py +0 -14
  93. squirrels/package_data/base_project/pyconfigs/parameters.py +0 -93
  94. squirrels/package_data/base_project/pyconfigs/user.py +0 -23
  95. squirrels/package_data/base_project/squirrels.yml.j2 +0 -71
  96. squirrels-0.5.0rc0.dist-info/RECORD +0 -70
  97. /squirrels/{package_data → _package_data}/base_project/assets/expenses.db +0 -0
  98. /squirrels/{package_data → _package_data}/base_project/assets/weather.db +0 -0
  99. /squirrels/{package_data → _package_data}/base_project/docker/.dockerignore +0 -0
  100. /squirrels/{package_data → _package_data}/base_project/docker/Dockerfile +0 -0
  101. /squirrels/{package_data → _package_data}/base_project/docker/compose.yml +0 -0
  102. /squirrels/{package_data/base_project/.gitignore → _package_data/base_project/gitignore} +0 -0
  103. /squirrels/{package_data → _package_data}/base_project/seeds/seed_categories.csv +0 -0
  104. /squirrels/{package_data → _package_data}/base_project/seeds/seed_subcategories.csv +0 -0
  105. /squirrels/{package_data → _package_data}/base_project/tmp/.gitignore +0 -0
  106. {squirrels-0.5.0rc0.dist-info → squirrels-0.5.1.dist-info}/WHEEL +0 -0
  107. {squirrels-0.5.0rc0.dist-info → squirrels-0.5.1.dist-info}/entry_points.txt +0 -0
  108. {squirrels-0.5.0rc0.dist-info → squirrels-0.5.1.dist-info}/licenses/LICENSE +0 -0
squirrels/_auth.py CHANGED
@@ -1,96 +1,137 @@
1
+ from typing import Callable, Any
1
2
  from datetime import datetime, timedelta, timezone
2
- from enum import Enum
3
3
  from functools import cached_property
4
4
  from jwt.exceptions import InvalidTokenError
5
5
  from passlib.context import CryptContext
6
- from pydantic import BaseModel, ConfigDict, ValidationError
7
- from pydantic_core import PydanticUndefined
6
+ from pydantic import ValidationError
8
7
  from sqlalchemy import create_engine, Engine, func, inspect, text, ForeignKey
9
- from sqlalchemy import Column, String, Integer, Float, Boolean
10
8
  from sqlalchemy.orm import declarative_base, sessionmaker, Mapped, mapped_column
11
- import jwt, types, typing as _t, uuid
9
+ import jwt, uuid, secrets, json
12
10
 
13
11
  from ._manifest import PermissionScope
14
12
  from ._py_module import PyModule
15
13
  from ._exceptions import InvalidInputError, ConfigurationError
14
+ from ._arguments.init_time_args import AuthProviderArgs
15
+ from ._schemas.auth_models import (
16
+ CustomUserFields, AbstractUser, GuestUser, RegisteredUser, ApiKey, UserField, AuthProvider, ProviderConfigs,
17
+ ClientRegistrationRequest, ClientUpdateRequest, ClientDetailsResponse, ClientRegistrationResponse,
18
+ ClientUpdateResponse, TokenResponse
19
+ )
16
20
  from . import _utils as u, _constants as c
17
21
 
18
22
  pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
19
23
 
20
- reserved_fields = ["username", "is_admin"]
21
- disallowed_fields = ["password", "password_hash", "created_at", "token_id", "exp"]
24
+ ProviderFunctionType = Callable[[AuthProviderArgs], AuthProvider]
22
25
 
23
- class BaseUser(BaseModel):
24
- model_config = ConfigDict(from_attributes=True)
25
- username: str
26
- is_admin: bool = False
27
-
28
- @classmethod
29
- def dropped_columns(cls):
30
- return []
31
-
32
- def __hash__(self):
33
- return hash(self.username)
34
-
35
- User = _t.TypeVar('User', bound=BaseUser)
36
-
37
- class AccessToken(BaseModel):
38
- model_config = ConfigDict(from_attributes=True)
39
- token_id: str
40
- title: str
41
- username: str
42
- created_at: datetime
43
- expires_at: datetime
44
26
 
27
+ class Authenticator:
28
+ providers: list[ProviderFunctionType] = [] # static variable to stage providers
45
29
 
46
- class UserField(BaseModel):
47
- name: str
48
- type: str
49
- nullable: bool
50
- enum: list[str] | None
51
- default: _t.Any | None
52
-
53
-
54
- class Authenticator(_t.Generic[User]):
55
- def __init__(self, logger: u.Logger, base_path: str, env_vars: dict[str, str], *, sa_engine: Engine | None = None, cls: type[User] | None = None):
30
+ def __init__(
31
+ self, logger: u.Logger, base_path: str, auth_args: AuthProviderArgs, provider_functions: list[ProviderFunctionType],
32
+ custom_user_fields_cls: type[CustomUserFields], *, sa_engine: Engine | None = None, external_only: bool = False
33
+ ):
56
34
  self.logger = logger
57
- self.env_vars = env_vars
35
+ self.env_vars = auth_args.env_vars
58
36
  self.secret_key = self.env_vars.get(c.SQRL_SECRET_KEY)
37
+ self.external_only = external_only
59
38
 
60
39
  # Create a new declarative base for this instance
61
40
  self.Base = declarative_base()
62
41
 
63
- # Define DbBaseUser class for this instance
64
- class DbBaseUser(self.Base):
42
+ # Define DbUser class for this instance
43
+ class DbUser(self.Base):
65
44
  __tablename__ = 'users'
66
45
  __table_args__ = {'extend_existing': True}
67
46
  username: Mapped[str] = mapped_column(primary_key=True)
68
- is_admin: Mapped[bool] = mapped_column(nullable=False, default=False)
47
+ access_level: Mapped[str] = mapped_column(nullable=False, default="member")
69
48
  password_hash: Mapped[str] = mapped_column(nullable=False)
49
+ custom_fields: Mapped[str] = mapped_column(nullable=False, default="{}")
70
50
  created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
71
51
 
72
- # Define DbAccessToken class for this instance
73
- class DbAccessToken(self.Base):
74
- __tablename__ = 'access_tokens'
52
+ # Define DbApiKey class for this instance
53
+ class DbApiKey(self.Base):
54
+ __tablename__ = 'api_keys'
75
55
 
76
- token_id: Mapped[str] = mapped_column(primary_key=True, default=lambda: str(uuid.uuid4()))
56
+ id: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
57
+ hashed_key: Mapped[str] = mapped_column(unique=True, nullable=False)
77
58
  title: Mapped[str] = mapped_column(nullable=False)
78
59
  username: Mapped[str] = mapped_column(ForeignKey('users.username', ondelete='CASCADE'), nullable=False)
79
60
  created_at: Mapped[datetime] = mapped_column(nullable=False)
80
61
  expires_at: Mapped[datetime] = mapped_column(nullable=False)
81
-
62
+
63
+ def __repr__(self):
64
+ return f"<DbApiKey(id='{self.id}', username='{self.username}')>"
65
+
66
+ # Define DbOAuthClient class for this instance
67
+ class DbOAuthClient(self.Base):
68
+ __tablename__ = 'oauth_clients'
69
+
70
+ client_id: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
71
+ client_secret_hash: Mapped[str] = mapped_column(nullable=False)
72
+ client_name: Mapped[str] = mapped_column(nullable=False)
73
+ redirect_uris: Mapped[str] = mapped_column(nullable=False) # JSON array of allowed redirect URIs
74
+ scope: Mapped[str] = mapped_column(nullable=False, default='read')
75
+ grant_types: Mapped[str] = mapped_column(nullable=False, default='authorization_code,refresh_token')
76
+ response_types: Mapped[str] = mapped_column(nullable=False, default='code')
77
+ client_type: Mapped[str] = mapped_column(nullable=False, default='confidential') # 'confidential' or 'public'
78
+ registration_access_token_hash: Mapped[str] = mapped_column(nullable=False) # Token for client management
79
+ created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
80
+ is_active: Mapped[bool] = mapped_column(nullable=False, default=True)
81
+
82
+ def __repr__(self):
83
+ return f"<DbOAuthClient(client_id='{self.client_id}', name='{self.client_name}')>"
84
+
85
+ # Define DbAuthorizationCode class for this instance
86
+ class DbAuthorizationCode(self.Base):
87
+ __tablename__ = 'authorization_codes'
88
+
89
+ code: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
90
+ client_id: Mapped[str] = mapped_column(ForeignKey('oauth_clients.client_id', ondelete='CASCADE'), nullable=False)
91
+ username: Mapped[str] = mapped_column(ForeignKey('users.username', ondelete='CASCADE'), nullable=False)
92
+ redirect_uri: Mapped[str] = mapped_column(nullable=False)
93
+ scope: Mapped[str] = mapped_column(nullable=True)
94
+ code_challenge: Mapped[str] = mapped_column(nullable=False) # PKCE always required
95
+ code_challenge_method: Mapped[str] = mapped_column(nullable=False) # only S256 is supported
96
+ created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
97
+ expires_at: Mapped[datetime] = mapped_column(nullable=False) # 10 minutes from creation
98
+ used: Mapped[bool] = mapped_column(nullable=False, default=False)
99
+
100
+ def __repr__(self):
101
+ return f"<DbAuthorizationCode(code='{self.code[:8]}...', client_id='{self.client_id}')>"
102
+
103
+ # Define DbOAuthToken class for this instance
104
+ class DbOAuthToken(self.Base):
105
+ __tablename__ = 'oauth_tokens'
106
+
107
+ token_id: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
108
+ access_token_hash: Mapped[str] = mapped_column(unique=True, nullable=False)
109
+ refresh_token_hash: Mapped[str] = mapped_column(unique=True, nullable=True) # NULL for client_credentials grants
110
+ client_id: Mapped[str] = mapped_column(ForeignKey('oauth_clients.client_id', ondelete='CASCADE'), nullable=False)
111
+ username: Mapped[str] = mapped_column(ForeignKey('users.username', ondelete='CASCADE'), nullable=False)
112
+ scope: Mapped[str] = mapped_column(nullable=True)
113
+ token_type: Mapped[str] = mapped_column(nullable=False, default='Bearer')
114
+ created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
115
+ access_token_expires_at: Mapped[datetime] = mapped_column(nullable=False) # Uses SQRL_AUTH_TOKEN_EXPIRE_MINUTES
116
+ refresh_token_expires_at: Mapped[datetime] = mapped_column(nullable=True) # 30 days from creation, NULL for client_credentials
117
+ is_revoked: Mapped[bool] = mapped_column(nullable=False, default=False)
118
+
82
119
  def __repr__(self):
83
- return f"<AccessToken(token_id='{self.token_id}', username='{self.username}')>"
120
+ return f"<DbOAuthToken(token_id='{self.token_id}', client_id='{self.client_id}', username='{self.username}')>"
84
121
 
85
- self.DbBaseUser = DbBaseUser
86
- self.DbAccessToken = DbAccessToken
122
+ self.DbApiKey = DbApiKey
123
+ self.DbOAuthClient = DbOAuthClient
124
+ self.DbAuthorizationCode = DbAuthorizationCode
125
+ self.DbOAuthToken = DbOAuthToken
87
126
 
88
- self.User = self._get_user_model(base_path) if cls is None else cls
89
- self.DbUser: type[DbBaseUser] = self._initialize_db_user_model(self.User)
127
+ self.CustomUserFields = custom_user_fields_cls
128
+ self.DbUser = DbUser
129
+
130
+ self.auth_providers = [provider_function(auth_args) for provider_function in provider_functions]
90
131
 
91
132
  if sa_engine is None:
92
- sqlite_relative_path = env_vars.get(c.SQRL_AUTH_DB_FILE_PATH, f"{c.TARGET_FOLDER}/{c.DB_FILE}")
93
- sqlite_path = u.Path(base_path, sqlite_relative_path)
133
+ raw_sqlite_path = self.env_vars.get(c.SQRL_AUTH_DB_FILE_PATH, f"{{project_path}}/{c.TARGET_FOLDER}/{c.DB_FILE}")
134
+ sqlite_path = u.Path(raw_sqlite_path.format(project_path=base_path))
94
135
  sqlite_path.parent.mkdir(parents=True, exist_ok=True)
95
136
  self.engine = create_engine(f"sqlite:///{str(sqlite_path)}")
96
137
  else:
@@ -106,67 +147,21 @@ class Authenticator(_t.Generic[User]):
106
147
 
107
148
  self.Session = sessionmaker(bind=self.engine)
108
149
 
109
- self._initialize_db(self.User, self.DbUser, self.engine, self.Session)
150
+ self._initialize_db()
110
151
 
111
- def _get_user_model(self, base_path: str) -> type[BaseUser]:
112
- user_module_path = u.Path(base_path, c.PYCONFIGS_FOLDER, c.USER_FILE)
113
- user_module = PyModule(user_module_path)
114
- User = user_module.get_func_or_class("User", default_attr=BaseUser)
115
- if not issubclass(User, BaseUser):
116
- raise ConfigurationError(f"User class in '{c.USER_FILE}' must inherit from BaseUser")
117
- return User
118
-
119
- def _initialize_db_user_model(self, *args) -> type:
120
- """Get the user model with any custom attributes defined in user.py"""
121
- attrs = {}
122
-
123
- # Iterate over all fields in the User model
124
- for field_name, field in self.User.model_fields.items():
125
- if field_name in reserved_fields:
126
- continue
127
- if field_name in disallowed_fields:
128
- raise ConfigurationError(f"Field name '{field_name}' is disallowed in the User model and cannot be used")
129
-
130
- field_type = field.annotation
131
- if _t.get_origin(field_type) in (_t.Union, types.UnionType):
132
- field_type = _t.get_args(field_type)[0]
133
- nullable = True
134
- else:
135
- nullable = False
152
+ def _convert_db_user_to_user(self, db_user) -> RegisteredUser:
153
+ """Convert a database user to an AbstractUser object"""
154
+ # Deserialize custom_fields JSON and merge with defaults
155
+ custom_fields_json = json.loads(db_user.custom_fields) if db_user.custom_fields else {}
156
+ custom_fields = self.CustomUserFields(**custom_fields_json)
157
+
158
+ return RegisteredUser(
159
+ username=db_user.username,
160
+ access_level=db_user.access_level,
161
+ custom_fields=custom_fields
162
+ )
136
163
 
137
- if _t.get_origin(field_type) == _t.Literal:
138
- field_type = str
139
-
140
- # Map Python types and default values to SQLAlchemy columns
141
- default_value = field.default
142
- if default_value is PydanticUndefined:
143
- raise ConfigurationError(f"No default value found for field '{field_name}' in User model")
144
- elif not nullable and default_value is None:
145
- raise ConfigurationError(f"Default value for non-nullable field '{field_name}' was set as None in User model")
146
- elif default_value is not None and type(default_value) is not field_type:
147
- raise ConfigurationError(f"Default value for field '{field_name}' does not match field type in User model")
148
-
149
- if field_type == str:
150
- col_type = String
151
- elif field_type == int:
152
- col_type = Integer
153
- elif field_type == float:
154
- col_type = Float
155
- elif field_type == bool:
156
- col_type = Boolean
157
- elif isinstance(field_type, type) and issubclass(field_type, Enum):
158
- col_type = String
159
- default_value = default_value.value
160
- else:
161
- continue
162
-
163
- attrs[field_name] = Column(col_type, nullable=nullable, default=default_value) # type: ignore
164
-
165
- # Create the sqlalchemy model class
166
- DbUser = type('DbUser', (self.DbBaseUser,), attrs)
167
- return DbUser
168
-
169
- def _initialize_db(self, *args): # TODO: Use logger instead of print
164
+ def _initialize_db(self): # TODO: Use logger instead of print
170
165
  session = self.Session()
171
166
  try:
172
167
  # Get existing columns in the database
@@ -174,14 +169,12 @@ class Authenticator(_t.Generic[User]):
174
169
  existing_columns = {col['name'] for col in inspector.get_columns('users')}
175
170
 
176
171
  # Get all columns defined in the model
177
- dropped_columns = set(self.User.dropped_columns())
178
- model_columns = set(self.DbUser.__table__.columns.keys()) - dropped_columns
172
+ model_columns = set(self.DbUser.__table__.columns.keys())
179
173
 
180
174
  # Find columns that are in the model but not in the database
181
175
  new_columns = model_columns - existing_columns
182
176
  if new_columns:
183
177
  add_columns_msg = f"Adding columns to database: {new_columns}"
184
- print("NOTE:", add_columns_msg)
185
178
  self.logger.info(add_columns_msg)
186
179
 
187
180
  for col_name in new_columns:
@@ -198,24 +191,6 @@ class Authenticator(_t.Generic[User]):
198
191
  session.execute(text(alter_stmt))
199
192
 
200
193
  session.commit()
201
-
202
- # Determine columns to drop
203
- columns_to_drop = dropped_columns.intersection(existing_columns)
204
- if columns_to_drop:
205
- drop_columns_msg = f"Dropping columns from database: {columns_to_drop}"
206
- print("NOTE:", drop_columns_msg)
207
- self.logger.info(drop_columns_msg)
208
-
209
- for col_name in columns_to_drop:
210
- session.execute(text(f"ALTER TABLE users DROP COLUMN {col_name}"))
211
-
212
- session.commit()
213
-
214
- # Find columns that are in the database but not in the model
215
- extra_db_columns = existing_columns - columns_to_drop - model_columns
216
- if extra_db_columns:
217
- self.logger.warn(f"The following database columns are not in the User model: {extra_db_columns}\n"
218
- "If you want to drop these columns, please use the `dropped_columns` class method of the User model.")
219
194
 
220
195
  # Get admin password from environment variable if exists
221
196
  admin_password = self.env_vars.get(c.SQRL_SECRET_ADMIN_PASSWORD)
@@ -225,7 +200,7 @@ class Authenticator(_t.Generic[User]):
225
200
  password_hash = pwd_context.hash(admin_password)
226
201
  admin_user = session.get(self.DbUser, c.ADMIN_USERNAME)
227
202
  if admin_user is None:
228
- admin_user = self.DbUser(username=c.ADMIN_USERNAME, password_hash=password_hash, is_admin=True)
203
+ admin_user = self.DbUser(username=c.ADMIN_USERNAME, password_hash=password_hash, access_level="admin")
229
204
  session.add(admin_user)
230
205
  else:
231
206
  admin_user.password_hash = password_hash
@@ -238,7 +213,7 @@ class Authenticator(_t.Generic[User]):
238
213
  @cached_property
239
214
  def user_fields(self) -> list[UserField]:
240
215
  """
241
- Get the fields of the User model as a list of dictionaries
216
+ Get the fields of the CustomUserFields model as a list of dictionaries
242
217
 
243
218
  Each dictionary contains the following keys:
244
219
  - name: The name of the field
@@ -247,11 +222,15 @@ class Authenticator(_t.Generic[User]):
247
222
  - enum: The possible values of the field (or None if not applicable)
248
223
  - default: The default value of the field (or None if field is required)
249
224
  """
250
- schema = self.User.model_json_schema()
251
-
252
- fields = []
225
+ # Add the base fields
226
+ fields = [
227
+ UserField(name="username", type="string", nullable=False, enum=None, default=None),
228
+ UserField(name="access_level", type="string", nullable=False, enum=["admin", "member"], default="member")
229
+ ]
253
230
 
254
- properties: dict[str, dict[str, _t.Any]] = schema.get("properties", {})
231
+ # Add custom fields
232
+ schema = self.CustomUserFields.model_json_schema()
233
+ properties: dict[str, dict[str, Any]] = schema.get("properties", {})
255
234
  for field_name, field_schema in properties.items():
256
235
  if choices := field_schema.get("anyOf"):
257
236
  field_type = choices[0]["type"]
@@ -268,54 +247,99 @@ class Authenticator(_t.Generic[User]):
268
247
  def add_user(self, username: str, user_fields: dict, *, update_user: bool = False) -> None:
269
248
  session = self.Session()
270
249
 
271
- # Validate the user data
250
+ # Separate custom fields from base fields
251
+ access_level = user_fields.get('access_level', 'member')
252
+ password = user_fields.get('password')
253
+
254
+ # Validate access level - cannot add/update users with guest access level
255
+ if access_level == "guest":
256
+ raise InvalidInputError(400, "invalid_access_level", "Cannot create or update users with 'guest' access level")
257
+
258
+ # Extract custom fields (everything except username, access_level, password)
259
+ custom_fields_dict = {k: v for k, v in user_fields.items() if k not in ['username', 'access_level', 'password']}
260
+
261
+ # Validate the custom fields
272
262
  try:
273
- user_data = self.User(**user_fields, username=username).model_dump(mode='json')
263
+ custom_fields = self.CustomUserFields(**custom_fields_dict)
264
+ custom_fields_json = json.dumps(custom_fields.model_dump(mode='json'))
274
265
  except ValidationError as e:
275
- raise InvalidInputError(102, f"Invalid user field '{e.errors()[0]['loc'][0]}': {e.errors()[0]['msg']}")
266
+ raise InvalidInputError(400, "invalid_user_data", f"Invalid user field '{e.errors()[0]['loc'][0]}': {e.errors()[0]['msg']}")
276
267
 
277
- # Add a new user
268
+ # Add or update user
278
269
  try:
279
270
  # Check if the user already exists
280
271
  existing_user = session.get(self.DbUser, username)
281
272
  if existing_user is not None:
282
273
  if not update_user:
283
- raise InvalidInputError(101, f"User '{username}' already exists")
274
+ raise InvalidInputError(400, "username_already_exists", f"User '{username}' already exists")
284
275
 
285
- if username == c.ADMIN_USERNAME:
286
- raise InvalidInputError(24, "Changing the admin user is not permitted")
287
- new_user = self.DbUser(password_hash=existing_user.password_hash, **user_data)
288
- session.delete(existing_user)
276
+ if username == c.ADMIN_USERNAME and access_level != "admin":
277
+ raise InvalidInputError(403, "admin_cannot_be_non_admin", "Setting the admin user to non-admin is not permitted")
278
+
279
+ # Update existing user
280
+ existing_user.access_level = access_level
281
+ existing_user.custom_fields = custom_fields_json
289
282
  else:
290
283
  if update_user:
291
- raise InvalidInputError(41, f"No user found for username: {username}")
284
+ raise InvalidInputError(404, "no_user_found_for_username", f"No user found for username: {username}")
292
285
 
293
- password = user_fields.get('password')
294
286
  if password is None:
295
- raise InvalidInputError(100, f"Missing required field 'password' when adding a new user")
287
+ raise InvalidInputError(400, "missing_password", f"Missing required field 'password' when adding a new user")
288
+
296
289
  password_hash = pwd_context.hash(password)
297
- new_user = self.DbUser(password_hash=password_hash, **user_data)
298
-
299
- # Add the user to the session
300
- session.add(new_user)
290
+ new_user = self.DbUser(
291
+ username=username,
292
+ access_level=access_level,
293
+ password_hash=password_hash,
294
+ custom_fields=custom_fields_json
295
+ )
296
+ session.add(new_user)
301
297
 
302
298
  # Commit the transaction
303
299
  session.commit()
304
300
 
305
301
  finally:
306
302
  session.close()
303
+
304
+ def create_or_get_user_from_provider(self, provider_name: str, user_info: dict) -> RegisteredUser:
305
+ provider = next((p for p in self.auth_providers if p.name == provider_name), None)
306
+ if provider is None:
307
+ raise InvalidInputError(404, "auth_provider_not_found", f"Provider '{provider_name}' not found")
308
+
309
+ user = provider.provider_configs.get_user(user_info)
310
+ session = self.Session()
311
+ try:
312
+ # Convert user to database user
313
+ custom_fields_json = user.custom_fields.model_dump_json()
314
+ db_user = self.DbUser(
315
+ username=user.username,
316
+ access_level=user.access_level,
317
+ password_hash="", # By omitting password_hash, it becomes impossible to login with username and password (OAuth only)
318
+ custom_fields=custom_fields_json
319
+ )
320
+
321
+ existing_db_user = session.get(self.DbUser, db_user.username)
322
+ if existing_db_user is None:
323
+ session.add(db_user)
324
+
325
+ session.commit()
326
+
327
+ return self._convert_db_user_to_user(db_user)
328
+
329
+ finally:
330
+ session.close()
307
331
 
308
- def get_user(self, username: str, password: str) -> User:
332
+ def get_user(self, username: str, password: str) -> RegisteredUser:
309
333
  session = self.Session()
310
334
  try:
311
335
  # Query for user by username
312
336
  db_user = session.get(self.DbUser, username)
313
337
 
314
338
  if db_user and pwd_context.verify(password, db_user.password_hash):
315
- user = self.User.model_validate(db_user)
316
- return user # type: ignore
339
+ user = self._convert_db_user_to_user(db_user)
340
+ return user
317
341
  else:
318
- raise InvalidInputError(0, f"Username or password not found")
342
+ raise InvalidInputError(401, "incorrect_username_or_password", f"Incorrect username or password")
319
343
 
320
344
  finally:
321
345
  session.close()
@@ -325,127 +349,611 @@ class Authenticator(_t.Generic[User]):
325
349
  try:
326
350
  db_user = session.get(self.DbUser, username)
327
351
  if db_user is None:
328
- raise InvalidInputError(2, f"User not found")
352
+ raise InvalidInputError(401, "user_not_found", f"Username '{username}' not found for password change")
329
353
 
330
- if pwd_context.verify(old_password, db_user.password_hash):
354
+ if db_user.password_hash and pwd_context.verify(old_password, db_user.password_hash):
331
355
  db_user.password_hash = pwd_context.hash(new_password)
332
356
  session.commit()
333
357
  else:
334
- raise InvalidInputError(3, f"Incorrect password")
358
+ raise InvalidInputError(401, "incorrect_password", f"Incorrect password")
335
359
  finally:
336
360
  session.close()
337
361
 
338
362
  def delete_user(self, username: str) -> None:
339
363
  if username == c.ADMIN_USERNAME:
340
- raise InvalidInputError(23, "Cannot delete the admin user")
364
+ raise InvalidInputError(403, "cannot_delete_admin_user", "Cannot delete the admin user")
341
365
 
342
366
  session = self.Session()
343
367
  try:
344
368
  db_user = session.get(self.DbUser, username)
345
369
  if db_user is None:
346
- raise InvalidInputError(41, f"No user found for username: {username}")
370
+ raise InvalidInputError(404, "no_user_found_for_username", f"No user found for username: {username}")
347
371
  session.delete(db_user)
348
372
  session.commit()
349
373
  finally:
350
374
  session.close()
351
375
 
352
- def get_all_users(self) -> list[User]:
376
+ def get_all_users(self) -> list[RegisteredUser]:
353
377
  session = self.Session()
354
378
  try:
355
379
  db_users = session.query(self.DbUser).all()
356
- return [self.User.model_validate(user) for user in db_users] # type: ignore
380
+ return [self._convert_db_user_to_user(user) for user in db_users]
357
381
  finally:
358
382
  session.close()
359
383
 
360
- def create_access_token(self, user: User, expiry_minutes: int | None, *, title: str | None = None) -> tuple[str, datetime]:
384
+ def create_access_token(self, user: AbstractUser, expiry_minutes: int | None, *, title: str | None = None) -> tuple[str, datetime]:
385
+ """
386
+ Creates an API key if title is provided. Otherwise, creates a JWT token.
387
+ """
361
388
  created_at = datetime.now(timezone.utc)
362
389
  expire_at = created_at + timedelta(minutes=expiry_minutes) if expiry_minutes is not None else datetime.max
363
- token_id = None
390
+
391
+ if self.secret_key is None:
392
+ raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to create an access token")
393
+
364
394
  if title is not None:
365
395
  session = self.Session()
366
396
  try:
367
- access_token = self.DbAccessToken(title=title, username=user.username, created_at=created_at, expires_at=expire_at)
368
- session.add(access_token)
397
+ token_id = "sqrl-" + secrets.token_urlsafe(16)
398
+ hashed_key = u.hash_string(token_id, salt=self.secret_key)
399
+ api_key = self.DbApiKey(hashed_key=hashed_key, title=title, username=user.username, created_at=created_at, expires_at=expire_at)
400
+ session.add(api_key)
369
401
  session.commit()
370
- token_id = access_token.token_id
371
402
  finally:
372
403
  session.close()
404
+ else:
405
+ to_encode = {"username": user.username, "exp": expire_at}
406
+ token_id = jwt.encode(to_encode, self.secret_key, algorithm="HS256")
373
407
 
374
- if self.secret_key is None:
375
- raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to create an access token")
376
- to_encode = {"username": user.username, "token_id": token_id, "exp": expire_at}
377
- encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm="HS256")
378
- return encoded_jwt, expire_at
408
+ return token_id, expire_at
379
409
 
380
- def get_user_from_token(self, token: str | None) -> User | None:
381
- if token is None or token == "":
410
+ def get_user_from_token(self, token: str | None) -> RegisteredUser | None:
411
+ """
412
+ Get a user from an access token (JWT, or API key if token starts with 'sqrl-')
413
+ """
414
+ if not token:
382
415
  return None
383
416
 
384
417
  if self.secret_key is None:
385
418
  raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to get user from an access token")
386
419
 
387
- try:
388
- payload: dict = jwt.decode(token, self.secret_key, algorithms=["HS256"])
389
- except InvalidTokenError:
390
- raise InvalidInputError(1, "Invalid authorization token")
391
-
392
420
  session = self.Session()
393
421
  try:
394
- if payload.get("token_id") is not None:
395
- access_token = session.query(self.DbAccessToken).filter(
396
- self.DbAccessToken.username == payload["username"],
397
- self.DbAccessToken.token_id == payload["token_id"],
398
- self.DbAccessToken.expires_at >= func.now()
422
+ if token.startswith("sqrl-"):
423
+ hashed_key = u.hash_string(token, salt=self.secret_key)
424
+ api_key = session.query(self.DbApiKey).filter(
425
+ self.DbApiKey.hashed_key == hashed_key,
426
+ self.DbApiKey.expires_at >= func.now()
399
427
  ).first()
400
- if access_token is None:
401
- raise InvalidInputError(1, "Invalid authorization token")
402
-
403
- db_user = session.get(self.DbUser, payload["username"])
428
+ if api_key is None:
429
+ raise InvalidTokenError()
430
+ username = api_key.username
431
+ else:
432
+ payload: dict = jwt.decode(token, self.secret_key, algorithms=["HS256"])
433
+ username = payload["username"]
434
+
435
+ db_user = session.get(self.DbUser, username)
404
436
  if db_user is None:
405
- raise InvalidInputError(1, "Invalid authorization token")
437
+ raise InvalidTokenError()
438
+
439
+ user = self._convert_db_user_to_user(db_user)
440
+
441
+ except InvalidTokenError:
442
+ raise InvalidInputError(401, "invalid_authorization_token", "Invalid authorization token")
406
443
  finally:
407
444
  session.close()
408
445
 
409
- user = self.User.model_validate(db_user)
410
- return user # type: ignore
446
+ return user
411
447
 
412
- def get_all_tokens(self, username: str) -> list[AccessToken]:
448
+ def get_all_api_keys(self, username: str) -> list[ApiKey]:
449
+ """
450
+ Get the ID, title, and expiry date of all API keys for a user. Note that the ID is a hash of the API key, not the API key itself.
451
+ """
413
452
  session = self.Session()
414
453
  try:
415
- tokens = session.query(self.DbAccessToken).filter(
416
- self.DbAccessToken.username == username,
417
- self.DbAccessToken.expires_at >= func.now()
454
+ tokens = session.query(self.DbApiKey).filter(
455
+ self.DbApiKey.username == username,
456
+ self.DbApiKey.expires_at >= func.now()
418
457
  ).all()
419
458
 
420
- return [AccessToken.model_validate(token) for token in tokens]
459
+ return [ApiKey.model_validate(token) for token in tokens]
421
460
  finally:
422
461
  session.close()
423
462
 
424
- def revoke_token(self, username: str, token_id: str) -> None:
463
+ def revoke_api_key(self, username: str, api_key_id: str) -> None:
464
+ """
465
+ Revoke an API key
466
+ """
425
467
  session = self.Session()
426
468
  try:
427
- access_token = session.query(self.DbAccessToken).filter(
428
- self.DbAccessToken.username == username,
429
- self.DbAccessToken.token_id == token_id
469
+
470
+ api_key = session.query(self.DbApiKey).filter(
471
+ self.DbApiKey.username == username,
472
+ self.DbApiKey.id == api_key_id
430
473
  ).first()
431
474
 
432
- if access_token is None:
433
- raise InvalidInputError(40, f"No access token found for token_id: {token_id}")
475
+ if api_key is None:
476
+ raise InvalidInputError(404, "api_key_not_found", f"The API key could not be found: {api_key_id}")
434
477
 
435
- session.delete(access_token)
478
+ session.delete(api_key)
436
479
  session.commit()
437
480
  finally:
438
481
  session.close()
439
482
 
440
- def can_user_access_scope(self, user: User | None, scope: PermissionScope) -> bool:
441
- if user is None:
483
+ def can_user_access_scope(self, user: AbstractUser, scope: PermissionScope) -> bool:
484
+ if user.access_level == "guest":
442
485
  user_level = PermissionScope.PUBLIC
443
- elif user.is_admin:
486
+ elif user.access_level == "admin":
444
487
  user_level = PermissionScope.PRIVATE
445
- else:
488
+ else: # member
446
489
  user_level = PermissionScope.PROTECTED
447
490
 
448
491
  return user_level.value >= scope.value
449
492
 
493
+ # OAuth Client Management Methods
494
+
495
+ def generate_secret_and_hash(self) -> tuple[str, str]:
496
+ """Generate a secure access token and its hash"""
497
+ secret = secrets.token_urlsafe(64)
498
+ secret_hash = pwd_context.hash(secret)
499
+ return secret, secret_hash
500
+
501
+ def _validate_client_registration_request(self, request: ClientRegistrationRequest | ClientUpdateRequest) -> dict:
502
+ updates = {}
503
+ if request.client_name:
504
+ updates['client_name'] = request.client_name
505
+
506
+ # Validate redirect_uris if being updated
507
+ if request.redirect_uris:
508
+ for uri in request.redirect_uris:
509
+ if not self._validate_redirect_uri_format(uri):
510
+ raise InvalidInputError(400, "invalid_redirect_uri", f"Invalid redirect URI format: {uri}")
511
+ updates['redirect_uris'] = json.dumps(request.redirect_uris)
512
+
513
+ # Validate grant_types if being updated
514
+ if request.grant_types:
515
+ if not all(grant_type in c.SUPPORTED_GRANT_TYPES for grant_type in request.grant_types):
516
+ raise InvalidInputError(400, "invalid_grant_types", f"Invalid grant types. Supported grant types are: {c.SUPPORTED_GRANT_TYPES}")
517
+ updates['grant_types'] = ','.join(request.grant_types)
518
+
519
+ # Validate response_types if being updated
520
+ if request.response_types:
521
+ if not all(response_type in c.SUPPORTED_RESPONSE_TYPES for response_type in request.response_types):
522
+ raise InvalidInputError(400, "invalid_response_types", f"Invalid response types. Supported response types are: {c.SUPPORTED_RESPONSE_TYPES}")
523
+ updates['response_types'] = ','.join(request.response_types)
524
+
525
+ # Validate scope if being updated
526
+ if request.scope:
527
+ scopes = request.scope.split()
528
+ if not all(scope in c.SUPPORTED_SCOPES for scope in scopes):
529
+ raise InvalidInputError(400, "invalid_scope", f"Invalid scope. Supported scopes are: {c.SUPPORTED_SCOPES}")
530
+ updates['scope'] = ','.join(scopes)
531
+
532
+ return updates
533
+
534
+ def register_oauth_client(
535
+ self, request: ClientRegistrationRequest, client_management_path_format: str
536
+ ) -> ClientRegistrationResponse:
537
+ """Register a new OAuth client and return client_id, client_secret, and registration_access_token"""
538
+ grant_types = request.grant_types
539
+ if grant_types is None:
540
+ grant_types = ['authorization_code', 'refresh_token']
541
+
542
+ # Validate request
543
+ self._validate_client_registration_request(request)
544
+
545
+ # Generate secure client credentials and registration access token
546
+ client_id = secrets.token_urlsafe(16)
547
+ client_secret, client_secret_hash = self.generate_secret_and_hash()
548
+
549
+ registration_access_token, registration_access_token_hash = self.generate_secret_and_hash()
550
+ registration_client_uri = client_management_path_format.format(client_id=client_id)
551
+
552
+ session = self.Session()
553
+ try:
554
+ oauth_client = self.DbOAuthClient(
555
+ client_id=client_id,
556
+ client_secret_hash=client_secret_hash,
557
+ client_name=request.client_name,
558
+ redirect_uris=json.dumps(request.redirect_uris),
559
+ scope=request.scope,
560
+ grant_types=','.join(grant_types),
561
+ registration_access_token_hash=registration_access_token_hash
562
+ )
563
+ session.add(oauth_client)
564
+ session.commit()
565
+
566
+ return ClientRegistrationResponse(
567
+ client_id=client_id,
568
+ client_secret=client_secret,
569
+ client_name=request.client_name,
570
+ redirect_uris=request.redirect_uris,
571
+ scope=request.scope,
572
+ grant_types=grant_types,
573
+ response_types=request.response_types,
574
+ created_at=datetime.now(timezone.utc),
575
+ is_active=True,
576
+ registration_client_uri=registration_client_uri,
577
+ registration_access_token=registration_access_token,
578
+ )
579
+
580
+ finally:
581
+ session.close()
582
+
583
+ def get_oauth_client_details(self, client_id: str) -> ClientDetailsResponse:
584
+ """Get OAuth client details with parsed JSON fields"""
585
+ session = self.Session()
586
+ try:
587
+ client = session.get(self.DbOAuthClient, client_id)
588
+ if client is None:
589
+ raise InvalidInputError(404, "invalid_client_id", "Client not found")
590
+ if not client.is_active:
591
+ raise InvalidInputError(404, "invalid_client_id", "Client is no longer active")
592
+
593
+ return ClientDetailsResponse(
594
+ client_id=client.client_id,
595
+ client_name=client.client_name,
596
+ redirect_uris=json.loads(client.redirect_uris),
597
+ scope=client.scope,
598
+ grant_types=client.grant_types.split(','),
599
+ response_types=client.response_types.split(','),
600
+ created_at=client.created_at,
601
+ is_active=client.is_active
602
+ )
603
+ finally:
604
+ session.close()
605
+
606
+ def validate_client_credentials(self, client_id: str, client_secret: str) -> bool:
607
+ """Validate OAuth client credentials"""
608
+ session = self.Session()
609
+ try:
610
+ client = session.get(self.DbOAuthClient, client_id)
611
+ if client is None or not client.is_active:
612
+ return False
613
+ return pwd_context.verify(client_secret, client.client_secret_hash)
614
+ finally:
615
+ session.close()
616
+
617
+ def validate_redirect_uri(self, client_id: str, redirect_uri: str) -> bool:
618
+ """Validate that redirect_uri is registered for the client"""
619
+ session = self.Session()
620
+ try:
621
+ client = session.get(self.DbOAuthClient, client_id)
622
+ if client is None or not client.is_active:
623
+ return False
624
+
625
+ registered_uris = json.loads(client.redirect_uris)
626
+ return redirect_uri in registered_uris
627
+ finally:
628
+ session.close()
629
+
630
+ def validate_registration_access_token(self, client_id: str, registration_access_token: str) -> bool:
631
+ """Validate registration access token for client management operations"""
632
+ session = self.Session()
633
+ try:
634
+ client = session.get(self.DbOAuthClient, client_id)
635
+ if client is None:
636
+ return False
637
+
638
+ return pwd_context.verify(registration_access_token, client.registration_access_token_hash)
639
+ finally:
640
+ session.close()
641
+
642
+ def _validate_redirect_uri_format(self, uri: str) -> bool:
643
+ """Validate redirect URI format for security"""
644
+ # Basic validation - must be https (except localhost) and not contain fragments
645
+ if '#' in uri:
646
+ return False
647
+
648
+ if uri.startswith('http://'):
649
+ # Only allow http for localhost/development
650
+ if not ('localhost' in uri or '127.0.0.1' in uri):
651
+ return False
652
+ elif not uri.startswith('https://'):
653
+ # Custom schemes allowed for mobile apps
654
+ if '://' not in uri:
655
+ return False
656
+
657
+ return True
658
+
659
+ def update_oauth_client_with_token_rotation(self, client_id: str, request: ClientUpdateRequest) -> ClientUpdateResponse:
660
+ """Update OAuth client metadata and rotate registration access token"""
661
+
662
+ # Build update dictionary, excluding None values
663
+ updates = self._validate_client_registration_request(request)
664
+
665
+ session = self.Session()
666
+ try:
667
+ client = session.get(self.DbOAuthClient, client_id)
668
+ if client is None:
669
+ raise InvalidInputError(404, "invalid_client_id", f"OAuth client not found: {client_id}")
670
+ if not client.is_active and not request.is_active:
671
+ raise InvalidInputError(400, "invalid_request", "Cannot update a deactivated client")
672
+
673
+ if request.is_active is not None:
674
+ client.is_active = request.is_active
675
+
676
+ # Generate new registration access token
677
+ new_registration_access_token, new_registration_access_token_hash = self.generate_secret_and_hash()
678
+
679
+ # Update client fields
680
+ for key, value in updates.items():
681
+ if hasattr(client, key):
682
+ setattr(client, key, value)
683
+
684
+ # Update registration access token
685
+ client.registration_access_token_hash = new_registration_access_token_hash
686
+
687
+ session.commit()
688
+
689
+ return ClientUpdateResponse(
690
+ client_id=client.client_id,
691
+ client_name=client.client_name,
692
+ redirect_uris=json.loads(client.redirect_uris),
693
+ scope=client.scope,
694
+ grant_types=client.grant_types.split(','),
695
+ response_types=client.response_types.split(','),
696
+ created_at=client.created_at,
697
+ is_active=client.is_active,
698
+ registration_access_token=new_registration_access_token,
699
+ )
700
+
701
+ finally:
702
+ session.close()
703
+
704
+ def revoke_oauth_client(self, client_id: str) -> None:
705
+ """Revoke (deactivate) an OAuth client"""
706
+ session = self.Session()
707
+ try:
708
+ client = session.get(self.DbOAuthClient, client_id)
709
+ if client is None:
710
+ raise InvalidInputError(404, "client_not_found", f"OAuth client not found: {client_id}")
711
+
712
+ client.is_active = False
713
+ session.commit()
714
+ finally:
715
+ session.close()
716
+
717
+ # Authorization Code Flow Methods
718
+
719
+ def create_authorization_code(
720
+ self, client_id: str, username: str, redirect_uri: str, scope: str,
721
+ code_challenge: str, code_challenge_method: str = "S256"
722
+ ) -> str:
723
+ """Create and store an authorization code for the authorization code flow"""
724
+
725
+ # Validate client_id
726
+ client_details = self.get_oauth_client_details(client_id)
727
+ if client_details is None or not client_details.is_active:
728
+ raise InvalidInputError(400, "invalid_client", "Invalid or inactive client")
729
+
730
+ # Validate redirect_uri
731
+ if not self.validate_redirect_uri(client_id, redirect_uri):
732
+ raise InvalidInputError(400, "invalid_redirect_uri", "Invalid redirect_uri for this client")
733
+
734
+ # Validate PKCE parameters
735
+ if not code_challenge:
736
+ raise InvalidInputError(400, "invalid_request", "code_challenge is required")
737
+
738
+ if code_challenge_method not in ["S256"]:
739
+ raise InvalidInputError(400, "invalid_request", "code_challenge_method must be 'S256'")
740
+
741
+ # Generate authorization code
742
+ code = secrets.token_urlsafe(48)
743
+
744
+ # Set expiration (10 minutes from now)
745
+ created_at = datetime.now(timezone.utc)
746
+ expires_at = created_at + timedelta(minutes=10)
747
+
748
+ session = self.Session()
749
+ try:
750
+ auth_code = self.DbAuthorizationCode(
751
+ code=code,
752
+ client_id=client_id,
753
+ username=username,
754
+ redirect_uri=redirect_uri,
755
+ scope=scope,
756
+ code_challenge=code_challenge,
757
+ code_challenge_method=code_challenge_method,
758
+ created_at=created_at,
759
+ expires_at=expires_at,
760
+ used=False
761
+ )
762
+ session.add(auth_code)
763
+ session.commit()
764
+
765
+ return code
766
+
767
+ finally:
768
+ session.close()
769
+
770
+ def exchange_authorization_code(
771
+ self, code: str, client_id: str, redirect_uri: str, code_verifier: str, access_token_expiry_minutes: int
772
+ ) -> TokenResponse:
773
+ """
774
+ Exchange authorization code for access and refresh tokens
775
+ Returns (access_token, refresh_token)
776
+ """
777
+ session = self.Session()
778
+ try:
779
+ # Get and validate authorization code
780
+ auth_code = session.query(self.DbAuthorizationCode).filter(
781
+ self.DbAuthorizationCode.code == code,
782
+ self.DbAuthorizationCode.client_id == client_id,
783
+ self.DbAuthorizationCode.expires_at >= func.now(),
784
+ self.DbAuthorizationCode.used == False
785
+ ).first()
786
+
787
+ if auth_code is None:
788
+ raise InvalidInputError(400, "invalid_grant", "Invalid authorization code")
789
+
790
+ # Validate redirect URI
791
+ if auth_code.redirect_uri != redirect_uri:
792
+ raise InvalidInputError(400, "invalid_grant", "Redirect URI mismatch")
793
+
794
+ if not u.validate_pkce_challenge(code_verifier, auth_code.code_challenge):
795
+ raise InvalidInputError(400, "invalid_grant", "Invalid code_verifier")
796
+
797
+ # Get user
798
+ db_user = session.get(self.DbUser, auth_code.username)
799
+ if db_user is None:
800
+ raise InvalidInputError(400, "invalid_grant", "User not found")
801
+
802
+ # Mark authorization code as used
803
+ auth_code.used = True
804
+
805
+ # Generate tokens
806
+ user_obj = self._convert_db_user_to_user(db_user)
807
+ access_token, token_expires_at = self.create_access_token(user_obj, expiry_minutes=access_token_expiry_minutes)
808
+ access_token_hash = pwd_context.hash(access_token)
809
+
810
+ # Generate refresh token
811
+ refresh_token, refresh_token_hash = self.generate_secret_and_hash()
812
+ refresh_expires_at = datetime.now(timezone.utc) + timedelta(days=30)
813
+
814
+ oauth_token = self.DbOAuthToken(
815
+ access_token_hash=access_token_hash,
816
+ refresh_token_hash=refresh_token_hash,
817
+ client_id=client_id,
818
+ username=auth_code.username,
819
+ scope=auth_code.scope,
820
+ access_token_expires_at=token_expires_at,
821
+ refresh_token_expires_at=refresh_expires_at
822
+ )
823
+ session.add(oauth_token)
824
+
825
+ session.commit()
826
+
827
+ return TokenResponse(
828
+ access_token=access_token,
829
+ expires_in=access_token_expiry_minutes*60,
830
+ refresh_token=refresh_token
831
+ )
832
+
833
+ finally:
834
+ session.close()
835
+
836
+ def refresh_oauth_access_token(self, refresh_token: str, client_id: str, access_token_expiry_minutes: int) -> TokenResponse:
837
+ """
838
+ Refresh OAuth access token using refresh token
839
+ Returns (access_token, new_refresh_token)
840
+ """
841
+ session = self.Session()
842
+ try:
843
+ # Validate client
844
+ client = session.get(self.DbOAuthClient, client_id)
845
+ if client is None or not client.is_active:
846
+ raise InvalidInputError(400, "invalid_client", "Invalid or inactive client")
847
+
848
+ # Find active refresh token for this client
849
+ oauth_token = session.query(self.DbOAuthToken).filter(
850
+ self.DbOAuthToken.client_id == client_id,
851
+ self.DbOAuthToken.refresh_token_expires_at >= func.now(),
852
+ self.DbOAuthToken.is_revoked == False
853
+ ).first()
854
+
855
+ # Find the token that matches our refresh token
856
+ if oauth_token is None or not pwd_context.verify(refresh_token, oauth_token.refresh_token_hash):
857
+ raise InvalidInputError(400, "invalid_grant", "Invalid or expired refresh token")
858
+
859
+ # Get user
860
+ db_user = session.get(self.DbUser, oauth_token.username)
861
+ if db_user is None:
862
+ raise InvalidInputError(400, "invalid_client", "User not found")
863
+
864
+ # Check secret key is available
865
+ if self.secret_key is None:
866
+ raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required for OAuth token operations")
867
+
868
+ # Generate new tokens
869
+ user_obj = self._convert_db_user_to_user(db_user)
870
+ access_token, token_expires_at = self.create_access_token(user_obj, expiry_minutes=access_token_expiry_minutes)
871
+ access_token_hash = pwd_context.hash(access_token)
872
+
873
+ # Generate new refresh token
874
+ new_refresh_token, new_refresh_token_hash = self.generate_secret_and_hash()
875
+ refresh_expires_at = datetime.now(timezone.utc) + timedelta(days=30)
876
+
877
+ # Revoke old token
878
+ oauth_token.is_revoked = True
879
+
880
+ # Create new token entry
881
+ new_oauth_token = self.DbOAuthToken(
882
+ access_token_hash=access_token_hash,
883
+ refresh_token_hash=new_refresh_token_hash,
884
+ client_id=client_id,
885
+ username=oauth_token.username,
886
+ scope=oauth_token.scope,
887
+ access_token_expires_at=token_expires_at,
888
+ refresh_token_expires_at=refresh_expires_at
889
+ )
890
+ session.add(new_oauth_token)
891
+
892
+ session.commit()
893
+
894
+ return TokenResponse(
895
+ access_token=access_token,
896
+ expires_in=access_token_expiry_minutes*60,
897
+ refresh_token=new_refresh_token
898
+ )
899
+
900
+ finally:
901
+ session.close()
902
+
903
+ def revoke_oauth_token(self, client_id: str, token: str, token_type_hint: str | None = None) -> None:
904
+ """
905
+ Revoke an OAuth refresh token
906
+ token_type_hint is optional or must be 'refresh_token'. Revoking access token is not supported yet.
907
+ """
908
+ if token_type_hint and token_type_hint != 'refresh_token':
909
+ raise InvalidInputError(400, "invalid_request", "Only refresh tokens can be revoked")
910
+
911
+ session = self.Session()
912
+ try:
913
+ # Validate client
914
+ client = session.get(self.DbOAuthClient, client_id)
915
+ if client is None or not client.is_active:
916
+ raise InvalidInputError(400, "invalid_client", "Invalid or inactive client")
917
+
918
+ # Get all potentially matching tokens
919
+ oauth_tokens = session.query(self.DbOAuthToken).filter(
920
+ self.DbOAuthToken.client_id == client_id,
921
+ self.DbOAuthToken.refresh_token_expires_at >= func.now(),
922
+ self.DbOAuthToken.is_revoked == False
923
+ ).all()
924
+
925
+ # Find the token that matches
926
+ oauth_token = None
927
+ for token_obj in oauth_tokens:
928
+ if pwd_context.verify(token, token_obj.refresh_token_hash):
929
+ oauth_token = token_obj
930
+ break
931
+
932
+ # Revoke token if found (per OAuth spec, always return success)
933
+ if oauth_token:
934
+ oauth_token.is_revoked = True
935
+ session.commit()
936
+
937
+ finally:
938
+ session.close()
939
+
450
940
  def close(self) -> None:
451
941
  self.engine.dispose()
942
+
943
+
944
+ def provider(name: str, label: str, icon: str):
945
+ """
946
+ Decorator to register an authentication provider
947
+
948
+ Arguments:
949
+ name: The name of the provider (must be unique, e.g. 'google')
950
+ label: The label of the provider (e.g. 'Google')
951
+ icon: The URL of the icon of the provider (e.g. 'https://www.google.com/favicon.ico')
952
+ """
953
+ def decorator(func: Callable[[AuthProviderArgs], ProviderConfigs]):
954
+ def wrapper(sqrl: AuthProviderArgs):
955
+ provider_configs = func(sqrl)
956
+ return AuthProvider(name=name, label=label, icon=icon, provider_configs=provider_configs)
957
+ Authenticator.providers.append(wrapper)
958
+ return wrapper
959
+ return decorator