mcp-proxy-oauth-dcr 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,71 @@
1
+ """Logging configuration for the MCP Proxy.
2
+
3
+ This module sets up structured logging using structlog for better
4
+ debugging and monitoring capabilities.
5
+ """
6
+
7
+ import logging
8
+ import sys
9
+ from typing import Any, Dict
10
+
11
+ import structlog
12
+
13
+
14
+ def configure_logging(log_level: str = "info") -> None:
15
+ """Configure structured logging for the application.
16
+
17
+ Args:
18
+ log_level: Logging level (debug, info, warning, error, critical)
19
+ """
20
+ # Convert string level to logging constant
21
+ numeric_level = getattr(logging, log_level.upper(), logging.INFO)
22
+
23
+ # Configure standard logging
24
+ logging.basicConfig(
25
+ format="%(message)s",
26
+ stream=sys.stderr, # Use stderr to avoid interfering with stdio MCP
27
+ level=numeric_level,
28
+ )
29
+
30
+ # Configure structlog
31
+ structlog.configure(
32
+ processors=[
33
+ structlog.contextvars.merge_contextvars,
34
+ structlog.processors.add_log_level,
35
+ structlog.processors.StackInfoRenderer(),
36
+ structlog.dev.set_exc_info,
37
+ structlog.processors.TimeStamper(fmt="iso", utc=True),
38
+ structlog.dev.ConsoleRenderer() if sys.stderr.isatty() else structlog.processors.JSONRenderer(),
39
+ ],
40
+ wrapper_class=structlog.make_filtering_bound_logger(numeric_level),
41
+ context_class=dict,
42
+ logger_factory=structlog.PrintLoggerFactory(file=sys.stderr),
43
+ cache_logger_on_first_use=True,
44
+ )
45
+
46
+
47
+ def get_logger(name: str) -> structlog.BoundLogger:
48
+ """Get a logger instance for a module.
49
+
50
+ Args:
51
+ name: Logger name (typically __name__)
52
+
53
+ Returns:
54
+ Configured structlog logger
55
+ """
56
+ return structlog.get_logger(name)
57
+
58
+
59
+ def add_context(**kwargs: Any) -> None:
60
+ """Add context variables to all subsequent log messages.
61
+
62
+ Args:
63
+ **kwargs: Context key-value pairs
64
+ """
65
+ structlog.contextvars.clear_contextvars()
66
+ structlog.contextvars.bind_contextvars(**kwargs)
67
+
68
+
69
+ def clear_context() -> None:
70
+ """Clear all context variables."""
71
+ structlog.contextvars.clear_contextvars()
mcp_proxy/models.py ADDED
@@ -0,0 +1,259 @@
1
+ """Core data models and interfaces for the MCP Proxy.
2
+
3
+ This module defines the data structures used throughout the proxy system,
4
+ including JSON-RPC messages, HTTP requests/responses, authentication state,
5
+ and configuration models.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ from dataclasses import dataclass, field
12
+ from datetime import datetime
13
+ from enum import Enum
14
+ from typing import Any, Dict, List, Optional, Union
15
+
16
+ from pydantic import BaseModel, Field, HttpUrl, field_validator
17
+
18
+
19
+ class MessageStatus(str, Enum):
20
+ """Status of message processing."""
21
+ PENDING = "pending"
22
+ COMPLETED = "completed"
23
+ FAILED = "failed"
24
+
25
+
26
+ # ============================================================================
27
+ # JSON-RPC Models
28
+ # ============================================================================
29
+
30
+
31
+ class JsonRpcError(BaseModel):
32
+ """JSON-RPC 2.0 error object."""
33
+ code: int = Field(..., description="Error code")
34
+ message: str = Field(..., description="Error message")
35
+ data: Optional[Any] = Field(None, description="Additional error data")
36
+
37
+
38
+ class JsonRpcMessage(BaseModel):
39
+ """JSON-RPC 2.0 message structure."""
40
+ jsonrpc: str = Field(default="2.0", description="JSON-RPC version")
41
+ id: Optional[Union[str, int]] = Field(None, description="Request/response ID")
42
+ method: Optional[str] = Field(None, description="Method name for requests")
43
+ params: Optional[Any] = Field(None, description="Method parameters")
44
+ result: Optional[Any] = Field(None, description="Result for responses")
45
+ error: Optional[JsonRpcError] = Field(None, description="Error for error responses")
46
+
47
+ @field_validator("jsonrpc")
48
+ @classmethod
49
+ def validate_jsonrpc_version(cls, v: str) -> str:
50
+ """Ensure JSON-RPC version is 2.0."""
51
+ if v != "2.0":
52
+ raise ValueError("JSON-RPC version must be 2.0")
53
+ return v
54
+
55
+ def is_request(self) -> bool:
56
+ """Check if this is a request message."""
57
+ return self.method is not None
58
+
59
+ def is_response(self) -> bool:
60
+ """Check if this is a response message."""
61
+ return self.result is not None or self.error is not None
62
+
63
+ def is_notification(self) -> bool:
64
+ """Check if this is a notification (request without ID)."""
65
+ return self.method is not None and self.id is None
66
+
67
+ def to_json(self) -> str:
68
+ """Serialize to JSON string."""
69
+ return self.model_dump_json(exclude_none=True)
70
+
71
+ @classmethod
72
+ def from_json(cls, json_str: str) -> JsonRpcMessage:
73
+ """Parse from JSON string."""
74
+ return cls.model_validate_json(json_str)
75
+
76
+
77
+ # ============================================================================
78
+ # HTTP MCP Models
79
+ # ============================================================================
80
+
81
+
82
+ class HttpMcpRequest(BaseModel):
83
+ """HTTP MCP request structure."""
84
+ method: str = Field(..., description="HTTP method (POST, GET)")
85
+ url: str = Field(..., description="Request URL")
86
+ headers: Dict[str, str] = Field(default_factory=dict, description="HTTP headers")
87
+ body: Optional[str] = Field(None, description="Request body")
88
+ session_id: Optional[str] = Field(None, description="MCP session ID")
89
+
90
+ @field_validator("method")
91
+ @classmethod
92
+ def validate_method(cls, v: str) -> str:
93
+ """Ensure HTTP method is valid."""
94
+ if v.upper() not in ["GET", "POST", "PUT", "DELETE", "PATCH"]:
95
+ raise ValueError(f"Invalid HTTP method: {v}")
96
+ return v.upper()
97
+
98
+
99
+ class HttpMcpResponse(BaseModel):
100
+ """HTTP MCP response structure."""
101
+ status: int = Field(..., description="HTTP status code")
102
+ headers: Dict[str, str] = Field(default_factory=dict, description="Response headers")
103
+ body: str = Field(default="", description="Response body")
104
+ content_type: str = Field(default="application/json", description="Content type")
105
+
106
+ def is_success(self) -> bool:
107
+ """Check if response indicates success."""
108
+ return 200 <= self.status < 300
109
+
110
+ def is_auth_error(self) -> bool:
111
+ """Check if response indicates authentication error."""
112
+ return self.status == 401
113
+
114
+
115
+ # ============================================================================
116
+ # Authentication Models
117
+ # ============================================================================
118
+
119
+
120
+ class ClientCredentials(BaseModel):
121
+ """OAuth client credentials."""
122
+ client_id: str = Field(..., description="OAuth client ID")
123
+ client_secret: str = Field(..., description="OAuth client secret")
124
+ expires_at: Optional[datetime] = Field(None, description="Credential expiration time")
125
+
126
+ def is_expired(self) -> bool:
127
+ """Check if credentials are expired."""
128
+ if self.expires_at is None:
129
+ return False
130
+ return datetime.now() >= self.expires_at
131
+
132
+
133
+ class OAuthTokenResponse(BaseModel):
134
+ """OAuth token response."""
135
+ access_token: str = Field(..., description="Access token")
136
+ token_type: str = Field(..., description="Token type (usually Bearer)")
137
+ expires_in: int = Field(..., description="Token lifetime in seconds")
138
+ scope: Optional[str] = Field(None, description="Token scope")
139
+ refresh_token: Optional[str] = Field(None, description="Refresh token")
140
+
141
+ def get_expiration_time(self) -> datetime:
142
+ """Calculate token expiration time."""
143
+ from datetime import timedelta
144
+ return datetime.now() + timedelta(seconds=self.expires_in)
145
+
146
+
147
+ class AuthenticationState(BaseModel):
148
+ """Current authentication state."""
149
+ client_credentials: Optional[ClientCredentials] = Field(
150
+ None, description="Client credentials"
151
+ )
152
+ access_token: Optional[str] = Field(None, description="Current access token")
153
+ refresh_token: Optional[str] = Field(None, description="Refresh token for token renewal")
154
+ token_expires_at: Optional[datetime] = Field(
155
+ None, description="Token expiration time"
156
+ )
157
+ last_dcr_attempt: Optional[datetime] = Field(
158
+ None, description="Last DCR attempt timestamp"
159
+ )
160
+ dcr_retry_count: int = Field(default=0, description="DCR retry counter")
161
+ is_authenticated: bool = Field(default=False, description="Authentication status")
162
+
163
+ def is_token_valid(self) -> bool:
164
+ """Check if current token is valid."""
165
+ if not self.access_token or not self.token_expires_at:
166
+ return False
167
+ # Consider token invalid 60 seconds before actual expiration
168
+ from datetime import timedelta
169
+ return datetime.now() < (self.token_expires_at - timedelta(seconds=60))
170
+
171
+ def needs_token_refresh(self) -> bool:
172
+ """Check if token needs refresh."""
173
+ return self.is_authenticated and not self.is_token_valid()
174
+
175
+
176
+ # ============================================================================
177
+ # Session Models
178
+ # ============================================================================
179
+
180
+
181
+ class SessionState(BaseModel):
182
+ """MCP session state."""
183
+ session_id: Optional[str] = Field(None, description="Session ID")
184
+ is_initialized: bool = Field(default=False, description="Initialization status")
185
+ protocol_version: str = Field(default="1.0", description="MCP protocol version")
186
+ capabilities: List[str] = Field(default_factory=list, description="Session capabilities")
187
+ sse_stream_active: bool = Field(default=False, description="SSE stream status")
188
+
189
+
190
+ # ============================================================================
191
+ # Message Correlation Models
192
+ # ============================================================================
193
+
194
+
195
+ @dataclass
196
+ class MessageCorrelation:
197
+ """Tracks correlation between stdio and HTTP messages."""
198
+ stdio_request_id: Union[str, int]
199
+ http_request_id: str
200
+ timestamp: datetime
201
+ method: str
202
+ status: MessageStatus = MessageStatus.PENDING
203
+
204
+ def is_pending(self) -> bool:
205
+ """Check if message is still pending."""
206
+ return self.status == MessageStatus.PENDING
207
+
208
+ def mark_completed(self) -> None:
209
+ """Mark message as completed."""
210
+ self.status = MessageStatus.COMPLETED
211
+
212
+ def mark_failed(self) -> None:
213
+ """Mark message as failed."""
214
+ self.status = MessageStatus.FAILED
215
+
216
+
217
+ # ============================================================================
218
+ # Configuration Models
219
+ # ============================================================================
220
+
221
+
222
+ class ProxyConfig(BaseModel):
223
+ """Proxy configuration."""
224
+ mcp_server_url: HttpUrl = Field(..., description="HTTP MCP server URL")
225
+ oauth_provider_url: HttpUrl = Field(..., description="OAuth provider URL")
226
+ client_name: str = Field(default="mcp-proxy-client", description="OAuth client name")
227
+ scopes: List[str] = Field(
228
+ default_factory=lambda: ["mcp:read", "mcp:write"],
229
+ description="OAuth scopes"
230
+ )
231
+ connection_timeout: int = Field(
232
+ default=30, ge=1, le=300, description="Connection timeout in seconds"
233
+ )
234
+ retry_attempts: int = Field(
235
+ default=3, ge=0, le=10, description="Number of retry attempts"
236
+ )
237
+ log_level: str = Field(
238
+ default="info", description="Logging level"
239
+ )
240
+ max_backoff_seconds: int = Field(
241
+ default=60, ge=1, le=300, description="Maximum backoff time in seconds"
242
+ )
243
+
244
+ @field_validator("log_level")
245
+ @classmethod
246
+ def validate_log_level(cls, v: str) -> str:
247
+ """Ensure log level is valid."""
248
+ valid_levels = ["debug", "info", "warning", "error", "critical"]
249
+ if v.lower() not in valid_levels:
250
+ raise ValueError(f"Invalid log level: {v}. Must be one of {valid_levels}")
251
+ return v.lower()
252
+
253
+ @field_validator("scopes")
254
+ @classmethod
255
+ def validate_scopes(cls, v: List[str]) -> List[str]:
256
+ """Ensure scopes list is not empty."""
257
+ if not v:
258
+ raise ValueError("Scopes list cannot be empty")
259
+ return v
mcp_proxy/protocols.py ADDED
@@ -0,0 +1,122 @@
1
+ """Protocol interfaces for the MCP Proxy components.
2
+
3
+ This module defines the abstract interfaces (protocols) that components
4
+ must implement, enabling loose coupling and testability.
5
+ """
6
+
7
+ from typing import Protocol, Optional, AsyncIterator
8
+ from .models import (
9
+ JsonRpcMessage,
10
+ HttpMcpRequest,
11
+ HttpMcpResponse,
12
+ ClientCredentials,
13
+ OAuthTokenResponse,
14
+ ProxyConfig,
15
+ AuthenticationState,
16
+ )
17
+
18
+
19
+ class StdioInterface(Protocol):
20
+ """Interface for stdio communication with Kiro."""
21
+
22
+ async def start(self) -> None:
23
+ """Start the stdio interface."""
24
+ ...
25
+
26
+ async def stop(self) -> None:
27
+ """Stop the stdio interface."""
28
+ ...
29
+
30
+ async def send_message(self, message: JsonRpcMessage) -> None:
31
+ """Send a JSON-RPC message to stdout."""
32
+ ...
33
+
34
+ async def receive_message(self) -> JsonRpcMessage:
35
+ """Receive a JSON-RPC message from stdin."""
36
+ ...
37
+
38
+
39
+ class ProtocolTranslator(Protocol):
40
+ """Interface for protocol translation between stdio and HTTP."""
41
+
42
+ async def translate_stdio_to_http(
43
+ self, message: JsonRpcMessage
44
+ ) -> HttpMcpRequest:
45
+ """Translate stdio JSON-RPC message to HTTP MCP request."""
46
+ ...
47
+
48
+ async def translate_http_to_stdio(
49
+ self, response: HttpMcpResponse
50
+ ) -> JsonRpcMessage:
51
+ """Translate HTTP MCP response to stdio JSON-RPC message."""
52
+ ...
53
+
54
+ def correlate_messages(self, request_id: str, response_id: str) -> None:
55
+ """Correlate request and response messages."""
56
+ ...
57
+
58
+
59
+ class AuthenticationManager(Protocol):
60
+ """Interface for OAuth DCR and token management."""
61
+
62
+ async def initialize(self) -> None:
63
+ """Initialize authentication manager and perform DCR if needed."""
64
+ ...
65
+
66
+ async def get_access_token(self) -> str:
67
+ """Get a valid access token, refreshing if necessary."""
68
+ ...
69
+
70
+ async def refresh_token(self) -> str:
71
+ """Refresh the access token."""
72
+ ...
73
+
74
+ async def perform_dcr(self) -> ClientCredentials:
75
+ """Perform OAuth Dynamic Client Registration."""
76
+ ...
77
+
78
+ def is_token_valid(self) -> bool:
79
+ """Check if current token is valid."""
80
+ ...
81
+
82
+ def get_state(self) -> AuthenticationState:
83
+ """Get current authentication state."""
84
+ ...
85
+
86
+
87
+ class HttpClient(Protocol):
88
+ """Interface for HTTP communication with backend MCP server."""
89
+
90
+ async def send_request(self, request: HttpMcpRequest) -> HttpMcpResponse:
91
+ """Send HTTP request to backend MCP server."""
92
+ ...
93
+
94
+ async def open_sse_stream(
95
+ self, session_id: Optional[str] = None
96
+ ) -> AsyncIterator[str]:
97
+ """Open Server-Sent Events stream."""
98
+ ...
99
+
100
+ async def close_connection(self) -> None:
101
+ """Close HTTP connection."""
102
+ ...
103
+
104
+ def set_auth_token(self, token: str) -> None:
105
+ """Set authentication token for requests."""
106
+ ...
107
+
108
+
109
+ class ConfigurationManager(Protocol):
110
+ """Interface for configuration management."""
111
+
112
+ async def load(self) -> ProxyConfig:
113
+ """Load configuration from environment and files."""
114
+ ...
115
+
116
+ def validate(self, config: ProxyConfig) -> bool:
117
+ """Validate configuration parameters."""
118
+ ...
119
+
120
+ def get_defaults(self) -> ProxyConfig:
121
+ """Get default configuration values."""
122
+ ...