sqlprism 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlprism/__init__.py +1 -0
- sqlprism/cli.py +625 -0
- sqlprism/core/__init__.py +0 -0
- sqlprism/core/graph.py +1547 -0
- sqlprism/core/indexer.py +677 -0
- sqlprism/core/mcp_tools.py +982 -0
- sqlprism/languages/__init__.py +28 -0
- sqlprism/languages/dbt.py +199 -0
- sqlprism/languages/sql.py +1031 -0
- sqlprism/languages/sqlmesh.py +203 -0
- sqlprism/languages/utils.py +73 -0
- sqlprism/types.py +190 -0
- sqlprism-1.0.0.dist-info/METADATA +429 -0
- sqlprism-1.0.0.dist-info/RECORD +17 -0
- sqlprism-1.0.0.dist-info/WHEEL +4 -0
- sqlprism-1.0.0.dist-info/entry_points.txt +2 -0
- sqlprism-1.0.0.dist-info/licenses/LICENSE +190 -0
|
@@ -0,0 +1,982 @@
|
|
|
1
|
+
"""MCP server exposing SQL indexer tools.
|
|
2
|
+
|
|
3
|
+
This is the interface LLMs interact with. Tools are provider-agnostic —
|
|
4
|
+
any MCP client (Claude, Cursor, Continue.dev, etc.) can connect via
|
|
5
|
+
stdio or streamable HTTP.
|
|
6
|
+
|
|
7
|
+
Focused entirely on SQL: tables, views, CTEs, column lineage, transforms,
|
|
8
|
+
WHERE filters, and dependency tracing across dialects.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import asyncio
|
|
12
|
+
import json
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
from datetime import datetime
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Literal
|
|
17
|
+
|
|
18
|
+
from mcp.server.fastmcp import FastMCP
|
|
19
|
+
from pydantic import BaseModel, Field
|
|
20
|
+
|
|
21
|
+
from sqlprism.core.graph import GraphDB
|
|
22
|
+
from sqlprism.core.indexer import Indexer
|
|
23
|
+
from sqlprism.languages import is_sql_file
|
|
24
|
+
from sqlprism.types import NodeResult, ParseResult, parse_repo_config
|
|
25
|
+
|
|
26
|
+
# ── Server initialisation ──
|
|
27
|
+
|
|
28
|
+
mcp = FastMCP("sqlprism")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass(frozen=True)
|
|
32
|
+
class _ServerState:
|
|
33
|
+
"""Immutable bundle of server state, swapped atomically on configure()."""
|
|
34
|
+
|
|
35
|
+
graph: GraphDB
|
|
36
|
+
indexer: Indexer
|
|
37
|
+
config: dict
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# Single atomic reference — readers snapshot this once; no lock needed.
|
|
41
|
+
_state: _ServerState | None = None
|
|
42
|
+
|
|
43
|
+
# Background reindex state — shared across reindex, reindex_dbt, reindex_sqlmesh.
|
|
44
|
+
# Only one reindex may run at a time to avoid write-lock conflicts.
|
|
45
|
+
_reindex_lock = asyncio.Lock()
|
|
46
|
+
_reindex_task: asyncio.Task | None = None
|
|
47
|
+
_reindex_status: dict = {"state": "idle"}
|
|
48
|
+
_last_parse_errors: list[str] = []
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def configure(db_path: str | Path, repos: dict, sql_dialect: str | None = None):
|
|
52
|
+
"""Initialise the graph and indexer with repo configuration.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
db_path: Path to DuckDB file
|
|
56
|
+
repos: {repo_name: path_or_config} — value is either a string path
|
|
57
|
+
or a dict with "path", "dialect", "dialect_overrides" keys
|
|
58
|
+
sql_dialect: Global fallback SQL dialect (overridden by per-repo config)
|
|
59
|
+
|
|
60
|
+
Thread-safety: builds a new immutable ``_ServerState`` and swaps it in
|
|
61
|
+
with a single assignment, so concurrent readers never see partial updates.
|
|
62
|
+
"""
|
|
63
|
+
global _state
|
|
64
|
+
graph = GraphDB(db_path)
|
|
65
|
+
indexer = Indexer(graph)
|
|
66
|
+
config = {
|
|
67
|
+
"db_path": str(db_path),
|
|
68
|
+
"repos": repos,
|
|
69
|
+
"sql_dialect": sql_dialect,
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
# Register repos before publishing new state
|
|
73
|
+
for name, cfg in repos.items():
|
|
74
|
+
path = cfg["path"] if isinstance(cfg, dict) else cfg
|
|
75
|
+
graph.upsert_repo(name, path)
|
|
76
|
+
|
|
77
|
+
# Atomic swap — readers always get a consistent triple
|
|
78
|
+
_state = _ServerState(graph=graph, indexer=indexer, config=config)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _get_state() -> _ServerState:
|
|
82
|
+
"""Snapshot current server state or raise if not yet configured."""
|
|
83
|
+
state = _state
|
|
84
|
+
if state is None:
|
|
85
|
+
raise RuntimeError("Server not configured. Call configure() first.")
|
|
86
|
+
return state
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _get_graph() -> GraphDB:
|
|
90
|
+
return _get_state().graph
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _get_indexer() -> Indexer:
|
|
94
|
+
return _get_state().indexer
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _resolve_repo_config(repo_name: str) -> tuple[str, str | None, dict[str, str] | None]:
|
|
98
|
+
"""Extract (path, dialect, dialect_overrides) from repo config."""
|
|
99
|
+
config = _get_state().config
|
|
100
|
+
cfg = config["repos"].get(repo_name)
|
|
101
|
+
if cfg is None:
|
|
102
|
+
raise ValueError(f"Repo '{repo_name}' not found in config")
|
|
103
|
+
return parse_repo_config(cfg, config.get("sql_dialect"))
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
# ── Query tools ──
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class SearchInput(BaseModel):
|
|
110
|
+
model_config = {"populate_by_name": True}
|
|
111
|
+
pattern: str = Field(
|
|
112
|
+
...,
|
|
113
|
+
description="Search pattern (partial name match, case-insensitive)",
|
|
114
|
+
)
|
|
115
|
+
kind: str | None = Field(
|
|
116
|
+
None,
|
|
117
|
+
description="Filter by node kind: 'table', 'view', 'cte', 'query'",
|
|
118
|
+
)
|
|
119
|
+
sql_schema: str | None = Field(
|
|
120
|
+
None,
|
|
121
|
+
alias="schema",
|
|
122
|
+
description="Filter by SQL schema name (e.g. 'staging', 'public')",
|
|
123
|
+
)
|
|
124
|
+
repo: str | None = Field(None, description="Filter by repo name. Omit to search all repos.")
|
|
125
|
+
limit: int = Field(20, description="Max results (default 20)", ge=1, le=100)
|
|
126
|
+
offset: int = Field(0, description="Number of results to skip for pagination (default 0)", ge=0)
|
|
127
|
+
include_snippets: bool = Field(True, description="Include source code snippets in results")
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@mcp.tool(
|
|
131
|
+
name="search",
|
|
132
|
+
annotations={
|
|
133
|
+
"readOnlyHint": True,
|
|
134
|
+
"destructiveHint": False,
|
|
135
|
+
"idempotentHint": True,
|
|
136
|
+
"openWorldHint": False,
|
|
137
|
+
},
|
|
138
|
+
)
|
|
139
|
+
async def search(params: SearchInput) -> dict:
|
|
140
|
+
"""Search for SQL entities by name across the codebase graph.
|
|
141
|
+
|
|
142
|
+
Finds tables, views, CTEs, and queries by partial name match.
|
|
143
|
+
Returns matches with name, kind, file path, repo, and line numbers.
|
|
144
|
+
"""
|
|
145
|
+
return await asyncio.to_thread(
|
|
146
|
+
_get_graph().query_search,
|
|
147
|
+
pattern=params.pattern,
|
|
148
|
+
kind=params.kind,
|
|
149
|
+
schema=params.sql_schema,
|
|
150
|
+
repo=params.repo,
|
|
151
|
+
limit=params.limit,
|
|
152
|
+
offset=params.offset,
|
|
153
|
+
include_snippets=params.include_snippets,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class FindReferencesInput(BaseModel):
|
|
158
|
+
model_config = {"populate_by_name": True}
|
|
159
|
+
name: str = Field(..., description="Entity name (table, view, CTE, etc.)")
|
|
160
|
+
kind: str | None = Field(None, description="Filter by node kind to disambiguate")
|
|
161
|
+
sql_schema: str | None = Field(
|
|
162
|
+
None,
|
|
163
|
+
alias="schema",
|
|
164
|
+
description="Filter by SQL schema name (e.g. 'staging', 'public')",
|
|
165
|
+
)
|
|
166
|
+
repo: str | None = Field(
|
|
167
|
+
None,
|
|
168
|
+
description="Filter by repo name. Omit to search all repos.",
|
|
169
|
+
)
|
|
170
|
+
direction: Literal["both", "inbound", "outbound"] = Field(
|
|
171
|
+
"both",
|
|
172
|
+
description="'inbound', 'outbound', or 'both'",
|
|
173
|
+
)
|
|
174
|
+
include_snippets: bool = Field(True, description="Include source code snippets in results")
|
|
175
|
+
limit: int = Field(100, description="Max results per direction (default 100)", ge=1, le=500)
|
|
176
|
+
offset: int = Field(0, description="Number of results to skip for pagination (default 0)", ge=0)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
@mcp.tool(
|
|
180
|
+
name="find_references",
|
|
181
|
+
annotations={
|
|
182
|
+
"readOnlyHint": True,
|
|
183
|
+
"destructiveHint": False,
|
|
184
|
+
"idempotentHint": True,
|
|
185
|
+
"openWorldHint": False,
|
|
186
|
+
},
|
|
187
|
+
)
|
|
188
|
+
async def find_references(params: FindReferencesInput) -> dict:
|
|
189
|
+
"""Find everything connected to a named SQL entity.
|
|
190
|
+
|
|
191
|
+
Returns both inbound (what depends on this) and outbound (what this depends on)
|
|
192
|
+
relationships. Each result includes: name, kind, relationship type, file path, repo.
|
|
193
|
+
"""
|
|
194
|
+
return await asyncio.to_thread(
|
|
195
|
+
_get_graph().query_references,
|
|
196
|
+
name=params.name,
|
|
197
|
+
kind=params.kind,
|
|
198
|
+
schema=params.sql_schema,
|
|
199
|
+
repo=params.repo,
|
|
200
|
+
direction=params.direction,
|
|
201
|
+
include_snippets=params.include_snippets,
|
|
202
|
+
limit=params.limit,
|
|
203
|
+
offset=params.offset,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class FindColumnUsageInput(BaseModel):
|
|
208
|
+
table: str = Field(..., description="Table name to search column usage for")
|
|
209
|
+
column: str | None = Field(None, description="Specific column name. Omit for all columns.")
|
|
210
|
+
usage_type: str | None = Field(
|
|
211
|
+
None,
|
|
212
|
+
description=("Filter: 'select', 'where', 'join_on', 'group_by', 'order_by', 'having', 'insert', 'update'"),
|
|
213
|
+
)
|
|
214
|
+
repo: str | None = Field(None, description="Filter by repo name. Omit to search all repos.")
|
|
215
|
+
limit: int = Field(100, description="Max results (default 100)", ge=1, le=500)
|
|
216
|
+
offset: int = Field(0, description="Number of results to skip for pagination (default 0)", ge=0)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@mcp.tool(
|
|
220
|
+
name="find_column_usage",
|
|
221
|
+
annotations={
|
|
222
|
+
"readOnlyHint": True,
|
|
223
|
+
"destructiveHint": False,
|
|
224
|
+
"idempotentHint": True,
|
|
225
|
+
"openWorldHint": False,
|
|
226
|
+
},
|
|
227
|
+
)
|
|
228
|
+
async def find_column_usage(params: FindColumnUsageInput) -> dict:
|
|
229
|
+
"""Find where and how columns are used across SQL models.
|
|
230
|
+
|
|
231
|
+
Powered by sqlglot's column lineage analysis. Shows usage type,
|
|
232
|
+
transforms (CAST, COALESCE, etc.), output aliases, and WHERE conditions.
|
|
233
|
+
|
|
234
|
+
Answers: "where is customer_id used in WHERE clauses?",
|
|
235
|
+
"how is animal.breed_id transformed?", "show all column usage on orders."
|
|
236
|
+
"""
|
|
237
|
+
return await asyncio.to_thread(
|
|
238
|
+
_get_graph().query_column_usage,
|
|
239
|
+
table=params.table,
|
|
240
|
+
column=params.column,
|
|
241
|
+
usage_type=params.usage_type,
|
|
242
|
+
repo=params.repo,
|
|
243
|
+
limit=params.limit,
|
|
244
|
+
offset=params.offset,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
class TraceDependenciesInput(BaseModel):
|
|
249
|
+
name: str = Field(..., description="Starting entity name")
|
|
250
|
+
kind: str | None = Field(None, description="Filter by node kind to disambiguate")
|
|
251
|
+
direction: Literal["upstream", "downstream", "both"] = Field(
|
|
252
|
+
"downstream",
|
|
253
|
+
description="'upstream', 'downstream', or 'both'",
|
|
254
|
+
)
|
|
255
|
+
max_depth: int = Field(
|
|
256
|
+
3,
|
|
257
|
+
description="Maximum hops to traverse (default 3, max 6)",
|
|
258
|
+
ge=1,
|
|
259
|
+
le=6,
|
|
260
|
+
)
|
|
261
|
+
repo: str | None = Field(
|
|
262
|
+
None,
|
|
263
|
+
description="Filter by repo name. Omit to trace across all repos.",
|
|
264
|
+
)
|
|
265
|
+
include_snippets: bool = Field(
|
|
266
|
+
False,
|
|
267
|
+
description="Include source code snippets (default false for trace, can be large)",
|
|
268
|
+
)
|
|
269
|
+
limit: int = Field(100, description="Max results (default 100)", ge=1, le=500)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
@mcp.tool(
|
|
273
|
+
name="trace_dependencies",
|
|
274
|
+
annotations={
|
|
275
|
+
"readOnlyHint": True,
|
|
276
|
+
"destructiveHint": False,
|
|
277
|
+
"idempotentHint": True,
|
|
278
|
+
"openWorldHint": False,
|
|
279
|
+
},
|
|
280
|
+
)
|
|
281
|
+
async def trace_dependencies(params: TraceDependenciesInput) -> dict:
|
|
282
|
+
"""Trace multi-hop dependency chains through the SQL graph.
|
|
283
|
+
|
|
284
|
+
Follows table → view → CTE → query chains. Use for impact analysis:
|
|
285
|
+
"if I change this table, what models break?"
|
|
286
|
+
"""
|
|
287
|
+
return await asyncio.to_thread(
|
|
288
|
+
_get_graph().query_trace,
|
|
289
|
+
name=params.name,
|
|
290
|
+
kind=params.kind,
|
|
291
|
+
direction=params.direction,
|
|
292
|
+
max_depth=params.max_depth,
|
|
293
|
+
repo=params.repo,
|
|
294
|
+
include_snippets=params.include_snippets,
|
|
295
|
+
limit=params.limit,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
class TraceColumnLineageInput(BaseModel):
|
|
300
|
+
table: str | None = Field(
|
|
301
|
+
None,
|
|
302
|
+
description="Source or intermediate table name to trace lineage for",
|
|
303
|
+
)
|
|
304
|
+
column: str | None = Field(
|
|
305
|
+
None,
|
|
306
|
+
description="Column name to trace",
|
|
307
|
+
)
|
|
308
|
+
output_node: str | None = Field(
|
|
309
|
+
None,
|
|
310
|
+
description="Output entity name (table/view/query) to trace lineage from",
|
|
311
|
+
)
|
|
312
|
+
repo: str | None = Field(None, description="Filter by repo name. Omit to search all repos.")
|
|
313
|
+
limit: int = Field(100, description="Max lineage chains to return (default 100)", ge=1, le=500)
|
|
314
|
+
offset: int = Field(0, description="Number of chains to skip for pagination (default 0)", ge=0)
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
@mcp.tool(
|
|
318
|
+
name="trace_column_lineage",
|
|
319
|
+
annotations={
|
|
320
|
+
"readOnlyHint": True,
|
|
321
|
+
"destructiveHint": False,
|
|
322
|
+
"idempotentHint": True,
|
|
323
|
+
"openWorldHint": False,
|
|
324
|
+
},
|
|
325
|
+
)
|
|
326
|
+
async def trace_column_lineage(params: TraceColumnLineageInput) -> dict:
|
|
327
|
+
"""Trace end-to-end column lineage through CTEs and subqueries.
|
|
328
|
+
|
|
329
|
+
Shows how an output column traces back to source table columns, with
|
|
330
|
+
each intermediate hop (CTE, subquery) and any transforms (CAST, etc.).
|
|
331
|
+
|
|
332
|
+
Answers: "where does dim_users.created_date come from?",
|
|
333
|
+
"which output columns depend on orders.amount?"
|
|
334
|
+
|
|
335
|
+
Note: SELECT * lineage requires a schema catalog built from prior column
|
|
336
|
+
usage data. On a fresh index, SELECT * columns may not be expanded.
|
|
337
|
+
Run a second full reindex to populate the catalog and resolve them.
|
|
338
|
+
"""
|
|
339
|
+
return await asyncio.to_thread(
|
|
340
|
+
_get_graph().query_column_lineage,
|
|
341
|
+
table=params.table,
|
|
342
|
+
column=params.column,
|
|
343
|
+
output_node=params.output_node,
|
|
344
|
+
repo=params.repo,
|
|
345
|
+
limit=params.limit,
|
|
346
|
+
offset=params.offset,
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
class PrImpactInput(BaseModel):
|
|
351
|
+
base_commit: str = Field(
|
|
352
|
+
...,
|
|
353
|
+
description="Git commit hash or ref to compare against (e.g., 'main', 'abc123f')",
|
|
354
|
+
)
|
|
355
|
+
repo: str | None = Field(
|
|
356
|
+
None,
|
|
357
|
+
description="Repo to analyse. Required if multiple repos configured.",
|
|
358
|
+
)
|
|
359
|
+
max_blast_radius_depth: int = Field(
|
|
360
|
+
3,
|
|
361
|
+
description="Hops to trace from changed nodes (default 3)",
|
|
362
|
+
ge=1,
|
|
363
|
+
le=6,
|
|
364
|
+
)
|
|
365
|
+
compare_mode: Literal["delta", "absolute"] = Field(
|
|
366
|
+
"delta",
|
|
367
|
+
description=("'delta' = net-new impact vs base (default), 'absolute' = total blast radius (v1 behavior)"),
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
@mcp.tool(
|
|
372
|
+
name="pr_impact",
|
|
373
|
+
annotations={
|
|
374
|
+
"readOnlyHint": True,
|
|
375
|
+
"destructiveHint": False,
|
|
376
|
+
"idempotentHint": False,
|
|
377
|
+
"openWorldHint": False,
|
|
378
|
+
},
|
|
379
|
+
)
|
|
380
|
+
async def pr_impact(params: PrImpactInput) -> dict:
|
|
381
|
+
"""Analyse the structural impact of SQL changes since a base commit.
|
|
382
|
+
|
|
383
|
+
Computes structural diff (added/removed/modified tables, views, CTEs,
|
|
384
|
+
column usage) then traces the blast radius through the full index.
|
|
385
|
+
|
|
386
|
+
**Delta mode caveat:** ``compare_mode="delta"`` shows **net-new downstream
|
|
387
|
+
impact** by approximating the base blast radius via edge exclusion on the
|
|
388
|
+
HEAD graph. It does **not** detect reduced blast radius from removed edges
|
|
389
|
+
— ``no_longer_affected`` will be empty when a PR only removes dependencies.
|
|
390
|
+
Use ``compare_mode="absolute"`` for a full picture when edge removals are
|
|
391
|
+
the primary change.
|
|
392
|
+
"""
|
|
393
|
+
state = _get_state()
|
|
394
|
+
indexer = state.indexer
|
|
395
|
+
graph = state.graph
|
|
396
|
+
config = state.config
|
|
397
|
+
|
|
398
|
+
# Determine which repo
|
|
399
|
+
if params.repo:
|
|
400
|
+
path, dialect, dialect_overrides = _resolve_repo_config(params.repo)
|
|
401
|
+
repo_path = Path(path)
|
|
402
|
+
elif len(config["repos"]) == 1:
|
|
403
|
+
repo_name = list(config["repos"].keys())[0]
|
|
404
|
+
path, dialect, dialect_overrides = _resolve_repo_config(repo_name)
|
|
405
|
+
repo_path = Path(path)
|
|
406
|
+
else:
|
|
407
|
+
return {"error": "Multiple repos configured — specify which repo to analyse."}
|
|
408
|
+
|
|
409
|
+
def _blocking_pr_impact() -> dict:
|
|
410
|
+
changed_files = indexer.get_changed_files(repo_path, params.base_commit)
|
|
411
|
+
if not changed_files:
|
|
412
|
+
return {"files_changed": [], "structural_diff": {}, "blast_radius": {}}
|
|
413
|
+
|
|
414
|
+
old_results: dict[str, ParseResult] = {}
|
|
415
|
+
new_results: dict[str, ParseResult] = {}
|
|
416
|
+
|
|
417
|
+
for file_path in changed_files:
|
|
418
|
+
full_path = repo_path / file_path
|
|
419
|
+
if full_path.exists() and is_sql_file(file_path):
|
|
420
|
+
content = full_path.read_text(errors="replace")
|
|
421
|
+
new_results[file_path] = indexer.parse_file(file_path, content, dialect)
|
|
422
|
+
|
|
423
|
+
old = indexer.parse_file_at_commit(repo_path, file_path, params.base_commit, dialect)
|
|
424
|
+
if old:
|
|
425
|
+
old_results[file_path] = old
|
|
426
|
+
|
|
427
|
+
diff = _compute_structural_diff(old_results, new_results)
|
|
428
|
+
|
|
429
|
+
affected_node_names = (
|
|
430
|
+
[n["name"] for n in diff["nodes_added"]]
|
|
431
|
+
+ [n["name"] for n in diff["nodes_removed"]]
|
|
432
|
+
+ [n["name"] for n in diff["nodes_modified"]]
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
# Names of truly new nodes (no base trace needed for these)
|
|
436
|
+
added_names = {n["name"] for n in diff["nodes_added"]}
|
|
437
|
+
|
|
438
|
+
# Build exclude set: edges added in HEAD that did not exist at base
|
|
439
|
+
edges_added_set: set[tuple[str, str]] = {(e["source"], e["target"]) for e in diff.get("edges_added", [])}
|
|
440
|
+
|
|
441
|
+
is_delta = params.compare_mode == "delta"
|
|
442
|
+
|
|
443
|
+
blast_radius: dict = {}
|
|
444
|
+
if affected_node_names:
|
|
445
|
+
head_affected: set[tuple[str, str]] = set()
|
|
446
|
+
base_affected: set[tuple[str, str]] = set()
|
|
447
|
+
all_head_paths: list[dict] = [] # flat list for repo counting
|
|
448
|
+
repos_hit: set[str] = set()
|
|
449
|
+
truncated = len(affected_node_names) > 20
|
|
450
|
+
|
|
451
|
+
affected_node_names.sort()
|
|
452
|
+
for node_name in affected_node_names[:20]:
|
|
453
|
+
# HEAD blast radius (current graph)
|
|
454
|
+
head_trace = graph.query_trace(
|
|
455
|
+
name=node_name,
|
|
456
|
+
direction="downstream",
|
|
457
|
+
max_depth=params.max_blast_radius_depth,
|
|
458
|
+
)
|
|
459
|
+
head_paths = head_trace.get("paths", [])
|
|
460
|
+
head_affected.update((p["name"], p["kind"]) for p in head_paths)
|
|
461
|
+
all_head_paths.extend(head_paths)
|
|
462
|
+
repos_hit.update(head_trace.get("repos_affected", []))
|
|
463
|
+
|
|
464
|
+
# Base blast radius approximation (exclude new edges)
|
|
465
|
+
if is_delta and node_name not in added_names:
|
|
466
|
+
base_trace = graph.query_trace(
|
|
467
|
+
name=node_name,
|
|
468
|
+
direction="downstream",
|
|
469
|
+
max_depth=params.max_blast_radius_depth,
|
|
470
|
+
exclude_edges=edges_added_set,
|
|
471
|
+
)
|
|
472
|
+
base_affected.update((p["name"], p["kind"]) for p in base_trace.get("paths", []))
|
|
473
|
+
|
|
474
|
+
if is_delta:
|
|
475
|
+
newly_affected = head_affected - base_affected
|
|
476
|
+
no_longer_affected = base_affected - head_affected
|
|
477
|
+
|
|
478
|
+
blast_radius = {
|
|
479
|
+
"compare_mode": "delta",
|
|
480
|
+
"head_total": len(head_affected),
|
|
481
|
+
"base_total": len(base_affected),
|
|
482
|
+
"delta": len(head_affected) - len(base_affected),
|
|
483
|
+
"newly_affected": [{"name": n, "kind": k} for n, k in sorted(newly_affected)],
|
|
484
|
+
"no_longer_affected": [{"name": n, "kind": k} for n, k in sorted(no_longer_affected)],
|
|
485
|
+
"unchanged_affected": len(head_affected & base_affected),
|
|
486
|
+
"note": (
|
|
487
|
+
"Delta mode approximates the base blast radius by "
|
|
488
|
+
"excluding newly-added edges from the HEAD graph. "
|
|
489
|
+
"It shows net-new downstream impact but does not "
|
|
490
|
+
"detect reduced blast radius from removed edges."
|
|
491
|
+
),
|
|
492
|
+
# Backward-compat fields
|
|
493
|
+
"transitively_affected": len(head_affected),
|
|
494
|
+
"repos_affected": sorted(repos_hit),
|
|
495
|
+
"truncated": truncated,
|
|
496
|
+
"total_affected_nodes": len(affected_node_names),
|
|
497
|
+
}
|
|
498
|
+
else:
|
|
499
|
+
# Absolute mode (v1 behavior)
|
|
500
|
+
blast_radius = {
|
|
501
|
+
"compare_mode": "absolute",
|
|
502
|
+
"transitively_affected": len(all_head_paths),
|
|
503
|
+
"affected_by_repo": {r: sum(1 for a in all_head_paths if a.get("repo") == r) for r in repos_hit},
|
|
504
|
+
"repos_affected": sorted(repos_hit),
|
|
505
|
+
"truncated": truncated,
|
|
506
|
+
"total_affected_nodes": len(affected_node_names),
|
|
507
|
+
}
|
|
508
|
+
|
|
509
|
+
if truncated:
|
|
510
|
+
blast_radius["truncation_message"] = (
|
|
511
|
+
f"Blast radius incomplete — {len(affected_node_names)} affected nodes, "
|
|
512
|
+
"only first 20 traced. Use trace_dependencies "
|
|
513
|
+
"on specific nodes for full picture."
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
return {
|
|
517
|
+
"files_changed": changed_files,
|
|
518
|
+
"structural_diff": diff,
|
|
519
|
+
"blast_radius": blast_radius,
|
|
520
|
+
}
|
|
521
|
+
|
|
522
|
+
return await asyncio.to_thread(_blocking_pr_impact)
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
# ── Index management tools ──
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
class ReindexInput(BaseModel):
|
|
529
|
+
repo: str | None = Field(None, description="Specific repo to reindex. Omit for all repos.")
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
@mcp.tool(
|
|
533
|
+
name="reindex",
|
|
534
|
+
annotations={
|
|
535
|
+
"readOnlyHint": False,
|
|
536
|
+
"destructiveHint": False,
|
|
537
|
+
"idempotentHint": True,
|
|
538
|
+
"openWorldHint": False,
|
|
539
|
+
},
|
|
540
|
+
)
|
|
541
|
+
async def reindex(params: ReindexInput) -> dict:
|
|
542
|
+
"""Trigger a reindex of SQL files. Checksums and re-parses only what changed.
|
|
543
|
+
|
|
544
|
+
Runs in the background so queries remain available during reindex.
|
|
545
|
+
Supports per-repo SQL dialects and path-based dialect overrides.
|
|
546
|
+
"""
|
|
547
|
+
global _reindex_task, _reindex_status
|
|
548
|
+
|
|
549
|
+
async with _reindex_lock:
|
|
550
|
+
# If already running, return status
|
|
551
|
+
if _reindex_task and not _reindex_task.done():
|
|
552
|
+
return {"status": "in_progress", **_reindex_status}
|
|
553
|
+
|
|
554
|
+
state = _get_state()
|
|
555
|
+
indexer = state.indexer
|
|
556
|
+
|
|
557
|
+
repos = state.config["repos"]
|
|
558
|
+
if params.repo:
|
|
559
|
+
if params.repo not in repos:
|
|
560
|
+
return {"error": f"Repo '{params.repo}' not found in config"}
|
|
561
|
+
repos = {params.repo: repos[params.repo]}
|
|
562
|
+
|
|
563
|
+
repo_names = list(repos.keys())
|
|
564
|
+
_reindex_status = {
|
|
565
|
+
"state": "started",
|
|
566
|
+
"started_at": datetime.now().isoformat(),
|
|
567
|
+
"repos": repo_names,
|
|
568
|
+
}
|
|
569
|
+
|
|
570
|
+
async def _background_reindex():
|
|
571
|
+
global _reindex_status
|
|
572
|
+
try:
|
|
573
|
+
|
|
574
|
+
def _blocking():
|
|
575
|
+
global _reindex_status
|
|
576
|
+
results = {}
|
|
577
|
+
for name, cfg in repos.items():
|
|
578
|
+
_reindex_status = {**_reindex_status, "current_repo": name}
|
|
579
|
+
path, dialect, dialect_overrides = _resolve_repo_config(name)
|
|
580
|
+
results[name] = indexer.reindex_repo(
|
|
581
|
+
name,
|
|
582
|
+
path,
|
|
583
|
+
dialect=dialect,
|
|
584
|
+
dialect_overrides=dialect_overrides,
|
|
585
|
+
)
|
|
586
|
+
return results
|
|
587
|
+
|
|
588
|
+
result = await asyncio.to_thread(_blocking)
|
|
589
|
+
global _last_parse_errors
|
|
590
|
+
all_errors = []
|
|
591
|
+
for repo_result in result.values():
|
|
592
|
+
all_errors.extend(repo_result.get("parse_errors", []))
|
|
593
|
+
_last_parse_errors = all_errors
|
|
594
|
+
_reindex_status = {
|
|
595
|
+
**_reindex_status,
|
|
596
|
+
"state": "completed",
|
|
597
|
+
"completed_at": datetime.now().isoformat(),
|
|
598
|
+
"result": result,
|
|
599
|
+
}
|
|
600
|
+
return result
|
|
601
|
+
except Exception as e:
|
|
602
|
+
_reindex_status = {
|
|
603
|
+
**_reindex_status,
|
|
604
|
+
"state": "failed",
|
|
605
|
+
"error": str(e),
|
|
606
|
+
"failed_at": datetime.now().isoformat(),
|
|
607
|
+
}
|
|
608
|
+
|
|
609
|
+
_reindex_task = asyncio.create_task(_background_reindex())
|
|
610
|
+
|
|
611
|
+
return {
|
|
612
|
+
"status": "started",
|
|
613
|
+
"message": ("Reindex running in background. Queries remain available. Call index_status to check progress."),
|
|
614
|
+
"repos": repo_names,
|
|
615
|
+
}
|
|
616
|
+
|
|
617
|
+
|
|
618
|
+
class ReindexSqlmeshInput(BaseModel):
|
|
619
|
+
name: str = Field(..., description="Repo name for the index")
|
|
620
|
+
project_path: str = Field(
|
|
621
|
+
...,
|
|
622
|
+
description="Path to sqlmesh project dir (containing config.yaml)",
|
|
623
|
+
)
|
|
624
|
+
env_file: str | None = Field(
|
|
625
|
+
None,
|
|
626
|
+
description="Path to .env file for sqlmesh config variables",
|
|
627
|
+
)
|
|
628
|
+
dialect: str = Field(
|
|
629
|
+
"athena",
|
|
630
|
+
description="SQL dialect for rendering (default: athena)",
|
|
631
|
+
)
|
|
632
|
+
variables: dict[str, str] | None = Field(
|
|
633
|
+
None,
|
|
634
|
+
description='SQLMesh variables, e.g. {"GRACE_PERIOD": "7"}',
|
|
635
|
+
)
|
|
636
|
+
sqlmesh_command: str = Field(
|
|
637
|
+
"uv run python",
|
|
638
|
+
description="Command to run python in sqlmesh venv",
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
|
|
642
|
+
@mcp.tool(
|
|
643
|
+
name="reindex_sqlmesh",
|
|
644
|
+
annotations={
|
|
645
|
+
"readOnlyHint": False,
|
|
646
|
+
"destructiveHint": False,
|
|
647
|
+
"idempotentHint": True,
|
|
648
|
+
"openWorldHint": False,
|
|
649
|
+
},
|
|
650
|
+
)
|
|
651
|
+
async def reindex_sqlmesh(params: ReindexSqlmeshInput) -> dict:
|
|
652
|
+
"""Index a sqlmesh project by rendering all models into clean SQL.
|
|
653
|
+
|
|
654
|
+
Runs in the background so queries remain available during reindex.
|
|
655
|
+
Uses sqlmesh's rendering engine to expand macros and resolve variables,
|
|
656
|
+
then parses with sqlglot to extract tables, CTEs, edges, column lineage.
|
|
657
|
+
"""
|
|
658
|
+
global _reindex_task, _reindex_status
|
|
659
|
+
|
|
660
|
+
async with _reindex_lock:
|
|
661
|
+
# If already running, return status
|
|
662
|
+
if _reindex_task and not _reindex_task.done():
|
|
663
|
+
return {"status": "in_progress", **_reindex_status}
|
|
664
|
+
|
|
665
|
+
indexer = _get_indexer()
|
|
666
|
+
|
|
667
|
+
var_dict: dict[str, str | int] = {}
|
|
668
|
+
if params.variables:
|
|
669
|
+
for k, v in params.variables.items():
|
|
670
|
+
try:
|
|
671
|
+
var_dict[k] = int(v)
|
|
672
|
+
except ValueError:
|
|
673
|
+
var_dict[k] = v
|
|
674
|
+
|
|
675
|
+
_reindex_status = {
|
|
676
|
+
"state": "started",
|
|
677
|
+
"started_at": datetime.now().isoformat(),
|
|
678
|
+
"repos": [params.name],
|
|
679
|
+
"tool": "reindex_sqlmesh",
|
|
680
|
+
}
|
|
681
|
+
|
|
682
|
+
async def _background_reindex():
|
|
683
|
+
global _reindex_status
|
|
684
|
+
try:
|
|
685
|
+
result = await asyncio.to_thread(
|
|
686
|
+
indexer.reindex_sqlmesh,
|
|
687
|
+
repo_name=params.name,
|
|
688
|
+
project_path=params.project_path,
|
|
689
|
+
env_file=params.env_file,
|
|
690
|
+
variables=var_dict,
|
|
691
|
+
dialect=params.dialect,
|
|
692
|
+
sqlmesh_command=params.sqlmesh_command,
|
|
693
|
+
)
|
|
694
|
+
global _last_parse_errors
|
|
695
|
+
if isinstance(result, dict):
|
|
696
|
+
_last_parse_errors = result.get("parse_errors", [])
|
|
697
|
+
_reindex_status = {
|
|
698
|
+
**_reindex_status,
|
|
699
|
+
"state": "completed",
|
|
700
|
+
"completed_at": datetime.now().isoformat(),
|
|
701
|
+
"result": result,
|
|
702
|
+
}
|
|
703
|
+
return result
|
|
704
|
+
except Exception as e:
|
|
705
|
+
_reindex_status = {
|
|
706
|
+
**_reindex_status,
|
|
707
|
+
"state": "failed",
|
|
708
|
+
"error": str(e),
|
|
709
|
+
"failed_at": datetime.now().isoformat(),
|
|
710
|
+
}
|
|
711
|
+
|
|
712
|
+
_reindex_task = asyncio.create_task(_background_reindex())
|
|
713
|
+
|
|
714
|
+
return {
|
|
715
|
+
"status": "started",
|
|
716
|
+
"message": (
|
|
717
|
+
"SQLMesh reindex running in background. Queries remain available. Call index_status to check progress."
|
|
718
|
+
),
|
|
719
|
+
"repos": [params.name],
|
|
720
|
+
}
|
|
721
|
+
|
|
722
|
+
|
|
723
|
+
class ReindexDbtInput(BaseModel):
|
|
724
|
+
name: str = Field(..., description="Repo name for the index")
|
|
725
|
+
project_path: str = Field(
|
|
726
|
+
...,
|
|
727
|
+
description="Path to dbt project dir (containing dbt_project.yml)",
|
|
728
|
+
)
|
|
729
|
+
profiles_dir: str | None = Field(
|
|
730
|
+
None,
|
|
731
|
+
description="Path to directory containing profiles.yml",
|
|
732
|
+
)
|
|
733
|
+
env_file: str | None = Field(
|
|
734
|
+
None,
|
|
735
|
+
description="Path to .env file for dbt connection variables",
|
|
736
|
+
)
|
|
737
|
+
target: str | None = Field(None, description="dbt target name")
|
|
738
|
+
dbt_command: str = Field(
|
|
739
|
+
"uv run dbt",
|
|
740
|
+
description="Command to invoke dbt",
|
|
741
|
+
)
|
|
742
|
+
dialect: str | None = Field(
|
|
743
|
+
None,
|
|
744
|
+
description="SQL dialect for parsing (e.g. 'starrocks', 'mysql', 'postgres')",
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
|
|
748
|
+
@mcp.tool(
|
|
749
|
+
name="reindex_dbt",
|
|
750
|
+
annotations={
|
|
751
|
+
"readOnlyHint": False,
|
|
752
|
+
"destructiveHint": False,
|
|
753
|
+
"idempotentHint": True,
|
|
754
|
+
"openWorldHint": False,
|
|
755
|
+
},
|
|
756
|
+
)
|
|
757
|
+
async def reindex_dbt(params: ReindexDbtInput) -> dict:
|
|
758
|
+
"""Index a dbt project by compiling all models into clean SQL.
|
|
759
|
+
|
|
760
|
+
Runs in the background so queries remain available during reindex.
|
|
761
|
+
Runs `dbt compile`, then parses with sqlglot to extract tables, CTEs,
|
|
762
|
+
edges, column lineage with transforms.
|
|
763
|
+
"""
|
|
764
|
+
global _reindex_task, _reindex_status
|
|
765
|
+
|
|
766
|
+
async with _reindex_lock:
|
|
767
|
+
# If already running, return status
|
|
768
|
+
if _reindex_task and not _reindex_task.done():
|
|
769
|
+
return {"status": "in_progress", **_reindex_status}
|
|
770
|
+
|
|
771
|
+
indexer = _get_indexer()
|
|
772
|
+
|
|
773
|
+
_reindex_status = {
|
|
774
|
+
"state": "started",
|
|
775
|
+
"started_at": datetime.now().isoformat(),
|
|
776
|
+
"repos": [params.name],
|
|
777
|
+
"tool": "reindex_dbt",
|
|
778
|
+
}
|
|
779
|
+
|
|
780
|
+
async def _background_reindex():
|
|
781
|
+
global _reindex_status
|
|
782
|
+
try:
|
|
783
|
+
result = await asyncio.to_thread(
|
|
784
|
+
indexer.reindex_dbt,
|
|
785
|
+
repo_name=params.name,
|
|
786
|
+
project_path=params.project_path,
|
|
787
|
+
profiles_dir=params.profiles_dir,
|
|
788
|
+
env_file=params.env_file,
|
|
789
|
+
target=params.target,
|
|
790
|
+
dbt_command=params.dbt_command,
|
|
791
|
+
dialect=params.dialect,
|
|
792
|
+
)
|
|
793
|
+
global _last_parse_errors
|
|
794
|
+
if isinstance(result, dict):
|
|
795
|
+
_last_parse_errors = result.get("parse_errors", [])
|
|
796
|
+
_reindex_status = {
|
|
797
|
+
**_reindex_status,
|
|
798
|
+
"state": "completed",
|
|
799
|
+
"completed_at": datetime.now().isoformat(),
|
|
800
|
+
"result": result,
|
|
801
|
+
}
|
|
802
|
+
return result
|
|
803
|
+
except Exception as e:
|
|
804
|
+
_reindex_status = {
|
|
805
|
+
**_reindex_status,
|
|
806
|
+
"state": "failed",
|
|
807
|
+
"error": str(e),
|
|
808
|
+
"failed_at": datetime.now().isoformat(),
|
|
809
|
+
}
|
|
810
|
+
|
|
811
|
+
_reindex_task = asyncio.create_task(_background_reindex())
|
|
812
|
+
|
|
813
|
+
return {
|
|
814
|
+
"status": "started",
|
|
815
|
+
"message": (
|
|
816
|
+
"dbt reindex running in background. Queries remain available. Call index_status to check progress."
|
|
817
|
+
),
|
|
818
|
+
"repos": [params.name],
|
|
819
|
+
}
|
|
820
|
+
|
|
821
|
+
|
|
822
|
+
@mcp.tool(
|
|
823
|
+
name="index_status",
|
|
824
|
+
annotations={
|
|
825
|
+
"readOnlyHint": True,
|
|
826
|
+
"destructiveHint": False,
|
|
827
|
+
"idempotentHint": True,
|
|
828
|
+
"openWorldHint": False,
|
|
829
|
+
},
|
|
830
|
+
)
|
|
831
|
+
async def index_status() -> dict:
|
|
832
|
+
"""Current state of the index — repos, file counts, last commit, staleness."""
|
|
833
|
+
status = await asyncio.to_thread(_get_graph().get_index_status)
|
|
834
|
+
if _reindex_task and not _reindex_task.done():
|
|
835
|
+
status["reindex_in_progress"] = True
|
|
836
|
+
status["reindex_status"] = _reindex_status
|
|
837
|
+
elif _reindex_status.get("state") in ("completed", "failed"):
|
|
838
|
+
status["last_reindex"] = _reindex_status
|
|
839
|
+
status["parse_error_count"] = len(_last_parse_errors)
|
|
840
|
+
if _last_parse_errors:
|
|
841
|
+
status["last_parse_errors"] = _last_parse_errors[:50] # cap at 50
|
|
842
|
+
return status
|
|
843
|
+
|
|
844
|
+
|
|
845
|
+
# ── Internal helpers ──
|
|
846
|
+
|
|
847
|
+
|
|
848
|
+
def _node_fingerprint(node: NodeResult) -> str:
|
|
849
|
+
"""Create a comparable fingerprint for a node including metadata."""
|
|
850
|
+
return json.dumps(node.metadata, sort_keys=True) if node.metadata else ""
|
|
851
|
+
|
|
852
|
+
|
|
853
|
+
def _compute_structural_diff(
|
|
854
|
+
old_results: dict[str, ParseResult],
|
|
855
|
+
new_results: dict[str, ParseResult],
|
|
856
|
+
) -> dict:
|
|
857
|
+
"""Compare old and new parse results to find structural changes."""
|
|
858
|
+
old_nodes = set()
|
|
859
|
+
new_nodes = set()
|
|
860
|
+
old_edges = set()
|
|
861
|
+
new_edges = set()
|
|
862
|
+
old_columns = set()
|
|
863
|
+
new_columns = set()
|
|
864
|
+
|
|
865
|
+
# Track edges, columns, and metadata per node to detect actual modifications
|
|
866
|
+
old_node_edges: dict[tuple[str, str, str | None], set] = {}
|
|
867
|
+
new_node_edges: dict[tuple[str, str, str | None], set] = {}
|
|
868
|
+
old_node_columns: dict[tuple[str, str, str | None], set] = {}
|
|
869
|
+
new_node_columns: dict[tuple[str, str, str | None], set] = {}
|
|
870
|
+
old_node_meta: dict[tuple[str, str, str | None], str] = {}
|
|
871
|
+
new_node_meta: dict[tuple[str, str, str | None], str] = {}
|
|
872
|
+
|
|
873
|
+
# Build (name, kind) -> schema lookups so edges/columns can resolve schema
|
|
874
|
+
old_schema_lookup: dict[tuple[str, str], str | None] = {}
|
|
875
|
+
new_schema_lookup: dict[tuple[str, str], str | None] = {}
|
|
876
|
+
|
|
877
|
+
for result in old_results.values():
|
|
878
|
+
for n in result.nodes:
|
|
879
|
+
schema = (n.metadata or {}).get("schema")
|
|
880
|
+
key = (n.name, n.kind, schema)
|
|
881
|
+
old_nodes.add(key)
|
|
882
|
+
old_node_edges.setdefault(key, set())
|
|
883
|
+
old_node_columns.setdefault(key, set())
|
|
884
|
+
old_node_meta[key] = _node_fingerprint(n)
|
|
885
|
+
old_schema_lookup[(n.name, n.kind)] = schema
|
|
886
|
+
for e in result.edges:
|
|
887
|
+
edge_tuple = (
|
|
888
|
+
e.source_name,
|
|
889
|
+
e.source_kind,
|
|
890
|
+
e.target_name,
|
|
891
|
+
e.target_kind,
|
|
892
|
+
e.relationship,
|
|
893
|
+
)
|
|
894
|
+
old_edges.add(edge_tuple)
|
|
895
|
+
src_schema = old_schema_lookup.get((e.source_name, e.source_kind))
|
|
896
|
+
src_key = (e.source_name, e.source_kind, src_schema)
|
|
897
|
+
old_node_edges.setdefault(src_key, set()).add(edge_tuple)
|
|
898
|
+
for c in result.column_usage:
|
|
899
|
+
col_tuple = (c.node_name, c.table_name, c.column_name, c.usage_type)
|
|
900
|
+
old_columns.add(col_tuple)
|
|
901
|
+
col_schema = old_schema_lookup.get((c.node_name, c.node_kind))
|
|
902
|
+
col_key = (c.node_name, c.node_kind, col_schema)
|
|
903
|
+
old_node_columns.setdefault(col_key, set()).add(col_tuple)
|
|
904
|
+
|
|
905
|
+
for result in new_results.values():
|
|
906
|
+
for n in result.nodes:
|
|
907
|
+
schema = (n.metadata or {}).get("schema")
|
|
908
|
+
key = (n.name, n.kind, schema)
|
|
909
|
+
new_nodes.add(key)
|
|
910
|
+
new_node_edges.setdefault(key, set())
|
|
911
|
+
new_node_columns.setdefault(key, set())
|
|
912
|
+
new_node_meta[key] = _node_fingerprint(n)
|
|
913
|
+
new_schema_lookup[(n.name, n.kind)] = schema
|
|
914
|
+
for e in result.edges:
|
|
915
|
+
edge_tuple = (
|
|
916
|
+
e.source_name,
|
|
917
|
+
e.source_kind,
|
|
918
|
+
e.target_name,
|
|
919
|
+
e.target_kind,
|
|
920
|
+
e.relationship,
|
|
921
|
+
)
|
|
922
|
+
new_edges.add(edge_tuple)
|
|
923
|
+
src_schema = new_schema_lookup.get((e.source_name, e.source_kind))
|
|
924
|
+
src_key = (e.source_name, e.source_kind, src_schema)
|
|
925
|
+
new_node_edges.setdefault(src_key, set()).add(edge_tuple)
|
|
926
|
+
for c in result.column_usage:
|
|
927
|
+
col_tuple = (c.node_name, c.table_name, c.column_name, c.usage_type)
|
|
928
|
+
new_columns.add(col_tuple)
|
|
929
|
+
col_schema = new_schema_lookup.get((c.node_name, c.node_kind))
|
|
930
|
+
col_key = (c.node_name, c.node_kind, col_schema)
|
|
931
|
+
new_node_columns.setdefault(col_key, set()).add(col_tuple)
|
|
932
|
+
|
|
933
|
+
# A node present in both is "modified" if its edges, columns, or metadata changed
|
|
934
|
+
nodes_modified = []
|
|
935
|
+
for key in old_nodes & new_nodes:
|
|
936
|
+
if (
|
|
937
|
+
old_node_edges.get(key, set()) != new_node_edges.get(key, set())
|
|
938
|
+
or old_node_columns.get(key, set()) != new_node_columns.get(key, set())
|
|
939
|
+
or old_node_meta.get(key, "") != new_node_meta.get(key, "")
|
|
940
|
+
):
|
|
941
|
+
entry = {"name": key[0], "kind": key[1]}
|
|
942
|
+
if key[2] is not None:
|
|
943
|
+
entry["schema"] = key[2]
|
|
944
|
+
nodes_modified.append(entry)
|
|
945
|
+
|
|
946
|
+
def _node_dict(n):
|
|
947
|
+
d = {"name": n[0], "kind": n[1]}
|
|
948
|
+
if n[2] is not None:
|
|
949
|
+
d["schema"] = n[2]
|
|
950
|
+
return d
|
|
951
|
+
|
|
952
|
+
return {
|
|
953
|
+
"nodes_added": [_node_dict(n) for n in new_nodes - old_nodes],
|
|
954
|
+
"nodes_removed": [_node_dict(n) for n in old_nodes - new_nodes],
|
|
955
|
+
"nodes_modified": nodes_modified,
|
|
956
|
+
"edges_added": [
|
|
957
|
+
{
|
|
958
|
+
"source": e[0],
|
|
959
|
+
"source_kind": e[1],
|
|
960
|
+
"target": e[2],
|
|
961
|
+
"target_kind": e[3],
|
|
962
|
+
"relationship": e[4],
|
|
963
|
+
}
|
|
964
|
+
for e in new_edges - old_edges
|
|
965
|
+
],
|
|
966
|
+
"edges_removed": [
|
|
967
|
+
{
|
|
968
|
+
"source": e[0],
|
|
969
|
+
"source_kind": e[1],
|
|
970
|
+
"target": e[2],
|
|
971
|
+
"target_kind": e[3],
|
|
972
|
+
"relationship": e[4],
|
|
973
|
+
}
|
|
974
|
+
for e in old_edges - new_edges
|
|
975
|
+
],
|
|
976
|
+
"columns_added": [
|
|
977
|
+
{"node": c[0], "table": c[1], "column": c[2], "usage_type": c[3]} for c in new_columns - old_columns
|
|
978
|
+
],
|
|
979
|
+
"columns_removed": [
|
|
980
|
+
{"node": c[0], "table": c[1], "column": c[2], "usage_type": c[3]} for c in old_columns - new_columns
|
|
981
|
+
],
|
|
982
|
+
}
|