squirrels 0.5.0b3__py3-none-any.whl → 0.6.0.post0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. squirrels/__init__.py +4 -0
  2. squirrels/_api_routes/__init__.py +5 -0
  3. squirrels/_api_routes/auth.py +337 -0
  4. squirrels/_api_routes/base.py +196 -0
  5. squirrels/_api_routes/dashboards.py +156 -0
  6. squirrels/_api_routes/data_management.py +148 -0
  7. squirrels/_api_routes/datasets.py +220 -0
  8. squirrels/_api_routes/project.py +289 -0
  9. squirrels/_api_server.py +440 -792
  10. squirrels/_arguments/__init__.py +0 -0
  11. squirrels/_arguments/{_init_time_args.py → init_time_args.py} +23 -43
  12. squirrels/_arguments/{_run_time_args.py → run_time_args.py} +32 -68
  13. squirrels/_auth.py +590 -264
  14. squirrels/_command_line.py +130 -58
  15. squirrels/_compile_prompts.py +147 -0
  16. squirrels/_connection_set.py +16 -15
  17. squirrels/_constants.py +36 -11
  18. squirrels/_dashboards.py +179 -0
  19. squirrels/_data_sources.py +40 -34
  20. squirrels/_dataset_types.py +16 -11
  21. squirrels/_env_vars.py +209 -0
  22. squirrels/_exceptions.py +9 -37
  23. squirrels/_http_error_responses.py +52 -0
  24. squirrels/_initializer.py +7 -6
  25. squirrels/_logging.py +121 -0
  26. squirrels/_manifest.py +155 -77
  27. squirrels/_mcp_server.py +578 -0
  28. squirrels/_model_builder.py +11 -55
  29. squirrels/_model_configs.py +5 -5
  30. squirrels/_model_queries.py +1 -1
  31. squirrels/_models.py +276 -143
  32. squirrels/_package_data/base_project/.env +1 -24
  33. squirrels/_package_data/base_project/.env.example +31 -17
  34. squirrels/_package_data/base_project/connections.yml +4 -3
  35. squirrels/_package_data/base_project/dashboards/dashboard_example.py +13 -7
  36. squirrels/_package_data/base_project/dashboards/dashboard_example.yml +6 -6
  37. squirrels/_package_data/base_project/docker/Dockerfile +2 -2
  38. squirrels/_package_data/base_project/docker/compose.yml +1 -1
  39. squirrels/_package_data/base_project/duckdb_init.sql +1 -0
  40. squirrels/_package_data/base_project/models/builds/build_example.py +2 -2
  41. squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +7 -2
  42. squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +16 -10
  43. squirrels/_package_data/base_project/models/federates/federate_example.py +27 -17
  44. squirrels/_package_data/base_project/models/federates/federate_example.sql +3 -7
  45. squirrels/_package_data/base_project/models/federates/federate_example.yml +7 -7
  46. squirrels/_package_data/base_project/models/sources.yml +5 -6
  47. squirrels/_package_data/base_project/parameters.yml +24 -38
  48. squirrels/_package_data/base_project/pyconfigs/connections.py +8 -3
  49. squirrels/_package_data/base_project/pyconfigs/context.py +26 -14
  50. squirrels/_package_data/base_project/pyconfigs/parameters.py +124 -81
  51. squirrels/_package_data/base_project/pyconfigs/user.py +48 -15
  52. squirrels/_package_data/base_project/resources/public/.gitkeep +0 -0
  53. squirrels/_package_data/base_project/seeds/seed_categories.yml +1 -1
  54. squirrels/_package_data/base_project/seeds/seed_subcategories.yml +1 -1
  55. squirrels/_package_data/base_project/squirrels.yml.j2 +21 -31
  56. squirrels/_package_data/templates/login_successful.html +53 -0
  57. squirrels/_package_data/templates/squirrels_studio.html +22 -0
  58. squirrels/_parameter_configs.py +43 -22
  59. squirrels/_parameter_options.py +1 -1
  60. squirrels/_parameter_sets.py +41 -30
  61. squirrels/_parameters.py +560 -123
  62. squirrels/_project.py +487 -277
  63. squirrels/_py_module.py +71 -10
  64. squirrels/_request_context.py +33 -0
  65. squirrels/_schemas/__init__.py +0 -0
  66. squirrels/_schemas/auth_models.py +83 -0
  67. squirrels/_schemas/query_param_models.py +70 -0
  68. squirrels/_schemas/request_models.py +26 -0
  69. squirrels/_schemas/response_models.py +286 -0
  70. squirrels/_seeds.py +52 -13
  71. squirrels/_sources.py +29 -23
  72. squirrels/_utils.py +221 -42
  73. squirrels/_version.py +1 -3
  74. squirrels/arguments.py +7 -2
  75. squirrels/auth.py +4 -0
  76. squirrels/connections.py +2 -0
  77. squirrels/dashboards.py +3 -1
  78. squirrels/data_sources.py +6 -0
  79. squirrels/parameter_options.py +5 -0
  80. squirrels/parameters.py +5 -0
  81. squirrels/types.py +10 -3
  82. squirrels-0.6.0.post0.dist-info/METADATA +148 -0
  83. squirrels-0.6.0.post0.dist-info/RECORD +101 -0
  84. {squirrels-0.5.0b3.dist-info → squirrels-0.6.0.post0.dist-info}/WHEEL +1 -1
  85. squirrels/_api_response_models.py +0 -190
  86. squirrels/_dashboard_types.py +0 -82
  87. squirrels/_dashboards_io.py +0 -79
  88. squirrels-0.5.0b3.dist-info/METADATA +0 -110
  89. squirrels-0.5.0b3.dist-info/RECORD +0 -80
  90. /squirrels/_package_data/base_project/{assets → resources}/expenses.db +0 -0
  91. /squirrels/_package_data/base_project/{assets → resources}/weather.db +0 -0
  92. {squirrels-0.5.0b3.dist-info → squirrels-0.6.0.post0.dist-info}/entry_points.txt +0 -0
  93. {squirrels-0.5.0b3.dist-info → squirrels-0.6.0.post0.dist-info}/licenses/LICENSE +0 -0
squirrels/_auth.py CHANGED
@@ -1,231 +1,165 @@
1
+ from typing import Callable, Any
1
2
  from datetime import datetime, timedelta, timezone
2
- from enum import Enum
3
3
  from functools import cached_property
4
4
  from jwt.exceptions import InvalidTokenError
5
5
  from passlib.context import CryptContext
6
- from pydantic import BaseModel, ConfigDict, ValidationError
7
- from pydantic_core import PydanticUndefined
6
+ from pydantic import ValidationError
8
7
  from sqlalchemy import create_engine, Engine, func, inspect, text, ForeignKey
9
- from sqlalchemy import Column, String, Integer, Float, Boolean
10
8
  from sqlalchemy.orm import declarative_base, sessionmaker, Mapped, mapped_column
11
- import jwt, types, typing as _t, uuid
9
+ import jwt, uuid, secrets, json, base64, requests
10
+ from jwt import PyJWKClient
12
11
 
12
+ from ._env_vars import SquirrelsEnvVars
13
13
  from ._manifest import PermissionScope
14
- from ._py_module import PyModule
15
14
  from ._exceptions import InvalidInputError, ConfigurationError
15
+ from ._arguments.init_time_args import AuthProviderArgs
16
+ from ._schemas.auth_models import (
17
+ CustomUserFields, AbstractUser, RegisteredUser, ApiKey, UserField, UserFieldsModel, AuthProvider, ProviderConfigs
18
+ )
19
+ from ._schemas import response_models as rm
16
20
  from . import _utils as u, _constants as c
17
21
 
18
22
  pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
19
23
 
20
- reserved_fields = ["username", "is_admin"]
21
- disallowed_fields = ["password", "password_hash", "created_at", "token_id", "exp"]
24
+ ProviderFunctionType = Callable[[AuthProviderArgs], AuthProvider]
22
25
 
23
- class BaseUser(BaseModel):
24
- model_config = ConfigDict(from_attributes=True)
25
- username: str
26
- is_admin: bool = False
27
-
28
- @classmethod
29
- def dropped_columns(cls):
30
- return []
31
-
32
- def __hash__(self):
33
- return hash(self.username)
34
-
35
- User = _t.TypeVar('User', bound=BaseUser)
36
-
37
- class AccessToken(BaseModel):
38
- model_config = ConfigDict(from_attributes=True)
39
- token_id: str
40
- title: str
41
- username: str
42
- created_at: datetime
43
- expires_at: datetime
44
26
 
27
+ class Authenticator:
28
+ providers: list[ProviderFunctionType] = [] # static variable to stage providers
45
29
 
46
- class UserField(BaseModel):
47
- name: str
48
- type: str
49
- nullable: bool
50
- enum: list[str] | None
51
- default: _t.Any | None
52
-
53
-
54
- class Authenticator(_t.Generic[User]):
55
- def __init__(self, logger: u.Logger, base_path: str, env_vars: dict[str, str], *, sa_engine: Engine | None = None, cls: type[User] | None = None):
30
+ def __init__(
31
+ self, logger: u.Logger, env_vars: SquirrelsEnvVars, auth_args: AuthProviderArgs,
32
+ provider_functions: list[ProviderFunctionType], custom_user_fields_cls: type[CustomUserFields],
33
+ *,
34
+ sa_engine: Engine | None = None, external_only: bool = False
35
+ ):
56
36
  self.logger = logger
57
37
  self.env_vars = env_vars
58
- self.secret_key = self.env_vars.get(c.SQRL_SECRET_KEY)
38
+ self.secret_key = env_vars.secret_key
39
+ self.external_only = external_only
40
+ self.password_requirements = rm.PasswordRequirements()
59
41
 
60
42
  # Create a new declarative base for this instance
61
43
  self.Base = declarative_base()
62
44
 
63
- # Define DbBaseUser class for this instance
64
- class DbBaseUser(self.Base):
45
+ # Define DbUser class for this instance
46
+ class DbUser(self.Base):
65
47
  __tablename__ = 'users'
66
48
  __table_args__ = {'extend_existing': True}
67
49
  username: Mapped[str] = mapped_column(primary_key=True)
68
- is_admin: Mapped[bool] = mapped_column(nullable=False, default=False)
50
+ access_level: Mapped[str] = mapped_column(nullable=False, default="member")
69
51
  password_hash: Mapped[str] = mapped_column(nullable=False)
52
+ custom_fields: Mapped[str] = mapped_column(nullable=False, default="{}")
70
53
  created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
71
54
 
72
- # Define DbAccessToken class for this instance
73
- class DbAccessToken(self.Base):
74
- __tablename__ = 'access_tokens'
55
+ # Define DbApiKey class for this instance
56
+ class DbApiKey(self.Base):
57
+ __tablename__ = 'api_keys'
75
58
 
76
- token_id: Mapped[str] = mapped_column(primary_key=True, default=lambda: str(uuid.uuid4()))
59
+ id: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
60
+ hashed_key: Mapped[str] = mapped_column(unique=True, nullable=False)
61
+ last_four: Mapped[str] = mapped_column(nullable=False)
77
62
  title: Mapped[str] = mapped_column(nullable=False)
78
63
  username: Mapped[str] = mapped_column(ForeignKey('users.username', ondelete='CASCADE'), nullable=False)
79
64
  created_at: Mapped[datetime] = mapped_column(nullable=False)
80
65
  expires_at: Mapped[datetime] = mapped_column(nullable=False)
81
-
82
- def __repr__(self):
83
- return f"<AccessToken(token_id='{self.token_id}', username='{self.username}')>"
84
66
 
85
- self.DbBaseUser = DbBaseUser
86
- self.DbAccessToken = DbAccessToken
87
-
88
- self.User = self._get_user_model(base_path) if cls is None else cls
89
- self.DbUser: type[DbBaseUser] = self._initialize_db_user_model(self.User)
67
+ def __repr__(self):
68
+ return f"<DbApiKey(id='{self.id}', username='{self.username}')>"
90
69
 
91
- if sa_engine is None:
92
- sqlite_relative_path = env_vars.get(c.SQRL_AUTH_DB_FILE_PATH, f"{c.TARGET_FOLDER}/{c.DB_FILE}")
93
- sqlite_path = u.Path(base_path, sqlite_relative_path)
94
- sqlite_path.parent.mkdir(parents=True, exist_ok=True)
95
- self.engine = create_engine(f"sqlite:///{str(sqlite_path)}")
96
- else:
97
- self.engine = sa_engine
70
+ self.CustomUserFields = custom_user_fields_cls
71
+ self.DbUser = DbUser
72
+
73
+ self.DbApiKey = DbApiKey
98
74
 
99
- # Configure SQLite pragmas
100
- with self.engine.connect() as conn:
101
- conn.execute(text("PRAGMA journal_mode = WAL"))
102
- conn.execute(text("PRAGMA synchronous = NORMAL"))
103
- conn.commit()
75
+ self.auth_providers = [provider_function(auth_args) for provider_function in provider_functions]
76
+ self._jwks_clients: dict[str, PyJWKClient] = {}
77
+ self._provider_metadata_cache: dict[str, dict] = {}
104
78
 
105
- self.Base.metadata.create_all(self.engine)
106
-
107
- self.Session = sessionmaker(bind=self.engine)
108
-
109
- self._initialize_db(self.User, self.DbUser, self.engine, self.Session)
110
-
111
- def _get_user_model(self, base_path: str) -> type[BaseUser]:
112
- user_module_path = u.Path(base_path, c.PYCONFIGS_FOLDER, c.USER_FILE)
113
- user_module = PyModule(user_module_path)
114
- User = user_module.get_func_or_class("User", default_attr=BaseUser)
115
- if not issubclass(User, BaseUser):
116
- raise ConfigurationError(f"User class in '{c.USER_FILE}' must inherit from BaseUser")
117
- return User
118
-
119
- def _initialize_db_user_model(self, *args) -> type:
120
- """Get the user model with any custom attributes defined in user.py"""
121
- attrs = {}
122
-
123
- # Iterate over all fields in the User model
124
- for field_name, field in self.User.model_fields.items():
125
- if field_name in reserved_fields:
126
- continue
127
- if field_name in disallowed_fields:
128
- raise ConfigurationError(f"Field name '{field_name}' is disallowed in the User model and cannot be used")
129
-
130
- field_type = field.annotation
131
- if _t.get_origin(field_type) in (_t.Union, types.UnionType):
132
- field_type = _t.get_args(field_type)[0]
133
- nullable = True
79
+ if not self.external_only:
80
+ if sa_engine is None:
81
+ raw_sqlite_path = self.env_vars.auth_db_file_path
82
+ sqlite_path = u.Path(raw_sqlite_path.format(project_path=self.env_vars.project_path))
83
+ sqlite_path.parent.mkdir(parents=True, exist_ok=True)
84
+ self.engine = create_engine(f"sqlite:///{str(sqlite_path)}")
134
85
  else:
135
- nullable = False
136
-
137
- if _t.get_origin(field_type) == _t.Literal:
138
- field_type = str
139
-
140
- # Map Python types and default values to SQLAlchemy columns
141
- default_value = field.default
142
- if default_value is PydanticUndefined:
143
- raise ConfigurationError(f"No default value found for field '{field_name}' in User model")
144
- elif not nullable and default_value is None:
145
- raise ConfigurationError(f"Default value for non-nullable field '{field_name}' was set as None in User model")
146
- elif default_value is not None and type(default_value) is not field_type:
147
- raise ConfigurationError(f"Default value for field '{field_name}' does not match field type in User model")
86
+ self.engine = sa_engine
148
87
 
149
- if field_type == str:
150
- col_type = String
151
- elif field_type == int:
152
- col_type = Integer
153
- elif field_type == float:
154
- col_type = Float
155
- elif field_type == bool:
156
- col_type = Boolean
157
- elif isinstance(field_type, type) and issubclass(field_type, Enum):
158
- col_type = String
159
- default_value = default_value.value
160
- else:
161
- continue
162
-
163
- attrs[field_name] = Column(col_type, nullable=nullable, default=default_value) # type: ignore
88
+ # Configure SQLite pragmas
89
+ with self.engine.connect() as conn:
90
+ conn.execute(text("PRAGMA journal_mode = WAL"))
91
+ conn.execute(text("PRAGMA synchronous = NORMAL"))
92
+ conn.commit()
93
+
94
+ self.Base.metadata.create_all(self.engine)
164
95
 
165
- # Create the sqlalchemy model class
166
- DbUser = type('DbUser', (self.DbBaseUser,), attrs)
167
- return DbUser
96
+ self.Session = sessionmaker(bind=self.engine)
168
97
 
169
- def _initialize_db(self, *args): # TODO: Use logger instead of print
98
+ self._initialize_db()
99
+
100
+ def _convert_db_user_to_user(self, db_user) -> RegisteredUser:
101
+ """Convert a database user to an AbstractUser object"""
102
+ # Deserialize custom_fields JSON and merge with defaults
103
+ custom_fields_json = json.loads(db_user.custom_fields) if db_user.custom_fields else {}
104
+ custom_fields = self.CustomUserFields(**custom_fields_json)
105
+
106
+ return RegisteredUser(
107
+ username=db_user.username,
108
+ access_level=db_user.access_level,
109
+ custom_fields=custom_fields
110
+ )
111
+
112
+ def _validate_password_length(self, password: str) -> None:
113
+ """Validate that password meets length requirements and does not exceed 72 characters (bcrypt limit)"""
114
+ min_len = self.password_requirements.min_length
115
+ max_len = min(self.password_requirements.max_length, 72)
116
+ if len(password) < min_len:
117
+ raise InvalidInputError(400, "password_too_short", f"Password must be at least {min_len} characters long")
118
+ if len(password) > max_len:
119
+ raise InvalidInputError(400, "password_too_long", f"Password cannot exceed {max_len} characters")
120
+
121
+ def _initialize_db(self):
170
122
  session = self.Session()
171
123
  try:
172
124
  # Get existing columns in the database
173
125
  inspector = inspect(self.engine)
174
- existing_columns = {col['name'] for col in inspector.get_columns('users')}
175
-
176
- # Get all columns defined in the model
177
- dropped_columns = set(self.User.dropped_columns())
178
- model_columns = set(self.DbUser.__table__.columns.keys()) - dropped_columns
179
-
180
- # Find columns that are in the model but not in the database
181
- new_columns = model_columns - existing_columns
182
- if new_columns:
183
- add_columns_msg = f"Adding columns to database: {new_columns}"
184
- print("NOTE:", add_columns_msg)
185
- self.logger.info(add_columns_msg)
186
-
187
- for col_name in new_columns:
188
- col = self.DbUser.__table__.columns[col_name]
189
- column_type = col.type.compile(self.engine.dialect)
190
- nullable = "NULL" if col.nullable else "NOT NULL"
191
- if col.default is not None:
192
- default_val = f"'{col.default.arg}'" if isinstance(col.default.arg, str) else col.default.arg
193
- default = f"DEFAULT {default_val}"
194
- else:
195
- default = ""
196
-
197
- alter_stmt = f"ALTER TABLE users ADD COLUMN {col_name} {column_type} {nullable} {default}"
198
- session.execute(text(alter_stmt))
199
-
200
- session.commit()
201
126
 
202
- # Determine columns to drop
203
- columns_to_drop = dropped_columns.intersection(existing_columns)
204
- if columns_to_drop:
205
- drop_columns_msg = f"Dropping columns from database: {columns_to_drop}"
206
- print("NOTE:", drop_columns_msg)
207
- self.logger.info(drop_columns_msg)
208
-
209
- for col_name in columns_to_drop:
210
- session.execute(text(f"ALTER TABLE users DROP COLUMN {col_name}"))
127
+ for db_model in [self.DbUser, self.DbApiKey]:
128
+ table_name = db_model.__tablename__
129
+ existing_columns = {col['name'] for col in inspector.get_columns(table_name)}
130
+ model_columns = set(db_model.__table__.columns.keys())
131
+ new_columns = model_columns - existing_columns
211
132
 
212
- session.commit()
213
-
214
- # Find columns that are in the database but not in the model
215
- extra_db_columns = existing_columns - columns_to_drop - model_columns
216
- if extra_db_columns:
217
- self.logger.warning(f"The following database columns are not in the User model: {extra_db_columns}\n"
218
- "If you want to drop these columns, please use the `dropped_columns` class method of the User model.")
133
+ if new_columns:
134
+ add_columns_msg = f"Adding columns to table {table_name}: {new_columns}"
135
+ self.logger.info(add_columns_msg)
136
+
137
+ for col_name in new_columns:
138
+ col = db_model.__table__.columns[col_name]
139
+ column_type = col.type.compile(self.engine.dialect)
140
+ nullable = "NULL" if col.nullable else "NOT NULL"
141
+ if col.default is not None and not callable(col.default.arg):
142
+ default_val = f"'{col.default.arg}'" if isinstance(col.default.arg, str) else col.default.arg
143
+ default = f"DEFAULT {default_val}"
144
+ else:
145
+ # If nullable is False and no default is provided, use an empty string as a placeholder default for SQLite
146
+ # TODO: Use a different default value (instead of empty string) based on the column type
147
+ default = "DEFAULT ''" if not col.nullable else ""
148
+
149
+ alter_stmt = f"ALTER TABLE {table_name} ADD COLUMN {col_name} {column_type} {nullable} {default}"
150
+ session.execute(text(alter_stmt))
151
+
152
+ session.commit()
219
153
 
220
154
  # Get admin password from environment variable if exists
221
- admin_password = self.env_vars.get(c.SQRL_SECRET_ADMIN_PASSWORD)
155
+ admin_password = self.env_vars.secret_admin_password
222
156
 
223
- # If admin password variable exists, find username "admin". If it does not exist, add it
224
157
  if admin_password is not None:
158
+ self._validate_password_length(admin_password)
225
159
  password_hash = pwd_context.hash(admin_password)
226
160
  admin_user = session.get(self.DbUser, c.ADMIN_USERNAME)
227
161
  if admin_user is None:
228
- admin_user = self.DbUser(username=c.ADMIN_USERNAME, password_hash=password_hash, is_admin=True)
162
+ admin_user = self.DbUser(username=c.ADMIN_USERNAME, password_hash=password_hash, access_level="admin")
229
163
  session.add(admin_user)
230
164
  else:
231
165
  admin_user.password_hash = password_hash
@@ -236,9 +170,9 @@ class Authenticator(_t.Generic[User]):
236
170
  session.close()
237
171
 
238
172
  @cached_property
239
- def user_fields(self) -> list[UserField]:
173
+ def user_fields(self) -> UserFieldsModel:
240
174
  """
241
- Get the fields of the User model as a list of dictionaries
175
+ Get the fields of the CustomUserFields model as a list of dictionaries
242
176
 
243
177
  Each dictionary contains the following keys:
244
178
  - name: The name of the field
@@ -247,11 +181,10 @@ class Authenticator(_t.Generic[User]):
247
181
  - enum: The possible values of the field (or None if not applicable)
248
182
  - default: The default value of the field (or None if field is required)
249
183
  """
250
- schema = self.User.model_json_schema()
251
-
252
- fields = []
253
184
 
254
- properties: dict[str, dict[str, _t.Any]] = schema.get("properties", {})
185
+ custom_fields = []
186
+ schema = self.CustomUserFields.model_json_schema()
187
+ properties: dict[str, dict[str, Any]] = schema.get("properties", {})
255
188
  for field_name, field_schema in properties.items():
256
189
  if choices := field_schema.get("anyOf"):
257
190
  field_type = choices[0]["type"]
@@ -260,62 +193,112 @@ class Authenticator(_t.Generic[User]):
260
193
  field_type = field_schema["type"]
261
194
  nullable = False
262
195
 
263
- field_data = UserField(name=field_name, type=field_type, nullable=nullable, enum=field_schema.get("enum"), default=field_schema.get("default"))
264
- fields.append(field_data)
265
-
266
- return fields
196
+ field_data = UserField(name=field_name, label=field_schema.get("title", field_name), type=field_type, nullable=nullable, enum=field_schema.get("enum"), default=field_schema.get("default"))
197
+ custom_fields.append(field_data)
198
+
199
+ return UserFieldsModel(
200
+ username=UserField(name="username", label="Username / Email", type="string", nullable=False, enum=None, default=None),
201
+ access_level=UserField(name="access_level", label="Access Level", type="string", nullable=False, enum=["admin", "member"], default="member"),
202
+ custom_fields=custom_fields
203
+ )
267
204
 
268
- def add_user(self, username: str, user_fields: dict, *, update_user: bool = False) -> None:
269
- session = self.Session()
270
-
271
- # Validate the user data
205
+ def add_user(self, username: str, user_fields: dict, *, update_user: bool = False) -> RegisteredUser:
206
+ # Separate custom fields from base fields
207
+ access_level = user_fields.get('access_level', 'member')
208
+ password = user_fields.get('password')
209
+
210
+ # Validate access level - cannot add/update users with guest access level
211
+ if access_level == "guest":
212
+ raise InvalidInputError(400, "invalid_access_level", "Cannot create or update users with 'guest' access level")
213
+
214
+ # Extract custom fields
215
+ custom_fields_data: dict[str, Any] = user_fields.get('custom_fields', {})
216
+
217
+ # Validate the custom fields
272
218
  try:
273
- user_data = self.User(**user_fields, username=username).model_dump(mode='json')
219
+ custom_fields = self.CustomUserFields(**custom_fields_data)
220
+ custom_fields_json = json.dumps(custom_fields.model_dump(mode='json'))
274
221
  except ValidationError as e:
275
- raise InvalidInputError(102, f"Invalid user field '{e.errors()[0]['loc'][0]}': {e.errors()[0]['msg']}")
222
+ raise InvalidInputError(400, "invalid_user_data", f"Invalid user field '{e.errors()[0]['loc'][0]}': {e.errors()[0]['msg']}")
276
223
 
277
- # Add a new user
224
+ # Add or update user
225
+ session = self.Session()
278
226
  try:
279
227
  # Check if the user already exists
280
- existing_user = session.get(self.DbUser, username)
281
- if existing_user is not None:
228
+ db_user = session.get(self.DbUser, username)
229
+ if db_user is not None:
282
230
  if not update_user:
283
- raise InvalidInputError(101, f"User '{username}' already exists")
231
+ raise InvalidInputError(400, "username_already_exists", f"User '{username}' already exists")
232
+
233
+ if username == c.ADMIN_USERNAME and access_level != "admin":
234
+ raise InvalidInputError(403, "admin_cannot_be_non_admin", "Setting the admin user to non-admin is not permitted")
284
235
 
285
- if username == c.ADMIN_USERNAME and user_data.get("is_admin") is False:
286
- raise InvalidInputError(24, "Setting the admin user to non-admin is not permitted")
287
- new_user = self.DbUser(password_hash=existing_user.password_hash, **user_data)
288
- session.delete(existing_user)
236
+ # Update existing user
237
+ db_user.access_level = access_level
238
+ db_user.custom_fields = custom_fields_json
289
239
  else:
290
240
  if update_user:
291
- raise InvalidInputError(41, f"No user found for username: {username}")
241
+ raise InvalidInputError(404, "no_user_found_for_username", f"No user found for username: {username}")
292
242
 
293
- password = user_fields.get('password')
294
243
  if password is None:
295
- raise InvalidInputError(100, f"Missing required field 'password' when adding a new user")
244
+ raise InvalidInputError(400, "missing_password", f"Missing required field 'password' when adding a new user")
245
+
246
+ self._validate_password_length(password)
296
247
  password_hash = pwd_context.hash(password)
297
- new_user = self.DbUser(password_hash=password_hash, **user_data)
298
-
299
- # Add the user to the session
300
- session.add(new_user)
248
+ db_user = self.DbUser(
249
+ username=username,
250
+ access_level=access_level,
251
+ password_hash=password_hash,
252
+ custom_fields=custom_fields_json
253
+ )
254
+ session.add(db_user)
301
255
 
302
256
  # Commit the transaction
303
257
  session.commit()
258
+ return self._convert_db_user_to_user(db_user)
304
259
 
305
260
  finally:
306
261
  session.close()
262
+
263
+ def create_or_get_user_from_provider(self, provider_name: str, user_info: dict) -> RegisteredUser:
264
+ provider = next((p for p in self.auth_providers if p.name == provider_name), None)
265
+ if provider is None:
266
+ raise InvalidInputError(404, "auth_provider_not_found", f"Provider '{provider_name}' not found")
267
+
268
+ user = provider.provider_configs.get_user(user_info)
269
+ session = self.Session()
270
+ try:
271
+ # Convert user to database user
272
+ custom_fields_json = user.custom_fields.model_dump_json()
273
+ db_user = self.DbUser(
274
+ username=user.username,
275
+ access_level=user.access_level,
276
+ password_hash="", # By omitting password_hash, it becomes impossible to login with username and password (OAuth only)
277
+ custom_fields=custom_fields_json
278
+ )
279
+
280
+ existing_db_user = session.get(self.DbUser, db_user.username)
281
+ if existing_db_user is None:
282
+ session.add(db_user)
283
+
284
+ session.commit()
307
285
 
308
- def get_user(self, username: str, password: str) -> User:
286
+ return self._convert_db_user_to_user(db_user)
287
+
288
+ finally:
289
+ session.close()
290
+
291
+ def get_user(self, username: str, password: str) -> RegisteredUser:
309
292
  session = self.Session()
310
293
  try:
311
294
  # Query for user by username
312
295
  db_user = session.get(self.DbUser, username)
313
296
 
314
297
  if db_user and pwd_context.verify(password, db_user.password_hash):
315
- user = self.User.model_validate(db_user)
316
- return user # type: ignore
298
+ user = self._convert_db_user_to_user(db_user)
299
+ return user
317
300
  else:
318
- raise InvalidInputError(0, f"Username or password not found")
301
+ raise InvalidInputError(401, "incorrect_username_or_password", f"Incorrect username or password")
319
302
 
320
303
  finally:
321
304
  session.close()
@@ -325,127 +308,470 @@ class Authenticator(_t.Generic[User]):
325
308
  try:
326
309
  db_user = session.get(self.DbUser, username)
327
310
  if db_user is None:
328
- raise InvalidInputError(2, f"User not found")
311
+ raise InvalidInputError(401, "user_not_found", f"Username '{username}' not found for password change")
329
312
 
330
- if pwd_context.verify(old_password, db_user.password_hash):
313
+ if db_user.password_hash and pwd_context.verify(old_password, db_user.password_hash):
314
+ self._validate_password_length(new_password)
331
315
  db_user.password_hash = pwd_context.hash(new_password)
332
316
  session.commit()
333
317
  else:
334
- raise InvalidInputError(3, f"Incorrect password")
318
+ raise InvalidInputError(401, "incorrect_password", f"Incorrect password")
335
319
  finally:
336
320
  session.close()
337
321
 
338
322
  def delete_user(self, username: str) -> None:
339
323
  if username == c.ADMIN_USERNAME:
340
- raise InvalidInputError(23, "Cannot delete the admin user")
324
+ raise InvalidInputError(403, "cannot_delete_admin_user", "Cannot delete the admin user")
341
325
 
342
326
  session = self.Session()
343
327
  try:
344
328
  db_user = session.get(self.DbUser, username)
345
329
  if db_user is None:
346
- raise InvalidInputError(41, f"No user found for username: {username}")
330
+ raise InvalidInputError(404, "no_user_found_for_username", f"No user found for username: {username}")
347
331
  session.delete(db_user)
348
332
  session.commit()
349
333
  finally:
350
334
  session.close()
351
335
 
352
- def get_all_users(self) -> list[User]:
336
+ def get_all_users(self) -> list[RegisteredUser]:
353
337
  session = self.Session()
354
338
  try:
355
339
  db_users = session.query(self.DbUser).all()
356
- return [self.User.model_validate(user) for user in db_users] # type: ignore
340
+ return [self._convert_db_user_to_user(user) for user in db_users]
357
341
  finally:
358
342
  session.close()
359
343
 
360
- def create_access_token(self, user: User, expiry_minutes: int | None, *, title: str | None = None) -> tuple[str, datetime]:
344
+ def create_access_token(self, user: AbstractUser, expiry_minutes: int | None, *, title: str | None = None) -> tuple[str, datetime]:
345
+ """
346
+ Creates an API key if title is provided. Otherwise, creates a JWT token.
347
+ """
361
348
  created_at = datetime.now(timezone.utc)
362
349
  expire_at = created_at + timedelta(minutes=expiry_minutes) if expiry_minutes is not None else datetime.max
363
- token_id = None
350
+
351
+ if self.secret_key is None:
352
+ raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to create an access token")
353
+
364
354
  if title is not None:
365
355
  session = self.Session()
366
356
  try:
367
- access_token = self.DbAccessToken(title=title, username=user.username, created_at=created_at, expires_at=expire_at)
368
- session.add(access_token)
357
+ token_id = "sqrl-" + secrets.token_urlsafe(16)
358
+ hashed_key = u.hash_string(token_id, salt=self.secret_key)
359
+ last_four = token_id[-4:]
360
+ api_key = self.DbApiKey(
361
+ hashed_key=hashed_key, last_four=last_four, title=title, username=user.username,
362
+ created_at=created_at, expires_at=expire_at
363
+ )
364
+ session.add(api_key)
369
365
  session.commit()
370
- token_id = access_token.token_id
371
366
  finally:
372
367
  session.close()
368
+ else:
369
+ to_encode = {"username": user.username, "exp": expire_at}
370
+ token_id = jwt.encode(to_encode, self.secret_key, algorithm="HS256")
373
371
 
374
- if self.secret_key is None:
375
- raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to create an access token")
376
- to_encode = {"username": user.username, "token_id": token_id, "exp": expire_at}
377
- encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm="HS256")
378
- return encoded_jwt, expire_at
372
+ return token_id, expire_at
379
373
 
380
- def get_user_from_token(self, token: str | None) -> User | None:
381
- if token is None or token == "":
382
- return None
374
+ def get_user_from_token(self, token: str | None) -> tuple[RegisteredUser | None, float | None]:
375
+ """
376
+ Get a user and expiry time from an access token (JWT, or API key if token starts with 'sqrl-')
377
+ """
378
+ if not token:
379
+ return None, None
383
380
 
384
381
  if self.secret_key is None:
385
382
  raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to get user from an access token")
386
383
 
387
- try:
388
- payload: dict = jwt.decode(token, self.secret_key, algorithms=["HS256"])
389
- except InvalidTokenError:
390
- raise InvalidInputError(1, "Invalid authorization token")
391
-
392
384
  session = self.Session()
393
385
  try:
394
- if payload.get("token_id") is not None:
395
- access_token = session.query(self.DbAccessToken).filter(
396
- self.DbAccessToken.username == payload["username"],
397
- self.DbAccessToken.token_id == payload["token_id"],
398
- self.DbAccessToken.expires_at >= func.now()
386
+ if token.startswith("sqrl-"):
387
+ hashed_key = u.hash_string(token, salt=self.secret_key)
388
+ api_key = session.query(self.DbApiKey).filter(
389
+ self.DbApiKey.hashed_key == hashed_key,
390
+ self.DbApiKey.expires_at >= func.now()
399
391
  ).first()
400
- if access_token is None:
401
- raise InvalidInputError(1, "Invalid authorization token")
402
-
403
- db_user = session.get(self.DbUser, payload["username"])
392
+ if api_key is None:
393
+ raise InvalidTokenError()
394
+ username = api_key.username
395
+ expiry = None
396
+ else:
397
+ payload: dict = jwt.decode(token, self.secret_key, algorithms=["HS256"])
398
+ username = payload["username"]
399
+ expiry = payload.get("exp")
400
+
401
+ db_user = session.get(self.DbUser, username)
404
402
  if db_user is None:
405
- raise InvalidInputError(1, "Invalid authorization token")
403
+ raise InvalidTokenError()
404
+
405
+ user = self._convert_db_user_to_user(db_user)
406
+
407
+ except InvalidTokenError:
408
+ raise InvalidInputError(401, "invalid_authorization_token", "Invalid authorization token")
406
409
  finally:
407
410
  session.close()
408
411
 
409
- user = self.User.model_validate(db_user)
410
- return user # type: ignore
412
+ return user, expiry
411
413
 
412
- def get_all_tokens(self, username: str) -> list[AccessToken]:
414
+ def _get_jwks_client(self, jwks_uri: str) -> PyJWKClient:
415
+ if jwks_uri not in self._jwks_clients:
416
+ self._jwks_clients[jwks_uri] = PyJWKClient(jwks_uri)
417
+ return self._jwks_clients[jwks_uri]
418
+
419
+ def _get_issuer_from_token(self, token: str) -> str | None:
420
+ try:
421
+ # JWT format is header.payload.signature
422
+ parts = token.split('.')
423
+ if len(parts) != 3:
424
+ return None
425
+
426
+ # Base64url decode the payload (JWT uses base64url encoding)
427
+ payload_b64 = parts[1]
428
+ # Add padding if necessary: (-len % 4) yields 0..3
429
+ padding = '=' * ((-len(payload_b64)) % 4)
430
+ payload_json = base64.urlsafe_b64decode(payload_b64 + padding).decode("utf-8")
431
+ payload: dict = json.loads(payload_json)
432
+
433
+ issuer = payload.get("iss")
434
+ return issuer if isinstance(issuer, str) and issuer else None
435
+ except Exception:
436
+ return None
437
+
438
+ @staticmethod
439
+ def _is_jwt(token: str) -> bool:
440
+ if not isinstance(token, str) or not token:
441
+ return False
442
+ parts = token.split(".")
443
+ return len(parts) == 3 and all(parts)
444
+
445
+ def _get_provider_metadata(self, provider_name: str) -> dict:
446
+ provider = next((p for p in self.auth_providers if p.name == provider_name), None)
447
+ if provider is None:
448
+ raise InvalidInputError(404, "auth_provider_not_found", f"Provider '{provider_name}' not found")
449
+
450
+ metadata_url = provider.provider_configs.server_metadata_url
451
+ cached = self._provider_metadata_cache.get(metadata_url)
452
+ if isinstance(cached, dict) and cached:
453
+ return cached
454
+
455
+ try:
456
+ response = requests.get(metadata_url, timeout=10)
457
+ response.raise_for_status()
458
+ metadata = response.json()
459
+ if not isinstance(metadata, dict):
460
+ raise ValueError("Provider metadata was not a JSON object")
461
+ except Exception as e:
462
+ raise ConfigurationError(f"Failed to fetch metadata for provider '{provider_name}': {str(e)}") from e
463
+
464
+ self._provider_metadata_cache[metadata_url] = metadata
465
+ return metadata
466
+
467
+ def _verify_provider_jwt(
468
+ self,
469
+ provider_name: str,
470
+ token: str,
471
+ *,
472
+ purpose: str,
473
+ expected_nonce: str | None = None,
474
+ verify_aud: bool = True,
475
+ ) -> dict | None:
476
+ """
477
+ Verify a provider-issued JWT (signature + exp; aud when requested).
478
+ Uses OIDC discovery for jwks_uri and (best-effort) issuer validation.
479
+ """
480
+ if not self._is_jwt(token):
481
+ return None
482
+
483
+ provider = next((p for p in self.auth_providers if p.name == provider_name), None)
484
+ if provider is None:
485
+ raise InvalidInputError(404, "auth_provider_not_found", f"Provider '{provider_name}' not found")
486
+
487
+ metadata = self._get_provider_metadata(provider_name)
488
+ jwks_uri = metadata.get("jwks_uri")
489
+ if not isinstance(jwks_uri, str) or not jwks_uri:
490
+ raise ConfigurationError(f"jwks_uri not found in metadata for provider '{provider_name}'")
491
+
492
+ signing_algs = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
493
+ if not isinstance(signing_algs, list) or not signing_algs:
494
+ signing_algs = ["RS256"]
495
+
496
+ jwks_client = self._get_jwks_client(jwks_uri)
497
+ signing_key = jwks_client.get_signing_key_from_jwt(token)
498
+
499
+ decode_kwargs: dict[str, Any] = {
500
+ "key": signing_key.key,
501
+ "algorithms": signing_algs,
502
+ "options": {
503
+ "verify_aud": bool(verify_aud),
504
+ # We'll validate issuer manually to avoid brittle trailing-slash mismatches.
505
+ "verify_iss": False,
506
+ },
507
+ }
508
+ if verify_aud:
509
+ decode_kwargs["audience"] = provider.provider_configs.client_id
510
+
511
+ try:
512
+ payload = jwt.decode(token, **decode_kwargs)
513
+ except Exception:
514
+ return None
515
+
516
+ if not isinstance(payload, dict):
517
+ return None
518
+
519
+ expected_issuer = metadata.get("issuer") or provider.provider_configs.server_url
520
+ token_issuer = payload.get("iss")
521
+ if isinstance(expected_issuer, str) and expected_issuer and isinstance(token_issuer, str) and token_issuer:
522
+ if token_issuer.rstrip("/") != expected_issuer.rstrip("/"):
523
+ raise InvalidInputError(
524
+ 401,
525
+ "invalid_provider_token",
526
+ f"Invalid {purpose} issuer for provider '{provider_name}'",
527
+ )
528
+
529
+ if expected_nonce is not None:
530
+ nonce_claim = payload.get("nonce")
531
+ if nonce_claim != expected_nonce:
532
+ raise InvalidInputError(
533
+ 401,
534
+ "invalid_provider_token",
535
+ f"Invalid {purpose} nonce for provider '{provider_name}'",
536
+ )
537
+
538
+ return payload
539
+
540
+ def get_user_info_from_token_details(self, provider_name: str, token_details: dict, *, expected_nonce: str | None = None) -> dict:
541
+ """
542
+ Determine user_info from an OAuth/OIDC token response.
543
+
544
+ Priority:
545
+ - token_details["user_info"] / token_details["userinfo"]
546
+ - verify + decode token_details["id_token"] if it's a JWT
547
+ - verify + decode token_details["access_token"] if it's a JWT
548
+ - for opaque access_token: call userinfo or introspection endpoint from provider metadata
549
+ """
550
+ for key in ("user_info", "userinfo"):
551
+ user_info = token_details.get(key)
552
+ if isinstance(user_info, dict) and user_info:
553
+ return user_info
554
+
555
+ id_token = token_details.get("id_token")
556
+ if isinstance(id_token, str):
557
+ if payload := self._verify_provider_jwt(
558
+ provider_name, id_token, purpose="id_token", expected_nonce=expected_nonce, verify_aud=True
559
+ ):
560
+ return payload
561
+
562
+ access_token = token_details.get("access_token")
563
+ if isinstance(access_token, str):
564
+ # Some providers issue JWT access tokens. Audience can vary (resource server),
565
+ # so we verify signature/exp and issuer, but skip aud validation.
566
+ if payload := self._verify_provider_jwt(
567
+ provider_name, access_token, purpose="access_token", expected_nonce=None, verify_aud=False
568
+ ):
569
+ return payload
570
+
571
+ if not isinstance(access_token, str) or not access_token:
572
+ raise InvalidInputError(
573
+ 400,
574
+ "invalid_provider_user_info",
575
+ f"User information not found in token details for {provider_name}",
576
+ )
577
+
578
+ provider = next((p for p in self.auth_providers if p.name == provider_name), None)
579
+ if provider is None:
580
+ raise InvalidInputError(404, "auth_provider_not_found", f"Provider '{provider_name}' not found")
581
+
582
+ metadata: dict = self._get_provider_metadata(provider_name)
583
+
584
+ userinfo_endpoint = metadata.get("userinfo_endpoint")
585
+ if isinstance(userinfo_endpoint, str) and userinfo_endpoint:
586
+ try:
587
+ response = requests.get(
588
+ userinfo_endpoint,
589
+ headers={"Authorization": f"Bearer {access_token}"},
590
+ timeout=10,
591
+ )
592
+ response.raise_for_status()
593
+ user_info = response.json()
594
+ if isinstance(user_info, dict) and user_info:
595
+ return user_info
596
+ except Exception as e:
597
+ self.logger.warning(f"Failed to fetch user info from {userinfo_endpoint} for provider {provider_name}: {str(e)}")
598
+ # Fall back to introspection if available
599
+ pass
600
+
601
+ introspection_endpoint = metadata.get("introspection_endpoint")
602
+ if isinstance(introspection_endpoint, str) and introspection_endpoint:
603
+ try:
604
+ response = requests.post(
605
+ introspection_endpoint,
606
+ data={"token": access_token},
607
+ auth=(provider.provider_configs.client_id, provider.provider_configs.client_secret),
608
+ timeout=10,
609
+ )
610
+ response.raise_for_status()
611
+ token_info = response.json()
612
+ if isinstance(token_info, dict):
613
+ if token_info.get("active") is False:
614
+ raise InvalidInputError(401, "inactive_external_token", "External authorization token is inactive")
615
+ return token_info
616
+ except InvalidInputError:
617
+ raise
618
+ except Exception as e:
619
+ self.logger.warning(f"Introspection request failed for provider {provider_name} using basic auth: {str(e)}")
620
+ # Some providers require "client_secret_post" instead of basic auth for introspection.
621
+ try:
622
+ response = requests.post(
623
+ introspection_endpoint,
624
+ data={
625
+ "token": access_token,
626
+ "client_id": provider.provider_configs.client_id,
627
+ "client_secret": provider.provider_configs.client_secret,
628
+ },
629
+ timeout=10,
630
+ )
631
+ response.raise_for_status()
632
+ token_info = response.json()
633
+ if isinstance(token_info, dict):
634
+ if token_info.get("active") is False:
635
+ raise InvalidInputError(401, "inactive_external_token", "External authorization token is inactive")
636
+ return token_info
637
+ except InvalidInputError:
638
+ raise
639
+ except Exception as e2:
640
+ self.logger.warning(f"Introspection request failed for provider {provider_name} using post data: {str(e2)}")
641
+ pass
642
+
643
+ raise InvalidInputError(
644
+ 400,
645
+ "invalid_provider_user_info",
646
+ f"User information not found in token details for {provider_name}",
647
+ )
648
+
649
+ def get_user_from_external_token(self, token: str, provider_name: str | None = None) -> tuple[RegisteredUser, float | None]:
650
+ """
651
+ Get a user from an external OAuth token by validating against provider's JWKS
652
+ """
653
+ issuer: str | None = None
654
+ token_is_jwt = self._is_jwt(token)
655
+
656
+ if provider_name:
657
+ provider = next((p for p in self.auth_providers if p.name == provider_name), None)
658
+ elif token_is_jwt:
659
+ issuer = self._get_issuer_from_token(token)
660
+ if not issuer:
661
+ raise InvalidInputError(401, "invalid_external_token", "Could not extract issuer from token")
662
+
663
+ # Match provider by issuer (server_url)
664
+ provider = next(
665
+ (p for p in self.auth_providers if p.provider_configs.server_url.rstrip("/") == issuer.rstrip("/")),
666
+ None,
667
+ )
668
+ else:
669
+ # Opaque external token: if there's exactly one configured provider, assume it.
670
+ provider = self.auth_providers[0] if len(self.auth_providers) == 1 else None
671
+
672
+ if provider is None:
673
+ if provider_name:
674
+ raise InvalidInputError(404, "auth_provider_not_found", f"Provider '{provider_name}' not found")
675
+ if token_is_jwt:
676
+ raise InvalidInputError(401, "auth_provider_not_found", f"No provider found for issuer: {issuer}")
677
+ raise InvalidInputError(401, "invalid_external_token", "Could not determine provider for external token")
678
+
679
+ # JWT external token: validate signature/exp (+ issuer) via provider JWKS
680
+ if token_is_jwt:
681
+ try:
682
+ payload = self._verify_provider_jwt(provider.name, token, purpose="external_token", verify_aud=False)
683
+ except InvalidInputError:
684
+ # Keep the external-auth contract stable (avoid leaking provider-specific details).
685
+ raise InvalidInputError(401, "invalid_external_token", "Invalid external authorization token")
686
+
687
+ if not isinstance(payload, dict) or not payload:
688
+ raise InvalidInputError(401, "invalid_external_token", "Invalid external authorization token")
689
+ else:
690
+ # Opaque token: reuse the existing provider userinfo/introspection logic.
691
+ try:
692
+ payload = self.get_user_info_from_token_details(provider.name, {"access_token": token})
693
+ except InvalidInputError as e:
694
+ # Normalize into the external-auth error contract.
695
+ if getattr(e, "error", None) == "inactive_external_token":
696
+ raise
697
+ raise InvalidInputError(401, "invalid_external_token", "Invalid external authorization token")
698
+
699
+ if not isinstance(payload, dict) or not payload:
700
+ raise InvalidInputError(401, "invalid_external_token", "Invalid external authorization token")
701
+
702
+ user = provider.provider_configs.get_user(payload)
703
+ exp = payload.get("exp")
704
+ expiry: float | None
705
+ if isinstance(exp, (int, float)):
706
+ expiry = float(exp)
707
+ else:
708
+ expiry = None
709
+
710
+ return user, expiry
711
+
712
+ def get_all_api_keys(self, username: str) -> list[ApiKey]:
713
+ """
714
+ Get the ID, title, and expiry date of all API keys for a user. Note that the ID is a hash of the API key, not the API key itself.
715
+ """
413
716
  session = self.Session()
414
717
  try:
415
- tokens = session.query(self.DbAccessToken).filter(
416
- self.DbAccessToken.username == username,
417
- self.DbAccessToken.expires_at >= func.now()
718
+ tokens = session.query(self.DbApiKey).filter(
719
+ self.DbApiKey.username == username,
720
+ self.DbApiKey.expires_at >= func.now()
418
721
  ).all()
419
722
 
420
- return [AccessToken.model_validate(token) for token in tokens]
723
+ return [ApiKey.model_validate(token) for token in tokens]
421
724
  finally:
422
725
  session.close()
423
726
 
424
- def revoke_token(self, username: str, token_id: str) -> None:
727
+ def revoke_api_key(self, username: str, api_key_id: str) -> None:
728
+ """
729
+ Revoke an API key
730
+ """
425
731
  session = self.Session()
426
732
  try:
427
- access_token = session.query(self.DbAccessToken).filter(
428
- self.DbAccessToken.username == username,
429
- self.DbAccessToken.token_id == token_id
733
+
734
+ api_key = session.query(self.DbApiKey).filter(
735
+ self.DbApiKey.username == username,
736
+ self.DbApiKey.id == api_key_id
430
737
  ).first()
431
738
 
432
- if access_token is None:
433
- raise InvalidInputError(40, f"No access token found for token_id: {token_id}")
739
+ if api_key is None:
740
+ raise InvalidInputError(404, "api_key_not_found", f"The API key could not be found: {api_key_id}")
434
741
 
435
- session.delete(access_token)
742
+ session.delete(api_key)
436
743
  session.commit()
437
744
  finally:
438
745
  session.close()
439
746
 
440
- def can_user_access_scope(self, user: User | None, scope: PermissionScope) -> bool:
441
- if user is None:
747
+ def can_user_access_scope(self, user: AbstractUser, scope: PermissionScope) -> bool:
748
+ if user.access_level == "guest":
442
749
  user_level = PermissionScope.PUBLIC
443
- elif user.is_admin:
750
+ elif user.access_level == "admin":
444
751
  user_level = PermissionScope.PRIVATE
445
- else:
752
+ else: # member
446
753
  user_level = PermissionScope.PROTECTED
447
754
 
448
755
  return user_level.value >= scope.value
449
-
756
+
450
757
  def close(self) -> None:
451
- self.engine.dispose()
758
+ if hasattr(self, "engine"):
759
+ self.engine.dispose()
760
+
761
+
762
+ def provider(name: str, label: str, icon: str):
763
+ """
764
+ Decorator to register an authentication provider
765
+
766
+ Arguments:
767
+ name: The name of the provider (must be unique, e.g. 'google')
768
+ label: The label of the provider (e.g. 'Google')
769
+ icon: The URL of the icon of the provider (e.g. 'https://www.google.com/favicon.ico')
770
+ """
771
+ def decorator(func: Callable[[AuthProviderArgs], ProviderConfigs]):
772
+ def wrapper(sqrl: AuthProviderArgs):
773
+ provider_configs = func(sqrl)
774
+ return AuthProvider(name=name, label=label, icon=icon, provider_configs=provider_configs)
775
+ Authenticator.providers.append(wrapper)
776
+ return wrapper
777
+ return decorator