sqlprism 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,982 @@
1
+ """MCP server exposing SQL indexer tools.
2
+
3
+ This is the interface LLMs interact with. Tools are provider-agnostic —
4
+ any MCP client (Claude, Cursor, Continue.dev, etc.) can connect via
5
+ stdio or streamable HTTP.
6
+
7
+ Focused entirely on SQL: tables, views, CTEs, column lineage, transforms,
8
+ WHERE filters, and dependency tracing across dialects.
9
+ """
10
+
11
+ import asyncio
12
+ import json
13
+ from dataclasses import dataclass
14
+ from datetime import datetime
15
+ from pathlib import Path
16
+ from typing import Literal
17
+
18
+ from mcp.server.fastmcp import FastMCP
19
+ from pydantic import BaseModel, Field
20
+
21
+ from sqlprism.core.graph import GraphDB
22
+ from sqlprism.core.indexer import Indexer
23
+ from sqlprism.languages import is_sql_file
24
+ from sqlprism.types import NodeResult, ParseResult, parse_repo_config
25
+
26
+ # ── Server initialisation ──
27
+
28
+ mcp = FastMCP("sqlprism")
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class _ServerState:
33
+ """Immutable bundle of server state, swapped atomically on configure()."""
34
+
35
+ graph: GraphDB
36
+ indexer: Indexer
37
+ config: dict
38
+
39
+
40
+ # Single atomic reference — readers snapshot this once; no lock needed.
41
+ _state: _ServerState | None = None
42
+
43
+ # Background reindex state — shared across reindex, reindex_dbt, reindex_sqlmesh.
44
+ # Only one reindex may run at a time to avoid write-lock conflicts.
45
+ _reindex_lock = asyncio.Lock()
46
+ _reindex_task: asyncio.Task | None = None
47
+ _reindex_status: dict = {"state": "idle"}
48
+ _last_parse_errors: list[str] = []
49
+
50
+
51
+ def configure(db_path: str | Path, repos: dict, sql_dialect: str | None = None):
52
+ """Initialise the graph and indexer with repo configuration.
53
+
54
+ Args:
55
+ db_path: Path to DuckDB file
56
+ repos: {repo_name: path_or_config} — value is either a string path
57
+ or a dict with "path", "dialect", "dialect_overrides" keys
58
+ sql_dialect: Global fallback SQL dialect (overridden by per-repo config)
59
+
60
+ Thread-safety: builds a new immutable ``_ServerState`` and swaps it in
61
+ with a single assignment, so concurrent readers never see partial updates.
62
+ """
63
+ global _state
64
+ graph = GraphDB(db_path)
65
+ indexer = Indexer(graph)
66
+ config = {
67
+ "db_path": str(db_path),
68
+ "repos": repos,
69
+ "sql_dialect": sql_dialect,
70
+ }
71
+
72
+ # Register repos before publishing new state
73
+ for name, cfg in repos.items():
74
+ path = cfg["path"] if isinstance(cfg, dict) else cfg
75
+ graph.upsert_repo(name, path)
76
+
77
+ # Atomic swap — readers always get a consistent triple
78
+ _state = _ServerState(graph=graph, indexer=indexer, config=config)
79
+
80
+
81
+ def _get_state() -> _ServerState:
82
+ """Snapshot current server state or raise if not yet configured."""
83
+ state = _state
84
+ if state is None:
85
+ raise RuntimeError("Server not configured. Call configure() first.")
86
+ return state
87
+
88
+
89
+ def _get_graph() -> GraphDB:
90
+ return _get_state().graph
91
+
92
+
93
+ def _get_indexer() -> Indexer:
94
+ return _get_state().indexer
95
+
96
+
97
+ def _resolve_repo_config(repo_name: str) -> tuple[str, str | None, dict[str, str] | None]:
98
+ """Extract (path, dialect, dialect_overrides) from repo config."""
99
+ config = _get_state().config
100
+ cfg = config["repos"].get(repo_name)
101
+ if cfg is None:
102
+ raise ValueError(f"Repo '{repo_name}' not found in config")
103
+ return parse_repo_config(cfg, config.get("sql_dialect"))
104
+
105
+
106
+ # ── Query tools ──
107
+
108
+
109
+ class SearchInput(BaseModel):
110
+ model_config = {"populate_by_name": True}
111
+ pattern: str = Field(
112
+ ...,
113
+ description="Search pattern (partial name match, case-insensitive)",
114
+ )
115
+ kind: str | None = Field(
116
+ None,
117
+ description="Filter by node kind: 'table', 'view', 'cte', 'query'",
118
+ )
119
+ sql_schema: str | None = Field(
120
+ None,
121
+ alias="schema",
122
+ description="Filter by SQL schema name (e.g. 'staging', 'public')",
123
+ )
124
+ repo: str | None = Field(None, description="Filter by repo name. Omit to search all repos.")
125
+ limit: int = Field(20, description="Max results (default 20)", ge=1, le=100)
126
+ offset: int = Field(0, description="Number of results to skip for pagination (default 0)", ge=0)
127
+ include_snippets: bool = Field(True, description="Include source code snippets in results")
128
+
129
+
130
+ @mcp.tool(
131
+ name="search",
132
+ annotations={
133
+ "readOnlyHint": True,
134
+ "destructiveHint": False,
135
+ "idempotentHint": True,
136
+ "openWorldHint": False,
137
+ },
138
+ )
139
+ async def search(params: SearchInput) -> dict:
140
+ """Search for SQL entities by name across the codebase graph.
141
+
142
+ Finds tables, views, CTEs, and queries by partial name match.
143
+ Returns matches with name, kind, file path, repo, and line numbers.
144
+ """
145
+ return await asyncio.to_thread(
146
+ _get_graph().query_search,
147
+ pattern=params.pattern,
148
+ kind=params.kind,
149
+ schema=params.sql_schema,
150
+ repo=params.repo,
151
+ limit=params.limit,
152
+ offset=params.offset,
153
+ include_snippets=params.include_snippets,
154
+ )
155
+
156
+
157
+ class FindReferencesInput(BaseModel):
158
+ model_config = {"populate_by_name": True}
159
+ name: str = Field(..., description="Entity name (table, view, CTE, etc.)")
160
+ kind: str | None = Field(None, description="Filter by node kind to disambiguate")
161
+ sql_schema: str | None = Field(
162
+ None,
163
+ alias="schema",
164
+ description="Filter by SQL schema name (e.g. 'staging', 'public')",
165
+ )
166
+ repo: str | None = Field(
167
+ None,
168
+ description="Filter by repo name. Omit to search all repos.",
169
+ )
170
+ direction: Literal["both", "inbound", "outbound"] = Field(
171
+ "both",
172
+ description="'inbound', 'outbound', or 'both'",
173
+ )
174
+ include_snippets: bool = Field(True, description="Include source code snippets in results")
175
+ limit: int = Field(100, description="Max results per direction (default 100)", ge=1, le=500)
176
+ offset: int = Field(0, description="Number of results to skip for pagination (default 0)", ge=0)
177
+
178
+
179
+ @mcp.tool(
180
+ name="find_references",
181
+ annotations={
182
+ "readOnlyHint": True,
183
+ "destructiveHint": False,
184
+ "idempotentHint": True,
185
+ "openWorldHint": False,
186
+ },
187
+ )
188
+ async def find_references(params: FindReferencesInput) -> dict:
189
+ """Find everything connected to a named SQL entity.
190
+
191
+ Returns both inbound (what depends on this) and outbound (what this depends on)
192
+ relationships. Each result includes: name, kind, relationship type, file path, repo.
193
+ """
194
+ return await asyncio.to_thread(
195
+ _get_graph().query_references,
196
+ name=params.name,
197
+ kind=params.kind,
198
+ schema=params.sql_schema,
199
+ repo=params.repo,
200
+ direction=params.direction,
201
+ include_snippets=params.include_snippets,
202
+ limit=params.limit,
203
+ offset=params.offset,
204
+ )
205
+
206
+
207
+ class FindColumnUsageInput(BaseModel):
208
+ table: str = Field(..., description="Table name to search column usage for")
209
+ column: str | None = Field(None, description="Specific column name. Omit for all columns.")
210
+ usage_type: str | None = Field(
211
+ None,
212
+ description=("Filter: 'select', 'where', 'join_on', 'group_by', 'order_by', 'having', 'insert', 'update'"),
213
+ )
214
+ repo: str | None = Field(None, description="Filter by repo name. Omit to search all repos.")
215
+ limit: int = Field(100, description="Max results (default 100)", ge=1, le=500)
216
+ offset: int = Field(0, description="Number of results to skip for pagination (default 0)", ge=0)
217
+
218
+
219
+ @mcp.tool(
220
+ name="find_column_usage",
221
+ annotations={
222
+ "readOnlyHint": True,
223
+ "destructiveHint": False,
224
+ "idempotentHint": True,
225
+ "openWorldHint": False,
226
+ },
227
+ )
228
+ async def find_column_usage(params: FindColumnUsageInput) -> dict:
229
+ """Find where and how columns are used across SQL models.
230
+
231
+ Powered by sqlglot's column lineage analysis. Shows usage type,
232
+ transforms (CAST, COALESCE, etc.), output aliases, and WHERE conditions.
233
+
234
+ Answers: "where is customer_id used in WHERE clauses?",
235
+ "how is animal.breed_id transformed?", "show all column usage on orders."
236
+ """
237
+ return await asyncio.to_thread(
238
+ _get_graph().query_column_usage,
239
+ table=params.table,
240
+ column=params.column,
241
+ usage_type=params.usage_type,
242
+ repo=params.repo,
243
+ limit=params.limit,
244
+ offset=params.offset,
245
+ )
246
+
247
+
248
+ class TraceDependenciesInput(BaseModel):
249
+ name: str = Field(..., description="Starting entity name")
250
+ kind: str | None = Field(None, description="Filter by node kind to disambiguate")
251
+ direction: Literal["upstream", "downstream", "both"] = Field(
252
+ "downstream",
253
+ description="'upstream', 'downstream', or 'both'",
254
+ )
255
+ max_depth: int = Field(
256
+ 3,
257
+ description="Maximum hops to traverse (default 3, max 6)",
258
+ ge=1,
259
+ le=6,
260
+ )
261
+ repo: str | None = Field(
262
+ None,
263
+ description="Filter by repo name. Omit to trace across all repos.",
264
+ )
265
+ include_snippets: bool = Field(
266
+ False,
267
+ description="Include source code snippets (default false for trace, can be large)",
268
+ )
269
+ limit: int = Field(100, description="Max results (default 100)", ge=1, le=500)
270
+
271
+
272
+ @mcp.tool(
273
+ name="trace_dependencies",
274
+ annotations={
275
+ "readOnlyHint": True,
276
+ "destructiveHint": False,
277
+ "idempotentHint": True,
278
+ "openWorldHint": False,
279
+ },
280
+ )
281
+ async def trace_dependencies(params: TraceDependenciesInput) -> dict:
282
+ """Trace multi-hop dependency chains through the SQL graph.
283
+
284
+ Follows table → view → CTE → query chains. Use for impact analysis:
285
+ "if I change this table, what models break?"
286
+ """
287
+ return await asyncio.to_thread(
288
+ _get_graph().query_trace,
289
+ name=params.name,
290
+ kind=params.kind,
291
+ direction=params.direction,
292
+ max_depth=params.max_depth,
293
+ repo=params.repo,
294
+ include_snippets=params.include_snippets,
295
+ limit=params.limit,
296
+ )
297
+
298
+
299
+ class TraceColumnLineageInput(BaseModel):
300
+ table: str | None = Field(
301
+ None,
302
+ description="Source or intermediate table name to trace lineage for",
303
+ )
304
+ column: str | None = Field(
305
+ None,
306
+ description="Column name to trace",
307
+ )
308
+ output_node: str | None = Field(
309
+ None,
310
+ description="Output entity name (table/view/query) to trace lineage from",
311
+ )
312
+ repo: str | None = Field(None, description="Filter by repo name. Omit to search all repos.")
313
+ limit: int = Field(100, description="Max lineage chains to return (default 100)", ge=1, le=500)
314
+ offset: int = Field(0, description="Number of chains to skip for pagination (default 0)", ge=0)
315
+
316
+
317
+ @mcp.tool(
318
+ name="trace_column_lineage",
319
+ annotations={
320
+ "readOnlyHint": True,
321
+ "destructiveHint": False,
322
+ "idempotentHint": True,
323
+ "openWorldHint": False,
324
+ },
325
+ )
326
+ async def trace_column_lineage(params: TraceColumnLineageInput) -> dict:
327
+ """Trace end-to-end column lineage through CTEs and subqueries.
328
+
329
+ Shows how an output column traces back to source table columns, with
330
+ each intermediate hop (CTE, subquery) and any transforms (CAST, etc.).
331
+
332
+ Answers: "where does dim_users.created_date come from?",
333
+ "which output columns depend on orders.amount?"
334
+
335
+ Note: SELECT * lineage requires a schema catalog built from prior column
336
+ usage data. On a fresh index, SELECT * columns may not be expanded.
337
+ Run a second full reindex to populate the catalog and resolve them.
338
+ """
339
+ return await asyncio.to_thread(
340
+ _get_graph().query_column_lineage,
341
+ table=params.table,
342
+ column=params.column,
343
+ output_node=params.output_node,
344
+ repo=params.repo,
345
+ limit=params.limit,
346
+ offset=params.offset,
347
+ )
348
+
349
+
350
+ class PrImpactInput(BaseModel):
351
+ base_commit: str = Field(
352
+ ...,
353
+ description="Git commit hash or ref to compare against (e.g., 'main', 'abc123f')",
354
+ )
355
+ repo: str | None = Field(
356
+ None,
357
+ description="Repo to analyse. Required if multiple repos configured.",
358
+ )
359
+ max_blast_radius_depth: int = Field(
360
+ 3,
361
+ description="Hops to trace from changed nodes (default 3)",
362
+ ge=1,
363
+ le=6,
364
+ )
365
+ compare_mode: Literal["delta", "absolute"] = Field(
366
+ "delta",
367
+ description=("'delta' = net-new impact vs base (default), 'absolute' = total blast radius (v1 behavior)"),
368
+ )
369
+
370
+
371
+ @mcp.tool(
372
+ name="pr_impact",
373
+ annotations={
374
+ "readOnlyHint": True,
375
+ "destructiveHint": False,
376
+ "idempotentHint": False,
377
+ "openWorldHint": False,
378
+ },
379
+ )
380
+ async def pr_impact(params: PrImpactInput) -> dict:
381
+ """Analyse the structural impact of SQL changes since a base commit.
382
+
383
+ Computes structural diff (added/removed/modified tables, views, CTEs,
384
+ column usage) then traces the blast radius through the full index.
385
+
386
+ **Delta mode caveat:** ``compare_mode="delta"`` shows **net-new downstream
387
+ impact** by approximating the base blast radius via edge exclusion on the
388
+ HEAD graph. It does **not** detect reduced blast radius from removed edges
389
+ — ``no_longer_affected`` will be empty when a PR only removes dependencies.
390
+ Use ``compare_mode="absolute"`` for a full picture when edge removals are
391
+ the primary change.
392
+ """
393
+ state = _get_state()
394
+ indexer = state.indexer
395
+ graph = state.graph
396
+ config = state.config
397
+
398
+ # Determine which repo
399
+ if params.repo:
400
+ path, dialect, dialect_overrides = _resolve_repo_config(params.repo)
401
+ repo_path = Path(path)
402
+ elif len(config["repos"]) == 1:
403
+ repo_name = list(config["repos"].keys())[0]
404
+ path, dialect, dialect_overrides = _resolve_repo_config(repo_name)
405
+ repo_path = Path(path)
406
+ else:
407
+ return {"error": "Multiple repos configured — specify which repo to analyse."}
408
+
409
+ def _blocking_pr_impact() -> dict:
410
+ changed_files = indexer.get_changed_files(repo_path, params.base_commit)
411
+ if not changed_files:
412
+ return {"files_changed": [], "structural_diff": {}, "blast_radius": {}}
413
+
414
+ old_results: dict[str, ParseResult] = {}
415
+ new_results: dict[str, ParseResult] = {}
416
+
417
+ for file_path in changed_files:
418
+ full_path = repo_path / file_path
419
+ if full_path.exists() and is_sql_file(file_path):
420
+ content = full_path.read_text(errors="replace")
421
+ new_results[file_path] = indexer.parse_file(file_path, content, dialect)
422
+
423
+ old = indexer.parse_file_at_commit(repo_path, file_path, params.base_commit, dialect)
424
+ if old:
425
+ old_results[file_path] = old
426
+
427
+ diff = _compute_structural_diff(old_results, new_results)
428
+
429
+ affected_node_names = (
430
+ [n["name"] for n in diff["nodes_added"]]
431
+ + [n["name"] for n in diff["nodes_removed"]]
432
+ + [n["name"] for n in diff["nodes_modified"]]
433
+ )
434
+
435
+ # Names of truly new nodes (no base trace needed for these)
436
+ added_names = {n["name"] for n in diff["nodes_added"]}
437
+
438
+ # Build exclude set: edges added in HEAD that did not exist at base
439
+ edges_added_set: set[tuple[str, str]] = {(e["source"], e["target"]) for e in diff.get("edges_added", [])}
440
+
441
+ is_delta = params.compare_mode == "delta"
442
+
443
+ blast_radius: dict = {}
444
+ if affected_node_names:
445
+ head_affected: set[tuple[str, str]] = set()
446
+ base_affected: set[tuple[str, str]] = set()
447
+ all_head_paths: list[dict] = [] # flat list for repo counting
448
+ repos_hit: set[str] = set()
449
+ truncated = len(affected_node_names) > 20
450
+
451
+ affected_node_names.sort()
452
+ for node_name in affected_node_names[:20]:
453
+ # HEAD blast radius (current graph)
454
+ head_trace = graph.query_trace(
455
+ name=node_name,
456
+ direction="downstream",
457
+ max_depth=params.max_blast_radius_depth,
458
+ )
459
+ head_paths = head_trace.get("paths", [])
460
+ head_affected.update((p["name"], p["kind"]) for p in head_paths)
461
+ all_head_paths.extend(head_paths)
462
+ repos_hit.update(head_trace.get("repos_affected", []))
463
+
464
+ # Base blast radius approximation (exclude new edges)
465
+ if is_delta and node_name not in added_names:
466
+ base_trace = graph.query_trace(
467
+ name=node_name,
468
+ direction="downstream",
469
+ max_depth=params.max_blast_radius_depth,
470
+ exclude_edges=edges_added_set,
471
+ )
472
+ base_affected.update((p["name"], p["kind"]) for p in base_trace.get("paths", []))
473
+
474
+ if is_delta:
475
+ newly_affected = head_affected - base_affected
476
+ no_longer_affected = base_affected - head_affected
477
+
478
+ blast_radius = {
479
+ "compare_mode": "delta",
480
+ "head_total": len(head_affected),
481
+ "base_total": len(base_affected),
482
+ "delta": len(head_affected) - len(base_affected),
483
+ "newly_affected": [{"name": n, "kind": k} for n, k in sorted(newly_affected)],
484
+ "no_longer_affected": [{"name": n, "kind": k} for n, k in sorted(no_longer_affected)],
485
+ "unchanged_affected": len(head_affected & base_affected),
486
+ "note": (
487
+ "Delta mode approximates the base blast radius by "
488
+ "excluding newly-added edges from the HEAD graph. "
489
+ "It shows net-new downstream impact but does not "
490
+ "detect reduced blast radius from removed edges."
491
+ ),
492
+ # Backward-compat fields
493
+ "transitively_affected": len(head_affected),
494
+ "repos_affected": sorted(repos_hit),
495
+ "truncated": truncated,
496
+ "total_affected_nodes": len(affected_node_names),
497
+ }
498
+ else:
499
+ # Absolute mode (v1 behavior)
500
+ blast_radius = {
501
+ "compare_mode": "absolute",
502
+ "transitively_affected": len(all_head_paths),
503
+ "affected_by_repo": {r: sum(1 for a in all_head_paths if a.get("repo") == r) for r in repos_hit},
504
+ "repos_affected": sorted(repos_hit),
505
+ "truncated": truncated,
506
+ "total_affected_nodes": len(affected_node_names),
507
+ }
508
+
509
+ if truncated:
510
+ blast_radius["truncation_message"] = (
511
+ f"Blast radius incomplete — {len(affected_node_names)} affected nodes, "
512
+ "only first 20 traced. Use trace_dependencies "
513
+ "on specific nodes for full picture."
514
+ )
515
+
516
+ return {
517
+ "files_changed": changed_files,
518
+ "structural_diff": diff,
519
+ "blast_radius": blast_radius,
520
+ }
521
+
522
+ return await asyncio.to_thread(_blocking_pr_impact)
523
+
524
+
525
+ # ── Index management tools ──
526
+
527
+
528
+ class ReindexInput(BaseModel):
529
+ repo: str | None = Field(None, description="Specific repo to reindex. Omit for all repos.")
530
+
531
+
532
+ @mcp.tool(
533
+ name="reindex",
534
+ annotations={
535
+ "readOnlyHint": False,
536
+ "destructiveHint": False,
537
+ "idempotentHint": True,
538
+ "openWorldHint": False,
539
+ },
540
+ )
541
+ async def reindex(params: ReindexInput) -> dict:
542
+ """Trigger a reindex of SQL files. Checksums and re-parses only what changed.
543
+
544
+ Runs in the background so queries remain available during reindex.
545
+ Supports per-repo SQL dialects and path-based dialect overrides.
546
+ """
547
+ global _reindex_task, _reindex_status
548
+
549
+ async with _reindex_lock:
550
+ # If already running, return status
551
+ if _reindex_task and not _reindex_task.done():
552
+ return {"status": "in_progress", **_reindex_status}
553
+
554
+ state = _get_state()
555
+ indexer = state.indexer
556
+
557
+ repos = state.config["repos"]
558
+ if params.repo:
559
+ if params.repo not in repos:
560
+ return {"error": f"Repo '{params.repo}' not found in config"}
561
+ repos = {params.repo: repos[params.repo]}
562
+
563
+ repo_names = list(repos.keys())
564
+ _reindex_status = {
565
+ "state": "started",
566
+ "started_at": datetime.now().isoformat(),
567
+ "repos": repo_names,
568
+ }
569
+
570
+ async def _background_reindex():
571
+ global _reindex_status
572
+ try:
573
+
574
+ def _blocking():
575
+ global _reindex_status
576
+ results = {}
577
+ for name, cfg in repos.items():
578
+ _reindex_status = {**_reindex_status, "current_repo": name}
579
+ path, dialect, dialect_overrides = _resolve_repo_config(name)
580
+ results[name] = indexer.reindex_repo(
581
+ name,
582
+ path,
583
+ dialect=dialect,
584
+ dialect_overrides=dialect_overrides,
585
+ )
586
+ return results
587
+
588
+ result = await asyncio.to_thread(_blocking)
589
+ global _last_parse_errors
590
+ all_errors = []
591
+ for repo_result in result.values():
592
+ all_errors.extend(repo_result.get("parse_errors", []))
593
+ _last_parse_errors = all_errors
594
+ _reindex_status = {
595
+ **_reindex_status,
596
+ "state": "completed",
597
+ "completed_at": datetime.now().isoformat(),
598
+ "result": result,
599
+ }
600
+ return result
601
+ except Exception as e:
602
+ _reindex_status = {
603
+ **_reindex_status,
604
+ "state": "failed",
605
+ "error": str(e),
606
+ "failed_at": datetime.now().isoformat(),
607
+ }
608
+
609
+ _reindex_task = asyncio.create_task(_background_reindex())
610
+
611
+ return {
612
+ "status": "started",
613
+ "message": ("Reindex running in background. Queries remain available. Call index_status to check progress."),
614
+ "repos": repo_names,
615
+ }
616
+
617
+
618
+ class ReindexSqlmeshInput(BaseModel):
619
+ name: str = Field(..., description="Repo name for the index")
620
+ project_path: str = Field(
621
+ ...,
622
+ description="Path to sqlmesh project dir (containing config.yaml)",
623
+ )
624
+ env_file: str | None = Field(
625
+ None,
626
+ description="Path to .env file for sqlmesh config variables",
627
+ )
628
+ dialect: str = Field(
629
+ "athena",
630
+ description="SQL dialect for rendering (default: athena)",
631
+ )
632
+ variables: dict[str, str] | None = Field(
633
+ None,
634
+ description='SQLMesh variables, e.g. {"GRACE_PERIOD": "7"}',
635
+ )
636
+ sqlmesh_command: str = Field(
637
+ "uv run python",
638
+ description="Command to run python in sqlmesh venv",
639
+ )
640
+
641
+
642
+ @mcp.tool(
643
+ name="reindex_sqlmesh",
644
+ annotations={
645
+ "readOnlyHint": False,
646
+ "destructiveHint": False,
647
+ "idempotentHint": True,
648
+ "openWorldHint": False,
649
+ },
650
+ )
651
+ async def reindex_sqlmesh(params: ReindexSqlmeshInput) -> dict:
652
+ """Index a sqlmesh project by rendering all models into clean SQL.
653
+
654
+ Runs in the background so queries remain available during reindex.
655
+ Uses sqlmesh's rendering engine to expand macros and resolve variables,
656
+ then parses with sqlglot to extract tables, CTEs, edges, column lineage.
657
+ """
658
+ global _reindex_task, _reindex_status
659
+
660
+ async with _reindex_lock:
661
+ # If already running, return status
662
+ if _reindex_task and not _reindex_task.done():
663
+ return {"status": "in_progress", **_reindex_status}
664
+
665
+ indexer = _get_indexer()
666
+
667
+ var_dict: dict[str, str | int] = {}
668
+ if params.variables:
669
+ for k, v in params.variables.items():
670
+ try:
671
+ var_dict[k] = int(v)
672
+ except ValueError:
673
+ var_dict[k] = v
674
+
675
+ _reindex_status = {
676
+ "state": "started",
677
+ "started_at": datetime.now().isoformat(),
678
+ "repos": [params.name],
679
+ "tool": "reindex_sqlmesh",
680
+ }
681
+
682
+ async def _background_reindex():
683
+ global _reindex_status
684
+ try:
685
+ result = await asyncio.to_thread(
686
+ indexer.reindex_sqlmesh,
687
+ repo_name=params.name,
688
+ project_path=params.project_path,
689
+ env_file=params.env_file,
690
+ variables=var_dict,
691
+ dialect=params.dialect,
692
+ sqlmesh_command=params.sqlmesh_command,
693
+ )
694
+ global _last_parse_errors
695
+ if isinstance(result, dict):
696
+ _last_parse_errors = result.get("parse_errors", [])
697
+ _reindex_status = {
698
+ **_reindex_status,
699
+ "state": "completed",
700
+ "completed_at": datetime.now().isoformat(),
701
+ "result": result,
702
+ }
703
+ return result
704
+ except Exception as e:
705
+ _reindex_status = {
706
+ **_reindex_status,
707
+ "state": "failed",
708
+ "error": str(e),
709
+ "failed_at": datetime.now().isoformat(),
710
+ }
711
+
712
+ _reindex_task = asyncio.create_task(_background_reindex())
713
+
714
+ return {
715
+ "status": "started",
716
+ "message": (
717
+ "SQLMesh reindex running in background. Queries remain available. Call index_status to check progress."
718
+ ),
719
+ "repos": [params.name],
720
+ }
721
+
722
+
723
+ class ReindexDbtInput(BaseModel):
724
+ name: str = Field(..., description="Repo name for the index")
725
+ project_path: str = Field(
726
+ ...,
727
+ description="Path to dbt project dir (containing dbt_project.yml)",
728
+ )
729
+ profiles_dir: str | None = Field(
730
+ None,
731
+ description="Path to directory containing profiles.yml",
732
+ )
733
+ env_file: str | None = Field(
734
+ None,
735
+ description="Path to .env file for dbt connection variables",
736
+ )
737
+ target: str | None = Field(None, description="dbt target name")
738
+ dbt_command: str = Field(
739
+ "uv run dbt",
740
+ description="Command to invoke dbt",
741
+ )
742
+ dialect: str | None = Field(
743
+ None,
744
+ description="SQL dialect for parsing (e.g. 'starrocks', 'mysql', 'postgres')",
745
+ )
746
+
747
+
748
+ @mcp.tool(
749
+ name="reindex_dbt",
750
+ annotations={
751
+ "readOnlyHint": False,
752
+ "destructiveHint": False,
753
+ "idempotentHint": True,
754
+ "openWorldHint": False,
755
+ },
756
+ )
757
+ async def reindex_dbt(params: ReindexDbtInput) -> dict:
758
+ """Index a dbt project by compiling all models into clean SQL.
759
+
760
+ Runs in the background so queries remain available during reindex.
761
+ Runs `dbt compile`, then parses with sqlglot to extract tables, CTEs,
762
+ edges, column lineage with transforms.
763
+ """
764
+ global _reindex_task, _reindex_status
765
+
766
+ async with _reindex_lock:
767
+ # If already running, return status
768
+ if _reindex_task and not _reindex_task.done():
769
+ return {"status": "in_progress", **_reindex_status}
770
+
771
+ indexer = _get_indexer()
772
+
773
+ _reindex_status = {
774
+ "state": "started",
775
+ "started_at": datetime.now().isoformat(),
776
+ "repos": [params.name],
777
+ "tool": "reindex_dbt",
778
+ }
779
+
780
+ async def _background_reindex():
781
+ global _reindex_status
782
+ try:
783
+ result = await asyncio.to_thread(
784
+ indexer.reindex_dbt,
785
+ repo_name=params.name,
786
+ project_path=params.project_path,
787
+ profiles_dir=params.profiles_dir,
788
+ env_file=params.env_file,
789
+ target=params.target,
790
+ dbt_command=params.dbt_command,
791
+ dialect=params.dialect,
792
+ )
793
+ global _last_parse_errors
794
+ if isinstance(result, dict):
795
+ _last_parse_errors = result.get("parse_errors", [])
796
+ _reindex_status = {
797
+ **_reindex_status,
798
+ "state": "completed",
799
+ "completed_at": datetime.now().isoformat(),
800
+ "result": result,
801
+ }
802
+ return result
803
+ except Exception as e:
804
+ _reindex_status = {
805
+ **_reindex_status,
806
+ "state": "failed",
807
+ "error": str(e),
808
+ "failed_at": datetime.now().isoformat(),
809
+ }
810
+
811
+ _reindex_task = asyncio.create_task(_background_reindex())
812
+
813
+ return {
814
+ "status": "started",
815
+ "message": (
816
+ "dbt reindex running in background. Queries remain available. Call index_status to check progress."
817
+ ),
818
+ "repos": [params.name],
819
+ }
820
+
821
+
822
+ @mcp.tool(
823
+ name="index_status",
824
+ annotations={
825
+ "readOnlyHint": True,
826
+ "destructiveHint": False,
827
+ "idempotentHint": True,
828
+ "openWorldHint": False,
829
+ },
830
+ )
831
+ async def index_status() -> dict:
832
+ """Current state of the index — repos, file counts, last commit, staleness."""
833
+ status = await asyncio.to_thread(_get_graph().get_index_status)
834
+ if _reindex_task and not _reindex_task.done():
835
+ status["reindex_in_progress"] = True
836
+ status["reindex_status"] = _reindex_status
837
+ elif _reindex_status.get("state") in ("completed", "failed"):
838
+ status["last_reindex"] = _reindex_status
839
+ status["parse_error_count"] = len(_last_parse_errors)
840
+ if _last_parse_errors:
841
+ status["last_parse_errors"] = _last_parse_errors[:50] # cap at 50
842
+ return status
843
+
844
+
845
+ # ── Internal helpers ──
846
+
847
+
848
+ def _node_fingerprint(node: NodeResult) -> str:
849
+ """Create a comparable fingerprint for a node including metadata."""
850
+ return json.dumps(node.metadata, sort_keys=True) if node.metadata else ""
851
+
852
+
853
+ def _compute_structural_diff(
854
+ old_results: dict[str, ParseResult],
855
+ new_results: dict[str, ParseResult],
856
+ ) -> dict:
857
+ """Compare old and new parse results to find structural changes."""
858
+ old_nodes = set()
859
+ new_nodes = set()
860
+ old_edges = set()
861
+ new_edges = set()
862
+ old_columns = set()
863
+ new_columns = set()
864
+
865
+ # Track edges, columns, and metadata per node to detect actual modifications
866
+ old_node_edges: dict[tuple[str, str, str | None], set] = {}
867
+ new_node_edges: dict[tuple[str, str, str | None], set] = {}
868
+ old_node_columns: dict[tuple[str, str, str | None], set] = {}
869
+ new_node_columns: dict[tuple[str, str, str | None], set] = {}
870
+ old_node_meta: dict[tuple[str, str, str | None], str] = {}
871
+ new_node_meta: dict[tuple[str, str, str | None], str] = {}
872
+
873
+ # Build (name, kind) -> schema lookups so edges/columns can resolve schema
874
+ old_schema_lookup: dict[tuple[str, str], str | None] = {}
875
+ new_schema_lookup: dict[tuple[str, str], str | None] = {}
876
+
877
+ for result in old_results.values():
878
+ for n in result.nodes:
879
+ schema = (n.metadata or {}).get("schema")
880
+ key = (n.name, n.kind, schema)
881
+ old_nodes.add(key)
882
+ old_node_edges.setdefault(key, set())
883
+ old_node_columns.setdefault(key, set())
884
+ old_node_meta[key] = _node_fingerprint(n)
885
+ old_schema_lookup[(n.name, n.kind)] = schema
886
+ for e in result.edges:
887
+ edge_tuple = (
888
+ e.source_name,
889
+ e.source_kind,
890
+ e.target_name,
891
+ e.target_kind,
892
+ e.relationship,
893
+ )
894
+ old_edges.add(edge_tuple)
895
+ src_schema = old_schema_lookup.get((e.source_name, e.source_kind))
896
+ src_key = (e.source_name, e.source_kind, src_schema)
897
+ old_node_edges.setdefault(src_key, set()).add(edge_tuple)
898
+ for c in result.column_usage:
899
+ col_tuple = (c.node_name, c.table_name, c.column_name, c.usage_type)
900
+ old_columns.add(col_tuple)
901
+ col_schema = old_schema_lookup.get((c.node_name, c.node_kind))
902
+ col_key = (c.node_name, c.node_kind, col_schema)
903
+ old_node_columns.setdefault(col_key, set()).add(col_tuple)
904
+
905
+ for result in new_results.values():
906
+ for n in result.nodes:
907
+ schema = (n.metadata or {}).get("schema")
908
+ key = (n.name, n.kind, schema)
909
+ new_nodes.add(key)
910
+ new_node_edges.setdefault(key, set())
911
+ new_node_columns.setdefault(key, set())
912
+ new_node_meta[key] = _node_fingerprint(n)
913
+ new_schema_lookup[(n.name, n.kind)] = schema
914
+ for e in result.edges:
915
+ edge_tuple = (
916
+ e.source_name,
917
+ e.source_kind,
918
+ e.target_name,
919
+ e.target_kind,
920
+ e.relationship,
921
+ )
922
+ new_edges.add(edge_tuple)
923
+ src_schema = new_schema_lookup.get((e.source_name, e.source_kind))
924
+ src_key = (e.source_name, e.source_kind, src_schema)
925
+ new_node_edges.setdefault(src_key, set()).add(edge_tuple)
926
+ for c in result.column_usage:
927
+ col_tuple = (c.node_name, c.table_name, c.column_name, c.usage_type)
928
+ new_columns.add(col_tuple)
929
+ col_schema = new_schema_lookup.get((c.node_name, c.node_kind))
930
+ col_key = (c.node_name, c.node_kind, col_schema)
931
+ new_node_columns.setdefault(col_key, set()).add(col_tuple)
932
+
933
+ # A node present in both is "modified" if its edges, columns, or metadata changed
934
+ nodes_modified = []
935
+ for key in old_nodes & new_nodes:
936
+ if (
937
+ old_node_edges.get(key, set()) != new_node_edges.get(key, set())
938
+ or old_node_columns.get(key, set()) != new_node_columns.get(key, set())
939
+ or old_node_meta.get(key, "") != new_node_meta.get(key, "")
940
+ ):
941
+ entry = {"name": key[0], "kind": key[1]}
942
+ if key[2] is not None:
943
+ entry["schema"] = key[2]
944
+ nodes_modified.append(entry)
945
+
946
+ def _node_dict(n):
947
+ d = {"name": n[0], "kind": n[1]}
948
+ if n[2] is not None:
949
+ d["schema"] = n[2]
950
+ return d
951
+
952
+ return {
953
+ "nodes_added": [_node_dict(n) for n in new_nodes - old_nodes],
954
+ "nodes_removed": [_node_dict(n) for n in old_nodes - new_nodes],
955
+ "nodes_modified": nodes_modified,
956
+ "edges_added": [
957
+ {
958
+ "source": e[0],
959
+ "source_kind": e[1],
960
+ "target": e[2],
961
+ "target_kind": e[3],
962
+ "relationship": e[4],
963
+ }
964
+ for e in new_edges - old_edges
965
+ ],
966
+ "edges_removed": [
967
+ {
968
+ "source": e[0],
969
+ "source_kind": e[1],
970
+ "target": e[2],
971
+ "target_kind": e[3],
972
+ "relationship": e[4],
973
+ }
974
+ for e in old_edges - new_edges
975
+ ],
976
+ "columns_added": [
977
+ {"node": c[0], "table": c[1], "column": c[2], "usage_type": c[3]} for c in new_columns - old_columns
978
+ ],
979
+ "columns_removed": [
980
+ {"node": c[0], "table": c[1], "column": c[2], "usage_type": c[3]} for c in old_columns - new_columns
981
+ ],
982
+ }