sqlh 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqlh/.DS_Store ADDED
Binary file
sqlh/__init__.py ADDED
@@ -0,0 +1,32 @@
1
+ from .core.graph import DagGraph
2
+ from .core.helper import split_sql, trim_comment
3
+ from .utils import (
4
+ get_all_leaf_tables,
5
+ get_all_root_tables,
6
+ get_all_tables,
7
+ read_sql_from_directory,
8
+ search_command_json,
9
+ search_related_downstream_tables,
10
+ search_related_root_tables,
11
+ search_related_tables,
12
+ search_related_upstream_tables,
13
+ visualize_dag,
14
+ )
15
+
16
+ __version__ = "0.2.3"
17
+
18
+ __all__ = [
19
+ "split_sql",
20
+ "trim_comment",
21
+ "DagGraph",
22
+ "read_sql_from_directory",
23
+ "get_all_tables",
24
+ "get_all_root_tables",
25
+ "get_all_leaf_tables",
26
+ "search_related_root_tables",
27
+ "search_related_upstream_tables",
28
+ "search_related_downstream_tables",
29
+ "search_related_tables",
30
+ "search_command_json",
31
+ "visualize_dag",
32
+ ]
sqlh/cli.py ADDED
@@ -0,0 +1,153 @@
1
+ import argparse
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ from sqlh import __version__
6
+
7
+ from .utils import (
8
+ get_all_dag,
9
+ get_all_leaf_tables,
10
+ get_all_root_tables,
11
+ get_all_tables,
12
+ list_command_json,
13
+ list_command_text,
14
+ read_sql_from_directory,
15
+ search_command_json,
16
+ search_command_text,
17
+ search_related_downstream_tables,
18
+ search_related_root_tables,
19
+ search_related_tables,
20
+ search_related_upstream_tables,
21
+ visualize_dag,
22
+ )
23
+
24
+
25
+ def _create_parent_parser():
26
+ """创建包含共享参数的父解析器"""
27
+ parent_parser = argparse.ArgumentParser(add_help=False)
28
+
29
+ # 共享参数: 所有子命令都支持 -s/--sql 或 -p/--path
30
+ sql_or_path = parent_parser.add_mutually_exclusive_group(required=True)
31
+ sql_or_path.add_argument("-s", "--sql", dest="sql", help="sql statement")
32
+ sql_or_path.add_argument("-p", "--path", dest="path", help="sql file or directory path")
33
+
34
+ # 共享参数: 所有子命令都支持输出格式
35
+ parent_parser.add_argument(
36
+ "-f",
37
+ "--output-format",
38
+ choices=["json", "text", "web", "html"],
39
+ default="json",
40
+ help="output format",
41
+ )
42
+
43
+ return parent_parser
44
+
45
+
46
+ def arg_parse():
47
+ parser = argparse.ArgumentParser(usage="%(prog)s [OPTIONS] <COMMAND>", description="mini-sqllineage")
48
+ parser.add_argument("-v", "--version", action="version", version=__version__)
49
+
50
+ # 获取共享参数的父解析器
51
+ parent_parser = _create_parent_parser()
52
+
53
+ # 子命令
54
+ subparsers = parser.add_subparsers(dest="command", required=True)
55
+
56
+ # list-tables 子命令
57
+ list_parser = subparsers.add_parser(
58
+ "list",
59
+ parents=[parent_parser],
60
+ help="list all tables / root tables / leaf tables",
61
+ add_help=False,
62
+ )
63
+ list_target = list_parser.add_mutually_exclusive_group(required=True)
64
+ list_target.add_argument("--all", action="store_true", help="list all tables")
65
+ list_target.add_argument("--root", action="store_true", help="list root tables (tables with no dependencies)")
66
+ list_target.add_argument("--leaf", action="store_true", help="list leaf tables (tables not used by others)")
67
+ list_parser.add_argument("-h", "--help", action="help", default=argparse.SUPPRESS, help="show this help message")
68
+
69
+ # search-table 子命令
70
+ search_parser = subparsers.add_parser(
71
+ "search",
72
+ parents=[parent_parser],
73
+ help="search table relationships",
74
+ add_help=False,
75
+ )
76
+ search_direction = search_parser.add_mutually_exclusive_group(required=True)
77
+ search_direction.add_argument("--root", action="store_true", help="search root tables of the specified table")
78
+ search_direction.add_argument("--upstream", action="store_true", help="search upstream tables (dependencies)")
79
+ search_direction.add_argument("--downstream", action="store_true", help="search downstream tables (dependents)")
80
+ search_direction.add_argument("--all", action="store_true", help="search both upstream and downstream tables")
81
+ search_parser.add_argument("-t", "--table", help="table name to search", required=True)
82
+ search_parser.add_argument("-h", "--help", action="help", default=argparse.SUPPRESS, help="show this help message")
83
+
84
+ # web 子命令
85
+ web_parser = subparsers.add_parser(
86
+ "web",
87
+ parents=[parent_parser],
88
+ help="start web server for visualization",
89
+ add_help=False,
90
+ )
91
+ web_parser.add_argument("--html-path", help="html file path for visualization", default=".")
92
+ web_parser.add_argument("-h", "--help", action="help", default=argparse.SUPPRESS, help="show this help message")
93
+
94
+ return parser.parse_args()
95
+
96
+
97
+ def main():
98
+ args = arg_parse()
99
+
100
+ if args.sql:
101
+ sql_stmt_str = args.sql
102
+ else:
103
+ try:
104
+ sql_stmt_str = read_sql_from_directory(args.path)
105
+ except FileNotFoundError:
106
+ print(f"Error: File not found: {args.path}")
107
+ sys.exit(1)
108
+
109
+ if args.command == "list":
110
+ if args.all:
111
+ output = get_all_tables(sql_stmt_str)
112
+ sub_command_arg = "--all"
113
+ elif args.root:
114
+ output = get_all_root_tables(sql_stmt_str)
115
+ sub_command_arg = "--root"
116
+ elif args.leaf:
117
+ output = get_all_leaf_tables(sql_stmt_str)
118
+ sub_command_arg = "--leaf"
119
+
120
+ if args.output_format == "json":
121
+ print(list_command_json(output, sub_command_arg))
122
+ else:
123
+ print(list_command_text(output))
124
+
125
+ elif args.command == "search":
126
+ if args.root:
127
+ output = search_related_root_tables(sql_stmt_str, args.table)
128
+ sub_command_arg = "--root"
129
+ elif args.upstream:
130
+ output = search_related_upstream_tables(sql_stmt_str, args.table)
131
+ sub_command_arg = "--upstream"
132
+ elif args.downstream:
133
+ output = search_related_downstream_tables(sql_stmt_str, args.table)
134
+ sub_command_arg = "--downstream"
135
+ elif args.all:
136
+ output = search_related_tables(sql_stmt_str, args.table)
137
+ sub_command_arg = "--all"
138
+
139
+ if args.output_format == "json":
140
+ print(search_command_json(output, sub_command_arg))
141
+ elif args.output_format in ["web", "html"]:
142
+ if isinstance(output, tuple):
143
+ visualize_dag(output[1], template_type="mermaid", filename="lineage_mermaid.html")
144
+ else:
145
+ print(output)
146
+ elif args.output_format == "text":
147
+ print(search_command_text(output))
148
+
149
+ elif args.command == "web":
150
+ html_file_path = Path(args.html_path) / "lineage_dagre.html"
151
+ print(f"open web page: {html_file_path}")
152
+ visualize_dag(get_all_dag(sql_stmt_str), template_type="dagre", filename=html_file_path)
153
+ return
sqlh/core/graph.py ADDED
@@ -0,0 +1,385 @@
1
+ """
2
+ DAG (Directed Acyclic Graph) implementation for table dependency tracking.
3
+
4
+ This module provides a graph data structure for representing and analyzing
5
+ table dependencies in SQL queries, with support for:
6
+ - Node and edge management
7
+ - Upstream/downstream traversal
8
+ - Cycle detection
9
+ - Mermaid.js format visualization
10
+
11
+ Example:
12
+ from src.graph import DagGraph
13
+
14
+ dg = DagGraph()
15
+ dg.add_edge("table_a", "table_b")
16
+ dg.add_edge("table_b", "table_c")
17
+ dg.print_all_edges_to_mermaid()
18
+ """
19
+
20
+ from collections import deque
21
+ from pathlib import Path
22
+ from string import Template
23
+ from typing import Literal, Union
24
+
25
+
26
+ class NodeNotFoundException(Exception):
27
+ """Exception raised when a node is not found in the graph."""
28
+
29
+ pass
30
+
31
+
32
+ class NodeExistsException(Exception):
33
+ """Exception raised when attempting to add a node that already exists."""
34
+
35
+ pass
36
+
37
+
38
+ class CycleDetectedException(Exception):
39
+ """Exception raised when a cycle is detected in the graph."""
40
+
41
+ pass
42
+
43
+
44
+ FindResult = Union["DagGraph", NodeNotFoundException]
45
+
46
+
47
+ class DagGraph:
48
+ def __init__(self, nodes: list[str] | None = None, edges: list[tuple[str, str]] | None = None) -> None:
49
+ """
50
+ 初始化DAG图
51
+
52
+ Args:
53
+ nodes: 初始节点列表,默认为空列表
54
+ """
55
+ if nodes is None:
56
+ nodes = []
57
+ self.__nodes = set(nodes) # 使用集合提升查找效率
58
+ if edges is None:
59
+ # edges = []
60
+ self.__edges = set() # 使用集合存储边
61
+ else:
62
+ self.__edges = set(edges) # 使用集合存储边
63
+ # self.__nodes.update({node for edge in edges for node in edge}) # 从边中提取节点并添加到节点集合
64
+ for edge in edges:
65
+ self.__nodes.add(edge[0])
66
+ self.__nodes.add(edge[1])
67
+
68
+ self.__adjacency_list: dict[str, set[str]] = {} # 邻接表,用于快速遍历(下游)
69
+ self.__reverse_adjacency_list: dict[str, set[str]] = {} # 反向邻接表,用于上游查找
70
+ for node in nodes:
71
+ self.__adjacency_list[node] = set()
72
+ self.__reverse_adjacency_list[node] = set()
73
+
74
+ def add_node(self, node: str) -> None:
75
+ """
76
+ 添加节点
77
+
78
+ Args:
79
+ node: 节点对象
80
+
81
+ Raises:
82
+ NodeExistsException: 节点已存在
83
+ """
84
+ if node in self.__nodes:
85
+ raise NodeExistsException(f"节点已存在:{node}")
86
+ self.__nodes.add(node)
87
+ if node not in self.__adjacency_list:
88
+ self.__adjacency_list[node] = set()
89
+ if node not in self.__reverse_adjacency_list:
90
+ self.__reverse_adjacency_list[node] = set()
91
+
92
+ def remove_node(self, node: str) -> None:
93
+ """
94
+ 删除节点及其相关边
95
+
96
+ Args:
97
+ node: 节点对象
98
+
99
+ Raises:
100
+ NodeNotFoundException: 节点不存在
101
+ """
102
+ if node not in self.__nodes:
103
+ raise NodeNotFoundException(f"节点不存在:{node}")
104
+
105
+ # 删除节点
106
+ self.__nodes.discard(node)
107
+
108
+ # 删除与该节点相关的所有边
109
+ edges_to_remove = {(f, t) for f, t in self.__edges if f == node or t == node}
110
+ self.__edges -= edges_to_remove
111
+
112
+ # 更新邻接表
113
+ if node in self.__adjacency_list:
114
+ del self.__adjacency_list[node]
115
+ for adjacent in self.__adjacency_list:
116
+ self.__adjacency_list[adjacent].discard(node)
117
+
118
+ # 更新反向邻接表
119
+ if node in self.__reverse_adjacency_list:
120
+ del self.__reverse_adjacency_list[node]
121
+ for adjacent in self.__reverse_adjacency_list:
122
+ self.__reverse_adjacency_list[adjacent].discard(node)
123
+
124
+ def add_edge(self, _from: str, _to: str) -> None:
125
+ """
126
+ 添加边,如果节点不存在则自动添加
127
+
128
+ Args:
129
+ _from: 起始节点
130
+ _to: 目标节点
131
+
132
+ Raises:
133
+ CycleDetectedException: 如果添加该边会形成环
134
+ """
135
+ # 自动添加不存在的节点
136
+ if _from not in self.__nodes:
137
+ self.add_node(_from)
138
+ if _to not in self.__nodes:
139
+ self.add_node(_to)
140
+
141
+ # 检查是否会形成环
142
+ # if self.__would_create_cycle(_from.name, _to.name):
143
+ # raise CycleDetectedException(f"添加边 {_from} -> {_to} 会形成环")
144
+
145
+ # 添加边
146
+ edge = (_from, _to)
147
+ self.__edges.add(edge)
148
+ self.__adjacency_list[_from].add(_to)
149
+ self.__reverse_adjacency_list[_to].add(_from)
150
+
151
+ def remove_edge(self, _from: str, _to: str) -> None:
152
+ """
153
+ 删除边
154
+
155
+ Args:
156
+ _from: 起始节点
157
+ _to: 目标节点
158
+
159
+ Raises:
160
+ NodeNotFoundException: 节点不存在
161
+ """
162
+ if _from not in self.__nodes:
163
+ raise NodeNotFoundException(f"节点不存在:{_from}")
164
+ if _to not in self.__nodes:
165
+ raise NodeNotFoundException(f"节点不存在:{_to}")
166
+
167
+ # edge = (_from, _to)
168
+ self.__edges.discard((_from, _to))
169
+ if _from in self.__adjacency_list:
170
+ self.__adjacency_list[_from].discard(_to)
171
+ if _to in self.__reverse_adjacency_list:
172
+ self.__reverse_adjacency_list[_to].discard(_from)
173
+
174
+ def get_nodes(self) -> list[str]:
175
+ """
176
+ 获取所有节点
177
+ Returns:
178
+ 节点列表
179
+ """
180
+ return sorted(list(self.__nodes))
181
+
182
+ def get_edges(self) -> list[tuple[str, str]]:
183
+ """
184
+ 获取所有边(去重)
185
+
186
+ Returns:
187
+ 边列表,每个元素为 (from, to) 元组
188
+ """
189
+ return sorted(list(self.__edges))
190
+
191
+ def union(self, other: "DagGraph") -> "DagGraph":
192
+ """
193
+ 合并两个DAG图
194
+
195
+ Args:
196
+ other: 另一个DAG图
197
+
198
+ Returns:
199
+ 合并后的DAG图
200
+ """
201
+ new_nodes = self.get_nodes() + other.get_nodes()
202
+ new_edges = self.get_edges() + other.get_edges()
203
+ return DagGraph(new_nodes, new_edges)
204
+
205
+ @property
206
+ def empty(self) -> bool:
207
+ return len(self.__nodes) == 0
208
+
209
+ def __would_create_cycle(self, from_node: str, to_node: str) -> bool:
210
+ """
211
+ 检查添加边是否会形成环
212
+
213
+ Args:
214
+ from_node: 起始节点
215
+ to_node: 目标节点
216
+
217
+ Returns:
218
+ True 如果会形成环
219
+ """
220
+ if from_node == to_node:
221
+ return True
222
+
223
+ # 从 to_node 开始DFS,看能否到达 from_node
224
+ visited = set()
225
+ stack = [to_node]
226
+
227
+ while stack:
228
+ current = stack.pop()
229
+ if current == from_node:
230
+ return True
231
+ if current not in visited:
232
+ visited.add(current)
233
+ # 获取当前节点的所有下游节点
234
+ if current in self.__adjacency_list:
235
+ stack.extend(self.__adjacency_list[current])
236
+
237
+ return False
238
+
239
+ def has_cycle(self) -> bool:
240
+ """
241
+ 检测图中是否存在环
242
+
243
+ Returns:
244
+ True 如果存在环
245
+ """
246
+ visited = set()
247
+ rec_stack = set()
248
+
249
+ def dfs(node: str) -> bool:
250
+ visited.add(node)
251
+ rec_stack.add(node)
252
+
253
+ for neighbor in self.__adjacency_list.get(node, set()):
254
+ if neighbor not in visited:
255
+ if dfs(neighbor):
256
+ return True
257
+ elif neighbor in rec_stack:
258
+ return True
259
+
260
+ rec_stack.remove(node)
261
+ return False
262
+
263
+ for node in self.__nodes:
264
+ if node not in visited:
265
+ if dfs(node):
266
+ return True
267
+
268
+ return False
269
+
270
+ def find_upstream(self, node: str) -> FindResult:
271
+ """
272
+ 查找所有上游依赖的边
273
+
274
+ Args:
275
+ node: 目标节点
276
+
277
+ Returns:
278
+ 上游边的集合
279
+ """
280
+ if node not in self.__nodes:
281
+ return NodeNotFoundException(f"节点不存在:{node}")
282
+
283
+ queue = deque([node])
284
+ visited = set([node])
285
+ all_relations = []
286
+
287
+ while queue:
288
+ current = queue.popleft()
289
+ # 使用反向邻接表直接查找上游节点,提升性能
290
+ predecessors = self.__reverse_adjacency_list.get(current, set())
291
+ for predecessor in predecessors:
292
+ edge = (predecessor, current)
293
+ all_relations.append(edge)
294
+ if predecessor not in visited:
295
+ visited.add(predecessor)
296
+ queue.append(predecessor)
297
+ return DagGraph(edges=all_relations)
298
+
299
+ def find_downstream(self, node: str) -> FindResult:
300
+ """
301
+ 查找所有下游依赖的边
302
+
303
+ Args:
304
+ node: 起始节点
305
+
306
+ Returns:
307
+ 下游边的集合
308
+ """
309
+ if node not in self.__nodes:
310
+ return NodeNotFoundException(f"节点不存在:{node}")
311
+
312
+ queue = deque([node])
313
+ visited = set([node])
314
+ all_relations = []
315
+
316
+ while queue:
317
+ current = queue.popleft()
318
+ # 使用邻接表直接查找下游节点,提升性能
319
+ neighbors = self.__adjacency_list.get(current, set())
320
+ for neighbor in neighbors:
321
+ edge = (current, neighbor)
322
+ all_relations.append(edge)
323
+ if neighbor not in visited:
324
+ visited.add(neighbor)
325
+ queue.append(neighbor)
326
+
327
+ return DagGraph(edges=all_relations)
328
+
329
+ def to_mermaid(self, direction="LR") -> str:
330
+ """
331
+ 转换为 Mermaid 格式字符串
332
+
333
+ Returns:
334
+ Mermaid格式的图描述字符串
335
+ """
336
+ if not self.__nodes:
337
+ return ""
338
+ else:
339
+ mermaid_str = f"graph {direction}"
340
+ for _from, _to in self.__edges:
341
+ mermaid_str += f"\n {_from} --> {_to}"
342
+ return mermaid_str
343
+
344
+ def to_dict(self) -> dict:
345
+ """
346
+ 转换为字典格式
347
+
348
+ Returns:
349
+ 字典格式的图描述
350
+ """
351
+ nodes = [{"id": node, "label": node.split(".")[:-1]} for node in self.__nodes]
352
+ edges = [{"source": _from, "target": _to} for _from, _to in self.__edges]
353
+ return {"nodes": nodes, "edges": edges, "node_count": len(nodes)}
354
+
355
+ def to_html(self, template_type: Literal["mermaid", "dagre"] = "mermaid") -> str:
356
+ """
357
+ 生成包含Mermaid.js可视化的HTML代码
358
+
359
+ Args:
360
+ edges: 可选的边集合,如果为None则使用所有边
361
+
362
+ Returns:
363
+ 包含Mermaid.js可视化的HTML字符串
364
+ """
365
+ mermaid_content = self.to_mermaid()
366
+ lineage_data = self.to_dict()
367
+ # 读取HTML模板文件
368
+ if template_type == "mermaid":
369
+ template_path = Path(__file__).parent.parent / "static" / "mermaid_template.html"
370
+ elif template_type == "dagre":
371
+ template_path = Path(__file__).parent.parent / "static" / "dagre_template.html"
372
+ else:
373
+ raise ValueError(f"Unknown template type: {template_type}")
374
+
375
+ with open(template_path, encoding="utf-8") as f:
376
+ template = Template(f.read())
377
+
378
+ # 替换模板变量
379
+ html_content = template.safe_substitute(
380
+ title="DAG Visualization",
381
+ mermaid_content=mermaid_content,
382
+ lineage_data=lineage_data
383
+ )
384
+
385
+ return html_content