mcpsnare 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mcpsnare/__init__.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "0.3.0"
mcpsnare/__main__.py ADDED
@@ -0,0 +1,4 @@
1
+ from mcpsnare.cli import main
2
+
3
+ if __name__ == "__main__":
4
+ main()
@@ -0,0 +1,2 @@
1
+ from mcpsnare.checks import (path_traversal, info_leak, cmd_injection, ssrf, auth_bypass, # noqa: F401
2
+ sql_injection) # noqa: F401
@@ -0,0 +1,48 @@
1
+ import re
2
+ from mcpsnare.models import Probe, Finding, Severity, Confidence
3
+ from mcpsnare.checks.base import register
4
+
5
+ # Volatile substrings stripped before the tolerant compare, so a bypass is detected
6
+ # even when the two bodies differ only by a timestamp / request-id / nonce. NOTE: a
7
+ # bare record "id" is intentionally NOT stripped - a different record id is a real data
8
+ # difference, not a volatile field, and stripping it would risk a false bypass.
9
+ _VOLATILE = re.compile(
10
+ r"\d{4}-\d{2}-\d{2}[T ]\d{2}:\d{2}:\d{2}(?:\.\d+)?Z?" # ISO timestamps
11
+ r"|[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" # UUIDs
12
+ r"|\"(?:ts|timestamp|nonce|request[_-]?id|trace[_-]?id)\"\s*:\s*\"?[^\",}]*\"?",
13
+ re.IGNORECASE,
14
+ )
15
+
16
+
17
+ def _normalize(s: str) -> str:
18
+ return _VOLATILE.sub("", s or "").strip()
19
+
20
+
21
+ @register
22
+ class AuthBypass:
23
+ id = "auth_bypass"
24
+ def generate(self, point, ctx):
25
+ if ctx.transport != "http" or ctx.call_tool_unauth is None:
26
+ return []
27
+ return [Probe(check=self.id, point=point, payload="<no auth header>",
28
+ args=dict(point.base_args), meta={"needs_unauth": True})]
29
+ def evaluate(self, probe, response, ctx):
30
+ unauth = probe.meta.get("unauth_response")
31
+ if not unauth:
32
+ return None
33
+ if unauth == response:
34
+ # Raw byte-identical: a clear, directly-observed bypass.
35
+ return self._finding(probe, Confidence.CONFIRMED,
36
+ "tool callable without auth header (identical response)")
37
+ nu, nr = _normalize(unauth), _normalize(response)
38
+ if nu and nu == nr:
39
+ # Match only after stripping volatile fields: strong but inferred -> FIRM.
40
+ return self._finding(probe, Confidence.FIRM,
41
+ "tool callable without auth header (responses match modulo volatile fields)")
42
+ return None
43
+ def _finding(self, probe, conf, evidence):
44
+ return Finding(check=self.id, tool=probe.point.tool, param="-",
45
+ severity=Severity.HIGH, confidence=conf, cwe="CWE-306",
46
+ title=f"Missing authentication on {probe.point.tool}",
47
+ payload=probe.payload, evidence=evidence,
48
+ remediation="Enforce auth on the HTTP transport for all sensitive tools.")
@@ -0,0 +1,27 @@
1
+ from dataclasses import dataclass
2
+ from typing import Callable, Protocol
3
+ from mcpsnare.models import InjectionPoint, Probe, Finding, ToolBaseline
4
+
5
+
6
+ @dataclass
7
+ class CheckContext:
8
+ oob: object | None
9
+ transport: str
10
+ call_tool: Callable[[str, dict], str] | None = None
11
+ call_tool_unauth: Callable[[str, dict], str] | None = None
12
+ baseline: ToolBaseline | None = None
13
+ aggressive: bool = False
14
+
15
+
16
+ class Check(Protocol):
17
+ id: str
18
+ def generate(self, point: InjectionPoint, ctx: "CheckContext") -> list[Probe]: ...
19
+ def evaluate(self, probe: Probe, response: str, ctx: "CheckContext") -> Finding | None: ...
20
+
21
+
22
+ REGISTRY: dict[str, "Check"] = {}
23
+
24
+
25
+ def register(cls):
26
+ REGISTRY[cls.id] = cls()
27
+ return cls
@@ -0,0 +1,73 @@
1
+ from mcpsnare.models import Probe, Finding, Severity, Confidence
2
+ from mcpsnare.checks.base import register
3
+
4
+ _SLEEP_SECONDS = 5
5
+ _LATENCY_MULT = 3
6
+
7
+ # OOB confirmation payloads: each triggers an outbound request to {url}. Covers POSIX
8
+ # sh, Windows cmd.exe, and PowerShell, deduped. Non-blocking (always sent).
9
+ _OOB_TEMPLATES = (
10
+ "; curl {url}", # POSIX
11
+ "$(curl {url})", # POSIX command substitution
12
+ "& curl {url}", # cmd.exe / POSIX background
13
+ "| curl {url}", # cmd.exe / POSIX pipe
14
+ "; iwr {url}", # PowerShell (Invoke-WebRequest)
15
+ "; curl.exe {url}", # PowerShell (real curl, not the iwr alias)
16
+ )
17
+
18
+ # Blocking time-based payloads (~_SLEEP_SECONDS delay), sent only with --aggressive.
19
+ _SLEEP_TEMPLATES = (
20
+ "; sleep {n}", # POSIX
21
+ "$(sleep {n})", # POSIX command substitution
22
+ "& ping -n {n1} 127.0.0.1", # cmd.exe (no sleep builtin; ping ~1s/echo)
23
+ "; Start-Sleep -s {n}", # PowerShell
24
+ )
25
+
26
+ @register
27
+ class CmdInjection:
28
+ id = "cmd_injection"
29
+ def generate(self, point, ctx):
30
+ probes = []
31
+ if ctx.oob is not None:
32
+ for tpl in _OOB_TEMPLATES:
33
+ token, url = ctx.oob.new_token()
34
+ cmd = tpl.format(url=url) # separator+command, e.g. "; curl <url>"
35
+ whole_args = point.set(f"mcpsnare{cmd}")
36
+ embed_args = point.embed(cmd) # valid-value prefix + cmd
37
+ variants = [whole_args]
38
+ if embed_args != whole_args:
39
+ variants.append(embed_args)
40
+ for args in variants:
41
+ from mcpsnare.inject.jsonpath import deep_get
42
+ payload = deep_get(args, point.json_path)
43
+ probes.append(Probe(check=self.id, point=point, payload=str(payload),
44
+ args=args, token=token))
45
+ if getattr(ctx, "aggressive", False):
46
+ for tpl in _SLEEP_TEMPLATES:
47
+ pl = f"mcpsnare{tpl.format(n=_SLEEP_SECONDS, n1=_SLEEP_SECONDS + 1)}"
48
+ probes.append(Probe(check=self.id, point=point, payload=pl, args=point.set(pl),
49
+ meta={"time_based": True, "threshold": _SLEEP_SECONDS}))
50
+ return probes
51
+ def evaluate(self, probe, response, ctx):
52
+ if probe.token and ctx.oob and ctx.oob.interactions(probe.token):
53
+ return self._finding(probe, Confidence.CONFIRMED,
54
+ f"OOB callback received for payload {probe.payload!r}")
55
+ if probe.meta.get("time_based"):
56
+ elapsed = probe.meta.get("elapsed", 0)
57
+ sleep_s = probe.meta["threshold"]
58
+ baseline = getattr(ctx, "baseline", None)
59
+ if baseline is not None:
60
+ margin = max(baseline.latency + sleep_s * 0.8, baseline.latency * _LATENCY_MULT)
61
+ evidence = f"response delayed {elapsed:.1f}s vs baseline {baseline.latency:.1f}s"
62
+ else:
63
+ margin = sleep_s # no calibration: fall back to the fixed threshold
64
+ evidence = f"response delayed {elapsed:.1f}s"
65
+ if elapsed >= margin:
66
+ return self._finding(probe, Confidence.FIRM, evidence)
67
+ return None
68
+ def _finding(self, probe, conf, evidence):
69
+ return Finding(check=self.id, tool=probe.point.tool, param=probe.point.param_name,
70
+ severity=Severity.CRITICAL, confidence=conf, cwe="CWE-78",
71
+ title=f"Command injection in {probe.point.tool}.{probe.point.param_name}",
72
+ payload=probe.payload, evidence=evidence,
73
+ remediation="Never pass tool input to a shell; use exec with arg arrays / allowlists.")
@@ -0,0 +1,43 @@
1
+ import re
2
+ from mcpsnare.models import Probe, Finding, Severity, Confidence
3
+ from mcpsnare.checks.base import register
4
+
5
+ _MARKERS = [re.compile(p) for p in [
6
+ r"-----BEGIN [A-Z ]*PRIVATE KEY-----", r"AKIA[0-9A-Z]{16}",
7
+ r"(?i)api[_-]?key\s*[=:]\s*\S+", r"(?i)secret\s*[=:]\s*\S+",
8
+ r"xox[baprs]-[0-9A-Za-z-]+", r"eyJ[A-Za-z0-9_-]{10,}\.[A-Za-z0-9_-]{10,}\."]]
9
+
10
+
11
+ @register
12
+ class InfoLeak:
13
+ id = "info_leak"
14
+ def generate(self, point, ctx):
15
+ return [Probe(check=self.id, point=point, payload="mcpsnare-probe",
16
+ args=point.set("mcpsnare-probe"))]
17
+ def evaluate(self, probe, response, ctx):
18
+ hits = [m.pattern for m in _MARKERS if m.search(response or "")]
19
+ if not hits:
20
+ return None
21
+ baseline = getattr(ctx, "baseline", None)
22
+ if baseline is not None:
23
+ # Diff by which PATTERN matched, not the matched substring: this favors FP
24
+ # elimination (a docs tool with a fixed example key is suppressed). Known
25
+ # trade-off (disclosed in docs/claims-matrix.md): a baseline placeholder and
26
+ # a real leak of the SAME shape both match one pattern, so such a leak is missed.
27
+ base_hits = {m.pattern for m in _MARKERS if m.search(baseline.response or "")}
28
+ triggered = [h for h in hits if h not in base_hits]
29
+ if not triggered:
30
+ return None # secrets also present in benign baseline = normal output, not a leak
31
+ return self._finding(probe, Confidence.FIRM,
32
+ f"secret-shaped match triggered by input (absent in baseline): {triggered}")
33
+ if len(hits) >= 2:
34
+ return self._finding(probe, Confidence.TENTATIVE,
35
+ f"secret-shaped pattern match (no baseline to diff): {hits}")
36
+ return None
37
+
38
+ def _finding(self, probe, conf, evidence):
39
+ return Finding(check=self.id, tool=probe.point.tool, param=probe.point.param_name,
40
+ severity=Severity.HIGH, confidence=conf, cwe="CWE-200",
41
+ title=f"Secret/info leak via {probe.point.tool}",
42
+ payload=probe.payload, evidence=evidence,
43
+ remediation="Never return secrets/credentials in tool output or errors.")
@@ -0,0 +1,25 @@
1
+ import re
2
+ from mcpsnare.models import Probe, Finding, Severity, Confidence
3
+ from mcpsnare.checks.base import register
4
+
5
+ _PAYLOADS = ["../../../../../../etc/passwd",
6
+ "..\\..\\..\\..\\..\\..\\windows\\win.ini"]
7
+ _CANARY = re.compile(r"root:x:0:0:|\[fonts\]", re.IGNORECASE)
8
+
9
+
10
+ @register
11
+ class PathTraversal:
12
+ id = "path_traversal"
13
+ def generate(self, point, ctx):
14
+ out = []
15
+ for pl in _PAYLOADS:
16
+ out.append(Probe(check=self.id, point=point, payload=pl, args=point.set(pl)))
17
+ return out
18
+ def evaluate(self, probe, response, ctx):
19
+ if _CANARY.search(response or ""):
20
+ return Finding(check=self.id, tool=probe.point.tool, param=probe.point.param_name,
21
+ severity=Severity.HIGH, confidence=Confidence.CONFIRMED, cwe="CWE-22",
22
+ title=f"Path traversal in {probe.point.tool}.{probe.point.param_name}",
23
+ payload=probe.payload, evidence=(response or "")[:200],
24
+ remediation="Resolve and contain paths within an allowed base dir.")
25
+ return None
@@ -0,0 +1,76 @@
1
+ import re
2
+ from mcpsnare.models import Probe, Finding, Severity, Confidence
3
+ from mcpsnare.checks.base import register
4
+
5
+ _SLEEP_SECONDS = 5
6
+ _LATENCY_MULT = 3
7
+
8
+ # Distinctive SQL error signatures across common engines.
9
+ _ERROR_SIGNS = re.compile(
10
+ r"SQL syntax|SQLSTATE|ORA-\d{5}|mysql_fetch|mysql_num_rows|"
11
+ r"unclosed quotation mark|quoted string not properly terminated|"
12
+ r"you have an error in your SQL|near \"[^\"]*\": syntax error|"
13
+ r"PG::\w+Error|pg_query|psql:.*(?:ERROR|FATAL)|SQLite3?::|Microsoft OLE DB|"
14
+ r"ODBC SQL Server|Npgsql\.|System\.Data\.SqlClient",
15
+ re.IGNORECASE,
16
+ )
17
+
18
+ # Non-blocking error-based payloads (always sent).
19
+ _ERROR_PAYLOADS = ("'", '"', "')", "' OR '1'='1")
20
+
21
+ # Blocking time-based payloads (~_SLEEP_SECONDS), aggressive-only. MySQL, MSSQL, PostgreSQL.
22
+ _TIME_TEMPLATES = (
23
+ "' OR SLEEP({n})-- ",
24
+ "'; WAITFOR DELAY '0:0:{n}'-- ",
25
+ "' OR pg_sleep({n})-- ",
26
+ )
27
+
28
+
29
+ @register
30
+ class SqlInjection:
31
+ id = "sql_injection"
32
+
33
+ def generate(self, point, ctx):
34
+ probes = []
35
+ for pl in _ERROR_PAYLOADS:
36
+ probes.append(Probe(check=self.id, point=point, payload=pl,
37
+ args=point.set(pl), meta={"error_based": True}))
38
+ if getattr(ctx, "aggressive", False):
39
+ for tpl in _TIME_TEMPLATES:
40
+ pl = tpl.format(n=_SLEEP_SECONDS)
41
+ probes.append(Probe(check=self.id, point=point, payload=pl,
42
+ args=point.set(pl),
43
+ meta={"time_based": True, "threshold": _SLEEP_SECONDS}))
44
+ return probes
45
+
46
+ def evaluate(self, probe, response, ctx):
47
+ if probe.meta.get("time_based"):
48
+ elapsed = probe.meta.get("elapsed", 0)
49
+ sleep_s = probe.meta["threshold"]
50
+ baseline = getattr(ctx, "baseline", None)
51
+ if baseline is not None:
52
+ margin = max(baseline.latency + sleep_s * 0.8, baseline.latency * _LATENCY_MULT)
53
+ evidence = f"response delayed {elapsed:.1f}s vs baseline {baseline.latency:.1f}s (SQL sleep)"
54
+ else:
55
+ margin = sleep_s
56
+ evidence = f"response delayed {elapsed:.1f}s (SQL sleep)"
57
+ if elapsed >= margin:
58
+ return self._finding(probe, Confidence.FIRM, evidence)
59
+ return None
60
+ if not _ERROR_SIGNS.search(response or ""):
61
+ return None
62
+ baseline = getattr(ctx, "baseline", None)
63
+ if baseline is not None:
64
+ if _ERROR_SIGNS.search(baseline.response or ""):
65
+ return None # error already in benign baseline = not triggered
66
+ return self._finding(probe, Confidence.FIRM,
67
+ "SQL error signature triggered by quote payload (absent in baseline)")
68
+ return self._finding(probe, Confidence.TENTATIVE,
69
+ "SQL error signature matched (no baseline to corroborate)")
70
+
71
+ def _finding(self, probe, conf, evidence):
72
+ return Finding(check=self.id, tool=probe.point.tool, param=probe.point.param_name,
73
+ severity=Severity.HIGH, confidence=conf, cwe="CWE-89",
74
+ title=f"SQL injection in {probe.point.tool}.{probe.point.param_name}",
75
+ payload=probe.payload, evidence=evidence,
76
+ remediation="Use parameterised queries / prepared statements; never concatenate input into SQL.")
@@ -0,0 +1,19 @@
1
+ from mcpsnare.models import Probe, Finding, Severity, Confidence
2
+ from mcpsnare.checks.base import register
3
+
4
+ @register
5
+ class SSRF:
6
+ id = "ssrf"
7
+ def generate(self, point, ctx):
8
+ if ctx.oob is None:
9
+ return []
10
+ token, url = ctx.oob.new_token()
11
+ return [Probe(check=self.id, point=point, payload=url, args=point.set(url), token=token)]
12
+ def evaluate(self, probe, response, ctx):
13
+ if probe.token and ctx.oob and ctx.oob.interactions(probe.token):
14
+ return Finding(check=self.id, tool=probe.point.tool, param=probe.point.param_name,
15
+ severity=Severity.HIGH, confidence=Confidence.CONFIRMED, cwe="CWE-918",
16
+ title=f"SSRF in {probe.point.tool}.{probe.point.param_name}",
17
+ payload=probe.payload, evidence="OOB callback received",
18
+ remediation="Validate/allowlist outbound URLs; block internal ranges & metadata IPs.")
19
+ return None
mcpsnare/cli.py ADDED
@@ -0,0 +1,118 @@
1
+ import argparse
2
+ import asyncio
3
+ import os
4
+ import shlex
5
+ import sys
6
+
7
+ from mcpsnare.report.render import to_json, to_sarif, to_markdown
8
+
9
+
10
+ def _positive_float(s):
11
+ v = float(s)
12
+ if v <= 0:
13
+ raise argparse.ArgumentTypeError("must be > 0")
14
+ return v
15
+
16
+
17
+ def aggressive_note(aggressive: bool) -> str | None:
18
+ """Honest note for default (non-aggressive) scans: blocking time-based probes
19
+ were skipped, so an empty report must not be read as 'secure'. None when
20
+ aggressive (nothing was skipped)."""
21
+ if aggressive:
22
+ return None
23
+ return ("[i] Default mode: blocking time-based probes were skipped. "
24
+ "Re-run with --aggressive to add time-based command-injection and SQL-injection detection.")
25
+
26
+
27
+ def build_parser():
28
+ p = argparse.ArgumentParser(prog="mcpsnare")
29
+ sub = p.add_subparsers(dest="cmd", required=True)
30
+ s = sub.add_parser("scan", help="scan an MCP server")
31
+ g = s.add_mutually_exclusive_group(required=True)
32
+ g.add_argument("--stdio", help='command to launch the server, e.g. "python server.py"')
33
+ g.add_argument("--http", help="streamable HTTP MCP endpoint URL")
34
+ s.add_argument("--header", action="append", default=[], help="HTTP header k:v (repeatable)")
35
+ s.add_argument("--oob", choices=["local", "interactsh", "none"], default="local")
36
+ s.add_argument("--interactsh-server", default="oast.fun",
37
+ help="interactsh/OAST server domain for --oob interactsh (default oast.fun)")
38
+ s.add_argument("--aggressive", action="store_true",
39
+ help="also send blocking time-based (sleep) probes; default sends only non-blocking OOB/canary/pattern probes")
40
+ s.add_argument("--concurrency", type=int, default=4,
41
+ help="max concurrent probe requests (default 4)")
42
+ s.add_argument("--rate", type=_positive_float, default=None,
43
+ help="max requests/second (default unlimited)")
44
+ s.add_argument("--oob-timeout", type=float, default=20.0,
45
+ help="seconds to poll for OOB callbacks (default 20)")
46
+ s.add_argument("--oob-poll-interval", type=float, default=2.5,
47
+ help="OOB poll interval seconds (default 2.5)")
48
+ s.add_argument("--output", choices=["console", "json", "sarif", "md"], default="console")
49
+ return p
50
+
51
+
52
+ async def _run(args):
53
+ from mcpsnare.connect.session import stdio_session, http_session
54
+ from mcpsnare.connect.resources import ResourceToolView
55
+ from mcpsnare.engine import scan_session
56
+ import mcpsnare.checks # register
57
+ from mcpsnare.oob.local import LocalOOB
58
+
59
+ print("[!] mcpsnare - authorized testing only.")
60
+ oob_cm = None
61
+ oob = None
62
+ if args.oob == "local":
63
+ oob_cm = LocalOOB()
64
+ oob = oob_cm.__enter__()
65
+ elif args.oob == "interactsh":
66
+ from mcpsnare.oob.interactsh import InteractshOOB
67
+ from mcpsnare.oob.interactsh_client import InteractshClient
68
+ oob = InteractshOOB(InteractshClient(server=args.interactsh_server))
69
+ try:
70
+ if args.stdio:
71
+ argv = shlex.split(args.stdio, posix=(os.name != "nt"))
72
+ async with stdio_session(argv) as sess:
73
+ findings = await scan_session(sess, oob=oob, transport="stdio", aggressive=args.aggressive,
74
+ concurrency=args.concurrency, rate=args.rate,
75
+ oob_timeout=args.oob_timeout,
76
+ oob_poll_interval=args.oob_poll_interval)
77
+ findings += await scan_session(ResourceToolView(sess), oob=oob, transport="stdio",
78
+ aggressive=args.aggressive, concurrency=args.concurrency,
79
+ rate=args.rate, check_ids=["path_traversal", "info_leak"])
80
+ else:
81
+ headers = dict(h.split(":", 1) for h in args.header)
82
+ async with http_session(args.http, headers=headers) as sess:
83
+ if headers:
84
+ async with http_session(args.http, headers={}) as sess_unauth:
85
+ findings = await scan_session(sess, oob=oob, transport="http",
86
+ call_tool_unauth=sess_unauth.call_tool, aggressive=args.aggressive,
87
+ concurrency=args.concurrency, rate=args.rate,
88
+ oob_timeout=args.oob_timeout,
89
+ oob_poll_interval=args.oob_poll_interval)
90
+ findings += await scan_session(ResourceToolView(sess), oob=oob, transport="http",
91
+ aggressive=args.aggressive, concurrency=args.concurrency,
92
+ rate=args.rate, check_ids=["path_traversal", "info_leak"])
93
+ else:
94
+ findings = await scan_session(sess, oob=oob, transport="http", aggressive=args.aggressive,
95
+ concurrency=args.concurrency, rate=args.rate,
96
+ oob_timeout=args.oob_timeout,
97
+ oob_poll_interval=args.oob_poll_interval)
98
+ findings += await scan_session(ResourceToolView(sess), oob=oob, transport="http",
99
+ aggressive=args.aggressive, concurrency=args.concurrency,
100
+ rate=args.rate, check_ids=["path_traversal", "info_leak"])
101
+ finally:
102
+ if oob_cm:
103
+ oob_cm.__exit__(None, None, None)
104
+ renderers = {"json": to_json, "sarif": to_sarif, "md": to_markdown}
105
+ if args.output in renderers:
106
+ print(renderers[args.output](findings))
107
+ else:
108
+ print(f"\n{len(findings)} finding(s):")
109
+ for f in findings:
110
+ print(f" [{f.severity.value.upper()}] {f.title} ({f.confidence.value})")
111
+ note = aggressive_note(args.aggressive)
112
+ if note:
113
+ print(note, file=sys.stderr)
114
+
115
+
116
+ def main():
117
+ args = build_parser().parse_args()
118
+ asyncio.run(_run(args))
File without changes
@@ -0,0 +1,42 @@
1
+ import re
2
+
3
+ from mcpsnare.models import ToolInfo
4
+
5
+ _TMPL_PARAM = re.compile(r"\{([^}/]+)\}")
6
+
7
+
8
+ class ResourceToolView:
9
+ """Presents an object's resource templates as tool-like objects so the existing
10
+ engine (injection_points + checks + oracles) can scan resources. A templated
11
+ ``{param}`` in a URI template becomes a string injection point; ``call_tool`` fills
12
+ the template and ``read_resource``s it.
13
+
14
+ The wrapped object must expose ``list_resource_templates() -> list[(name, uriTemplate)]``
15
+ and ``read_resource(uri) -> str`` (mcpsnare's Session does).
16
+ """
17
+
18
+ def __init__(self, session):
19
+ self._session = session
20
+ self._templates = {} # tool_name -> uriTemplate
21
+
22
+ async def list_tools(self):
23
+ tools = []
24
+ for name, tmpl in await self._session.list_resource_templates():
25
+ params = _TMPL_PARAM.findall(tmpl)
26
+ if not params:
27
+ continue
28
+ props = {p: {"type": "string"} for p in params}
29
+ schema = {"type": "object", "properties": props, "required": params}
30
+ tool_name = f"resource:{tmpl}"
31
+ self._templates[tool_name] = tmpl
32
+ tools.append(ToolInfo(name=tool_name, description=name, input_schema=schema))
33
+ return tools
34
+
35
+ async def call_tool(self, name, args):
36
+ # Precondition: list_tools() must run first to populate self._templates.
37
+ # scan_session always calls list_tools() before any call_tool(), so this holds.
38
+ tmpl = self._templates[name]
39
+ uri = tmpl
40
+ for key, value in args.items():
41
+ uri = uri.replace("{" + key + "}", str(value))
42
+ return await self._session.read_resource(uri)
@@ -0,0 +1,68 @@
1
+ from contextlib import asynccontextmanager
2
+
3
+ from mcp import ClientSession, StdioServerParameters
4
+ from mcp.client.stdio import stdio_client
5
+ from mcp.client.streamable_http import create_mcp_http_client, streamable_http_client
6
+
7
+ from mcpsnare.models import ToolInfo
8
+
9
+
10
+ class Session:
11
+ def __init__(self, cs):
12
+ self._cs = cs
13
+
14
+ async def list_tools(self):
15
+ resp = await self._cs.list_tools()
16
+ return [
17
+ ToolInfo(
18
+ name=t.name,
19
+ description=t.description or "",
20
+ input_schema=t.inputSchema or {},
21
+ )
22
+ for t in resp.tools
23
+ ]
24
+
25
+ async def call_tool(self, name, args):
26
+ resp = await self._cs.call_tool(name, args)
27
+ parts = []
28
+ for c in resp.content:
29
+ parts.append(getattr(c, "text", "") or "")
30
+ structured = getattr(resp, "structuredContent", None)
31
+ if structured:
32
+ import json
33
+ parts.append(json.dumps(structured, default=str))
34
+ return "\n".join(p for p in parts if p)
35
+
36
+ async def list_resource_templates(self):
37
+ resp = await self._cs.list_resource_templates()
38
+ return [(t.name, t.uriTemplate) for t in resp.resourceTemplates]
39
+
40
+ async def read_resource(self, uri):
41
+ resp = await self._cs.read_resource(uri)
42
+ parts = []
43
+ for c in resp.contents:
44
+ parts.append(getattr(c, "text", "") or "")
45
+ return "\n".join(p for p in parts if p)
46
+
47
+
48
+ @asynccontextmanager
49
+ async def stdio_session(command):
50
+ params = StdioServerParameters(command=command[0], args=command[1:])
51
+ async with stdio_client(params) as (read, write):
52
+ async with ClientSession(read, write) as cs:
53
+ await cs.initialize()
54
+ yield Session(cs)
55
+
56
+
57
+ @asynccontextmanager
58
+ async def http_session(url, headers=None):
59
+ # The SDK's streamable_http_client takes headers via a caller-owned httpx client
60
+ # (the old headers= kwarg is gone). create_mcp_http_client applies the same MCP
61
+ # defaults the old path used (30s timeout, 300s SSE read, follow_redirects), so
62
+ # routing headers through it is behaviour-preserving - not a bare AsyncClient,
63
+ # which would default to a ~5s read timeout and break long-lived streams.
64
+ async with create_mcp_http_client(headers=headers or {}) as http_client:
65
+ async with streamable_http_client(url, http_client=http_client) as (read, write, *_):
66
+ async with ClientSession(read, write) as cs:
67
+ await cs.initialize()
68
+ yield Session(cs)