sql-glider 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,289 @@
1
+ """Pydantic models for graph-based lineage representation."""
2
+
3
+ import csv
4
+ from datetime import datetime, timezone
5
+ from pathlib import Path
6
+ from typing import List, Optional
7
+
8
+ from pydantic import BaseModel, Field
9
+
10
+ from sqlglider.global_models import NodeFormat
11
+
12
+
13
+ class GraphNode(BaseModel):
14
+ """Represents a node in the lineage graph (a column)."""
15
+
16
+ identifier: str = Field(
17
+ ..., description="Unique node identifier (fully-qualified column name)"
18
+ )
19
+ file_path: str = Field(
20
+ ..., description="Source SQL file path where first encountered"
21
+ )
22
+ query_index: int = Field(..., description="Index of query within the file")
23
+
24
+ # Structured fields for flexible querying (always populated from identifier)
25
+ schema_name: Optional[str] = Field(None, description="Schema name (if present)")
26
+ table: Optional[str] = Field(None, description="Table name")
27
+ column: Optional[str] = Field(None, description="Column name")
28
+
29
+ @classmethod
30
+ def from_identifier(
31
+ cls,
32
+ identifier: str,
33
+ file_path: str,
34
+ query_index: int,
35
+ ) -> "GraphNode":
36
+ """
37
+ Create a GraphNode from a column identifier.
38
+
39
+ Parses the identifier into schema, table, and column components.
40
+
41
+ Args:
42
+ identifier: Fully-qualified column name (e.g., "schema.table.column" or "table.column")
43
+ file_path: Source SQL file path
44
+ query_index: Query index within the file
45
+
46
+ Returns:
47
+ GraphNode with parsed components
48
+ """
49
+ parts = identifier.split(".")
50
+
51
+ if len(parts) >= 3:
52
+ schema_name = parts[0]
53
+ table = parts[1]
54
+ column = ".".join(parts[2:]) # Handle columns with dots
55
+ elif len(parts) == 2:
56
+ schema_name = None
57
+ table = parts[0]
58
+ column = parts[1]
59
+ else:
60
+ schema_name = None
61
+ table = None
62
+ column = identifier
63
+
64
+ return cls(
65
+ identifier=identifier,
66
+ file_path=file_path,
67
+ query_index=query_index,
68
+ schema_name=schema_name,
69
+ table=table,
70
+ column=column,
71
+ )
72
+
73
+
74
+ class GraphEdge(BaseModel):
75
+ """Represents an edge in the lineage graph (contributes_to relationship)."""
76
+
77
+ source_node: str = Field(
78
+ ..., description="Source node identifier (contributes from)"
79
+ )
80
+ target_node: str = Field(..., description="Target node identifier (contributes to)")
81
+ file_path: str = Field(
82
+ ..., description="Source SQL file where relationship is defined"
83
+ )
84
+ query_index: int = Field(..., description="Index of query within the file")
85
+
86
+
87
+ class ManifestEntry(BaseModel):
88
+ """Represents a single entry in a manifest file."""
89
+
90
+ file_path: str = Field(..., description="Path to SQL file")
91
+ dialect: Optional[str] = Field(
92
+ None, description="SQL dialect (optional, uses default if empty)"
93
+ )
94
+
95
+
96
+ class Manifest(BaseModel):
97
+ """Represents a manifest file with SQL file paths and optional dialects."""
98
+
99
+ entries: List[ManifestEntry] = Field(default_factory=list)
100
+
101
+ @classmethod
102
+ def from_csv(cls, csv_path: Path) -> "Manifest":
103
+ """
104
+ Load manifest from CSV file.
105
+
106
+ Expected CSV format:
107
+ ```
108
+ file_path,dialect
109
+ queries/orders.sql,spark
110
+ queries/customers.sql,postgres
111
+ queries/legacy.sql,
112
+ ```
113
+
114
+ Args:
115
+ csv_path: Path to manifest CSV file
116
+
117
+ Returns:
118
+ Manifest with loaded entries
119
+
120
+ Raises:
121
+ FileNotFoundError: If CSV file doesn't exist
122
+ ValueError: If CSV is missing required 'file_path' column
123
+ """
124
+ if not csv_path.exists():
125
+ raise FileNotFoundError(f"Manifest file not found: {csv_path}")
126
+
127
+ entries = []
128
+ with open(csv_path, newline="", encoding="utf-8") as f:
129
+ reader = csv.DictReader(f)
130
+
131
+ # Validate required column
132
+ if reader.fieldnames is None or "file_path" not in reader.fieldnames:
133
+ raise ValueError("Manifest CSV must have a 'file_path' column")
134
+
135
+ for row in reader:
136
+ file_path = row["file_path"].strip()
137
+ if not file_path:
138
+ continue # Skip empty rows
139
+
140
+ dialect = row.get("dialect", "").strip() or None
141
+ entries.append(ManifestEntry(file_path=file_path, dialect=dialect))
142
+
143
+ return cls(entries=entries)
144
+
145
+
146
+ class LineagePath(BaseModel):
147
+ """A single lineage path from a node to the queried column."""
148
+
149
+ nodes: List[str] = Field(
150
+ ..., description="Ordered list of node identifiers in the path"
151
+ )
152
+
153
+ @property
154
+ def hops(self) -> int:
155
+ """Number of hops in the path (edges traversed)."""
156
+ return len(self.nodes) - 1 if len(self.nodes) > 1 else 0
157
+
158
+ def to_arrow_string(self) -> str:
159
+ """Format path as arrow-separated string for display."""
160
+ return " -> ".join(self.nodes)
161
+
162
+
163
+ class LineageNode(BaseModel):
164
+ """
165
+ A node in lineage query results with additional context.
166
+
167
+ Extends GraphNode fields with query-specific information like hop distance
168
+ and the output column being queried.
169
+ """
170
+
171
+ # Fields from GraphNode
172
+ identifier: str = Field(
173
+ ..., description="Unique node identifier (fully-qualified column name)"
174
+ )
175
+ file_path: str = Field(
176
+ ..., description="Source SQL file path where first encountered"
177
+ )
178
+ query_index: int = Field(..., description="Index of query within the file")
179
+ schema_name: Optional[str] = Field(None, description="Schema name (if present)")
180
+ table: Optional[str] = Field(None, description="Table name")
181
+ column: Optional[str] = Field(None, description="Column name")
182
+
183
+ # Query result fields
184
+ hops: int = Field(..., description="Number of hops from the queried column")
185
+ output_column: str = Field(..., description="The column that was queried")
186
+
187
+ # Path tracking and root/leaf detection fields
188
+ is_root: bool = Field(
189
+ default=False, description="True if node has no upstream dependencies"
190
+ )
191
+ is_leaf: bool = Field(
192
+ default=False, description="True if node has no downstream dependencies"
193
+ )
194
+ paths: List[LineagePath] = Field(
195
+ default_factory=list,
196
+ description="All paths from this node to the queried column",
197
+ )
198
+
199
+ @classmethod
200
+ def from_graph_node(
201
+ cls,
202
+ node: "GraphNode",
203
+ hops: int,
204
+ output_column: str,
205
+ is_root: bool = False,
206
+ is_leaf: bool = False,
207
+ paths: Optional[List[LineagePath]] = None,
208
+ ) -> "LineageNode":
209
+ """
210
+ Create a LineageNode from a GraphNode with additional context.
211
+
212
+ Args:
213
+ node: The underlying GraphNode
214
+ hops: Number of hops from the query column
215
+ output_column: The column that was queried
216
+ is_root: True if node has no upstream dependencies
217
+ is_leaf: True if node has no downstream dependencies
218
+ paths: List of all paths from this node to the queried column
219
+
220
+ Returns:
221
+ LineageNode with all GraphNode fields plus query context
222
+ """
223
+ return cls(
224
+ identifier=node.identifier,
225
+ file_path=node.file_path,
226
+ query_index=node.query_index,
227
+ schema_name=node.schema_name,
228
+ table=node.table,
229
+ column=node.column,
230
+ hops=hops,
231
+ output_column=output_column,
232
+ is_root=is_root,
233
+ is_leaf=is_leaf,
234
+ paths=paths or [],
235
+ )
236
+
237
+
238
+ class GraphMetadata(BaseModel):
239
+ """Metadata about the lineage graph."""
240
+
241
+ node_format: NodeFormat = Field(
242
+ default=NodeFormat.QUALIFIED,
243
+ description="Format of node identifiers in serialized output",
244
+ )
245
+ default_dialect: str = Field(
246
+ default="spark", description="Default SQL dialect used"
247
+ )
248
+ created_at: str = Field(
249
+ default_factory=lambda: datetime.now(timezone.utc).isoformat(),
250
+ description="ISO 8601 timestamp of graph creation",
251
+ )
252
+ source_files: List[str] = Field(
253
+ default_factory=list,
254
+ description="List of source SQL files included in the graph",
255
+ )
256
+ total_nodes: int = Field(
257
+ default=0, description="Total number of nodes in the graph"
258
+ )
259
+ total_edges: int = Field(
260
+ default=0, description="Total number of edges in the graph"
261
+ )
262
+
263
+
264
+ class LineageGraph(BaseModel):
265
+ """Serializable representation of the complete lineage graph."""
266
+
267
+ metadata: GraphMetadata = Field(default_factory=GraphMetadata)
268
+ nodes: List[GraphNode] = Field(
269
+ default_factory=list, description="All nodes in the graph"
270
+ )
271
+ edges: List[GraphEdge] = Field(
272
+ default_factory=list, description="All edges in the graph"
273
+ )
274
+
275
+ def get_node_by_identifier(self, identifier: str) -> Optional[GraphNode]:
276
+ """
277
+ Find a node by its identifier (case-insensitive).
278
+
279
+ Args:
280
+ identifier: Node identifier to find
281
+
282
+ Returns:
283
+ GraphNode if found, None otherwise
284
+ """
285
+ identifier_lower = identifier.lower()
286
+ for node in self.nodes:
287
+ if node.identifier.lower() == identifier_lower:
288
+ return node
289
+ return None
@@ -0,0 +1,287 @@
1
+ """Graph query functionality for upstream/downstream analysis."""
2
+
3
+ from pathlib import Path
4
+ from typing import List, Optional
5
+
6
+ import rustworkx as rx
7
+
8
+ from sqlglider.graph.models import GraphNode, LineageGraph, LineageNode, LineagePath
9
+ from sqlglider.graph.serialization import load_graph, to_rustworkx
10
+
11
+
12
+ class LineageQueryResult:
13
+ """Result of a lineage query."""
14
+
15
+ def __init__(
16
+ self,
17
+ query_column: str,
18
+ direction: str, # "upstream" or "downstream"
19
+ related_columns: List[LineageNode],
20
+ ):
21
+ """
22
+ Initialize query result.
23
+
24
+ Args:
25
+ query_column: The column that was queried
26
+ direction: Query direction ("upstream" or "downstream")
27
+ related_columns: List of related LineageNode objects with hop info
28
+ """
29
+ self.query_column = query_column
30
+ self.direction = direction
31
+ self.related_columns = related_columns
32
+
33
+ def __len__(self) -> int:
34
+ """Return number of related columns."""
35
+ return len(self.related_columns)
36
+
37
+ def __iter__(self):
38
+ """Iterate over related columns."""
39
+ return iter(self.related_columns)
40
+
41
+
42
+ class GraphQuerier:
43
+ """Query lineage graphs for upstream/downstream dependencies."""
44
+
45
+ def __init__(self, graph: LineageGraph):
46
+ """
47
+ Initialize the querier with a graph.
48
+
49
+ Args:
50
+ graph: LineageGraph to query
51
+ """
52
+ self.graph = graph
53
+ self.rx_graph, self.node_map = to_rustworkx(graph)
54
+ self._reverse_map = {v: k for k, v in self.node_map.items()}
55
+ # Create reversed graph for upstream queries (lazy initialization)
56
+ self._rx_graph_reversed: Optional[rx.PyDiGraph] = None
57
+
58
+ @property
59
+ def rx_graph_reversed(self) -> rx.PyDiGraph:
60
+ """Get reversed graph for upstream traversal (created lazily)."""
61
+ if self._rx_graph_reversed is None:
62
+ self._rx_graph_reversed = self.rx_graph.copy()
63
+ self._rx_graph_reversed.reverse()
64
+ return self._rx_graph_reversed
65
+
66
+ def _is_root(self, node_idx: int) -> bool:
67
+ """Check if node is a root (no incoming edges in original graph)."""
68
+ return self.rx_graph.in_degree(node_idx) == 0
69
+
70
+ def _is_leaf(self, node_idx: int) -> bool:
71
+ """Check if node is a leaf (no outgoing edges in original graph)."""
72
+ return self.rx_graph.out_degree(node_idx) == 0
73
+
74
+ def _find_all_paths(
75
+ self,
76
+ from_idx: int,
77
+ to_idx: int,
78
+ use_reversed: bool = False,
79
+ ) -> List[List[int]]:
80
+ """
81
+ Find all simple paths between two nodes.
82
+
83
+ Args:
84
+ from_idx: Starting node index
85
+ to_idx: Target node index
86
+ use_reversed: If True, use reversed graph for upstream queries
87
+
88
+ Returns:
89
+ List of paths, where each path is a list of node indices
90
+ """
91
+ graph = self.rx_graph_reversed if use_reversed else self.rx_graph
92
+ return rx.all_simple_paths(graph, from_idx, to_idx)
93
+
94
+ def _convert_path_to_identifiers(
95
+ self,
96
+ path: List[int],
97
+ reverse: bool = False,
98
+ ) -> LineagePath:
99
+ """
100
+ Convert a path of node indices to a LineagePath with identifiers.
101
+
102
+ Args:
103
+ path: List of node indices
104
+ reverse: If True, reverse the path order (for upstream queries)
105
+
106
+ Returns:
107
+ LineagePath with node identifiers
108
+ """
109
+ identifiers = [self._reverse_map[idx] for idx in path]
110
+ if reverse:
111
+ identifiers = list(reversed(identifiers))
112
+ return LineagePath(nodes=identifiers)
113
+
114
+ @classmethod
115
+ def from_file(cls, graph_path: Path) -> "GraphQuerier":
116
+ """
117
+ Create a querier from a graph file.
118
+
119
+ Args:
120
+ graph_path: Path to graph JSON file
121
+
122
+ Returns:
123
+ GraphQuerier instance
124
+
125
+ Raises:
126
+ FileNotFoundError: If file doesn't exist
127
+ """
128
+ graph = load_graph(graph_path)
129
+ return cls(graph)
130
+
131
+ def find_upstream(self, column: str) -> LineageQueryResult:
132
+ """
133
+ Find all upstream (source) columns for a given column.
134
+
135
+ Uses dijkstra_shortest_path_lengths on a reversed graph to find all
136
+ nodes that have a path leading to the specified column, with hop counts,
137
+ root/leaf detection, and full path information.
138
+
139
+ Args:
140
+ column: Column identifier to analyze
141
+
142
+ Returns:
143
+ LineageQueryResult with upstream columns including:
144
+ - hop distances (shortest path)
145
+ - is_root/is_leaf flags
146
+ - all paths to the queried column
147
+
148
+ Raises:
149
+ ValueError: If column not found in graph
150
+ """
151
+ # Case-insensitive lookup
152
+ matched_column = self._find_column(column)
153
+ if matched_column is None:
154
+ raise ValueError(f"Column '{column}' not found in graph")
155
+
156
+ target_idx = self.node_map[matched_column]
157
+
158
+ # Use dijkstra on reversed graph to get distances to all ancestors
159
+ # Each edge has weight 1.0 for hop counting
160
+ distances = rx.dijkstra_shortest_path_lengths(
161
+ self.rx_graph_reversed,
162
+ target_idx,
163
+ edge_cost_fn=lambda _: 1.0,
164
+ )
165
+
166
+ # Build LineageNode for each reachable node
167
+ upstream_columns = []
168
+ for idx, hops in distances.items():
169
+ node_data = self.rx_graph[idx]
170
+
171
+ # Find all paths from this node to target
172
+ # On reversed graph: from target to this node, then reverse the paths
173
+ raw_paths = self._find_all_paths(target_idx, idx, use_reversed=True)
174
+ paths = [
175
+ self._convert_path_to_identifiers(p, reverse=True) for p in raw_paths
176
+ ]
177
+
178
+ upstream_columns.append(
179
+ LineageNode.from_graph_node(
180
+ GraphNode(**node_data),
181
+ hops=int(hops),
182
+ output_column=matched_column,
183
+ is_root=self._is_root(idx),
184
+ is_leaf=self._is_leaf(idx),
185
+ paths=paths,
186
+ )
187
+ )
188
+
189
+ # Sort by identifier for consistent output
190
+ upstream_columns.sort(key=lambda n: n.identifier.lower())
191
+
192
+ return LineageQueryResult(
193
+ query_column=matched_column,
194
+ direction="upstream",
195
+ related_columns=upstream_columns,
196
+ )
197
+
198
+ def find_downstream(self, column: str) -> LineageQueryResult:
199
+ """
200
+ Find all downstream (affected) columns for a given column.
201
+
202
+ Uses dijkstra_shortest_path_lengths to find all nodes that have a path
203
+ from the specified column, with hop counts, root/leaf detection, and
204
+ full path information.
205
+
206
+ Args:
207
+ column: Column identifier to analyze
208
+
209
+ Returns:
210
+ LineageQueryResult with downstream columns including:
211
+ - hop distances (shortest path)
212
+ - is_root/is_leaf flags
213
+ - all paths from the queried column
214
+
215
+ Raises:
216
+ ValueError: If column not found in graph
217
+ """
218
+ # Case-insensitive lookup
219
+ matched_column = self._find_column(column)
220
+ if matched_column is None:
221
+ raise ValueError(f"Column '{column}' not found in graph")
222
+
223
+ source_idx = self.node_map[matched_column]
224
+
225
+ # Use dijkstra on original graph to get distances to all descendants
226
+ # Each edge has weight 1.0 for hop counting
227
+ distances = rx.dijkstra_shortest_path_lengths(
228
+ self.rx_graph,
229
+ source_idx,
230
+ edge_cost_fn=lambda _: 1.0,
231
+ )
232
+
233
+ # Build LineageNode for each reachable node
234
+ downstream_columns = []
235
+ for idx, hops in distances.items():
236
+ node_data = self.rx_graph[idx]
237
+
238
+ # Find all paths from source to this node
239
+ raw_paths = self._find_all_paths(source_idx, idx, use_reversed=False)
240
+ paths = [
241
+ self._convert_path_to_identifiers(p, reverse=False) for p in raw_paths
242
+ ]
243
+
244
+ downstream_columns.append(
245
+ LineageNode.from_graph_node(
246
+ GraphNode(**node_data),
247
+ hops=int(hops),
248
+ output_column=matched_column,
249
+ is_root=self._is_root(idx),
250
+ is_leaf=self._is_leaf(idx),
251
+ paths=paths,
252
+ )
253
+ )
254
+
255
+ # Sort by identifier for consistent output
256
+ downstream_columns.sort(key=lambda n: n.identifier.lower())
257
+
258
+ return LineageQueryResult(
259
+ query_column=matched_column,
260
+ direction="downstream",
261
+ related_columns=downstream_columns,
262
+ )
263
+
264
+ def _find_column(self, column: str) -> Optional[str]:
265
+ """
266
+ Find column with case-insensitive matching.
267
+
268
+ Args:
269
+ column: Column identifier to find
270
+
271
+ Returns:
272
+ Matched column identifier or None
273
+ """
274
+ column_lower = column.lower()
275
+ for identifier in self.node_map.keys():
276
+ if identifier.lower() == column_lower:
277
+ return identifier
278
+ return None
279
+
280
+ def list_columns(self) -> List[str]:
281
+ """
282
+ List all column identifiers in the graph.
283
+
284
+ Returns:
285
+ Sorted list of column identifiers
286
+ """
287
+ return sorted(self.node_map.keys(), key=str.lower)
@@ -0,0 +1,107 @@
1
+ """Serialization and deserialization for lineage graphs."""
2
+
3
+ from pathlib import Path
4
+ from typing import Dict, Tuple
5
+
6
+ import rustworkx as rx
7
+
8
+ from sqlglider.graph.models import (
9
+ GraphEdge,
10
+ GraphMetadata,
11
+ GraphNode,
12
+ LineageGraph,
13
+ )
14
+
15
+
16
+ def save_graph(graph: LineageGraph, output_path: Path) -> None:
17
+ """
18
+ Save a LineageGraph to a JSON file.
19
+
20
+ Args:
21
+ graph: LineageGraph to save
22
+ output_path: Output file path
23
+ """
24
+ output_path.write_text(
25
+ graph.model_dump_json(indent=2),
26
+ encoding="utf-8",
27
+ )
28
+
29
+
30
+ def load_graph(input_path: Path) -> LineageGraph:
31
+ """
32
+ Load a LineageGraph from a JSON file.
33
+
34
+ Args:
35
+ input_path: Input file path
36
+
37
+ Returns:
38
+ Loaded LineageGraph
39
+
40
+ Raises:
41
+ FileNotFoundError: If input file doesn't exist
42
+ ValueError: If file content is invalid JSON or doesn't match schema
43
+ """
44
+ if not input_path.exists():
45
+ raise FileNotFoundError(f"Graph file not found: {input_path}")
46
+
47
+ content = input_path.read_text(encoding="utf-8")
48
+ return LineageGraph.model_validate_json(content)
49
+
50
+
51
+ def to_rustworkx(graph: LineageGraph) -> Tuple[rx.PyDiGraph, Dict[str, int]]:
52
+ """
53
+ Convert a LineageGraph to a rustworkx PyDiGraph.
54
+
55
+ Args:
56
+ graph: LineageGraph to convert
57
+
58
+ Returns:
59
+ Tuple of (PyDiGraph, node_identifier_to_index_map)
60
+ """
61
+ rx_graph: rx.PyDiGraph = rx.PyDiGraph()
62
+ node_map: Dict[str, int] = {}
63
+
64
+ # Add nodes
65
+ for node in graph.nodes:
66
+ idx = rx_graph.add_node(node.model_dump())
67
+ node_map[node.identifier] = idx
68
+
69
+ # Add edges
70
+ for edge in graph.edges:
71
+ source_idx = node_map.get(edge.source_node)
72
+ target_idx = node_map.get(edge.target_node)
73
+ if source_idx is not None and target_idx is not None:
74
+ rx_graph.add_edge(source_idx, target_idx, edge.model_dump())
75
+
76
+ return rx_graph, node_map
77
+
78
+
79
+ def from_rustworkx(
80
+ rx_graph: rx.PyDiGraph,
81
+ metadata: GraphMetadata,
82
+ ) -> LineageGraph:
83
+ """
84
+ Convert a rustworkx PyDiGraph to a LineageGraph.
85
+
86
+ Args:
87
+ rx_graph: rustworkx directed graph
88
+ metadata: Graph metadata to include
89
+
90
+ Returns:
91
+ LineageGraph with nodes and edges from the rustworkx graph
92
+ """
93
+ nodes = [GraphNode(**rx_graph[idx]) for idx in rx_graph.node_indices()]
94
+ edges = [
95
+ GraphEdge(**rx_graph.get_edge_data_by_index(idx))
96
+ for idx in rx_graph.edge_indices()
97
+ ]
98
+
99
+ # Update metadata counts
100
+ metadata.total_nodes = len(nodes)
101
+ metadata.total_edges = len(edges)
102
+
103
+ return LineageGraph(
104
+ metadata=metadata,
105
+ nodes=nodes,
106
+ edges=edges,
107
+ )
@@ -0,0 +1,10 @@
1
+ """Lineage analysis module for SQL Glider."""
2
+
3
+ from sqlglider.lineage.analyzer import (
4
+ LineageAnalyzer,
5
+ LineageItem,
6
+ QueryLineageResult,
7
+ QueryMetadata,
8
+ )
9
+
10
+ __all__ = ["LineageAnalyzer", "LineageItem", "QueryLineageResult", "QueryMetadata"]