sql-glider 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,17 @@
1
+ """Shared models and enums used across sqlglider modules."""
2
+
3
+ from enum import Enum
4
+
5
+
6
+ class AnalysisLevel(str, Enum):
7
+ """Analysis granularity level for lineage."""
8
+
9
+ COLUMN = "column"
10
+ TABLE = "table"
11
+
12
+
13
+ class NodeFormat(str, Enum):
14
+ """Format for node identifiers in graph output."""
15
+
16
+ QUALIFIED = "qualified"
17
+ STRUCTURED = "structured"
@@ -0,0 +1,42 @@
1
+ """Graph-based lineage analysis module for SQL Glider."""
2
+
3
+ from sqlglider.graph.builder import GraphBuilder
4
+ from sqlglider.graph.merge import GraphMerger, merge_graphs
5
+ from sqlglider.graph.models import (
6
+ GraphEdge,
7
+ GraphMetadata,
8
+ GraphNode,
9
+ LineageGraph,
10
+ Manifest,
11
+ ManifestEntry,
12
+ )
13
+ from sqlglider.graph.query import GraphQuerier, LineageQueryResult
14
+ from sqlglider.graph.serialization import (
15
+ from_rustworkx,
16
+ load_graph,
17
+ save_graph,
18
+ to_rustworkx,
19
+ )
20
+
21
+ __all__ = [
22
+ # Models
23
+ "GraphNode",
24
+ "GraphEdge",
25
+ "GraphMetadata",
26
+ "LineageGraph",
27
+ "Manifest",
28
+ "ManifestEntry",
29
+ # Builder
30
+ "GraphBuilder",
31
+ # Merge
32
+ "GraphMerger",
33
+ "merge_graphs",
34
+ # Query
35
+ "GraphQuerier",
36
+ "LineageQueryResult",
37
+ # Serialization
38
+ "load_graph",
39
+ "save_graph",
40
+ "to_rustworkx",
41
+ "from_rustworkx",
42
+ ]
@@ -0,0 +1,310 @@
1
+ """Graph builder for constructing lineage graphs from SQL files."""
2
+
3
+ from datetime import datetime, timezone
4
+ from pathlib import Path
5
+ from typing import Callable, Dict, List, Optional, Set
6
+
7
+ import rustworkx as rx
8
+ from rich.console import Console
9
+
10
+ from sqlglider.global_models import AnalysisLevel, NodeFormat
11
+ from sqlglider.graph.models import (
12
+ GraphEdge,
13
+ GraphMetadata,
14
+ GraphNode,
15
+ LineageGraph,
16
+ Manifest,
17
+ )
18
+ from sqlglider.lineage.analyzer import LineageAnalyzer
19
+ from sqlglider.utils.file_utils import read_sql_file
20
+
21
+ console = Console(stderr=True)
22
+
23
+ # Type alias for SQL preprocessor functions
24
+ SqlPreprocessor = Callable[[str, Path], str]
25
+
26
+
27
+ class GraphBuilder:
28
+ """Build lineage graphs from SQL files using rustworkx."""
29
+
30
+ def __init__(
31
+ self,
32
+ node_format: NodeFormat = NodeFormat.QUALIFIED,
33
+ dialect: str = "spark",
34
+ sql_preprocessor: Optional[SqlPreprocessor] = None,
35
+ ):
36
+ """
37
+ Initialize the graph builder.
38
+
39
+ Args:
40
+ node_format: Format for node identifiers (QUALIFIED or STRUCTURED)
41
+ dialect: Default SQL dialect (used when not specified per-file)
42
+ sql_preprocessor: Optional function to preprocess SQL before analysis.
43
+ Takes (sql: str, file_path: Path) and returns processed SQL.
44
+ Useful for templating (e.g., Jinja2 rendering).
45
+ """
46
+ self.node_format = node_format
47
+ self.dialect = dialect
48
+ self.sql_preprocessor = sql_preprocessor
49
+ self.graph: rx.PyDiGraph = rx.PyDiGraph()
50
+ self._node_index_map: Dict[str, int] = {} # identifier -> rustworkx node index
51
+ self._source_files: Set[str] = set()
52
+ self._edge_set: Set[tuple] = set() # (source, target) for dedup
53
+ self._skipped_files: List[tuple[str, str]] = [] # (file_path, reason)
54
+
55
+ def add_file(
56
+ self,
57
+ file_path: Path,
58
+ dialect: Optional[str] = None,
59
+ ) -> "GraphBuilder":
60
+ """
61
+ Add lineage from a single SQL file to the graph.
62
+
63
+ Args:
64
+ file_path: Path to SQL file
65
+ dialect: SQL dialect (uses builder default if not specified)
66
+
67
+ Returns:
68
+ self for method chaining
69
+
70
+ Raises:
71
+ FileNotFoundError: If file doesn't exist
72
+ ParseError: If SQL cannot be parsed
73
+ """
74
+ file_dialect = dialect or self.dialect
75
+ file_path_str = str(file_path.resolve())
76
+
77
+ try:
78
+ sql_content = read_sql_file(file_path)
79
+
80
+ # Apply SQL preprocessor if configured (e.g., for templating)
81
+ if self.sql_preprocessor:
82
+ sql_content = self.sql_preprocessor(sql_content, file_path)
83
+
84
+ analyzer = LineageAnalyzer(sql_content, dialect=file_dialect)
85
+ results = analyzer.analyze_queries(level=AnalysisLevel.COLUMN)
86
+
87
+ # Print warnings for any skipped queries within the file
88
+ for skipped in analyzer.skipped_queries:
89
+ console.print(
90
+ f"[yellow]Warning:[/yellow] Skipping query {skipped.query_index} "
91
+ f"in {file_path.name} ({skipped.statement_type}): {skipped.reason}"
92
+ )
93
+
94
+ self._source_files.add(file_path_str)
95
+
96
+ for result in results:
97
+ query_index = result.metadata.query_index
98
+
99
+ for item in result.lineage_items:
100
+ if not item.source_name: # Skip empty sources
101
+ continue
102
+
103
+ # Add/get nodes
104
+ source_node_idx = self._ensure_node(
105
+ item.source_name,
106
+ file_path_str,
107
+ query_index,
108
+ )
109
+ target_node_idx = self._ensure_node(
110
+ item.output_name,
111
+ file_path_str,
112
+ query_index,
113
+ )
114
+
115
+ # Add edge (source contributes_to target) - deduplicate
116
+ edge_key = (item.source_name, item.output_name)
117
+ if edge_key not in self._edge_set:
118
+ edge = GraphEdge(
119
+ source_node=item.source_name,
120
+ target_node=item.output_name,
121
+ file_path=file_path_str,
122
+ query_index=query_index,
123
+ )
124
+ self.graph.add_edge(
125
+ source_node_idx, target_node_idx, edge.model_dump()
126
+ )
127
+ self._edge_set.add(edge_key)
128
+
129
+ except ValueError as e:
130
+ # Skip files that fail completely (all statements unsupported)
131
+ error_msg = str(e)
132
+ self._skipped_files.append((file_path_str, error_msg))
133
+ console.print(
134
+ f"[yellow]Warning:[/yellow] Skipping {file_path.name}: {error_msg}"
135
+ )
136
+
137
+ return self
138
+
139
+ def add_directory(
140
+ self,
141
+ dir_path: Path,
142
+ recursive: bool = False,
143
+ glob_pattern: str = "*.sql",
144
+ dialect: Optional[str] = None,
145
+ ) -> "GraphBuilder":
146
+ """
147
+ Add lineage from all SQL files in a directory.
148
+
149
+ Args:
150
+ dir_path: Path to directory
151
+ recursive: Whether to search recursively
152
+ glob_pattern: Glob pattern for SQL files
153
+ dialect: SQL dialect (uses builder default if not specified)
154
+
155
+ Returns:
156
+ self for method chaining
157
+
158
+ Raises:
159
+ ValueError: If path is not a directory
160
+ """
161
+ if not dir_path.is_dir():
162
+ raise ValueError(f"Not a directory: {dir_path}")
163
+
164
+ if recursive:
165
+ pattern = f"**/{glob_pattern}"
166
+ else:
167
+ pattern = glob_pattern
168
+
169
+ for sql_file in sorted(dir_path.glob(pattern)):
170
+ if sql_file.is_file():
171
+ self.add_file(sql_file, dialect)
172
+
173
+ return self
174
+
175
+ def add_manifest(
176
+ self,
177
+ manifest_path: Path,
178
+ dialect: Optional[str] = None,
179
+ ) -> "GraphBuilder":
180
+ """
181
+ Add lineage from files specified in a manifest CSV.
182
+
183
+ Args:
184
+ manifest_path: Path to manifest CSV file
185
+ dialect: Default SQL dialect (overridden by manifest entries)
186
+
187
+ Returns:
188
+ self for method chaining
189
+
190
+ Raises:
191
+ FileNotFoundError: If manifest or referenced files don't exist
192
+ ValueError: If manifest format is invalid
193
+ """
194
+ manifest = Manifest.from_csv(manifest_path)
195
+ base_dir = manifest_path.parent
196
+
197
+ for entry in manifest.entries:
198
+ # Resolve file path relative to manifest location
199
+ file_path = Path(entry.file_path)
200
+ if not file_path.is_absolute():
201
+ file_path = (base_dir / entry.file_path).resolve()
202
+
203
+ # Use entry dialect, then CLI dialect, then builder default
204
+ entry_dialect = entry.dialect or dialect or self.dialect
205
+ self.add_file(file_path, entry_dialect)
206
+
207
+ return self
208
+
209
+ def add_files(
210
+ self,
211
+ file_paths: List[Path],
212
+ dialect: Optional[str] = None,
213
+ ) -> "GraphBuilder":
214
+ """
215
+ Add lineage from multiple SQL files.
216
+
217
+ Args:
218
+ file_paths: List of paths to SQL files
219
+ dialect: SQL dialect (uses builder default if not specified)
220
+
221
+ Returns:
222
+ self for method chaining
223
+ """
224
+ for file_path in file_paths:
225
+ self.add_file(file_path, dialect)
226
+ return self
227
+
228
+ def _ensure_node(
229
+ self,
230
+ identifier: str,
231
+ file_path: str,
232
+ query_index: int,
233
+ ) -> int:
234
+ """
235
+ Ensure a node exists in the graph, creating it if necessary.
236
+
237
+ Args:
238
+ identifier: Node identifier (e.g., "table.column")
239
+ file_path: Source file path
240
+ query_index: Query index within file
241
+
242
+ Returns:
243
+ rustworkx node index
244
+ """
245
+ if identifier in self._node_index_map:
246
+ return self._node_index_map[identifier]
247
+
248
+ node = GraphNode.from_identifier(
249
+ identifier=identifier,
250
+ file_path=file_path,
251
+ query_index=query_index,
252
+ )
253
+
254
+ node_idx = self.graph.add_node(node.model_dump())
255
+ self._node_index_map[identifier] = node_idx
256
+ return node_idx
257
+
258
+ def build(self) -> LineageGraph:
259
+ """
260
+ Build and return the final LineageGraph.
261
+
262
+ Returns:
263
+ LineageGraph with metadata, nodes, and edges
264
+ """
265
+ nodes = []
266
+ for idx in self.graph.node_indices():
267
+ node_data = self.graph[idx]
268
+ nodes.append(GraphNode(**node_data))
269
+
270
+ edges = []
271
+ for edge_idx in self.graph.edge_indices():
272
+ edge_data = self.graph.get_edge_data_by_index(edge_idx)
273
+ edges.append(GraphEdge(**edge_data))
274
+
275
+ metadata = GraphMetadata(
276
+ node_format=self.node_format,
277
+ default_dialect=self.dialect,
278
+ created_at=datetime.now(timezone.utc).isoformat(),
279
+ source_files=sorted(self._source_files),
280
+ total_nodes=len(nodes),
281
+ total_edges=len(edges),
282
+ )
283
+
284
+ # Print summary of skipped files if any
285
+ if self._skipped_files:
286
+ console.print(
287
+ f"\n[yellow]Summary:[/yellow] Skipped {len(self._skipped_files)} "
288
+ f"file(s) that could not be analyzed for lineage."
289
+ )
290
+
291
+ return LineageGraph(
292
+ metadata=metadata,
293
+ nodes=nodes,
294
+ edges=edges,
295
+ )
296
+
297
+ @property
298
+ def rustworkx_graph(self) -> rx.PyDiGraph:
299
+ """Get the underlying rustworkx graph for direct operations."""
300
+ return self.graph
301
+
302
+ @property
303
+ def node_index_map(self) -> Dict[str, int]:
304
+ """Get mapping from node identifiers to rustworkx indices."""
305
+ return self._node_index_map.copy()
306
+
307
+ @property
308
+ def skipped_files(self) -> List[tuple[str, str]]:
309
+ """Get list of files that were skipped during graph building."""
310
+ return self._skipped_files.copy()
@@ -0,0 +1,136 @@
1
+ """Graph merging functionality."""
2
+
3
+ from datetime import datetime, timezone
4
+ from pathlib import Path
5
+ from typing import Dict, List, Set
6
+
7
+ import rustworkx as rx
8
+
9
+ from sqlglider.global_models import NodeFormat
10
+ from sqlglider.graph.models import (
11
+ GraphEdge,
12
+ GraphMetadata,
13
+ GraphNode,
14
+ LineageGraph,
15
+ )
16
+ from sqlglider.graph.serialization import load_graph
17
+
18
+
19
+ class GraphMerger:
20
+ """Merge multiple lineage graphs into one."""
21
+
22
+ def __init__(self):
23
+ """Initialize the merger."""
24
+ self.merged_graph: rx.PyDiGraph = rx.PyDiGraph()
25
+ self._node_map: Dict[str, int] = {} # identifier -> node index
26
+ self._source_files: Set[str] = set()
27
+ self._edge_set: Set[tuple] = set() # (source, target) for dedup
28
+
29
+ def add_graph(self, graph: LineageGraph) -> "GraphMerger":
30
+ """
31
+ Add a graph to be merged.
32
+
33
+ Nodes are deduplicated by identifier (first occurrence wins).
34
+ Edges are deduplicated by (source_node, target_node) pair.
35
+
36
+ Args:
37
+ graph: LineageGraph to add
38
+
39
+ Returns:
40
+ self for method chaining
41
+ """
42
+ self._source_files.update(graph.metadata.source_files)
43
+
44
+ # Add nodes (deduplicate by identifier)
45
+ for node in graph.nodes:
46
+ if node.identifier not in self._node_map:
47
+ idx = self.merged_graph.add_node(node.model_dump())
48
+ self._node_map[node.identifier] = idx
49
+
50
+ # Add edges (deduplicate by source-target pair)
51
+ for edge in graph.edges:
52
+ edge_key = (edge.source_node, edge.target_node)
53
+ if edge_key not in self._edge_set:
54
+ source_idx = self._node_map.get(edge.source_node)
55
+ target_idx = self._node_map.get(edge.target_node)
56
+ if source_idx is not None and target_idx is not None:
57
+ self.merged_graph.add_edge(
58
+ source_idx, target_idx, edge.model_dump()
59
+ )
60
+ self._edge_set.add(edge_key)
61
+
62
+ return self
63
+
64
+ def add_file(self, graph_path: Path) -> "GraphMerger":
65
+ """
66
+ Add a graph from a JSON file.
67
+
68
+ Args:
69
+ graph_path: Path to graph JSON file
70
+
71
+ Returns:
72
+ self for method chaining
73
+
74
+ Raises:
75
+ FileNotFoundError: If file doesn't exist
76
+ ValueError: If file is not valid graph JSON
77
+ """
78
+ graph = load_graph(graph_path)
79
+ return self.add_graph(graph)
80
+
81
+ def add_files(self, graph_paths: List[Path]) -> "GraphMerger":
82
+ """
83
+ Add multiple graphs from JSON files.
84
+
85
+ Args:
86
+ graph_paths: List of paths to graph JSON files
87
+
88
+ Returns:
89
+ self for method chaining
90
+ """
91
+ for path in graph_paths:
92
+ self.add_file(path)
93
+ return self
94
+
95
+ def merge(self) -> LineageGraph:
96
+ """
97
+ Build the merged graph.
98
+
99
+ Returns:
100
+ Merged LineageGraph with combined nodes and edges
101
+ """
102
+ nodes = [
103
+ GraphNode(**self.merged_graph[idx])
104
+ for idx in self.merged_graph.node_indices()
105
+ ]
106
+ edges = [
107
+ GraphEdge(**self.merged_graph.get_edge_data_by_index(idx))
108
+ for idx in self.merged_graph.edge_indices()
109
+ ]
110
+
111
+ metadata = GraphMetadata(
112
+ node_format=NodeFormat.QUALIFIED, # Merged graphs use qualified format
113
+ default_dialect="spark",
114
+ created_at=datetime.now(timezone.utc).isoformat(),
115
+ source_files=sorted(self._source_files),
116
+ total_nodes=len(nodes),
117
+ total_edges=len(edges),
118
+ )
119
+
120
+ return LineageGraph(metadata=metadata, nodes=nodes, edges=edges)
121
+
122
+
123
+ def merge_graphs(graph_paths: List[Path]) -> LineageGraph:
124
+ """
125
+ Convenience function to merge multiple graph files.
126
+
127
+ Args:
128
+ graph_paths: List of paths to graph JSON files
129
+
130
+ Returns:
131
+ Merged LineageGraph
132
+ """
133
+ merger = GraphMerger()
134
+ for path in graph_paths:
135
+ merger.add_file(path)
136
+ return merger.merge()