squirrels 0.5.0b2__py3-none-any.whl → 0.5.0b4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of squirrels might be problematic. Click here for more details.

Files changed (96) hide show
  1. dateutils/__init__.py +6 -460
  2. dateutils/_enums.py +25 -0
  3. dateutils/_implementation.py +409 -0
  4. dateutils/types.py +6 -0
  5. squirrels/__init__.py +9 -13
  6. squirrels/_api_routes/__init__.py +5 -0
  7. squirrels/_api_routes/auth.py +262 -0
  8. squirrels/_api_routes/base.py +154 -0
  9. squirrels/_api_routes/dashboards.py +142 -0
  10. squirrels/_api_routes/data_management.py +103 -0
  11. squirrels/_api_routes/datasets.py +242 -0
  12. squirrels/_api_routes/oauth2.py +300 -0
  13. squirrels/_api_routes/project.py +214 -0
  14. squirrels/_api_server.py +145 -748
  15. squirrels/_arguments/__init__.py +0 -0
  16. squirrels/{arguments → _arguments}/init_time_args.py +7 -2
  17. squirrels/{arguments → _arguments}/run_time_args.py +4 -26
  18. squirrels/_auth.py +646 -93
  19. squirrels/_connection_set.py +5 -5
  20. squirrels/_constants.py +7 -1
  21. squirrels/{_dashboards_io.py → _dashboards.py} +87 -6
  22. squirrels/_data_sources.py +564 -0
  23. squirrels/_exceptions.py +9 -37
  24. squirrels/_initializer.py +31 -26
  25. squirrels/_manifest.py +5 -5
  26. squirrels/_model_builder.py +1 -1
  27. squirrels/_model_configs.py +2 -2
  28. squirrels/_model_queries.py +1 -1
  29. squirrels/_models.py +40 -27
  30. squirrels/{package_data → _package_data}/base_project/.env +1 -0
  31. squirrels/{package_data → _package_data}/base_project/.env.example +1 -0
  32. squirrels/{package_data → _package_data}/base_project/dashboards/dashboard_example.py +4 -4
  33. squirrels/{package_data → _package_data}/base_project/dashboards/dashboard_example.yml +2 -2
  34. squirrels/_package_data/base_project/macros/macros_example.sql +17 -0
  35. squirrels/{package_data → _package_data}/base_project/models/builds/build_example.py +2 -2
  36. squirrels/{package_data → _package_data}/base_project/models/builds/build_example.sql +1 -1
  37. squirrels/{package_data → _package_data}/base_project/models/dbviews/dbview_example.sql +1 -1
  38. squirrels/_package_data/base_project/models/federates/federate_example.py +41 -0
  39. squirrels/_package_data/base_project/models/federates/federate_example.sql +25 -0
  40. squirrels/{package_data → _package_data}/base_project/models/federates/federate_example.yml +6 -6
  41. squirrels/{package_data → _package_data}/base_project/parameters.yml +9 -8
  42. squirrels/_package_data/base_project/pyconfigs/connections.py +14 -0
  43. squirrels/{package_data → _package_data}/base_project/pyconfigs/context.py +14 -16
  44. squirrels/_package_data/base_project/pyconfigs/parameters.py +106 -0
  45. squirrels/_package_data/base_project/pyconfigs/user.py +51 -0
  46. squirrels/_package_data/templates/dataset_results.html +112 -0
  47. squirrels/_package_data/templates/oauth_login.html +271 -0
  48. squirrels/_parameter_configs.py +35 -35
  49. squirrels/_parameter_options.py +348 -0
  50. squirrels/_parameter_sets.py +47 -37
  51. squirrels/_parameters.py +1664 -0
  52. squirrels/_project.py +76 -32
  53. squirrels/_py_module.py +3 -2
  54. squirrels/_schemas/__init__.py +0 -0
  55. squirrels/_schemas/auth_models.py +144 -0
  56. squirrels/_schemas/query_param_models.py +67 -0
  57. squirrels/{_api_response_models.py → _schemas/response_models.py} +12 -8
  58. squirrels/_utils.py +38 -4
  59. squirrels/arguments.py +2 -0
  60. squirrels/auth.py +1 -0
  61. squirrels/connections.py +1 -0
  62. squirrels/dashboards.py +1 -82
  63. squirrels/data_sources.py +8 -563
  64. squirrels/parameter_options.py +8 -348
  65. squirrels/parameters.py +9 -1266
  66. squirrels/types.py +11 -0
  67. {squirrels-0.5.0b2.dist-info → squirrels-0.5.0b4.dist-info}/METADATA +4 -1
  68. squirrels-0.5.0b4.dist-info/RECORD +94 -0
  69. squirrels/package_data/base_project/macros/macros_example.sql +0 -15
  70. squirrels/package_data/base_project/models/federates/federate_example.py +0 -44
  71. squirrels/package_data/base_project/models/federates/federate_example.sql +0 -17
  72. squirrels/package_data/base_project/pyconfigs/connections.py +0 -14
  73. squirrels/package_data/base_project/pyconfigs/parameters.py +0 -93
  74. squirrels/package_data/base_project/pyconfigs/user.py +0 -23
  75. squirrels-0.5.0b2.dist-info/RECORD +0 -70
  76. /squirrels/{dataset_result.py → _dataset_types.py} +0 -0
  77. /squirrels/{package_data → _package_data}/base_project/assets/expenses.db +0 -0
  78. /squirrels/{package_data → _package_data}/base_project/assets/weather.db +0 -0
  79. /squirrels/{package_data → _package_data}/base_project/connections.yml +0 -0
  80. /squirrels/{package_data → _package_data}/base_project/docker/.dockerignore +0 -0
  81. /squirrels/{package_data → _package_data}/base_project/docker/Dockerfile +0 -0
  82. /squirrels/{package_data → _package_data}/base_project/docker/compose.yml +0 -0
  83. /squirrels/{package_data → _package_data}/base_project/duckdb_init.sql +0 -0
  84. /squirrels/{package_data/base_project/.gitignore → _package_data/base_project/gitignore} +0 -0
  85. /squirrels/{package_data → _package_data}/base_project/models/builds/build_example.yml +0 -0
  86. /squirrels/{package_data → _package_data}/base_project/models/dbviews/dbview_example.yml +0 -0
  87. /squirrels/{package_data → _package_data}/base_project/models/sources.yml +0 -0
  88. /squirrels/{package_data → _package_data}/base_project/seeds/seed_categories.csv +0 -0
  89. /squirrels/{package_data → _package_data}/base_project/seeds/seed_categories.yml +0 -0
  90. /squirrels/{package_data → _package_data}/base_project/seeds/seed_subcategories.csv +0 -0
  91. /squirrels/{package_data → _package_data}/base_project/seeds/seed_subcategories.yml +0 -0
  92. /squirrels/{package_data → _package_data}/base_project/squirrels.yml.j2 +0 -0
  93. /squirrels/{package_data → _package_data}/base_project/tmp/.gitignore +0 -0
  94. {squirrels-0.5.0b2.dist-info → squirrels-0.5.0b4.dist-info}/WHEEL +0 -0
  95. {squirrels-0.5.0b2.dist-info → squirrels-0.5.0b4.dist-info}/entry_points.txt +0 -0
  96. {squirrels-0.5.0b2.dist-info → squirrels-0.5.0b4.dist-info}/licenses/LICENSE +0 -0
squirrels/_auth.py CHANGED
@@ -1,18 +1,24 @@
1
+ from typing import Callable
1
2
  from datetime import datetime, timedelta, timezone
2
3
  from enum import Enum
3
4
  from functools import cached_property
4
5
  from jwt.exceptions import InvalidTokenError
5
6
  from passlib.context import CryptContext
6
- from pydantic import BaseModel, ConfigDict, ValidationError
7
+ from pydantic import ValidationError
7
8
  from pydantic_core import PydanticUndefined
8
9
  from sqlalchemy import create_engine, Engine, func, inspect, text, ForeignKey
9
10
  from sqlalchemy import Column, String, Integer, Float, Boolean
10
11
  from sqlalchemy.orm import declarative_base, sessionmaker, Mapped, mapped_column
11
- import jwt, types, typing as _t, uuid
12
+ import jwt, types, typing as _t, uuid, secrets, json
12
13
 
13
14
  from ._manifest import PermissionScope
14
15
  from ._py_module import PyModule
15
16
  from ._exceptions import InvalidInputError, ConfigurationError
17
+ from ._arguments.init_time_args import AuthProviderArgs
18
+ from ._schemas.auth_models import (
19
+ BaseUser, ApiKey, UserField, AuthProvider, ProviderConfigs, ClientRegistrationRequest, ClientUpdateRequest,
20
+ ClientDetailsResponse, ClientRegistrationResponse, ClientUpdateResponse, TokenResponse
21
+ )
16
22
  from . import _utils as u, _constants as c
17
23
 
18
24
  pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
@@ -20,41 +26,20 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
20
26
  reserved_fields = ["username", "is_admin"]
21
27
  disallowed_fields = ["password", "password_hash", "created_at", "token_id", "exp"]
22
28
 
23
- class BaseUser(BaseModel):
24
- model_config = ConfigDict(from_attributes=True)
25
- username: str
26
- is_admin: bool = False
27
-
28
- @classmethod
29
- def dropped_columns(cls):
30
- return []
31
-
32
- def __hash__(self):
33
- return hash(self.username)
34
-
35
29
  User = _t.TypeVar('User', bound=BaseUser)
36
30
 
37
- class AccessToken(BaseModel):
38
- model_config = ConfigDict(from_attributes=True)
39
- token_id: str
40
- title: str
41
- username: str
42
- created_at: datetime
43
- expires_at: datetime
44
-
45
-
46
- class UserField(BaseModel):
47
- name: str
48
- type: str
49
- nullable: bool
50
- enum: list[str] | None
51
- default: _t.Any | None
31
+ ProviderFunctionType = Callable[[AuthProviderArgs], AuthProvider]
52
32
 
53
33
 
54
34
  class Authenticator(_t.Generic[User]):
55
- def __init__(self, logger: u.Logger, base_path: str, env_vars: dict[str, str], *, sa_engine: Engine | None = None, cls: type[User] | None = None):
35
+ providers: list[ProviderFunctionType] = [] # static variable to stage providers
36
+
37
+ def __init__(
38
+ self, logger: u.Logger, base_path: str, auth_args: AuthProviderArgs, provider_functions: list[ProviderFunctionType],
39
+ user_cls: type[User], *, sa_engine: Engine | None = None
40
+ ):
56
41
  self.logger = logger
57
- self.env_vars = env_vars
42
+ self.env_vars = auth_args.env_vars
58
43
  self.secret_key = self.env_vars.get(c.SQRL_SECRET_KEY)
59
44
 
60
45
  # Create a new declarative base for this instance
@@ -69,27 +54,89 @@ class Authenticator(_t.Generic[User]):
69
54
  password_hash: Mapped[str] = mapped_column(nullable=False)
70
55
  created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
71
56
 
72
- # Define DbAccessToken class for this instance
73
- class DbAccessToken(self.Base):
74
- __tablename__ = 'access_tokens'
57
+ # Define DbApiKey class for this instance
58
+ class DbApiKey(self.Base):
59
+ __tablename__ = 'api_keys'
75
60
 
76
- token_id: Mapped[str] = mapped_column(primary_key=True, default=lambda: str(uuid.uuid4()))
61
+ id: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
62
+ hashed_key: Mapped[str] = mapped_column(unique=True, nullable=False)
77
63
  title: Mapped[str] = mapped_column(nullable=False)
78
64
  username: Mapped[str] = mapped_column(ForeignKey('users.username', ondelete='CASCADE'), nullable=False)
79
65
  created_at: Mapped[datetime] = mapped_column(nullable=False)
80
66
  expires_at: Mapped[datetime] = mapped_column(nullable=False)
81
-
67
+
68
+ def __repr__(self):
69
+ return f"<DbApiKey(id='{self.id}', username='{self.username}')>"
70
+
71
+ # Define DbOAuthClient class for this instance
72
+ class DbOAuthClient(self.Base):
73
+ __tablename__ = 'oauth_clients'
74
+
75
+ client_id: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
76
+ client_secret_hash: Mapped[str] = mapped_column(nullable=False)
77
+ client_name: Mapped[str] = mapped_column(nullable=False)
78
+ redirect_uris: Mapped[str] = mapped_column(nullable=False) # JSON array of allowed redirect URIs
79
+ scope: Mapped[str] = mapped_column(nullable=False, default='read')
80
+ grant_types: Mapped[str] = mapped_column(nullable=False, default='authorization_code,refresh_token')
81
+ response_types: Mapped[str] = mapped_column(nullable=False, default='code')
82
+ client_type: Mapped[str] = mapped_column(nullable=False, default='confidential') # 'confidential' or 'public'
83
+ registration_access_token_hash: Mapped[str] = mapped_column(nullable=False) # Token for client management
84
+ created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
85
+ is_active: Mapped[bool] = mapped_column(nullable=False, default=True)
86
+
82
87
  def __repr__(self):
83
- return f"<AccessToken(token_id='{self.token_id}', username='{self.username}')>"
88
+ return f"<DbOAuthClient(client_id='{self.client_id}', name='{self.client_name}')>"
89
+
90
+ # Define DbAuthorizationCode class for this instance
91
+ class DbAuthorizationCode(self.Base):
92
+ __tablename__ = 'authorization_codes'
93
+
94
+ code: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
95
+ client_id: Mapped[str] = mapped_column(ForeignKey('oauth_clients.client_id', ondelete='CASCADE'), nullable=False)
96
+ username: Mapped[str] = mapped_column(ForeignKey('users.username', ondelete='CASCADE'), nullable=False)
97
+ redirect_uri: Mapped[str] = mapped_column(nullable=False)
98
+ scope: Mapped[str] = mapped_column(nullable=True)
99
+ code_challenge: Mapped[str] = mapped_column(nullable=False) # PKCE always required
100
+ code_challenge_method: Mapped[str] = mapped_column(nullable=False) # only S256 is supported
101
+ created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
102
+ expires_at: Mapped[datetime] = mapped_column(nullable=False) # 10 minutes from creation
103
+ used: Mapped[bool] = mapped_column(nullable=False, default=False)
104
+
105
+ def __repr__(self):
106
+ return f"<DbAuthorizationCode(code='{self.code[:8]}...', client_id='{self.client_id}')>"
107
+
108
+ # Define DbOAuthToken class for this instance
109
+ class DbOAuthToken(self.Base):
110
+ __tablename__ = 'oauth_tokens'
111
+
112
+ token_id: Mapped[str] = mapped_column(primary_key=True, default=lambda: uuid.uuid4().hex)
113
+ access_token_hash: Mapped[str] = mapped_column(unique=True, nullable=False)
114
+ refresh_token_hash: Mapped[str] = mapped_column(unique=True, nullable=True) # NULL for client_credentials grants
115
+ client_id: Mapped[str] = mapped_column(ForeignKey('oauth_clients.client_id', ondelete='CASCADE'), nullable=False)
116
+ username: Mapped[str] = mapped_column(ForeignKey('users.username', ondelete='CASCADE'), nullable=False)
117
+ scope: Mapped[str] = mapped_column(nullable=True)
118
+ token_type: Mapped[str] = mapped_column(nullable=False, default='Bearer')
119
+ created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now())
120
+ access_token_expires_at: Mapped[datetime] = mapped_column(nullable=False) # Uses SQRL_AUTH_TOKEN_EXPIRE_MINUTES
121
+ refresh_token_expires_at: Mapped[datetime] = mapped_column(nullable=True) # 30 days from creation, NULL for client_credentials
122
+ is_revoked: Mapped[bool] = mapped_column(nullable=False, default=False)
123
+
124
+ def __repr__(self):
125
+ return f"<DbOAuthToken(token_id='{self.token_id}', client_id='{self.client_id}', username='{self.username}')>"
84
126
 
85
127
  self.DbBaseUser = DbBaseUser
86
- self.DbAccessToken = DbAccessToken
128
+ self.DbApiKey = DbApiKey
129
+ self.DbOAuthClient = DbOAuthClient
130
+ self.DbAuthorizationCode = DbAuthorizationCode
131
+ self.DbOAuthToken = DbOAuthToken
87
132
 
88
- self.User = self._get_user_model(base_path) if cls is None else cls
133
+ self.User = user_cls
89
134
  self.DbUser: type[DbBaseUser] = self._initialize_db_user_model(self.User)
135
+
136
+ self.auth_providers = [provider_function(auth_args) for provider_function in provider_functions]
90
137
 
91
138
  if sa_engine is None:
92
- sqlite_relative_path = env_vars.get(c.SQRL_AUTH_DB_FILE_PATH, f"{c.TARGET_FOLDER}/{c.DB_FILE}")
139
+ sqlite_relative_path = self.env_vars.get(c.SQRL_AUTH_DB_FILE_PATH, f"{c.TARGET_FOLDER}/{c.DB_FILE}")
93
140
  sqlite_path = u.Path(base_path, sqlite_relative_path)
94
141
  sqlite_path.parent.mkdir(parents=True, exist_ok=True)
95
142
  self.engine = create_engine(f"sqlite:///{str(sqlite_path)}")
@@ -181,7 +228,7 @@ class Authenticator(_t.Generic[User]):
181
228
  new_columns = model_columns - existing_columns
182
229
  if new_columns:
183
230
  add_columns_msg = f"Adding columns to database: {new_columns}"
184
- print("NOTE:", add_columns_msg)
231
+ print("NOTE -", add_columns_msg)
185
232
  self.logger.info(add_columns_msg)
186
233
 
187
234
  for col_name in new_columns:
@@ -203,7 +250,7 @@ class Authenticator(_t.Generic[User]):
203
250
  columns_to_drop = dropped_columns.intersection(existing_columns)
204
251
  if columns_to_drop:
205
252
  drop_columns_msg = f"Dropping columns from database: {columns_to_drop}"
206
- print("NOTE:", drop_columns_msg)
253
+ print("NOTE -", drop_columns_msg)
207
254
  self.logger.info(drop_columns_msg)
208
255
 
209
256
  for col_name in columns_to_drop:
@@ -272,7 +319,7 @@ class Authenticator(_t.Generic[User]):
272
319
  try:
273
320
  user_data = self.User(**user_fields, username=username).model_dump(mode='json')
274
321
  except ValidationError as e:
275
- raise InvalidInputError(102, f"Invalid user field '{e.errors()[0]['loc'][0]}': {e.errors()[0]['msg']}")
322
+ raise InvalidInputError(400, "Invalid user data", f"Invalid user field '{e.errors()[0]['loc'][0]}': {e.errors()[0]['msg']}")
276
323
 
277
324
  # Add a new user
278
325
  try:
@@ -280,19 +327,20 @@ class Authenticator(_t.Generic[User]):
280
327
  existing_user = session.get(self.DbUser, username)
281
328
  if existing_user is not None:
282
329
  if not update_user:
283
- raise InvalidInputError(101, f"User '{username}' already exists")
330
+ raise InvalidInputError(400, "Username already exists", f"User '{username}' already exists")
331
+
332
+ if username == c.ADMIN_USERNAME and user_data.get("is_admin") is False:
333
+ raise InvalidInputError(403, "Non-admin 'admin' user not permitted", "Setting the admin user to non-admin is not permitted")
284
334
 
285
- if username == c.ADMIN_USERNAME:
286
- raise InvalidInputError(24, "Changing the admin user is not permitted")
287
335
  new_user = self.DbUser(password_hash=existing_user.password_hash, **user_data)
288
336
  session.delete(existing_user)
289
337
  else:
290
338
  if update_user:
291
- raise InvalidInputError(41, f"No user found for username: {username}")
339
+ raise InvalidInputError(404, "No user found for username", f"No user found for username: {username}")
292
340
 
293
341
  password = user_fields.get('password')
294
342
  if password is None:
295
- raise InvalidInputError(100, f"Missing required field 'password' when adding a new user")
343
+ raise InvalidInputError(400, "Missing required field 'password'", f"Missing required field 'password' when adding a new user")
296
344
  password_hash = pwd_context.hash(password)
297
345
  new_user = self.DbUser(password_hash=password_hash, **user_data)
298
346
 
@@ -304,6 +352,28 @@ class Authenticator(_t.Generic[User]):
304
352
 
305
353
  finally:
306
354
  session.close()
355
+
356
+ def create_or_get_user_from_provider(self, provider_name: str, user_info: dict) -> User:
357
+ provider = next((p for p in self.auth_providers if p.name == provider_name), None)
358
+ if provider is None:
359
+ raise InvalidInputError(404, "Provider not found", f"Provider '{provider_name}' not found")
360
+
361
+ user = provider.provider_configs.get_user(user_info)
362
+ session = self.Session()
363
+ try:
364
+ db_user = session.get(self.DbUser, user.username)
365
+ if db_user is None:
366
+ # Create new user
367
+ user_data = user.model_dump()
368
+ password_hash = "" # No hash makes it impossible to login with username and password
369
+ db_user = self.DbUser(password_hash=password_hash, **user_data)
370
+ session.add(db_user)
371
+ session.commit()
372
+
373
+ return self.User.model_validate(db_user)
374
+
375
+ finally:
376
+ session.close()
307
377
 
308
378
  def get_user(self, username: str, password: str) -> User:
309
379
  session = self.Session()
@@ -315,7 +385,7 @@ class Authenticator(_t.Generic[User]):
315
385
  user = self.User.model_validate(db_user)
316
386
  return user # type: ignore
317
387
  else:
318
- raise InvalidInputError(0, f"Username or password not found")
388
+ raise InvalidInputError(401, "Incorrect username or password", f"Incorrect username or password")
319
389
 
320
390
  finally:
321
391
  session.close()
@@ -325,114 +395,132 @@ class Authenticator(_t.Generic[User]):
325
395
  try:
326
396
  db_user = session.get(self.DbUser, username)
327
397
  if db_user is None:
328
- raise InvalidInputError(2, f"User not found")
398
+ raise InvalidInputError(401, "User not found", f"Username '{username}' not found for password change")
329
399
 
330
- if pwd_context.verify(old_password, db_user.password_hash):
400
+ if db_user.password_hash and pwd_context.verify(old_password, db_user.password_hash):
331
401
  db_user.password_hash = pwd_context.hash(new_password)
332
402
  session.commit()
333
403
  else:
334
- raise InvalidInputError(3, f"Incorrect password")
404
+ raise InvalidInputError(401, "Incorrect password", f"Incorrect password")
335
405
  finally:
336
406
  session.close()
337
407
 
338
408
  def delete_user(self, username: str) -> None:
339
409
  if username == c.ADMIN_USERNAME:
340
- raise InvalidInputError(23, "Cannot delete the admin user")
410
+ raise InvalidInputError(403, "Cannot delete admin user", "Cannot delete the admin user")
341
411
 
342
412
  session = self.Session()
343
413
  try:
344
414
  db_user = session.get(self.DbUser, username)
345
415
  if db_user is None:
346
- raise InvalidInputError(41, f"No user found for username: {username}")
416
+ raise InvalidInputError(404, "No user found for username", f"No user found for username: {username}")
347
417
  session.delete(db_user)
348
418
  session.commit()
349
419
  finally:
350
420
  session.close()
351
421
 
352
- def get_all_users(self) -> list[User]:
422
+ def get_all_users(self) -> list:
353
423
  session = self.Session()
354
424
  try:
355
425
  db_users = session.query(self.DbUser).all()
356
- return [self.User.model_validate(user) for user in db_users] # type: ignore
426
+ return [self.User.model_validate(user) for user in db_users]
357
427
  finally:
358
428
  session.close()
359
429
 
360
430
  def create_access_token(self, user: User, expiry_minutes: int | None, *, title: str | None = None) -> tuple[str, datetime]:
431
+ """
432
+ Creates an API key if title is provided. Otherwise, creates a JWT token.
433
+ """
361
434
  created_at = datetime.now(timezone.utc)
362
435
  expire_at = created_at + timedelta(minutes=expiry_minutes) if expiry_minutes is not None else datetime.max
363
- token_id = None
436
+
437
+ if self.secret_key is None:
438
+ raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to create an access token")
439
+
364
440
  if title is not None:
365
441
  session = self.Session()
366
442
  try:
367
- access_token = self.DbAccessToken(title=title, username=user.username, created_at=created_at, expires_at=expire_at)
368
- session.add(access_token)
443
+ token_id = "sqrl-" + uuid.uuid4().hex
444
+ hashed_key = u.hash_string(token_id, salt=self.secret_key)
445
+ api_key = self.DbApiKey(hashed_key=hashed_key, title=title, username=user.username, created_at=created_at, expires_at=expire_at)
446
+ session.add(api_key)
369
447
  session.commit()
370
- token_id = access_token.token_id
371
448
  finally:
372
449
  session.close()
450
+ else:
451
+ to_encode = {"username": user.username, "exp": expire_at}
452
+ token_id = jwt.encode(to_encode, self.secret_key, algorithm="HS256")
373
453
 
374
- if self.secret_key is None:
375
- raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to create an access token")
376
- to_encode = {"username": user.username, "token_id": token_id, "exp": expire_at}
377
- encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm="HS256")
378
- return encoded_jwt, expire_at
454
+ return token_id, expire_at
379
455
 
380
456
  def get_user_from_token(self, token: str | None) -> User | None:
381
- if token is None or token == "":
457
+ """
458
+ Get a user from an access token (JWT, or API key if token starts with 'sqrl-')
459
+ """
460
+ if not token:
382
461
  return None
383
462
 
384
463
  if self.secret_key is None:
385
464
  raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required to get user from an access token")
386
465
 
387
- try:
388
- payload: dict = jwt.decode(token, self.secret_key, algorithms=["HS256"])
389
- except InvalidTokenError:
390
- raise InvalidInputError(1, "Invalid authorization token")
391
-
392
466
  session = self.Session()
393
467
  try:
394
- if payload.get("token_id") is not None:
395
- access_token = session.query(self.DbAccessToken).filter(
396
- self.DbAccessToken.username == payload["username"],
397
- self.DbAccessToken.token_id == payload["token_id"],
398
- self.DbAccessToken.expires_at >= func.now()
468
+ if token.startswith("sqrl-"):
469
+ hashed_key = u.hash_string(token, salt=self.secret_key)
470
+ api_key = session.query(self.DbApiKey).filter(
471
+ self.DbApiKey.hashed_key == hashed_key,
472
+ self.DbApiKey.expires_at >= func.now()
399
473
  ).first()
400
- if access_token is None:
401
- raise InvalidInputError(1, "Invalid authorization token")
402
-
403
- db_user = session.get(self.DbUser, payload["username"])
474
+ if api_key is None:
475
+ raise InvalidTokenError()
476
+ username = api_key.username
477
+ else:
478
+ payload: dict = jwt.decode(token, self.secret_key, algorithms=["HS256"])
479
+ username = payload["username"]
480
+
481
+ db_user = session.get(self.DbUser, username)
404
482
  if db_user is None:
405
- raise InvalidInputError(1, "Invalid authorization token")
483
+ raise InvalidTokenError()
484
+
485
+ except InvalidTokenError:
486
+ raise InvalidInputError(401, "Invalid authorization token", "Invalid authorization token")
406
487
  finally:
407
488
  session.close()
408
489
 
409
490
  user = self.User.model_validate(db_user)
410
491
  return user # type: ignore
411
492
 
412
- def get_all_tokens(self, username: str) -> list[AccessToken]:
493
+ def get_all_api_keys(self, username: str) -> list[ApiKey]:
494
+ """
495
+ Get the ID, title, and expiry date of all API keys for a user. Note that the ID is a hash of the API key, not the API key itself.
496
+ """
413
497
  session = self.Session()
414
498
  try:
415
- tokens = session.query(self.DbAccessToken).filter(
416
- self.DbAccessToken.username == username,
417
- self.DbAccessToken.expires_at >= func.now()
499
+ tokens = session.query(self.DbApiKey).filter(
500
+ self.DbApiKey.username == username,
501
+ self.DbApiKey.expires_at >= func.now()
418
502
  ).all()
419
503
 
420
- return [AccessToken.model_validate(token) for token in tokens]
504
+ return [ApiKey.model_validate(token) for token in tokens]
421
505
  finally:
422
506
  session.close()
423
507
 
424
- def revoke_token(self, username: str, token_id: str) -> None:
508
+ def revoke_api_key(self, username: str, api_key_id: str) -> None:
509
+ """
510
+ Revoke an API key
511
+ """
425
512
  session = self.Session()
426
513
  try:
427
- access_token = session.query(self.DbAccessToken).filter(
428
- self.DbAccessToken.username == username,
429
- self.DbAccessToken.token_id == token_id
514
+
515
+ api_key = session.query(self.DbApiKey).filter(
516
+ self.DbApiKey.username == username,
517
+ self.DbApiKey.id == api_key_id
430
518
  ).first()
431
519
 
432
- if access_token is None:
433
- raise InvalidInputError(40, f"No access token found for token_id: {token_id}")
520
+ if api_key is None:
521
+ raise InvalidInputError(404, "API key not found", f"The API key could not be found: {api_key_id}")
434
522
 
435
- session.delete(access_token)
523
+ session.delete(api_key)
436
524
  session.commit()
437
525
  finally:
438
526
  session.close()
@@ -447,5 +535,470 @@ class Authenticator(_t.Generic[User]):
447
535
 
448
536
  return user_level.value >= scope.value
449
537
 
538
+ # OAuth Client Management Methods
539
+
540
+ def generate_secret_and_hash(self) -> tuple[str, str]:
541
+ """Generate a secure access token and its hash"""
542
+ secret = secrets.token_urlsafe(64)
543
+ secret_hash = pwd_context.hash(secret)
544
+ return secret, secret_hash
545
+
546
+ def _validate_client_registration_request(self, request: ClientRegistrationRequest | ClientUpdateRequest) -> dict:
547
+ updates = {}
548
+ if request.client_name:
549
+ updates['client_name'] = request.client_name
550
+
551
+ # Validate redirect_uris if being updated
552
+ if request.redirect_uris:
553
+ for uri in request.redirect_uris:
554
+ if not self._validate_redirect_uri_format(uri):
555
+ raise InvalidInputError(400, "invalid_redirect_uri", f"Invalid redirect URI format: {uri}")
556
+ updates['redirect_uris'] = json.dumps(request.redirect_uris)
557
+
558
+ # Validate grant_types if being updated
559
+ if request.grant_types:
560
+ if not all(grant_type in c.SUPPORTED_GRANT_TYPES for grant_type in request.grant_types):
561
+ raise InvalidInputError(400, "invalid_grant_types", f"Invalid grant types. Supported grant types are: {c.SUPPORTED_GRANT_TYPES}")
562
+ updates['grant_types'] = ','.join(request.grant_types)
563
+
564
+ # Validate response_types if being updated
565
+ if request.response_types:
566
+ if not all(response_type in c.SUPPORTED_RESPONSE_TYPES for response_type in request.response_types):
567
+ raise InvalidInputError(400, "invalid_response_types", f"Invalid response types. Supported response types are: {c.SUPPORTED_RESPONSE_TYPES}")
568
+ updates['response_types'] = ','.join(request.response_types)
569
+
570
+ # Validate scope if being updated
571
+ if request.scope:
572
+ scopes = request.scope.split()
573
+ if not all(scope in c.SUPPORTED_SCOPES for scope in scopes):
574
+ raise InvalidInputError(400, "invalid_scope", f"Invalid scope. Supported scopes are: {c.SUPPORTED_SCOPES}")
575
+ updates['scope'] = ','.join(scopes)
576
+
577
+ return updates
578
+
579
+ def register_oauth_client(
580
+ self, request: ClientRegistrationRequest, client_management_path_format: str
581
+ ) -> ClientRegistrationResponse:
582
+ """Register a new OAuth client and return client_id, client_secret, and registration_access_token"""
583
+ grant_types = request.grant_types
584
+ if grant_types is None:
585
+ grant_types = ['authorization_code', 'refresh_token']
586
+
587
+ # Validate request
588
+ self._validate_client_registration_request(request)
589
+
590
+ # Generate secure client credentials and registration access token
591
+ client_id = secrets.token_urlsafe(16)
592
+ client_secret, client_secret_hash = self.generate_secret_and_hash()
593
+
594
+ registration_access_token, registration_access_token_hash = self.generate_secret_and_hash()
595
+ registration_client_uri = client_management_path_format.format(client_id=client_id)
596
+
597
+ session = self.Session()
598
+ try:
599
+ oauth_client = self.DbOAuthClient(
600
+ client_id=client_id,
601
+ client_secret_hash=client_secret_hash,
602
+ client_name=request.client_name,
603
+ redirect_uris=json.dumps(request.redirect_uris),
604
+ scope=request.scope,
605
+ grant_types=','.join(grant_types),
606
+ registration_access_token_hash=registration_access_token_hash
607
+ )
608
+ session.add(oauth_client)
609
+ session.commit()
610
+
611
+ return ClientRegistrationResponse(
612
+ client_id=client_id,
613
+ client_secret=client_secret,
614
+ client_name=request.client_name,
615
+ redirect_uris=request.redirect_uris,
616
+ scope=request.scope,
617
+ grant_types=grant_types,
618
+ response_types=request.response_types,
619
+ created_at=datetime.now(timezone.utc),
620
+ is_active=True,
621
+ registration_client_uri=registration_client_uri,
622
+ registration_access_token=registration_access_token,
623
+ )
624
+
625
+ finally:
626
+ session.close()
627
+
628
+ def get_oauth_client_details(self, client_id: str) -> ClientDetailsResponse:
629
+ """Get OAuth client details with parsed JSON fields"""
630
+ session = self.Session()
631
+ try:
632
+ client = session.get(self.DbOAuthClient, client_id)
633
+ if client is None:
634
+ raise InvalidInputError(404, "invalid_client_id", "Client not found")
635
+ if not client.is_active:
636
+ raise InvalidInputError(404, "invalid_client_id", "Client is no longer active")
637
+
638
+ return ClientDetailsResponse(
639
+ client_id=client.client_id,
640
+ client_name=client.client_name,
641
+ redirect_uris=json.loads(client.redirect_uris),
642
+ scope=client.scope,
643
+ grant_types=client.grant_types.split(','),
644
+ response_types=client.response_types.split(','),
645
+ created_at=client.created_at,
646
+ is_active=client.is_active
647
+ )
648
+ finally:
649
+ session.close()
650
+
651
+ def validate_client_credentials(self, client_id: str, client_secret: str) -> bool:
652
+ """Validate OAuth client credentials"""
653
+ session = self.Session()
654
+ try:
655
+ client = session.get(self.DbOAuthClient, client_id)
656
+ if client is None or not client.is_active:
657
+ return False
658
+ return pwd_context.verify(client_secret, client.client_secret_hash)
659
+ finally:
660
+ session.close()
661
+
662
+ def validate_redirect_uri(self, client_id: str, redirect_uri: str) -> bool:
663
+ """Validate that redirect_uri is registered for the client"""
664
+ session = self.Session()
665
+ try:
666
+ client = session.get(self.DbOAuthClient, client_id)
667
+ if client is None or not client.is_active:
668
+ return False
669
+
670
+ registered_uris = json.loads(client.redirect_uris)
671
+ return redirect_uri in registered_uris
672
+ finally:
673
+ session.close()
674
+
675
+ def validate_registration_access_token(self, client_id: str, registration_access_token: str) -> bool:
676
+ """Validate registration access token for client management operations"""
677
+ session = self.Session()
678
+ try:
679
+ client = session.get(self.DbOAuthClient, client_id)
680
+ if client is None:
681
+ return False
682
+
683
+ return pwd_context.verify(registration_access_token, client.registration_access_token_hash)
684
+ finally:
685
+ session.close()
686
+
687
+ def _validate_redirect_uri_format(self, uri: str) -> bool:
688
+ """Validate redirect URI format for security"""
689
+ # Basic validation - must be https (except localhost) and not contain fragments
690
+ if '#' in uri:
691
+ return False
692
+
693
+ if uri.startswith('http://'):
694
+ # Only allow http for localhost/development
695
+ if not ('localhost' in uri or '127.0.0.1' in uri):
696
+ return False
697
+ elif not uri.startswith('https://'):
698
+ # Custom schemes allowed for mobile apps
699
+ if '://' not in uri:
700
+ return False
701
+
702
+ return True
703
+
704
+ def update_oauth_client_with_token_rotation(self, client_id: str, request: ClientUpdateRequest) -> ClientUpdateResponse:
705
+ """Update OAuth client metadata and rotate registration access token"""
706
+
707
+ # Build update dictionary, excluding None values
708
+ updates = self._validate_client_registration_request(request)
709
+
710
+ session = self.Session()
711
+ try:
712
+ client = session.get(self.DbOAuthClient, client_id)
713
+ if client is None:
714
+ raise InvalidInputError(404, "invalid_client_id", f"OAuth client not found: {client_id}")
715
+ if not client.is_active and not request.is_active:
716
+ raise InvalidInputError(400, "invalid_request", "Cannot update a deactivated client")
717
+
718
+ if request.is_active is not None:
719
+ client.is_active = request.is_active
720
+
721
+ # Generate new registration access token
722
+ new_registration_access_token, new_registration_access_token_hash = self.generate_secret_and_hash()
723
+
724
+ # Update client fields
725
+ for key, value in updates.items():
726
+ if hasattr(client, key):
727
+ setattr(client, key, value)
728
+
729
+ # Update registration access token
730
+ client.registration_access_token_hash = new_registration_access_token_hash
731
+
732
+ session.commit()
733
+
734
+ return ClientUpdateResponse(
735
+ client_id=client.client_id,
736
+ client_name=client.client_name,
737
+ redirect_uris=json.loads(client.redirect_uris),
738
+ scope=client.scope,
739
+ grant_types=client.grant_types.split(','),
740
+ response_types=client.response_types.split(','),
741
+ created_at=client.created_at,
742
+ is_active=client.is_active,
743
+ registration_access_token=new_registration_access_token,
744
+ )
745
+
746
+ finally:
747
+ session.close()
748
+
749
+ def revoke_oauth_client(self, client_id: str) -> None:
750
+ """Revoke (deactivate) an OAuth client"""
751
+ session = self.Session()
752
+ try:
753
+ client = session.get(self.DbOAuthClient, client_id)
754
+ if client is None:
755
+ raise InvalidInputError(404, "client_not_found", f"OAuth client not found: {client_id}")
756
+
757
+ client.is_active = False
758
+ session.commit()
759
+ finally:
760
+ session.close()
761
+
762
+ # Authorization Code Flow Methods
763
+
764
+ def create_authorization_code(
765
+ self, client_id: str, username: str, redirect_uri: str, scope: str,
766
+ code_challenge: str, code_challenge_method: str = "S256"
767
+ ) -> str:
768
+ """Create and store an authorization code for the authorization code flow"""
769
+
770
+ # Validate client_id
771
+ client_details = self.get_oauth_client_details(client_id)
772
+ if client_details is None or not client_details.is_active:
773
+ raise InvalidInputError(400, "invalid_client", "Invalid or inactive client")
774
+
775
+ # Validate redirect_uri
776
+ if not self.validate_redirect_uri(client_id, redirect_uri):
777
+ raise InvalidInputError(400, "invalid_redirect_uri", "Invalid redirect_uri for this client")
778
+
779
+ # Validate PKCE parameters
780
+ if not code_challenge:
781
+ raise InvalidInputError(400, "invalid_request", "code_challenge is required")
782
+
783
+ if code_challenge_method not in ["S256"]:
784
+ raise InvalidInputError(400, "invalid_request", "code_challenge_method must be 'S256'")
785
+
786
+ # Generate authorization code
787
+ code = secrets.token_urlsafe(48)
788
+
789
+ # Set expiration (10 minutes from now)
790
+ created_at = datetime.now(timezone.utc)
791
+ expires_at = created_at + timedelta(minutes=10)
792
+
793
+ session = self.Session()
794
+ try:
795
+ auth_code = self.DbAuthorizationCode(
796
+ code=code,
797
+ client_id=client_id,
798
+ username=username,
799
+ redirect_uri=redirect_uri,
800
+ scope=scope,
801
+ code_challenge=code_challenge,
802
+ code_challenge_method=code_challenge_method,
803
+ created_at=created_at,
804
+ expires_at=expires_at,
805
+ used=False
806
+ )
807
+ session.add(auth_code)
808
+ session.commit()
809
+
810
+ return code
811
+
812
+ finally:
813
+ session.close()
814
+
815
+ def exchange_authorization_code(
816
+ self, code: str, client_id: str, redirect_uri: str, code_verifier: str, access_token_expiry_minutes: int
817
+ ) -> TokenResponse:
818
+ """
819
+ Exchange authorization code for access and refresh tokens
820
+ Returns (access_token, refresh_token)
821
+ """
822
+ session = self.Session()
823
+ try:
824
+ # Get and validate authorization code
825
+ auth_code = session.query(self.DbAuthorizationCode).filter(
826
+ self.DbAuthorizationCode.code == code,
827
+ self.DbAuthorizationCode.client_id == client_id,
828
+ self.DbAuthorizationCode.expires_at >= func.now(),
829
+ self.DbAuthorizationCode.used == False
830
+ ).first()
831
+
832
+ if auth_code is None:
833
+ raise InvalidInputError(400, "invalid_grant", "Invalid authorization code")
834
+
835
+ # Validate redirect URI
836
+ if auth_code.redirect_uri != redirect_uri:
837
+ raise InvalidInputError(400, "invalid_grant", "Redirect URI mismatch")
838
+
839
+ if not u.validate_pkce_challenge(code_verifier, auth_code.code_challenge):
840
+ raise InvalidInputError(400, "invalid_grant", "Invalid code_verifier")
841
+
842
+ # Get user
843
+ user = session.get(self.DbUser, auth_code.username)
844
+ if user is None:
845
+ raise InvalidInputError(400, "invalid_grant", "User not found")
846
+
847
+ # Mark authorization code as used
848
+ auth_code.used = True
849
+
850
+ # Generate tokens
851
+ user_obj = self.User.model_validate(user)
852
+ access_token, token_expires_at = self.create_access_token(user_obj, expiry_minutes=access_token_expiry_minutes)
853
+ access_token_hash = pwd_context.hash(access_token)
854
+
855
+ # Generate refresh token
856
+ refresh_token, refresh_token_hash = self.generate_secret_and_hash()
857
+ refresh_expires_at = datetime.now(timezone.utc) + timedelta(days=30)
858
+
859
+ oauth_token = self.DbOAuthToken(
860
+ access_token_hash=access_token_hash,
861
+ refresh_token_hash=refresh_token_hash,
862
+ client_id=client_id,
863
+ username=auth_code.username,
864
+ scope=auth_code.scope,
865
+ access_token_expires_at=token_expires_at,
866
+ refresh_token_expires_at=refresh_expires_at
867
+ )
868
+ session.add(oauth_token)
869
+
870
+ session.commit()
871
+
872
+ return TokenResponse(
873
+ access_token=access_token,
874
+ expires_in=access_token_expiry_minutes*60,
875
+ refresh_token=refresh_token
876
+ )
877
+
878
+ finally:
879
+ session.close()
880
+
881
+ def refresh_oauth_access_token(self, refresh_token: str, client_id: str, access_token_expiry_minutes: int) -> TokenResponse:
882
+ """
883
+ Refresh OAuth access token using refresh token
884
+ Returns (access_token, new_refresh_token)
885
+ """
886
+ session = self.Session()
887
+ try:
888
+ # Validate client
889
+ client = session.get(self.DbOAuthClient, client_id)
890
+ if client is None or not client.is_active:
891
+ raise InvalidInputError(400, "invalid_client", "Invalid or inactive client")
892
+
893
+ # Find active refresh token for this client
894
+ oauth_token = session.query(self.DbOAuthToken).filter(
895
+ self.DbOAuthToken.client_id == client_id,
896
+ self.DbOAuthToken.refresh_token_expires_at >= func.now(),
897
+ self.DbOAuthToken.is_revoked == False
898
+ ).first()
899
+
900
+ # Find the token that matches our refresh token
901
+ if oauth_token is None or not pwd_context.verify(refresh_token, oauth_token.refresh_token_hash):
902
+ raise InvalidInputError(400, "invalid_grant", "Invalid or expired refresh token")
903
+
904
+ # Get user
905
+ user = session.get(self.DbUser, oauth_token.username)
906
+ if user is None:
907
+ raise InvalidInputError(400, "invalid_client", "User not found")
908
+
909
+ # Check secret key is available
910
+ if self.secret_key is None:
911
+ raise ConfigurationError(f"Environment variable '{c.SQRL_SECRET_KEY}' is required for OAuth token operations")
912
+
913
+ # Generate new tokens
914
+ user_obj = self.User.model_validate(user)
915
+ access_token, token_expires_at = self.create_access_token(user_obj, expiry_minutes=access_token_expiry_minutes)
916
+ access_token_hash = pwd_context.hash(access_token)
917
+
918
+ # Generate new refresh token
919
+ new_refresh_token, new_refresh_token_hash = self.generate_secret_and_hash()
920
+ refresh_expires_at = datetime.now(timezone.utc) + timedelta(days=30)
921
+
922
+ # Revoke old token
923
+ oauth_token.is_revoked = True
924
+
925
+ # Create new token entry
926
+ new_oauth_token = self.DbOAuthToken(
927
+ access_token_hash=access_token_hash,
928
+ refresh_token_hash=new_refresh_token_hash,
929
+ client_id=client_id,
930
+ username=oauth_token.username,
931
+ scope=oauth_token.scope,
932
+ access_token_expires_at=token_expires_at,
933
+ refresh_token_expires_at=refresh_expires_at
934
+ )
935
+ session.add(new_oauth_token)
936
+
937
+ session.commit()
938
+
939
+ return TokenResponse(
940
+ access_token=access_token,
941
+ expires_in=access_token_expiry_minutes*60,
942
+ refresh_token=new_refresh_token
943
+ )
944
+
945
+ finally:
946
+ session.close()
947
+
948
+ def revoke_oauth_token(self, client_id: str, token: str, token_type_hint: str | None = None) -> None:
949
+ """
950
+ Revoke an OAuth refresh token
951
+ token_type_hint is optional or must be 'refresh_token'. Revoking access token is not supported yet.
952
+ """
953
+ if token_type_hint and token_type_hint != 'refresh_token':
954
+ raise InvalidInputError(400, "invalid_request", "Only refresh tokens can be revoked")
955
+
956
+ session = self.Session()
957
+ try:
958
+ # Validate client
959
+ client = session.get(self.DbOAuthClient, client_id)
960
+ if client is None or not client.is_active:
961
+ raise InvalidInputError(400, "invalid_client", "Invalid or inactive client")
962
+
963
+ # Get all potentially matching tokens
964
+ oauth_tokens = session.query(self.DbOAuthToken).filter(
965
+ self.DbOAuthToken.client_id == client_id,
966
+ self.DbOAuthToken.refresh_token_expires_at >= func.now(),
967
+ self.DbOAuthToken.is_revoked == False
968
+ ).all()
969
+
970
+ # Find the token that matches
971
+ oauth_token = None
972
+ for token_obj in oauth_tokens:
973
+ if pwd_context.verify(token, token_obj.refresh_token_hash):
974
+ oauth_token = token_obj
975
+ break
976
+
977
+ # Revoke token if found (per OAuth spec, always return success)
978
+ if oauth_token:
979
+ oauth_token.is_revoked = True
980
+ session.commit()
981
+
982
+ finally:
983
+ session.close()
984
+
450
985
  def close(self) -> None:
451
986
  self.engine.dispose()
987
+
988
+
989
+ def provider(name: str, label: str, icon: str):
990
+ """
991
+ Decorator to register an authentication provider
992
+
993
+ Arguments:
994
+ name: The name of the provider (must be unique, e.g. 'google')
995
+ label: The label of the provider (e.g. 'Google')
996
+ icon: The URL of the icon of the provider (e.g. 'https://www.google.com/favicon.ico')
997
+ """
998
+ def decorator(func: Callable[[AuthProviderArgs], ProviderConfigs]):
999
+ def wrapper(sqrl: AuthProviderArgs):
1000
+ provider_configs = func(sqrl)
1001
+ return AuthProvider(name=name, label=label, icon=icon, provider_configs=provider_configs)
1002
+ Authenticator.providers.append(wrapper)
1003
+ return wrapper
1004
+ return decorator