mcp-bastion-python 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mcp_bastion/__init__.py +14 -0
- mcp_bastion/base.py +98 -0
- mcp_bastion/errors.py +42 -0
- mcp_bastion/middleware.py +229 -0
- mcp_bastion/pillars/__init__.py +11 -0
- mcp_bastion/pillars/pii_redaction.py +105 -0
- mcp_bastion/pillars/prompt_guard.py +97 -0
- mcp_bastion/pillars/rate_limit.py +106 -0
- mcp_bastion_python-1.0.1.dist-info/METADATA +506 -0
- mcp_bastion_python-1.0.1.dist-info/RECORD +12 -0
- mcp_bastion_python-1.0.1.dist-info/WHEEL +4 -0
- mcp_bastion_python-1.0.1.dist-info/licenses/NOTICE +5 -0
mcp_bastion/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MCP-Bastion: Security middleware for Model Context Protocol servers.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from mcp_bastion.middleware import MCPBastionMiddleware
|
|
6
|
+
from mcp_bastion.base import Middleware, MiddlewareContext, compose_middleware
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"MCPBastionMiddleware",
|
|
10
|
+
"Middleware",
|
|
11
|
+
"MiddlewareContext",
|
|
12
|
+
"compose_middleware",
|
|
13
|
+
]
|
|
14
|
+
__version__ = "1.0.0"
|
mcp_bastion/base.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base middleware abstractions for MCP-Bastion.
|
|
3
|
+
|
|
4
|
+
Middleware base class, MiddlewareContext dataclass, compose_middleware.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from collections.abc import Awaitable, Callable
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from typing import Any, Generic, TypeVar
|
|
10
|
+
|
|
11
|
+
T = TypeVar("T")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class MiddlewareContext(Generic[T]):
|
|
16
|
+
"""Context for middleware chain: message, metadata, request_id, session_id."""
|
|
17
|
+
|
|
18
|
+
message: T
|
|
19
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
20
|
+
request_id: str | None = None
|
|
21
|
+
session_id: str | None = None
|
|
22
|
+
|
|
23
|
+
def copy(self, **kwargs: Any) -> "MiddlewareContext[T]":
|
|
24
|
+
"""Create a copy with updated fields."""
|
|
25
|
+
data = {
|
|
26
|
+
"message": self.message,
|
|
27
|
+
"metadata": dict(self.metadata),
|
|
28
|
+
"request_id": self.request_id,
|
|
29
|
+
"session_id": self.session_id,
|
|
30
|
+
}
|
|
31
|
+
data.update(kwargs)
|
|
32
|
+
return MiddlewareContext(**data)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
CallNext = Callable[[MiddlewareContext[T]], Awaitable[Any]]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class Middleware(Generic[T]):
|
|
39
|
+
"""Base class for MCP middleware. Override on_message, on_call_tool, on_read_resource."""
|
|
40
|
+
|
|
41
|
+
async def __call__(
|
|
42
|
+
self,
|
|
43
|
+
context: MiddlewareContext[T],
|
|
44
|
+
call_next: CallNext[T],
|
|
45
|
+
) -> Any:
|
|
46
|
+
return await self.on_message(context, call_next)
|
|
47
|
+
|
|
48
|
+
async def on_message(
|
|
49
|
+
self,
|
|
50
|
+
context: MiddlewareContext[T],
|
|
51
|
+
call_next: CallNext[T],
|
|
52
|
+
) -> Any:
|
|
53
|
+
"""Handle any message. Override for generic processing."""
|
|
54
|
+
return await call_next(context)
|
|
55
|
+
|
|
56
|
+
async def on_call_tool(
|
|
57
|
+
self,
|
|
58
|
+
context: MiddlewareContext[T],
|
|
59
|
+
call_next: CallNext[T],
|
|
60
|
+
) -> Any:
|
|
61
|
+
"""Handle tool calls. Override for tool-specific processing."""
|
|
62
|
+
return await self.on_message(context, call_next)
|
|
63
|
+
|
|
64
|
+
async def on_read_resource(
|
|
65
|
+
self,
|
|
66
|
+
context: MiddlewareContext[T],
|
|
67
|
+
call_next: CallNext[T],
|
|
68
|
+
) -> Any:
|
|
69
|
+
"""Handle resource reads. Override for resource-specific processing."""
|
|
70
|
+
return await self.on_message(context, call_next)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def compose_middleware(
|
|
74
|
+
*middleware: Middleware[Any],
|
|
75
|
+
) -> Callable[[MiddlewareContext[Any], CallNext[Any]], Awaitable[Any]]:
|
|
76
|
+
"""Compose middleware. First in list = outermost."""
|
|
77
|
+
if not middleware:
|
|
78
|
+
async def passthrough(ctx: MiddlewareContext[Any], call_next: CallNext[Any]) -> Any:
|
|
79
|
+
return await call_next(ctx)
|
|
80
|
+
return passthrough
|
|
81
|
+
|
|
82
|
+
async def composed(
|
|
83
|
+
context: MiddlewareContext[Any],
|
|
84
|
+
call_next: CallNext[Any],
|
|
85
|
+
) -> Any:
|
|
86
|
+
index = 0
|
|
87
|
+
|
|
88
|
+
async def next_handler(ctx: MiddlewareContext[Any]) -> Any:
|
|
89
|
+
nonlocal index
|
|
90
|
+
if index >= len(middleware):
|
|
91
|
+
return await call_next(ctx)
|
|
92
|
+
mw = middleware[index]
|
|
93
|
+
index += 1
|
|
94
|
+
return await mw(ctx, next_handler)
|
|
95
|
+
|
|
96
|
+
return await next_handler(context)
|
|
97
|
+
|
|
98
|
+
return composed
|
mcp_bastion/errors.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MCP-compliant error types for security policy violations.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MCPBastionError(Exception):
|
|
9
|
+
"""Base exception for MCP-Bastion security violations."""
|
|
10
|
+
|
|
11
|
+
def __init__(self, message: str, code: int = -32000) -> None:
|
|
12
|
+
super().__init__(message)
|
|
13
|
+
self.message = message
|
|
14
|
+
self.code = code
|
|
15
|
+
|
|
16
|
+
def to_mcp_error(self) -> dict:
|
|
17
|
+
"""Format as MCP/JSON-RPC error object."""
|
|
18
|
+
return {
|
|
19
|
+
"code": self.code,
|
|
20
|
+
"message": self.message,
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class PromptInjectionError(MCPBastionError):
|
|
25
|
+
"""Raised when prompt injection or jailbreak is detected."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, message: str = "Request blocked: potential prompt injection detected") -> None:
|
|
28
|
+
super().__init__(message, code=-32001)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class RateLimitExceededError(MCPBastionError):
|
|
32
|
+
"""Raised when rate limit or iteration cap is exceeded."""
|
|
33
|
+
|
|
34
|
+
def __init__(self, message: str = "Request blocked: rate limit exceeded") -> None:
|
|
35
|
+
super().__init__(message, code=-32002)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class TokenBudgetExceededError(MCPBastionError):
|
|
39
|
+
"""Raised when FinOps token budget is exhausted."""
|
|
40
|
+
|
|
41
|
+
def __init__(self, message: str = "Request blocked: token budget exhausted") -> None:
|
|
42
|
+
super().__init__(message, code=-32003)
|
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MCP-Bastion security middleware.
|
|
3
|
+
|
|
4
|
+
Intercepts CallToolRequest and ReadResourceResult for prompt injection,
|
|
5
|
+
PII redaction, and rate limiting.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
import time
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
from mcp_bastion.base import CallNext, Middleware, MiddlewareContext
|
|
16
|
+
from mcp_bastion.errors import PromptInjectionError, RateLimitExceededError
|
|
17
|
+
from mcp_bastion.pillars.pii_redaction import PIIRedactor
|
|
18
|
+
from mcp_bastion.pillars.prompt_guard import PromptGuardEngine
|
|
19
|
+
from mcp_bastion.pillars.rate_limit import TokenBucketRateLimiter
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _extract_text_from_value(value: Any) -> str:
|
|
25
|
+
"""Flatten args to string for injection check."""
|
|
26
|
+
if value is None:
|
|
27
|
+
return ""
|
|
28
|
+
if isinstance(value, str):
|
|
29
|
+
return value
|
|
30
|
+
if isinstance(value, (int, float, bool)):
|
|
31
|
+
return str(value)
|
|
32
|
+
if isinstance(value, dict):
|
|
33
|
+
return " ".join(_extract_text_from_value(v) for v in value.values())
|
|
34
|
+
if isinstance(value, (list, tuple)):
|
|
35
|
+
return " ".join(_extract_text_from_value(v) for v in value)
|
|
36
|
+
return str(value)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _is_call_tool_request(message: Any) -> bool:
|
|
40
|
+
"""True if message is tools/call."""
|
|
41
|
+
if hasattr(message, "root"):
|
|
42
|
+
msg = message.root
|
|
43
|
+
else:
|
|
44
|
+
msg = message
|
|
45
|
+
if hasattr(msg, "method") and getattr(msg, "method", None) == "tools/call":
|
|
46
|
+
return True
|
|
47
|
+
if isinstance(msg, dict) and msg.get("method") == "tools/call":
|
|
48
|
+
return True
|
|
49
|
+
return False
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _is_read_resource_result(message: Any) -> bool:
|
|
53
|
+
"""True if message has resource contents."""
|
|
54
|
+
if message is None:
|
|
55
|
+
return False
|
|
56
|
+
if hasattr(message, "contents"):
|
|
57
|
+
return True
|
|
58
|
+
if hasattr(message, "root"):
|
|
59
|
+
msg = message.root
|
|
60
|
+
else:
|
|
61
|
+
msg = message
|
|
62
|
+
if isinstance(msg, dict):
|
|
63
|
+
result = msg.get("result") or msg.get("params") or msg
|
|
64
|
+
if isinstance(result, dict) and ("contents" in result or "content" in result):
|
|
65
|
+
return True
|
|
66
|
+
if hasattr(result, "contents"):
|
|
67
|
+
return True
|
|
68
|
+
return False
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _get_params(message: Any) -> dict | None:
|
|
72
|
+
"""Extract params from message."""
|
|
73
|
+
if hasattr(message, "root"):
|
|
74
|
+
msg = message.root
|
|
75
|
+
else:
|
|
76
|
+
msg = message
|
|
77
|
+
if isinstance(msg, dict):
|
|
78
|
+
return msg.get("params") or msg.get("result")
|
|
79
|
+
if hasattr(msg, "params"):
|
|
80
|
+
return getattr(msg.params, "__dict__", None) or {}
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _get_request_id(message: Any) -> str | None:
|
|
85
|
+
"""Extract request ID from message."""
|
|
86
|
+
if hasattr(message, "root"):
|
|
87
|
+
msg = message.root
|
|
88
|
+
else:
|
|
89
|
+
msg = message
|
|
90
|
+
if isinstance(msg, dict):
|
|
91
|
+
return str(msg.get("id", "")) or None
|
|
92
|
+
if hasattr(msg, "id"):
|
|
93
|
+
return str(getattr(msg, "id", "")) or None
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _get_content_from_result(result: Any) -> list[dict[str, Any]] | None:
|
|
98
|
+
"""Extract content list from result for PII redaction."""
|
|
99
|
+
if result is None:
|
|
100
|
+
return None
|
|
101
|
+
payload = result
|
|
102
|
+
if isinstance(result, dict) and "result" in result:
|
|
103
|
+
payload = result["result"]
|
|
104
|
+
if hasattr(payload, "contents"):
|
|
105
|
+
items = payload.contents
|
|
106
|
+
elif isinstance(payload, dict) and "contents" in payload:
|
|
107
|
+
items = payload["contents"]
|
|
108
|
+
elif isinstance(payload, dict) and "content" in payload:
|
|
109
|
+
items = payload["content"]
|
|
110
|
+
else:
|
|
111
|
+
return None
|
|
112
|
+
if not isinstance(items, list):
|
|
113
|
+
return None
|
|
114
|
+
out = []
|
|
115
|
+
for item in items:
|
|
116
|
+
if hasattr(item, "model_dump"):
|
|
117
|
+
out.append(item.model_dump())
|
|
118
|
+
elif isinstance(item, dict):
|
|
119
|
+
out.append(dict(item))
|
|
120
|
+
else:
|
|
121
|
+
out.append({"type": "text", "text": str(item)})
|
|
122
|
+
return out
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _set_content_in_result(result: Any, content: list[dict[str, Any]]) -> None:
|
|
126
|
+
"""Replace content in result after redaction."""
|
|
127
|
+
payload = result
|
|
128
|
+
if isinstance(result, dict) and "result" in result:
|
|
129
|
+
payload = result["result"]
|
|
130
|
+
if hasattr(payload, "contents"):
|
|
131
|
+
payload.contents = content
|
|
132
|
+
elif isinstance(payload, dict):
|
|
133
|
+
if "contents" in payload:
|
|
134
|
+
payload["contents"] = content
|
|
135
|
+
if "content" in payload:
|
|
136
|
+
payload["content"] = content
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class MCPBastionMiddleware(Middleware[Any]):
|
|
140
|
+
def __init__(
|
|
141
|
+
self,
|
|
142
|
+
prompt_guard: PromptGuardEngine | None = None,
|
|
143
|
+
pii_redactor: PIIRedactor | None = None,
|
|
144
|
+
rate_limiter: TokenBucketRateLimiter | None = None,
|
|
145
|
+
enable_prompt_guard: bool = True,
|
|
146
|
+
enable_pii_redaction: bool = True,
|
|
147
|
+
enable_rate_limit: bool = True,
|
|
148
|
+
) -> None:
|
|
149
|
+
self.prompt_guard = prompt_guard or PromptGuardEngine()
|
|
150
|
+
self.pii_redactor = pii_redactor or PIIRedactor()
|
|
151
|
+
self.rate_limiter = rate_limiter or TokenBucketRateLimiter()
|
|
152
|
+
self.enable_prompt_guard = enable_prompt_guard
|
|
153
|
+
self.enable_pii_redaction = enable_pii_redaction
|
|
154
|
+
self.enable_rate_limit = enable_rate_limit
|
|
155
|
+
|
|
156
|
+
async def __call__(
|
|
157
|
+
self,
|
|
158
|
+
context: MiddlewareContext[Any],
|
|
159
|
+
call_next: CallNext[Any],
|
|
160
|
+
) -> Any:
|
|
161
|
+
"""Run security checks, then call_next."""
|
|
162
|
+
start = time.perf_counter()
|
|
163
|
+
msg = context.message
|
|
164
|
+
|
|
165
|
+
try:
|
|
166
|
+
if _is_call_tool_request(msg):
|
|
167
|
+
return await self._handle_call_tool(context, call_next)
|
|
168
|
+
result = await call_next(context)
|
|
169
|
+
if result is not None and _is_read_resource_result(result):
|
|
170
|
+
result = self._redact_result_content(result)
|
|
171
|
+
return result
|
|
172
|
+
finally:
|
|
173
|
+
elapsed_ms = (time.perf_counter() - start) * 1000
|
|
174
|
+
context.metadata["elapsed_ms"] = round(elapsed_ms, 2)
|
|
175
|
+
logger.debug("request done elapsed_ms=%.2f", elapsed_ms)
|
|
176
|
+
|
|
177
|
+
async def _handle_call_tool(
|
|
178
|
+
self,
|
|
179
|
+
context: MiddlewareContext[Any],
|
|
180
|
+
call_next: CallNext[Any],
|
|
181
|
+
) -> Any:
|
|
182
|
+
"""Apply prompt guard and rate limit before tool execution."""
|
|
183
|
+
msg = context.message
|
|
184
|
+
params = _get_params(msg)
|
|
185
|
+
request_id = _get_request_id(msg) or context.request_id
|
|
186
|
+
session_id = context.session_id
|
|
187
|
+
|
|
188
|
+
if self.enable_rate_limit:
|
|
189
|
+
allowed, err = self.rate_limiter.check_iteration(
|
|
190
|
+
request_id=request_id,
|
|
191
|
+
session_id=session_id,
|
|
192
|
+
)
|
|
193
|
+
if not allowed:
|
|
194
|
+
logger.warning("rate_limit_blocked request_id=%s session_id=%s reason=%s", request_id, session_id, err)
|
|
195
|
+
raise RateLimitExceededError(err or "Rate limit exceeded")
|
|
196
|
+
|
|
197
|
+
if self.enable_prompt_guard and params:
|
|
198
|
+
arguments = params.get("arguments") or params
|
|
199
|
+
if isinstance(arguments, str):
|
|
200
|
+
try:
|
|
201
|
+
arguments = json.loads(arguments)
|
|
202
|
+
except json.JSONDecodeError:
|
|
203
|
+
arguments = {"raw": arguments}
|
|
204
|
+
text = _extract_text_from_value(arguments)
|
|
205
|
+
if text and self.prompt_guard.is_malicious(text):
|
|
206
|
+
logger.warning("prompt_injection_blocked request_id=%s", request_id)
|
|
207
|
+
raise PromptInjectionError()
|
|
208
|
+
|
|
209
|
+
self.rate_limiter.consume_iteration(
|
|
210
|
+
request_id=request_id,
|
|
211
|
+
session_id=session_id,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
result = await call_next(context)
|
|
215
|
+
|
|
216
|
+
if self.enable_pii_redaction and result is not None:
|
|
217
|
+
result = self._redact_result_content(result)
|
|
218
|
+
|
|
219
|
+
return result
|
|
220
|
+
|
|
221
|
+
def _redact_result_content(self, result: Any) -> Any:
|
|
222
|
+
"""Redact PII from result content items."""
|
|
223
|
+
content = _get_content_from_result(result)
|
|
224
|
+
if not content:
|
|
225
|
+
return result
|
|
226
|
+
redacted = self.pii_redactor.redact_content_items(content)
|
|
227
|
+
_set_content_in_result(result, redacted)
|
|
228
|
+
return result
|
|
229
|
+
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Security pillars for MCP-Bastion."""
|
|
2
|
+
|
|
3
|
+
from mcp_bastion.pillars.prompt_guard import PromptGuardEngine
|
|
4
|
+
from mcp_bastion.pillars.pii_redaction import PIIRedactor
|
|
5
|
+
from mcp_bastion.pillars.rate_limit import TokenBucketRateLimiter
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"PromptGuardEngine",
|
|
9
|
+
"PIIRedactor",
|
|
10
|
+
"TokenBucketRateLimiter",
|
|
11
|
+
]
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PII redaction via Microsoft Presidio.
|
|
3
|
+
|
|
4
|
+
presidio-analyzer, presidio-anonymizer, spaCy. Sanitizes TextContent.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class PIIRedactor:
|
|
16
|
+
"""Presidio + spaCy. Sanitizes TextContent."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
entities: list[str] | None = None,
|
|
21
|
+
language: str = "en",
|
|
22
|
+
) -> None:
|
|
23
|
+
self.entities = entities or [
|
|
24
|
+
"PERSON",
|
|
25
|
+
"EMAIL_ADDRESS",
|
|
26
|
+
"PHONE_NUMBER",
|
|
27
|
+
"CREDIT_CARD",
|
|
28
|
+
"US_SSN",
|
|
29
|
+
"US_PASSPORT",
|
|
30
|
+
"MEDICAL_LICENSE",
|
|
31
|
+
"IBAN_CODE",
|
|
32
|
+
]
|
|
33
|
+
self.language = language
|
|
34
|
+
self._analyzer = None
|
|
35
|
+
self._anonymizer = None
|
|
36
|
+
|
|
37
|
+
def _ensure_loaded(self) -> None:
|
|
38
|
+
"""Lazy-load Presidio components with optimized spaCy config."""
|
|
39
|
+
if self._analyzer is not None:
|
|
40
|
+
return
|
|
41
|
+
try:
|
|
42
|
+
from presidio_analyzer import AnalyzerEngine
|
|
43
|
+
from presidio_analyzer.nlp_engine import NlpEngineProvider
|
|
44
|
+
from presidio_anonymizer import AnonymizerEngine
|
|
45
|
+
|
|
46
|
+
config = {
|
|
47
|
+
"nlp_engine_name": "spacy",
|
|
48
|
+
"models": [{"lang_code": self.language, "model_name": "en_core_web_sm"}],
|
|
49
|
+
}
|
|
50
|
+
provider = NlpEngineProvider(nlp_configuration=config)
|
|
51
|
+
nlp_engine = provider.create_engine()
|
|
52
|
+
|
|
53
|
+
self._analyzer = AnalyzerEngine(nlp_engine=nlp_engine, supported_languages=[self.language])
|
|
54
|
+
self._anonymizer = AnonymizerEngine()
|
|
55
|
+
except Exception as e:
|
|
56
|
+
logger.warning("Presidio load failed: %s. PII redaction disabled.", e)
|
|
57
|
+
raise
|
|
58
|
+
|
|
59
|
+
def redact_text(self, text: str) -> str:
|
|
60
|
+
"""
|
|
61
|
+
Analyze and anonymize PII in the given text.
|
|
62
|
+
|
|
63
|
+
Returns sanitized text with detected entities replaced by placeholders.
|
|
64
|
+
"""
|
|
65
|
+
if not text or not isinstance(text, str):
|
|
66
|
+
return text
|
|
67
|
+
|
|
68
|
+
try:
|
|
69
|
+
self._ensure_loaded()
|
|
70
|
+
results = self._analyzer.analyze(
|
|
71
|
+
text=text,
|
|
72
|
+
language=self.language,
|
|
73
|
+
entities=self.entities,
|
|
74
|
+
)
|
|
75
|
+
if not results:
|
|
76
|
+
return text
|
|
77
|
+
logger.debug("redacted %d entities", len(results))
|
|
78
|
+
anonymized = self._anonymizer.anonymize(text=text, analyzer_results=results)
|
|
79
|
+
return anonymized.text
|
|
80
|
+
except Exception as e:
|
|
81
|
+
logger.warning("PII redaction failed: %s. Returning original text.", e)
|
|
82
|
+
return text
|
|
83
|
+
|
|
84
|
+
def redact_content_items(self, content: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
85
|
+
"""
|
|
86
|
+
Redact PII from MCP content items.
|
|
87
|
+
|
|
88
|
+
Processes TextContent items; other types are passed through unchanged.
|
|
89
|
+
"""
|
|
90
|
+
if not content:
|
|
91
|
+
return content
|
|
92
|
+
|
|
93
|
+
result = []
|
|
94
|
+
for item in content:
|
|
95
|
+
if not isinstance(item, dict):
|
|
96
|
+
result.append(item)
|
|
97
|
+
continue
|
|
98
|
+
if item.get("type") == "text" and "text" in item:
|
|
99
|
+
result.append({
|
|
100
|
+
**item,
|
|
101
|
+
"text": self.redact_text(str(item["text"])),
|
|
102
|
+
})
|
|
103
|
+
else:
|
|
104
|
+
result.append(item)
|
|
105
|
+
return result
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Prompt injection detection via Llama Prompt Guard 2.
|
|
3
|
+
|
|
4
|
+
Uses meta-llama/Llama-Prompt-Guard-2-86M, temperature-adjusted softmax.
|
|
5
|
+
Blocks when malicious probability exceeds threshold (default 0.85).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
from typing import TYPE_CHECKING
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
import torch
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
MALICIOUS_THRESHOLD = 0.85
|
|
19
|
+
TEMPERATURE = 0.1
|
|
20
|
+
MODEL_ID = "meta-llama/Llama-Prompt-Guard-2-86M"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class PromptGuardEngine:
|
|
24
|
+
"""PromptGuard 86M, temperature softmax. CPU/GPU."""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
threshold: float = MALICIOUS_THRESHOLD,
|
|
29
|
+
temperature: float = TEMPERATURE,
|
|
30
|
+
model_id: str = MODEL_ID,
|
|
31
|
+
device: str | None = None,
|
|
32
|
+
) -> None:
|
|
33
|
+
self.threshold = threshold
|
|
34
|
+
self.temperature = temperature
|
|
35
|
+
self.model_id = model_id
|
|
36
|
+
self._model = None
|
|
37
|
+
self._tokenizer = None
|
|
38
|
+
self._device = device
|
|
39
|
+
|
|
40
|
+
def _ensure_loaded(self) -> None:
|
|
41
|
+
"""Lazy-load model and tokenizer."""
|
|
42
|
+
if self._model is not None:
|
|
43
|
+
return
|
|
44
|
+
try:
|
|
45
|
+
import torch
|
|
46
|
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
47
|
+
|
|
48
|
+
self._tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
|
49
|
+
self._model = AutoModelForSequenceClassification.from_pretrained(self.model_id)
|
|
50
|
+
|
|
51
|
+
if self._device is None:
|
|
52
|
+
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
53
|
+
self._model = self._model.to(self._device)
|
|
54
|
+
self._model.eval()
|
|
55
|
+
logger.info("PromptGuard loaded model=%s device=%s", self.model_id, self._device)
|
|
56
|
+
except Exception as e:
|
|
57
|
+
logger.warning("PromptGuard model load failed: %s. Injection check disabled.", e)
|
|
58
|
+
raise
|
|
59
|
+
|
|
60
|
+
def _temperature_adjusted_softmax(self, logits: "torch.Tensor") -> "torch.Tensor":
|
|
61
|
+
"""Temperature scaling before softmax."""
|
|
62
|
+
import torch
|
|
63
|
+
scaled = logits / self.temperature
|
|
64
|
+
return torch.softmax(scaled, dim=-1)
|
|
65
|
+
|
|
66
|
+
def score(self, text: str) -> float:
|
|
67
|
+
"""Malicious probability 0-1. Above threshold = block."""
|
|
68
|
+
if not text or not text.strip():
|
|
69
|
+
return 0.0
|
|
70
|
+
|
|
71
|
+
self._ensure_loaded()
|
|
72
|
+
import torch
|
|
73
|
+
|
|
74
|
+
inputs = self._tokenizer(
|
|
75
|
+
text[:512],
|
|
76
|
+
return_tensors="pt",
|
|
77
|
+
truncation=True,
|
|
78
|
+
max_length=512,
|
|
79
|
+
).to(self._device)
|
|
80
|
+
|
|
81
|
+
with torch.no_grad():
|
|
82
|
+
outputs = self._model(**inputs)
|
|
83
|
+
probs = self._temperature_adjusted_softmax(outputs.logits)
|
|
84
|
+
probs_np = probs.cpu().numpy()
|
|
85
|
+
|
|
86
|
+
label2id = self._model.config.label2id
|
|
87
|
+
malicious_id = label2id.get("MALICIOUS", label2id.get("malicious", 1))
|
|
88
|
+
return float(probs_np[0][malicious_id])
|
|
89
|
+
|
|
90
|
+
def is_malicious(self, text: str) -> bool:
|
|
91
|
+
"""True if score >= threshold."""
|
|
92
|
+
try:
|
|
93
|
+
score = self.score(text)
|
|
94
|
+
return score >= self.threshold
|
|
95
|
+
except Exception as e:
|
|
96
|
+
logger.warning("PromptGuard inference failed: %s. Allowing request.", e)
|
|
97
|
+
return False
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Rate limiting: token bucket per session.
|
|
3
|
+
|
|
4
|
+
Max 15 iterations, 60s timeout, optional token budget (50k default).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
import time
|
|
11
|
+
from collections import defaultdict
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
DEFAULT_MAX_ITERATIONS = 15
|
|
17
|
+
DEFAULT_TIMEOUT_SECONDS = 60
|
|
18
|
+
DEFAULT_TOKEN_BUDGET = 50_000
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class SessionState:
|
|
23
|
+
"""Per-session rate limit state."""
|
|
24
|
+
|
|
25
|
+
iterations: int = 0
|
|
26
|
+
started_at: float = field(default_factory=time.monotonic)
|
|
27
|
+
tokens_used: int = 0
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class TokenBucketRateLimiter:
|
|
31
|
+
"""Token bucket per session. Iteration cap, timeout, token budget."""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
max_iterations: int = DEFAULT_MAX_ITERATIONS,
|
|
36
|
+
timeout_seconds: float = DEFAULT_TIMEOUT_SECONDS,
|
|
37
|
+
token_budget: int = DEFAULT_TOKEN_BUDGET,
|
|
38
|
+
) -> None:
|
|
39
|
+
self.max_iterations = max_iterations
|
|
40
|
+
self.timeout_seconds = timeout_seconds
|
|
41
|
+
self.token_budget = token_budget
|
|
42
|
+
self._sessions: dict[str, SessionState] = defaultdict(SessionState)
|
|
43
|
+
|
|
44
|
+
def _get_session_id(self, request_id: str | None, session_id: str | None) -> str:
|
|
45
|
+
"""Resolve session key from request or session ID."""
|
|
46
|
+
return session_id or request_id or "default"
|
|
47
|
+
|
|
48
|
+
def _cleanup_expired(self, session_key: str) -> None:
|
|
49
|
+
"""Remove session if it has exceeded the global timeout."""
|
|
50
|
+
state = self._sessions.get(session_key)
|
|
51
|
+
if state is None:
|
|
52
|
+
return
|
|
53
|
+
elapsed = time.monotonic() - state.started_at
|
|
54
|
+
if elapsed > self.timeout_seconds:
|
|
55
|
+
del self._sessions[session_key]
|
|
56
|
+
|
|
57
|
+
def check_iteration(
|
|
58
|
+
self,
|
|
59
|
+
request_id: str | None = None,
|
|
60
|
+
session_id: str | None = None,
|
|
61
|
+
) -> tuple[bool, str | None]:
|
|
62
|
+
"""
|
|
63
|
+
Check if another iteration is allowed.
|
|
64
|
+
|
|
65
|
+
Returns (allowed, error_message). If allowed is False, error_message
|
|
66
|
+
describes the violation.
|
|
67
|
+
"""
|
|
68
|
+
key = self._get_session_id(request_id, session_id)
|
|
69
|
+
self._cleanup_expired(key)
|
|
70
|
+
|
|
71
|
+
state = self._sessions[key]
|
|
72
|
+
elapsed = time.monotonic() - state.started_at
|
|
73
|
+
|
|
74
|
+
if elapsed > self.timeout_seconds:
|
|
75
|
+
del self._sessions[key]
|
|
76
|
+
return False, "Session timeout exceeded (60s limit)"
|
|
77
|
+
|
|
78
|
+
if state.iterations >= self.max_iterations:
|
|
79
|
+
return False, f"Maximum iterations exceeded ({self.max_iterations} limit)"
|
|
80
|
+
|
|
81
|
+
if state.tokens_used >= self.token_budget:
|
|
82
|
+
return False, f"Token budget exhausted ({self.token_budget} limit)"
|
|
83
|
+
|
|
84
|
+
return True, None
|
|
85
|
+
|
|
86
|
+
def consume_iteration(
|
|
87
|
+
self,
|
|
88
|
+
request_id: str | None = None,
|
|
89
|
+
session_id: str | None = None,
|
|
90
|
+
tokens: int = 0,
|
|
91
|
+
) -> None:
|
|
92
|
+
"""Record one iteration and optional token consumption."""
|
|
93
|
+
key = self._get_session_id(request_id, session_id)
|
|
94
|
+
state = self._sessions[key]
|
|
95
|
+
state.iterations += 1
|
|
96
|
+
state.tokens_used += tokens
|
|
97
|
+
|
|
98
|
+
def reset_session(
|
|
99
|
+
self,
|
|
100
|
+
request_id: str | None = None,
|
|
101
|
+
session_id: str | None = None,
|
|
102
|
+
) -> None:
|
|
103
|
+
"""Reset session state (e.g., on new request)."""
|
|
104
|
+
key = self._get_session_id(request_id, session_id)
|
|
105
|
+
if key in self._sessions:
|
|
106
|
+
del self._sessions[key]
|
|
@@ -0,0 +1,506 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mcp-bastion-python
|
|
3
|
+
Version: 1.0.1
|
|
4
|
+
Summary: Security middleware for MCP servers protecting LLM agents from prompt injection, resource exhaustion, and PII leakage
|
|
5
|
+
Project-URL: Homepage, https://github.com/mcp-bastion/mcp-bastion
|
|
6
|
+
Project-URL: Repository, https://github.com/mcp-bastion/mcp-bastion
|
|
7
|
+
Project-URL: Documentation, https://github.com/mcp-bastion/mcp-bastion#readme
|
|
8
|
+
Author: Viquar Khan
|
|
9
|
+
License-Expression: MIT
|
|
10
|
+
License-File: NOTICE
|
|
11
|
+
Keywords: llm,mcp,middleware,pii,prompt-injection,security
|
|
12
|
+
Classifier: Development Status :: 4 - Beta
|
|
13
|
+
Classifier: Intended Audience :: Developers
|
|
14
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Topic :: Security
|
|
20
|
+
Requires-Python: >=3.10
|
|
21
|
+
Requires-Dist: mcp>=1.0.0
|
|
22
|
+
Requires-Dist: presidio-analyzer>=2.2.0
|
|
23
|
+
Requires-Dist: presidio-anonymizer>=2.2.0
|
|
24
|
+
Requires-Dist: spacy>=3.5.0
|
|
25
|
+
Requires-Dist: torch>=2.0.0
|
|
26
|
+
Requires-Dist: transformers>=4.30.0
|
|
27
|
+
Provides-Extra: dev
|
|
28
|
+
Requires-Dist: pytest-asyncio>=0.21.0; extra == 'dev'
|
|
29
|
+
Requires-Dist: pytest-cov>=4.0.0; extra == 'dev'
|
|
30
|
+
Requires-Dist: pytest>=7.0.0; extra == 'dev'
|
|
31
|
+
Description-Content-Type: text/markdown
|
|
32
|
+
|
|
33
|
+
# MCP-Bastion
|
|
34
|
+
|
|
35
|
+
**Enterprise-Grade Security Middleware for the Model Context Protocol**
|
|
36
|
+
|
|
37
|
+
**Author:** Viquar Khan
|
|
38
|
+
|
|
39
|
+
> Releases are published automatically to npm and PyPI via GitHub Actions when tags are pushed.
|
|
40
|
+
|
|
41
|
+
The Model Context Protocol (MCP) has rapidly become the universally accepted standard for connecting AI agents to enterprise databases and APIs. However, this connectivity introduces a massive new attack surface: unpredictable, non-deterministic agentic behavior.
|
|
42
|
+
|
|
43
|
+
MCP-Bastion is a lightweight, drop-in security middleware designed to wrap around any existing Python or TypeScript MCP server. Instead of relying on passive logging, human-in-the-loop approvals, or third-party APIs, MCP-Bastion provides an active, 100% local defense layer. It intercepts standard JSON-RPC traffic to stop threats before they cross the enterprise boundary.
|
|
44
|
+
|
|
45
|
+
Under 5ms proxy overhead. MCP-Bastion provides:
|
|
46
|
+
|
|
47
|
+
- **Prompt Injection Defense:** Meta PromptGuard runs locally to block adversarial payloads and jailbreaks.
|
|
48
|
+
- **PII Redaction:** Uses Microsoft Presidio to detect and mask PII before it reaches the LLM context.
|
|
49
|
+
- **Infinite Loop Protection:** Token buckets and cycle detection stop runaway agents from burning API budget.
|
|
50
|
+
|
|
51
|
+
Secure your MCP server without changing business logic.
|
|
52
|
+
|
|
53
|
+
---
|
|
54
|
+
|
|
55
|
+
## Core Features
|
|
56
|
+
|
|
57
|
+
**Zero-Click Prompt Injection Prevention**
|
|
58
|
+
|
|
59
|
+
Integrates Meta's PromptGuard model locally to detect and block malicious payloads, jailbreaks, and adversarial tokenization before they reach your external tools.
|
|
60
|
+
|
|
61
|
+
**PII Redaction**
|
|
62
|
+
|
|
63
|
+
Microsoft Presidio scans outbound tool results and masks PII (redaction, substitution, generalization).
|
|
64
|
+
|
|
65
|
+
**Infinite Loop and Denial of Wallet Protection**
|
|
66
|
+
|
|
67
|
+
Implements stateful cycle detection and configurable FinOps token-bucket algorithms to automatically terminate runaway agents and prevent massive API bill overruns.
|
|
68
|
+
|
|
69
|
+
**100% Local Execution (Data Privacy)**
|
|
70
|
+
|
|
71
|
+
All security classification and data redaction happen entirely within the local memory space of your server. Sensitive data never leaves your enterprise network for third-party safety evaluations.
|
|
72
|
+
|
|
73
|
+
**Low Latency**
|
|
74
|
+
|
|
75
|
+
Drop-in middleware, under 5ms overhead.
|
|
76
|
+
|
|
77
|
+
**Framework Integration**
|
|
78
|
+
|
|
79
|
+
Hooks into MCP SDKs (TypeScript, Python) and FastMCP via standard middleware. No business logic changes.
|
|
80
|
+
|
|
81
|
+
---
|
|
82
|
+
|
|
83
|
+
## Why MCP-Bastion (Competitive Comparison)
|
|
84
|
+
|
|
85
|
+
Early security packages (mcp-guardian, mcp-shield) focus on logging or static scanning. MCP-Bastion adds an active defense layer.
|
|
86
|
+
|
|
87
|
+
### 1. Active Defense vs. Passive Logging
|
|
88
|
+
|
|
89
|
+
| The Competition | MCP-Bastion |
|
|
90
|
+
|-----------------|-------------|
|
|
91
|
+
| Tools like mcp-guardian focus on tracing, logging, human-in-the-loop approvals. | Automated interception. MCP-Bastion scrubs PII before it leaves the server. |
|
|
92
|
+
|
|
93
|
+
### 2. Local Inference vs. Third-Party APIs
|
|
94
|
+
|
|
95
|
+
| The Competition | MCP-Bastion |
|
|
96
|
+
|-----------------|-------------|
|
|
97
|
+
| Many guardrail proxies send prompts to external APIs (e.g. OpenAI moderation) to check for malice. | PromptGuard-86M and Presidio run locally. Data stays on your network. |
|
|
98
|
+
|
|
99
|
+
### 3. Stateful Denial of Wallet Protection
|
|
100
|
+
|
|
101
|
+
| The Competition | MCP-Bastion |
|
|
102
|
+
|-----------------|-------------|
|
|
103
|
+
| Most tools focus on static vulns or basic rate limits. | Tracks tool call history per session. Stops runaway loops before they burn API budget. |
|
|
104
|
+
|
|
105
|
+
### 4. Drop-in Middleware vs. Standalone Gateway
|
|
106
|
+
|
|
107
|
+
| The Competition | MCP-Bastion |
|
|
108
|
+
|-----------------|-------------|
|
|
109
|
+
| Some solutions need standalone proxy servers. | Library hooks into `server.setRequestHandler` (TS) or middleware (Python). No extra infra. |
|
|
110
|
+
|
|
111
|
+
---
|
|
112
|
+
|
|
113
|
+
## Structure
|
|
114
|
+
|
|
115
|
+
| Path | Description |
|
|
116
|
+
|------|-------------|
|
|
117
|
+
| `src/mcp_bastion/` | Python package: PromptGuard, Presidio, rate limiting |
|
|
118
|
+
| `packages/core/` | TypeScript package: rate limiting; ML via Python sidecar |
|
|
119
|
+
| `examples/` | Python examples: basic middleware, full demo ([examples/README.md](examples/README.md)) |
|
|
120
|
+
| `scripts/validate_checklist.py` | Enterprise validation runner |
|
|
121
|
+
| `VALIDATION_CHECKLIST.md` | Validation guide and MCP Inspector steps |
|
|
122
|
+
| `SETUP_GUIDE.md` | Setup, config, and validation |
|
|
123
|
+
|
|
124
|
+
## Installation
|
|
125
|
+
|
|
126
|
+
**Python**
|
|
127
|
+
|
|
128
|
+
```bash
|
|
129
|
+
uv add mcp-bastion-python
|
|
130
|
+
# or
|
|
131
|
+
pip install mcp-bastion-python
|
|
132
|
+
```
|
|
133
|
+
|
|
134
|
+
**TypeScript**
|
|
135
|
+
|
|
136
|
+
```bash
|
|
137
|
+
npm install @mcp-bastion/core
|
|
138
|
+
```
|
|
139
|
+
|
|
140
|
+
## Developer Guide
|
|
141
|
+
|
|
142
|
+
Integration examples for Python and TypeScript.
|
|
143
|
+
|
|
144
|
+
---
|
|
145
|
+
|
|
146
|
+
### Quick Start (Python)
|
|
147
|
+
|
|
148
|
+
Add MCP-Bastion to an existing MCP server in three steps:
|
|
149
|
+
|
|
150
|
+
```python
|
|
151
|
+
from mcp_bastion import MCPBastionMiddleware, compose_middleware
|
|
152
|
+
|
|
153
|
+
# 1. Create the security middleware
|
|
154
|
+
bastion = MCPBastionMiddleware(
|
|
155
|
+
enable_prompt_guard=True,
|
|
156
|
+
enable_pii_redaction=True,
|
|
157
|
+
enable_rate_limit=True,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# 2. Compose with your middleware chain (Bastion runs first)
|
|
161
|
+
middleware = compose_middleware(bastion)
|
|
162
|
+
|
|
163
|
+
# 3. Pass the composed middleware to your MCP server
|
|
164
|
+
# (integration depends on your server framework)
|
|
165
|
+
```
|
|
166
|
+
|
|
167
|
+
**Examples:**
|
|
168
|
+
|
|
169
|
+
| Example | Description |
|
|
170
|
+
|---------|-------------|
|
|
171
|
+
| `examples/python_server_example.py` | Basic middleware chain |
|
|
172
|
+
| `examples/full_demo.py` | All features: add, PII, rate limit, prompt injection |
|
|
173
|
+
|
|
174
|
+
```bash
|
|
175
|
+
# Windows: $env:PYTHONPATH="src"; python examples/full_demo.py
|
|
176
|
+
# Linux/Mac: PYTHONPATH=src python examples/full_demo.py
|
|
177
|
+
```
|
|
178
|
+
|
|
179
|
+
**Enterprise validation:**
|
|
180
|
+
|
|
181
|
+
```bash
|
|
182
|
+
PYTHONPATH=src python scripts/validate_checklist.py
|
|
183
|
+
```
|
|
184
|
+
|
|
185
|
+
See `VALIDATION_CHECKLIST.md` and `SETUP_GUIDE.md`.
|
|
186
|
+
|
|
187
|
+
---
|
|
188
|
+
|
|
189
|
+
### Python Tutorial: FastMCP Server
|
|
190
|
+
|
|
191
|
+
FastMCP server with MCP-Bastion.
|
|
192
|
+
|
|
193
|
+
**Step 1: Install dependencies**
|
|
194
|
+
|
|
195
|
+
```bash
|
|
196
|
+
pip install mcp mcp-bastion-python
|
|
197
|
+
```
|
|
198
|
+
|
|
199
|
+
**Step 2: Create your server file** (`server.py`)
|
|
200
|
+
|
|
201
|
+
```python
|
|
202
|
+
from mcp.server.fastmcp import FastMCP
|
|
203
|
+
from mcp_bastion import MCPBastionMiddleware, compose_middleware
|
|
204
|
+
|
|
205
|
+
# Create the MCP server
|
|
206
|
+
mcp = FastMCP("My Secure Server")
|
|
207
|
+
|
|
208
|
+
# Create MCP-Bastion middleware
|
|
209
|
+
# It intercepts tool calls and resource reads before they execute
|
|
210
|
+
bastion = MCPBastionMiddleware(
|
|
211
|
+
enable_prompt_guard=True, # Block malicious prompts via PromptGuard
|
|
212
|
+
enable_pii_redaction=True, # Mask PII in outgoing content
|
|
213
|
+
enable_rate_limit=True, # Cap at 15 iterations, 60s timeout
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# Compose middleware chain (pass to your server's middleware config if supported)
|
|
217
|
+
middleware = compose_middleware(bastion)
|
|
218
|
+
|
|
219
|
+
# Register a tool (protected when middleware is wired into your server)
|
|
220
|
+
@mcp.tool()
|
|
221
|
+
def get_weather(city: str) -> str:
|
|
222
|
+
"""Get weather for a city."""
|
|
223
|
+
return f"Weather in {city}: 22C, sunny"
|
|
224
|
+
|
|
225
|
+
# Resource (PII redacted)
|
|
226
|
+
@mcp.resource("user://profile/{user_id}")
|
|
227
|
+
def get_profile(user_id: str) -> str:
|
|
228
|
+
"""Get user profile. PII redacted."""
|
|
229
|
+
return f"User {user_id}: John Doe, SSN 123-45-6789, john@example.com"
|
|
230
|
+
|
|
231
|
+
if __name__ == "__main__":
|
|
232
|
+
mcp.run(transport="streamable-http")
|
|
233
|
+
```
|
|
234
|
+
|
|
235
|
+
**Step 3: Run the server**
|
|
236
|
+
|
|
237
|
+
```bash
|
|
238
|
+
python server.py
|
|
239
|
+
```
|
|
240
|
+
|
|
241
|
+
MCP-Bastion:
|
|
242
|
+
- Scans tool args for prompt injection
|
|
243
|
+
- Redacts PII from resource responses
|
|
244
|
+
- Blocks sessions over 15 calls or 60s
|
|
245
|
+
|
|
246
|
+
---
|
|
247
|
+
|
|
248
|
+
### Python: Custom Rate Limits
|
|
249
|
+
|
|
250
|
+
Custom config example:
|
|
251
|
+
|
|
252
|
+
```python
|
|
253
|
+
from mcp_bastion import MCPBastionMiddleware
|
|
254
|
+
from mcp_bastion.pillars.rate_limit import TokenBucketRateLimiter
|
|
255
|
+
from mcp_bastion.pillars.prompt_guard import PromptGuardEngine
|
|
256
|
+
|
|
257
|
+
# Stricter limits
|
|
258
|
+
rate_limiter = TokenBucketRateLimiter(
|
|
259
|
+
max_iterations=10,
|
|
260
|
+
timeout_seconds=30,
|
|
261
|
+
token_budget=25_000,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Higher threshold = fewer blocks, more risk
|
|
265
|
+
prompt_guard = PromptGuardEngine(threshold=0.92)
|
|
266
|
+
|
|
267
|
+
bastion = MCPBastionMiddleware(
|
|
268
|
+
prompt_guard=prompt_guard,
|
|
269
|
+
rate_limiter=rate_limiter,
|
|
270
|
+
enable_prompt_guard=True,
|
|
271
|
+
enable_pii_redaction=True,
|
|
272
|
+
enable_rate_limit=True,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# Disable PII redaction if your data has no PII
|
|
276
|
+
bastion_no_pii = MCPBastionMiddleware(enable_pii_redaction=False)
|
|
277
|
+
```
|
|
278
|
+
|
|
279
|
+
---
|
|
280
|
+
|
|
281
|
+
### Python: Custom Middleware
|
|
282
|
+
|
|
283
|
+
Extend `Middleware` to add logging, metrics, or custom logic:
|
|
284
|
+
|
|
285
|
+
```python
|
|
286
|
+
from mcp_bastion.base import Middleware, MiddlewareContext, compose_middleware
|
|
287
|
+
|
|
288
|
+
class LoggingMiddleware(Middleware):
|
|
289
|
+
async def on_message(self, context, call_next):
|
|
290
|
+
result = await call_next(context)
|
|
291
|
+
# log method, elapsed, etc.
|
|
292
|
+
return result
|
|
293
|
+
|
|
294
|
+
middleware = compose_middleware(bastion, LoggingMiddleware())
|
|
295
|
+
```
|
|
296
|
+
|
|
297
|
+
See `examples/full_demo.py` for a complete example.
|
|
298
|
+
|
|
299
|
+
---
|
|
300
|
+
|
|
301
|
+
### TypeScript: Wrap an MCP Server
|
|
302
|
+
|
|
303
|
+
**Step 1: Install dependencies**
|
|
304
|
+
|
|
305
|
+
```bash
|
|
306
|
+
npm install @modelcontextprotocol/sdk @mcp-bastion/core
|
|
307
|
+
```
|
|
308
|
+
|
|
309
|
+
**Step 2: Create your server** (`server.ts`)
|
|
310
|
+
|
|
311
|
+
```typescript
|
|
312
|
+
import { Server } from "@modelcontextprotocol/sdk/server/index.js";
|
|
313
|
+
import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js";
|
|
314
|
+
import {
|
|
315
|
+
wrapWithMcpBastion,
|
|
316
|
+
wrapCallToolHandler,
|
|
317
|
+
} from "@mcp-bastion/core";
|
|
318
|
+
|
|
319
|
+
const server = new Server({ name: "my-mcp-server", version: "1.0.0" });
|
|
320
|
+
|
|
321
|
+
// Wrap the server with MCP-Bastion (rate limiting only by default)
|
|
322
|
+
// For prompt injection and PII, run the Python sidecar and set sidecarUrl
|
|
323
|
+
wrapWithMcpBastion(server, {
|
|
324
|
+
enableRateLimit: true,
|
|
325
|
+
maxIterations: 15,
|
|
326
|
+
timeoutMs: 60_000,
|
|
327
|
+
// Optional: enable ML features via Python sidecar
|
|
328
|
+
sidecarUrl: process.env.MCP_BASTION_SIDECAR || "",
|
|
329
|
+
enablePromptGuard: !!process.env.MCP_BASTION_SIDECAR,
|
|
330
|
+
enablePiiRedaction: !!process.env.MCP_BASTION_SIDECAR,
|
|
331
|
+
});
|
|
332
|
+
|
|
333
|
+
// Register tools (handlers are automatically wrapped)
|
|
334
|
+
server.setRequestHandler("tools/call" as any, async (request) => {
|
|
335
|
+
if (request.params?.name === "get_weather") {
|
|
336
|
+
return {
|
|
337
|
+
content: [{ type: "text", text: "Sunny, 22C" }],
|
|
338
|
+
isError: false,
|
|
339
|
+
};
|
|
340
|
+
}
|
|
341
|
+
throw new Error("Unknown tool");
|
|
342
|
+
});
|
|
343
|
+
|
|
344
|
+
async function main() {
|
|
345
|
+
const transport = new StdioServerTransport();
|
|
346
|
+
await server.connect(transport);
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
main();
|
|
350
|
+
```
|
|
351
|
+
|
|
352
|
+
**Step 3: Run with rate limiting only**
|
|
353
|
+
|
|
354
|
+
```bash
|
|
355
|
+
npx tsx server.ts
|
|
356
|
+
```
|
|
357
|
+
|
|
358
|
+
**Step 4: Run with full ML features (Python sidecar)**
|
|
359
|
+
|
|
360
|
+
For prompt injection and PII redaction, run a Python HTTP service that exposes `/prompt-guard` and `/pii-redact` endpoints (see the Python package for sidecar implementation). Then:
|
|
361
|
+
|
|
362
|
+
```bash
|
|
363
|
+
# Start the Python sidecar, then the TypeScript server
|
|
364
|
+
MCP_BASTION_SIDECAR=http://localhost:8000 npx tsx server.ts
|
|
365
|
+
```
|
|
366
|
+
|
|
367
|
+
---
|
|
368
|
+
|
|
369
|
+
### TypeScript: Wrap Individual Handlers
|
|
370
|
+
|
|
371
|
+
Wrap specific handlers only:
|
|
372
|
+
|
|
373
|
+
```typescript
|
|
374
|
+
import {
|
|
375
|
+
wrapCallToolHandler,
|
|
376
|
+
wrapReadResourceHandler,
|
|
377
|
+
} from "@mcp-bastion/core";
|
|
378
|
+
import {
|
|
379
|
+
CallToolRequestSchema,
|
|
380
|
+
ReadResourceRequestSchema,
|
|
381
|
+
} from "@modelcontextprotocol/sdk/types.js";
|
|
382
|
+
|
|
383
|
+
// Wrap only the tool handler
|
|
384
|
+
const safeToolHandler = wrapCallToolHandler(
|
|
385
|
+
async (request) => {
|
|
386
|
+
// Your tool logic
|
|
387
|
+
return { content: [{ type: "text", text: "OK" }], isError: false };
|
|
388
|
+
},
|
|
389
|
+
{ enableRateLimit: true, maxIterations: 10 }
|
|
390
|
+
);
|
|
391
|
+
|
|
392
|
+
// Wrap only the resource handler (for PII redaction)
|
|
393
|
+
const safeResourceHandler = wrapReadResourceHandler(
|
|
394
|
+
async (request) => {
|
|
395
|
+
const contents = await fetchResource(request.params.uri);
|
|
396
|
+
return { contents };
|
|
397
|
+
},
|
|
398
|
+
{ sidecarUrl: "http://localhost:8000", enablePiiRedaction: true }
|
|
399
|
+
);
|
|
400
|
+
|
|
401
|
+
server.setRequestHandler(CallToolRequestSchema, safeToolHandler);
|
|
402
|
+
server.setRequestHandler(ReadResourceRequestSchema, safeResourceHandler);
|
|
403
|
+
```
|
|
404
|
+
|
|
405
|
+
---
|
|
406
|
+
|
|
407
|
+
### Configuration Reference
|
|
408
|
+
|
|
409
|
+
| Option | Python | TypeScript | Default | Description |
|
|
410
|
+
|--------|--------|------------|---------|-------------|
|
|
411
|
+
| `enable_prompt_guard` | Yes | Yes | `True` (Python) / `False` (TS) | Block malicious prompts via PromptGuard |
|
|
412
|
+
| `enable_pii_redaction` | Yes | Yes | `True` (Python) / `False` (TS) | Mask PII in outgoing content |
|
|
413
|
+
| `enable_rate_limit` | Yes | Yes | `True` | Enforce iteration and timeout caps |
|
|
414
|
+
| `max_iterations` | Via `TokenBucketRateLimiter` | Yes | 15 | Max tool calls per session |
|
|
415
|
+
| `timeout_seconds` / `timeoutMs` | Via `TokenBucketRateLimiter` | Yes | 60 | Session timeout |
|
|
416
|
+
| `token_budget` | Via `TokenBucketRateLimiter` | - | 50,000 | FinOps token cap per request |
|
|
417
|
+
| `sidecarUrl` | - | Yes | `""` | Python sidecar URL for ML features |
|
|
418
|
+
| `threshold` | Via `PromptGuardEngine` | - | 0.85 | Malicious probability cutoff |
|
|
419
|
+
| `setLogLevel` | - | Yes | `"info"` | TypeScript: `"debug"` \| `"info"` \| `"warn"` \| `"error"` |
|
|
420
|
+
|
|
421
|
+
---
|
|
422
|
+
|
|
423
|
+
### Error Handling
|
|
424
|
+
|
|
425
|
+
When MCP-Bastion blocks a request, it returns standard MCP/JSON-RPC errors:
|
|
426
|
+
|
|
427
|
+
| Code | Exception | When |
|
|
428
|
+
|------|-----------|------|
|
|
429
|
+
| -32001 | `PromptInjectionError` | Tool args contain jailbreak/injection |
|
|
430
|
+
| -32002 | `RateLimitExceededError` | Session exceeds iteration or timeout limit |
|
|
431
|
+
| -32003 | `TokenBudgetExceededError` | Session exceeds token budget |
|
|
432
|
+
|
|
433
|
+
```python
|
|
434
|
+
# Python: exceptions
|
|
435
|
+
from mcp_bastion.errors import (
|
|
436
|
+
PromptInjectionError,
|
|
437
|
+
RateLimitExceededError,
|
|
438
|
+
TokenBudgetExceededError,
|
|
439
|
+
)
|
|
440
|
+
import logging
|
|
441
|
+
logger = logging.getLogger(__name__)
|
|
442
|
+
|
|
443
|
+
try:
|
|
444
|
+
result = await middleware(context, call_next)
|
|
445
|
+
except PromptInjectionError as e:
|
|
446
|
+
logger.warning("blocked: %s", e.to_mcp_error())
|
|
447
|
+
except RateLimitExceededError as e:
|
|
448
|
+
logger.warning("blocked: %s", e.to_mcp_error())
|
|
449
|
+
except TokenBudgetExceededError as e:
|
|
450
|
+
logger.warning("blocked: %s", e.to_mcp_error())
|
|
451
|
+
```
|
|
452
|
+
|
|
453
|
+
```typescript
|
|
454
|
+
// TypeScript: handlers return isError: true
|
|
455
|
+
import { logger, setLogLevel } from "@mcp-bastion/core";
|
|
456
|
+
setLogLevel("debug"); // optional: "debug" | "info" | "warn" | "error"
|
|
457
|
+
const result = await guardedHandler(request);
|
|
458
|
+
if (result.isError) {
|
|
459
|
+
logger.error("blocked", result.content);
|
|
460
|
+
}
|
|
461
|
+
```
|
|
462
|
+
|
|
463
|
+
---
|
|
464
|
+
|
|
465
|
+
### Testing
|
|
466
|
+
|
|
467
|
+
MCP Inspector:
|
|
468
|
+
|
|
469
|
+
```bash
|
|
470
|
+
# Start your guarded server
|
|
471
|
+
python server.py # or: npx tsx server.ts
|
|
472
|
+
|
|
473
|
+
# In another terminal, launch the Inspector
|
|
474
|
+
npx -y @modelcontextprotocol/inspector
|
|
475
|
+
```
|
|
476
|
+
|
|
477
|
+
Connect via HTTP (`http://localhost:8000/mcp`) or stdio, then:
|
|
478
|
+
1. List tools and call one with benign arguments (should succeed)
|
|
479
|
+
2. Call a tool with "Ignore previous instructions" (should be blocked)
|
|
480
|
+
3. Trigger 16+ tool calls in one session (should hit rate limit)
|
|
481
|
+
|
|
482
|
+
---
|
|
483
|
+
|
|
484
|
+
## Testing
|
|
485
|
+
|
|
486
|
+
```bash
|
|
487
|
+
# Python (PYTHONPATH=src on Windows: $env:PYTHONPATH="src")
|
|
488
|
+
pytest tests/ -v
|
|
489
|
+
|
|
490
|
+
# TypeScript
|
|
491
|
+
npm run test --workspace=@mcp-bastion/core
|
|
492
|
+
|
|
493
|
+
# Full validation checklist (build, pillars, latency)
|
|
494
|
+
PYTHONPATH=src python scripts/validate_checklist.py
|
|
495
|
+
|
|
496
|
+
# MCP Inspector (manual)
|
|
497
|
+
npx -y @modelcontextprotocol/inspector
|
|
498
|
+
```
|
|
499
|
+
|
|
500
|
+
## Third-Party Components
|
|
501
|
+
|
|
502
|
+
See `NOTICE` for licenses. MCP-Bastion uses Meta Llama Prompt Guard 2 (Llama 4 Community License) and Microsoft Presidio.
|
|
503
|
+
|
|
504
|
+
## License
|
|
505
|
+
|
|
506
|
+
MIT
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
mcp_bastion/__init__.py,sha256=izXzfw8A1AXqS-b22XPTtTSolsdMluCjPsUugd6CLoQ,347
|
|
2
|
+
mcp_bastion/base.py,sha256=zZ0YG0tVn01cS7osjJDgBqK_6PVFme3GQ6tmhyfhITo,2910
|
|
3
|
+
mcp_bastion/errors.py,sha256=oMGb3cjjSo3sOof8Hv9DdkucczhOAV2LY74xeEAkiXY,1295
|
|
4
|
+
mcp_bastion/middleware.py,sha256=_FHFfINBP8Qmdl6KliEQbBIZIxlprO8-lkMoK0N0I_o,7799
|
|
5
|
+
mcp_bastion/pillars/__init__.py,sha256=ybeqOYWVTGT3PEnquAz8nXSrofMV1F64H9eTC4P7wuU,317
|
|
6
|
+
mcp_bastion/pillars/pii_redaction.py,sha256=U6jl33qwV53q7I14McRD7tvRadhpHz39gecMJBsps0c,3326
|
|
7
|
+
mcp_bastion/pillars/prompt_guard.py,sha256=wR50TikDcnEjyzWz5Qut2QN523fW8ZIwUNfNlB14lqY,3148
|
|
8
|
+
mcp_bastion/pillars/rate_limit.py,sha256=vp_r8TpNVx9nyVzgf96h6EW7O3YmX_tfkHyaRLlxj6M,3328
|
|
9
|
+
mcp_bastion_python-1.0.1.dist-info/METADATA,sha256=Qp4fkViR2Vva4XVC1ai2JIidleFAjsGWTfeH1Ihf3Wg,15464
|
|
10
|
+
mcp_bastion_python-1.0.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
11
|
+
mcp_bastion_python-1.0.1.dist-info/licenses/NOTICE,sha256=_DlzQBhNBsf8mK-N55MCiz9juUyYjIaEPxZlMjFvPmc,273
|
|
12
|
+
mcp_bastion_python-1.0.1.dist-info/RECORD,,
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
MCP-Bastion uses the following third-party components:
|
|
2
|
+
|
|
3
|
+
Llama Prompt Guard 2 (meta-llama/Llama-Prompt-Guard-2-86M)
|
|
4
|
+
Llama 4 is licensed under the Llama 4 Community License, Copyright (c) Meta Platforms, Inc. All Rights Reserved.
|
|
5
|
+
See: https://www.llama.com/docs/overview
|