openmaskit 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. openmaskit/__init__.py +19 -0
  2. openmaskit/__main__.py +617 -0
  3. openmaskit/backend_client.py +181 -0
  4. openmaskit/cli.py +112 -0
  5. openmaskit/config.py +126 -0
  6. openmaskit/container.py +285 -0
  7. openmaskit/logging_config.py +74 -0
  8. openmaskit/masking/__init__.py +0 -0
  9. openmaskit/masking/engine.py +536 -0
  10. openmaskit/masking/mappers.py +20 -0
  11. openmaskit/masking/parsing.py +51 -0
  12. openmaskit/masking/rules.py +103 -0
  13. openmaskit/masking/store.py +619 -0
  14. openmaskit/models.py +66 -0
  15. openmaskit/oauth/__init__.py +0 -0
  16. openmaskit/oauth/handler.py +431 -0
  17. openmaskit/proxy/__init__.py +0 -0
  18. openmaskit/proxy/core.py +574 -0
  19. openmaskit/proxy/http_downstream.py +159 -0
  20. openmaskit/proxy/manager.py +260 -0
  21. openmaskit/proxy/upstream.py +321 -0
  22. openmaskit/security.py +145 -0
  23. openmaskit/traffic/__init__.py +0 -0
  24. openmaskit/traffic/buffer.py +44 -0
  25. openmaskit/traffic/store.py +223 -0
  26. openmaskit/web/__init__.py +0 -0
  27. openmaskit/web/app.py +142 -0
  28. openmaskit/web/health.py +109 -0
  29. openmaskit/web/origin.py +110 -0
  30. openmaskit/web/routes/__init__.py +0 -0
  31. openmaskit/web/routes/custom_targets.py +280 -0
  32. openmaskit/web/routes/guardrails.py +160 -0
  33. openmaskit/web/routes/hidden_tools.py +40 -0
  34. openmaskit/web/routes/injections.py +142 -0
  35. openmaskit/web/routes/mappers.py +382 -0
  36. openmaskit/web/routes/marketplace.py +607 -0
  37. openmaskit/web/routes/oauth.py +162 -0
  38. openmaskit/web/routes/oauth_callback.py +191 -0
  39. openmaskit/web/routes/pages.py +158 -0
  40. openmaskit/web/routes/rules.py +124 -0
  41. openmaskit/web/routes/traffic.py +82 -0
  42. openmaskit/web/static/big.png +0 -0
  43. openmaskit/web/static/favicon.png +0 -0
  44. openmaskit/web/static/icon.png +0 -0
  45. openmaskit/web/static/marketplace.html +937 -0
  46. openmaskit/web/static/new_maskit-removebg-preview.png +0 -0
  47. openmaskit/web/static/onboarding.css +386 -0
  48. openmaskit/web/static/original_icon.png +0 -0
  49. openmaskit/web/static/shared.js +322 -0
  50. openmaskit/web/static/style.css +5036 -0
  51. openmaskit/web/static/targets.html +1174 -0
  52. openmaskit/web/static/tool_detail.html +936 -0
  53. openmaskit/web/static/tools.html +661 -0
  54. openmaskit/web/static/tutorial.css +377 -0
  55. openmaskit/web/static/tutorial.js +546 -0
  56. openmaskit/web/static/tutorials/guardrails.json +31 -0
  57. openmaskit/web/static/tutorials/hide-tool.json +16 -0
  58. openmaskit/web/static/tutorials/injections.json +31 -0
  59. openmaskit/web/static/tutorials/masking-with-result.json +31 -0
  60. openmaskit/web/static/tutorials/masking.json +16 -0
  61. openmaskit-0.1.1.dist-info/METADATA +229 -0
  62. openmaskit-0.1.1.dist-info/RECORD +66 -0
  63. openmaskit-0.1.1.dist-info/WHEEL +4 -0
  64. openmaskit-0.1.1.dist-info/entry_points.txt +2 -0
  65. openmaskit-0.1.1.dist-info/licenses/LICENSE +201 -0
  66. openmaskit-0.1.1.dist-info/licenses/NOTICE +10 -0
openmaskit/__init__.py ADDED
@@ -0,0 +1,19 @@
1
+ """OpenMaskit - MCP server wrapper that masks sensitive fields."""
2
+
3
+ from importlib.metadata import version, PackageNotFoundError
4
+
5
+ try:
6
+ __version__ = version("openmaskit")
7
+ except PackageNotFoundError:
8
+ # Running from source without the package installed
9
+ try:
10
+ from pathlib import Path
11
+ try:
12
+ import tomllib
13
+ except ImportError:
14
+ import tomli as tomllib # type: ignore
15
+ _pyproject = Path(__file__).parent.parent.parent / "pyproject.toml"
16
+ with open(_pyproject, "rb") as _f:
17
+ __version__ = tomllib.load(_f).get("project", {}).get("version", "unknown")
18
+ except Exception:
19
+ __version__ = "unknown"
openmaskit/__main__.py ADDED
@@ -0,0 +1,617 @@
1
+ """OpenMaskit entry point."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import logging
7
+ import os
8
+ import signal
9
+ import sys
10
+ from contextlib import AsyncExitStack
11
+ from pathlib import Path
12
+ import random
13
+ import string
14
+
15
+ import anyio
16
+ import uvicorn
17
+ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
18
+
19
+ from mcp.shared.message import SessionMessage
20
+
21
+ from openmaskit.config import load_config
22
+ from openmaskit.masking.engine import MaskingEngine
23
+ from openmaskit.masking.rules import ArgumentGuardrail, ArgumentInjection, MaskingRule
24
+ from openmaskit.masking.store import MaskingStore
25
+ from openmaskit.oauth.handler import OAuthCallbackServer
26
+ from openmaskit.proxy.core import ProxyState, TargetState, run_proxy_for_target
27
+ from openmaskit.proxy.http_downstream import create_mcp_app
28
+ from openmaskit.proxy.manager import TargetManager, _build_upstream_config
29
+ from openmaskit.proxy.upstream import connect_upstream, is_oauth_token_expired, refresh_backend_oauth_token
30
+ from openmaskit.traffic.buffer import TrafficBuffer
31
+ from openmaskit.traffic.store import TrafficStore
32
+ from openmaskit.web.app import create_app
33
+ from openmaskit import __version__
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ def print_banner():
38
+ # https://patorjk.com/software/taag
39
+ # DOS Rebel for openmaskit
40
+ # Pagga for version
41
+ banner = """
42
+ ███████ ██████ ██████ █████ ███ █████
43
+ ███░░░░░███ ░░██████ ██████ ░░███ ░░░ ░░███
44
+ ███ ░░███ ████████ ██████ ████████ ░███░█████░███ ██████ █████ ░███ █████ ████ ███████
45
+ ░███ ░███░░███░░███ ███░░███░░███░░███ ░███░░███ ░███ ░░░░░███ ███░░ ░███░░███ ░░███ ░░░███░
46
+ ░███ ░███ ░███ ░███░███████ ░███ ░███ ░███ ░░░ ░███ ███████ ░░█████ ░██████░ ░███ ░███
47
+ ░░███ ███ ░███ ░███░███░░░ ░███ ░███ ░███ ░███ ███░░███ ░░░░███ ░███░░███ ░███ ░███ ███
48
+ ░░░███████░ ░███████ ░░██████ ████ █████ █████ █████░░████████ ██████ ████ █████ █████ ░░█████
49
+ ░░░░░░░ ░███░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░ ░░░░░
50
+ ░███
51
+ █████ ░▄▀▄░░░░▀█░░░░░▄▀▄
52
+ ░░░░░ ░█/█░░░░░█░░░░░█/█
53
+ ░░▀░░▀░░▀▀▀░▀░░░▀░
54
+ """
55
+ print(banner)
56
+
57
+ async def _flush_loop(engine: MaskingEngine, shutdown_event: anyio.Event):
58
+ """Periodically flush pending alias writes to the database."""
59
+ while not shutdown_event.is_set():
60
+ await anyio.sleep(1.0)
61
+ if engine.has_pending_writes:
62
+ try:
63
+ await engine.flush_pending()
64
+ except Exception:
65
+ logger.exception("Failed to flush aliases to database")
66
+ # Final flush on shutdown
67
+ if engine.has_pending_writes:
68
+ try:
69
+ await engine.flush_pending()
70
+ except Exception:
71
+ logger.exception("Failed final flush of aliases to database")
72
+
73
+
74
+ async def _traffic_flush_loop(
75
+ buffer: TrafficBuffer,
76
+ store: TrafficStore,
77
+ shutdown_event: anyio.Event,
78
+ ):
79
+ """Periodically drain the traffic buffer into the traffic DB."""
80
+ while not shutdown_event.is_set():
81
+ await anyio.sleep(1.0)
82
+ if buffer.has_pending:
83
+ await buffer.flush(store)
84
+ # Final drain on shutdown
85
+ if buffer.has_pending:
86
+ await buffer.flush(store)
87
+
88
+
89
+ async def _traffic_rotation_loop(
90
+ store: TrafficStore,
91
+ max_rows: int,
92
+ shutdown_event: anyio.Event,
93
+ ):
94
+ """Enforce the global traffic row cap every 5 minutes."""
95
+ while not shutdown_event.is_set():
96
+ with anyio.move_on_after(300):
97
+ await shutdown_event.wait()
98
+ if shutdown_event.is_set():
99
+ return
100
+ try:
101
+ await store.enforce_row_cap(max_rows)
102
+ except Exception:
103
+ logger.exception("Failed to enforce traffic row cap")
104
+
105
+
106
+ def _install_shutdown_noise_filter(shutdown_event: anyio.Event) -> None:
107
+ """Quiet anyio/MCP-SDK async-generator teardown noise during shutdown.
108
+
109
+ The MCP SDK's stdio_client holds anyio task-group cancel scopes whose
110
+ teardown happens in Python's asyncgen finalizer task — not the task that
111
+ entered them. The resulting 'cancel scope in a different task' and
112
+ 'aclose(): asynchronous generator is already running' errors are harmless
113
+ but loud. Swallow them once shutdown has been signalled.
114
+ """
115
+ loop = asyncio.get_running_loop()
116
+
117
+ def _is_known_shutdown_noise(exc: BaseException | None, message: str) -> bool:
118
+ if "asynchronous generator" in message:
119
+ return True
120
+ if exc is None:
121
+ return False
122
+ if isinstance(exc, GeneratorExit):
123
+ return True
124
+ s = str(exc).lower()
125
+ if "cancel scope" in s or "asynchronous generator" in s:
126
+ return True
127
+ sub = getattr(exc, "exceptions", None)
128
+ if sub:
129
+ return bool(sub) and all(_is_known_shutdown_noise(e, "") for e in sub)
130
+ return False
131
+
132
+ def handler(loop, context):
133
+ if shutdown_event.is_set():
134
+ if _is_known_shutdown_noise(context.get("exception"), context.get("message", "")):
135
+ return
136
+ loop.default_exception_handler(context)
137
+
138
+ loop.set_exception_handler(handler)
139
+
140
+
141
+ async def _graceful_shutdown(
142
+ state: ProxyState,
143
+ shutdown_event: anyio.Event,
144
+ web_server,
145
+ mcp_server,
146
+ callback_web_server,
147
+ tg,
148
+ drain_timeout: float,
149
+ flush_timeout: float,
150
+ ) -> None:
151
+ """Coordinate graceful shutdown with timeout enforcement.
152
+
153
+ Shutdown sequence:
154
+ 1. Stop accepting new requests (servers marked for exit)
155
+ 2. Drain in-flight requests (ResponseDispatcher notifies waiters)
156
+ 3. Wait for flush loops to complete
157
+ 4. Close upstream connections
158
+ 5. Task group cancellation (any remaining tasks)
159
+ """
160
+ logger.info("Starting graceful shutdown sequence")
161
+
162
+ # Stage 1: Stop accepting new requests
163
+ logger.info("Stage 1/4: Stopping request acceptance")
164
+ web_server.should_exit = True
165
+ mcp_server.should_exit = True
166
+ callback_web_server.should_exit = True
167
+
168
+ # Stage 2: Drain in-flight requests
169
+ logger.info(f"Stage 2/4: Draining in-flight requests ({drain_timeout}s timeout)")
170
+ with anyio.move_on_after(drain_timeout):
171
+ for target_state in state.targets.values():
172
+ # Notify all waiting HTTP clients to abort
173
+ target_state.response_dispatcher.shutdown()
174
+
175
+ # Close downstream streams (prevents new messages from clients)
176
+ if target_state.ds_read_send:
177
+ await target_state.ds_read_send.aclose()
178
+
179
+ # Stage 3: Wait for flush loops to complete
180
+ logger.info(f"Stage 3/4: Waiting for database flushes ({flush_timeout}s timeout)")
181
+ with anyio.move_on_after(flush_timeout):
182
+ # Signal shutdown to flush loops
183
+ shutdown_event.set()
184
+
185
+ # Give flush loops time to complete final flush
186
+ await anyio.sleep(0.5)
187
+
188
+ # Check if flushes completed
189
+ pending_count = sum(
190
+ ts.engine.has_pending_writes for ts in state.targets.values()
191
+ )
192
+ if pending_count > 0:
193
+ logger.warning(f"{pending_count} targets still have pending writes")
194
+
195
+ # Drain any remaining traffic entries
196
+ if state.traffic_buffer is not None and state.traffic_store is not None:
197
+ if state.traffic_buffer.has_pending:
198
+ await state.traffic_buffer.flush(state.traffic_store)
199
+
200
+ # Stage 4: Cancel remaining tasks (relay loops, servers)
201
+ logger.info("Stage 4/4: Cancelling remaining tasks")
202
+ tg.cancel_scope.cancel()
203
+
204
+ def _generate_installation_id() -> str:
205
+ length = 25
206
+ random_string = ''.join(
207
+ random.choices(string.ascii_letters + string.digits, k=length)
208
+ )
209
+ return random_string
210
+
211
+ def _load_installation_id() -> str:
212
+ id_path = Path("~/.openmaskit/.installation_id").expanduser()
213
+ if id_path.exists():
214
+ return id_path.read_bytes().strip().decode('utf-8')
215
+
216
+ key = _generate_installation_id()
217
+ id_path.parent.mkdir(parents=True, exist_ok=True)
218
+ id_path.write_bytes(bytes(key, 'utf-8'))
219
+ id_path.chmod(0o600)
220
+ return key
221
+
222
+ async def async_main():
223
+ from openmaskit.logging_config import setup_logging
224
+ from openmaskit.cli import parse_args
225
+
226
+ setup_logging()
227
+ logger = logging.getLogger(__name__)
228
+
229
+ args = parse_args()
230
+ config = load_config(
231
+ path=args.config_path,
232
+ web_port=args.web_port,
233
+ mcp_port=args.mcp_port,
234
+ oauth_port=args.oauth_port,
235
+ store_path=args.store_path,
236
+ )
237
+ bind_host = os.environ.get("OPENMASKIT_HOST", "127.0.0.1")
238
+
239
+ # Container runtime detection
240
+ from openmaskit.container import get_container_runtime
241
+ runtime = get_container_runtime(config.container_runtime)
242
+ if runtime:
243
+ if config.container_runtime:
244
+ logger.info(f"Container runtime: {runtime} (configured)")
245
+ else:
246
+ logger.info(f"Container runtime: {runtime} (auto-detected)")
247
+ else:
248
+ logger.warning("No container runtime detected. Containerized MCP servers will not work.")
249
+
250
+ # Shutdown configuration
251
+ SHUTDOWN_TIMEOUT = float(os.environ.get("OPENMASKIT_SHUTDOWN_TIMEOUT", "30"))
252
+ DRAIN_TIMEOUT = 5.0 # Time to wait for in-flight requests
253
+ FLUSH_TIMEOUT = 3.0 # Time to wait for database flushes
254
+
255
+ store = await MaskingStore.create(config.store_path)
256
+
257
+ traffic_db_path = os.environ.get(
258
+ "OPENMASKIT_TRAFFIC_DB_PATH",
259
+ str(Path("~/.openmaskit/traffic.db").expanduser()),
260
+ )
261
+ traffic_store = await TrafficStore.create(traffic_db_path)
262
+ traffic_buffer = TrafficBuffer()
263
+ traffic_max_rows = int(os.environ.get("OPENMASKIT_TRAFFIC_MAX_ROWS", "10000"))
264
+
265
+ state = ProxyState()
266
+ state.store = store
267
+ state.traffic_store = traffic_store
268
+ state.traffic_buffer = traffic_buffer
269
+ state.mcp_port = config.mcp_port
270
+ state.oauth_port = config.oauth_port
271
+
272
+ # Create per-target state
273
+ for name, target_config in config.targets.items():
274
+ rules = [
275
+ MaskingRule(
276
+ tool_name=r.tool_name,
277
+ field_path=r.field_path,
278
+ alias_prefix=r.alias_prefix,
279
+ action=r.action,
280
+ )
281
+ for r in target_config.rules
282
+ ]
283
+ db_rules = await store.get_rules(target_name=name)
284
+ rules.extend(db_rules)
285
+
286
+ engine = MaskingEngine(rules, store, target_name=name)
287
+ await engine.load_aliases()
288
+ await engine.load_mappers()
289
+ await engine.load_guardrails()
290
+ await engine.load_injections()
291
+
292
+ for g in target_config.guardrails:
293
+ engine.add_guardrail(ArgumentGuardrail(
294
+ tool_name=g.tool_name, argument_name=g.argument_name,
295
+ match_type=g.match_type, pattern=g.pattern, message=g.message,
296
+ ))
297
+ for i in target_config.injections:
298
+ engine.add_injection(ArgumentInjection(
299
+ tool_name=i.tool_name, argument_name=i.argument_name,
300
+ value=i.value, mode=i.mode,
301
+ ))
302
+
303
+ hidden = await store.get_hidden_tools(target_name=name)
304
+
305
+ ds_read_send, ds_read_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32)
306
+
307
+ target_state = TargetState(
308
+ name=name,
309
+ engine=engine,
310
+ hidden_tools=set(hidden),
311
+ ds_read_send=ds_read_send,
312
+ ds_read_recv=ds_read_recv,
313
+ traffic_buffer=traffic_buffer,
314
+ )
315
+ state.targets[name] = target_state
316
+
317
+ state.config_target_ids = set(config.targets.keys())
318
+
319
+ # Load active marketplace servers from DB
320
+ marketplace_configs: dict[str, dict] = {}
321
+ installed = await store.get_installed_servers(active_only=True)
322
+ for record in installed:
323
+ server_id = record["id"]
324
+ if server_id in state.targets:
325
+ continue
326
+
327
+ engine = MaskingEngine([], store, target_name=server_id)
328
+ await engine.load_aliases()
329
+ await engine.load_mappers()
330
+ await engine.load_guardrails()
331
+ await engine.load_injections()
332
+ hidden = await store.get_hidden_tools(target_name=server_id)
333
+ ds_read_send, ds_read_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32)
334
+
335
+ target_state = TargetState(
336
+ name=server_id,
337
+ engine=engine,
338
+ hidden_tools=set(hidden),
339
+ ds_read_send=ds_read_send,
340
+ ds_read_recv=ds_read_recv,
341
+ server_id=server_id, # Set server_id for OAuth refresh
342
+ traffic_buffer=traffic_buffer,
343
+ )
344
+ state.targets[server_id] = target_state
345
+ marketplace_configs[server_id] = record["config"]
346
+
347
+ shutdown_event = anyio.Event()
348
+ _install_shutdown_noise_filter(shutdown_event)
349
+
350
+ # Shared OAuth callback server (always running)
351
+ callback_server = OAuthCallbackServer(port=config.oauth_port)
352
+ callback_app = callback_server.create_app()
353
+ callback_uvicorn_config = uvicorn.Config(
354
+ callback_app,
355
+ host=bind_host,
356
+ port=config.oauth_port,
357
+ log_level="warning",
358
+ log_config=None,
359
+ )
360
+ callback_web_server = uvicorn.Server(callback_uvicorn_config)
361
+ callback_web_server.install_signal_handlers = lambda: None
362
+ state.callback_server = callback_server
363
+
364
+ installation_id = _load_installation_id()
365
+ openmaskit_version = __version__
366
+ print('----------------------------------------------------------------------------------------------------------')
367
+ print_banner()
368
+ print('----------------------------------------------------------------------------------------------------------')
369
+
370
+ # Initialize backend client for marketplace and auth integration
371
+ from openmaskit.backend_client import BackendClient
372
+
373
+ backend_client = BackendClient(installation_id=installation_id, openmaskit_version=openmaskit_version)
374
+ oauth_states: dict[str, dict] = {} # {csrf_state: {server_id, handle, timestamp}}
375
+
376
+ # Store backend_client in state for token refresh
377
+ state.backend_client = backend_client
378
+
379
+ state.version_status = await backend_client.check_version()
380
+ if state.version_status:
381
+ if not state.version_status.get("supported", True):
382
+ logger.warning(
383
+ "OpenMaskit %s is no longer supported. Latest: %s",
384
+ openmaskit_version, state.version_status.get("latest_version"),
385
+ )
386
+ elif state.version_status.get("update_available"):
387
+ logger.info(
388
+ "OpenMaskit update available: %s (you have %s)",
389
+ state.version_status.get("latest_version"), openmaskit_version,
390
+ )
391
+
392
+ logger.info("OpenMaskit proxy starting")
393
+ logger.info(f"Dashboard: http://{bind_host}:{config.web_port}")
394
+ logger.info(f"OAuth callback: http://{bind_host}:{config.oauth_port}/callback")
395
+ if backend_client.enabled:
396
+ logger.info("Backend integration enabled.")
397
+ logger.info("MCP servers:")
398
+ for name in state.target_names:
399
+ logger.info(f" {name}: http://{bind_host}:{config.mcp_port}/{name}/mcp")
400
+
401
+ from openmaskit.web.origin import default_localhost_origins
402
+ allowed_origins = default_localhost_origins(config.web_port)
403
+ extra_origins_env = os.environ.get("OPENMASKIT_ALLOWED_ORIGINS", "").strip()
404
+ if extra_origins_env:
405
+ allowed_origins.extend(
406
+ o.strip() for o in extra_origins_env.split(",") if o.strip()
407
+ )
408
+ logger.debug(f"Allowed dashboard origins: {allowed_origins}")
409
+
410
+ # Route upstream MCP child stderr based on log level. Upstream servers
411
+ # emit their own protocol chatter ("Processing request of type …") that
412
+ # we cannot demote from inside our process; we can only mute the stream.
413
+ if logger.isEnabledFor(logging.DEBUG):
414
+ upstream_errlog = sys.stderr
415
+ else:
416
+ upstream_errlog = open(os.devnull, "w")
417
+ logger.info("Upstream stderr is suppressed. Set OPENMASKIT_LOG_LEVEL=DEBUG to see it.")
418
+
419
+ web_app = create_app(state, allowed_origins=allowed_origins)
420
+ web_app.state.backend_client = backend_client
421
+ web_app.state.oauth_states = oauth_states
422
+ mcp_app = create_mcp_app(state, allowed_origins=allowed_origins)
423
+
424
+ uvicorn_config = uvicorn.Config(
425
+ web_app,
426
+ host=bind_host,
427
+ port=config.web_port,
428
+ log_level="warning",
429
+ log_config=None,
430
+ )
431
+ web_server = uvicorn.Server(uvicorn_config)
432
+ web_server.install_signal_handlers = lambda: None
433
+
434
+ mcp_uvicorn_config = uvicorn.Config(
435
+ mcp_app,
436
+ host=bind_host,
437
+ port=config.mcp_port,
438
+ log_level="warning",
439
+ log_config=None,
440
+ )
441
+ mcp_server = uvicorn.Server(mcp_uvicorn_config)
442
+ mcp_server.install_signal_handlers = lambda: None
443
+
444
+ try:
445
+ async with AsyncExitStack() as stack:
446
+ # Connect all upstream targets
447
+ upstream_streams: dict[str, tuple[
448
+ MemoryObjectReceiveStream[SessionMessage | Exception],
449
+ MemoryObjectSendStream[SessionMessage],
450
+ ]] = {}
451
+
452
+ failed_targets = []
453
+ for name, target_state in state.targets.items():
454
+ if name in config.targets:
455
+ target_config = config.targets[name]
456
+ us_read, us_write, container_info = await stack.enter_async_context(
457
+ connect_upstream(target_config.upstream, config.store_path,
458
+ errlog=upstream_errlog, server_id=name,
459
+ callback_server=callback_server,
460
+ container_runtime=config.container_runtime)
461
+ )
462
+ target_state.container_info = container_info
463
+ upstream_streams[name] = (us_read, us_write)
464
+ elif name in marketplace_configs:
465
+ upstream_cfg = _build_upstream_config(marketplace_configs[name])
466
+
467
+ # Pre-flight refresh if token is known-expired
468
+ if is_oauth_token_expired(name, config.store_path):
469
+ logger.info("OAuth token for %s is expired; attempting refresh before connect", name)
470
+ refreshed = await refresh_backend_oauth_token(name, config.store_path, backend_client)
471
+ if not refreshed:
472
+ logger.warning(
473
+ "Could not refresh OAuth token for %s; deactivating. User must re-authenticate via the dashboard.",
474
+ name,
475
+ )
476
+ failed_targets.append(name)
477
+ try:
478
+ await store.deactivate_server(name)
479
+ except Exception as deactivate_exc:
480
+ logger.error("Failed to deactivate server %s: %s", name, deactivate_exc)
481
+ continue
482
+
483
+ async def _connect_with_isolated_stack():
484
+ own_stack = AsyncExitStack()
485
+ await own_stack.__aenter__()
486
+ try:
487
+ r, w, ci = await own_stack.enter_async_context(
488
+ connect_upstream(upstream_cfg, config.store_path,
489
+ errlog=upstream_errlog, server_id=name,
490
+ callback_server=callback_server,
491
+ container_runtime=config.container_runtime)
492
+ )
493
+ return own_stack, r, w, ci
494
+ except BaseException:
495
+ await own_stack.aclose()
496
+ raise
497
+
498
+ own_stack = None
499
+ container_info = None
500
+ try:
501
+ own_stack, us_read, us_write, container_info = await _connect_with_isolated_stack()
502
+ except Exception as exc:
503
+ # One refresh+retry on failure (covers stale token not flagged by created_at)
504
+ logger.warning("Failed to connect marketplace server %s: %s", name, exc)
505
+ refreshed = await refresh_backend_oauth_token(name, config.store_path, backend_client)
506
+ if refreshed:
507
+ try:
508
+ own_stack, us_read, us_write, container_info = await _connect_with_isolated_stack()
509
+ except Exception as exc2:
510
+ logger.warning("Retry after refresh failed for %s: %s", name, exc2)
511
+ own_stack = None
512
+ if own_stack is None:
513
+ failed_targets.append(name)
514
+ try:
515
+ await store.deactivate_server(name)
516
+ logger.info("Deactivated server %s in database; re-auth via dashboard to reconnect", name)
517
+ except Exception as deactivate_exc:
518
+ logger.error("Failed to deactivate server %s: %s", name, deactivate_exc)
519
+ continue
520
+
521
+ # Register the isolated stack with the parent so it tears down at shutdown.
522
+ # Wrap aclose to swallow exit-time errors per target (otherwise one bad
523
+ # upstream's teardown takes down the whole process).
524
+ def _make_safe_aclose(target_name=name, s=own_stack):
525
+ async def _close():
526
+ try:
527
+ await s.aclose()
528
+ except Exception as exc:
529
+ logger.warning("Error closing upstream %s at shutdown: %s", target_name, exc)
530
+ return _close
531
+ stack.push_async_callback(_make_safe_aclose())
532
+ state.targets[name].container_info = container_info
533
+ upstream_streams[name] = (us_read, us_write)
534
+ for name in failed_targets:
535
+ del state.targets[name]
536
+
537
+ async with anyio.create_task_group() as tg:
538
+ manager = TargetManager(state, store, config.store_path,
539
+ callback_server=callback_server,
540
+ container_runtime=config.container_runtime)
541
+ manager.set_task_group(tg, shutdown_event)
542
+ state.target_manager = manager
543
+
544
+ # Start OAuth state cleanup task if backend is enabled
545
+ if backend_client.enabled:
546
+ from openmaskit.web.routes.oauth_callback import cleanup_expired_oauth_states
547
+ tg.start_soon(cleanup_expired_oauth_states, oauth_states)
548
+
549
+ async def _shutdown_on_signal():
550
+ with anyio.open_signal_receiver(signal.SIGINT, signal.SIGTERM) as signals:
551
+ async for sig in signals:
552
+ logger.info(f"Received {sig.name}, initiating graceful shutdown")
553
+ await _graceful_shutdown(
554
+ state, shutdown_event, web_server, mcp_server,
555
+ callback_web_server, tg, DRAIN_TIMEOUT, FLUSH_TIMEOUT
556
+ )
557
+ break
558
+
559
+ tg.start_soon(_shutdown_on_signal)
560
+
561
+ async def _safe_run_proxy(target_state, us_read, us_write):
562
+ target_name = target_state.name
563
+ try:
564
+ await run_proxy_for_target(target_state, us_read, us_write)
565
+ except Exception as exc:
566
+ logger.error(
567
+ "[%s] Proxy task failed (likely OAuth/upstream error); deactivating target. %s",
568
+ target_name, exc,
569
+ )
570
+ # Deactivate so we don't crash again on next startup
571
+ try:
572
+ await store.deactivate_server(target_name)
573
+ except Exception:
574
+ pass
575
+ state.targets.pop(target_name, None)
576
+
577
+ for name, target_state in state.targets.items():
578
+ us_read, us_write = upstream_streams[name]
579
+ tg.start_soon(_safe_run_proxy, target_state, us_read, us_write)
580
+ tg.start_soon(_flush_loop, target_state.engine, shutdown_event)
581
+ # Start background eviction to prevent memory leaks
582
+ tg.start_soon(target_state.response_dispatcher.start_background_eviction, shutdown_event)
583
+
584
+ tg.start_soon(_traffic_flush_loop, traffic_buffer, traffic_store, shutdown_event)
585
+ tg.start_soon(_traffic_rotation_loop, traffic_store, traffic_max_rows, shutdown_event)
586
+
587
+ tg.start_soon(web_server.serve)
588
+ tg.start_soon(mcp_server.serve)
589
+ tg.start_soon(callback_web_server.serve)
590
+
591
+ except Exception as exc:
592
+ logger.exception(f"Error: {type(exc).__name__}: {exc}")
593
+ finally:
594
+ await backend_client.close()
595
+ await store.close()
596
+ try:
597
+ await traffic_store.close()
598
+ except Exception:
599
+ logger.exception("Failed to close traffic store")
600
+
601
+ logger.info("OpenMaskit stopped")
602
+
603
+
604
+ def main():
605
+ try:
606
+ anyio.run(async_main)
607
+ except (KeyboardInterrupt, SystemExit):
608
+ pass
609
+ except BaseException as exc:
610
+ print(f"Fatal error: {exc}", file=sys.stderr)
611
+ import traceback
612
+ traceback.print_exc(file=sys.stderr)
613
+ sys.exit(1)
614
+
615
+
616
+ if __name__ == "__main__":
617
+ main()