belgie-oauth 0.1.0__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- belgie_oauth-0.2.0/PKG-INFO +22 -0
- belgie_oauth-0.2.0/README.md +8 -0
- {belgie_oauth-0.1.0 → belgie_oauth-0.2.0}/pyproject.toml +12 -3
- belgie_oauth-0.2.0/src/belgie_oauth/__init__.py +10 -0
- belgie_oauth-0.2.0/src/belgie_oauth/metadata.py +43 -0
- belgie_oauth-0.2.0/src/belgie_oauth/models.py +111 -0
- belgie_oauth-0.2.0/src/belgie_oauth/plugin.py +444 -0
- belgie_oauth-0.2.0/src/belgie_oauth/provider.py +223 -0
- belgie_oauth-0.2.0/src/belgie_oauth/settings.py +38 -0
- belgie_oauth-0.2.0/src/belgie_oauth/utils.py +27 -0
- belgie_oauth-0.1.0/PKG-INFO +0 -9
- belgie_oauth-0.1.0/README.md +0 -0
- belgie_oauth-0.1.0/src/belgie_oauth/__init__.py +0 -2
- {belgie_oauth-0.1.0 → belgie_oauth-0.2.0}/src/belgie_oauth/py.typed +0 -0
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: belgie-oauth
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: Add your description here
|
|
5
|
+
Author: Matt LeMay
|
|
6
|
+
Author-email: Matt LeMay <mplemay@users.noreply.github.com>
|
|
7
|
+
Requires-Dist: belgie-core
|
|
8
|
+
Requires-Dist: fastapi>=0.100
|
|
9
|
+
Requires-Dist: pydantic>=2.0
|
|
10
|
+
Requires-Dist: pydantic-settings>=2.0
|
|
11
|
+
Requires-Dist: python-multipart>=0.0.20
|
|
12
|
+
Requires-Python: >=3.12, <3.15
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
|
|
15
|
+
# belgie-oauth
|
|
16
|
+
|
|
17
|
+
OAuth 2.1 authorization server package for Belgie.
|
|
18
|
+
|
|
19
|
+
## Persistence
|
|
20
|
+
|
|
21
|
+
`SimpleOAuthProvider` keeps clients and tokens in memory. For production deployments, replace or extend the provider
|
|
22
|
+
with persistent storage.
|
|
@@ -1,14 +1,23 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "belgie-oauth"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.2.0"
|
|
4
4
|
description = "Add your description here"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
authors = [
|
|
7
7
|
{ name = "Matt LeMay", email = "mplemay@users.noreply.github.com" }
|
|
8
8
|
]
|
|
9
|
-
requires-python = ">=3.12"
|
|
10
|
-
dependencies = [
|
|
9
|
+
requires-python = ">=3.12,<3.15"
|
|
10
|
+
dependencies = [
|
|
11
|
+
"belgie-core",
|
|
12
|
+
"fastapi>=0.100",
|
|
13
|
+
"pydantic>=2.0",
|
|
14
|
+
"pydantic-settings>=2.0",
|
|
15
|
+
"python-multipart>=0.0.20",
|
|
16
|
+
]
|
|
11
17
|
|
|
12
18
|
[build-system]
|
|
13
19
|
requires = ["uv_build>=0.9.28,<0.10.0"]
|
|
14
20
|
build-backend = "uv_build"
|
|
21
|
+
|
|
22
|
+
[tool.uv.build-backend]
|
|
23
|
+
source-exclude = ["**/__tests__/**"]
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from belgie_oauth.metadata import build_oauth_metadata, build_oauth_metadata_well_known_path
|
|
2
|
+
from belgie_oauth.plugin import OAuthPlugin
|
|
3
|
+
from belgie_oauth.settings import OAuthSettings
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"OAuthPlugin",
|
|
7
|
+
"OAuthSettings",
|
|
8
|
+
"build_oauth_metadata",
|
|
9
|
+
"build_oauth_metadata_well_known_path",
|
|
10
|
+
]
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
from urllib.parse import urlparse
|
|
5
|
+
|
|
6
|
+
from pydantic import AnyHttpUrl
|
|
7
|
+
|
|
8
|
+
from belgie_oauth.models import OAuthMetadata
|
|
9
|
+
from belgie_oauth.utils import join_url
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from belgie_oauth.settings import OAuthSettings
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def build_oauth_metadata(issuer_url: str, settings: OAuthSettings) -> OAuthMetadata:
|
|
16
|
+
authorization_endpoint = AnyHttpUrl(join_url(issuer_url, "authorize"))
|
|
17
|
+
token_endpoint = AnyHttpUrl(join_url(issuer_url, "token"))
|
|
18
|
+
registration_endpoint = AnyHttpUrl(join_url(issuer_url, "register"))
|
|
19
|
+
revocation_endpoint = AnyHttpUrl(join_url(issuer_url, "revoke"))
|
|
20
|
+
introspection_endpoint = AnyHttpUrl(join_url(issuer_url, "introspect"))
|
|
21
|
+
|
|
22
|
+
return OAuthMetadata(
|
|
23
|
+
issuer=AnyHttpUrl(issuer_url),
|
|
24
|
+
authorization_endpoint=authorization_endpoint,
|
|
25
|
+
token_endpoint=token_endpoint,
|
|
26
|
+
registration_endpoint=registration_endpoint,
|
|
27
|
+
scopes_supported=[settings.default_scope],
|
|
28
|
+
response_types_supported=["code"],
|
|
29
|
+
grant_types_supported=["authorization_code"],
|
|
30
|
+
token_endpoint_auth_methods_supported=["client_secret_post"],
|
|
31
|
+
code_challenge_methods_supported=["S256"],
|
|
32
|
+
revocation_endpoint=revocation_endpoint,
|
|
33
|
+
revocation_endpoint_auth_methods_supported=["client_secret_post"],
|
|
34
|
+
introspection_endpoint=introspection_endpoint,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def build_oauth_metadata_well_known_path(issuer_url: str) -> str:
|
|
39
|
+
parsed = urlparse(issuer_url)
|
|
40
|
+
path = parsed.path.rstrip("/")
|
|
41
|
+
if path and path != "/":
|
|
42
|
+
return f"/.well-known/oauth-authorization-server{path}"
|
|
43
|
+
return "/.well-known/oauth-authorization-server"
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, field_validator
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class OAuthToken(BaseModel):
|
|
9
|
+
access_token: str
|
|
10
|
+
token_type: Literal["Bearer"] = "Bearer" # noqa: S105
|
|
11
|
+
expires_in: int | None = None
|
|
12
|
+
scope: str | None = None
|
|
13
|
+
refresh_token: str | None = None
|
|
14
|
+
|
|
15
|
+
@field_validator("token_type", mode="before")
|
|
16
|
+
@classmethod
|
|
17
|
+
def normalize_token_type(cls, value: str | None) -> str | None:
|
|
18
|
+
if isinstance(value, str):
|
|
19
|
+
return value.title()
|
|
20
|
+
return value
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class InvalidScopeError(Exception):
|
|
24
|
+
def __init__(self, message: str) -> None:
|
|
25
|
+
self.message = message
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class InvalidRedirectUriError(Exception):
|
|
29
|
+
def __init__(self, message: str) -> None:
|
|
30
|
+
self.message = message
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class OAuthClientMetadata(BaseModel):
|
|
34
|
+
redirect_uris: list[AnyUrl] | None = Field(..., min_length=1)
|
|
35
|
+
token_endpoint_auth_method: (
|
|
36
|
+
Literal[
|
|
37
|
+
"none",
|
|
38
|
+
"client_secret_post",
|
|
39
|
+
"client_secret_basic",
|
|
40
|
+
"private_key_jwt",
|
|
41
|
+
]
|
|
42
|
+
| None
|
|
43
|
+
) = None
|
|
44
|
+
grant_types: list[str] = ["authorization_code", "refresh_token"]
|
|
45
|
+
response_types: list[str] = ["code"]
|
|
46
|
+
scope: str | None = None
|
|
47
|
+
|
|
48
|
+
client_name: str | None = None
|
|
49
|
+
client_uri: AnyHttpUrl | None = None
|
|
50
|
+
logo_uri: AnyHttpUrl | None = None
|
|
51
|
+
contacts: list[str] | None = None
|
|
52
|
+
tos_uri: AnyHttpUrl | None = None
|
|
53
|
+
policy_uri: AnyHttpUrl | None = None
|
|
54
|
+
jwks_uri: AnyHttpUrl | None = None
|
|
55
|
+
jwks: Any | None = None
|
|
56
|
+
software_id: str | None = None
|
|
57
|
+
software_version: str | None = None
|
|
58
|
+
|
|
59
|
+
def validate_scope(self, requested_scope: str | None) -> list[str] | None:
|
|
60
|
+
if requested_scope is None:
|
|
61
|
+
return None
|
|
62
|
+
requested_scopes = requested_scope.split(" ")
|
|
63
|
+
allowed_scopes = [] if self.scope is None else self.scope.split(" ")
|
|
64
|
+
for scope in requested_scopes:
|
|
65
|
+
if scope not in allowed_scopes:
|
|
66
|
+
message = f"Client was not registered with scope {scope}"
|
|
67
|
+
raise InvalidScopeError(message)
|
|
68
|
+
return requested_scopes
|
|
69
|
+
|
|
70
|
+
def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl:
|
|
71
|
+
if redirect_uri is not None:
|
|
72
|
+
if self.redirect_uris is None or redirect_uri not in self.redirect_uris:
|
|
73
|
+
message = f"Redirect URI '{redirect_uri}' not registered for client"
|
|
74
|
+
raise InvalidRedirectUriError(message)
|
|
75
|
+
return redirect_uri
|
|
76
|
+
if self.redirect_uris is not None and len(self.redirect_uris) == 1:
|
|
77
|
+
return self.redirect_uris[0]
|
|
78
|
+
message = "redirect_uri must be specified when client has multiple registered URIs"
|
|
79
|
+
raise InvalidRedirectUriError(message)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class OAuthClientInformationFull(OAuthClientMetadata):
|
|
83
|
+
client_id: str | None = None
|
|
84
|
+
client_secret: str | None = None
|
|
85
|
+
client_id_issued_at: int | None = None
|
|
86
|
+
client_secret_expires_at: int | None = None
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class OAuthMetadata(BaseModel):
|
|
90
|
+
issuer: AnyHttpUrl
|
|
91
|
+
authorization_endpoint: AnyHttpUrl
|
|
92
|
+
token_endpoint: AnyHttpUrl
|
|
93
|
+
registration_endpoint: AnyHttpUrl | None = None
|
|
94
|
+
scopes_supported: list[str] | None = None
|
|
95
|
+
response_types_supported: list[str] = ["code"]
|
|
96
|
+
response_modes_supported: list[str] | None = None
|
|
97
|
+
grant_types_supported: list[str] | None = None
|
|
98
|
+
token_endpoint_auth_methods_supported: list[str] | None = None
|
|
99
|
+
token_endpoint_auth_signing_alg_values_supported: list[str] | None = None
|
|
100
|
+
service_documentation: AnyHttpUrl | None = None
|
|
101
|
+
ui_locales_supported: list[str] | None = None
|
|
102
|
+
op_policy_uri: AnyHttpUrl | None = None
|
|
103
|
+
op_tos_uri: AnyHttpUrl | None = None
|
|
104
|
+
revocation_endpoint: AnyHttpUrl | None = None
|
|
105
|
+
revocation_endpoint_auth_methods_supported: list[str] | None = None
|
|
106
|
+
revocation_endpoint_auth_signing_alg_values_supported: list[str] | None = None
|
|
107
|
+
introspection_endpoint: AnyHttpUrl | None = None
|
|
108
|
+
introspection_endpoint_auth_methods_supported: list[str] | None = None
|
|
109
|
+
introspection_endpoint_auth_signing_alg_values_supported: list[str] | None = None
|
|
110
|
+
code_challenge_methods_supported: list[str] | None = None
|
|
111
|
+
client_id_metadata_document_supported: bool | None = None
|
|
@@ -0,0 +1,444 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import secrets
|
|
4
|
+
import time
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
from urllib.parse import urlparse, urlunparse
|
|
7
|
+
|
|
8
|
+
from belgie_core.core.protocols import Plugin
|
|
9
|
+
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
|
10
|
+
from fastapi.responses import JSONResponse, RedirectResponse, Response
|
|
11
|
+
from fastapi.security import SecurityScopes
|
|
12
|
+
from pydantic import AnyUrl, ValidationError
|
|
13
|
+
|
|
14
|
+
from belgie_oauth.metadata import build_oauth_metadata, build_oauth_metadata_well_known_path
|
|
15
|
+
from belgie_oauth.models import (
|
|
16
|
+
InvalidRedirectUriError,
|
|
17
|
+
InvalidScopeError,
|
|
18
|
+
OAuthClientInformationFull,
|
|
19
|
+
OAuthClientMetadata,
|
|
20
|
+
OAuthMetadata,
|
|
21
|
+
)
|
|
22
|
+
from belgie_oauth.provider import AuthorizationParams, SimpleOAuthProvider
|
|
23
|
+
from belgie_oauth.utils import construct_redirect_uri, create_code_challenge, join_url
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from collections.abc import Mapping
|
|
27
|
+
|
|
28
|
+
from belgie_core.core.belgie import Belgie
|
|
29
|
+
from belgie_core.core.client import BelgieClient
|
|
30
|
+
|
|
31
|
+
from belgie_oauth.settings import OAuthSettings
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class OAuthPlugin(Plugin):
|
|
35
|
+
def __init__(self, settings: OAuthSettings) -> None:
|
|
36
|
+
self._settings = settings
|
|
37
|
+
self._provider: SimpleOAuthProvider | None = None
|
|
38
|
+
self._metadata_router: APIRouter | None = None
|
|
39
|
+
|
|
40
|
+
def router(self, belgie: Belgie) -> APIRouter:
|
|
41
|
+
issuer_url = (
|
|
42
|
+
str(self._settings.issuer_url) if self._settings.issuer_url else _build_issuer_url(belgie, self._settings)
|
|
43
|
+
)
|
|
44
|
+
if self._provider is None:
|
|
45
|
+
self._provider = SimpleOAuthProvider(self._settings, issuer_url=issuer_url)
|
|
46
|
+
provider = self._provider
|
|
47
|
+
|
|
48
|
+
self._metadata_router = self.metadata_router(belgie)
|
|
49
|
+
|
|
50
|
+
router = APIRouter(prefix=self._settings.route_prefix, tags=["oauth"])
|
|
51
|
+
metadata = build_oauth_metadata(issuer_url, self._settings)
|
|
52
|
+
|
|
53
|
+
router = self._add_metadata_route(router, metadata)
|
|
54
|
+
router = self._add_authorize_route(router, belgie, provider, self._settings, issuer_url)
|
|
55
|
+
router = self._add_token_route(router, provider)
|
|
56
|
+
router = self._add_register_route(router, provider)
|
|
57
|
+
router = self._add_revoke_route(router, provider)
|
|
58
|
+
router = self._add_login_route(router, belgie, issuer_url, self._settings)
|
|
59
|
+
router = self._add_login_callback_route(router, belgie, provider)
|
|
60
|
+
return self._add_introspect_route(router, provider)
|
|
61
|
+
|
|
62
|
+
def metadata_router(self, belgie: Belgie) -> APIRouter:
|
|
63
|
+
issuer_url = (
|
|
64
|
+
str(self._settings.issuer_url) if self._settings.issuer_url else _build_issuer_url(belgie, self._settings)
|
|
65
|
+
)
|
|
66
|
+
metadata = build_oauth_metadata(issuer_url, self._settings)
|
|
67
|
+
well_known_path = build_oauth_metadata_well_known_path(issuer_url)
|
|
68
|
+
|
|
69
|
+
def create_oauth_metadata_router() -> APIRouter:
|
|
70
|
+
router = APIRouter(tags=["oauth"])
|
|
71
|
+
|
|
72
|
+
async def metadata_handler(_: Request) -> Response:
|
|
73
|
+
return JSONResponse(metadata.model_dump(mode="json"))
|
|
74
|
+
|
|
75
|
+
router.add_api_route(well_known_path, metadata_handler, methods=["GET"])
|
|
76
|
+
return router
|
|
77
|
+
|
|
78
|
+
return create_oauth_metadata_router()
|
|
79
|
+
|
|
80
|
+
def public_router(self, belgie: Belgie) -> APIRouter:
|
|
81
|
+
if self._metadata_router is None:
|
|
82
|
+
self._metadata_router = self.metadata_router(belgie)
|
|
83
|
+
return self._metadata_router
|
|
84
|
+
|
|
85
|
+
@staticmethod
|
|
86
|
+
def _add_metadata_route(router: APIRouter, metadata: OAuthMetadata) -> APIRouter:
|
|
87
|
+
async def metadata_handler(_: Request) -> Response:
|
|
88
|
+
return JSONResponse(metadata.model_dump(mode="json"))
|
|
89
|
+
|
|
90
|
+
router.add_api_route(
|
|
91
|
+
"/.well-known/oauth-authorization-server",
|
|
92
|
+
metadata_handler,
|
|
93
|
+
methods=["GET"],
|
|
94
|
+
)
|
|
95
|
+
return router
|
|
96
|
+
|
|
97
|
+
@staticmethod
|
|
98
|
+
def _add_authorize_route(
|
|
99
|
+
router: APIRouter,
|
|
100
|
+
belgie: Belgie,
|
|
101
|
+
provider: SimpleOAuthProvider,
|
|
102
|
+
settings: OAuthSettings,
|
|
103
|
+
issuer_url: str,
|
|
104
|
+
) -> APIRouter:
|
|
105
|
+
async def authorize_handler(
|
|
106
|
+
request: Request,
|
|
107
|
+
client: BelgieClient = Depends(belgie), # noqa: B008
|
|
108
|
+
) -> Response:
|
|
109
|
+
data = await _get_request_params(request)
|
|
110
|
+
oauth_client, params = await _parse_authorize_params(data, provider, settings)
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
await client.get_user(SecurityScopes(), request)
|
|
114
|
+
except HTTPException as exc:
|
|
115
|
+
if exc.status_code == status.HTTP_401_UNAUTHORIZED:
|
|
116
|
+
if not settings.login_url:
|
|
117
|
+
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="login_required") from exc
|
|
118
|
+
|
|
119
|
+
state_value = await _authorize_state(provider, oauth_client, params)
|
|
120
|
+
login_url = _build_login_redirect(issuer_url, state_value)
|
|
121
|
+
return RedirectResponse(url=login_url, status_code=status.HTTP_302_FOUND)
|
|
122
|
+
raise
|
|
123
|
+
|
|
124
|
+
state_value = await _authorize_state(provider, oauth_client, params)
|
|
125
|
+
redirect_url = await _issue_authorization_code(provider, state_value)
|
|
126
|
+
return RedirectResponse(url=redirect_url, status_code=status.HTTP_302_FOUND)
|
|
127
|
+
|
|
128
|
+
router.add_api_route("/authorize", authorize_handler, methods=["GET", "POST"])
|
|
129
|
+
return router
|
|
130
|
+
|
|
131
|
+
@staticmethod
|
|
132
|
+
def _add_token_route(router: APIRouter, provider: SimpleOAuthProvider) -> APIRouter: # noqa: C901
|
|
133
|
+
async def token_handler(request: Request) -> Response: # noqa: C901, PLR0911
|
|
134
|
+
form = await request.form()
|
|
135
|
+
grant_type = _get_str(form, "grant_type")
|
|
136
|
+
if grant_type != "authorization_code":
|
|
137
|
+
return _oauth_error("unsupported_grant_type", status_code=400)
|
|
138
|
+
|
|
139
|
+
code = _get_str(form, "code")
|
|
140
|
+
if not code:
|
|
141
|
+
return _oauth_error("invalid_request", "missing code", status_code=400)
|
|
142
|
+
|
|
143
|
+
client_id = _get_str(form, "client_id")
|
|
144
|
+
if not client_id:
|
|
145
|
+
return _oauth_error("invalid_client", status_code=401)
|
|
146
|
+
|
|
147
|
+
oauth_client = await provider.get_client(client_id)
|
|
148
|
+
if not oauth_client:
|
|
149
|
+
return _oauth_error("invalid_client", status_code=401)
|
|
150
|
+
|
|
151
|
+
client_secret = _get_str(form, "client_secret")
|
|
152
|
+
if oauth_client.client_secret and client_secret != oauth_client.client_secret:
|
|
153
|
+
return _oauth_error("invalid_client", status_code=401)
|
|
154
|
+
|
|
155
|
+
authorization_code = await provider.load_authorization_code(code)
|
|
156
|
+
if not authorization_code:
|
|
157
|
+
return _oauth_error("invalid_grant", status_code=400)
|
|
158
|
+
|
|
159
|
+
if authorization_code.expires_at < time.time():
|
|
160
|
+
return _oauth_error("invalid_grant", "code expired", status_code=400)
|
|
161
|
+
|
|
162
|
+
redirect_uri_raw = _get_str(form, "redirect_uri")
|
|
163
|
+
if client_id != authorization_code.client_id:
|
|
164
|
+
return _oauth_error("invalid_grant", "client_id mismatch", status_code=400)
|
|
165
|
+
|
|
166
|
+
if authorization_code.redirect_uri_provided_explicitly and not redirect_uri_raw:
|
|
167
|
+
return _oauth_error("invalid_request", "missing redirect_uri", status_code=400)
|
|
168
|
+
if redirect_uri_raw and redirect_uri_raw != str(authorization_code.redirect_uri):
|
|
169
|
+
return _oauth_error("invalid_grant", "redirect_uri mismatch", status_code=400)
|
|
170
|
+
|
|
171
|
+
code_verifier = _get_str(form, "code_verifier")
|
|
172
|
+
if not code_verifier:
|
|
173
|
+
return _oauth_error("invalid_request", "missing code_verifier", status_code=400)
|
|
174
|
+
|
|
175
|
+
expected_challenge = create_code_challenge(code_verifier)
|
|
176
|
+
if expected_challenge != authorization_code.code_challenge:
|
|
177
|
+
return _oauth_error("invalid_grant", "invalid code_verifier", status_code=400)
|
|
178
|
+
|
|
179
|
+
token = await provider.exchange_authorization_code(authorization_code)
|
|
180
|
+
return JSONResponse(token.model_dump())
|
|
181
|
+
|
|
182
|
+
router.add_api_route("/token", token_handler, methods=["POST"])
|
|
183
|
+
return router
|
|
184
|
+
|
|
185
|
+
@staticmethod
|
|
186
|
+
def _add_register_route(router: APIRouter, provider: SimpleOAuthProvider) -> APIRouter:
|
|
187
|
+
async def register_handler(request: Request) -> Response:
|
|
188
|
+
try:
|
|
189
|
+
payload = await request.json()
|
|
190
|
+
metadata = OAuthClientMetadata.model_validate(payload)
|
|
191
|
+
except ValidationError as exc:
|
|
192
|
+
return _oauth_error(
|
|
193
|
+
"invalid_request",
|
|
194
|
+
_format_validation_error(exc),
|
|
195
|
+
status_code=400,
|
|
196
|
+
)
|
|
197
|
+
except ValueError as exc:
|
|
198
|
+
description = str(exc) or "invalid client metadata"
|
|
199
|
+
return _oauth_error("invalid_request", description, status_code=400)
|
|
200
|
+
|
|
201
|
+
try:
|
|
202
|
+
client_info = await provider.register_client(metadata)
|
|
203
|
+
except ValueError as exc:
|
|
204
|
+
description = str(exc) or "invalid client metadata"
|
|
205
|
+
return _oauth_error("invalid_request", description, status_code=400)
|
|
206
|
+
return JSONResponse(client_info.model_dump(mode="json"))
|
|
207
|
+
|
|
208
|
+
router.add_api_route("/register", register_handler, methods=["POST"])
|
|
209
|
+
return router
|
|
210
|
+
|
|
211
|
+
@staticmethod
|
|
212
|
+
def _add_revoke_route(router: APIRouter, provider: SimpleOAuthProvider) -> APIRouter:
|
|
213
|
+
async def revoke_handler(request: Request) -> Response:
|
|
214
|
+
form = await request.form()
|
|
215
|
+
client_id: str | None = _get_str(form, "client_id")
|
|
216
|
+
if not client_id:
|
|
217
|
+
return _oauth_error("invalid_request", "missing client_id", status_code=400)
|
|
218
|
+
|
|
219
|
+
oauth_client = await provider.get_client(client_id)
|
|
220
|
+
if not oauth_client:
|
|
221
|
+
return _oauth_error("invalid_client", status_code=401)
|
|
222
|
+
|
|
223
|
+
client_secret: str | None = _get_str(form, "client_secret")
|
|
224
|
+
if oauth_client.client_secret:
|
|
225
|
+
if not client_secret:
|
|
226
|
+
return _oauth_error("invalid_request", "missing client_secret", status_code=400)
|
|
227
|
+
if client_secret != oauth_client.client_secret:
|
|
228
|
+
return _oauth_error("invalid_client", status_code=401)
|
|
229
|
+
|
|
230
|
+
token: str | None = _get_str(form, "token")
|
|
231
|
+
if not token:
|
|
232
|
+
return _oauth_error("invalid_request", "missing token", status_code=400)
|
|
233
|
+
|
|
234
|
+
access_token = await provider.load_access_token(token)
|
|
235
|
+
if access_token:
|
|
236
|
+
await provider.revoke_token(access_token)
|
|
237
|
+
return JSONResponse({})
|
|
238
|
+
|
|
239
|
+
router.add_api_route("/revoke", revoke_handler, methods=["POST"])
|
|
240
|
+
return router
|
|
241
|
+
|
|
242
|
+
@staticmethod
|
|
243
|
+
def _add_login_route(
|
|
244
|
+
router: APIRouter,
|
|
245
|
+
belgie: Belgie,
|
|
246
|
+
issuer_url: str,
|
|
247
|
+
settings: OAuthSettings,
|
|
248
|
+
) -> APIRouter:
|
|
249
|
+
async def login_handler(request: Request) -> Response:
|
|
250
|
+
state = request.query_params.get("state")
|
|
251
|
+
if not state:
|
|
252
|
+
raise HTTPException(status_code=400, detail="missing state")
|
|
253
|
+
|
|
254
|
+
if not settings.login_url:
|
|
255
|
+
raise HTTPException(status_code=400, detail="login_url not configured")
|
|
256
|
+
|
|
257
|
+
parsed_login_url = urlparse(settings.login_url)
|
|
258
|
+
if parsed_login_url.scheme in {"http", "https"}:
|
|
259
|
+
login_url = settings.login_url
|
|
260
|
+
else:
|
|
261
|
+
login_url = join_url(belgie.settings.base_url, settings.login_url)
|
|
262
|
+
|
|
263
|
+
return_to_base = join_url(issuer_url, "login/callback")
|
|
264
|
+
# Build a callback URL with state, then wrap it into the login redirect as return_to.
|
|
265
|
+
return_to_url = construct_redirect_uri(return_to_base, state=state)
|
|
266
|
+
redirect_url = construct_redirect_uri(login_url, return_to=return_to_url)
|
|
267
|
+
return RedirectResponse(url=redirect_url, status_code=status.HTTP_302_FOUND)
|
|
268
|
+
|
|
269
|
+
router.add_api_route("/login", login_handler, methods=["GET"])
|
|
270
|
+
return router
|
|
271
|
+
|
|
272
|
+
@staticmethod
|
|
273
|
+
def _add_login_callback_route(
|
|
274
|
+
router: APIRouter,
|
|
275
|
+
belgie: Belgie,
|
|
276
|
+
provider: SimpleOAuthProvider,
|
|
277
|
+
) -> APIRouter:
|
|
278
|
+
async def login_callback_handler(
|
|
279
|
+
request: Request,
|
|
280
|
+
client: BelgieClient = Depends(belgie), # noqa: B008
|
|
281
|
+
) -> Response:
|
|
282
|
+
state = request.query_params.get("state")
|
|
283
|
+
if not state:
|
|
284
|
+
raise HTTPException(status_code=400, detail="missing state")
|
|
285
|
+
|
|
286
|
+
try:
|
|
287
|
+
await client.get_user(SecurityScopes(), request)
|
|
288
|
+
except HTTPException as exc:
|
|
289
|
+
if exc.status_code == status.HTTP_401_UNAUTHORIZED:
|
|
290
|
+
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="login_required") from exc
|
|
291
|
+
raise
|
|
292
|
+
|
|
293
|
+
try:
|
|
294
|
+
redirect_url = await provider.issue_authorization_code(state)
|
|
295
|
+
except ValueError as exc:
|
|
296
|
+
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
297
|
+
return RedirectResponse(url=redirect_url, status_code=status.HTTP_302_FOUND)
|
|
298
|
+
|
|
299
|
+
router.add_api_route("/login/callback", login_callback_handler, methods=["GET"])
|
|
300
|
+
return router
|
|
301
|
+
|
|
302
|
+
@staticmethod
|
|
303
|
+
def _add_introspect_route(router: APIRouter, provider: SimpleOAuthProvider) -> APIRouter:
|
|
304
|
+
async def introspect_handler(request: Request) -> Response:
|
|
305
|
+
form = await request.form()
|
|
306
|
+
token = _get_str(form, "token")
|
|
307
|
+
if not token:
|
|
308
|
+
return JSONResponse({"active": False}, status_code=400)
|
|
309
|
+
|
|
310
|
+
access_token = await provider.load_access_token(token)
|
|
311
|
+
if not access_token:
|
|
312
|
+
return JSONResponse({"active": False})
|
|
313
|
+
|
|
314
|
+
return JSONResponse(
|
|
315
|
+
{
|
|
316
|
+
"active": True,
|
|
317
|
+
"client_id": access_token.client_id,
|
|
318
|
+
"scope": " ".join(access_token.scopes),
|
|
319
|
+
"exp": access_token.expires_at,
|
|
320
|
+
"iat": access_token.created_at,
|
|
321
|
+
"token_type": "Bearer",
|
|
322
|
+
"aud": access_token.resource,
|
|
323
|
+
},
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
router.add_api_route("/introspect", introspect_handler, methods=["POST"])
|
|
327
|
+
return router
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def _build_issuer_url(belgie: Belgie, settings: OAuthSettings) -> str:
|
|
331
|
+
parsed = urlparse(belgie.settings.base_url)
|
|
332
|
+
base_path = parsed.path.rstrip("/")
|
|
333
|
+
prefix = settings.route_prefix.strip("/")
|
|
334
|
+
auth_path = "auth"
|
|
335
|
+
full_path = f"{base_path}/{auth_path}/{prefix}" if prefix else f"{base_path}/{auth_path}"
|
|
336
|
+
return urlunparse(parsed._replace(path=full_path, query="", fragment=""))
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
async def _parse_authorize_params(
|
|
340
|
+
data: dict[str, str],
|
|
341
|
+
provider: SimpleOAuthProvider,
|
|
342
|
+
settings: OAuthSettings,
|
|
343
|
+
) -> tuple[OAuthClientInformationFull, AuthorizationParams]:
|
|
344
|
+
response_type = _get_str(data, "response_type")
|
|
345
|
+
if response_type != "code":
|
|
346
|
+
raise HTTPException(status_code=400, detail="unsupported_response_type")
|
|
347
|
+
|
|
348
|
+
client_id = _get_str(data, "client_id")
|
|
349
|
+
if not client_id:
|
|
350
|
+
raise HTTPException(status_code=400, detail="missing client_id")
|
|
351
|
+
|
|
352
|
+
oauth_client = await provider.get_client(client_id)
|
|
353
|
+
if not oauth_client:
|
|
354
|
+
raise HTTPException(status_code=400, detail="invalid_client")
|
|
355
|
+
|
|
356
|
+
redirect_uri_raw = _get_str(data, "redirect_uri")
|
|
357
|
+
redirect_uri = AnyUrl(redirect_uri_raw) if redirect_uri_raw else None
|
|
358
|
+
try:
|
|
359
|
+
validated_redirect_uri = oauth_client.validate_redirect_uri(redirect_uri)
|
|
360
|
+
except InvalidRedirectUriError as exc:
|
|
361
|
+
raise HTTPException(status_code=400, detail=exc.message) from exc
|
|
362
|
+
|
|
363
|
+
scope_raw = _get_str(data, "scope")
|
|
364
|
+
try:
|
|
365
|
+
scopes = oauth_client.validate_scope(scope_raw)
|
|
366
|
+
except InvalidScopeError as exc:
|
|
367
|
+
raise HTTPException(status_code=400, detail=exc.message) from exc
|
|
368
|
+
if scopes is None:
|
|
369
|
+
scopes = [settings.default_scope]
|
|
370
|
+
|
|
371
|
+
code_challenge = _get_str(data, "code_challenge")
|
|
372
|
+
if not code_challenge:
|
|
373
|
+
raise HTTPException(status_code=400, detail="missing code_challenge")
|
|
374
|
+
|
|
375
|
+
code_challenge_method = _get_str(data, "code_challenge_method") or settings.code_challenge_method
|
|
376
|
+
if code_challenge_method != "S256":
|
|
377
|
+
raise HTTPException(status_code=400, detail="unsupported code_challenge_method")
|
|
378
|
+
|
|
379
|
+
resource = _get_str(data, "resource")
|
|
380
|
+
state = _get_str(data, "state") or secrets.token_hex(16)
|
|
381
|
+
|
|
382
|
+
params = AuthorizationParams(
|
|
383
|
+
state=state,
|
|
384
|
+
scopes=scopes,
|
|
385
|
+
code_challenge=code_challenge,
|
|
386
|
+
redirect_uri=validated_redirect_uri,
|
|
387
|
+
redirect_uri_provided_explicitly=redirect_uri_raw is not None,
|
|
388
|
+
resource=resource,
|
|
389
|
+
)
|
|
390
|
+
return oauth_client, params
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
async def _authorize_state(
|
|
394
|
+
provider: SimpleOAuthProvider,
|
|
395
|
+
oauth_client: OAuthClientInformationFull,
|
|
396
|
+
params: AuthorizationParams,
|
|
397
|
+
) -> str:
|
|
398
|
+
try:
|
|
399
|
+
return await provider.authorize(oauth_client, params)
|
|
400
|
+
except ValueError as exc:
|
|
401
|
+
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
async def _issue_authorization_code(provider: SimpleOAuthProvider, state: str) -> str:
|
|
405
|
+
try:
|
|
406
|
+
return await provider.issue_authorization_code(state)
|
|
407
|
+
except ValueError as exc:
|
|
408
|
+
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def _build_login_redirect(issuer_url: str, state: str) -> str:
|
|
412
|
+
return construct_redirect_uri(join_url(issuer_url, "login"), state=state)
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def _oauth_error(error: str, description: str | None = None, status_code: int = 400) -> JSONResponse:
|
|
416
|
+
payload: dict[str, Any] = {"error": error}
|
|
417
|
+
if description:
|
|
418
|
+
payload["error_description"] = description
|
|
419
|
+
return JSONResponse(payload, status_code=status_code)
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def _format_validation_error(error: ValidationError) -> str:
|
|
423
|
+
entries = error.errors()
|
|
424
|
+
if not entries:
|
|
425
|
+
return "invalid client metadata"
|
|
426
|
+
entry = entries[0]
|
|
427
|
+
loc = ".".join(str(part) for part in entry.get("loc", []) if part is not None)
|
|
428
|
+
msg = entry.get("msg", "invalid client metadata")
|
|
429
|
+
if loc:
|
|
430
|
+
return f"{loc}: {msg}"
|
|
431
|
+
return msg
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
async def _get_request_params(request: Request) -> dict[str, str]:
|
|
435
|
+
if request.method == "GET":
|
|
436
|
+
return dict(request.query_params)
|
|
437
|
+
return dict(await request.form())
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
def _get_str(data: Mapping[str, Any], key: str) -> str | None:
|
|
441
|
+
value = data.get(key)
|
|
442
|
+
if isinstance(value, str):
|
|
443
|
+
return value
|
|
444
|
+
return None
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import secrets
|
|
4
|
+
import time
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from pydantic import AnyUrl
|
|
9
|
+
|
|
10
|
+
from belgie_oauth.models import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
|
|
11
|
+
from belgie_oauth.utils import construct_redirect_uri
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from belgie_oauth.settings import OAuthSettings
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(frozen=True, slots=True, kw_only=True)
|
|
18
|
+
class AuthorizationParams:
|
|
19
|
+
state: str | None
|
|
20
|
+
scopes: list[str] | None
|
|
21
|
+
code_challenge: str
|
|
22
|
+
redirect_uri: AnyUrl
|
|
23
|
+
redirect_uri_provided_explicitly: bool
|
|
24
|
+
resource: str | None = None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass(frozen=True, slots=True, kw_only=True)
|
|
28
|
+
class AuthorizationCode:
|
|
29
|
+
code: str
|
|
30
|
+
scopes: list[str]
|
|
31
|
+
expires_at: float
|
|
32
|
+
client_id: str
|
|
33
|
+
code_challenge: str
|
|
34
|
+
redirect_uri: AnyUrl
|
|
35
|
+
redirect_uri_provided_explicitly: bool
|
|
36
|
+
resource: str | None = None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass(frozen=True, slots=True, kw_only=True)
|
|
40
|
+
class RefreshToken:
|
|
41
|
+
token: str
|
|
42
|
+
client_id: str
|
|
43
|
+
scopes: list[str]
|
|
44
|
+
expires_at: int | None = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass(frozen=True, slots=True, kw_only=True)
|
|
48
|
+
class AccessToken:
|
|
49
|
+
token: str
|
|
50
|
+
client_id: str
|
|
51
|
+
scopes: list[str]
|
|
52
|
+
created_at: int
|
|
53
|
+
expires_at: int | None = None
|
|
54
|
+
resource: str | None = None
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass(frozen=True, slots=True, kw_only=True)
|
|
58
|
+
class StateEntry:
|
|
59
|
+
redirect_uri: str
|
|
60
|
+
code_challenge: str
|
|
61
|
+
redirect_uri_provided_explicitly: bool
|
|
62
|
+
client_id: str
|
|
63
|
+
resource: str | None
|
|
64
|
+
scopes: list[str] | None
|
|
65
|
+
created_at: float
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class SimpleOAuthProvider:
|
|
69
|
+
def __init__(self, settings: OAuthSettings, issuer_url: str) -> None:
|
|
70
|
+
self.settings = settings
|
|
71
|
+
self.issuer_url = issuer_url
|
|
72
|
+
self.clients: dict[str, OAuthClientInformationFull] = {}
|
|
73
|
+
self.auth_codes: dict[str, AuthorizationCode] = {}
|
|
74
|
+
self.tokens: dict[str, AccessToken] = {}
|
|
75
|
+
self.state_mapping: dict[str, StateEntry] = {}
|
|
76
|
+
|
|
77
|
+
client_secret = settings.client_secret.get_secret_value() if settings.client_secret is not None else None
|
|
78
|
+
self.clients[settings.client_id] = OAuthClientInformationFull(
|
|
79
|
+
client_id=settings.client_id,
|
|
80
|
+
client_secret=client_secret,
|
|
81
|
+
redirect_uris=settings.redirect_uris,
|
|
82
|
+
scope=settings.default_scope,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
|
86
|
+
return self.clients.get(client_id)
|
|
87
|
+
|
|
88
|
+
async def register_client(self, metadata: OAuthClientMetadata) -> OAuthClientInformationFull:
|
|
89
|
+
token_endpoint_auth_method = metadata.token_endpoint_auth_method or "client_secret_post"
|
|
90
|
+
if token_endpoint_auth_method not in {"client_secret_post", "none"}:
|
|
91
|
+
msg = f"unsupported token_endpoint_auth_method: {token_endpoint_auth_method}"
|
|
92
|
+
raise ValueError(msg)
|
|
93
|
+
client_secret = None
|
|
94
|
+
if token_endpoint_auth_method != "none": # noqa: S105
|
|
95
|
+
client_secret = secrets.token_hex(16)
|
|
96
|
+
|
|
97
|
+
client_id = f"belgie_client_{secrets.token_hex(8)}"
|
|
98
|
+
while client_id in self.clients:
|
|
99
|
+
client_id = f"belgie_client_{secrets.token_hex(8)}"
|
|
100
|
+
|
|
101
|
+
metadata_payload = metadata.model_dump()
|
|
102
|
+
metadata_payload["token_endpoint_auth_method"] = token_endpoint_auth_method
|
|
103
|
+
client_info = OAuthClientInformationFull(
|
|
104
|
+
**metadata_payload,
|
|
105
|
+
client_id=client_id,
|
|
106
|
+
client_secret=client_secret,
|
|
107
|
+
client_id_issued_at=int(time.time()),
|
|
108
|
+
client_secret_expires_at=None,
|
|
109
|
+
)
|
|
110
|
+
self.clients[client_id] = client_info
|
|
111
|
+
return client_info
|
|
112
|
+
|
|
113
|
+
async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
|
|
114
|
+
self._purge_state_mapping()
|
|
115
|
+
state = params.state or secrets.token_hex(16)
|
|
116
|
+
if state in self.state_mapping:
|
|
117
|
+
msg = "Authorization state already exists"
|
|
118
|
+
raise ValueError(msg)
|
|
119
|
+
self.state_mapping[state] = StateEntry(
|
|
120
|
+
redirect_uri=str(params.redirect_uri),
|
|
121
|
+
code_challenge=params.code_challenge,
|
|
122
|
+
redirect_uri_provided_explicitly=params.redirect_uri_provided_explicitly,
|
|
123
|
+
client_id=client.client_id,
|
|
124
|
+
resource=params.resource,
|
|
125
|
+
scopes=params.scopes,
|
|
126
|
+
created_at=time.time(),
|
|
127
|
+
)
|
|
128
|
+
return state
|
|
129
|
+
|
|
130
|
+
async def issue_authorization_code(self, state: str) -> str:
|
|
131
|
+
self._purge_state_mapping()
|
|
132
|
+
state_data = self.state_mapping.get(state)
|
|
133
|
+
if not state_data:
|
|
134
|
+
msg = "Invalid state parameter"
|
|
135
|
+
raise ValueError(msg)
|
|
136
|
+
|
|
137
|
+
redirect_uri = state_data.redirect_uri
|
|
138
|
+
code_challenge = state_data.code_challenge
|
|
139
|
+
redirect_uri_provided_explicitly = state_data.redirect_uri_provided_explicitly
|
|
140
|
+
client_id = state_data.client_id
|
|
141
|
+
resource = state_data.resource
|
|
142
|
+
scopes = state_data.scopes or [self.settings.default_scope]
|
|
143
|
+
|
|
144
|
+
if redirect_uri is None or code_challenge is None or client_id is None:
|
|
145
|
+
msg = "Invalid authorization state"
|
|
146
|
+
raise ValueError(msg)
|
|
147
|
+
|
|
148
|
+
new_code = f"belgie_{secrets.token_hex(16)}"
|
|
149
|
+
auth_code = AuthorizationCode(
|
|
150
|
+
code=new_code,
|
|
151
|
+
client_id=client_id,
|
|
152
|
+
redirect_uri=AnyUrl(redirect_uri),
|
|
153
|
+
redirect_uri_provided_explicitly=bool(redirect_uri_provided_explicitly),
|
|
154
|
+
expires_at=time.time() + self.settings.authorization_code_ttl_seconds,
|
|
155
|
+
scopes=scopes,
|
|
156
|
+
code_challenge=code_challenge,
|
|
157
|
+
resource=resource,
|
|
158
|
+
)
|
|
159
|
+
self.auth_codes[new_code] = auth_code
|
|
160
|
+
|
|
161
|
+
del self.state_mapping[state]
|
|
162
|
+
return construct_redirect_uri(redirect_uri, code=new_code, state=state)
|
|
163
|
+
|
|
164
|
+
async def load_authorization_code(self, authorization_code: str) -> AuthorizationCode | None:
|
|
165
|
+
return self.auth_codes.get(authorization_code)
|
|
166
|
+
|
|
167
|
+
async def exchange_authorization_code(self, authorization_code: AuthorizationCode) -> OAuthToken:
|
|
168
|
+
if authorization_code.code not in self.auth_codes:
|
|
169
|
+
msg = "Invalid authorization code"
|
|
170
|
+
raise ValueError(msg)
|
|
171
|
+
|
|
172
|
+
mcp_token = f"belgie_{secrets.token_hex(32)}"
|
|
173
|
+
self.tokens[mcp_token] = AccessToken(
|
|
174
|
+
token=mcp_token,
|
|
175
|
+
client_id=authorization_code.client_id,
|
|
176
|
+
scopes=authorization_code.scopes,
|
|
177
|
+
created_at=int(time.time()),
|
|
178
|
+
expires_at=int(time.time()) + self.settings.access_token_ttl_seconds,
|
|
179
|
+
resource=authorization_code.resource,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
del self.auth_codes[authorization_code.code]
|
|
183
|
+
|
|
184
|
+
return OAuthToken(
|
|
185
|
+
access_token=mcp_token,
|
|
186
|
+
token_type="Bearer", # noqa: S106
|
|
187
|
+
expires_in=self.settings.access_token_ttl_seconds,
|
|
188
|
+
scope=" ".join(authorization_code.scopes),
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
async def load_access_token(self, token: str) -> AccessToken | None:
|
|
192
|
+
access_token = self.tokens.get(token)
|
|
193
|
+
if not access_token:
|
|
194
|
+
return None
|
|
195
|
+
|
|
196
|
+
if access_token.expires_at is not None and access_token.expires_at < time.time():
|
|
197
|
+
del self.tokens[token]
|
|
198
|
+
return None
|
|
199
|
+
|
|
200
|
+
return access_token
|
|
201
|
+
|
|
202
|
+
def _purge_state_mapping(self, now: float | None = None) -> None:
|
|
203
|
+
if not self.state_mapping:
|
|
204
|
+
return
|
|
205
|
+
current = time.time() if now is None else now
|
|
206
|
+
ttl_seconds = self.settings.state_ttl_seconds
|
|
207
|
+
if ttl_seconds <= 0:
|
|
208
|
+
return
|
|
209
|
+
expired_states = [
|
|
210
|
+
state for state, entry in self.state_mapping.items() if entry.created_at + ttl_seconds < current
|
|
211
|
+
]
|
|
212
|
+
for state in expired_states:
|
|
213
|
+
self.state_mapping.pop(state, None)
|
|
214
|
+
|
|
215
|
+
async def load_refresh_token(self, _refresh_token: str) -> RefreshToken | None:
|
|
216
|
+
return None
|
|
217
|
+
|
|
218
|
+
async def exchange_refresh_token(self, refresh_token: RefreshToken, scopes: list[str]) -> OAuthToken:
|
|
219
|
+
raise NotImplementedError
|
|
220
|
+
|
|
221
|
+
async def revoke_token(self, token: AccessToken | RefreshToken) -> None:
|
|
222
|
+
if isinstance(token, AccessToken):
|
|
223
|
+
self.tokens.pop(token.token, None)
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from functools import cached_property
|
|
4
|
+
from typing import Literal
|
|
5
|
+
from urllib.parse import urlparse, urlunparse
|
|
6
|
+
|
|
7
|
+
from pydantic import AnyHttpUrl, AnyUrl, Field, SecretStr
|
|
8
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class OAuthSettings(BaseSettings):
|
|
12
|
+
model_config = SettingsConfigDict(env_prefix="BELGIE_OAUTH_")
|
|
13
|
+
|
|
14
|
+
base_url: AnyHttpUrl | None = None
|
|
15
|
+
route_prefix: str = "/oauth"
|
|
16
|
+
login_url: str | None = None
|
|
17
|
+
|
|
18
|
+
client_id: str = "belgie_client"
|
|
19
|
+
client_secret: SecretStr | None = None
|
|
20
|
+
redirect_uris: list[AnyUrl] = Field(..., min_length=1)
|
|
21
|
+
default_scope: str = "user"
|
|
22
|
+
|
|
23
|
+
authorization_code_ttl_seconds: int = 300
|
|
24
|
+
access_token_ttl_seconds: int = 3600
|
|
25
|
+
state_ttl_seconds: int = 600
|
|
26
|
+
code_challenge_method: Literal["S256"] = "S256"
|
|
27
|
+
|
|
28
|
+
@cached_property
|
|
29
|
+
def issuer_url(self) -> AnyHttpUrl | None:
|
|
30
|
+
if self.base_url is None:
|
|
31
|
+
return None
|
|
32
|
+
|
|
33
|
+
parsed = urlparse(str(self.base_url))
|
|
34
|
+
base_path = parsed.path.rstrip("/")
|
|
35
|
+
prefix = self.route_prefix.strip("/")
|
|
36
|
+
auth_path = "auth"
|
|
37
|
+
full_path = f"{base_path}/{auth_path}/{prefix}" if prefix else f"{base_path}/{auth_path}"
|
|
38
|
+
return AnyHttpUrl(urlunparse(parsed._replace(path=full_path, query="", fragment="")))
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import hashlib
|
|
5
|
+
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def construct_redirect_uri(redirect_uri_base: str, **params: str | None) -> str:
|
|
9
|
+
parsed_uri = urlparse(redirect_uri_base)
|
|
10
|
+
query_params = [(key, value) for key, values in parse_qs(parsed_uri.query).items() for value in values]
|
|
11
|
+
for key, value in params.items():
|
|
12
|
+
if value is not None:
|
|
13
|
+
query_params.append((key, value))
|
|
14
|
+
return urlunparse(parsed_uri._replace(query=urlencode(query_params)))
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def join_url(base_url: str, path: str) -> str:
|
|
18
|
+
parsed = urlparse(base_url)
|
|
19
|
+
base_path = parsed.path.rstrip("/")
|
|
20
|
+
append_path = path.lstrip("/")
|
|
21
|
+
joined_path = f"{base_path}/{append_path}" if append_path else base_path
|
|
22
|
+
return urlunparse(parsed._replace(path=joined_path))
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def create_code_challenge(code_verifier: str) -> str:
|
|
26
|
+
digest = hashlib.sha256(code_verifier.encode("utf-8")).digest()
|
|
27
|
+
return base64.urlsafe_b64encode(digest).rstrip(b"=").decode("utf-8")
|
belgie_oauth-0.1.0/PKG-INFO
DELETED
belgie_oauth-0.1.0/README.md
DELETED
|
File without changes
|
|
File without changes
|