spooling 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spooling/mcp_server.py ADDED
@@ -0,0 +1,312 @@
1
+ """Spooling MCP server.
2
+
3
+ Exposes Spooling's trace, span, eval, and stats data over the Model Context
4
+ Protocol so any MCP-compatible agent (Codex, Cursor, etc.) can
5
+ query it as a source of context. Defaults to streamable-HTTP transport on
6
+ http://127.0.0.1:3004/mcp so web-based and remote agents can connect; stdio
7
+ is still available for stdio-only clients via `serve_stdio()`.
8
+
9
+ Tools exposed:
10
+ - list_traces(limit, provider, project)
11
+ - get_trace(trace_id)
12
+ - search_sessions(query, limit, project)
13
+ - get_stats()
14
+ - get_top_vendors()
15
+ - list_evals(rubric_id, limit)
16
+ - run_eval(rubric_id, trace_id)
17
+
18
+ The server is read-mostly: `run_eval` is the only mutation, and it writes
19
+ to the same evals table the GUI reads from.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import json
25
+ from typing import Any, Optional
26
+
27
+ from mcp.server.fastmcp import FastMCP
28
+
29
+ from spooling.db import get_connection
30
+
31
+
32
+ MCP_HOST = "127.0.0.1"
33
+ MCP_PORT = 3004
34
+ MCP_PATH = "/mcp"
35
+ MCP_URL = f"http://{MCP_HOST}:{MCP_PORT}{MCP_PATH}"
36
+
37
+
38
+ mcp = FastMCP(
39
+ name="spooling",
40
+ instructions=(
41
+ "Spooling tracks your AI coding sessions across Codex, Cursor, "
42
+ "Cursor, Copilot, Windsurf, Kiro, and Antigravity. Use these tools "
43
+ "to recall past sessions, search history semantically, inspect "
44
+ "span trees, and score sessions with Strands evaluators. Traces "
45
+ "and their spans carry token usage, cost, vendor tags, and eval "
46
+ "scores, so you can answer questions like 'how much did I spend "
47
+ "on Linear tool calls last week?' or 'show me the longest-running "
48
+ "agent span from Cursor'."
49
+ ),
50
+ host=MCP_HOST,
51
+ port=MCP_PORT,
52
+ stateless_http=True,
53
+ )
54
+
55
+
56
+ def _row(r) -> dict | None:
57
+ return dict(r) if r else None
58
+
59
+
60
+ def _rows(rs) -> list[dict]:
61
+ return [dict(r) for r in rs]
62
+
63
+
64
+ # --- tools -----------------------------------------------------------------
65
+
66
+ @mcp.tool()
67
+ def list_traces(
68
+ limit: int = 25,
69
+ provider: Optional[str] = None,
70
+ project: Optional[str] = None,
71
+ ) -> list[dict]:
72
+ """Recent Spooling traces. Use this to find recent sessions before drilling in.
73
+
74
+ Args:
75
+ limit: Max rows to return (default 25, capped at 200).
76
+ provider: Filter to one provider id (jsonl-session, codex, cursor, copilot, windsurf, kiro, antigravity, gemini, opencode).
77
+ project: Filter to sessions whose project name matches exactly.
78
+ """
79
+ limit = max(1, min(limit, 200))
80
+ clauses = []
81
+ params: list[Any] = []
82
+ if provider:
83
+ clauses.append("provider_id = %s")
84
+ params.append(provider)
85
+ if project:
86
+ clauses.append("project = %s")
87
+ params.append(project)
88
+ where = ("WHERE " + " AND ".join(clauses)) if clauses else ""
89
+ params.append(limit)
90
+
91
+ conn = get_connection()
92
+ try:
93
+ rows = conn.execute(
94
+ f"""SELECT id, session_id, provider_id, project, title, started_at,
95
+ duration_ms, span_count, agent_count, tool_count, llm_count,
96
+ error_count, total_cost_usd, model
97
+ FROM traces {where}
98
+ ORDER BY started_at DESC LIMIT %s""",
99
+ tuple(params),
100
+ ).fetchall()
101
+ finally:
102
+ conn.close()
103
+ return _rows(rows)
104
+
105
+
106
+ @mcp.tool()
107
+ def get_trace(trace_id: str) -> dict:
108
+ """Full detail for one trace: header row, span tree (flattened), and eval scores.
109
+
110
+ Args:
111
+ trace_id: The id from list_traces (looks like `trace-<session-uuid>`).
112
+ """
113
+ conn = get_connection()
114
+ try:
115
+ trace = conn.execute("SELECT * FROM traces WHERE id = %s", (trace_id,)).fetchone()
116
+ if not trace:
117
+ return {"error": f"trace not found: {trace_id}"}
118
+
119
+ spans = conn.execute(
120
+ """SELECT id, parent_id, kind, name, status, started_at, ended_at,
121
+ duration_ms, depth, sequence, input_tokens, output_tokens,
122
+ cost_usd, model, tool_name, tool_is_error, vendor, category,
123
+ agent_type, agent_prompt
124
+ FROM spans WHERE trace_id = %s ORDER BY sequence""",
125
+ (trace_id,),
126
+ ).fetchall()
127
+
128
+ evals = conn.execute(
129
+ """SELECT e.rubric_id, r.name AS rubric_name, e.score, e.passed,
130
+ e.label, e.rationale, e.run_at
131
+ FROM evals e LEFT JOIN eval_rubrics r ON r.id = e.rubric_id
132
+ WHERE e.trace_id = %s ORDER BY e.run_at DESC""",
133
+ (trace_id,),
134
+ ).fetchall()
135
+ finally:
136
+ conn.close()
137
+
138
+ return {
139
+ "trace": _row(trace),
140
+ "spans": _rows(spans),
141
+ "evals": _rows(evals),
142
+ }
143
+
144
+
145
+ @mcp.tool()
146
+ def search_sessions(
147
+ query: str,
148
+ limit: int = 10,
149
+ project: Optional[str] = None,
150
+ ) -> list[dict]:
151
+ """Semantic search over Spooling's embedded session chunks. Returns ranked matches.
152
+
153
+ Args:
154
+ query: Natural-language description of what to find.
155
+ limit: Max results (default 10, capped at 50).
156
+ project: Optional project name filter.
157
+ """
158
+ from spooling.search import search as do_search
159
+ limit = max(1, min(limit, 50))
160
+ return do_search(query, limit=limit, project=project)
161
+
162
+
163
+ @mcp.tool()
164
+ def get_stats() -> dict:
165
+ """Top-line Spooling stats: total traces, spans, tools, llm calls, cost, errors."""
166
+ conn = get_connection()
167
+ try:
168
+ row = conn.execute(
169
+ """SELECT
170
+ COUNT(*) AS traces,
171
+ COALESCE(SUM(span_count), 0) AS spans,
172
+ COALESCE(SUM(agent_count), 0) AS agents,
173
+ COALESCE(SUM(tool_count), 0) AS tools,
174
+ COALESCE(SUM(llm_count), 0) AS llm_calls,
175
+ COALESCE(SUM(error_count), 0) AS errors,
176
+ COALESCE(SUM(total_input_tokens), 0) AS input_tokens,
177
+ COALESCE(SUM(total_output_tokens), 0) AS output_tokens,
178
+ COALESCE(SUM(total_cost_usd), 0) AS cost_usd
179
+ FROM traces"""
180
+ ).fetchone()
181
+
182
+ per_provider = conn.execute(
183
+ """SELECT provider_id, COUNT(*) AS traces,
184
+ SUM(total_cost_usd) AS cost_usd
185
+ FROM traces GROUP BY provider_id ORDER BY traces DESC"""
186
+ ).fetchall()
187
+ finally:
188
+ conn.close()
189
+
190
+ return {
191
+ "summary": _row(row) or {},
192
+ "by_provider": _rows(per_provider),
193
+ }
194
+
195
+
196
+ @mcp.tool()
197
+ def get_top_vendors(limit: int = 20) -> list[dict]:
198
+ """Top external vendors (Linear, GitHub, Slack, Snowflake, ...) by tool-call count.
199
+
200
+ Args:
201
+ limit: Max rows (default 20, capped at 100).
202
+ """
203
+ limit = max(1, min(limit, 100))
204
+ conn = get_connection()
205
+ try:
206
+ rows = conn.execute(
207
+ """SELECT vendor, category,
208
+ COUNT(*) AS uses,
209
+ SUM(CASE WHEN tool_is_error THEN 1 ELSE 0 END) AS errors,
210
+ COUNT(DISTINCT trace_id) AS traces
211
+ FROM spans
212
+ WHERE kind = 'tool' AND vendor IS NOT NULL
213
+ AND vendor NOT IN ('filesystem', 'shell', 'search', 'unknown')
214
+ GROUP BY vendor, category
215
+ ORDER BY uses DESC LIMIT %s""",
216
+ (limit,),
217
+ ).fetchall()
218
+ finally:
219
+ conn.close()
220
+ return _rows(rows)
221
+
222
+
223
+ @mcp.tool()
224
+ def list_evals(
225
+ rubric_id: Optional[str] = None,
226
+ limit: int = 50,
227
+ ) -> list[dict]:
228
+ """Recent eval runs. Optionally filter by rubric id.
229
+
230
+ Args:
231
+ rubric_id: e.g. "helpfulness", "tool-error-rate".
232
+ limit: Max rows (default 50, capped at 200).
233
+ """
234
+ limit = max(1, min(limit, 200))
235
+ conn = get_connection()
236
+ try:
237
+ if rubric_id:
238
+ rows = conn.execute(
239
+ """SELECT e.id, e.rubric_id, r.name AS rubric_name, e.trace_id,
240
+ e.score, e.passed, e.label, e.rationale, e.run_at
241
+ FROM evals e LEFT JOIN eval_rubrics r ON r.id = e.rubric_id
242
+ WHERE e.rubric_id = %s ORDER BY e.run_at DESC LIMIT %s""",
243
+ (rubric_id, limit),
244
+ ).fetchall()
245
+ else:
246
+ rows = conn.execute(
247
+ """SELECT e.id, e.rubric_id, r.name AS rubric_name, e.trace_id,
248
+ e.score, e.passed, e.label, e.rationale, e.run_at
249
+ FROM evals e LEFT JOIN eval_rubrics r ON r.id = e.rubric_id
250
+ ORDER BY e.run_at DESC LIMIT %s""",
251
+ (limit,),
252
+ ).fetchall()
253
+ finally:
254
+ conn.close()
255
+ return _rows(rows)
256
+
257
+
258
+ @mcp.tool()
259
+ def list_rubrics() -> list[dict]:
260
+ """All configured eval rubrics (Strands evaluators + function rubrics)."""
261
+ conn = get_connection()
262
+ try:
263
+ rows = conn.execute(
264
+ """SELECT id, name, description, kind, target_kind,
265
+ evaluator_type, rubric_text, is_default
266
+ FROM eval_rubrics ORDER BY id"""
267
+ ).fetchall()
268
+ finally:
269
+ conn.close()
270
+ return _rows(rows)
271
+
272
+
273
+ @mcp.tool()
274
+ def run_eval(rubric_id: str, trace_id: str) -> dict:
275
+ """Run a rubric against a single trace and persist the result.
276
+
277
+ Args:
278
+ rubric_id: From list_rubrics().
279
+ trace_id: From list_traces() or get_trace().
280
+ """
281
+ from spooling.evals import run_rubric
282
+
283
+ eid = run_rubric(rubric_id, trace_id)
284
+ if eid is None:
285
+ return {"status": "skipped", "rubric_id": rubric_id, "trace_id": trace_id}
286
+
287
+ conn = get_connection()
288
+ try:
289
+ row = conn.execute(
290
+ """SELECT id, score, passed, label, rationale, judge_model
291
+ FROM evals WHERE id = %s""",
292
+ (eid,),
293
+ ).fetchone()
294
+ finally:
295
+ conn.close()
296
+ return {"status": "ok", "rubric_id": rubric_id, "trace_id": trace_id, "result": _row(row)}
297
+
298
+
299
+ # --- entrypoint ------------------------------------------------------------
300
+
301
+ def serve_stdio() -> None:
302
+ """Run the MCP server over stdio (for stdio-only MCP clients)."""
303
+ mcp.run(transport="stdio")
304
+
305
+
306
+ def serve_http() -> None:
307
+ """Run the MCP server over streamable-HTTP at MCP_URL."""
308
+ mcp.run(transport="streamable-http")
309
+
310
+
311
+ if __name__ == "__main__":
312
+ serve_http()