mcp-authflow 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,81 @@
1
+ """MCP Auth Framework - Reusable OAuth Authorization Server components.
2
+
3
+ Provides building blocks for OAuth 2.0 authorization servers that protect
4
+ MCP (Model Context Protocol) tool access:
5
+
6
+ - **Token storage** — Pluggable backends (in-memory, PostgreSQL) for access
7
+ and refresh tokens.
8
+ - **CORS** — Origin validation helpers for OAuth/MCP endpoints.
9
+ - **Rate limiting** — Sliding-window rate limiter for token endpoints.
10
+ - **Responses** — Standardized OAuth 2.0 error responses (RFC 6749).
11
+ - **Validation** — Input sanitization for OAuth identifiers and scopes.
12
+ """
13
+
14
+ from mcp_auth_framework.cors import build_cors_headers, get_cors_origin, parse_allowed_origins
15
+ from mcp_auth_framework.rate_limiting import SlidingWindowRateLimiter
16
+ from mcp_auth_framework.responses import (
17
+ OAUTH_NO_CACHE_HEADERS,
18
+ backend_connection_error,
19
+ backend_invalid_response,
20
+ backend_oauth_error,
21
+ backend_timeout,
22
+ invalid_client,
23
+ invalid_grant,
24
+ invalid_request,
25
+ invalid_scope,
26
+ oauth_error,
27
+ rate_limit_exceeded,
28
+ server_error,
29
+ slow_down,
30
+ )
31
+ from mcp_auth_framework.storage import MemoryTokenStorage, TokenStorage
32
+ from mcp_auth_framework.validation import (
33
+ VALID_ID_PATTERN,
34
+ parse_json_field,
35
+ parse_scope_field,
36
+ validate_client_id,
37
+ )
38
+
39
+ __all__ = [
40
+ # Version
41
+ "__version__",
42
+ # CORS
43
+ "build_cors_headers",
44
+ "get_cors_origin",
45
+ "parse_allowed_origins",
46
+ # Rate limiting
47
+ "SlidingWindowRateLimiter",
48
+ # OAuth responses
49
+ "OAUTH_NO_CACHE_HEADERS",
50
+ "backend_connection_error",
51
+ "backend_invalid_response",
52
+ "backend_oauth_error",
53
+ "backend_timeout",
54
+ "invalid_client",
55
+ "invalid_grant",
56
+ "invalid_request",
57
+ "invalid_scope",
58
+ "oauth_error",
59
+ "rate_limit_exceeded",
60
+ "server_error",
61
+ "slow_down",
62
+ # Storage
63
+ "MemoryTokenStorage",
64
+ "PostgresTokenStorage",
65
+ "TokenStorage",
66
+ # Validation
67
+ "VALID_ID_PATTERN",
68
+ "parse_json_field",
69
+ "parse_scope_field",
70
+ "validate_client_id",
71
+ ]
72
+
73
+ __version__ = "0.1.0"
74
+
75
+
76
+ def __getattr__(name: str) -> type:
77
+ if name == "PostgresTokenStorage":
78
+ from mcp_auth_framework.storage.postgres import PostgresTokenStorage # noqa: PLC0415
79
+
80
+ return PostgresTokenStorage
81
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
@@ -0,0 +1,63 @@
1
+ """CORS origin validation for MCP OAuth endpoints."""
2
+
3
+ import os
4
+
5
+ from starlette.requests import Request
6
+
7
+
8
+ def parse_allowed_origins(env_var: str = "ALLOWED_MCP_ORIGINS") -> list[str]:
9
+ """Parse allowed CORS origins from a comma-separated environment variable.
10
+
11
+ Args:
12
+ env_var: Name of the environment variable to read.
13
+
14
+ Returns:
15
+ List of allowed origin strings, stripped of whitespace.
16
+ """
17
+ raw = os.getenv(env_var, "")
18
+ if not raw:
19
+ return []
20
+ return [origin.strip() for origin in raw.split(",") if origin.strip()]
21
+
22
+
23
+ def get_cors_origin(request: Request, allowed_origins: list[str]) -> str:
24
+ """Get CORS origin header value based on request origin.
25
+
26
+ Only returns the origin if it's in the allowed list, otherwise returns
27
+ empty string to deny CORS access.
28
+
29
+ Args:
30
+ request: The incoming request.
31
+ allowed_origins: List of allowed origin strings.
32
+
33
+ Returns:
34
+ Origin value for Access-Control-Allow-Origin header.
35
+ """
36
+ request_origin = request.headers.get("origin", "")
37
+ if request_origin in allowed_origins:
38
+ return request_origin
39
+ return ""
40
+
41
+
42
+ def build_cors_headers(request: Request, allowed_origins: list[str]) -> dict[str, str]:
43
+ """Build standard CORS headers for OAuth discovery endpoints.
44
+
45
+ Only includes Access-Control-Allow-Origin when the request origin is
46
+ in the allowlist. Omits it entirely for disallowed origins per the
47
+ CORS specification.
48
+
49
+ Args:
50
+ request: The incoming request.
51
+ allowed_origins: List of allowed origin strings.
52
+
53
+ Returns:
54
+ Dict of CORS headers.
55
+ """
56
+ headers: dict[str, str] = {
57
+ "Access-Control-Allow-Methods": "GET, OPTIONS",
58
+ "Access-Control-Allow-Headers": "*",
59
+ }
60
+ origin = get_cors_origin(request, allowed_origins)
61
+ if origin:
62
+ headers["Access-Control-Allow-Origin"] = origin
63
+ return headers
File without changes
@@ -0,0 +1,59 @@
1
+ """Rate limiting utilities for OAuth endpoints."""
2
+
3
+ import time
4
+ from collections import defaultdict
5
+
6
+
7
+ class SlidingWindowRateLimiter:
8
+ """Simple in-memory rate limiter for OAuth endpoints.
9
+
10
+ Tracks requests per client within a sliding time window.
11
+ Thread-safe for async usage within a single process.
12
+ """
13
+
14
+ def __init__(self, requests_per_window: int, window_seconds: int):
15
+ """Initialize the rate limiter.
16
+
17
+ Args:
18
+ requests_per_window: Maximum number of requests allowed in the window
19
+ window_seconds: Size of the time window in seconds
20
+ """
21
+ self.requests_per_window = requests_per_window
22
+ self.window_seconds = window_seconds
23
+ self.clients: dict[str, list[float]] = defaultdict(list)
24
+
25
+ def is_allowed(self, client_id: str) -> bool:
26
+ """Check if the client is allowed to make a request.
27
+
28
+ Args:
29
+ client_id: OAuth client identifier
30
+
31
+ Returns:
32
+ True if the request is allowed, False if rate limited
33
+ """
34
+ now = time.time()
35
+ # Clean old requests outside the window
36
+ self.clients[client_id] = [
37
+ req_time for req_time in self.clients[client_id] if now - req_time < self.window_seconds
38
+ ]
39
+
40
+ if len(self.clients[client_id]) >= self.requests_per_window:
41
+ return False
42
+
43
+ self.clients[client_id].append(now)
44
+ return True
45
+
46
+ def get_retry_after(self, client_id: str) -> int:
47
+ """Get the number of seconds until the client can retry.
48
+
49
+ Args:
50
+ client_id: OAuth client identifier
51
+
52
+ Returns:
53
+ Number of seconds to wait before retrying (minimum 1)
54
+ """
55
+ if not self.clients[client_id]:
56
+ return 0
57
+ oldest_request = min(self.clients[client_id])
58
+ retry_after = int(self.window_seconds - (time.time() - oldest_request)) + 1
59
+ return max(retry_after, 1)
@@ -0,0 +1,140 @@
1
+ """
2
+ OAuth error response helpers for MCP Auth server.
3
+
4
+ This module provides standardized error response functions following
5
+ OAuth 2.0 error response format (RFC 6749).
6
+ """
7
+
8
+ from starlette.responses import JSONResponse
9
+
10
+ # Standard headers for OAuth responses
11
+ OAUTH_NO_CACHE_HEADERS: dict[str, str] = {"Cache-Control": "no-store"}
12
+
13
+
14
+ def oauth_error(
15
+ error: str,
16
+ description: str,
17
+ status_code: int = 400,
18
+ extra_headers: dict[str, str] | None = None,
19
+ ) -> JSONResponse:
20
+ """Create a standard OAuth error response.
21
+
22
+ Args:
23
+ error: OAuth error code (e.g., "invalid_request", "server_error")
24
+ description: Human-readable error description
25
+ status_code: HTTP status code (default 400)
26
+ extra_headers: Additional headers to include (e.g., Retry-After)
27
+
28
+ Returns:
29
+ JSONResponse with OAuth error format
30
+ """
31
+ headers = OAUTH_NO_CACHE_HEADERS.copy()
32
+ if extra_headers:
33
+ headers.update(extra_headers)
34
+
35
+ return JSONResponse(
36
+ {"error": error, "error_description": description},
37
+ status_code=status_code,
38
+ headers=headers,
39
+ )
40
+
41
+
42
+ def invalid_request(description: str) -> JSONResponse:
43
+ """Create an invalid_request error response (400).
44
+
45
+ Use for: missing required parameters, invalid parameter format.
46
+ """
47
+ return oauth_error("invalid_request", description, 400)
48
+
49
+
50
+ def invalid_client(description: str) -> JSONResponse:
51
+ """Create an invalid_client error response (401).
52
+
53
+ Use for: client authentication failed, unknown client.
54
+ """
55
+ return oauth_error("invalid_client", description, 401)
56
+
57
+
58
+ def slow_down(description: str, retry_after: int | None = None) -> JSONResponse:
59
+ """Create a slow_down error response (400 or 429).
60
+
61
+ Use for: rate limiting during device flow polling.
62
+
63
+ Args:
64
+ description: Error description
65
+ retry_after: Optional retry-after value in seconds
66
+ """
67
+ extra_headers = {"Retry-After": str(retry_after)} if retry_after else None
68
+ return oauth_error("slow_down", description, 400, extra_headers)
69
+
70
+
71
+ def rate_limit_exceeded(description: str, retry_after: int | None = None) -> JSONResponse:
72
+ """Create a rate limit exceeded error response (429).
73
+
74
+ Use for: too many requests.
75
+ """
76
+ extra_headers = {"Retry-After": str(retry_after)} if retry_after else None
77
+ return oauth_error("slow_down", description, 429, extra_headers)
78
+
79
+
80
+ def server_error(description: str, status_code: int = 500) -> JSONResponse:
81
+ """Create a server_error response.
82
+
83
+ Use for: internal server errors, backend failures.
84
+
85
+ Args:
86
+ description: Error description
87
+ status_code: HTTP status code (500, 502, 504, etc.)
88
+ """
89
+ return oauth_error("server_error", description, status_code)
90
+
91
+
92
+ def backend_timeout() -> JSONResponse:
93
+ """Create a backend timeout error response (504)."""
94
+ return server_error("Backend timeout", 504)
95
+
96
+
97
+ def backend_connection_error() -> JSONResponse:
98
+ """Create a backend connection error response (502)."""
99
+ return server_error("Backend connection error", 502)
100
+
101
+
102
+ def backend_invalid_response() -> JSONResponse:
103
+ """Create an invalid backend response error (502)."""
104
+ return server_error("Invalid response from backend", 502)
105
+
106
+
107
+ def invalid_grant(description: str) -> JSONResponse:
108
+ """Create an invalid_grant error response (400).
109
+
110
+ Use for: invalid or expired authorization codes, refresh tokens, or device codes.
111
+ """
112
+ return oauth_error("invalid_grant", description, 400)
113
+
114
+
115
+ def invalid_scope(description: str) -> JSONResponse:
116
+ """Create an invalid_scope error response (400).
117
+
118
+ Use for: requested scope is invalid, unknown, or exceeds what was granted.
119
+ """
120
+ return oauth_error("invalid_scope", description, 400)
121
+
122
+
123
+ def backend_oauth_error(error_dict: dict[str, str], status_code: int) -> JSONResponse:
124
+ """Create a JSONResponse from an already-formatted OAuth error dict.
125
+
126
+ Use for: forwarding transformed backend errors that are already in
127
+ {"error": "...", "error_description": "..."} format.
128
+
129
+ Args:
130
+ error_dict: Dict with "error" and "error_description" keys
131
+ status_code: HTTP status code from the backend response
132
+
133
+ Returns:
134
+ JSONResponse with Cache-Control: no-store header
135
+ """
136
+ return JSONResponse(
137
+ error_dict,
138
+ status_code=status_code,
139
+ headers=OAUTH_NO_CACHE_HEADERS.copy(),
140
+ )
@@ -0,0 +1,30 @@
1
+ """Token storage abstractions and implementations.
2
+
3
+ ``PostgresTokenStorage`` is importable from this module but loaded lazily
4
+ so that ``asyncpg`` is only required when actually used. Install the
5
+ ``postgres`` extra to enable it: ``pip install mcp-auth-framework[postgres]``
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import TYPE_CHECKING
11
+
12
+ from mcp_auth_framework.storage.base import TokenStorage
13
+ from mcp_auth_framework.storage.memory import MemoryTokenStorage
14
+
15
+ if TYPE_CHECKING:
16
+ from mcp_auth_framework.storage.postgres import PostgresTokenStorage
17
+
18
+ __all__ = [
19
+ "TokenStorage",
20
+ "MemoryTokenStorage",
21
+ "PostgresTokenStorage",
22
+ ]
23
+
24
+
25
+ def __getattr__(name: str) -> type:
26
+ if name == "PostgresTokenStorage":
27
+ from mcp_auth_framework.storage.postgres import PostgresTokenStorage # noqa: PLC0415
28
+
29
+ return PostgresTokenStorage
30
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
@@ -0,0 +1,131 @@
1
+ """Abstract base class for token storage implementations."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any
5
+
6
+
7
+ class TokenStorage(ABC):
8
+ """Abstract interface for MCP token storage."""
9
+
10
+ @abstractmethod
11
+ async def initialize(self) -> None:
12
+ """Initialize the storage backend."""
13
+ ...
14
+
15
+ @abstractmethod
16
+ async def close(self) -> None:
17
+ """Close the storage backend and clean up resources."""
18
+ ...
19
+
20
+ @abstractmethod
21
+ async def store_token(
22
+ self,
23
+ token: str,
24
+ client_id: str,
25
+ scopes: list[str],
26
+ expires_at: int,
27
+ resource: str | None = None,
28
+ user_id: int | None = None,
29
+ ) -> None:
30
+ """Store an access token.
31
+
32
+ Args:
33
+ token: The access token string
34
+ client_id: OAuth client ID
35
+ scopes: List of granted scopes
36
+ expires_at: Unix timestamp when token expires
37
+ resource: Optional RFC 8707 resource binding
38
+ user_id: Optional ID of the user who authorized the token
39
+ """
40
+ ...
41
+
42
+ @abstractmethod
43
+ async def load_token(self, token: str) -> dict[str, Any] | None:
44
+ """Load an access token.
45
+
46
+ Args:
47
+ token: The access token string to look up
48
+
49
+ Returns:
50
+ Token data dict if found and not expired, None otherwise
51
+ """
52
+ ...
53
+
54
+ @abstractmethod
55
+ async def delete_token(self, token: str) -> None:
56
+ """Delete a token.
57
+
58
+ Args:
59
+ token: The access token string to delete
60
+ """
61
+ ...
62
+
63
+ @abstractmethod
64
+ async def store_refresh_token(
65
+ self,
66
+ refresh_token: str,
67
+ client_id: str,
68
+ scopes: list[str],
69
+ expires_at: int,
70
+ resource: str | None = None,
71
+ user_id: int | None = None,
72
+ ) -> None:
73
+ """Store a refresh token.
74
+
75
+ Args:
76
+ refresh_token: The refresh token string
77
+ client_id: OAuth client ID
78
+ scopes: List of granted scopes
79
+ expires_at: Unix timestamp when token expires
80
+ resource: Optional RFC 8707 resource binding
81
+ user_id: Optional ID of the user who authorized the token
82
+ """
83
+ ...
84
+
85
+ @abstractmethod
86
+ async def load_refresh_token(self, refresh_token: str) -> dict[str, Any] | None:
87
+ """Load a refresh token.
88
+
89
+ Args:
90
+ refresh_token: The refresh token string to look up
91
+
92
+ Returns:
93
+ Token data dict if found and not expired, None otherwise
94
+ """
95
+ ...
96
+
97
+ @abstractmethod
98
+ async def delete_refresh_token(self, refresh_token: str) -> None:
99
+ """Delete a refresh token.
100
+
101
+ Args:
102
+ refresh_token: The refresh token string to delete
103
+ """
104
+ ...
105
+
106
+ @abstractmethod
107
+ async def cleanup_expired_tokens(self) -> int:
108
+ """Remove all expired access tokens.
109
+
110
+ Returns:
111
+ Number of tokens removed
112
+ """
113
+ ...
114
+
115
+ @abstractmethod
116
+ async def cleanup_expired_refresh_tokens(self) -> int:
117
+ """Remove all expired refresh tokens.
118
+
119
+ Returns:
120
+ Number of tokens removed
121
+ """
122
+ ...
123
+
124
+ @abstractmethod
125
+ async def get_token_count(self) -> int:
126
+ """Get the total number of access tokens in storage.
127
+
128
+ Returns:
129
+ Number of tokens stored
130
+ """
131
+ ...