open-shield-python 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- open_shield/adapters/__init__.py +5 -0
- open_shield/adapters/config.py +26 -0
- open_shield/adapters/key_provider.py +109 -0
- open_shield/adapters/token_validator.py +75 -0
- open_shield/api/__init__.py +3 -0
- open_shield/api/fastapi/__init__.py +4 -0
- open_shield/api/fastapi/dependencies.py +47 -0
- open_shield/api/fastapi/middleware.py +85 -0
- open_shield/domain/__init__.py +18 -0
- open_shield/domain/entities.py +59 -0
- open_shield/domain/exceptions.py +40 -0
- open_shield/domain/ports/__init__.py +4 -0
- open_shield/domain/ports/key_provider.py +32 -0
- open_shield/domain/ports/token_validator.py +38 -0
- open_shield/domain/services/__init__.py +4 -0
- open_shield/domain/services/authorization_service.py +45 -0
- open_shield/domain/services/token_service.py +83 -0
- open_shield_python-0.1.0.dist-info/METADATA +130 -0
- open_shield_python-0.1.0.dist-info/RECORD +21 -0
- open_shield_python-0.1.0.dist-info/WHEEL +4 -0
- open_shield_python-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class OpenShieldConfig(BaseSettings):
|
|
5
|
+
"""
|
|
6
|
+
Configuration for Open Shield SDK.
|
|
7
|
+
Loads settings from environment variables (OPEN_SHIELD_prefix).
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
model_config = SettingsConfigDict(
|
|
11
|
+
env_prefix="OPEN_SHIELD_",
|
|
12
|
+
env_file=".env",
|
|
13
|
+
env_file_encoding="utf-8",
|
|
14
|
+
extra="ignore",
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
ISSUER_URL: str
|
|
18
|
+
AUDIENCE: str | None = None
|
|
19
|
+
ALGORITHMS: list[str] = ["RS256"]
|
|
20
|
+
|
|
21
|
+
# Authorization defaults
|
|
22
|
+
REQUIRE_SCOPES: bool = True
|
|
23
|
+
REQUIRE_ROLES: bool = False
|
|
24
|
+
|
|
25
|
+
# Tenant extraction
|
|
26
|
+
TENANT_ID_CLAIM: str = "tid"
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import httpx
|
|
4
|
+
|
|
5
|
+
from open_shield.domain.exceptions import ConfigurationError, OpenShieldError
|
|
6
|
+
from open_shield.domain.ports import KeyProviderPort
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class OIDCDiscoKeyProvider(KeyProviderPort):
|
|
10
|
+
"""
|
|
11
|
+
Adapter that fetches JWKS from an OIDC provider's well-known configuration.
|
|
12
|
+
Features:
|
|
13
|
+
- Automatic discovery of jwks_uri
|
|
14
|
+
- Key caching (simple in-memory for MVP)
|
|
15
|
+
- Key rotation support (refresh on miss)
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, issuer_url: str):
|
|
19
|
+
self.issuer_url = issuer_url.rstrip("/")
|
|
20
|
+
self._jwks_uri: str | None = None
|
|
21
|
+
self._keys: dict[str, Any] = {}
|
|
22
|
+
self._client = httpx.Client() # Use sync client for simplicity in core logic
|
|
23
|
+
|
|
24
|
+
def get_key(self, kid: str) -> Any:
|
|
25
|
+
if kid not in self._keys:
|
|
26
|
+
self._refresh_keys()
|
|
27
|
+
|
|
28
|
+
if kid not in self._keys:
|
|
29
|
+
# Try one more time, force refresh
|
|
30
|
+
self._refresh_keys()
|
|
31
|
+
|
|
32
|
+
if kid not in self._keys:
|
|
33
|
+
raise OpenShieldError(
|
|
34
|
+
f"Key ID {kid} not found in JWKS from {self.issuer_url}"
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
return self._keys[kid]
|
|
38
|
+
|
|
39
|
+
def get_all_keys(self) -> list[dict[str, Any]]:
|
|
40
|
+
if not self._keys:
|
|
41
|
+
self._refresh_keys()
|
|
42
|
+
return list(self._keys.values())
|
|
43
|
+
|
|
44
|
+
def _refresh_keys(self) -> None:
|
|
45
|
+
if not self._jwks_uri:
|
|
46
|
+
self._discover()
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
response = self._client.get(self._jwks_uri) # type: ignore
|
|
50
|
+
response.raise_for_status()
|
|
51
|
+
jwks = response.json()
|
|
52
|
+
|
|
53
|
+
new_keys = {}
|
|
54
|
+
for key_data in jwks.get("keys", []):
|
|
55
|
+
kid = key_data.get("kid")
|
|
56
|
+
if kid:
|
|
57
|
+
# Convert JWK to PEM/Public Key object using PyJWT helpers
|
|
58
|
+
# Optimization: In real world, use
|
|
59
|
+
# jwt.algorithms.RSAAlgorithm.from_jwk
|
|
60
|
+
# Here we store raw JWK or converted object depending on what
|
|
61
|
+
# validator expects. For PyJWT, passing the JWK dict or a key
|
|
62
|
+
# object often works, but let's be explicit.
|
|
63
|
+
# We will use PyJWT's internal helpers if available or just return
|
|
64
|
+
# the dict as PyJWT decode() accepts a JWK dict set or specific key.
|
|
65
|
+
# Ideally we convert distinct key per algorithm.
|
|
66
|
+
|
|
67
|
+
# For simplicity in this phase, we'll store the JWK dict.
|
|
68
|
+
# The Validator adapter will need to handle the conversion if
|
|
69
|
+
# needed, OR we implement conversion here.
|
|
70
|
+
# Best practice: KeyProvider returns ready-to-use keys.
|
|
71
|
+
import jwt.algorithms
|
|
72
|
+
from jwt.algorithms import Algorithm
|
|
73
|
+
|
|
74
|
+
# Try to get algo from 'alg' field first (e.g. RS256)
|
|
75
|
+
alg_name = key_data.get("alg")
|
|
76
|
+
algo: Algorithm | None = None
|
|
77
|
+
if alg_name:
|
|
78
|
+
algo = jwt.algorithms.get_default_algorithms().get(alg_name)
|
|
79
|
+
else:
|
|
80
|
+
# Fallback for RSA if alg is missing but kty is RSA
|
|
81
|
+
if key_data.get("kty") == "RSA":
|
|
82
|
+
algo = jwt.algorithms.RSAAlgorithm # type: ignore
|
|
83
|
+
else:
|
|
84
|
+
algo = None
|
|
85
|
+
|
|
86
|
+
if algo:
|
|
87
|
+
public_key = algo.from_jwk(key_data)
|
|
88
|
+
new_keys[kid] = public_key
|
|
89
|
+
|
|
90
|
+
self._keys = new_keys
|
|
91
|
+
|
|
92
|
+
except Exception as e:
|
|
93
|
+
raise OpenShieldError(f"Failed to refresh JWKS: {e!s}") from e
|
|
94
|
+
|
|
95
|
+
def _discover(self) -> None:
|
|
96
|
+
try:
|
|
97
|
+
disco_url = f"{self.issuer_url}/.well-known/openid-configuration"
|
|
98
|
+
response = self._client.get(disco_url)
|
|
99
|
+
response.raise_for_status()
|
|
100
|
+
config = response.json()
|
|
101
|
+
self._jwks_uri = config.get("jwks_uri")
|
|
102
|
+
if not self._jwks_uri:
|
|
103
|
+
raise ConfigurationError(
|
|
104
|
+
"Pre-discovery failed: jwks_uri not found in OIDC config"
|
|
105
|
+
)
|
|
106
|
+
except Exception as e:
|
|
107
|
+
raise ConfigurationError(
|
|
108
|
+
f"OIDC discovery failed for {self.issuer_url}: {e!s}"
|
|
109
|
+
) from e
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import jwt
|
|
4
|
+
from jwt.exceptions import ExpiredSignatureError
|
|
5
|
+
from jwt.exceptions import InvalidTokenError as PyJWTError
|
|
6
|
+
|
|
7
|
+
from open_shield.domain.entities import Token
|
|
8
|
+
from open_shield.domain.exceptions import (
|
|
9
|
+
ExpiredTokenError,
|
|
10
|
+
InvalidSignatureError,
|
|
11
|
+
TokenValidationError,
|
|
12
|
+
)
|
|
13
|
+
from open_shield.domain.ports import TokenValidatorPort
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PyJWTValidator(TokenValidatorPort):
|
|
17
|
+
"""
|
|
18
|
+
Adapter that uses PyJWT to validate tokens.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
key_provider: Any | None = None,
|
|
24
|
+
algorithms: list[str] | None = None,
|
|
25
|
+
audience: str | None = None,
|
|
26
|
+
issuer: str | None = None,
|
|
27
|
+
) -> None:
|
|
28
|
+
self.key_provider = key_provider
|
|
29
|
+
self.algorithms = algorithms or ["RS256"]
|
|
30
|
+
self.audience = audience
|
|
31
|
+
self.issuer = issuer
|
|
32
|
+
|
|
33
|
+
def validate_token(self, token_string: str) -> Token:
|
|
34
|
+
try:
|
|
35
|
+
# 1. Decode unverified header to get key ID (kid)
|
|
36
|
+
header = jwt.get_unverified_header(token_string)
|
|
37
|
+
kid = header.get("kid")
|
|
38
|
+
|
|
39
|
+
key = None
|
|
40
|
+
if self.key_provider and kid:
|
|
41
|
+
key = self.key_provider.get_key(kid)
|
|
42
|
+
|
|
43
|
+
# 2. Decode and validate
|
|
44
|
+
# options={"verify_signature": True} is default
|
|
45
|
+
payload = jwt.decode(
|
|
46
|
+
token_string,
|
|
47
|
+
key=key, # type: ignore
|
|
48
|
+
algorithms=self.algorithms,
|
|
49
|
+
audience=self.audience,
|
|
50
|
+
issuer=self.issuer,
|
|
51
|
+
options={
|
|
52
|
+
"verify_exp": True,
|
|
53
|
+
"verify_aud": bool(self.audience),
|
|
54
|
+
"verify_iss": bool(self.issuer),
|
|
55
|
+
},
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
return Token(raw=token_string, claims=payload)
|
|
59
|
+
|
|
60
|
+
except ExpiredSignatureError as e:
|
|
61
|
+
raise ExpiredTokenError(f"Token expired: {e!s}") from e
|
|
62
|
+
except PyJWTError as e:
|
|
63
|
+
# Map generic PyJWT errors to specific domain errors if possible
|
|
64
|
+
if "Signature verification failed" in str(e):
|
|
65
|
+
raise InvalidSignatureError(f"Invalid signature: {e!s}") from e
|
|
66
|
+
raise TokenValidationError(f"Token validation failed: {e!s}") from e
|
|
67
|
+
except Exception as e:
|
|
68
|
+
# Fallback for unexpected errors
|
|
69
|
+
raise TokenValidationError(f"Unexpected validation error: {e!s}") from e
|
|
70
|
+
|
|
71
|
+
def decode_unverified(self, token_string: str) -> dict[str, Any]:
|
|
72
|
+
try:
|
|
73
|
+
return jwt.decode(token_string, options={"verify_signature": False})
|
|
74
|
+
except PyJWTError as e:
|
|
75
|
+
raise TokenValidationError(f"Failed to decode token: {e!s}") from e
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from fastapi import Depends, HTTPException, Request
|
|
2
|
+
|
|
3
|
+
from open_shield.domain.entities import UserContext
|
|
4
|
+
from open_shield.domain.exceptions import AuthorizationError
|
|
5
|
+
from open_shield.domain.services import AuthorizationService
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def get_user_context(request: Request) -> UserContext:
|
|
9
|
+
"""
|
|
10
|
+
Dependency to retrieve the UserContext from request.state.
|
|
11
|
+
Assumes OpenShieldMiddleware has run.
|
|
12
|
+
"""
|
|
13
|
+
if not hasattr(request.state, "user_context"):
|
|
14
|
+
raise HTTPException(status_code=401, detail="Authentication required")
|
|
15
|
+
from typing import cast
|
|
16
|
+
|
|
17
|
+
return cast(UserContext, request.state.user_context)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RequireScope:
|
|
21
|
+
"""Dependency that enforces a required scope."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, scope: str):
|
|
24
|
+
self.scope = scope
|
|
25
|
+
self.auth_service = AuthorizationService()
|
|
26
|
+
|
|
27
|
+
def __call__(self, context: UserContext = Depends(get_user_context)) -> UserContext:
|
|
28
|
+
try:
|
|
29
|
+
self.auth_service.require_scope(context, self.scope)
|
|
30
|
+
except AuthorizationError as e:
|
|
31
|
+
raise HTTPException(status_code=403, detail=str(e)) from e
|
|
32
|
+
return context
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class RequireRole:
|
|
36
|
+
"""Dependency that enforces one of the required roles."""
|
|
37
|
+
|
|
38
|
+
def __init__(self, roles: list[str]):
|
|
39
|
+
self.roles = roles
|
|
40
|
+
self.auth_service = AuthorizationService()
|
|
41
|
+
|
|
42
|
+
def __call__(self, context: UserContext = Depends(get_user_context)) -> UserContext:
|
|
43
|
+
try:
|
|
44
|
+
self.auth_service.require_any_role(context, self.roles)
|
|
45
|
+
except AuthorizationError as e:
|
|
46
|
+
raise HTTPException(status_code=403, detail=str(e)) from e
|
|
47
|
+
return context
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from fastapi import Request, Response
|
|
2
|
+
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
3
|
+
from starlette.types import ASGIApp
|
|
4
|
+
|
|
5
|
+
from open_shield.adapters import OIDCDiscoKeyProvider, OpenShieldConfig, PyJWTValidator
|
|
6
|
+
from open_shield.domain.exceptions import OpenShieldError, TokenValidationError
|
|
7
|
+
from open_shield.domain.services import TokenService
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class OpenShieldMiddleware(BaseHTTPMiddleware):
|
|
11
|
+
"""
|
|
12
|
+
Middleware that intercepts requests, validates the Authorization header,
|
|
13
|
+
and attaches the UserContext to the request state.
|
|
14
|
+
|
|
15
|
+
Attributes:
|
|
16
|
+
app: The ASGI application.
|
|
17
|
+
token_service: The domain service for validation.
|
|
18
|
+
config: The SDK configuration.
|
|
19
|
+
exclude_paths: A set of paths to exclude from authentication.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
app: ASGIApp,
|
|
25
|
+
config: OpenShieldConfig,
|
|
26
|
+
exclude_paths: set[str] | None = None,
|
|
27
|
+
):
|
|
28
|
+
super().__init__(app)
|
|
29
|
+
self.config = config
|
|
30
|
+
self.exclude_paths = exclude_paths or {
|
|
31
|
+
"/docs",
|
|
32
|
+
"/openapi.json",
|
|
33
|
+
"/redoc",
|
|
34
|
+
"/health",
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
# Initialize dependencies
|
|
38
|
+
# In a real app, these might be injected, but middleware initialization is often
|
|
39
|
+
# the composition root.
|
|
40
|
+
key_provider = OIDCDiscoKeyProvider(issuer_url=config.ISSUER_URL)
|
|
41
|
+
validator = PyJWTValidator(
|
|
42
|
+
key_provider=key_provider,
|
|
43
|
+
algorithms=config.ALGORITHMS,
|
|
44
|
+
audience=config.AUDIENCE,
|
|
45
|
+
issuer=config.ISSUER_URL,
|
|
46
|
+
)
|
|
47
|
+
self.token_service = TokenService(validator=validator)
|
|
48
|
+
|
|
49
|
+
async def dispatch(
|
|
50
|
+
self, request: Request, call_next: RequestResponseEndpoint
|
|
51
|
+
) -> Response:
|
|
52
|
+
if request.url.path in self.exclude_paths:
|
|
53
|
+
return await call_next(request)
|
|
54
|
+
|
|
55
|
+
auth_header = request.headers.get("Authorization")
|
|
56
|
+
if not auth_header:
|
|
57
|
+
# Basic check, detailed handling usually done by dependency or explicit 401
|
|
58
|
+
# If we want global enforcement, strictly 401 here.
|
|
59
|
+
# Return 401.
|
|
60
|
+
return Response("Missing Authorization Header", status_code=401)
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
scheme, token = auth_header.split()
|
|
64
|
+
if scheme.lower() != "bearer":
|
|
65
|
+
return Response("Invalid Authorization Scheme", status_code=401)
|
|
66
|
+
|
|
67
|
+
user_context = self.token_service.validate_and_extract(token)
|
|
68
|
+
|
|
69
|
+
# Attach to request.state for downstream access
|
|
70
|
+
request.state.user_context = user_context
|
|
71
|
+
|
|
72
|
+
# Enforce global require_scopes/roles if configured?
|
|
73
|
+
# Usually better handled in route dependencies.
|
|
74
|
+
|
|
75
|
+
except (ValueError, TokenValidationError) as e:
|
|
76
|
+
return Response(f"Unauthorized: {e!s}", status_code=401)
|
|
77
|
+
except OpenShieldError as e:
|
|
78
|
+
return Response(f"Forbidden: {e!s}", status_code=403)
|
|
79
|
+
except Exception:
|
|
80
|
+
# Log this error
|
|
81
|
+
return Response(
|
|
82
|
+
"Internal Server Error during Authentication", status_code=500
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
return await call_next(request)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from .entities import TenantContext, Token, User, UserContext
|
|
2
|
+
from .exceptions import (
|
|
3
|
+
AuthorizationError,
|
|
4
|
+
ConfigurationError,
|
|
5
|
+
OpenShieldError,
|
|
6
|
+
TokenValidationError,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"AuthorizationError",
|
|
11
|
+
"ConfigurationError",
|
|
12
|
+
"OpenShieldError",
|
|
13
|
+
"TenantContext",
|
|
14
|
+
"Token",
|
|
15
|
+
"TokenValidationError",
|
|
16
|
+
"User",
|
|
17
|
+
"UserContext",
|
|
18
|
+
]
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Entity(BaseModel):
|
|
7
|
+
"""Base class for all domain entities."""
|
|
8
|
+
|
|
9
|
+
model_config = ConfigDict(frozen=True)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TenantContext(Entity):
|
|
13
|
+
"""Represents the tenant context extracted from a token."""
|
|
14
|
+
|
|
15
|
+
tenant_id: str = Field(..., description="Unique identifier for the tenant")
|
|
16
|
+
metadata: dict[str, Any] = Field(
|
|
17
|
+
default_factory=dict, description="Additional tenant metadata"
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class User(Entity):
|
|
22
|
+
"""Represents an authenticated user."""
|
|
23
|
+
|
|
24
|
+
id: str = Field(..., description="Unique user identifier (sub)")
|
|
25
|
+
email: str | None = Field(None, description="User email address")
|
|
26
|
+
roles: list[str] = Field(default_factory=list, description="Assigned roles")
|
|
27
|
+
scopes: list[str] = Field(
|
|
28
|
+
default_factory=list, description="Granted scopes/permissions"
|
|
29
|
+
)
|
|
30
|
+
metadata: dict[str, Any] = Field(
|
|
31
|
+
default_factory=dict, description="Additional user attributes"
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Token(Entity):
|
|
36
|
+
"""Represents a raw and parsed authentication token."""
|
|
37
|
+
|
|
38
|
+
raw: str = Field(..., description="The raw JWT string")
|
|
39
|
+
claims: dict[str, Any] = Field(..., description="The parsed claims dictionary")
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def issuer(self) -> str | None:
|
|
43
|
+
return self.claims.get("iss")
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def audience(self) -> str | list[str] | None:
|
|
47
|
+
return self.claims.get("aud")
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def subject(self) -> str | None:
|
|
51
|
+
return self.claims.get("sub")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class UserContext(Entity):
|
|
55
|
+
"""Aggregates User, Token, and Tenant information for a request."""
|
|
56
|
+
|
|
57
|
+
user: User
|
|
58
|
+
token: Token
|
|
59
|
+
tenant: TenantContext | None = None
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
class OpenShieldError(Exception):
|
|
2
|
+
"""Base exception for all Open Shield errors."""
|
|
3
|
+
|
|
4
|
+
pass
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ConfigurationError(OpenShieldError):
|
|
8
|
+
"""Raised when the SDK configuration is invalid."""
|
|
9
|
+
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TokenValidationError(OpenShieldError):
|
|
14
|
+
"""Base exception for token validation failures."""
|
|
15
|
+
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class InvalidSignatureError(TokenValidationError):
|
|
20
|
+
"""Raised when the token signature is invalid."""
|
|
21
|
+
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ExpiredTokenError(TokenValidationError):
|
|
26
|
+
"""Raised when the token has expired."""
|
|
27
|
+
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class InvalidClaimsError(TokenValidationError):
|
|
32
|
+
"""Raised when token claims (iss, aud, etc.) are invalid."""
|
|
33
|
+
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class AuthorizationError(OpenShieldError):
|
|
38
|
+
"""Raised when a user lacks required permissions."""
|
|
39
|
+
|
|
40
|
+
pass
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class KeyProviderPort(ABC):
|
|
6
|
+
"""Abstract interface for retrieving and caching JSON Web Keys (JWKS)."""
|
|
7
|
+
|
|
8
|
+
@abstractmethod
|
|
9
|
+
def get_key(self, kid: str) -> Any:
|
|
10
|
+
"""
|
|
11
|
+
Retrieve a specific key by Key ID (kid).
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
kid: The Key ID to search for.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
The key object (implementation dependent, usually a public key).
|
|
18
|
+
|
|
19
|
+
Raises:
|
|
20
|
+
KeyCorrectionError: If the key cannot be found or retrieved.
|
|
21
|
+
"""
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def get_all_keys(self) -> list[dict[str, Any]]:
|
|
26
|
+
"""
|
|
27
|
+
Retrieve all available keys in JWK format.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
A list of JWK dictionaries.
|
|
31
|
+
"""
|
|
32
|
+
pass
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from open_shield.domain.entities import Token
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TokenValidatorPort(ABC):
|
|
8
|
+
"""Abstract interface for validating JWT tokens."""
|
|
9
|
+
|
|
10
|
+
@abstractmethod
|
|
11
|
+
def validate_token(self, token_string: str) -> Token:
|
|
12
|
+
"""
|
|
13
|
+
Parse and validate a raw JWT string.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
token_string: The raw JWT string.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
A validated Token entity.
|
|
20
|
+
|
|
21
|
+
Raises:
|
|
22
|
+
TokenValidationError: If validation fails (expired, invalid signature, etc).
|
|
23
|
+
"""
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def decode_unverified(self, token_string: str) -> dict[str, Any]:
|
|
28
|
+
"""
|
|
29
|
+
Decode the token without verification (useful for inspecting headers/claims
|
|
30
|
+
pre-validation).
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
token_string: The raw JWT string.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
The decoded claims dictionary.
|
|
37
|
+
"""
|
|
38
|
+
pass
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from open_shield.domain.entities import UserContext
|
|
2
|
+
from open_shield.domain.exceptions import AuthorizationError
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class AuthorizationService:
|
|
6
|
+
"""
|
|
7
|
+
Domain service for enforcing access control policies on a UserContext.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
def require_scope(self, context: UserContext, required_scope: str) -> None:
|
|
11
|
+
"""
|
|
12
|
+
Ensure the user has the required scope.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
context: The authenticated user context.
|
|
16
|
+
required_scope: The specific scope string required.
|
|
17
|
+
|
|
18
|
+
Raises:
|
|
19
|
+
AuthorizationError: If the scope is missing.
|
|
20
|
+
"""
|
|
21
|
+
if required_scope not in context.user.scopes:
|
|
22
|
+
# Check for exact match first.
|
|
23
|
+
# TODO: Implement hierarchical scope checking (e.g. read:users implies
|
|
24
|
+
# read:users:self) if needed.
|
|
25
|
+
raise AuthorizationError(f"Missing required scope: {required_scope}")
|
|
26
|
+
|
|
27
|
+
def require_any_role(self, context: UserContext, roles: list[str]) -> None:
|
|
28
|
+
"""
|
|
29
|
+
Ensure the user has at least one of the required roles.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
context: The authenticated user context.
|
|
33
|
+
roles: A list of allowed roles.
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
AuthorizationError: If the user has none of the required roles.
|
|
37
|
+
"""
|
|
38
|
+
user_roles = set(context.user.roles)
|
|
39
|
+
required_roles = set(roles)
|
|
40
|
+
|
|
41
|
+
if not user_roles.intersection(required_roles):
|
|
42
|
+
raise AuthorizationError(
|
|
43
|
+
f"Missing required role. User has: {user_roles}. "
|
|
44
|
+
f"Required one of: {required_roles}"
|
|
45
|
+
)
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from open_shield.domain.entities import TenantContext, Token, User, UserContext
|
|
2
|
+
from open_shield.domain.exceptions import TokenValidationError
|
|
3
|
+
from open_shield.domain.ports import TokenValidatorPort
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TokenService:
|
|
7
|
+
"""
|
|
8
|
+
Domain service for orchestrating token validation and context extraction.
|
|
9
|
+
Dependencies are injected via constructor (DIP).
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def __init__(self, validator: TokenValidatorPort):
|
|
13
|
+
self.validator = validator
|
|
14
|
+
|
|
15
|
+
def validate_and_extract(self, token_string: str) -> UserContext:
|
|
16
|
+
"""
|
|
17
|
+
Validate a raw token string and extract the user context.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
token_string: The raw JWT from the Authorization header.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
A populated UserContext object.
|
|
24
|
+
|
|
25
|
+
Raises:
|
|
26
|
+
TokenValidationError: If validation fails.
|
|
27
|
+
"""
|
|
28
|
+
token = self.validator.validate_token(token_string)
|
|
29
|
+
user = self._extract_user(token)
|
|
30
|
+
tenant = self._extract_tenant(token)
|
|
31
|
+
|
|
32
|
+
return UserContext(user=user, token=token, tenant=tenant)
|
|
33
|
+
|
|
34
|
+
def _extract_user(self, token: Token) -> User:
|
|
35
|
+
"""Extract user identity and permissions from the token."""
|
|
36
|
+
# Default claim mapping (stateless)
|
|
37
|
+
# TODO: Make claim mapping configurable
|
|
38
|
+
sub = token.subject
|
|
39
|
+
if not sub:
|
|
40
|
+
raise TokenValidationError("Token missing 'sub' claim")
|
|
41
|
+
|
|
42
|
+
email = token.claims.get("email")
|
|
43
|
+
# Standard claim for roles in Keycloak/Auth0 varies.
|
|
44
|
+
# We look for 'roles', 'realm_access.roles', or 'permissions'.
|
|
45
|
+
roles = token.claims.get("roles", [])
|
|
46
|
+
if "realm_access" in token.claims and isinstance(
|
|
47
|
+
token.claims["realm_access"], dict
|
|
48
|
+
):
|
|
49
|
+
roles.extend(token.claims["realm_access"].get("roles", []))
|
|
50
|
+
|
|
51
|
+
scopes = (
|
|
52
|
+
token.claims.get("scope", "").split()
|
|
53
|
+
if isinstance(token.claims.get("scope"), str)
|
|
54
|
+
else []
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
return User(
|
|
58
|
+
id=sub,
|
|
59
|
+
email=email,
|
|
60
|
+
roles=list(set(roles)), # Deduplicate
|
|
61
|
+
scopes=scopes,
|
|
62
|
+
metadata=token.claims,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def _extract_tenant(self, token: Token) -> TenantContext | None:
|
|
66
|
+
"""
|
|
67
|
+
Extract tenant context from the token.
|
|
68
|
+
Strategies:
|
|
69
|
+
1. Custom claim 'tid' or 'org_id'
|
|
70
|
+
2. Issuer URL parsing (e.g. https://auth.com/realms/{tenant})
|
|
71
|
+
"""
|
|
72
|
+
# Simple default strategy: look for specific claims
|
|
73
|
+
# TODO: Make tenant extraction strategy configurable
|
|
74
|
+
tid = (
|
|
75
|
+
token.claims.get("tid")
|
|
76
|
+
or token.claims.get("org_id")
|
|
77
|
+
or token.claims.get("tenant_id")
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
if tid:
|
|
81
|
+
return TenantContext(tenant_id=tid, metadata={})
|
|
82
|
+
|
|
83
|
+
return None
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: open-shield-python
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Vendor-agnostic authentication and authorization enforcement SDK
|
|
5
|
+
Project-URL: Repository, https://github.com/prayog-ai-labs/open-shield-python
|
|
6
|
+
Project-URL: Issues, https://github.com/prayog-ai-labs/open-shield-python/issues
|
|
7
|
+
Author-email: Avinash <avinash@prayog.ai>
|
|
8
|
+
License: MIT
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Keywords: authentication,authorization,fastapi,jwt,oidc,security
|
|
11
|
+
Classifier: Development Status :: 4 - Beta
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Operating System :: OS Independent
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
17
|
+
Classifier: Topic :: Security
|
|
18
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
19
|
+
Requires-Python: >=3.12
|
|
20
|
+
Requires-Dist: cryptography>=42.0.0
|
|
21
|
+
Requires-Dist: fastapi>=0.129.0
|
|
22
|
+
Requires-Dist: httpx>=0.27.0
|
|
23
|
+
Requires-Dist: pydantic-settings>=2.0.0
|
|
24
|
+
Requires-Dist: pydantic>=2.0.0
|
|
25
|
+
Requires-Dist: pyjwt>=2.8.0
|
|
26
|
+
Description-Content-Type: text/markdown
|
|
27
|
+
|
|
28
|
+
# Open Shield Python SDK
|
|
29
|
+
|
|
30
|
+
Vendor-agnostic authentication and authorization enforcement SDK for Python.
|
|
31
|
+
|
|
32
|
+
Open Shield allows you to enforce authentication (AuthN) and authorization (AuthZ) in your Python applications without tightly coupling your code to a specific identity provider (like Auth0, Keycloak, or Cognito).
|
|
33
|
+
|
|
34
|
+
## Features
|
|
35
|
+
|
|
36
|
+
- **Vendor Neutral**: Works with any OIDC-compliant provider.
|
|
37
|
+
- **Framework Agnostic**: Core logic is pure Python; first-class support for **FastAPI**.
|
|
38
|
+
- **Clean Architecture**: Domain logic is isolated from infrastructure concerns.
|
|
39
|
+
- **Type Safe**: Fully typed and checked with `mypy`.
|
|
40
|
+
- **Automatic JWKS Rotation**: Fetches and caches keys from your provider's OIDC discovery endpoint.
|
|
41
|
+
|
|
42
|
+
## Installation
|
|
43
|
+
|
|
44
|
+
```bash
|
|
45
|
+
pip install open-shield-python
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
## Quick Start (FastAPI)
|
|
49
|
+
|
|
50
|
+
1. **Configure Environment**
|
|
51
|
+
|
|
52
|
+
Set the following environment variables:
|
|
53
|
+
|
|
54
|
+
```bash
|
|
55
|
+
OPEN_SHIELD_ISSUER_URL=https://your-auth-domain.com/realms/myrealm
|
|
56
|
+
OPEN_SHIELD_AUDIENCE=my-api-identifier
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
2. **Add Middleware**
|
|
60
|
+
|
|
61
|
+
```python
|
|
62
|
+
from fastapi import FastAPI, Depends
|
|
63
|
+
from open_shield.api.fastapi import OpenShieldMiddleware, get_user_context, RequireScope, RequireRole
|
|
64
|
+
from open_shield.adapters import OpenShieldConfig
|
|
65
|
+
from open_shield.domain.entities import UserContext
|
|
66
|
+
|
|
67
|
+
app = FastAPI()
|
|
68
|
+
|
|
69
|
+
# Load config from environment
|
|
70
|
+
config = OpenShieldConfig()
|
|
71
|
+
|
|
72
|
+
# Add global authentication middleware
|
|
73
|
+
app.add_middleware(OpenShieldMiddleware, config=config)
|
|
74
|
+
|
|
75
|
+
# Public route (excluded by default: /docs, /openapi.json, /redoc, /health)
|
|
76
|
+
@app.get("/health")
|
|
77
|
+
def health():
|
|
78
|
+
return {"status": "ok"}
|
|
79
|
+
|
|
80
|
+
# Protected route (Authentication required)
|
|
81
|
+
@app.get("/users/me")
|
|
82
|
+
def read_current_user(context: UserContext = Depends(get_user_context)):
|
|
83
|
+
return {
|
|
84
|
+
"id": context.user.id,
|
|
85
|
+
"email": context.user.email,
|
|
86
|
+
"roles": context.user.roles
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
# Scoped route (Authorization required)
|
|
90
|
+
@app.get("/admin/dashboard")
|
|
91
|
+
def admin_dashboard(context: UserContext = Depends(RequireScope("read:admin"))):
|
|
92
|
+
return {"data": "secret admin data"}
|
|
93
|
+
|
|
94
|
+
# Role-based route
|
|
95
|
+
@app.get("/manager/reports")
|
|
96
|
+
def reports(context: UserContext = Depends(RequireRole(["manager", "admin"]))):
|
|
97
|
+
return {"reports": []}
|
|
98
|
+
```
|
|
99
|
+
|
|
100
|
+
## Architecture
|
|
101
|
+
|
|
102
|
+
This SDK follows **Clean Architecture** principles:
|
|
103
|
+
|
|
104
|
+
- **Domain Layer**: Pure Python logic, entities, and interfaces (Ports). Zero external dependencies.
|
|
105
|
+
- **Adapters Layer**: Concrete implementations of Ports (e.g., `PyJWT` for validation, `httpx` for JWKS).
|
|
106
|
+
- **API Layer**: Framework specific glue code (e.g., FastAPI Middleware).
|
|
107
|
+
|
|
108
|
+
## configuration
|
|
109
|
+
|
|
110
|
+
| Environment Variable | Description | Default |
|
|
111
|
+
|----------------------|-------------|---------|
|
|
112
|
+
| `OPEN_SHIELD_ISSUER_URL` | OIDC Issuer URL (required) | - |
|
|
113
|
+
| `OPEN_SHIELD_AUDIENCE` | Expected audience (`aud` claim) | None |
|
|
114
|
+
| `OPEN_SHIELD_ALGORITHMS` | Allowed signing algorithms | `["RS256"]` |
|
|
115
|
+
| `OPEN_SHIELD_REQUIRE_SCOPES` | Enforce scope presence | `True` |
|
|
116
|
+
|
|
117
|
+
## Development
|
|
118
|
+
|
|
119
|
+
This project uses `uv` for dependency management.
|
|
120
|
+
|
|
121
|
+
```bash
|
|
122
|
+
# Install dependencies
|
|
123
|
+
uv sync
|
|
124
|
+
|
|
125
|
+
# Run tests
|
|
126
|
+
uv run pytest
|
|
127
|
+
|
|
128
|
+
# Lint and Format
|
|
129
|
+
uv run ruff check .
|
|
130
|
+
```
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
open_shield/adapters/__init__.py,sha256=8FC2rMABRPcXZkl_d5DlFeknMYE5uYCe78mB-LsKLYE,202
|
|
2
|
+
open_shield/adapters/config.py,sha256=qx4XTFxMrtan1OLEmsZU1Os5-zKBs83G8KHI3DZDCKI,634
|
|
3
|
+
open_shield/adapters/key_provider.py,sha256=9zB3PyZZ_3t8wgzUxVWeXHv73Ktc9aI4PnmtAxVGFXc,4271
|
|
4
|
+
open_shield/adapters/token_validator.py,sha256=yXtnMiABrtsKEXMBDoNrNOrrx2Tpajqz7nCx1PV3Ipw,2653
|
|
5
|
+
open_shield/api/__init__.py,sha256=nWkMRvKxlDGQ3rvy6sBTePl8wxCQAAYzoFLhswk-J4Y,45
|
|
6
|
+
open_shield/api/fastapi/__init__.py,sha256=Z2ZLoj_sUY_ZWDx3hCZGvg3YN-BaK2c7yS6HvY_Jh6Q,202
|
|
7
|
+
open_shield/api/fastapi/dependencies.py,sha256=cal5CprB9ASnCjnFAwmU5qH7NwxXPuSHnKljYvRf2sc,1622
|
|
8
|
+
open_shield/api/fastapi/middleware.py,sha256=iEdtB4krvuMJm3zwj3bxQiAgNWaZ1WgnYvBO2K8CtCM,3144
|
|
9
|
+
open_shield/domain/__init__.py,sha256=Oca7GxVhjz5klLy5qTXogPT1jBSSqRtyYO6LN9qokk4,368
|
|
10
|
+
open_shield/domain/entities.py,sha256=5VDt5iJ3PQmQYrj-npEPM39sUfglF0H5dVhlfzcYwK8,1676
|
|
11
|
+
open_shield/domain/exceptions.py,sha256=ECizwHbgKNJeSdP0tTotDoDXth3gWnIoQjzdGjq8pkI,779
|
|
12
|
+
open_shield/domain/ports/__init__.py,sha256=Uv_NXCSY2QY18M875r6DBX5ibmw6ea19I-6S091N0Dc,143
|
|
13
|
+
open_shield/domain/ports/key_provider.py,sha256=sKb5ePgDPbphMcAuNVpCD1jCgu3KGk07ZHNjbq7g2qs,776
|
|
14
|
+
open_shield/domain/ports/token_validator.py,sha256=y9LnLoaMhiOWDwnURmJ4PKAypzr--tq1YjWnar-wfiA,938
|
|
15
|
+
open_shield/domain/services/__init__.py,sha256=H65tB4iEDNyKT2e34GkqCovzREK7yJXAKrQqS-lv4ZM,148
|
|
16
|
+
open_shield/domain/services/authorization_service.py,sha256=YtMpLhzCUFxoQpxn_e28qZ9MVcaBaGGsqccK67Hd4tw,1581
|
|
17
|
+
open_shield/domain/services/token_service.py,sha256=aNu5PNZJzTJ4Z0SPtMAD5isYecYn4m7qCMzG72j0ts0,2778
|
|
18
|
+
open_shield_python-0.1.0.dist-info/METADATA,sha256=aNt-xkCA4FBJatYOHKKqtVu6q4raOdO_YItx0ct8vNY,4200
|
|
19
|
+
open_shield_python-0.1.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
20
|
+
open_shield_python-0.1.0.dist-info/licenses/LICENSE,sha256=33ACXoHh8AAfHNbU6jphGhoIHyJOccH1mVG1lXn60DQ,1071
|
|
21
|
+
open_shield_python-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Prayog AI Labs
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|