auth-middleware 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. auth_middleware-0.1.0/LICENSE +21 -0
  2. auth_middleware-0.1.0/PKG-INFO +98 -0
  3. auth_middleware-0.1.0/README.md +79 -0
  4. auth_middleware-0.1.0/pyproject.toml +31 -0
  5. auth_middleware-0.1.0/src/auth_middleware/__init__.py +20 -0
  6. auth_middleware-0.1.0/src/auth_middleware/exceptions.py +5 -0
  7. auth_middleware-0.1.0/src/auth_middleware/functions.py +33 -0
  8. auth_middleware-0.1.0/src/auth_middleware/group_checker.py +34 -0
  9. auth_middleware-0.1.0/src/auth_middleware/jwt_auth_middleware.py +117 -0
  10. auth_middleware-0.1.0/src/auth_middleware/jwt_auth_provider.py +67 -0
  11. auth_middleware-0.1.0/src/auth_middleware/jwt_bearer_manager.py +83 -0
  12. auth_middleware-0.1.0/src/auth_middleware/logging.py +33 -0
  13. auth_middleware-0.1.0/src/auth_middleware/providers/__init__.py +0 -0
  14. auth_middleware-0.1.0/src/auth_middleware/providers/cognito/__init__.py +0 -0
  15. auth_middleware-0.1.0/src/auth_middleware/providers/cognito/cognito_provider.py +93 -0
  16. auth_middleware-0.1.0/src/auth_middleware/providers/cognito/exceptions.py +8 -0
  17. auth_middleware-0.1.0/src/auth_middleware/providers/cognito/settings.py +44 -0
  18. auth_middleware-0.1.0/src/auth_middleware/providers/entra_id/__init__.py +0 -0
  19. auth_middleware-0.1.0/src/auth_middleware/providers/entra_id/entra_id_provider.py +125 -0
  20. auth_middleware-0.1.0/src/auth_middleware/providers/entra_id/exceptions.py +8 -0
  21. auth_middleware-0.1.0/src/auth_middleware/providers/entra_id/settings.py +33 -0
  22. auth_middleware-0.1.0/src/auth_middleware/settings.py +33 -0
  23. auth_middleware-0.1.0/src/auth_middleware/types.py +62 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Press Any Key
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,98 @@
1
+ Metadata-Version: 2.1
2
+ Name: auth-middleware
3
+ Version: 0.1.0
4
+ Summary: Async Auth Middleware for FastAPI/Starlette
5
+ Author: impalah
6
+ Author-email: impalah@gmail.com
7
+ Requires-Python: >=3.12,<4.0
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3.12
10
+ Requires-Dist: colorlog (>=6.8.0,<7.0.0)
11
+ Requires-Dist: fastapi (>=0.105.0,<0.106.0)
12
+ Requires-Dist: pydantic[email] (>=2.5.3,<3.0.0)
13
+ Requires-Dist: python-dotenv (>=1.0.0,<2.0.0)
14
+ Requires-Dist: python-jose[cryptography] (>=3.3.0,<4.0.0)
15
+ Requires-Dist: requests (>=2.31.0,<3.0.0)
16
+ Requires-Dist: svix-ksuid (>=0.6.2,<0.7.0)
17
+ Description-Content-Type: text/markdown
18
+
19
+ # auth-middleware
20
+
21
+ Async Auth Middleware for FastAPI/Starlette.
22
+
23
+ ## Technology Stack:
24
+
25
+ - FastAPI
26
+ - Pytest (\*)
27
+
28
+ ## Development environment
29
+
30
+ ### Requirements:
31
+
32
+ - Python >= 3.12 (Pyenv, best option)
33
+ - Poetry as dependency manager
34
+
35
+ ### Activate development environment
36
+
37
+ ```
38
+ poetry install
39
+ ```
40
+
41
+ This will create a new virtual environment (if it does not exists) and will install all the dependencies.
42
+
43
+ To activate the virtual environment use:
44
+
45
+ ```
46
+ poetry shell
47
+ ```
48
+
49
+ ### Add/remove dependencies
50
+
51
+ ```
52
+ poetry add PIP_PACKAGE [-G group.name]
53
+ ```
54
+
55
+ Add dependency to the given group. If not specified will be added to the default group.
56
+
57
+ ```
58
+ poetry remove PIP_PACKAGE [-G group.name]
59
+ ```
60
+
61
+ Remove dependency from the given group
62
+
63
+ ## Tests
64
+
65
+ ### Debug From VS Code
66
+
67
+ Get the path of the virtual environment created by poetry:
68
+
69
+ ```bash
70
+ poetry env info -p
71
+ ```
72
+
73
+ Set in visual studio code the default interpreter to the virtual environment created by poetry.(SHIT+CTRL+P Select interpreter)
74
+
75
+ Launch "Pytest launch" from the run/debug tab.
76
+
77
+ You can set breakpoints and inspections
78
+
79
+ ### Launch tests from command line
80
+
81
+ ```
82
+ poetry run pytest --cov-report term-missing --cov=web_api_template ./tests
83
+ ```
84
+
85
+ This will launch tests and creates a code coverage report.
86
+
87
+ ### Exclude code from coverage
88
+
89
+ When you need to exclude code from the code coverage report set, in the lines or function to be excluded, the line:
90
+
91
+ ```
92
+ # pragma: no cover
93
+ ```
94
+
95
+ See: https://coverage.readthedocs.io/en/6.4.4/excluding.html
96
+
97
+
98
+
@@ -0,0 +1,79 @@
1
+ # auth-middleware
2
+
3
+ Async Auth Middleware for FastAPI/Starlette.
4
+
5
+ ## Technology Stack:
6
+
7
+ - FastAPI
8
+ - Pytest (\*)
9
+
10
+ ## Development environment
11
+
12
+ ### Requirements:
13
+
14
+ - Python >= 3.12 (Pyenv, best option)
15
+ - Poetry as dependency manager
16
+
17
+ ### Activate development environment
18
+
19
+ ```
20
+ poetry install
21
+ ```
22
+
23
+ This will create a new virtual environment (if it does not exists) and will install all the dependencies.
24
+
25
+ To activate the virtual environment use:
26
+
27
+ ```
28
+ poetry shell
29
+ ```
30
+
31
+ ### Add/remove dependencies
32
+
33
+ ```
34
+ poetry add PIP_PACKAGE [-G group.name]
35
+ ```
36
+
37
+ Add dependency to the given group. If not specified will be added to the default group.
38
+
39
+ ```
40
+ poetry remove PIP_PACKAGE [-G group.name]
41
+ ```
42
+
43
+ Remove dependency from the given group
44
+
45
+ ## Tests
46
+
47
+ ### Debug From VS Code
48
+
49
+ Get the path of the virtual environment created by poetry:
50
+
51
+ ```bash
52
+ poetry env info -p
53
+ ```
54
+
55
+ Set in visual studio code the default interpreter to the virtual environment created by poetry.(SHIT+CTRL+P Select interpreter)
56
+
57
+ Launch "Pytest launch" from the run/debug tab.
58
+
59
+ You can set breakpoints and inspections
60
+
61
+ ### Launch tests from command line
62
+
63
+ ```
64
+ poetry run pytest --cov-report term-missing --cov=web_api_template ./tests
65
+ ```
66
+
67
+ This will launch tests and creates a code coverage report.
68
+
69
+ ### Exclude code from coverage
70
+
71
+ When you need to exclude code from the code coverage report set, in the lines or function to be excluded, the line:
72
+
73
+ ```
74
+ # pragma: no cover
75
+ ```
76
+
77
+ See: https://coverage.readthedocs.io/en/6.4.4/excluding.html
78
+
79
+
@@ -0,0 +1,31 @@
1
+ [tool.poetry]
2
+ name = "auth-middleware"
3
+ version = "0.1.0"
4
+ description = "Async Auth Middleware for FastAPI/Starlette"
5
+ authors = ["impalah <impalah@gmail.com>"]
6
+ readme = "README.md"
7
+ packages = [{include = "auth_middleware", from = "src"}]
8
+
9
+ [tool.poetry.dependencies]
10
+ python = "^3.12"
11
+ fastapi = "^0.105.0"
12
+ python-dotenv = "^1.0.0"
13
+ svix-ksuid = "^0.6.2"
14
+ colorlog = "^6.8.0"
15
+ python-jose = {extras = ["cryptography"], version = "^3.3.0"}
16
+ requests = "^2.31.0"
17
+ pydantic = {extras = ["email"], version = "^2.5.3"}
18
+
19
+ [tool.poetry.group.dev.dependencies]
20
+ pytest = "^7.4.4"
21
+ pytest-mock = "^3.12.0"
22
+ pytest-asyncio = "^0.23.3"
23
+ mock = "^5.1.0"
24
+ pytest-cov = "^4.1.0"
25
+ black = "^23.12.1"
26
+ pytest-env = "^1.1.3"
27
+ mypy = "^1.8.0"
28
+
29
+ [build-system]
30
+ requires = ["poetry-core>=1.0.0"]
31
+ build-backend = "poetry.core.masonry.api"
@@ -0,0 +1,20 @@
1
+ from .functions import require_groups, require_user
2
+ from .group_checker import GroupChecker
3
+ from .exceptions import InvalidTokenException
4
+ from .jwt_auth_middleware import JwtAuthMiddleware
5
+ from .jwt_auth_provider import JWTAuthProvider
6
+ from .types import JWK, JWKS, JWTAuthorizationCredentials, User
7
+
8
+ __all__ = [
9
+ "require_groups",
10
+ "require_user",
11
+ "GroupChecker",
12
+ "User",
13
+ "InvalidTokenException",
14
+ "JwtAuthMiddleware",
15
+ "User",
16
+ "JWK",
17
+ "JWKS",
18
+ "JWTAuthorizationCredentials",
19
+ "JWTAuthProvider",
20
+ ]
@@ -0,0 +1,5 @@
1
+ from fastapi import HTTPException
2
+
3
+
4
+ class InvalidTokenException(HTTPException):
5
+ pass
@@ -0,0 +1,33 @@
1
+ from typing import List
2
+
3
+ from fastapi import HTTPException, Request
4
+
5
+ from auth_middleware.group_checker import GroupChecker
6
+ from auth_middleware.settings import settings
7
+
8
+
9
+ def require_groups(allowed_groups: List[str]):
10
+ """Check if the user has the required groups
11
+
12
+ Args:
13
+ allowed_groups (List[str]): _description_
14
+ """
15
+
16
+ def _group_checker(request: Request):
17
+ return GroupChecker(allowed_groups)(request)
18
+
19
+ return _group_checker
20
+
21
+
22
+ def require_user():
23
+ """Check if the user is authenticated"""
24
+
25
+ def _user_checker(request: Request):
26
+
27
+ if settings.AUTH_DISABLED:
28
+ return
29
+
30
+ if not hasattr(request.state, "current_user") or not request.state.current_user:
31
+ raise HTTPException(status_code=401, detail="Authentication required")
32
+
33
+ return _user_checker
@@ -0,0 +1,34 @@
1
+ from typing import List
2
+
3
+ from fastapi import HTTPException, Request
4
+
5
+ from auth_middleware.logging import logger
6
+ from auth_middleware.settings import settings
7
+ from auth_middleware.types import User
8
+
9
+
10
+ class GroupChecker:
11
+ """Controls if user has the required group (user_type)"""
12
+
13
+ __allowed_groups: list = []
14
+
15
+ def __init__(self, allowed_groups: List):
16
+ self.__allowed_groups = allowed_groups
17
+
18
+ def __call__(self, request: Request):
19
+
20
+ if settings.AUTH_DISABLED:
21
+ return
22
+
23
+ if not hasattr(request.state, "current_user") or not request.state.current_user:
24
+ raise HTTPException(status_code=401, detail="Authentication required")
25
+
26
+ user: User = request.state.current_user
27
+
28
+ if user.groups is not None and not any(
29
+ group in self.__allowed_groups for group in user.groups
30
+ ):
31
+ logger.debug(
32
+ f"User with groups {user.groups} not in {self.__allowed_groups}"
33
+ )
34
+ raise HTTPException(status_code=403, detail="Operation not allowed")
@@ -0,0 +1,117 @@
1
+ from typing import Optional
2
+
3
+ from fastapi import Request, status
4
+ from fastapi.security.utils import get_authorization_scheme_param
5
+ from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
6
+ from starlette.responses import JSONResponse, Response
7
+
8
+ from auth_middleware.exceptions import InvalidTokenException
9
+ from auth_middleware.jwt_auth_provider import JWTAuthProvider
10
+ from auth_middleware.jwt_bearer_manager import JWTBearerManager
11
+ from auth_middleware.logging import logger
12
+ from auth_middleware.types import JWTAuthorizationCredentials, User
13
+
14
+
15
+ class JwtAuthMiddleware(BaseHTTPMiddleware):
16
+ """JWT Authorization middleware for FastAPI
17
+ Adds the current user to the request state.
18
+
19
+ Args:
20
+ BaseHTTPMiddleware (_type_): _description_
21
+ """
22
+
23
+ _auth_provider: JWTAuthProvider
24
+ _jwt_bearer_manager = JWTBearerManager
25
+
26
+ def __init__(self, auth_provider: JWTAuthProvider, *args, **kwargs):
27
+ super().__init__(*args, **kwargs)
28
+ self._auth_provider = auth_provider
29
+ self._jwt_bearer_manager = JWTBearerManager(
30
+ auth_provider=self._auth_provider,
31
+ )
32
+
33
+ async def dispatch(
34
+ self, request: Request, call_next: RequestResponseEndpoint
35
+ ) -> Response | JSONResponse:
36
+ try:
37
+ request.state.current_user = await self.get_current_user(request=request)
38
+ except InvalidTokenException as ite:
39
+ logger.error("Invalid Token %s", str(ite))
40
+ return JSONResponse(
41
+ status_code=status.HTTP_401_UNAUTHORIZED,
42
+ content={"detail": "Invalid token"},
43
+ headers={"WWW-Authenticate": "Bearer"},
44
+ )
45
+ except Exception as e:
46
+ logger.error("Error in AuthMiddleware: %s", str(e))
47
+ return JSONResponse(
48
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
49
+ content={"detail": f"Server error: {str(e)}"},
50
+ )
51
+
52
+ response = await call_next(request)
53
+ return response
54
+
55
+ async def get_current_user(self, request: Request) -> User | None:
56
+ """Get current logged in and active user
57
+
58
+
59
+ Raises:
60
+ HTTPException: _description_
61
+
62
+ Returns:
63
+ User: Domain object.
64
+ """
65
+
66
+ logger.debug("Get Current Active User ...")
67
+
68
+ try:
69
+
70
+ if not self.__validate_credentials(request=request):
71
+ logger.debug("There are no credentiasl in the request")
72
+ return None
73
+
74
+ token: Optional[JWTAuthorizationCredentials] = (
75
+ await self._jwt_bearer_manager.get_credentials(request=request)
76
+ )
77
+
78
+ # Create User object from token
79
+ user: User = (
80
+ self._auth_provider.create_user_from_token(token=token)
81
+ if token
82
+ else self.__create_synthetic_user()
83
+ )
84
+ logger.debug("Returning %s", user)
85
+ return user
86
+ except InvalidTokenException as ite:
87
+ logger.error("Invalid Token %s", str(ite))
88
+ raise
89
+ except Exception as e:
90
+ logger.error("Not controlled exception %s", str(e))
91
+ raise
92
+
93
+ def __validate_credentials(self, request: Request) -> bool:
94
+ """Validate if credentials exist in the request headers
95
+
96
+ Args:
97
+ request (Request): _description_
98
+
99
+ Returns:
100
+ bool: _description_
101
+ """
102
+ authorization = request.headers.get("Authorization")
103
+ scheme, credentials = get_authorization_scheme_param(authorization)
104
+ return bool(authorization and scheme and credentials)
105
+
106
+ def __create_synthetic_user(self) -> User:
107
+ """Create a synthetic user for testing purposes
108
+
109
+ Returns:
110
+ User: Domain object.
111
+ """
112
+ return User(
113
+ id="synthetic",
114
+ name="synthetic",
115
+ groups=[],
116
+ email="synthetic@email.com",
117
+ )
@@ -0,0 +1,67 @@
1
+ from abc import ABCMeta, abstractmethod
2
+ from time import time, time_ns
3
+ from typing import Optional
4
+
5
+ from jose import jwk
6
+ from jose.utils import base64url_decode
7
+
8
+ from auth_middleware.types import JWK, JWKS, JWTAuthorizationCredentials, User
9
+
10
+
11
+ class JWTAuthProvider(metaclass=ABCMeta):
12
+
13
+ def _get_jwks(self) -> JWKS | None:
14
+ """
15
+ Returns a structure that caches the public keys used by the auth provider to sign its JWT tokens.
16
+ Cache is refreshed after a settable time or number of reads (usages)
17
+ """
18
+ reload_cache = False
19
+ try:
20
+ if (
21
+ not hasattr(self, "jks")
22
+ or self.jks.timestamp is None
23
+ or self.jks.timestamp < time_ns()
24
+ or self.jks.usage_counter is None
25
+ or self.jks.usage_counter <= 0
26
+ ):
27
+ reload_cache = True
28
+ except AttributeError:
29
+ # the first time after application startup, self.jks is NOT defined
30
+ reload_cache = True
31
+
32
+ try:
33
+ if reload_cache:
34
+ self.jks: JWKS = self.load_jwks()
35
+ else:
36
+ if self.jks.usage_counter is not None:
37
+ self.jks.usage_counter -= 1
38
+
39
+ except KeyError:
40
+ return None
41
+
42
+ return self.jks
43
+
44
+ def _get_hmac_key(self, token: JWTAuthorizationCredentials) -> Optional[JWK]:
45
+ jwks: Optional[JWKS] = self._get_jwks()
46
+ if jwks is not None and jwks.keys is not None:
47
+ for key in jwks.keys:
48
+ if key["kid"] == token.header["kid"]:
49
+ return key
50
+ return None
51
+
52
+ @abstractmethod
53
+ def load_jwks(
54
+ self,
55
+ ) -> JWKS: ...
56
+
57
+ @abstractmethod
58
+ async def verify_token(
59
+ self,
60
+ token: JWTAuthorizationCredentials,
61
+ ) -> bool: ...
62
+
63
+ @abstractmethod
64
+ def create_user_from_token(
65
+ self,
66
+ token: JWTAuthorizationCredentials,
67
+ ) -> User: ...
@@ -0,0 +1,83 @@
1
+ from typing import Optional
2
+
3
+ from fastapi import HTTPException
4
+ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
5
+ from jose import JWTError, jwt
6
+ from starlette.requests import Request
7
+ from starlette.status import HTTP_403_FORBIDDEN
8
+
9
+ from auth_middleware.exceptions import InvalidTokenException
10
+ from auth_middleware.jwt_auth_provider import JWTAuthProvider
11
+ from auth_middleware.logging import logger
12
+ from auth_middleware.settings import settings
13
+ from auth_middleware.types import JWTAuthorizationCredentials
14
+
15
+
16
+ class JWTBearerManager(HTTPBearer):
17
+
18
+ def __init__(
19
+ self,
20
+ auth_provider: JWTAuthProvider,
21
+ auto_error: bool = True,
22
+ ):
23
+ super().__init__(auto_error=auto_error)
24
+ self.auth_provider = auth_provider
25
+
26
+ async def get_credentials(
27
+ self, request: Request
28
+ ) -> Optional[JWTAuthorizationCredentials]:
29
+ if settings.AUTH_DISABLED:
30
+ return None
31
+
32
+ try:
33
+ credentials: Optional[HTTPAuthorizationCredentials] = (
34
+ await super().__call__(request)
35
+ )
36
+ except HTTPException as e:
37
+ logger.error("Error in JWTBearerManager: %s", str(e))
38
+ raise e
39
+ except Exception as e:
40
+ logger.error("Error in JWTBearerManager: %s", str(e))
41
+ raise InvalidTokenException(
42
+ status_code=HTTP_403_FORBIDDEN,
43
+ detail="JWK-invalid",
44
+ )
45
+
46
+ if credentials:
47
+ # TODO: use a constant for the string "Bearer"
48
+ if credentials.scheme != "Bearer":
49
+ logger.error("Error in JWTBearerManager: Wrong authentication method")
50
+ raise InvalidTokenException(
51
+ status_code=HTTP_403_FORBIDDEN,
52
+ detail="Wrong authentication method",
53
+ )
54
+
55
+ jwt_token = credentials.credentials
56
+
57
+ message, signature = jwt_token.rsplit(".", 1)
58
+
59
+ try:
60
+ jwt_credentials = JWTAuthorizationCredentials(
61
+ jwt_token=jwt_token,
62
+ header=jwt.get_unverified_header(jwt_token),
63
+ claims=jwt.get_unverified_claims(jwt_token),
64
+ signature=signature,
65
+ message=message,
66
+ )
67
+ except JWTError:
68
+ logger.error("Error in JWTBearerManager: JWTError")
69
+ raise InvalidTokenException(
70
+ status_code=HTTP_403_FORBIDDEN,
71
+ detail="JWK-invalid",
72
+ )
73
+
74
+ if not self.auth_provider.verify_token(jwt_credentials):
75
+ logger.error("Error in JWTBearerManager: token not verified")
76
+ raise InvalidTokenException(
77
+ status_code=HTTP_403_FORBIDDEN,
78
+ detail="JWK_invalid",
79
+ )
80
+
81
+ return jwt_credentials
82
+
83
+ return None
@@ -0,0 +1,33 @@
1
+ import logging
2
+
3
+ import colorlog
4
+
5
+ from auth_middleware.settings import settings
6
+
7
+ logger = logging.getLogger(
8
+ __name__ if settings.LOGGER_NAME == "" else settings.LOGGER_NAME
9
+ )
10
+ logger.setLevel(settings.LOG_LEVEL)
11
+
12
+ # create console handler and set level to debug
13
+ ch = logging.StreamHandler()
14
+ ch.setLevel(settings.LOG_LEVEL)
15
+
16
+ # create formatter
17
+ formatter = colorlog.ColoredFormatter(
18
+ settings.LOG_FORMAT,
19
+ reset=True,
20
+ log_colors={
21
+ "DEBUG": "cyan",
22
+ "INFO": "green",
23
+ "WARNING": "yellow",
24
+ "ERROR": "red",
25
+ "CRITICAL": "red,bg_white",
26
+ },
27
+ )
28
+
29
+ # add formatter to ch
30
+ ch.setFormatter(formatter)
31
+
32
+ # add ch to logger
33
+ logger.addHandler(ch)
@@ -0,0 +1,93 @@
1
+ from time import time, time_ns
2
+ from typing import List
3
+
4
+ import requests
5
+ from jose import jwk
6
+ from jose.utils import base64url_decode
7
+
8
+ from auth_middleware.jwt_auth_provider import JWTAuthProvider
9
+ from auth_middleware.logging import logger
10
+ from auth_middleware.providers.cognito.exceptions import AWSException
11
+ from auth_middleware.providers.cognito.settings import settings
12
+ from auth_middleware.types import JWK, JWKS, JWTAuthorizationCredentials, User
13
+
14
+
15
+ class CognitoProvider(JWTAuthProvider):
16
+
17
+ def __new__(cls):
18
+ if not hasattr(cls, "instance"):
19
+ cls.instance = super(CognitoProvider, cls).__new__(cls)
20
+ return cls.instance
21
+
22
+ def load_jwks(
23
+ self,
24
+ ) -> JWKS:
25
+ """Load JWKS credentials from remote Identity Provider
26
+
27
+ Returns:
28
+ JWKS: _description_
29
+ """
30
+
31
+ # TODO: Control errors
32
+ keys: List[JWK] = requests.get(
33
+ settings.AWS_COGNITO_JWKS_URL_TEMPLATE.format(
34
+ settings.AWS_COGNITO_USER_POOL_REGION,
35
+ settings.AWS_COGNITO_USER_POOL_ID,
36
+ )
37
+ ).json()["keys"]
38
+ timestamp: int = (
39
+ time_ns() + settings.AUTH_JWKS_CACHE_INTERVAL_MINUTES * 60 * 1000000000
40
+ )
41
+ usage_counter: int = settings.AUTH_JWKS_CACHE_USAGES
42
+ jks: JWKS = JWKS(keys=keys, timestamp=timestamp, usage_counter=usage_counter)
43
+
44
+ return jks
45
+
46
+ def verify_token(self, token: JWTAuthorizationCredentials) -> bool:
47
+
48
+ hmac_key_candidate = self._get_hmac_key(token)
49
+
50
+ if not hmac_key_candidate:
51
+ logger.error(
52
+ "No public key found that matches the one present in the TOKEN!"
53
+ )
54
+ raise AWSException("No public key found!")
55
+
56
+ hmac_key = jwk.construct(hmac_key_candidate)
57
+
58
+ decoded_signature = base64url_decode(token.signature.encode())
59
+
60
+ # if crypto is OK, then check expiry date
61
+ if hmac_key.verify(token.message.encode(), decoded_signature):
62
+ return token.claims["exp"] > time()
63
+
64
+ return False
65
+
66
+ def create_user_from_token(self, token: JWTAuthorizationCredentials) -> User:
67
+ """Initializes a domain User object with data recovered from a JWT TOKEN.
68
+ Args:
69
+ token (JWTAuthorizationCredentials): Defaults to Depends(oauth2_scheme).
70
+
71
+ Returns:
72
+ User: Domain object.
73
+
74
+ """
75
+
76
+ name_property: str = (
77
+ "username" if "username" in token.claims else "cognito:username"
78
+ )
79
+
80
+ return User(
81
+ id=token.claims["sub"],
82
+ name=(
83
+ token.claims[name_property]
84
+ if name_property in token.claims
85
+ else token.claims["sub"]
86
+ ),
87
+ groups=(
88
+ token.claims["cognito:groups"]
89
+ if "cognito:groups" in token.claims
90
+ else [str(token.claims["scope"]).split("/")[-1]]
91
+ ),
92
+ email=token.claims["email"] if "email" in token.claims else None,
93
+ )
@@ -0,0 +1,8 @@
1
+ class AWSException(Exception):
2
+ """Domain exception fro wrapping AWS-related exceptions.
3
+
4
+ Args:
5
+ Exception (Exception): inherits from base exception
6
+ """
7
+
8
+ ...
@@ -0,0 +1,44 @@
1
+ from typing import Optional
2
+
3
+ from starlette.config import Config
4
+
5
+ from auth_middleware.settings import Settings
6
+
7
+ config = Config()
8
+
9
+
10
+ class ModuleSettings(Settings):
11
+ """Settings for the module"""
12
+
13
+ AWS_COGNITO_USER_POOL_ID: Optional[str] = config(
14
+ "AWS_COGNITO_USER_POOL_ID",
15
+ cast=str,
16
+ default=None,
17
+ )
18
+
19
+ AWS_COGNITO_USER_POOL_REGION: Optional[str] = config(
20
+ "AWS_COGNITO_USER_POOL_REGION",
21
+ cast=str,
22
+ default=None,
23
+ )
24
+
25
+ AWS_COGNITO_JWKS_URL_TEMPLATE: str = config(
26
+ "AWS_COGNITO_JWKS_URL_TEMPLATE",
27
+ cast=str,
28
+ default="https://cognito-idp.{}.amazonaws.com/{}/.well-known/jwks.json",
29
+ )
30
+
31
+ AWS_COGNITO_USER_POOL_CLIENT_ID: Optional[str] = config(
32
+ "AWS_COGNITO_USER_POOL_CLIENT_ID",
33
+ cast=str,
34
+ default=None,
35
+ )
36
+
37
+ AWS_COGNITO_USER_POOL_CLIENT_SECRET: Optional[str] = config(
38
+ "AWS_COGNITO_USER_POOL_CLIENT_SECRET",
39
+ cast=str,
40
+ default=None,
41
+ )
42
+
43
+
44
+ settings = ModuleSettings()
@@ -0,0 +1,125 @@
1
+ from time import time, time_ns
2
+
3
+ import requests
4
+ from jose import JWTError, jwt
5
+
6
+ from auth_middleware.jwt_auth_provider import JWTAuthProvider
7
+ from auth_middleware.logging import logger
8
+ from auth_middleware.providers.entra_id.exceptions import AzureException
9
+ from auth_middleware.providers.entra_id.settings import settings
10
+ from auth_middleware.types import JWK, JWKS, JWTAuthorizationCredentials, User
11
+
12
+
13
+ class EntraIDProvider(JWTAuthProvider):
14
+
15
+ def __new__(cls):
16
+ if not hasattr(cls, "instance"):
17
+ cls.instance = super(EntraIDProvider, cls).__new__(cls)
18
+ return cls.instance
19
+
20
+ def load_jwks(
21
+ self,
22
+ ) -> JWKS:
23
+ """Load JWKS credentials from remote Identity Provider
24
+
25
+ Returns:
26
+ JWKS: _description_
27
+ """
28
+
29
+ # TODO: Control errors
30
+ openid_config = requests.get(
31
+ settings.AZURE_ENTRA_ID_JWKS_URL_TEMPLATE.format(
32
+ settings.AZURE_ENTRA_ID_TENANT_ID,
33
+ )
34
+ ).json()
35
+ jwks_uri = openid_config["jwks_uri"]
36
+ keys = requests.get(jwks_uri).json()["keys"]
37
+
38
+ # Convert 'x5c' field in each key from list to string
39
+ for key in keys:
40
+ if "x5c" in key and isinstance(key["x5c"], list):
41
+ key["x5c"] = "".join(key["x5c"])
42
+
43
+ timestamp: int = (
44
+ time_ns() + settings.AUTH_JWKS_CACHE_INTERVAL_MINUTES * 60 * 1000000000
45
+ )
46
+ usage_counter: int = settings.AUTH_JWKS_CACHE_USAGES
47
+ jks: JWKS = JWKS(keys=keys, timestamp=timestamp, usage_counter=usage_counter)
48
+ return jks
49
+
50
+ def verify_token(self, token: JWTAuthorizationCredentials) -> bool:
51
+ """Verifiy token signature
52
+
53
+ Args:
54
+ token (JWTAuthorizationCredentials): _description_
55
+
56
+ Raises:
57
+ AzureException: _description_
58
+
59
+ Returns:
60
+ bool: _description_
61
+ """
62
+
63
+ hmac_key_candidate = self._get_hmac_key(token)
64
+
65
+ if not hmac_key_candidate:
66
+ logger.error(
67
+ "No public key found that matches the one present in the TOKEN!"
68
+ )
69
+ raise AzureException("No public key found!")
70
+
71
+ # hmac_key = jwk.construct(hmac_key_candidate)
72
+
73
+ try:
74
+ rsa_key = {
75
+ "kty": hmac_key_candidate["kty"],
76
+ "kid": hmac_key_candidate["kid"],
77
+ "use": hmac_key_candidate["use"],
78
+ "n": hmac_key_candidate["n"],
79
+ "e": hmac_key_candidate["e"],
80
+ }
81
+
82
+ # Decode jwt token
83
+ payload = jwt.decode(
84
+ token.jwt_token,
85
+ rsa_key,
86
+ algorithms=["RS256"],
87
+ audience=settings.AZURE_ENTRA_ID_AUDIENCE_ID,
88
+ options={"verify_at_hash": False}, # Disable at_hash verification
89
+ )
90
+ return False if payload.get("sub") is None else True
91
+ except JWTError as je:
92
+ logger.error("Error in EntraIDClient: %s", str(je))
93
+ return False
94
+ except Exception as e:
95
+ logger.error("Error in JWTBearerManager: %s", str(e))
96
+ raise AzureException("Error in JWTBearerManager")
97
+
98
+ def create_user_from_token(self, token: JWTAuthorizationCredentials) -> User:
99
+ """Initializes a domain User object with data recovered from a JWT TOKEN.
100
+ Args:
101
+ token (JWTAuthorizationCredentials): Defaults to Depends(oauth2_scheme).
102
+
103
+ Returns:
104
+ User: Domain object.
105
+
106
+ """
107
+
108
+ name_property: str = (
109
+ "username" if "username" in token.claims else "preferred_username"
110
+ )
111
+
112
+ return User(
113
+ id=token.claims["sub"],
114
+ name=(
115
+ token.claims[name_property]
116
+ if name_property in token.claims
117
+ else token.claims["sub"]
118
+ ),
119
+ groups=(
120
+ token.claims["groups"]
121
+ if "groups" in token.claims
122
+ else [str(token.claims["scope"]).split("/")[-1]]
123
+ ),
124
+ email=token.claims["email"] if "email" in token.claims else None,
125
+ )
@@ -0,0 +1,8 @@
1
+ class AzureException(Exception):
2
+ """Domain exception for wrapping Azure-related exceptions.
3
+
4
+ Args:
5
+ Exception (Exception): inherits from base exception
6
+ """
7
+
8
+ ...
@@ -0,0 +1,33 @@
1
+ from typing import Optional
2
+
3
+ from starlette.config import Config
4
+
5
+ from auth_middleware.settings import Settings
6
+
7
+ config = Config()
8
+
9
+
10
+ class ModuleSettings(Settings):
11
+ """Settings for the module"""
12
+
13
+ AZURE_ENTRA_ID_TENANT_ID: Optional[str] = config(
14
+ "AZURE_ENTRA_ID_TENANT_ID",
15
+ cast=str,
16
+ default=None,
17
+ )
18
+
19
+ # The audience id is the client id of the application
20
+ AZURE_ENTRA_ID_AUDIENCE_ID: Optional[str] = config(
21
+ "AZURE_ENTRA_ID_AUDIENCE_ID",
22
+ cast=str,
23
+ default=None,
24
+ )
25
+
26
+ AZURE_ENTRA_ID_JWKS_URL_TEMPLATE: str = config(
27
+ "AZURE_ENTRA_ID_JWKS_URL_TEMPLATE",
28
+ cast=str,
29
+ default="https://login.microsoftonline.com/{}/v2.0/.well-known/openid-configuration",
30
+ )
31
+
32
+
33
+ settings = ModuleSettings()
@@ -0,0 +1,33 @@
1
+ from starlette.config import Config
2
+
3
+ config = Config()
4
+
5
+
6
+ class Settings:
7
+ """Settings for the module"""
8
+
9
+ LOG_LEVEL: str = config("LOG_LEVEL", cast=str, default="INFO").upper()
10
+ LOG_FORMAT: str = config(
11
+ "LOG_FORMAT",
12
+ cast=str,
13
+ default="%(log_color)s%(levelname)-9s%(reset)s %(asctime)s %(name)s %(message)s",
14
+ )
15
+
16
+ LOGGER_NAME: str = config("LOGGER_NAME", cast=str, default="authmiddleware")
17
+
18
+ # Disable authentication for the whole application
19
+ AUTH_DISABLED = config("AUTH_DISABLED", cast=bool, default=False)
20
+
21
+ AUTH_JWKS_CACHE_INTERVAL_MINUTES: int = config(
22
+ "AUTH_JWKS_CACHE_INTERVAL_MINUTES",
23
+ cast=int,
24
+ default=20,
25
+ )
26
+ AUTH_JWKS_CACHE_USAGES: int = config(
27
+ "AUTH_JWKS_CACHE_USAGES",
28
+ cast=int,
29
+ default=1000,
30
+ )
31
+
32
+
33
+ settings = Settings()
@@ -0,0 +1,62 @@
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ from pydantic import BaseModel, EmailStr, Field
4
+
5
+ JWK = Dict[str, str]
6
+
7
+
8
+ class JWKS(BaseModel):
9
+ keys: Optional[List[JWK]] = []
10
+ timestamp: Optional[int] = None
11
+ usage_counter: Optional[int] = 0
12
+
13
+
14
+ class JWTAuthorizationCredentials(BaseModel):
15
+ jwt_token: str
16
+ header: Dict[str, str]
17
+ claims: Dict[str, Any]
18
+ signature: str
19
+ message: str
20
+
21
+
22
+ class User(BaseModel):
23
+ """Application User
24
+
25
+ Args:
26
+ BaseModel (BaseModel): Inherited properties
27
+ """
28
+
29
+ id: str = Field(
30
+ ...,
31
+ max_length=500,
32
+ json_schema_extra={
33
+ "description": "Unique user ID (sub)",
34
+ "example": "0ujsswThIGTUYm2K8FjOOfXtY1K",
35
+ },
36
+ )
37
+
38
+ name: Optional[str] = Field(
39
+ default=None,
40
+ max_length=500,
41
+ json_schema_extra={
42
+ "description": "User name",
43
+ "example": "test_user",
44
+ },
45
+ )
46
+
47
+ email: Optional[EmailStr] = Field(
48
+ default=None,
49
+ max_length=500,
50
+ json_schema_extra={
51
+ "description": "User's email address (Optional)",
52
+ "example": "useradmin@user.com",
53
+ },
54
+ )
55
+
56
+ groups: Optional[List[str]] = Field(
57
+ default=[],
58
+ json_schema_extra={
59
+ "description": "List of user groups",
60
+ "example": '["admin", "user"]',
61
+ },
62
+ )