openmaskit 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openmaskit/__init__.py +19 -0
- openmaskit/__main__.py +617 -0
- openmaskit/backend_client.py +181 -0
- openmaskit/cli.py +112 -0
- openmaskit/config.py +126 -0
- openmaskit/container.py +285 -0
- openmaskit/logging_config.py +74 -0
- openmaskit/masking/__init__.py +0 -0
- openmaskit/masking/engine.py +536 -0
- openmaskit/masking/mappers.py +20 -0
- openmaskit/masking/parsing.py +51 -0
- openmaskit/masking/rules.py +103 -0
- openmaskit/masking/store.py +619 -0
- openmaskit/models.py +66 -0
- openmaskit/oauth/__init__.py +0 -0
- openmaskit/oauth/handler.py +431 -0
- openmaskit/proxy/__init__.py +0 -0
- openmaskit/proxy/core.py +574 -0
- openmaskit/proxy/http_downstream.py +159 -0
- openmaskit/proxy/manager.py +260 -0
- openmaskit/proxy/upstream.py +321 -0
- openmaskit/security.py +145 -0
- openmaskit/traffic/__init__.py +0 -0
- openmaskit/traffic/buffer.py +44 -0
- openmaskit/traffic/store.py +223 -0
- openmaskit/web/__init__.py +0 -0
- openmaskit/web/app.py +142 -0
- openmaskit/web/health.py +109 -0
- openmaskit/web/origin.py +110 -0
- openmaskit/web/routes/__init__.py +0 -0
- openmaskit/web/routes/custom_targets.py +280 -0
- openmaskit/web/routes/guardrails.py +160 -0
- openmaskit/web/routes/hidden_tools.py +40 -0
- openmaskit/web/routes/injections.py +142 -0
- openmaskit/web/routes/mappers.py +382 -0
- openmaskit/web/routes/marketplace.py +607 -0
- openmaskit/web/routes/oauth.py +162 -0
- openmaskit/web/routes/oauth_callback.py +191 -0
- openmaskit/web/routes/pages.py +158 -0
- openmaskit/web/routes/rules.py +124 -0
- openmaskit/web/routes/traffic.py +82 -0
- openmaskit/web/static/big.png +0 -0
- openmaskit/web/static/favicon.png +0 -0
- openmaskit/web/static/icon.png +0 -0
- openmaskit/web/static/marketplace.html +937 -0
- openmaskit/web/static/new_maskit-removebg-preview.png +0 -0
- openmaskit/web/static/onboarding.css +386 -0
- openmaskit/web/static/original_icon.png +0 -0
- openmaskit/web/static/shared.js +322 -0
- openmaskit/web/static/style.css +5036 -0
- openmaskit/web/static/targets.html +1174 -0
- openmaskit/web/static/tool_detail.html +936 -0
- openmaskit/web/static/tools.html +661 -0
- openmaskit/web/static/tutorial.css +377 -0
- openmaskit/web/static/tutorial.js +546 -0
- openmaskit/web/static/tutorials/guardrails.json +31 -0
- openmaskit/web/static/tutorials/hide-tool.json +16 -0
- openmaskit/web/static/tutorials/injections.json +31 -0
- openmaskit/web/static/tutorials/masking-with-result.json +31 -0
- openmaskit/web/static/tutorials/masking.json +16 -0
- openmaskit-0.1.1.dist-info/METADATA +229 -0
- openmaskit-0.1.1.dist-info/RECORD +66 -0
- openmaskit-0.1.1.dist-info/WHEEL +4 -0
- openmaskit-0.1.1.dist-info/entry_points.txt +2 -0
- openmaskit-0.1.1.dist-info/licenses/LICENSE +201 -0
- openmaskit-0.1.1.dist-info/licenses/NOTICE +10 -0
openmaskit/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""OpenMaskit - MCP server wrapper that masks sensitive fields."""
|
|
2
|
+
|
|
3
|
+
from importlib.metadata import version, PackageNotFoundError
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
__version__ = version("openmaskit")
|
|
7
|
+
except PackageNotFoundError:
|
|
8
|
+
# Running from source without the package installed
|
|
9
|
+
try:
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
try:
|
|
12
|
+
import tomllib
|
|
13
|
+
except ImportError:
|
|
14
|
+
import tomli as tomllib # type: ignore
|
|
15
|
+
_pyproject = Path(__file__).parent.parent.parent / "pyproject.toml"
|
|
16
|
+
with open(_pyproject, "rb") as _f:
|
|
17
|
+
__version__ = tomllib.load(_f).get("project", {}).get("version", "unknown")
|
|
18
|
+
except Exception:
|
|
19
|
+
__version__ = "unknown"
|
openmaskit/__main__.py
ADDED
|
@@ -0,0 +1,617 @@
|
|
|
1
|
+
"""OpenMaskit entry point."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import signal
|
|
9
|
+
import sys
|
|
10
|
+
from contextlib import AsyncExitStack
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
import random
|
|
13
|
+
import string
|
|
14
|
+
|
|
15
|
+
import anyio
|
|
16
|
+
import uvicorn
|
|
17
|
+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
18
|
+
|
|
19
|
+
from mcp.shared.message import SessionMessage
|
|
20
|
+
|
|
21
|
+
from openmaskit.config import load_config
|
|
22
|
+
from openmaskit.masking.engine import MaskingEngine
|
|
23
|
+
from openmaskit.masking.rules import ArgumentGuardrail, ArgumentInjection, MaskingRule
|
|
24
|
+
from openmaskit.masking.store import MaskingStore
|
|
25
|
+
from openmaskit.oauth.handler import OAuthCallbackServer
|
|
26
|
+
from openmaskit.proxy.core import ProxyState, TargetState, run_proxy_for_target
|
|
27
|
+
from openmaskit.proxy.http_downstream import create_mcp_app
|
|
28
|
+
from openmaskit.proxy.manager import TargetManager, _build_upstream_config
|
|
29
|
+
from openmaskit.proxy.upstream import connect_upstream, is_oauth_token_expired, refresh_backend_oauth_token
|
|
30
|
+
from openmaskit.traffic.buffer import TrafficBuffer
|
|
31
|
+
from openmaskit.traffic.store import TrafficStore
|
|
32
|
+
from openmaskit.web.app import create_app
|
|
33
|
+
from openmaskit import __version__
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
def print_banner():
|
|
38
|
+
# https://patorjk.com/software/taag
|
|
39
|
+
# DOS Rebel for openmaskit
|
|
40
|
+
# Pagga for version
|
|
41
|
+
banner = """
|
|
42
|
+
███████ ██████ ██████ █████ ███ █████
|
|
43
|
+
███░░░░░███ ░░██████ ██████ ░░███ ░░░ ░░███
|
|
44
|
+
███ ░░███ ████████ ██████ ████████ ░███░█████░███ ██████ █████ ░███ █████ ████ ███████
|
|
45
|
+
░███ ░███░░███░░███ ███░░███░░███░░███ ░███░░███ ░███ ░░░░░███ ███░░ ░███░░███ ░░███ ░░░███░
|
|
46
|
+
░███ ░███ ░███ ░███░███████ ░███ ░███ ░███ ░░░ ░███ ███████ ░░█████ ░██████░ ░███ ░███
|
|
47
|
+
░░███ ███ ░███ ░███░███░░░ ░███ ░███ ░███ ░███ ███░░███ ░░░░███ ░███░░███ ░███ ░███ ███
|
|
48
|
+
░░░███████░ ░███████ ░░██████ ████ █████ █████ █████░░████████ ██████ ████ █████ █████ ░░█████
|
|
49
|
+
░░░░░░░ ░███░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░ ░░░░░
|
|
50
|
+
░███
|
|
51
|
+
█████ ░▄▀▄░░░░▀█░░░░░▄▀▄
|
|
52
|
+
░░░░░ ░█/█░░░░░█░░░░░█/█
|
|
53
|
+
░░▀░░▀░░▀▀▀░▀░░░▀░
|
|
54
|
+
"""
|
|
55
|
+
print(banner)
|
|
56
|
+
|
|
57
|
+
async def _flush_loop(engine: MaskingEngine, shutdown_event: anyio.Event):
|
|
58
|
+
"""Periodically flush pending alias writes to the database."""
|
|
59
|
+
while not shutdown_event.is_set():
|
|
60
|
+
await anyio.sleep(1.0)
|
|
61
|
+
if engine.has_pending_writes:
|
|
62
|
+
try:
|
|
63
|
+
await engine.flush_pending()
|
|
64
|
+
except Exception:
|
|
65
|
+
logger.exception("Failed to flush aliases to database")
|
|
66
|
+
# Final flush on shutdown
|
|
67
|
+
if engine.has_pending_writes:
|
|
68
|
+
try:
|
|
69
|
+
await engine.flush_pending()
|
|
70
|
+
except Exception:
|
|
71
|
+
logger.exception("Failed final flush of aliases to database")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
async def _traffic_flush_loop(
|
|
75
|
+
buffer: TrafficBuffer,
|
|
76
|
+
store: TrafficStore,
|
|
77
|
+
shutdown_event: anyio.Event,
|
|
78
|
+
):
|
|
79
|
+
"""Periodically drain the traffic buffer into the traffic DB."""
|
|
80
|
+
while not shutdown_event.is_set():
|
|
81
|
+
await anyio.sleep(1.0)
|
|
82
|
+
if buffer.has_pending:
|
|
83
|
+
await buffer.flush(store)
|
|
84
|
+
# Final drain on shutdown
|
|
85
|
+
if buffer.has_pending:
|
|
86
|
+
await buffer.flush(store)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
async def _traffic_rotation_loop(
|
|
90
|
+
store: TrafficStore,
|
|
91
|
+
max_rows: int,
|
|
92
|
+
shutdown_event: anyio.Event,
|
|
93
|
+
):
|
|
94
|
+
"""Enforce the global traffic row cap every 5 minutes."""
|
|
95
|
+
while not shutdown_event.is_set():
|
|
96
|
+
with anyio.move_on_after(300):
|
|
97
|
+
await shutdown_event.wait()
|
|
98
|
+
if shutdown_event.is_set():
|
|
99
|
+
return
|
|
100
|
+
try:
|
|
101
|
+
await store.enforce_row_cap(max_rows)
|
|
102
|
+
except Exception:
|
|
103
|
+
logger.exception("Failed to enforce traffic row cap")
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _install_shutdown_noise_filter(shutdown_event: anyio.Event) -> None:
|
|
107
|
+
"""Quiet anyio/MCP-SDK async-generator teardown noise during shutdown.
|
|
108
|
+
|
|
109
|
+
The MCP SDK's stdio_client holds anyio task-group cancel scopes whose
|
|
110
|
+
teardown happens in Python's asyncgen finalizer task — not the task that
|
|
111
|
+
entered them. The resulting 'cancel scope in a different task' and
|
|
112
|
+
'aclose(): asynchronous generator is already running' errors are harmless
|
|
113
|
+
but loud. Swallow them once shutdown has been signalled.
|
|
114
|
+
"""
|
|
115
|
+
loop = asyncio.get_running_loop()
|
|
116
|
+
|
|
117
|
+
def _is_known_shutdown_noise(exc: BaseException | None, message: str) -> bool:
|
|
118
|
+
if "asynchronous generator" in message:
|
|
119
|
+
return True
|
|
120
|
+
if exc is None:
|
|
121
|
+
return False
|
|
122
|
+
if isinstance(exc, GeneratorExit):
|
|
123
|
+
return True
|
|
124
|
+
s = str(exc).lower()
|
|
125
|
+
if "cancel scope" in s or "asynchronous generator" in s:
|
|
126
|
+
return True
|
|
127
|
+
sub = getattr(exc, "exceptions", None)
|
|
128
|
+
if sub:
|
|
129
|
+
return bool(sub) and all(_is_known_shutdown_noise(e, "") for e in sub)
|
|
130
|
+
return False
|
|
131
|
+
|
|
132
|
+
def handler(loop, context):
|
|
133
|
+
if shutdown_event.is_set():
|
|
134
|
+
if _is_known_shutdown_noise(context.get("exception"), context.get("message", "")):
|
|
135
|
+
return
|
|
136
|
+
loop.default_exception_handler(context)
|
|
137
|
+
|
|
138
|
+
loop.set_exception_handler(handler)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
async def _graceful_shutdown(
|
|
142
|
+
state: ProxyState,
|
|
143
|
+
shutdown_event: anyio.Event,
|
|
144
|
+
web_server,
|
|
145
|
+
mcp_server,
|
|
146
|
+
callback_web_server,
|
|
147
|
+
tg,
|
|
148
|
+
drain_timeout: float,
|
|
149
|
+
flush_timeout: float,
|
|
150
|
+
) -> None:
|
|
151
|
+
"""Coordinate graceful shutdown with timeout enforcement.
|
|
152
|
+
|
|
153
|
+
Shutdown sequence:
|
|
154
|
+
1. Stop accepting new requests (servers marked for exit)
|
|
155
|
+
2. Drain in-flight requests (ResponseDispatcher notifies waiters)
|
|
156
|
+
3. Wait for flush loops to complete
|
|
157
|
+
4. Close upstream connections
|
|
158
|
+
5. Task group cancellation (any remaining tasks)
|
|
159
|
+
"""
|
|
160
|
+
logger.info("Starting graceful shutdown sequence")
|
|
161
|
+
|
|
162
|
+
# Stage 1: Stop accepting new requests
|
|
163
|
+
logger.info("Stage 1/4: Stopping request acceptance")
|
|
164
|
+
web_server.should_exit = True
|
|
165
|
+
mcp_server.should_exit = True
|
|
166
|
+
callback_web_server.should_exit = True
|
|
167
|
+
|
|
168
|
+
# Stage 2: Drain in-flight requests
|
|
169
|
+
logger.info(f"Stage 2/4: Draining in-flight requests ({drain_timeout}s timeout)")
|
|
170
|
+
with anyio.move_on_after(drain_timeout):
|
|
171
|
+
for target_state in state.targets.values():
|
|
172
|
+
# Notify all waiting HTTP clients to abort
|
|
173
|
+
target_state.response_dispatcher.shutdown()
|
|
174
|
+
|
|
175
|
+
# Close downstream streams (prevents new messages from clients)
|
|
176
|
+
if target_state.ds_read_send:
|
|
177
|
+
await target_state.ds_read_send.aclose()
|
|
178
|
+
|
|
179
|
+
# Stage 3: Wait for flush loops to complete
|
|
180
|
+
logger.info(f"Stage 3/4: Waiting for database flushes ({flush_timeout}s timeout)")
|
|
181
|
+
with anyio.move_on_after(flush_timeout):
|
|
182
|
+
# Signal shutdown to flush loops
|
|
183
|
+
shutdown_event.set()
|
|
184
|
+
|
|
185
|
+
# Give flush loops time to complete final flush
|
|
186
|
+
await anyio.sleep(0.5)
|
|
187
|
+
|
|
188
|
+
# Check if flushes completed
|
|
189
|
+
pending_count = sum(
|
|
190
|
+
ts.engine.has_pending_writes for ts in state.targets.values()
|
|
191
|
+
)
|
|
192
|
+
if pending_count > 0:
|
|
193
|
+
logger.warning(f"{pending_count} targets still have pending writes")
|
|
194
|
+
|
|
195
|
+
# Drain any remaining traffic entries
|
|
196
|
+
if state.traffic_buffer is not None and state.traffic_store is not None:
|
|
197
|
+
if state.traffic_buffer.has_pending:
|
|
198
|
+
await state.traffic_buffer.flush(state.traffic_store)
|
|
199
|
+
|
|
200
|
+
# Stage 4: Cancel remaining tasks (relay loops, servers)
|
|
201
|
+
logger.info("Stage 4/4: Cancelling remaining tasks")
|
|
202
|
+
tg.cancel_scope.cancel()
|
|
203
|
+
|
|
204
|
+
def _generate_installation_id() -> str:
|
|
205
|
+
length = 25
|
|
206
|
+
random_string = ''.join(
|
|
207
|
+
random.choices(string.ascii_letters + string.digits, k=length)
|
|
208
|
+
)
|
|
209
|
+
return random_string
|
|
210
|
+
|
|
211
|
+
def _load_installation_id() -> str:
|
|
212
|
+
id_path = Path("~/.openmaskit/.installation_id").expanduser()
|
|
213
|
+
if id_path.exists():
|
|
214
|
+
return id_path.read_bytes().strip().decode('utf-8')
|
|
215
|
+
|
|
216
|
+
key = _generate_installation_id()
|
|
217
|
+
id_path.parent.mkdir(parents=True, exist_ok=True)
|
|
218
|
+
id_path.write_bytes(bytes(key, 'utf-8'))
|
|
219
|
+
id_path.chmod(0o600)
|
|
220
|
+
return key
|
|
221
|
+
|
|
222
|
+
async def async_main():
|
|
223
|
+
from openmaskit.logging_config import setup_logging
|
|
224
|
+
from openmaskit.cli import parse_args
|
|
225
|
+
|
|
226
|
+
setup_logging()
|
|
227
|
+
logger = logging.getLogger(__name__)
|
|
228
|
+
|
|
229
|
+
args = parse_args()
|
|
230
|
+
config = load_config(
|
|
231
|
+
path=args.config_path,
|
|
232
|
+
web_port=args.web_port,
|
|
233
|
+
mcp_port=args.mcp_port,
|
|
234
|
+
oauth_port=args.oauth_port,
|
|
235
|
+
store_path=args.store_path,
|
|
236
|
+
)
|
|
237
|
+
bind_host = os.environ.get("OPENMASKIT_HOST", "127.0.0.1")
|
|
238
|
+
|
|
239
|
+
# Container runtime detection
|
|
240
|
+
from openmaskit.container import get_container_runtime
|
|
241
|
+
runtime = get_container_runtime(config.container_runtime)
|
|
242
|
+
if runtime:
|
|
243
|
+
if config.container_runtime:
|
|
244
|
+
logger.info(f"Container runtime: {runtime} (configured)")
|
|
245
|
+
else:
|
|
246
|
+
logger.info(f"Container runtime: {runtime} (auto-detected)")
|
|
247
|
+
else:
|
|
248
|
+
logger.warning("No container runtime detected. Containerized MCP servers will not work.")
|
|
249
|
+
|
|
250
|
+
# Shutdown configuration
|
|
251
|
+
SHUTDOWN_TIMEOUT = float(os.environ.get("OPENMASKIT_SHUTDOWN_TIMEOUT", "30"))
|
|
252
|
+
DRAIN_TIMEOUT = 5.0 # Time to wait for in-flight requests
|
|
253
|
+
FLUSH_TIMEOUT = 3.0 # Time to wait for database flushes
|
|
254
|
+
|
|
255
|
+
store = await MaskingStore.create(config.store_path)
|
|
256
|
+
|
|
257
|
+
traffic_db_path = os.environ.get(
|
|
258
|
+
"OPENMASKIT_TRAFFIC_DB_PATH",
|
|
259
|
+
str(Path("~/.openmaskit/traffic.db").expanduser()),
|
|
260
|
+
)
|
|
261
|
+
traffic_store = await TrafficStore.create(traffic_db_path)
|
|
262
|
+
traffic_buffer = TrafficBuffer()
|
|
263
|
+
traffic_max_rows = int(os.environ.get("OPENMASKIT_TRAFFIC_MAX_ROWS", "10000"))
|
|
264
|
+
|
|
265
|
+
state = ProxyState()
|
|
266
|
+
state.store = store
|
|
267
|
+
state.traffic_store = traffic_store
|
|
268
|
+
state.traffic_buffer = traffic_buffer
|
|
269
|
+
state.mcp_port = config.mcp_port
|
|
270
|
+
state.oauth_port = config.oauth_port
|
|
271
|
+
|
|
272
|
+
# Create per-target state
|
|
273
|
+
for name, target_config in config.targets.items():
|
|
274
|
+
rules = [
|
|
275
|
+
MaskingRule(
|
|
276
|
+
tool_name=r.tool_name,
|
|
277
|
+
field_path=r.field_path,
|
|
278
|
+
alias_prefix=r.alias_prefix,
|
|
279
|
+
action=r.action,
|
|
280
|
+
)
|
|
281
|
+
for r in target_config.rules
|
|
282
|
+
]
|
|
283
|
+
db_rules = await store.get_rules(target_name=name)
|
|
284
|
+
rules.extend(db_rules)
|
|
285
|
+
|
|
286
|
+
engine = MaskingEngine(rules, store, target_name=name)
|
|
287
|
+
await engine.load_aliases()
|
|
288
|
+
await engine.load_mappers()
|
|
289
|
+
await engine.load_guardrails()
|
|
290
|
+
await engine.load_injections()
|
|
291
|
+
|
|
292
|
+
for g in target_config.guardrails:
|
|
293
|
+
engine.add_guardrail(ArgumentGuardrail(
|
|
294
|
+
tool_name=g.tool_name, argument_name=g.argument_name,
|
|
295
|
+
match_type=g.match_type, pattern=g.pattern, message=g.message,
|
|
296
|
+
))
|
|
297
|
+
for i in target_config.injections:
|
|
298
|
+
engine.add_injection(ArgumentInjection(
|
|
299
|
+
tool_name=i.tool_name, argument_name=i.argument_name,
|
|
300
|
+
value=i.value, mode=i.mode,
|
|
301
|
+
))
|
|
302
|
+
|
|
303
|
+
hidden = await store.get_hidden_tools(target_name=name)
|
|
304
|
+
|
|
305
|
+
ds_read_send, ds_read_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32)
|
|
306
|
+
|
|
307
|
+
target_state = TargetState(
|
|
308
|
+
name=name,
|
|
309
|
+
engine=engine,
|
|
310
|
+
hidden_tools=set(hidden),
|
|
311
|
+
ds_read_send=ds_read_send,
|
|
312
|
+
ds_read_recv=ds_read_recv,
|
|
313
|
+
traffic_buffer=traffic_buffer,
|
|
314
|
+
)
|
|
315
|
+
state.targets[name] = target_state
|
|
316
|
+
|
|
317
|
+
state.config_target_ids = set(config.targets.keys())
|
|
318
|
+
|
|
319
|
+
# Load active marketplace servers from DB
|
|
320
|
+
marketplace_configs: dict[str, dict] = {}
|
|
321
|
+
installed = await store.get_installed_servers(active_only=True)
|
|
322
|
+
for record in installed:
|
|
323
|
+
server_id = record["id"]
|
|
324
|
+
if server_id in state.targets:
|
|
325
|
+
continue
|
|
326
|
+
|
|
327
|
+
engine = MaskingEngine([], store, target_name=server_id)
|
|
328
|
+
await engine.load_aliases()
|
|
329
|
+
await engine.load_mappers()
|
|
330
|
+
await engine.load_guardrails()
|
|
331
|
+
await engine.load_injections()
|
|
332
|
+
hidden = await store.get_hidden_tools(target_name=server_id)
|
|
333
|
+
ds_read_send, ds_read_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32)
|
|
334
|
+
|
|
335
|
+
target_state = TargetState(
|
|
336
|
+
name=server_id,
|
|
337
|
+
engine=engine,
|
|
338
|
+
hidden_tools=set(hidden),
|
|
339
|
+
ds_read_send=ds_read_send,
|
|
340
|
+
ds_read_recv=ds_read_recv,
|
|
341
|
+
server_id=server_id, # Set server_id for OAuth refresh
|
|
342
|
+
traffic_buffer=traffic_buffer,
|
|
343
|
+
)
|
|
344
|
+
state.targets[server_id] = target_state
|
|
345
|
+
marketplace_configs[server_id] = record["config"]
|
|
346
|
+
|
|
347
|
+
shutdown_event = anyio.Event()
|
|
348
|
+
_install_shutdown_noise_filter(shutdown_event)
|
|
349
|
+
|
|
350
|
+
# Shared OAuth callback server (always running)
|
|
351
|
+
callback_server = OAuthCallbackServer(port=config.oauth_port)
|
|
352
|
+
callback_app = callback_server.create_app()
|
|
353
|
+
callback_uvicorn_config = uvicorn.Config(
|
|
354
|
+
callback_app,
|
|
355
|
+
host=bind_host,
|
|
356
|
+
port=config.oauth_port,
|
|
357
|
+
log_level="warning",
|
|
358
|
+
log_config=None,
|
|
359
|
+
)
|
|
360
|
+
callback_web_server = uvicorn.Server(callback_uvicorn_config)
|
|
361
|
+
callback_web_server.install_signal_handlers = lambda: None
|
|
362
|
+
state.callback_server = callback_server
|
|
363
|
+
|
|
364
|
+
installation_id = _load_installation_id()
|
|
365
|
+
openmaskit_version = __version__
|
|
366
|
+
print('----------------------------------------------------------------------------------------------------------')
|
|
367
|
+
print_banner()
|
|
368
|
+
print('----------------------------------------------------------------------------------------------------------')
|
|
369
|
+
|
|
370
|
+
# Initialize backend client for marketplace and auth integration
|
|
371
|
+
from openmaskit.backend_client import BackendClient
|
|
372
|
+
|
|
373
|
+
backend_client = BackendClient(installation_id=installation_id, openmaskit_version=openmaskit_version)
|
|
374
|
+
oauth_states: dict[str, dict] = {} # {csrf_state: {server_id, handle, timestamp}}
|
|
375
|
+
|
|
376
|
+
# Store backend_client in state for token refresh
|
|
377
|
+
state.backend_client = backend_client
|
|
378
|
+
|
|
379
|
+
state.version_status = await backend_client.check_version()
|
|
380
|
+
if state.version_status:
|
|
381
|
+
if not state.version_status.get("supported", True):
|
|
382
|
+
logger.warning(
|
|
383
|
+
"OpenMaskit %s is no longer supported. Latest: %s",
|
|
384
|
+
openmaskit_version, state.version_status.get("latest_version"),
|
|
385
|
+
)
|
|
386
|
+
elif state.version_status.get("update_available"):
|
|
387
|
+
logger.info(
|
|
388
|
+
"OpenMaskit update available: %s (you have %s)",
|
|
389
|
+
state.version_status.get("latest_version"), openmaskit_version,
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
logger.info("OpenMaskit proxy starting")
|
|
393
|
+
logger.info(f"Dashboard: http://{bind_host}:{config.web_port}")
|
|
394
|
+
logger.info(f"OAuth callback: http://{bind_host}:{config.oauth_port}/callback")
|
|
395
|
+
if backend_client.enabled:
|
|
396
|
+
logger.info("Backend integration enabled.")
|
|
397
|
+
logger.info("MCP servers:")
|
|
398
|
+
for name in state.target_names:
|
|
399
|
+
logger.info(f" {name}: http://{bind_host}:{config.mcp_port}/{name}/mcp")
|
|
400
|
+
|
|
401
|
+
from openmaskit.web.origin import default_localhost_origins
|
|
402
|
+
allowed_origins = default_localhost_origins(config.web_port)
|
|
403
|
+
extra_origins_env = os.environ.get("OPENMASKIT_ALLOWED_ORIGINS", "").strip()
|
|
404
|
+
if extra_origins_env:
|
|
405
|
+
allowed_origins.extend(
|
|
406
|
+
o.strip() for o in extra_origins_env.split(",") if o.strip()
|
|
407
|
+
)
|
|
408
|
+
logger.debug(f"Allowed dashboard origins: {allowed_origins}")
|
|
409
|
+
|
|
410
|
+
# Route upstream MCP child stderr based on log level. Upstream servers
|
|
411
|
+
# emit their own protocol chatter ("Processing request of type …") that
|
|
412
|
+
# we cannot demote from inside our process; we can only mute the stream.
|
|
413
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
414
|
+
upstream_errlog = sys.stderr
|
|
415
|
+
else:
|
|
416
|
+
upstream_errlog = open(os.devnull, "w")
|
|
417
|
+
logger.info("Upstream stderr is suppressed. Set OPENMASKIT_LOG_LEVEL=DEBUG to see it.")
|
|
418
|
+
|
|
419
|
+
web_app = create_app(state, allowed_origins=allowed_origins)
|
|
420
|
+
web_app.state.backend_client = backend_client
|
|
421
|
+
web_app.state.oauth_states = oauth_states
|
|
422
|
+
mcp_app = create_mcp_app(state, allowed_origins=allowed_origins)
|
|
423
|
+
|
|
424
|
+
uvicorn_config = uvicorn.Config(
|
|
425
|
+
web_app,
|
|
426
|
+
host=bind_host,
|
|
427
|
+
port=config.web_port,
|
|
428
|
+
log_level="warning",
|
|
429
|
+
log_config=None,
|
|
430
|
+
)
|
|
431
|
+
web_server = uvicorn.Server(uvicorn_config)
|
|
432
|
+
web_server.install_signal_handlers = lambda: None
|
|
433
|
+
|
|
434
|
+
mcp_uvicorn_config = uvicorn.Config(
|
|
435
|
+
mcp_app,
|
|
436
|
+
host=bind_host,
|
|
437
|
+
port=config.mcp_port,
|
|
438
|
+
log_level="warning",
|
|
439
|
+
log_config=None,
|
|
440
|
+
)
|
|
441
|
+
mcp_server = uvicorn.Server(mcp_uvicorn_config)
|
|
442
|
+
mcp_server.install_signal_handlers = lambda: None
|
|
443
|
+
|
|
444
|
+
try:
|
|
445
|
+
async with AsyncExitStack() as stack:
|
|
446
|
+
# Connect all upstream targets
|
|
447
|
+
upstream_streams: dict[str, tuple[
|
|
448
|
+
MemoryObjectReceiveStream[SessionMessage | Exception],
|
|
449
|
+
MemoryObjectSendStream[SessionMessage],
|
|
450
|
+
]] = {}
|
|
451
|
+
|
|
452
|
+
failed_targets = []
|
|
453
|
+
for name, target_state in state.targets.items():
|
|
454
|
+
if name in config.targets:
|
|
455
|
+
target_config = config.targets[name]
|
|
456
|
+
us_read, us_write, container_info = await stack.enter_async_context(
|
|
457
|
+
connect_upstream(target_config.upstream, config.store_path,
|
|
458
|
+
errlog=upstream_errlog, server_id=name,
|
|
459
|
+
callback_server=callback_server,
|
|
460
|
+
container_runtime=config.container_runtime)
|
|
461
|
+
)
|
|
462
|
+
target_state.container_info = container_info
|
|
463
|
+
upstream_streams[name] = (us_read, us_write)
|
|
464
|
+
elif name in marketplace_configs:
|
|
465
|
+
upstream_cfg = _build_upstream_config(marketplace_configs[name])
|
|
466
|
+
|
|
467
|
+
# Pre-flight refresh if token is known-expired
|
|
468
|
+
if is_oauth_token_expired(name, config.store_path):
|
|
469
|
+
logger.info("OAuth token for %s is expired; attempting refresh before connect", name)
|
|
470
|
+
refreshed = await refresh_backend_oauth_token(name, config.store_path, backend_client)
|
|
471
|
+
if not refreshed:
|
|
472
|
+
logger.warning(
|
|
473
|
+
"Could not refresh OAuth token for %s; deactivating. User must re-authenticate via the dashboard.",
|
|
474
|
+
name,
|
|
475
|
+
)
|
|
476
|
+
failed_targets.append(name)
|
|
477
|
+
try:
|
|
478
|
+
await store.deactivate_server(name)
|
|
479
|
+
except Exception as deactivate_exc:
|
|
480
|
+
logger.error("Failed to deactivate server %s: %s", name, deactivate_exc)
|
|
481
|
+
continue
|
|
482
|
+
|
|
483
|
+
async def _connect_with_isolated_stack():
|
|
484
|
+
own_stack = AsyncExitStack()
|
|
485
|
+
await own_stack.__aenter__()
|
|
486
|
+
try:
|
|
487
|
+
r, w, ci = await own_stack.enter_async_context(
|
|
488
|
+
connect_upstream(upstream_cfg, config.store_path,
|
|
489
|
+
errlog=upstream_errlog, server_id=name,
|
|
490
|
+
callback_server=callback_server,
|
|
491
|
+
container_runtime=config.container_runtime)
|
|
492
|
+
)
|
|
493
|
+
return own_stack, r, w, ci
|
|
494
|
+
except BaseException:
|
|
495
|
+
await own_stack.aclose()
|
|
496
|
+
raise
|
|
497
|
+
|
|
498
|
+
own_stack = None
|
|
499
|
+
container_info = None
|
|
500
|
+
try:
|
|
501
|
+
own_stack, us_read, us_write, container_info = await _connect_with_isolated_stack()
|
|
502
|
+
except Exception as exc:
|
|
503
|
+
# One refresh+retry on failure (covers stale token not flagged by created_at)
|
|
504
|
+
logger.warning("Failed to connect marketplace server %s: %s", name, exc)
|
|
505
|
+
refreshed = await refresh_backend_oauth_token(name, config.store_path, backend_client)
|
|
506
|
+
if refreshed:
|
|
507
|
+
try:
|
|
508
|
+
own_stack, us_read, us_write, container_info = await _connect_with_isolated_stack()
|
|
509
|
+
except Exception as exc2:
|
|
510
|
+
logger.warning("Retry after refresh failed for %s: %s", name, exc2)
|
|
511
|
+
own_stack = None
|
|
512
|
+
if own_stack is None:
|
|
513
|
+
failed_targets.append(name)
|
|
514
|
+
try:
|
|
515
|
+
await store.deactivate_server(name)
|
|
516
|
+
logger.info("Deactivated server %s in database; re-auth via dashboard to reconnect", name)
|
|
517
|
+
except Exception as deactivate_exc:
|
|
518
|
+
logger.error("Failed to deactivate server %s: %s", name, deactivate_exc)
|
|
519
|
+
continue
|
|
520
|
+
|
|
521
|
+
# Register the isolated stack with the parent so it tears down at shutdown.
|
|
522
|
+
# Wrap aclose to swallow exit-time errors per target (otherwise one bad
|
|
523
|
+
# upstream's teardown takes down the whole process).
|
|
524
|
+
def _make_safe_aclose(target_name=name, s=own_stack):
|
|
525
|
+
async def _close():
|
|
526
|
+
try:
|
|
527
|
+
await s.aclose()
|
|
528
|
+
except Exception as exc:
|
|
529
|
+
logger.warning("Error closing upstream %s at shutdown: %s", target_name, exc)
|
|
530
|
+
return _close
|
|
531
|
+
stack.push_async_callback(_make_safe_aclose())
|
|
532
|
+
state.targets[name].container_info = container_info
|
|
533
|
+
upstream_streams[name] = (us_read, us_write)
|
|
534
|
+
for name in failed_targets:
|
|
535
|
+
del state.targets[name]
|
|
536
|
+
|
|
537
|
+
async with anyio.create_task_group() as tg:
|
|
538
|
+
manager = TargetManager(state, store, config.store_path,
|
|
539
|
+
callback_server=callback_server,
|
|
540
|
+
container_runtime=config.container_runtime)
|
|
541
|
+
manager.set_task_group(tg, shutdown_event)
|
|
542
|
+
state.target_manager = manager
|
|
543
|
+
|
|
544
|
+
# Start OAuth state cleanup task if backend is enabled
|
|
545
|
+
if backend_client.enabled:
|
|
546
|
+
from openmaskit.web.routes.oauth_callback import cleanup_expired_oauth_states
|
|
547
|
+
tg.start_soon(cleanup_expired_oauth_states, oauth_states)
|
|
548
|
+
|
|
549
|
+
async def _shutdown_on_signal():
|
|
550
|
+
with anyio.open_signal_receiver(signal.SIGINT, signal.SIGTERM) as signals:
|
|
551
|
+
async for sig in signals:
|
|
552
|
+
logger.info(f"Received {sig.name}, initiating graceful shutdown")
|
|
553
|
+
await _graceful_shutdown(
|
|
554
|
+
state, shutdown_event, web_server, mcp_server,
|
|
555
|
+
callback_web_server, tg, DRAIN_TIMEOUT, FLUSH_TIMEOUT
|
|
556
|
+
)
|
|
557
|
+
break
|
|
558
|
+
|
|
559
|
+
tg.start_soon(_shutdown_on_signal)
|
|
560
|
+
|
|
561
|
+
async def _safe_run_proxy(target_state, us_read, us_write):
|
|
562
|
+
target_name = target_state.name
|
|
563
|
+
try:
|
|
564
|
+
await run_proxy_for_target(target_state, us_read, us_write)
|
|
565
|
+
except Exception as exc:
|
|
566
|
+
logger.error(
|
|
567
|
+
"[%s] Proxy task failed (likely OAuth/upstream error); deactivating target. %s",
|
|
568
|
+
target_name, exc,
|
|
569
|
+
)
|
|
570
|
+
# Deactivate so we don't crash again on next startup
|
|
571
|
+
try:
|
|
572
|
+
await store.deactivate_server(target_name)
|
|
573
|
+
except Exception:
|
|
574
|
+
pass
|
|
575
|
+
state.targets.pop(target_name, None)
|
|
576
|
+
|
|
577
|
+
for name, target_state in state.targets.items():
|
|
578
|
+
us_read, us_write = upstream_streams[name]
|
|
579
|
+
tg.start_soon(_safe_run_proxy, target_state, us_read, us_write)
|
|
580
|
+
tg.start_soon(_flush_loop, target_state.engine, shutdown_event)
|
|
581
|
+
# Start background eviction to prevent memory leaks
|
|
582
|
+
tg.start_soon(target_state.response_dispatcher.start_background_eviction, shutdown_event)
|
|
583
|
+
|
|
584
|
+
tg.start_soon(_traffic_flush_loop, traffic_buffer, traffic_store, shutdown_event)
|
|
585
|
+
tg.start_soon(_traffic_rotation_loop, traffic_store, traffic_max_rows, shutdown_event)
|
|
586
|
+
|
|
587
|
+
tg.start_soon(web_server.serve)
|
|
588
|
+
tg.start_soon(mcp_server.serve)
|
|
589
|
+
tg.start_soon(callback_web_server.serve)
|
|
590
|
+
|
|
591
|
+
except Exception as exc:
|
|
592
|
+
logger.exception(f"Error: {type(exc).__name__}: {exc}")
|
|
593
|
+
finally:
|
|
594
|
+
await backend_client.close()
|
|
595
|
+
await store.close()
|
|
596
|
+
try:
|
|
597
|
+
await traffic_store.close()
|
|
598
|
+
except Exception:
|
|
599
|
+
logger.exception("Failed to close traffic store")
|
|
600
|
+
|
|
601
|
+
logger.info("OpenMaskit stopped")
|
|
602
|
+
|
|
603
|
+
|
|
604
|
+
def main():
|
|
605
|
+
try:
|
|
606
|
+
anyio.run(async_main)
|
|
607
|
+
except (KeyboardInterrupt, SystemExit):
|
|
608
|
+
pass
|
|
609
|
+
except BaseException as exc:
|
|
610
|
+
print(f"Fatal error: {exc}", file=sys.stderr)
|
|
611
|
+
import traceback
|
|
612
|
+
traceback.print_exc(file=sys.stderr)
|
|
613
|
+
sys.exit(1)
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
if __name__ == "__main__":
|
|
617
|
+
main()
|