sql-glider 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,349 @@
1
+ """Graph builder for constructing lineage graphs from SQL files."""
2
+
3
+ from datetime import datetime, timezone
4
+ from pathlib import Path
5
+ from typing import Callable, Dict, List, Optional, Set
6
+
7
+ import rustworkx as rx
8
+ from rich.console import Console
9
+ from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn
10
+
11
+ from sqlglider.global_models import AnalysisLevel, NodeFormat
12
+ from sqlglider.graph.models import (
13
+ GraphEdge,
14
+ GraphMetadata,
15
+ GraphNode,
16
+ LineageGraph,
17
+ Manifest,
18
+ )
19
+ from sqlglider.lineage.analyzer import LineageAnalyzer
20
+ from sqlglider.utils.file_utils import read_sql_file
21
+
22
+ console = Console(stderr=True)
23
+
24
+ # Type alias for SQL preprocessor functions
25
+ SqlPreprocessor = Callable[[str, Path], str]
26
+
27
+
28
+ class GraphBuilder:
29
+ """Build lineage graphs from SQL files using rustworkx."""
30
+
31
+ def __init__(
32
+ self,
33
+ node_format: NodeFormat = NodeFormat.QUALIFIED,
34
+ dialect: str = "spark",
35
+ sql_preprocessor: Optional[SqlPreprocessor] = None,
36
+ ):
37
+ """
38
+ Initialize the graph builder.
39
+
40
+ Args:
41
+ node_format: Format for node identifiers (QUALIFIED or STRUCTURED)
42
+ dialect: Default SQL dialect (used when not specified per-file)
43
+ sql_preprocessor: Optional function to preprocess SQL before analysis.
44
+ Takes (sql: str, file_path: Path) and returns processed SQL.
45
+ Useful for templating (e.g., Jinja2 rendering).
46
+ """
47
+ self.node_format = node_format
48
+ self.dialect = dialect
49
+ self.sql_preprocessor = sql_preprocessor
50
+ self.graph: rx.PyDiGraph = rx.PyDiGraph()
51
+ self._node_index_map: Dict[str, int] = {} # identifier -> rustworkx node index
52
+ self._source_files: Set[str] = set()
53
+ self._edge_set: Set[tuple] = set() # (source, target) for dedup
54
+ self._skipped_files: List[tuple[str, str]] = [] # (file_path, reason)
55
+
56
+ def add_file(
57
+ self,
58
+ file_path: Path,
59
+ dialect: Optional[str] = None,
60
+ ) -> "GraphBuilder":
61
+ """
62
+ Add lineage from a single SQL file to the graph.
63
+
64
+ Args:
65
+ file_path: Path to SQL file
66
+ dialect: SQL dialect (uses builder default if not specified)
67
+
68
+ Returns:
69
+ self for method chaining
70
+
71
+ Raises:
72
+ FileNotFoundError: If file doesn't exist
73
+ ParseError: If SQL cannot be parsed
74
+ """
75
+ file_dialect = dialect or self.dialect
76
+ file_path_str = str(file_path.resolve())
77
+
78
+ try:
79
+ sql_content = read_sql_file(file_path)
80
+
81
+ # Apply SQL preprocessor if configured (e.g., for templating)
82
+ if self.sql_preprocessor:
83
+ sql_content = self.sql_preprocessor(sql_content, file_path)
84
+
85
+ analyzer = LineageAnalyzer(sql_content, dialect=file_dialect)
86
+ results = analyzer.analyze_queries(level=AnalysisLevel.COLUMN)
87
+
88
+ # Print warnings for any skipped queries within the file
89
+ for skipped in analyzer.skipped_queries:
90
+ console.print(
91
+ f"[yellow]Warning:[/yellow] Skipping query {skipped.query_index} "
92
+ f"in {file_path.name} ({skipped.statement_type}): {skipped.reason}"
93
+ )
94
+
95
+ self._source_files.add(file_path_str)
96
+
97
+ for result in results:
98
+ query_index = result.metadata.query_index
99
+
100
+ for item in result.lineage_items:
101
+ if not item.source_name: # Skip empty sources
102
+ continue
103
+
104
+ # Add/get nodes
105
+ source_node_idx = self._ensure_node(
106
+ item.source_name,
107
+ file_path_str,
108
+ query_index,
109
+ )
110
+ target_node_idx = self._ensure_node(
111
+ item.output_name,
112
+ file_path_str,
113
+ query_index,
114
+ )
115
+
116
+ # Add edge (source contributes_to target) - deduplicate
117
+ edge_key = (item.source_name.lower(), item.output_name.lower())
118
+ if edge_key not in self._edge_set:
119
+ edge = GraphEdge(
120
+ source_node=item.source_name.lower(),
121
+ target_node=item.output_name.lower(),
122
+ file_path=file_path_str,
123
+ query_index=query_index,
124
+ )
125
+ self.graph.add_edge(
126
+ source_node_idx, target_node_idx, edge.model_dump()
127
+ )
128
+ self._edge_set.add(edge_key)
129
+
130
+ except ValueError as e:
131
+ # Skip files that fail completely (all statements unsupported)
132
+ error_msg = str(e)
133
+ self._skipped_files.append((file_path_str, error_msg))
134
+ console.print(
135
+ f"[yellow]Warning:[/yellow] Skipping {file_path.name}: {error_msg}"
136
+ )
137
+
138
+ return self
139
+
140
+ def add_directory(
141
+ self,
142
+ dir_path: Path,
143
+ recursive: bool = False,
144
+ glob_pattern: str = "*.sql",
145
+ dialect: Optional[str] = None,
146
+ ) -> "GraphBuilder":
147
+ """
148
+ Add lineage from all SQL files in a directory.
149
+
150
+ Args:
151
+ dir_path: Path to directory
152
+ recursive: Whether to search recursively
153
+ glob_pattern: Glob pattern for SQL files
154
+ dialect: SQL dialect (uses builder default if not specified)
155
+
156
+ Returns:
157
+ self for method chaining
158
+
159
+ Raises:
160
+ ValueError: If path is not a directory
161
+ """
162
+ if not dir_path.is_dir():
163
+ raise ValueError(f"Not a directory: {dir_path}")
164
+
165
+ if recursive:
166
+ pattern = f"**/{glob_pattern}"
167
+ else:
168
+ pattern = glob_pattern
169
+
170
+ sql_files = [f for f in sorted(dir_path.glob(pattern)) if f.is_file()]
171
+ return self.add_files(sql_files, dialect)
172
+
173
+ def add_manifest(
174
+ self,
175
+ manifest_path: Path,
176
+ dialect: Optional[str] = None,
177
+ ) -> "GraphBuilder":
178
+ """
179
+ Add lineage from files specified in a manifest CSV.
180
+
181
+ Args:
182
+ manifest_path: Path to manifest CSV file
183
+ dialect: Default SQL dialect (overridden by manifest entries)
184
+
185
+ Returns:
186
+ self for method chaining
187
+
188
+ Raises:
189
+ FileNotFoundError: If manifest or referenced files don't exist
190
+ ValueError: If manifest format is invalid
191
+ """
192
+ manifest = Manifest.from_csv(manifest_path)
193
+ base_dir = manifest_path.parent
194
+
195
+ # Collect files with their dialects
196
+ files_with_dialects: List[tuple[Path, str]] = []
197
+ for entry in manifest.entries:
198
+ # Resolve file path relative to manifest location
199
+ file_path = Path(entry.file_path)
200
+ if not file_path.is_absolute():
201
+ file_path = (base_dir / entry.file_path).resolve()
202
+
203
+ # Use entry dialect, then CLI dialect, then builder default
204
+ entry_dialect = entry.dialect or dialect or self.dialect
205
+ files_with_dialects.append((file_path, entry_dialect))
206
+
207
+ # Process with progress
208
+ if files_with_dialects:
209
+ total = len(files_with_dialects)
210
+ with Progress(
211
+ TextColumn("[progress.description]{task.description}"),
212
+ BarColumn(),
213
+ TaskProgressColumn(),
214
+ console=console,
215
+ transient=False,
216
+ ) as progress:
217
+ task = progress.add_task("Parsing", total=total)
218
+ for i, (file_path, file_dialect) in enumerate(
219
+ files_with_dialects, start=1
220
+ ):
221
+ console.print(f"Parsing file {i}/{total}: {file_path.name}")
222
+ self.add_file(file_path, file_dialect)
223
+ progress.advance(task)
224
+
225
+ return self
226
+
227
+ def add_files(
228
+ self,
229
+ file_paths: List[Path],
230
+ dialect: Optional[str] = None,
231
+ show_progress: bool = True,
232
+ ) -> "GraphBuilder":
233
+ """
234
+ Add lineage from multiple SQL files.
235
+
236
+ Args:
237
+ file_paths: List of paths to SQL files
238
+ dialect: SQL dialect (uses builder default if not specified)
239
+ show_progress: Whether to print progress messages
240
+
241
+ Returns:
242
+ self for method chaining
243
+ """
244
+ if not file_paths:
245
+ return self
246
+
247
+ if show_progress:
248
+ total = len(file_paths)
249
+ with Progress(
250
+ TextColumn("[progress.description]{task.description}"),
251
+ BarColumn(),
252
+ TaskProgressColumn(),
253
+ console=console,
254
+ transient=False,
255
+ ) as progress:
256
+ task = progress.add_task("Parsing", total=total)
257
+ for i, file_path in enumerate(file_paths, start=1):
258
+ console.print(f"Parsing file {i}/{total}: {file_path.name}")
259
+ self.add_file(file_path, dialect)
260
+ progress.advance(task)
261
+ else:
262
+ for file_path in file_paths:
263
+ self.add_file(file_path, dialect)
264
+ return self
265
+
266
+ def _ensure_node(
267
+ self,
268
+ identifier: str,
269
+ file_path: str,
270
+ query_index: int,
271
+ ) -> int:
272
+ """
273
+ Ensure a node exists in the graph, creating it if necessary.
274
+
275
+ Args:
276
+ identifier: Node identifier (e.g., "table.column")
277
+ file_path: Source file path
278
+ query_index: Query index within file
279
+
280
+ Returns:
281
+ rustworkx node index
282
+ """
283
+ key = identifier.lower()
284
+ if key in self._node_index_map:
285
+ return self._node_index_map[key]
286
+
287
+ node = GraphNode.from_identifier(
288
+ identifier=key,
289
+ file_path=file_path,
290
+ query_index=query_index,
291
+ )
292
+
293
+ node_idx = self.graph.add_node(node.model_dump())
294
+ self._node_index_map[key] = node_idx
295
+ return node_idx
296
+
297
+ def build(self) -> LineageGraph:
298
+ """
299
+ Build and return the final LineageGraph.
300
+
301
+ Returns:
302
+ LineageGraph with metadata, nodes, and edges
303
+ """
304
+ nodes = []
305
+ for idx in self.graph.node_indices():
306
+ node_data = self.graph[idx]
307
+ nodes.append(GraphNode(**node_data))
308
+
309
+ edges = []
310
+ for edge_idx in self.graph.edge_indices():
311
+ edge_data = self.graph.get_edge_data_by_index(edge_idx)
312
+ edges.append(GraphEdge(**edge_data))
313
+
314
+ metadata = GraphMetadata(
315
+ node_format=self.node_format,
316
+ default_dialect=self.dialect,
317
+ created_at=datetime.now(timezone.utc).isoformat(),
318
+ source_files=sorted(self._source_files),
319
+ total_nodes=len(nodes),
320
+ total_edges=len(edges),
321
+ )
322
+
323
+ # Print summary of skipped files if any
324
+ if self._skipped_files:
325
+ console.print(
326
+ f"\n[yellow]Summary:[/yellow] Skipped {len(self._skipped_files)} "
327
+ f"file(s) that could not be analyzed for lineage."
328
+ )
329
+
330
+ return LineageGraph(
331
+ metadata=metadata,
332
+ nodes=nodes,
333
+ edges=edges,
334
+ )
335
+
336
+ @property
337
+ def rustworkx_graph(self) -> rx.PyDiGraph:
338
+ """Get the underlying rustworkx graph for direct operations."""
339
+ return self.graph
340
+
341
+ @property
342
+ def node_index_map(self) -> Dict[str, int]:
343
+ """Get mapping from node identifiers to rustworkx indices."""
344
+ return self._node_index_map.copy()
345
+
346
+ @property
347
+ def skipped_files(self) -> List[tuple[str, str]]:
348
+ """Get list of files that were skipped during graph building."""
349
+ return self._skipped_files.copy()
@@ -0,0 +1,136 @@
1
+ """Graph merging functionality."""
2
+
3
+ from datetime import datetime, timezone
4
+ from pathlib import Path
5
+ from typing import Dict, List, Set
6
+
7
+ import rustworkx as rx
8
+
9
+ from sqlglider.global_models import NodeFormat
10
+ from sqlglider.graph.models import (
11
+ GraphEdge,
12
+ GraphMetadata,
13
+ GraphNode,
14
+ LineageGraph,
15
+ )
16
+ from sqlglider.graph.serialization import load_graph
17
+
18
+
19
+ class GraphMerger:
20
+ """Merge multiple lineage graphs into one."""
21
+
22
+ def __init__(self):
23
+ """Initialize the merger."""
24
+ self.merged_graph: rx.PyDiGraph = rx.PyDiGraph()
25
+ self._node_map: Dict[str, int] = {} # identifier -> node index
26
+ self._source_files: Set[str] = set()
27
+ self._edge_set: Set[tuple] = set() # (source, target) for dedup
28
+
29
+ def add_graph(self, graph: LineageGraph) -> "GraphMerger":
30
+ """
31
+ Add a graph to be merged.
32
+
33
+ Nodes are deduplicated by identifier (first occurrence wins).
34
+ Edges are deduplicated by (source_node, target_node) pair.
35
+
36
+ Args:
37
+ graph: LineageGraph to add
38
+
39
+ Returns:
40
+ self for method chaining
41
+ """
42
+ self._source_files.update(graph.metadata.source_files)
43
+
44
+ # Add nodes (deduplicate by identifier)
45
+ for node in graph.nodes:
46
+ if node.identifier not in self._node_map:
47
+ idx = self.merged_graph.add_node(node.model_dump())
48
+ self._node_map[node.identifier] = idx
49
+
50
+ # Add edges (deduplicate by source-target pair)
51
+ for edge in graph.edges:
52
+ edge_key = (edge.source_node, edge.target_node)
53
+ if edge_key not in self._edge_set:
54
+ source_idx = self._node_map.get(edge.source_node)
55
+ target_idx = self._node_map.get(edge.target_node)
56
+ if source_idx is not None and target_idx is not None:
57
+ self.merged_graph.add_edge(
58
+ source_idx, target_idx, edge.model_dump()
59
+ )
60
+ self._edge_set.add(edge_key)
61
+
62
+ return self
63
+
64
+ def add_file(self, graph_path: Path) -> "GraphMerger":
65
+ """
66
+ Add a graph from a JSON file.
67
+
68
+ Args:
69
+ graph_path: Path to graph JSON file
70
+
71
+ Returns:
72
+ self for method chaining
73
+
74
+ Raises:
75
+ FileNotFoundError: If file doesn't exist
76
+ ValueError: If file is not valid graph JSON
77
+ """
78
+ graph = load_graph(graph_path)
79
+ return self.add_graph(graph)
80
+
81
+ def add_files(self, graph_paths: List[Path]) -> "GraphMerger":
82
+ """
83
+ Add multiple graphs from JSON files.
84
+
85
+ Args:
86
+ graph_paths: List of paths to graph JSON files
87
+
88
+ Returns:
89
+ self for method chaining
90
+ """
91
+ for path in graph_paths:
92
+ self.add_file(path)
93
+ return self
94
+
95
+ def merge(self) -> LineageGraph:
96
+ """
97
+ Build the merged graph.
98
+
99
+ Returns:
100
+ Merged LineageGraph with combined nodes and edges
101
+ """
102
+ nodes = [
103
+ GraphNode(**self.merged_graph[idx])
104
+ for idx in self.merged_graph.node_indices()
105
+ ]
106
+ edges = [
107
+ GraphEdge(**self.merged_graph.get_edge_data_by_index(idx))
108
+ for idx in self.merged_graph.edge_indices()
109
+ ]
110
+
111
+ metadata = GraphMetadata(
112
+ node_format=NodeFormat.QUALIFIED, # Merged graphs use qualified format
113
+ default_dialect="spark",
114
+ created_at=datetime.now(timezone.utc).isoformat(),
115
+ source_files=sorted(self._source_files),
116
+ total_nodes=len(nodes),
117
+ total_edges=len(edges),
118
+ )
119
+
120
+ return LineageGraph(metadata=metadata, nodes=nodes, edges=edges)
121
+
122
+
123
+ def merge_graphs(graph_paths: List[Path]) -> LineageGraph:
124
+ """
125
+ Convenience function to merge multiple graph files.
126
+
127
+ Args:
128
+ graph_paths: List of paths to graph JSON files
129
+
130
+ Returns:
131
+ Merged LineageGraph
132
+ """
133
+ merger = GraphMerger()
134
+ for path in graph_paths:
135
+ merger.add_file(path)
136
+ return merger.merge()