squirrels 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of squirrels might be problematic. Click here for more details.

Files changed (125) hide show
  1. dateutils/__init__.py +6 -0
  2. dateutils/_enums.py +25 -0
  3. squirrels/dateutils.py → dateutils/_implementation.py +58 -111
  4. dateutils/types.py +6 -0
  5. squirrels/__init__.py +13 -11
  6. squirrels/_api_routes/__init__.py +5 -0
  7. squirrels/_api_routes/auth.py +271 -0
  8. squirrels/_api_routes/base.py +165 -0
  9. squirrels/_api_routes/dashboards.py +150 -0
  10. squirrels/_api_routes/data_management.py +145 -0
  11. squirrels/_api_routes/datasets.py +257 -0
  12. squirrels/_api_routes/oauth2.py +298 -0
  13. squirrels/_api_routes/project.py +252 -0
  14. squirrels/_api_server.py +256 -450
  15. squirrels/_arguments/__init__.py +0 -0
  16. squirrels/_arguments/init_time_args.py +108 -0
  17. squirrels/_arguments/run_time_args.py +147 -0
  18. squirrels/_auth.py +960 -0
  19. squirrels/_command_line.py +126 -45
  20. squirrels/_compile_prompts.py +147 -0
  21. squirrels/_connection_set.py +48 -26
  22. squirrels/_constants.py +68 -38
  23. squirrels/_dashboards.py +160 -0
  24. squirrels/_data_sources.py +570 -0
  25. squirrels/_dataset_types.py +84 -0
  26. squirrels/_exceptions.py +29 -0
  27. squirrels/_initializer.py +177 -80
  28. squirrels/_logging.py +115 -0
  29. squirrels/_manifest.py +208 -79
  30. squirrels/_model_builder.py +69 -0
  31. squirrels/_model_configs.py +74 -0
  32. squirrels/_model_queries.py +52 -0
  33. squirrels/_models.py +926 -367
  34. squirrels/_package_data/base_project/.env +42 -0
  35. squirrels/_package_data/base_project/.env.example +42 -0
  36. squirrels/_package_data/base_project/assets/expenses.db +0 -0
  37. squirrels/_package_data/base_project/connections.yml +16 -0
  38. squirrels/_package_data/base_project/dashboards/dashboard_example.py +34 -0
  39. squirrels/_package_data/base_project/dashboards/dashboard_example.yml +22 -0
  40. squirrels/{package_data → _package_data}/base_project/docker/.dockerignore +5 -2
  41. squirrels/{package_data → _package_data}/base_project/docker/Dockerfile +3 -3
  42. squirrels/{package_data → _package_data}/base_project/docker/compose.yml +1 -1
  43. squirrels/_package_data/base_project/duckdb_init.sql +10 -0
  44. squirrels/{package_data/base_project/.gitignore → _package_data/base_project/gitignore} +3 -2
  45. squirrels/_package_data/base_project/macros/macros_example.sql +17 -0
  46. squirrels/_package_data/base_project/models/builds/build_example.py +26 -0
  47. squirrels/_package_data/base_project/models/builds/build_example.sql +16 -0
  48. squirrels/_package_data/base_project/models/builds/build_example.yml +57 -0
  49. squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +12 -0
  50. squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +26 -0
  51. squirrels/_package_data/base_project/models/federates/federate_example.py +37 -0
  52. squirrels/_package_data/base_project/models/federates/federate_example.sql +19 -0
  53. squirrels/_package_data/base_project/models/federates/federate_example.yml +65 -0
  54. squirrels/_package_data/base_project/models/sources.yml +38 -0
  55. squirrels/{package_data → _package_data}/base_project/parameters.yml +56 -40
  56. squirrels/_package_data/base_project/pyconfigs/connections.py +14 -0
  57. squirrels/{package_data → _package_data}/base_project/pyconfigs/context.py +21 -40
  58. squirrels/_package_data/base_project/pyconfigs/parameters.py +141 -0
  59. squirrels/_package_data/base_project/pyconfigs/user.py +44 -0
  60. squirrels/_package_data/base_project/seeds/seed_categories.yml +15 -0
  61. squirrels/_package_data/base_project/seeds/seed_subcategories.csv +15 -0
  62. squirrels/_package_data/base_project/seeds/seed_subcategories.yml +21 -0
  63. squirrels/_package_data/base_project/squirrels.yml.j2 +61 -0
  64. squirrels/_package_data/templates/dataset_results.html +112 -0
  65. squirrels/_package_data/templates/oauth_login.html +271 -0
  66. squirrels/_package_data/templates/squirrels_studio.html +20 -0
  67. squirrels/_package_loader.py +8 -4
  68. squirrels/_parameter_configs.py +104 -103
  69. squirrels/_parameter_options.py +348 -0
  70. squirrels/_parameter_sets.py +57 -47
  71. squirrels/_parameters.py +1664 -0
  72. squirrels/_project.py +721 -0
  73. squirrels/_py_module.py +7 -5
  74. squirrels/_schemas/__init__.py +0 -0
  75. squirrels/_schemas/auth_models.py +167 -0
  76. squirrels/_schemas/query_param_models.py +75 -0
  77. squirrels/{_api_response_models.py → _schemas/response_models.py} +126 -47
  78. squirrels/_seeds.py +35 -16
  79. squirrels/_sources.py +110 -0
  80. squirrels/_utils.py +248 -73
  81. squirrels/_version.py +1 -1
  82. squirrels/arguments.py +7 -0
  83. squirrels/auth.py +4 -0
  84. squirrels/connections.py +3 -0
  85. squirrels/dashboards.py +2 -81
  86. squirrels/data_sources.py +14 -631
  87. squirrels/parameter_options.py +13 -348
  88. squirrels/parameters.py +14 -1266
  89. squirrels/types.py +16 -0
  90. squirrels-0.5.0.dist-info/METADATA +113 -0
  91. squirrels-0.5.0.dist-info/RECORD +97 -0
  92. {squirrels-0.4.1.dist-info → squirrels-0.5.0.dist-info}/WHEEL +1 -1
  93. squirrels-0.5.0.dist-info/entry_points.txt +3 -0
  94. {squirrels-0.4.1.dist-info → squirrels-0.5.0.dist-info/licenses}/LICENSE +1 -1
  95. squirrels/_authenticator.py +0 -85
  96. squirrels/_dashboards_io.py +0 -61
  97. squirrels/_environcfg.py +0 -84
  98. squirrels/arguments/init_time_args.py +0 -40
  99. squirrels/arguments/run_time_args.py +0 -208
  100. squirrels/package_data/assets/favicon.ico +0 -0
  101. squirrels/package_data/assets/index.css +0 -1
  102. squirrels/package_data/assets/index.js +0 -58
  103. squirrels/package_data/base_project/assets/expenses.db +0 -0
  104. squirrels/package_data/base_project/connections.yml +0 -7
  105. squirrels/package_data/base_project/dashboards/dashboard_example.py +0 -32
  106. squirrels/package_data/base_project/dashboards.yml +0 -10
  107. squirrels/package_data/base_project/env.yml +0 -29
  108. squirrels/package_data/base_project/models/dbviews/dbview_example.py +0 -47
  109. squirrels/package_data/base_project/models/dbviews/dbview_example.sql +0 -22
  110. squirrels/package_data/base_project/models/federates/federate_example.py +0 -21
  111. squirrels/package_data/base_project/models/federates/federate_example.sql +0 -3
  112. squirrels/package_data/base_project/pyconfigs/auth.py +0 -45
  113. squirrels/package_data/base_project/pyconfigs/connections.py +0 -19
  114. squirrels/package_data/base_project/pyconfigs/parameters.py +0 -95
  115. squirrels/package_data/base_project/seeds/seed_subcategories.csv +0 -15
  116. squirrels/package_data/base_project/squirrels.yml.j2 +0 -94
  117. squirrels/package_data/templates/index.html +0 -18
  118. squirrels/project.py +0 -378
  119. squirrels/user_base.py +0 -55
  120. squirrels-0.4.1.dist-info/METADATA +0 -117
  121. squirrels-0.4.1.dist-info/RECORD +0 -60
  122. squirrels-0.4.1.dist-info/entry_points.txt +0 -4
  123. /squirrels/{package_data → _package_data}/base_project/assets/weather.db +0 -0
  124. /squirrels/{package_data → _package_data}/base_project/seeds/seed_categories.csv +0 -0
  125. /squirrels/{package_data → _package_data}/base_project/tmp/.gitignore +0 -0
squirrels/_auth.py ADDED
@@ -0,0 +1,960 @@
1
+ from typing import Callable, Any
2
+ from datetime import datetime, timedelta, timezone
3
+ from functools import cached_property
4
+ from jwt.exceptions import InvalidTokenError
5
+ from passlib.context import CryptContext
6
+ from pydantic import ValidationError
7
+ from sqlalchemy import create_engine, Engine, func, inspect, text, ForeignKey
8
+ from sqlalchemy.orm import declarative_base, sessionmaker, Mapped, mapped_column
9
+ import jwt, uuid, secrets, json
10
+
11
+ from ._manifest import PermissionScope
12
+ from ._py_module import PyModule
13
+ from ._exceptions import InvalidInputError, ConfigurationError
14
+ from ._arguments.init_time_args import AuthProviderArgs
15
+ from ._schemas.auth_models import (
16
+ CustomUserFields, AbstractUser, GuestUser, RegisteredUser, ApiKey, UserField, AuthProvider, ProviderConfigs,
17
+ ClientRegistrationRequest, ClientUpdateRequest, ClientDetailsResponse, ClientRegistrationResponse,
18
+ ClientUpdateResponse, TokenResponse
19
+ )
20
+ from . import _utils as u, _constants as c
21
+
22
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
23
+
24
+ ProviderFunctionType = Callable[[AuthProviderArgs], AuthProvider]
25
+
26
+
27
+ class Authenticator:
28
+ providers: list[ProviderFunctionType] = [] # static variable to stage providers
29
+
30
+ def __init__(
31
+ self, logger: u.Logger, base_path: str, auth_args: AuthProviderArgs, provider_functions: list[ProviderFunctionType],
32
+ custom_user_fields_cls: type[CustomUserFields], *, sa_engine: Engine | None = None, external_only: bool = False
33
+ ):
34
+ self.logger = logger
35
+ self.env_vars = auth_args.env_vars
36
+ self.secret_key = self.env_vars.get(c.SQRL_SECRET_KEY)
37
+ self.external_only = external_only
38
+
39
+ # Create a new declarative base for this instance
40
+ self.Base = declarative_base()
41
+
42
+ # Define DbUser class for this instance
43
+ class DbUser(self.Base):
44
+ __tablename__ = 'users'
45
+ __table_args__ = {'extend_existing': True}
46
+ username: Mapped[str] = mapped_column(primary_key=True)
47
+ access_level: Mapped[str] = mapped_column(nullable=False, default="member")
48
+ password_hash: Mapped[str] = mapped_column(nullable=False)
49
+ custom_fields: Mapped[str] = mapped_column(nullable=False, default="{}")
50
+ created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
51
+
52
+ # Define DbApiKey class for this instance
53
+ class DbApiKey(self.Base):
54
+ __tablename__ = 'api_keys'
55
+
56
+ id: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
57
+ hashed_key: Mapped[str] = mapped_column(unique=True, nullable=False)
58
+ title: Mapped[str] = mapped_column(nullable=False)
59
+ username: Mapped[str] = mapped_column(ForeignKey('users.username', ondelete='CASCADE'), nullable=False)
60
+ created_at: Mapped[datetime] = mapped_column(nullable=False)
61
+ expires_at: Mapped[datetime] = mapped_column(nullable=False)
62
+
63
+ def __repr__(self):
64
+ return f"<DbApiKey(id='{self.id}', username='{self.username}')>"
65
+
66
+ # Define DbOAuthClient class for this instance
67
+ class DbOAuthClient(self.Base):
68
+ __tablename__ = 'oauth_clients'
69
+
70
+ client_id: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
71
+ client_secret_hash: Mapped[str] = mapped_column(nullable=False)
72
+ client_name: Mapped[str] = mapped_column(nullable=False)
73
+ redirect_uris: Mapped[str] = mapped_column(nullable=False) # JSON array of allowed redirect URIs
74
+ scope: Mapped[str] = mapped_column(nullable=False, default='read')
75
+ grant_types: Mapped[str] = mapped_column(nullable=False, default='authorization_code,refresh_token')
76
+ response_types: Mapped[str] = mapped_column(nullable=False, default='code')
77
+ client_type: Mapped[str] = mapped_column(nullable=False, default='confidential') # 'confidential' or 'public'
78
+ registration_access_token_hash: Mapped[str] = mapped_column(nullable=False) # Token for client management
79
+ created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
80
+ is_active: Mapped[bool] = mapped_column(nullable=False, default=True)
81
+
82
+ def __repr__(self):
83
+ return f"<DbOAuthClient(client_id='{self.client_id}', name='{self.client_name}')>"
84
+
85
+ # Define DbAuthorizationCode class for this instance
86
+ class DbAuthorizationCode(self.Base):
87
+ __tablename__ = 'authorization_codes'
88
+
89
+ code: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
90
+ client_id: Mapped[str] = mapped_column(ForeignKey('oauth_clients.client_id', ondelete='CASCADE'), nullable=False)
91
+ username: Mapped[str] = mapped_column(ForeignKey('users.username', ondelete='CASCADE'), nullable=False)
92
+ redirect_uri: Mapped[str] = mapped_column(nullable=False)
93
+ scope: Mapped[str] = mapped_column(nullable=True)
94
+ code_challenge: Mapped[str] = mapped_column(nullable=False) # PKCE always required
95
+ code_challenge_method: Mapped[str] = mapped_column(nullable=False) # only S256 is supported
96
+ created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
97
+ expires_at: Mapped[datetime] = mapped_column(nullable=False) # 10 minutes from creation
98
+ used: Mapped[bool] = mapped_column(nullable=False, default=False)
99
+
100
+ def __repr__(self):
101
+ return f"<DbAuthorizationCode(code='{self.code[:8]}...', client_id='{self.client_id}')>"
102
+
103
+ # Define DbOAuthToken class for this instance
104
+ class DbOAuthToken(self.Base):
105
+ __tablename__ = 'oauth_tokens'
106
+
107
+ token_id: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
108
+ access_token_hash: Mapped[str] = mapped_column(unique=True, nullable=False)
109
+ refresh_token_hash: Mapped[str] = mapped_column(unique=True, nullable=True) # NULL for client_credentials grants
110
+ client_id: Mapped[str] = mapped_column(ForeignKey('oauth_clients.client_id', ondelete='CASCADE'), nullable=False)
111
+ username: Mapped[str] = mapped_column(ForeignKey('users.username', ondelete='CASCADE'), nullable=False)
112
+ scope: Mapped[str] = mapped_column(nullable=True)
113
+ token_type: Mapped[str] = mapped_column(nullable=False, default='Bearer')
114
+ created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
115
+ access_token_expires_at: Mapped[datetime] = mapped_column(nullable=False) # Uses SQRL_AUTH_TOKEN_EXPIRE_MINUTES
116
+ refresh_token_expires_at: Mapped[datetime] = mapped_column(nullable=True) # 30 days from creation, NULL for client_credentials
117
+ is_revoked: Mapped[bool] = mapped_column(nullable=False, default=False)
118
+
119
+ def __repr__(self):
120
+ return f"<DbOAuthToken(token_id='{self.token_id}', client_id='{self.client_id}', username='{self.username}')>"
121
+
122
+ self.DbApiKey = DbApiKey
123
+ self.DbOAuthClient = DbOAuthClient
124
+ self.DbAuthorizationCode = DbAuthorizationCode
125
+ self.DbOAuthToken = DbOAuthToken
126
+
127
+ self.CustomUserFields = custom_user_fields_cls
128
+ self.DbUser = DbUser
129
+
130
+ self.auth_providers = [provider_function(auth_args) for provider_function in provider_functions]
131
+
132
+ if sa_engine is None:
133
+ raw_sqlite_path = self.env_vars.get(c.SQRL_AUTH_DB_FILE_PATH, f"{{project_path}}/{c.TARGET_FOLDER}/{c.DB_FILE}")
134
+ sqlite_path = u.Path(raw_sqlite_path.format(project_path=base_path))
135
+ sqlite_path.parent.mkdir(parents=True, exist_ok=True)
136
+ self.engine = create_engine(f"sqlite:///{str(sqlite_path)}")
137
+ else:
138
+ self.engine = sa_engine
139
+
140
+ # Configure SQLite pragmas
141
+ with self.engine.connect() as conn:
142
+ conn.execute(text("PRAGMA journal_mode = WAL"))
143
+ conn.execute(text("PRAGMA synchronous = NORMAL"))
144
+ conn.commit()
145
+
146
+ self.Base.metadata.create_all(self.engine)
147
+
148
+ self.Session = sessionmaker(bind=self.engine)
149
+
150
+ self._initialize_db()
151
+
152
+ def _convert_db_user_to_user(self, db_user) -> RegisteredUser:
153
+ """Convert a database user to an AbstractUser object"""
154
+ # Deserialize custom_fields JSON and merge with defaults
155
+ custom_fields_json = json.loads(db_user.custom_fields) if db_user.custom_fields else {}
156
+ custom_fields = self.CustomUserFields(**custom_fields_json)
157
+
158
+ return RegisteredUser(
159
+ username=db_user.username,
160
+ access_level=db_user.access_level,
161
+ custom_fields=custom_fields
162
+ )
163
+
164
+ def _initialize_db(self): # TODO: Use logger instead of print
165
+ session = self.Session()
166
+ try:
167
+ # Get existing columns in the database
168
+ inspector = inspect(self.engine)
169
+ existing_columns = {col['name'] for col in inspector.get_columns('users')}
170
+
171
+ # Get all columns defined in the model
172
+ model_columns = set(self.DbUser.__table__.columns.keys())
173
+
174
+ # Find columns that are in the model but not in the database
175
+ new_columns = model_columns - existing_columns
176
+ if new_columns:
177
+ add_columns_msg = f"Adding columns to database: {new_columns}"
178
+ print("NOTE -", add_columns_msg)
179
+ self.logger.info(add_columns_msg)
180
+
181
+ for col_name in new_columns:
182
+ col = self.DbUser.__table__.columns[col_name]
183
+ column_type = col.type.compile(self.engine.dialect)
184
+ nullable = "NULL" if col.nullable else "NOT NULL"
185
+ if col.default is not None:
186
+ default_val = f"'{col.default.arg}'" if isinstance(col.default.arg, str) else col.default.arg
187
+ default = f"DEFAULT {default_val}"
188
+ else:
189
+ default = ""
190
+
191
+ alter_stmt = f"ALTER TABLE users ADD COLUMN {col_name} {column_type} {nullable} {default}"
192
+ session.execute(text(alter_stmt))
193
+
194
+ session.commit()
195
+
196
+ # Get admin password from environment variable if exists
197
+ admin_password = self.env_vars.get(c.SQRL_SECRET_ADMIN_PASSWORD)
198
+
199
+ # If admin password variable exists, find username "admin". If it does not exist, add it
200
+ if admin_password is not None:
201
+ password_hash = pwd_context.hash(admin_password)
202
+ admin_user = session.get(self.DbUser, c.ADMIN_USERNAME)
203
+ if admin_user is None:
204
+ admin_user = self.DbUser(username=c.ADMIN_USERNAME, password_hash=password_hash, access_level="admin")
205
+ session.add(admin_user)
206
+ else:
207
+ admin_user.password_hash = password_hash
208
+
209
+ session.commit()
210
+
211
+ finally:
212
+ session.close()
213
+
214
+ @cached_property
215
+ def user_fields(self) -> list[UserField]:
216
+ """
217
+ Get the fields of the CustomUserFields model as a list of dictionaries
218
+
219
+ Each dictionary contains the following keys:
220
+ - name: The name of the field
221
+ - type: The type of the field
222
+ - nullable: Whether the field is nullable
223
+ - enum: The possible values of the field (or None if not applicable)
224
+ - default: The default value of the field (or None if field is required)
225
+ """
226
+ # Add the base fields
227
+ fields = [
228
+ UserField(name="username", type="string", nullable=False, enum=None, default=None),
229
+ UserField(name="access_level", type="string", nullable=False, enum=["admin", "member"], default="member")
230
+ ]
231
+
232
+ # Add custom fields
233
+ schema = self.CustomUserFields.model_json_schema()
234
+ properties: dict[str, dict[str, Any]] = schema.get("properties", {})
235
+ for field_name, field_schema in properties.items():
236
+ if choices := field_schema.get("anyOf"):
237
+ field_type = choices[0]["type"]
238
+ nullable = (choices[1]["type"] == "null")
239
+ else:
240
+ field_type = field_schema["type"]
241
+ nullable = False
242
+
243
+ field_data = UserField(name=field_name, type=field_type, nullable=nullable, enum=field_schema.get("enum"), default=field_schema.get("default"))
244
+ fields.append(field_data)
245
+
246
+ return fields
247
+
248
+ def add_user(self, username: str, user_fields: dict, *, update_user: bool = False) -> None:
249
+ session = self.Session()
250
+
251
+ # Separate custom fields from base fields
252
+ access_level = user_fields.get('access_level', 'member')
253
+ password = user_fields.get('password')
254
+
255
+ # Validate access level - cannot add/update users with guest access level
256
+ if access_level == "guest":
257
+ raise InvalidInputError(400, "invalid_access_level", "Cannot create or update users with 'guest' access level")
258
+
259
+ # Extract custom fields (everything except username, access_level, password)
260
+ custom_fields_dict = {k: v for k, v in user_fields.items() if k not in ['username', 'access_level', 'password']}
261
+
262
+ # Validate the custom fields
263
+ try:
264
+ custom_fields = self.CustomUserFields(**custom_fields_dict)
265
+ custom_fields_json = json.dumps(custom_fields.model_dump(mode='json'))
266
+ except ValidationError as e:
267
+ raise InvalidInputError(400, "invalid_user_data", f"Invalid user field '{e.errors()[0]['loc'][0]}': {e.errors()[0]['msg']}")
268
+
269
+ # Add or update user
270
+ try:
271
+ # Check if the user already exists
272
+ existing_user = session.get(self.DbUser, username)
273
+ if existing_user is not None:
274
+ if not update_user:
275
+ raise InvalidInputError(400, "username_already_exists", f"User '{username}' already exists")
276
+
277
+ if username == c.ADMIN_USERNAME and access_level != "admin":
278
+ raise InvalidInputError(403, "admin_cannot_be_non_admin", "Setting the admin user to non-admin is not permitted")
279
+
280
+ # Update existing user
281
+ existing_user.access_level = access_level
282
+ existing_user.custom_fields = custom_fields_json
283
+ else:
284
+ if update_user:
285
+ raise InvalidInputError(404, "no_user_found_for_username", f"No user found for username: {username}")
286
+
287
+ if password is None:
288
+ raise InvalidInputError(400, "missing_password", f"Missing required field 'password' when adding a new user")
289
+
290
+ password_hash = pwd_context.hash(password)
291
+ new_user = self.DbUser(
292
+ username=username,
293
+ access_level=access_level,
294
+ password_hash=password_hash,
295
+ custom_fields=custom_fields_json
296
+ )
297
+ session.add(new_user)
298
+
299
+ # Commit the transaction
300
+ session.commit()
301
+
302
+ finally:
303
+ session.close()
304
+
305
+ def create_or_get_user_from_provider(self, provider_name: str, user_info: dict) -> RegisteredUser:
306
+ provider = next((p for p in self.auth_providers if p.name == provider_name), None)
307
+ if provider is None:
308
+ raise InvalidInputError(404, "auth_provider_not_found", f"Provider '{provider_name}' not found")
309
+
310
+ user = provider.provider_configs.get_user(user_info)
311
+ session = self.Session()
312
+ try:
313
+ # Convert user to database user
314
+ custom_fields_json = user.custom_fields.model_dump_json()
315
+ db_user = self.DbUser(
316
+ username=user.username,
317
+ access_level=user.access_level,
318
+ password_hash="", # By omitting password_hash, it becomes impossible to login with username and password (OAuth only)
319
+ custom_fields=custom_fields_json
320
+ )
321
+
322
+ existing_db_user = session.get(self.DbUser, db_user.username)
323
+ if existing_db_user is None:
324
+ session.add(db_user)
325
+
326
+ session.commit()
327
+
328
+ return self._convert_db_user_to_user(db_user)
329
+
330
+ finally:
331
+ session.close()
332
+
333
+ def get_user(self, username: str, password: str) -> RegisteredUser:
334
+ session = self.Session()
335
+ try:
336
+ # Query for user by username
337
+ db_user = session.get(self.DbUser, username)
338
+
339
+ if db_user and pwd_context.verify(password, db_user.password_hash):
340
+ user = self._convert_db_user_to_user(db_user)
341
+ return user
342
+ else:
343
+ raise InvalidInputError(401, "incorrect_username_or_password", f"Incorrect username or password")
344
+
345
+ finally:
346
+ session.close()
347
+
348
+ def change_password(self, username: str, old_password: str, new_password: str) -> None:
349
+ session = self.Session()
350
+ try:
351
+ db_user = session.get(self.DbUser, username)
352
+ if db_user is None:
353
+ raise InvalidInputError(401, "user_not_found", f"Username '{username}' not found for password change")
354
+
355
+ if db_user.password_hash and pwd_context.verify(old_password, db_user.password_hash):
356
+ db_user.password_hash = pwd_context.hash(new_password)
357
+ session.commit()
358
+ else:
359
+ raise InvalidInputError(401, "incorrect_password", f"Incorrect password")
360
+ finally:
361
+ session.close()
362
+
363
+ def delete_user(self, username: str) -> None:
364
+ if username == c.ADMIN_USERNAME:
365
+ raise InvalidInputError(403, "cannot_delete_admin_user", "Cannot delete the admin user")
366
+
367
+ session = self.Session()
368
+ try:
369
+ db_user = session.get(self.DbUser, username)
370
+ if db_user is None:
371
+ raise InvalidInputError(404, "no_user_found_for_username", f"No user found for username: {username}")
372
+ session.delete(db_user)
373
+ session.commit()
374
+ finally:
375
+ session.close()
376
+
377
+ def get_all_users(self) -> list[RegisteredUser]:
378
+ session = self.Session()
379
+ try:
380
+ db_users = session.query(self.DbUser).all()
381
+ return [self._convert_db_user_to_user(user) for user in db_users]
382
+ finally:
383
+ session.close()
384
+
385
+ def create_access_token(self, user: AbstractUser, expiry_minutes: int | None, *, title: str | None = None) -> tuple[str, datetime]:
386
+ """
387
+ Creates an API key if title is provided. Otherwise, creates a JWT token.
388
+ """
389
+ created_at = datetime.now(timezone.utc)
390
+ expire_at = created_at + timedelta(minutes=expiry_minutes) if expiry_minutes is not None else datetime.max
391
+
392
+ if self.secret_key is None:
393
+ raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to create an access token")
394
+
395
+ if title is not None:
396
+ session = self.Session()
397
+ try:
398
+ token_id = "sqrl-" + uuid.uuid4().hex
399
+ hashed_key = u.hash_string(token_id, salt=self.secret_key)
400
+ api_key = self.DbApiKey(hashed_key=hashed_key, title=title, username=user.username, created_at=created_at, expires_at=expire_at)
401
+ session.add(api_key)
402
+ session.commit()
403
+ finally:
404
+ session.close()
405
+ else:
406
+ to_encode = {"username": user.username, "exp": expire_at}
407
+ token_id = jwt.encode(to_encode, self.secret_key, algorithm="HS256")
408
+
409
+ return token_id, expire_at
410
+
411
+ def get_user_from_token(self, token: str | None) -> RegisteredUser | None:
412
+ """
413
+ Get a user from an access token (JWT, or API key if token starts with 'sqrl-')
414
+ """
415
+ if not token:
416
+ return None
417
+
418
+ if self.secret_key is None:
419
+ raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to get user from an access token")
420
+
421
+ session = self.Session()
422
+ try:
423
+ if token.startswith("sqrl-"):
424
+ hashed_key = u.hash_string(token, salt=self.secret_key)
425
+ api_key = session.query(self.DbApiKey).filter(
426
+ self.DbApiKey.hashed_key == hashed_key,
427
+ self.DbApiKey.expires_at >= func.now()
428
+ ).first()
429
+ if api_key is None:
430
+ raise InvalidTokenError()
431
+ username = api_key.username
432
+ else:
433
+ payload: dict = jwt.decode(token, self.secret_key, algorithms=["HS256"])
434
+ username = payload["username"]
435
+
436
+ db_user = session.get(self.DbUser, username)
437
+ if db_user is None:
438
+ raise InvalidTokenError()
439
+
440
+ user = self._convert_db_user_to_user(db_user)
441
+
442
+ except InvalidTokenError:
443
+ raise InvalidInputError(401, "invalid_authorization_token", "Invalid authorization token")
444
+ finally:
445
+ session.close()
446
+
447
+ return user
448
+
449
+ def get_all_api_keys(self, username: str) -> list[ApiKey]:
450
+ """
451
+ Get the ID, title, and expiry date of all API keys for a user. Note that the ID is a hash of the API key, not the API key itself.
452
+ """
453
+ session = self.Session()
454
+ try:
455
+ tokens = session.query(self.DbApiKey).filter(
456
+ self.DbApiKey.username == username,
457
+ self.DbApiKey.expires_at >= func.now()
458
+ ).all()
459
+
460
+ return [ApiKey.model_validate(token) for token in tokens]
461
+ finally:
462
+ session.close()
463
+
464
+ def revoke_api_key(self, username: str, api_key_id: str) -> None:
465
+ """
466
+ Revoke an API key
467
+ """
468
+ session = self.Session()
469
+ try:
470
+
471
+ api_key = session.query(self.DbApiKey).filter(
472
+ self.DbApiKey.username == username,
473
+ self.DbApiKey.id == api_key_id
474
+ ).first()
475
+
476
+ if api_key is None:
477
+ raise InvalidInputError(404, "api_key_not_found", f"The API key could not be found: {api_key_id}")
478
+
479
+ session.delete(api_key)
480
+ session.commit()
481
+ finally:
482
+ session.close()
483
+
484
+ def can_user_access_scope(self, user: AbstractUser, scope: PermissionScope) -> bool:
485
+ if user.access_level == "guest":
486
+ user_level = PermissionScope.PUBLIC
487
+ elif user.access_level == "admin":
488
+ user_level = PermissionScope.PRIVATE
489
+ else: # member
490
+ user_level = PermissionScope.PROTECTED
491
+
492
+ return user_level.value >= scope.value
493
+
494
+ # OAuth Client Management Methods
495
+
496
+ def generate_secret_and_hash(self) -> tuple[str, str]:
497
+ """Generate a secure access token and its hash"""
498
+ secret = secrets.token_urlsafe(64)
499
+ secret_hash = pwd_context.hash(secret)
500
+ return secret, secret_hash
501
+
502
+ def _validate_client_registration_request(self, request: ClientRegistrationRequest | ClientUpdateRequest) -> dict:
503
+ updates = {}
504
+ if request.client_name:
505
+ updates['client_name'] = request.client_name
506
+
507
+ # Validate redirect_uris if being updated
508
+ if request.redirect_uris:
509
+ for uri in request.redirect_uris:
510
+ if not self._validate_redirect_uri_format(uri):
511
+ raise InvalidInputError(400, "invalid_redirect_uri", f"Invalid redirect URI format: {uri}")
512
+ updates['redirect_uris'] = json.dumps(request.redirect_uris)
513
+
514
+ # Validate grant_types if being updated
515
+ if request.grant_types:
516
+ if not all(grant_type in c.SUPPORTED_GRANT_TYPES for grant_type in request.grant_types):
517
+ raise InvalidInputError(400, "invalid_grant_types", f"Invalid grant types. Supported grant types are: {c.SUPPORTED_GRANT_TYPES}")
518
+ updates['grant_types'] = ','.join(request.grant_types)
519
+
520
+ # Validate response_types if being updated
521
+ if request.response_types:
522
+ if not all(response_type in c.SUPPORTED_RESPONSE_TYPES for response_type in request.response_types):
523
+ raise InvalidInputError(400, "invalid_response_types", f"Invalid response types. Supported response types are: {c.SUPPORTED_RESPONSE_TYPES}")
524
+ updates['response_types'] = ','.join(request.response_types)
525
+
526
+ # Validate scope if being updated
527
+ if request.scope:
528
+ scopes = request.scope.split()
529
+ if not all(scope in c.SUPPORTED_SCOPES for scope in scopes):
530
+ raise InvalidInputError(400, "invalid_scope", f"Invalid scope. Supported scopes are: {c.SUPPORTED_SCOPES}")
531
+ updates['scope'] = ','.join(scopes)
532
+
533
+ return updates
534
+
535
+ def register_oauth_client(
536
+ self, request: ClientRegistrationRequest, client_management_path_format: str
537
+ ) -> ClientRegistrationResponse:
538
+ """Register a new OAuth client and return client_id, client_secret, and registration_access_token"""
539
+ grant_types = request.grant_types
540
+ if grant_types is None:
541
+ grant_types = ['authorization_code', 'refresh_token']
542
+
543
+ # Validate request
544
+ self._validate_client_registration_request(request)
545
+
546
+ # Generate secure client credentials and registration access token
547
+ client_id = secrets.token_urlsafe(16)
548
+ client_secret, client_secret_hash = self.generate_secret_and_hash()
549
+
550
+ registration_access_token, registration_access_token_hash = self.generate_secret_and_hash()
551
+ registration_client_uri = client_management_path_format.format(client_id=client_id)
552
+
553
+ session = self.Session()
554
+ try:
555
+ oauth_client = self.DbOAuthClient(
556
+ client_id=client_id,
557
+ client_secret_hash=client_secret_hash,
558
+ client_name=request.client_name,
559
+ redirect_uris=json.dumps(request.redirect_uris),
560
+ scope=request.scope,
561
+ grant_types=','.join(grant_types),
562
+ registration_access_token_hash=registration_access_token_hash
563
+ )
564
+ session.add(oauth_client)
565
+ session.commit()
566
+
567
+ return ClientRegistrationResponse(
568
+ client_id=client_id,
569
+ client_secret=client_secret,
570
+ client_name=request.client_name,
571
+ redirect_uris=request.redirect_uris,
572
+ scope=request.scope,
573
+ grant_types=grant_types,
574
+ response_types=request.response_types,
575
+ created_at=datetime.now(timezone.utc),
576
+ is_active=True,
577
+ registration_client_uri=registration_client_uri,
578
+ registration_access_token=registration_access_token,
579
+ )
580
+
581
+ finally:
582
+ session.close()
583
+
584
+ def get_oauth_client_details(self, client_id: str) -> ClientDetailsResponse:
585
+ """Get OAuth client details with parsed JSON fields"""
586
+ session = self.Session()
587
+ try:
588
+ client = session.get(self.DbOAuthClient, client_id)
589
+ if client is None:
590
+ raise InvalidInputError(404, "invalid_client_id", "Client not found")
591
+ if not client.is_active:
592
+ raise InvalidInputError(404, "invalid_client_id", "Client is no longer active")
593
+
594
+ return ClientDetailsResponse(
595
+ client_id=client.client_id,
596
+ client_name=client.client_name,
597
+ redirect_uris=json.loads(client.redirect_uris),
598
+ scope=client.scope,
599
+ grant_types=client.grant_types.split(','),
600
+ response_types=client.response_types.split(','),
601
+ created_at=client.created_at,
602
+ is_active=client.is_active
603
+ )
604
+ finally:
605
+ session.close()
606
+
607
+ def validate_client_credentials(self, client_id: str, client_secret: str) -> bool:
608
+ """Validate OAuth client credentials"""
609
+ session = self.Session()
610
+ try:
611
+ client = session.get(self.DbOAuthClient, client_id)
612
+ if client is None or not client.is_active:
613
+ return False
614
+ return pwd_context.verify(client_secret, client.client_secret_hash)
615
+ finally:
616
+ session.close()
617
+
618
+ def validate_redirect_uri(self, client_id: str, redirect_uri: str) -> bool:
619
+ """Validate that redirect_uri is registered for the client"""
620
+ session = self.Session()
621
+ try:
622
+ client = session.get(self.DbOAuthClient, client_id)
623
+ if client is None or not client.is_active:
624
+ return False
625
+
626
+ registered_uris = json.loads(client.redirect_uris)
627
+ return redirect_uri in registered_uris
628
+ finally:
629
+ session.close()
630
+
631
+ def validate_registration_access_token(self, client_id: str, registration_access_token: str) -> bool:
632
+ """Validate registration access token for client management operations"""
633
+ session = self.Session()
634
+ try:
635
+ client = session.get(self.DbOAuthClient, client_id)
636
+ if client is None:
637
+ return False
638
+
639
+ return pwd_context.verify(registration_access_token, client.registration_access_token_hash)
640
+ finally:
641
+ session.close()
642
+
643
+ def _validate_redirect_uri_format(self, uri: str) -> bool:
644
+ """Validate redirect URI format for security"""
645
+ # Basic validation - must be https (except localhost) and not contain fragments
646
+ if '#' in uri:
647
+ return False
648
+
649
+ if uri.startswith('http://'):
650
+ # Only allow http for localhost/development
651
+ if not ('localhost' in uri or '127.0.0.1' in uri):
652
+ return False
653
+ elif not uri.startswith('https://'):
654
+ # Custom schemes allowed for mobile apps
655
+ if '://' not in uri:
656
+ return False
657
+
658
+ return True
659
+
660
+ def update_oauth_client_with_token_rotation(self, client_id: str, request: ClientUpdateRequest) -> ClientUpdateResponse:
661
+ """Update OAuth client metadata and rotate registration access token"""
662
+
663
+ # Build update dictionary, excluding None values
664
+ updates = self._validate_client_registration_request(request)
665
+
666
+ session = self.Session()
667
+ try:
668
+ client = session.get(self.DbOAuthClient, client_id)
669
+ if client is None:
670
+ raise InvalidInputError(404, "invalid_client_id", f"OAuth client not found: {client_id}")
671
+ if not client.is_active and not request.is_active:
672
+ raise InvalidInputError(400, "invalid_request", "Cannot update a deactivated client")
673
+
674
+ if request.is_active is not None:
675
+ client.is_active = request.is_active
676
+
677
+ # Generate new registration access token
678
+ new_registration_access_token, new_registration_access_token_hash = self.generate_secret_and_hash()
679
+
680
+ # Update client fields
681
+ for key, value in updates.items():
682
+ if hasattr(client, key):
683
+ setattr(client, key, value)
684
+
685
+ # Update registration access token
686
+ client.registration_access_token_hash = new_registration_access_token_hash
687
+
688
+ session.commit()
689
+
690
+ return ClientUpdateResponse(
691
+ client_id=client.client_id,
692
+ client_name=client.client_name,
693
+ redirect_uris=json.loads(client.redirect_uris),
694
+ scope=client.scope,
695
+ grant_types=client.grant_types.split(','),
696
+ response_types=client.response_types.split(','),
697
+ created_at=client.created_at,
698
+ is_active=client.is_active,
699
+ registration_access_token=new_registration_access_token,
700
+ )
701
+
702
+ finally:
703
+ session.close()
704
+
705
+ def revoke_oauth_client(self, client_id: str) -> None:
706
+ """Revoke (deactivate) an OAuth client"""
707
+ session = self.Session()
708
+ try:
709
+ client = session.get(self.DbOAuthClient, client_id)
710
+ if client is None:
711
+ raise InvalidInputError(404, "client_not_found", f"OAuth client not found: {client_id}")
712
+
713
+ client.is_active = False
714
+ session.commit()
715
+ finally:
716
+ session.close()
717
+
718
+ # Authorization Code Flow Methods
719
+
720
+ def create_authorization_code(
721
+ self, client_id: str, username: str, redirect_uri: str, scope: str,
722
+ code_challenge: str, code_challenge_method: str = "S256"
723
+ ) -> str:
724
+ """Create and store an authorization code for the authorization code flow"""
725
+
726
+ # Validate client_id
727
+ client_details = self.get_oauth_client_details(client_id)
728
+ if client_details is None or not client_details.is_active:
729
+ raise InvalidInputError(400, "invalid_client", "Invalid or inactive client")
730
+
731
+ # Validate redirect_uri
732
+ if not self.validate_redirect_uri(client_id, redirect_uri):
733
+ raise InvalidInputError(400, "invalid_redirect_uri", "Invalid redirect_uri for this client")
734
+
735
+ # Validate PKCE parameters
736
+ if not code_challenge:
737
+ raise InvalidInputError(400, "invalid_request", "code_challenge is required")
738
+
739
+ if code_challenge_method not in ["S256"]:
740
+ raise InvalidInputError(400, "invalid_request", "code_challenge_method must be 'S256'")
741
+
742
+ # Generate authorization code
743
+ code = secrets.token_urlsafe(48)
744
+
745
+ # Set expiration (10 minutes from now)
746
+ created_at = datetime.now(timezone.utc)
747
+ expires_at = created_at + timedelta(minutes=10)
748
+
749
+ session = self.Session()
750
+ try:
751
+ auth_code = self.DbAuthorizationCode(
752
+ code=code,
753
+ client_id=client_id,
754
+ username=username,
755
+ redirect_uri=redirect_uri,
756
+ scope=scope,
757
+ code_challenge=code_challenge,
758
+ code_challenge_method=code_challenge_method,
759
+ created_at=created_at,
760
+ expires_at=expires_at,
761
+ used=False
762
+ )
763
+ session.add(auth_code)
764
+ session.commit()
765
+
766
+ return code
767
+
768
+ finally:
769
+ session.close()
770
+
771
+ def exchange_authorization_code(
772
+ self, code: str, client_id: str, redirect_uri: str, code_verifier: str, access_token_expiry_minutes: int
773
+ ) -> TokenResponse:
774
+ """
775
+ Exchange authorization code for access and refresh tokens
776
+ Returns (access_token, refresh_token)
777
+ """
778
+ session = self.Session()
779
+ try:
780
+ # Get and validate authorization code
781
+ auth_code = session.query(self.DbAuthorizationCode).filter(
782
+ self.DbAuthorizationCode.code == code,
783
+ self.DbAuthorizationCode.client_id == client_id,
784
+ self.DbAuthorizationCode.expires_at >= func.now(),
785
+ self.DbAuthorizationCode.used == False
786
+ ).first()
787
+
788
+ if auth_code is None:
789
+ raise InvalidInputError(400, "invalid_grant", "Invalid authorization code")
790
+
791
+ # Validate redirect URI
792
+ if auth_code.redirect_uri != redirect_uri:
793
+ raise InvalidInputError(400, "invalid_grant", "Redirect URI mismatch")
794
+
795
+ if not u.validate_pkce_challenge(code_verifier, auth_code.code_challenge):
796
+ raise InvalidInputError(400, "invalid_grant", "Invalid code_verifier")
797
+
798
+ # Get user
799
+ db_user = session.get(self.DbUser, auth_code.username)
800
+ if db_user is None:
801
+ raise InvalidInputError(400, "invalid_grant", "User not found")
802
+
803
+ # Mark authorization code as used
804
+ auth_code.used = True
805
+
806
+ # Generate tokens
807
+ user_obj = self._convert_db_user_to_user(db_user)
808
+ access_token, token_expires_at = self.create_access_token(user_obj, expiry_minutes=access_token_expiry_minutes)
809
+ access_token_hash = pwd_context.hash(access_token)
810
+
811
+ # Generate refresh token
812
+ refresh_token, refresh_token_hash = self.generate_secret_and_hash()
813
+ refresh_expires_at = datetime.now(timezone.utc) + timedelta(days=30)
814
+
815
+ oauth_token = self.DbOAuthToken(
816
+ access_token_hash=access_token_hash,
817
+ refresh_token_hash=refresh_token_hash,
818
+ client_id=client_id,
819
+ username=auth_code.username,
820
+ scope=auth_code.scope,
821
+ access_token_expires_at=token_expires_at,
822
+ refresh_token_expires_at=refresh_expires_at
823
+ )
824
+ session.add(oauth_token)
825
+
826
+ session.commit()
827
+
828
+ return TokenResponse(
829
+ access_token=access_token,
830
+ expires_in=access_token_expiry_minutes*60,
831
+ refresh_token=refresh_token
832
+ )
833
+
834
+ finally:
835
+ session.close()
836
+
837
+ def refresh_oauth_access_token(self, refresh_token: str, client_id: str, access_token_expiry_minutes: int) -> TokenResponse:
838
+ """
839
+ Refresh OAuth access token using refresh token
840
+ Returns (access_token, new_refresh_token)
841
+ """
842
+ session = self.Session()
843
+ try:
844
+ # Validate client
845
+ client = session.get(self.DbOAuthClient, client_id)
846
+ if client is None or not client.is_active:
847
+ raise InvalidInputError(400, "invalid_client", "Invalid or inactive client")
848
+
849
+ # Find active refresh token for this client
850
+ oauth_token = session.query(self.DbOAuthToken).filter(
851
+ self.DbOAuthToken.client_id == client_id,
852
+ self.DbOAuthToken.refresh_token_expires_at >= func.now(),
853
+ self.DbOAuthToken.is_revoked == False
854
+ ).first()
855
+
856
+ # Find the token that matches our refresh token
857
+ if oauth_token is None or not pwd_context.verify(refresh_token, oauth_token.refresh_token_hash):
858
+ raise InvalidInputError(400, "invalid_grant", "Invalid or expired refresh token")
859
+
860
+ # Get user
861
+ db_user = session.get(self.DbUser, oauth_token.username)
862
+ if db_user is None:
863
+ raise InvalidInputError(400, "invalid_client", "User not found")
864
+
865
+ # Check secret key is available
866
+ if self.secret_key is None:
867
+ raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required for OAuth token operations")
868
+
869
+ # Generate new tokens
870
+ user_obj = self._convert_db_user_to_user(db_user)
871
+ access_token, token_expires_at = self.create_access_token(user_obj, expiry_minutes=access_token_expiry_minutes)
872
+ access_token_hash = pwd_context.hash(access_token)
873
+
874
+ # Generate new refresh token
875
+ new_refresh_token, new_refresh_token_hash = self.generate_secret_and_hash()
876
+ refresh_expires_at = datetime.now(timezone.utc) + timedelta(days=30)
877
+
878
+ # Revoke old token
879
+ oauth_token.is_revoked = True
880
+
881
+ # Create new token entry
882
+ new_oauth_token = self.DbOAuthToken(
883
+ access_token_hash=access_token_hash,
884
+ refresh_token_hash=new_refresh_token_hash,
885
+ client_id=client_id,
886
+ username=oauth_token.username,
887
+ scope=oauth_token.scope,
888
+ access_token_expires_at=token_expires_at,
889
+ refresh_token_expires_at=refresh_expires_at
890
+ )
891
+ session.add(new_oauth_token)
892
+
893
+ session.commit()
894
+
895
+ return TokenResponse(
896
+ access_token=access_token,
897
+ expires_in=access_token_expiry_minutes*60,
898
+ refresh_token=new_refresh_token
899
+ )
900
+
901
+ finally:
902
+ session.close()
903
+
904
+ def revoke_oauth_token(self, client_id: str, token: str, token_type_hint: str | None = None) -> None:
905
+ """
906
+ Revoke an OAuth refresh token
907
+ token_type_hint is optional or must be 'refresh_token'. Revoking access token is not supported yet.
908
+ """
909
+ if token_type_hint and token_type_hint != 'refresh_token':
910
+ raise InvalidInputError(400, "invalid_request", "Only refresh tokens can be revoked")
911
+
912
+ session = self.Session()
913
+ try:
914
+ # Validate client
915
+ client = session.get(self.DbOAuthClient, client_id)
916
+ if client is None or not client.is_active:
917
+ raise InvalidInputError(400, "invalid_client", "Invalid or inactive client")
918
+
919
+ # Get all potentially matching tokens
920
+ oauth_tokens = session.query(self.DbOAuthToken).filter(
921
+ self.DbOAuthToken.client_id == client_id,
922
+ self.DbOAuthToken.refresh_token_expires_at >= func.now(),
923
+ self.DbOAuthToken.is_revoked == False
924
+ ).all()
925
+
926
+ # Find the token that matches
927
+ oauth_token = None
928
+ for token_obj in oauth_tokens:
929
+ if pwd_context.verify(token, token_obj.refresh_token_hash):
930
+ oauth_token = token_obj
931
+ break
932
+
933
+ # Revoke token if found (per OAuth spec, always return success)
934
+ if oauth_token:
935
+ oauth_token.is_revoked = True
936
+ session.commit()
937
+
938
+ finally:
939
+ session.close()
940
+
941
+ def close(self) -> None:
942
+ self.engine.dispose()
943
+
944
+
945
+ def provider(name: str, label: str, icon: str):
946
+ """
947
+ Decorator to register an authentication provider
948
+
949
+ Arguments:
950
+ name: The name of the provider (must be unique, e.g. 'google')
951
+ label: The label of the provider (e.g. 'Google')
952
+ icon: The URL of the icon of the provider (e.g. 'https://www.google.com/favicon.ico')
953
+ """
954
+ def decorator(func: Callable[[AuthProviderArgs], ProviderConfigs]):
955
+ def wrapper(sqrl: AuthProviderArgs):
956
+ provider_configs = func(sqrl)
957
+ return AuthProvider(name=name, label=label, icon=icon, provider_configs=provider_configs)
958
+ Authenticator.providers.append(wrapper)
959
+ return wrapper
960
+ return decorator