squirrels 0.1.0__py3-none-any.whl → 0.6.0.post0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. dateutils/__init__.py +6 -0
  2. dateutils/_enums.py +25 -0
  3. squirrels/dateutils.py → dateutils/_implementation.py +409 -380
  4. dateutils/types.py +6 -0
  5. squirrels/__init__.py +21 -18
  6. squirrels/_api_routes/__init__.py +5 -0
  7. squirrels/_api_routes/auth.py +337 -0
  8. squirrels/_api_routes/base.py +196 -0
  9. squirrels/_api_routes/dashboards.py +156 -0
  10. squirrels/_api_routes/data_management.py +148 -0
  11. squirrels/_api_routes/datasets.py +220 -0
  12. squirrels/_api_routes/project.py +289 -0
  13. squirrels/_api_server.py +552 -134
  14. squirrels/_arguments/__init__.py +0 -0
  15. squirrels/_arguments/init_time_args.py +83 -0
  16. squirrels/_arguments/run_time_args.py +111 -0
  17. squirrels/_auth.py +777 -0
  18. squirrels/_command_line.py +239 -107
  19. squirrels/_compile_prompts.py +147 -0
  20. squirrels/_connection_set.py +94 -0
  21. squirrels/_constants.py +141 -64
  22. squirrels/_dashboards.py +179 -0
  23. squirrels/_data_sources.py +570 -0
  24. squirrels/_dataset_types.py +91 -0
  25. squirrels/_env_vars.py +209 -0
  26. squirrels/_exceptions.py +29 -0
  27. squirrels/_http_error_responses.py +52 -0
  28. squirrels/_initializer.py +319 -110
  29. squirrels/_logging.py +121 -0
  30. squirrels/_manifest.py +357 -187
  31. squirrels/_mcp_server.py +578 -0
  32. squirrels/_model_builder.py +69 -0
  33. squirrels/_model_configs.py +74 -0
  34. squirrels/_model_queries.py +52 -0
  35. squirrels/_models.py +1201 -0
  36. squirrels/_package_data/base_project/.env +7 -0
  37. squirrels/_package_data/base_project/.env.example +44 -0
  38. squirrels/_package_data/base_project/connections.yml +16 -0
  39. squirrels/_package_data/base_project/dashboards/dashboard_example.py +40 -0
  40. squirrels/_package_data/base_project/dashboards/dashboard_example.yml +22 -0
  41. squirrels/_package_data/base_project/docker/.dockerignore +16 -0
  42. squirrels/_package_data/base_project/docker/Dockerfile +16 -0
  43. squirrels/_package_data/base_project/docker/compose.yml +7 -0
  44. squirrels/_package_data/base_project/duckdb_init.sql +10 -0
  45. squirrels/_package_data/base_project/gitignore +13 -0
  46. squirrels/_package_data/base_project/macros/macros_example.sql +17 -0
  47. squirrels/_package_data/base_project/models/builds/build_example.py +26 -0
  48. squirrels/_package_data/base_project/models/builds/build_example.sql +16 -0
  49. squirrels/_package_data/base_project/models/builds/build_example.yml +57 -0
  50. squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +17 -0
  51. squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +32 -0
  52. squirrels/_package_data/base_project/models/federates/federate_example.py +51 -0
  53. squirrels/_package_data/base_project/models/federates/federate_example.sql +21 -0
  54. squirrels/_package_data/base_project/models/federates/federate_example.yml +65 -0
  55. squirrels/_package_data/base_project/models/sources.yml +38 -0
  56. squirrels/_package_data/base_project/parameters.yml +142 -0
  57. squirrels/_package_data/base_project/pyconfigs/connections.py +19 -0
  58. squirrels/_package_data/base_project/pyconfigs/context.py +96 -0
  59. squirrels/_package_data/base_project/pyconfigs/parameters.py +141 -0
  60. squirrels/_package_data/base_project/pyconfigs/user.py +56 -0
  61. squirrels/_package_data/base_project/resources/expenses.db +0 -0
  62. squirrels/_package_data/base_project/resources/public/.gitkeep +0 -0
  63. squirrels/_package_data/base_project/resources/weather.db +0 -0
  64. squirrels/_package_data/base_project/seeds/seed_categories.csv +6 -0
  65. squirrels/_package_data/base_project/seeds/seed_categories.yml +15 -0
  66. squirrels/_package_data/base_project/seeds/seed_subcategories.csv +15 -0
  67. squirrels/_package_data/base_project/seeds/seed_subcategories.yml +21 -0
  68. squirrels/_package_data/base_project/squirrels.yml.j2 +61 -0
  69. squirrels/_package_data/base_project/tmp/.gitignore +2 -0
  70. squirrels/_package_data/templates/login_successful.html +53 -0
  71. squirrels/_package_data/templates/squirrels_studio.html +22 -0
  72. squirrels/_package_loader.py +29 -0
  73. squirrels/_parameter_configs.py +592 -0
  74. squirrels/_parameter_options.py +348 -0
  75. squirrels/_parameter_sets.py +207 -0
  76. squirrels/_parameters.py +1703 -0
  77. squirrels/_project.py +796 -0
  78. squirrels/_py_module.py +122 -0
  79. squirrels/_request_context.py +33 -0
  80. squirrels/_schemas/__init__.py +0 -0
  81. squirrels/_schemas/auth_models.py +83 -0
  82. squirrels/_schemas/query_param_models.py +70 -0
  83. squirrels/_schemas/request_models.py +26 -0
  84. squirrels/_schemas/response_models.py +286 -0
  85. squirrels/_seeds.py +97 -0
  86. squirrels/_sources.py +112 -0
  87. squirrels/_utils.py +540 -149
  88. squirrels/_version.py +1 -3
  89. squirrels/arguments.py +7 -0
  90. squirrels/auth.py +4 -0
  91. squirrels/connections.py +3 -0
  92. squirrels/dashboards.py +3 -0
  93. squirrels/data_sources.py +14 -282
  94. squirrels/parameter_options.py +13 -189
  95. squirrels/parameters.py +14 -801
  96. squirrels/types.py +18 -0
  97. squirrels-0.6.0.post0.dist-info/METADATA +148 -0
  98. squirrels-0.6.0.post0.dist-info/RECORD +101 -0
  99. {squirrels-0.1.0.dist-info → squirrels-0.6.0.post0.dist-info}/WHEEL +1 -2
  100. {squirrels-0.1.0.dist-info → squirrels-0.6.0.post0.dist-info}/entry_points.txt +1 -0
  101. squirrels-0.6.0.post0.dist-info/licenses/LICENSE +201 -0
  102. squirrels/_credentials_manager.py +0 -87
  103. squirrels/_module_loader.py +0 -37
  104. squirrels/_parameter_set.py +0 -151
  105. squirrels/_renderer.py +0 -286
  106. squirrels/_timed_imports.py +0 -37
  107. squirrels/connection_set.py +0 -126
  108. squirrels/package_data/base_project/.gitignore +0 -4
  109. squirrels/package_data/base_project/connections.py +0 -21
  110. squirrels/package_data/base_project/database/sample_database.db +0 -0
  111. squirrels/package_data/base_project/database/seattle_weather.db +0 -0
  112. squirrels/package_data/base_project/datasets/sample_dataset/context.py +0 -8
  113. squirrels/package_data/base_project/datasets/sample_dataset/database_view1.py +0 -23
  114. squirrels/package_data/base_project/datasets/sample_dataset/database_view1.sql.j2 +0 -7
  115. squirrels/package_data/base_project/datasets/sample_dataset/final_view.py +0 -10
  116. squirrels/package_data/base_project/datasets/sample_dataset/final_view.sql.j2 +0 -2
  117. squirrels/package_data/base_project/datasets/sample_dataset/parameters.py +0 -30
  118. squirrels/package_data/base_project/datasets/sample_dataset/selections.cfg +0 -6
  119. squirrels/package_data/base_project/squirrels.yaml +0 -26
  120. squirrels/package_data/static/favicon.ico +0 -0
  121. squirrels/package_data/static/script.js +0 -234
  122. squirrels/package_data/static/style.css +0 -110
  123. squirrels/package_data/templates/index.html +0 -32
  124. squirrels-0.1.0.dist-info/LICENSE +0 -22
  125. squirrels-0.1.0.dist-info/METADATA +0 -67
  126. squirrels-0.1.0.dist-info/RECORD +0 -40
  127. squirrels-0.1.0.dist-info/top_level.txt +0 -1
squirrels/_auth.py ADDED
@@ -0,0 +1,777 @@
1
+ from typing import Callable, Any
2
+ from datetime import datetime, timedelta, timezone
3
+ from functools import cached_property
4
+ from jwt.exceptions import InvalidTokenError
5
+ from passlib.context import CryptContext
6
+ from pydantic import ValidationError
7
+ from sqlalchemy import create_engine, Engine, func, inspect, text, ForeignKey
8
+ from sqlalchemy.orm import declarative_base, sessionmaker, Mapped, mapped_column
9
+ import jwt, uuid, secrets, json, base64, requests
10
+ from jwt import PyJWKClient
11
+
12
+ from ._env_vars import SquirrelsEnvVars
13
+ from ._manifest import PermissionScope
14
+ from ._exceptions import InvalidInputError, ConfigurationError
15
+ from ._arguments.init_time_args import AuthProviderArgs
16
+ from ._schemas.auth_models import (
17
+ CustomUserFields, AbstractUser, RegisteredUser, ApiKey, UserField, UserFieldsModel, AuthProvider, ProviderConfigs
18
+ )
19
+ from ._schemas import response_models as rm
20
+ from . import _utils as u, _constants as c
21
+
22
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
23
+
24
+ ProviderFunctionType = Callable[[AuthProviderArgs], AuthProvider]
25
+
26
+
27
+ class Authenticator:
28
+ providers: list[ProviderFunctionType] = [] # static variable to stage providers
29
+
30
+ def __init__(
31
+ self, logger: u.Logger, env_vars: SquirrelsEnvVars, auth_args: AuthProviderArgs,
32
+ provider_functions: list[ProviderFunctionType], custom_user_fields_cls: type[CustomUserFields],
33
+ *,
34
+ sa_engine: Engine | None = None, external_only: bool = False
35
+ ):
36
+ self.logger = logger
37
+ self.env_vars = env_vars
38
+ self.secret_key = env_vars.secret_key
39
+ self.external_only = external_only
40
+ self.password_requirements = rm.PasswordRequirements()
41
+
42
+ # Create a new declarative base for this instance
43
+ self.Base = declarative_base()
44
+
45
+ # Define DbUser class for this instance
46
+ class DbUser(self.Base):
47
+ __tablename__ = 'users'
48
+ __table_args__ = {'extend_existing': True}
49
+ username: Mapped[str] = mapped_column(primary_key=True)
50
+ access_level: Mapped[str] = mapped_column(nullable=False, default="member")
51
+ password_hash: Mapped[str] = mapped_column(nullable=False)
52
+ custom_fields: Mapped[str] = mapped_column(nullable=False, default="{}")
53
+ created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
54
+
55
+ # Define DbApiKey class for this instance
56
+ class DbApiKey(self.Base):
57
+ __tablename__ = 'api_keys'
58
+
59
+ id: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
60
+ hashed_key: Mapped[str] = mapped_column(unique=True, nullable=False)
61
+ last_four: Mapped[str] = mapped_column(nullable=False)
62
+ title: Mapped[str] = mapped_column(nullable=False)
63
+ username: Mapped[str] = mapped_column(ForeignKey('users.username', ondelete='CASCADE'), nullable=False)
64
+ created_at: Mapped[datetime] = mapped_column(nullable=False)
65
+ expires_at: Mapped[datetime] = mapped_column(nullable=False)
66
+
67
+ def __repr__(self):
68
+ return f"<DbApiKey(id='{self.id}', username='{self.username}')>"
69
+
70
+ self.CustomUserFields = custom_user_fields_cls
71
+ self.DbUser = DbUser
72
+
73
+ self.DbApiKey = DbApiKey
74
+
75
+ self.auth_providers = [provider_function(auth_args) for provider_function in provider_functions]
76
+ self._jwks_clients: dict[str, PyJWKClient] = {}
77
+ self._provider_metadata_cache: dict[str, dict] = {}
78
+
79
+ if not self.external_only:
80
+ if sa_engine is None:
81
+ raw_sqlite_path = self.env_vars.auth_db_file_path
82
+ sqlite_path = u.Path(raw_sqlite_path.format(project_path=self.env_vars.project_path))
83
+ sqlite_path.parent.mkdir(parents=True, exist_ok=True)
84
+ self.engine = create_engine(f"sqlite:///{str(sqlite_path)}")
85
+ else:
86
+ self.engine = sa_engine
87
+
88
+ # Configure SQLite pragmas
89
+ with self.engine.connect() as conn:
90
+ conn.execute(text("PRAGMA journal_mode = WAL"))
91
+ conn.execute(text("PRAGMA synchronous = NORMAL"))
92
+ conn.commit()
93
+
94
+ self.Base.metadata.create_all(self.engine)
95
+
96
+ self.Session = sessionmaker(bind=self.engine)
97
+
98
+ self._initialize_db()
99
+
100
+ def _convert_db_user_to_user(self, db_user) -> RegisteredUser:
101
+ """Convert a database user to an AbstractUser object"""
102
+ # Deserialize custom_fields JSON and merge with defaults
103
+ custom_fields_json = json.loads(db_user.custom_fields) if db_user.custom_fields else {}
104
+ custom_fields = self.CustomUserFields(**custom_fields_json)
105
+
106
+ return RegisteredUser(
107
+ username=db_user.username,
108
+ access_level=db_user.access_level,
109
+ custom_fields=custom_fields
110
+ )
111
+
112
+ def _validate_password_length(self, password: str) -> None:
113
+ """Validate that password meets length requirements and does not exceed 72 characters (bcrypt limit)"""
114
+ min_len = self.password_requirements.min_length
115
+ max_len = min(self.password_requirements.max_length, 72)
116
+ if len(password) < min_len:
117
+ raise InvalidInputError(400, "password_too_short", f"Password must be at least {min_len} characters long")
118
+ if len(password) > max_len:
119
+ raise InvalidInputError(400, "password_too_long", f"Password cannot exceed {max_len} characters")
120
+
121
+ def _initialize_db(self):
122
+ session = self.Session()
123
+ try:
124
+ # Get existing columns in the database
125
+ inspector = inspect(self.engine)
126
+
127
+ for db_model in [self.DbUser, self.DbApiKey]:
128
+ table_name = db_model.__tablename__
129
+ existing_columns = {col['name'] for col in inspector.get_columns(table_name)}
130
+ model_columns = set(db_model.__table__.columns.keys())
131
+ new_columns = model_columns - existing_columns
132
+
133
+ if new_columns:
134
+ add_columns_msg = f"Adding columns to table {table_name}: {new_columns}"
135
+ self.logger.info(add_columns_msg)
136
+
137
+ for col_name in new_columns:
138
+ col = db_model.__table__.columns[col_name]
139
+ column_type = col.type.compile(self.engine.dialect)
140
+ nullable = "NULL" if col.nullable else "NOT NULL"
141
+ if col.default is not None and not callable(col.default.arg):
142
+ default_val = f"'{col.default.arg}'" if isinstance(col.default.arg, str) else col.default.arg
143
+ default = f"DEFAULT {default_val}"
144
+ else:
145
+ # If nullable is False and no default is provided, use an empty string as a placeholder default for SQLite
146
+ # TODO: Use a different default value (instead of empty string) based on the column type
147
+ default = "DEFAULT ''" if not col.nullable else ""
148
+
149
+ alter_stmt = f"ALTER TABLE {table_name} ADD COLUMN {col_name} {column_type} {nullable} {default}"
150
+ session.execute(text(alter_stmt))
151
+
152
+ session.commit()
153
+
154
+ # Get admin password from environment variable if exists
155
+ admin_password = self.env_vars.secret_admin_password
156
+
157
+ if admin_password is not None:
158
+ self._validate_password_length(admin_password)
159
+ password_hash = pwd_context.hash(admin_password)
160
+ admin_user = session.get(self.DbUser, c.ADMIN_USERNAME)
161
+ if admin_user is None:
162
+ admin_user = self.DbUser(username=c.ADMIN_USERNAME, password_hash=password_hash, access_level="admin")
163
+ session.add(admin_user)
164
+ else:
165
+ admin_user.password_hash = password_hash
166
+
167
+ session.commit()
168
+
169
+ finally:
170
+ session.close()
171
+
172
+ @cached_property
173
+ def user_fields(self) -> UserFieldsModel:
174
+ """
175
+ Get the fields of the CustomUserFields model as a list of dictionaries
176
+
177
+ Each dictionary contains the following keys:
178
+ - name: The name of the field
179
+ - type: The type of the field
180
+ - nullable: Whether the field is nullable
181
+ - enum: The possible values of the field (or None if not applicable)
182
+ - default: The default value of the field (or None if field is required)
183
+ """
184
+
185
+ custom_fields = []
186
+ schema = self.CustomUserFields.model_json_schema()
187
+ properties: dict[str, dict[str, Any]] = schema.get("properties", {})
188
+ for field_name, field_schema in properties.items():
189
+ if choices := field_schema.get("anyOf"):
190
+ field_type = choices[0]["type"]
191
+ nullable = (choices[1]["type"] == "null")
192
+ else:
193
+ field_type = field_schema["type"]
194
+ nullable = False
195
+
196
+ field_data = UserField(name=field_name, label=field_schema.get("title", field_name), type=field_type, nullable=nullable, enum=field_schema.get("enum"), default=field_schema.get("default"))
197
+ custom_fields.append(field_data)
198
+
199
+ return UserFieldsModel(
200
+ username=UserField(name="username", label="Username / Email", type="string", nullable=False, enum=None, default=None),
201
+ access_level=UserField(name="access_level", label="Access Level", type="string", nullable=False, enum=["admin", "member"], default="member"),
202
+ custom_fields=custom_fields
203
+ )
204
+
205
+ def add_user(self, username: str, user_fields: dict, *, update_user: bool = False) -> RegisteredUser:
206
+ # Separate custom fields from base fields
207
+ access_level = user_fields.get('access_level', 'member')
208
+ password = user_fields.get('password')
209
+
210
+ # Validate access level - cannot add/update users with guest access level
211
+ if access_level == "guest":
212
+ raise InvalidInputError(400, "invalid_access_level", "Cannot create or update users with 'guest' access level")
213
+
214
+ # Extract custom fields
215
+ custom_fields_data: dict[str, Any] = user_fields.get('custom_fields', {})
216
+
217
+ # Validate the custom fields
218
+ try:
219
+ custom_fields = self.CustomUserFields(**custom_fields_data)
220
+ custom_fields_json = json.dumps(custom_fields.model_dump(mode='json'))
221
+ except ValidationError as e:
222
+ raise InvalidInputError(400, "invalid_user_data", f"Invalid user field '{e.errors()[0]['loc'][0]}': {e.errors()[0]['msg']}")
223
+
224
+ # Add or update user
225
+ session = self.Session()
226
+ try:
227
+ # Check if the user already exists
228
+ db_user = session.get(self.DbUser, username)
229
+ if db_user is not None:
230
+ if not update_user:
231
+ raise InvalidInputError(400, "username_already_exists", f"User '{username}' already exists")
232
+
233
+ if username == c.ADMIN_USERNAME and access_level != "admin":
234
+ raise InvalidInputError(403, "admin_cannot_be_non_admin", "Setting the admin user to non-admin is not permitted")
235
+
236
+ # Update existing user
237
+ db_user.access_level = access_level
238
+ db_user.custom_fields = custom_fields_json
239
+ else:
240
+ if update_user:
241
+ raise InvalidInputError(404, "no_user_found_for_username", f"No user found for username: {username}")
242
+
243
+ if password is None:
244
+ raise InvalidInputError(400, "missing_password", f"Missing required field 'password' when adding a new user")
245
+
246
+ self._validate_password_length(password)
247
+ password_hash = pwd_context.hash(password)
248
+ db_user = self.DbUser(
249
+ username=username,
250
+ access_level=access_level,
251
+ password_hash=password_hash,
252
+ custom_fields=custom_fields_json
253
+ )
254
+ session.add(db_user)
255
+
256
+ # Commit the transaction
257
+ session.commit()
258
+ return self._convert_db_user_to_user(db_user)
259
+
260
+ finally:
261
+ session.close()
262
+
263
+ def create_or_get_user_from_provider(self, provider_name: str, user_info: dict) -> RegisteredUser:
264
+ provider = next((p for p in self.auth_providers if p.name == provider_name), None)
265
+ if provider is None:
266
+ raise InvalidInputError(404, "auth_provider_not_found", f"Provider '{provider_name}' not found")
267
+
268
+ user = provider.provider_configs.get_user(user_info)
269
+ session = self.Session()
270
+ try:
271
+ # Convert user to database user
272
+ custom_fields_json = user.custom_fields.model_dump_json()
273
+ db_user = self.DbUser(
274
+ username=user.username,
275
+ access_level=user.access_level,
276
+ password_hash="", # By omitting password_hash, it becomes impossible to login with username and password (OAuth only)
277
+ custom_fields=custom_fields_json
278
+ )
279
+
280
+ existing_db_user = session.get(self.DbUser, db_user.username)
281
+ if existing_db_user is None:
282
+ session.add(db_user)
283
+
284
+ session.commit()
285
+
286
+ return self._convert_db_user_to_user(db_user)
287
+
288
+ finally:
289
+ session.close()
290
+
291
+ def get_user(self, username: str, password: str) -> RegisteredUser:
292
+ session = self.Session()
293
+ try:
294
+ # Query for user by username
295
+ db_user = session.get(self.DbUser, username)
296
+
297
+ if db_user and pwd_context.verify(password, db_user.password_hash):
298
+ user = self._convert_db_user_to_user(db_user)
299
+ return user
300
+ else:
301
+ raise InvalidInputError(401, "incorrect_username_or_password", f"Incorrect username or password")
302
+
303
+ finally:
304
+ session.close()
305
+
306
+ def change_password(self, username: str, old_password: str, new_password: str) -> None:
307
+ session = self.Session()
308
+ try:
309
+ db_user = session.get(self.DbUser, username)
310
+ if db_user is None:
311
+ raise InvalidInputError(401, "user_not_found", f"Username '{username}' not found for password change")
312
+
313
+ if db_user.password_hash and pwd_context.verify(old_password, db_user.password_hash):
314
+ self._validate_password_length(new_password)
315
+ db_user.password_hash = pwd_context.hash(new_password)
316
+ session.commit()
317
+ else:
318
+ raise InvalidInputError(401, "incorrect_password", f"Incorrect password")
319
+ finally:
320
+ session.close()
321
+
322
+ def delete_user(self, username: str) -> None:
323
+ if username == c.ADMIN_USERNAME:
324
+ raise InvalidInputError(403, "cannot_delete_admin_user", "Cannot delete the admin user")
325
+
326
+ session = self.Session()
327
+ try:
328
+ db_user = session.get(self.DbUser, username)
329
+ if db_user is None:
330
+ raise InvalidInputError(404, "no_user_found_for_username", f"No user found for username: {username}")
331
+ session.delete(db_user)
332
+ session.commit()
333
+ finally:
334
+ session.close()
335
+
336
+ def get_all_users(self) -> list[RegisteredUser]:
337
+ session = self.Session()
338
+ try:
339
+ db_users = session.query(self.DbUser).all()
340
+ return [self._convert_db_user_to_user(user) for user in db_users]
341
+ finally:
342
+ session.close()
343
+
344
+ def create_access_token(self, user: AbstractUser, expiry_minutes: int | None, *, title: str | None = None) -> tuple[str, datetime]:
345
+ """
346
+ Creates an API key if title is provided. Otherwise, creates a JWT token.
347
+ """
348
+ created_at = datetime.now(timezone.utc)
349
+ expire_at = created_at + timedelta(minutes=expiry_minutes) if expiry_minutes is not None else datetime.max
350
+
351
+ if self.secret_key is None:
352
+ raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to create an access token")
353
+
354
+ if title is not None:
355
+ session = self.Session()
356
+ try:
357
+ token_id = "sqrl-" + secrets.token_urlsafe(16)
358
+ hashed_key = u.hash_string(token_id, salt=self.secret_key)
359
+ last_four = token_id[-4:]
360
+ api_key = self.DbApiKey(
361
+ hashed_key=hashed_key, last_four=last_four, title=title, username=user.username,
362
+ created_at=created_at, expires_at=expire_at
363
+ )
364
+ session.add(api_key)
365
+ session.commit()
366
+ finally:
367
+ session.close()
368
+ else:
369
+ to_encode = {"username": user.username, "exp": expire_at}
370
+ token_id = jwt.encode(to_encode, self.secret_key, algorithm="HS256")
371
+
372
+ return token_id, expire_at
373
+
374
+ def get_user_from_token(self, token: str | None) -> tuple[RegisteredUser | None, float | None]:
375
+ """
376
+ Get a user and expiry time from an access token (JWT, or API key if token starts with 'sqrl-')
377
+ """
378
+ if not token:
379
+ return None, None
380
+
381
+ if self.secret_key is None:
382
+ raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to get user from an access token")
383
+
384
+ session = self.Session()
385
+ try:
386
+ if token.startswith("sqrl-"):
387
+ hashed_key = u.hash_string(token, salt=self.secret_key)
388
+ api_key = session.query(self.DbApiKey).filter(
389
+ self.DbApiKey.hashed_key == hashed_key,
390
+ self.DbApiKey.expires_at >= func.now()
391
+ ).first()
392
+ if api_key is None:
393
+ raise InvalidTokenError()
394
+ username = api_key.username
395
+ expiry = None
396
+ else:
397
+ payload: dict = jwt.decode(token, self.secret_key, algorithms=["HS256"])
398
+ username = payload["username"]
399
+ expiry = payload.get("exp")
400
+
401
+ db_user = session.get(self.DbUser, username)
402
+ if db_user is None:
403
+ raise InvalidTokenError()
404
+
405
+ user = self._convert_db_user_to_user(db_user)
406
+
407
+ except InvalidTokenError:
408
+ raise InvalidInputError(401, "invalid_authorization_token", "Invalid authorization token")
409
+ finally:
410
+ session.close()
411
+
412
+ return user, expiry
413
+
414
+ def _get_jwks_client(self, jwks_uri: str) -> PyJWKClient:
415
+ if jwks_uri not in self._jwks_clients:
416
+ self._jwks_clients[jwks_uri] = PyJWKClient(jwks_uri)
417
+ return self._jwks_clients[jwks_uri]
418
+
419
+ def _get_issuer_from_token(self, token: str) -> str | None:
420
+ try:
421
+ # JWT format is header.payload.signature
422
+ parts = token.split('.')
423
+ if len(parts) != 3:
424
+ return None
425
+
426
+ # Base64url decode the payload (JWT uses base64url encoding)
427
+ payload_b64 = parts[1]
428
+ # Add padding if necessary: (-len % 4) yields 0..3
429
+ padding = '=' * ((-len(payload_b64)) % 4)
430
+ payload_json = base64.urlsafe_b64decode(payload_b64 + padding).decode("utf-8")
431
+ payload: dict = json.loads(payload_json)
432
+
433
+ issuer = payload.get("iss")
434
+ return issuer if isinstance(issuer, str) and issuer else None
435
+ except Exception:
436
+ return None
437
+
438
+ @staticmethod
439
+ def _is_jwt(token: str) -> bool:
440
+ if not isinstance(token, str) or not token:
441
+ return False
442
+ parts = token.split(".")
443
+ return len(parts) == 3 and all(parts)
444
+
445
+ def _get_provider_metadata(self, provider_name: str) -> dict:
446
+ provider = next((p for p in self.auth_providers if p.name == provider_name), None)
447
+ if provider is None:
448
+ raise InvalidInputError(404, "auth_provider_not_found", f"Provider '{provider_name}' not found")
449
+
450
+ metadata_url = provider.provider_configs.server_metadata_url
451
+ cached = self._provider_metadata_cache.get(metadata_url)
452
+ if isinstance(cached, dict) and cached:
453
+ return cached
454
+
455
+ try:
456
+ response = requests.get(metadata_url, timeout=10)
457
+ response.raise_for_status()
458
+ metadata = response.json()
459
+ if not isinstance(metadata, dict):
460
+ raise ValueError("Provider metadata was not a JSON object")
461
+ except Exception as e:
462
+ raise ConfigurationError(f"Failed to fetch metadata for provider '{provider_name}': {str(e)}") from e
463
+
464
+ self._provider_metadata_cache[metadata_url] = metadata
465
+ return metadata
466
+
467
+ def _verify_provider_jwt(
468
+ self,
469
+ provider_name: str,
470
+ token: str,
471
+ *,
472
+ purpose: str,
473
+ expected_nonce: str | None = None,
474
+ verify_aud: bool = True,
475
+ ) -> dict | None:
476
+ """
477
+ Verify a provider-issued JWT (signature + exp; aud when requested).
478
+ Uses OIDC discovery for jwks_uri and (best-effort) issuer validation.
479
+ """
480
+ if not self._is_jwt(token):
481
+ return None
482
+
483
+ provider = next((p for p in self.auth_providers if p.name == provider_name), None)
484
+ if provider is None:
485
+ raise InvalidInputError(404, "auth_provider_not_found", f"Provider '{provider_name}' not found")
486
+
487
+ metadata = self._get_provider_metadata(provider_name)
488
+ jwks_uri = metadata.get("jwks_uri")
489
+ if not isinstance(jwks_uri, str) or not jwks_uri:
490
+ raise ConfigurationError(f"jwks_uri not found in metadata for provider '{provider_name}'")
491
+
492
+ signing_algs = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
493
+ if not isinstance(signing_algs, list) or not signing_algs:
494
+ signing_algs = ["RS256"]
495
+
496
+ jwks_client = self._get_jwks_client(jwks_uri)
497
+ signing_key = jwks_client.get_signing_key_from_jwt(token)
498
+
499
+ decode_kwargs: dict[str, Any] = {
500
+ "key": signing_key.key,
501
+ "algorithms": signing_algs,
502
+ "options": {
503
+ "verify_aud": bool(verify_aud),
504
+ # We'll validate issuer manually to avoid brittle trailing-slash mismatches.
505
+ "verify_iss": False,
506
+ },
507
+ }
508
+ if verify_aud:
509
+ decode_kwargs["audience"] = provider.provider_configs.client_id
510
+
511
+ try:
512
+ payload = jwt.decode(token, **decode_kwargs)
513
+ except Exception:
514
+ return None
515
+
516
+ if not isinstance(payload, dict):
517
+ return None
518
+
519
+ expected_issuer = metadata.get("issuer") or provider.provider_configs.server_url
520
+ token_issuer = payload.get("iss")
521
+ if isinstance(expected_issuer, str) and expected_issuer and isinstance(token_issuer, str) and token_issuer:
522
+ if token_issuer.rstrip("/") != expected_issuer.rstrip("/"):
523
+ raise InvalidInputError(
524
+ 401,
525
+ "invalid_provider_token",
526
+ f"Invalid {purpose} issuer for provider '{provider_name}'",
527
+ )
528
+
529
+ if expected_nonce is not None:
530
+ nonce_claim = payload.get("nonce")
531
+ if nonce_claim != expected_nonce:
532
+ raise InvalidInputError(
533
+ 401,
534
+ "invalid_provider_token",
535
+ f"Invalid {purpose} nonce for provider '{provider_name}'",
536
+ )
537
+
538
+ return payload
539
+
540
+ def get_user_info_from_token_details(self, provider_name: str, token_details: dict, *, expected_nonce: str | None = None) -> dict:
541
+ """
542
+ Determine user_info from an OAuth/OIDC token response.
543
+
544
+ Priority:
545
+ - token_details["user_info"] / token_details["userinfo"]
546
+ - verify + decode token_details["id_token"] if it's a JWT
547
+ - verify + decode token_details["access_token"] if it's a JWT
548
+ - for opaque access_token: call userinfo or introspection endpoint from provider metadata
549
+ """
550
+ for key in ("user_info", "userinfo"):
551
+ user_info = token_details.get(key)
552
+ if isinstance(user_info, dict) and user_info:
553
+ return user_info
554
+
555
+ id_token = token_details.get("id_token")
556
+ if isinstance(id_token, str):
557
+ if payload := self._verify_provider_jwt(
558
+ provider_name, id_token, purpose="id_token", expected_nonce=expected_nonce, verify_aud=True
559
+ ):
560
+ return payload
561
+
562
+ access_token = token_details.get("access_token")
563
+ if isinstance(access_token, str):
564
+ # Some providers issue JWT access tokens. Audience can vary (resource server),
565
+ # so we verify signature/exp and issuer, but skip aud validation.
566
+ if payload := self._verify_provider_jwt(
567
+ provider_name, access_token, purpose="access_token", expected_nonce=None, verify_aud=False
568
+ ):
569
+ return payload
570
+
571
+ if not isinstance(access_token, str) or not access_token:
572
+ raise InvalidInputError(
573
+ 400,
574
+ "invalid_provider_user_info",
575
+ f"User information not found in token details for {provider_name}",
576
+ )
577
+
578
+ provider = next((p for p in self.auth_providers if p.name == provider_name), None)
579
+ if provider is None:
580
+ raise InvalidInputError(404, "auth_provider_not_found", f"Provider '{provider_name}' not found")
581
+
582
+ metadata: dict = self._get_provider_metadata(provider_name)
583
+
584
+ userinfo_endpoint = metadata.get("userinfo_endpoint")
585
+ if isinstance(userinfo_endpoint, str) and userinfo_endpoint:
586
+ try:
587
+ response = requests.get(
588
+ userinfo_endpoint,
589
+ headers={"Authorization": f"Bearer {access_token}"},
590
+ timeout=10,
591
+ )
592
+ response.raise_for_status()
593
+ user_info = response.json()
594
+ if isinstance(user_info, dict) and user_info:
595
+ return user_info
596
+ except Exception as e:
597
+ self.logger.warning(f"Failed to fetch user info from {userinfo_endpoint} for provider {provider_name}: {str(e)}")
598
+ # Fall back to introspection if available
599
+ pass
600
+
601
+ introspection_endpoint = metadata.get("introspection_endpoint")
602
+ if isinstance(introspection_endpoint, str) and introspection_endpoint:
603
+ try:
604
+ response = requests.post(
605
+ introspection_endpoint,
606
+ data={"token": access_token},
607
+ auth=(provider.provider_configs.client_id, provider.provider_configs.client_secret),
608
+ timeout=10,
609
+ )
610
+ response.raise_for_status()
611
+ token_info = response.json()
612
+ if isinstance(token_info, dict):
613
+ if token_info.get("active") is False:
614
+ raise InvalidInputError(401, "inactive_external_token", "External authorization token is inactive")
615
+ return token_info
616
+ except InvalidInputError:
617
+ raise
618
+ except Exception as e:
619
+ self.logger.warning(f"Introspection request failed for provider {provider_name} using basic auth: {str(e)}")
620
+ # Some providers require "client_secret_post" instead of basic auth for introspection.
621
+ try:
622
+ response = requests.post(
623
+ introspection_endpoint,
624
+ data={
625
+ "token": access_token,
626
+ "client_id": provider.provider_configs.client_id,
627
+ "client_secret": provider.provider_configs.client_secret,
628
+ },
629
+ timeout=10,
630
+ )
631
+ response.raise_for_status()
632
+ token_info = response.json()
633
+ if isinstance(token_info, dict):
634
+ if token_info.get("active") is False:
635
+ raise InvalidInputError(401, "inactive_external_token", "External authorization token is inactive")
636
+ return token_info
637
+ except InvalidInputError:
638
+ raise
639
+ except Exception as e2:
640
+ self.logger.warning(f"Introspection request failed for provider {provider_name} using post data: {str(e2)}")
641
+ pass
642
+
643
+ raise InvalidInputError(
644
+ 400,
645
+ "invalid_provider_user_info",
646
+ f"User information not found in token details for {provider_name}",
647
+ )
648
+
649
+ def get_user_from_external_token(self, token: str, provider_name: str | None = None) -> tuple[RegisteredUser, float | None]:
650
+ """
651
+ Get a user from an external OAuth token by validating against provider's JWKS
652
+ """
653
+ issuer: str | None = None
654
+ token_is_jwt = self._is_jwt(token)
655
+
656
+ if provider_name:
657
+ provider = next((p for p in self.auth_providers if p.name == provider_name), None)
658
+ elif token_is_jwt:
659
+ issuer = self._get_issuer_from_token(token)
660
+ if not issuer:
661
+ raise InvalidInputError(401, "invalid_external_token", "Could not extract issuer from token")
662
+
663
+ # Match provider by issuer (server_url)
664
+ provider = next(
665
+ (p for p in self.auth_providers if p.provider_configs.server_url.rstrip("/") == issuer.rstrip("/")),
666
+ None,
667
+ )
668
+ else:
669
+ # Opaque external token: if there's exactly one configured provider, assume it.
670
+ provider = self.auth_providers[0] if len(self.auth_providers) == 1 else None
671
+
672
+ if provider is None:
673
+ if provider_name:
674
+ raise InvalidInputError(404, "auth_provider_not_found", f"Provider '{provider_name}' not found")
675
+ if token_is_jwt:
676
+ raise InvalidInputError(401, "auth_provider_not_found", f"No provider found for issuer: {issuer}")
677
+ raise InvalidInputError(401, "invalid_external_token", "Could not determine provider for external token")
678
+
679
+ # JWT external token: validate signature/exp (+ issuer) via provider JWKS
680
+ if token_is_jwt:
681
+ try:
682
+ payload = self._verify_provider_jwt(provider.name, token, purpose="external_token", verify_aud=False)
683
+ except InvalidInputError:
684
+ # Keep the external-auth contract stable (avoid leaking provider-specific details).
685
+ raise InvalidInputError(401, "invalid_external_token", "Invalid external authorization token")
686
+
687
+ if not isinstance(payload, dict) or not payload:
688
+ raise InvalidInputError(401, "invalid_external_token", "Invalid external authorization token")
689
+ else:
690
+ # Opaque token: reuse the existing provider userinfo/introspection logic.
691
+ try:
692
+ payload = self.get_user_info_from_token_details(provider.name, {"access_token": token})
693
+ except InvalidInputError as e:
694
+ # Normalize into the external-auth error contract.
695
+ if getattr(e, "error", None) == "inactive_external_token":
696
+ raise
697
+ raise InvalidInputError(401, "invalid_external_token", "Invalid external authorization token")
698
+
699
+ if not isinstance(payload, dict) or not payload:
700
+ raise InvalidInputError(401, "invalid_external_token", "Invalid external authorization token")
701
+
702
+ user = provider.provider_configs.get_user(payload)
703
+ exp = payload.get("exp")
704
+ expiry: float | None
705
+ if isinstance(exp, (int, float)):
706
+ expiry = float(exp)
707
+ else:
708
+ expiry = None
709
+
710
+ return user, expiry
711
+
712
+ def get_all_api_keys(self, username: str) -> list[ApiKey]:
713
+ """
714
+ Get the ID, title, and expiry date of all API keys for a user. Note that the ID is a hash of the API key, not the API key itself.
715
+ """
716
+ session = self.Session()
717
+ try:
718
+ tokens = session.query(self.DbApiKey).filter(
719
+ self.DbApiKey.username == username,
720
+ self.DbApiKey.expires_at >= func.now()
721
+ ).all()
722
+
723
+ return [ApiKey.model_validate(token) for token in tokens]
724
+ finally:
725
+ session.close()
726
+
727
+ def revoke_api_key(self, username: str, api_key_id: str) -> None:
728
+ """
729
+ Revoke an API key
730
+ """
731
+ session = self.Session()
732
+ try:
733
+
734
+ api_key = session.query(self.DbApiKey).filter(
735
+ self.DbApiKey.username == username,
736
+ self.DbApiKey.id == api_key_id
737
+ ).first()
738
+
739
+ if api_key is None:
740
+ raise InvalidInputError(404, "api_key_not_found", f"The API key could not be found: {api_key_id}")
741
+
742
+ session.delete(api_key)
743
+ session.commit()
744
+ finally:
745
+ session.close()
746
+
747
+ def can_user_access_scope(self, user: AbstractUser, scope: PermissionScope) -> bool:
748
+ if user.access_level == "guest":
749
+ user_level = PermissionScope.PUBLIC
750
+ elif user.access_level == "admin":
751
+ user_level = PermissionScope.PRIVATE
752
+ else: # member
753
+ user_level = PermissionScope.PROTECTED
754
+
755
+ return user_level.value >= scope.value
756
+
757
+ def close(self) -> None:
758
+ if hasattr(self, "engine"):
759
+ self.engine.dispose()
760
+
761
+
762
+ def provider(name: str, label: str, icon: str):
763
+ """
764
+ Decorator to register an authentication provider
765
+
766
+ Arguments:
767
+ name: The name of the provider (must be unique, e.g. 'google')
768
+ label: The label of the provider (e.g. 'Google')
769
+ icon: The URL of the icon of the provider (e.g. 'https://www.google.com/favicon.ico')
770
+ """
771
+ def decorator(func: Callable[[AuthProviderArgs], ProviderConfigs]):
772
+ def wrapper(sqrl: AuthProviderArgs):
773
+ provider_configs = func(sqrl)
774
+ return AuthProvider(name=name, label=label, icon=icon, provider_configs=provider_configs)
775
+ Authenticator.providers.append(wrapper)
776
+ return wrapper
777
+ return decorator