sql-dag-flow 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
sql_dag_flow/main.py ADDED
@@ -0,0 +1,203 @@
1
+ from fastapi import FastAPI, HTTPException, Body
2
+ from pydantic import BaseModel
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from fastapi.staticfiles import StaticFiles
5
+ from fastapi.responses import FileResponse
6
+ import uvicorn
7
+ import os
8
+ import sys
9
+ import json
10
+ import webbrowser
11
+ import threading
12
+ import time
13
+ from .parser import parse_sql_files, build_graph
14
+
15
+ app = FastAPI()
16
+
17
+ # Enable CORS for frontend
18
+ app.add_middleware(
19
+ CORSMiddleware,
20
+ allow_origins=["*"],
21
+ allow_credentials=True,
22
+ allow_methods=["*"],
23
+ allow_headers=["*"],
24
+ )
25
+
26
+ # Package structure
27
+ # __file__ is inside src/sql_dag_flow/main.py
28
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
29
+ STATIC_DIR = os.path.join(BASE_DIR, "static")
30
+
31
+ # Global state
32
+ CURRENT_DIRECTORY = os.getcwd() # Default, updated by start()
33
+ DIAGRAM_FILE = "sql_diagram.json"
34
+
35
+ @app.get("/graph")
36
+ def get_graph(dialect: str = "bigquery"):
37
+ """Parses SQL files in the current directory and returns graph data."""
38
+ if not os.path.exists(CURRENT_DIRECTORY):
39
+ return {"nodes": [], "edges": [], "error": "Directory not found"}
40
+
41
+ tables = parse_sql_files(CURRENT_DIRECTORY, dialect=dialect)
42
+ nodes, edges = build_graph(tables)
43
+ return {"nodes": nodes, "edges": edges}
44
+
45
+ @app.post("/config/path")
46
+ def set_path(path_data: dict = Body(...)):
47
+ """Updates the directory to scan."""
48
+ global CURRENT_DIRECTORY
49
+ path = path_data.get("path")
50
+ # Basic validation
51
+ if not path or not os.path.exists(path):
52
+ raise HTTPException(status_code=400, detail="Directory does not exist")
53
+
54
+
55
+ CURRENT_DIRECTORY = path
56
+ return {"message": "Path updated", "path": CURRENT_DIRECTORY}
57
+
58
+ @app.post("/scan/folders")
59
+ def scan_folders(path_data: dict = Body(...)):
60
+ """Scans a directory and returns all subfolders (recursive, relative paths)."""
61
+ path = path_data.get("path")
62
+ if not path or not os.path.exists(path):
63
+ raise HTTPException(status_code=400, detail="Directory does not exist")
64
+
65
+ try:
66
+ subfolders = []
67
+ # Walk the directory tree
68
+ for root, dirs, files in os.walk(path):
69
+ # Skip hidden folders
70
+ dirs[:] = [d for d in dirs if not d.startswith('.')]
71
+
72
+ for d in dirs:
73
+ # Create relative path from the root path
74
+ full_path = os.path.join(root, d)
75
+ rel_path = os.path.relpath(full_path, path)
76
+ # Normalize separators to forward slashes for consistency
77
+ rel_path = rel_path.replace(os.sep, '/')
78
+ subfolders.append(rel_path)
79
+
80
+ # Sort for better UX
81
+ subfolders.sort()
82
+ return {"folders": subfolders}
83
+ except Exception as e:
84
+ raise HTTPException(status_code=500, detail=str(e))
85
+
86
+ @app.post("/graph/filtered")
87
+ def get_filtered_graph(data: dict = Body(...)):
88
+ """Parses SQL files with subfolder filtering."""
89
+ if not os.path.exists(CURRENT_DIRECTORY):
90
+ return {"nodes": [], "edges": [], "error": "Directory not found"}
91
+
92
+ subfolders = data.get("subfolders") # List of strings or None
93
+ dialect = data.get("dialect", "bigquery")
94
+ tables = parse_sql_files(CURRENT_DIRECTORY, allowed_subfolders=subfolders, dialect=dialect)
95
+ nodes, edges = build_graph(tables)
96
+ return {"nodes": nodes, "edges": edges}
97
+
98
+ @app.get("/config/path")
99
+ def get_path():
100
+ return {"path": CURRENT_DIRECTORY}
101
+
102
+ class SaveRequest(BaseModel):
103
+ nodes: list
104
+ edges: list
105
+ viewport: dict
106
+ metadata: dict
107
+ filename: str = "sql_diagram.json" # Default filename
108
+
109
+ @app.post("/save")
110
+ def save_graph(request: SaveRequest):
111
+ try:
112
+ # Use the path from metadata if available, otherwise default
113
+ path = request.metadata.get("path", ".")
114
+ if not os.path.isabs(path):
115
+ path = os.path.abspath(path)
116
+
117
+ filepath = os.path.join(path, request.filename)
118
+
119
+ data = {
120
+ "nodes": request.nodes,
121
+ "edges": request.edges,
122
+ "viewport": request.viewport,
123
+ "metadata": request.metadata
124
+ }
125
+ with open(filepath, "w") as f:
126
+ json.dump(data, f, indent=4)
127
+ return {"message": f"Graph saved successfully to {filepath}"}
128
+ except Exception as e:
129
+ raise HTTPException(status_code=500, detail=str(e))
130
+
131
+ @app.get("/load")
132
+ def load_graph(path: str = ".", filename: str = "sql_diagram.json"):
133
+ try:
134
+ if not os.path.isabs(path):
135
+ path = os.path.abspath(path)
136
+
137
+ filepath = os.path.join(path, filename)
138
+
139
+ if not os.path.exists(filepath):
140
+ return {"nodes": [], "edges": [], "viewport": {"x": 0, "y": 0, "zoom": 1}, "metadata": {}}
141
+
142
+ with open(filepath, "r") as f:
143
+ data = json.load(f)
144
+ return data
145
+ except Exception as e:
146
+ print(f"Error loading graph: {e}")
147
+ return {"nodes": [], "edges": [], "viewport": {"x": 0, "y": 0, "zoom": 1}, "metadata": {}}
148
+
149
+ @app.get("/config_files")
150
+ def list_config_files(path: str = "."):
151
+ try:
152
+ if not os.path.isabs(path):
153
+ path = os.path.abspath(path)
154
+
155
+ if not os.path.exists(path):
156
+ return {"files": []}
157
+
158
+ files = [f for f in os.listdir(path) if f.endswith(".json") and os.path.isfile(os.path.join(path, f))]
159
+ return {"files": files}
160
+ except Exception as e:
161
+ print(f"Error listing config files: {e}")
162
+ return {"files": []}
163
+
164
+ # Serve Static Files (Frontend)
165
+ if os.path.exists(STATIC_DIR):
166
+ app.mount("/assets", StaticFiles(directory=os.path.join(STATIC_DIR, "assets")), name="assets")
167
+
168
+ # Catch-all for SPA routing
169
+ @app.get("/{full_path:path}")
170
+ async def serve_spa(full_path: str):
171
+ file_path = os.path.join(STATIC_DIR, full_path)
172
+ if os.path.isfile(file_path):
173
+ return FileResponse(file_path)
174
+ return FileResponse(os.path.join(STATIC_DIR, "index.html"))
175
+
176
+ def start():
177
+ """Entry point for the CLI tool."""
178
+ global CURRENT_DIRECTORY
179
+
180
+ # CLI Argument Parsing
181
+ if len(sys.argv) > 1:
182
+ path_arg = sys.argv[1]
183
+ if os.path.exists(path_arg):
184
+ CURRENT_DIRECTORY = os.path.abspath(path_arg)
185
+ print(f"Setting project path from CLI: {CURRENT_DIRECTORY}")
186
+ else:
187
+ print(f"Warning: Path '{path_arg}' does not exist. Using defaults.")
188
+ else:
189
+ CURRENT_DIRECTORY = os.getcwd()
190
+ print(f"Using current directory: {CURRENT_DIRECTORY}")
191
+
192
+ def open_browser():
193
+ time.sleep(1.5)
194
+ webbrowser.open("http://localhost:8000")
195
+
196
+ threading.Thread(target=open_browser, daemon=True).start()
197
+
198
+ # Run uvicorn programmatically
199
+ # Note: When running programmatically, reload=True is not supported easily without other hacks
200
+ uvicorn.run(app, host="127.0.0.1", port=8000)
201
+
202
+ if __name__ == "__main__":
203
+ start()
sql_dag_flow/parser.py ADDED
@@ -0,0 +1,290 @@
1
+ import os
2
+ import sqlglot
3
+ from sqlglot import exp
4
+ import networkx as nx
5
+
6
+ import os
7
+ import sqlglot
8
+ from sqlglot import exp
9
+ import networkx as nx
10
+ import re
11
+
12
+ def parse_sql_files(directory, allowed_subfolders=None, dialect="bigquery"):
13
+ """
14
+ Recursively scans a directory for .sql files and parses them.
15
+ Returns a dictionary mapping table names to their dependencies and metadata.
16
+ """
17
+ tables = {}
18
+
19
+ for root, dirs, files in os.walk(directory):
20
+ # Filter subfolders if allowed_subfolders is specified
21
+ if allowed_subfolders is not None:
22
+ # allowed_subfolders contains relative paths like "sub1", "sub1/nested"
23
+ # We must prune 'dirs' so we only traverse relevant paths.
24
+
25
+ rel_root = os.path.relpath(root, directory).replace(os.sep, '/')
26
+ if rel_root == ".": rel_root = ""
27
+
28
+ allowed_dirs = []
29
+ for d in dirs:
30
+ rel_d = f"{rel_root}/{d}" if rel_root else d
31
+ # Keep 'd' if:
32
+ # 1. rel_d is exactly one of the allowed paths
33
+ # 2. rel_d is a parent of an allowed path (e.g. 'sub1' parent of 'sub1/nested')
34
+ # 3. rel_d is inside an allowed path (e.g. 'sub1/nested' inside 'sub1' which is allowed)
35
+
36
+ is_allowed = False
37
+ for allowed in allowed_subfolders:
38
+ if rel_d == allowed:
39
+ is_allowed = True
40
+ break
41
+ if allowed.startswith(rel_d + '/'): # rel_d is parent
42
+ is_allowed = True
43
+ break
44
+ if rel_d.startswith(allowed + '/'): # rel_d is child
45
+ is_allowed = True
46
+ break
47
+
48
+ if is_allowed:
49
+ allowed_dirs.append(d)
50
+
51
+ dirs[:] = allowed_dirs
52
+
53
+ # Check if the current directory is valid for file parsing
54
+ # We only parse files if we are IN a selected folder or a SUBFOLDER of a selected folder.
55
+ # We do NOT parse files if we are just traversing a PARENT folder to get to a selected one.
56
+ should_parse_files = True
57
+ if allowed_subfolders is not None:
58
+ should_parse_files = False
59
+ rel_root_check = os.path.relpath(root, directory).replace(os.sep, '/')
60
+ if rel_root_check == ".": rel_root_check = ""
61
+
62
+ # 1. Decide if we should parse files in THIS folder
63
+ if rel_root_check in allowed_subfolders:
64
+ should_parse_files = True
65
+
66
+ # 2. Prune 'dirs' to only traverse towards allowed folders
67
+ allowed_dirs = []
68
+ for d in dirs:
69
+ rel_d = f"{rel_root_check}/{d}" if rel_root_check else d
70
+
71
+ # Keep 'd' if:
72
+ # A. It is explicitly in the allowed list (so we can go there and parse)
73
+ # B. It is an ANCESTOR of something in the allowed list (so we can reach the allowed child)
74
+
75
+ is_traversal_allowed = False
76
+ if rel_d in allowed_subfolders:
77
+ is_traversal_allowed = True
78
+ else:
79
+ # Check if it's an ancestor
80
+ for allowed in allowed_subfolders:
81
+ if allowed.startswith(rel_d + '/'):
82
+ is_traversal_allowed = True
83
+ break
84
+
85
+ if is_traversal_allowed:
86
+ allowed_dirs.append(d)
87
+
88
+ dirs[:] = allowed_dirs
89
+
90
+ if not should_parse_files:
91
+ continue
92
+
93
+ for file in files:
94
+ if file.endswith(".sql"):
95
+ filepath = os.path.join(root, file)
96
+ # Heuristic for table name: filename without extension
97
+ filename_base = os.path.splitext(file)[0]
98
+
99
+ # Layer detection based on folder structure first, then filename
100
+ lower_path = filepath.lower()
101
+ layer = "other"
102
+ if "bronze" in lower_path or "bronce" in lower_path:
103
+ layer = "bronze"
104
+ elif "silver" in lower_path:
105
+ layer = "silver"
106
+ elif "gold" in lower_path:
107
+ layer = "gold"
108
+
109
+ with open(filepath, "r", encoding="utf-8") as f:
110
+ sql_content = f.read()
111
+
112
+ try:
113
+ # Parse with BigQuery dialect to support CREATE OR REPLACE TABLE/VIEW
114
+ parsed = sqlglot.parse_one(sql_content, read=dialect)
115
+
116
+ # Detect Node Type (Table or View)
117
+ node_type = "table" # default
118
+ if isinstance(parsed, exp.Create):
119
+ if parsed.kind == "VIEW":
120
+ node_type = "view"
121
+
122
+ # Attempt to extract Project and Dataset from the CREATE statement
123
+ # pattern: project.dataset.table or dataset.table
124
+ # We look for the creation target
125
+ target_table_name = filename_base
126
+ project = "default"
127
+ dataset = "default"
128
+
129
+ create_node = parsed.find(exp.Create)
130
+ if create_node and create_node.this:
131
+ # sqlglot represents the target as an exp.Table or exp.Schema
132
+ target_exp = create_node.this
133
+ if isinstance(target_exp, exp.Table):
134
+ target_table_name = target_exp.name
135
+ dataset = target_exp.db or "default"
136
+ project = target_exp.catalog or "default"
137
+
138
+ # Fallback: Extract from filename (project.dataset.table.sql)
139
+ if project == "default" and dataset == "default":
140
+ parts = filename_base.split('.')
141
+ if len(parts) == 3:
142
+ project, dataset, target_table_name = parts
143
+ elif len(parts) == 2:
144
+ dataset, target_table_name = parts
145
+
146
+ # Fallback: Extract from directory structure if straightforward
147
+ # e.g. /project/dataset/table.sql
148
+ if project == "default" and dataset == "default":
149
+ path_parts = os.path.normpath(filepath).split(os.sep)
150
+ # Simple heuristic: parent dir is dataset, grandparent is project?
151
+ # This is risky without strict structure, so maybe just stick to filename for now.
152
+ # Or just capture parent folder as dataset if it's not the layer name
153
+ parent_dir = path_parts[-2] if len(path_parts) > 1 else ""
154
+ if parent_dir.lower() not in ["bronze", "bronce", "silver", "gold", "other"] and dataset == "default":
155
+ dataset = parent_dir
156
+
157
+ dependencies = set()
158
+
159
+ # Find all tables referenced in the query
160
+ for table in parsed.find_all(exp.Table):
161
+ dep_name = table.name
162
+ # Construct full name if available to match lookup
163
+ full_name = dep_name
164
+ if table.db:
165
+ full_name = f"{table.db}.{dep_name}"
166
+ if table.catalog:
167
+ full_name = f"{table.catalog}.{table.db}.{dep_name}"
168
+
169
+ # Avoid self-reference if it matches the target
170
+ if dep_name == target_table_name:
171
+ continue
172
+
173
+ # If we haven't found a CREATE statement, this might just be a SELECT
174
+ # and we treat the filename as the target.
175
+
176
+ dependencies.add(full_name)
177
+ # REMOVED: partial match addition to prevent double counting in visual metadata
178
+ # matches are now handled in build_graph via fuzzy lookup
179
+
180
+ tables[filename_base] = {
181
+ # Use filename_base as unique ID for the graph to avoid ambiguity
182
+ # Visual label can be the actual table name
183
+ "id": filename_base,
184
+ "label": target_table_name,
185
+ "layer": layer,
186
+ "type": node_type,
187
+ "project": project,
188
+ "dataset": dataset,
189
+ "path": filepath,
190
+ "dependencies": list(dependencies),
191
+ "content": sql_content
192
+ }
193
+ except Exception as e:
194
+ print(f"Error parsing {filepath}: {e}")
195
+ tables[filename_base] = {
196
+ "id": filename_base,
197
+ "label": filename_base,
198
+ "layer": layer,
199
+ "type": "unknown",
200
+ "project": "n/a",
201
+ "dataset": "n/a",
202
+ "path": filepath,
203
+ "dependencies": [],
204
+ "error": str(e),
205
+ "content": sql_content
206
+ }
207
+
208
+ return tables
209
+
210
+
211
+ def build_graph(tables):
212
+ """
213
+ Constructs nodes and edges for React Flow.
214
+ """
215
+ nodes = []
216
+ edges = []
217
+
218
+ # Create a lookup map: identifier -> node_id
219
+ lookup = {}
220
+
221
+ for node_id, data in tables.items():
222
+ lookup[node_id] = node_id
223
+ if "label" in data:
224
+ lookup[data["label"]] = node_id
225
+
226
+ project = data.get("project", "default")
227
+ dataset = data.get("dataset", "default")
228
+ table = data.get("label", "")
229
+
230
+ if table:
231
+ if dataset != "default":
232
+ lookup[f"{dataset}.{table}"] = node_id
233
+ if project != "default":
234
+ lookup[f"{project}.{dataset}.{table}"] = node_id
235
+
236
+ # Track incoming edges for accurate dependency counting
237
+ incoming_edges_count = {node_id: 0 for node_id in tables}
238
+
239
+ # Create edges first (conceptually) to count dependencies
240
+ for source_id, data in tables.items():
241
+ for dep in data["dependencies"]:
242
+ target_id = lookup.get(dep)
243
+
244
+ # Fuzzy lookup: if exact match fails, try splitting by dot and matching last part (table name)
245
+ if not target_id and "." in dep:
246
+ short_name = dep.split(".")[-1]
247
+ target_id = lookup.get(short_name)
248
+
249
+ if target_id and target_id != source_id:
250
+ edges.append({
251
+ "id": f"{target_id}-{source_id}",
252
+ "source": target_id,
253
+ "target": source_id,
254
+ "animated": True,
255
+ "style": {"stroke": "#b1b1b7"}
256
+ })
257
+ incoming_edges_count[source_id] = incoming_edges_count.get(source_id, 0) + 1
258
+ else:
259
+ pass
260
+
261
+ # Create nodes with edge count info
262
+ # First, build a NetworkX graph to calculate transitive dependencies (nested deps)
263
+ G = nx.DiGraph()
264
+ for edge in edges:
265
+ G.add_edge(edge["source"], edge["target"])
266
+
267
+ for table_name, data in tables.items():
268
+ # Calculate nested dependencies (all ancestors in the dependency graph)
269
+ nested_count = 0
270
+ if G.has_node(table_name):
271
+ try:
272
+ # ancestors() returns all nodes u such that there is a path from u to table_name
273
+ nested_count = len(nx.ancestors(G, table_name))
274
+ except Exception:
275
+ pass # distinct graph parts or cycles? cycles shouldn't exist in DAG but safety first
276
+
277
+ nodes.append({
278
+ "id": table_name,
279
+ "data": {
280
+ "label": data["label"],
281
+ "layer": data["layer"],
282
+ "details": data,
283
+ "incomingCount": incoming_edges_count.get(table_name, 0),
284
+ "nestedCount": nested_count
285
+ },
286
+ "position": {"x": 0, "y": 0},
287
+ "type": "custom",
288
+ })
289
+
290
+ return nodes, edges