sqlprism 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,28 @@
1
+ """SQL parser registry and file-type utilities.
2
+
3
+ All SQL dialects are handled by a single ``SqlParser`` class backed by sqlglot.
4
+ The dialect is specified at indexing time per-repo or per-path, not per-parser.
5
+
6
+ Attributes:
7
+ SQL_EXTENSIONS: Set of file extensions recognised as SQL
8
+ (``{".sql", ".ddl", ".hql"}``).
9
+ """
10
+
11
+ # Extensions we recognise as SQL
12
+ SQL_EXTENSIONS: set[str] = {".sql", ".ddl", ".hql"}
13
+
14
+
15
+ def is_sql_file(file_path: str) -> bool:
16
+ """Check if a file path has a recognised SQL extension.
17
+
18
+ Args:
19
+ file_path: File path or name to check.
20
+
21
+ Returns:
22
+ ``True`` if the path ends with any extension in ``SQL_EXTENSIONS``.
23
+ """
24
+ lowered = file_path.lower()
25
+ for ext in SQL_EXTENSIONS:
26
+ if lowered.endswith(ext):
27
+ return True
28
+ return False
@@ -0,0 +1,199 @@
1
+ """dbt model renderer.
2
+
3
+ Runs `dbt compile` via subprocess to expand Jinja macros and resolve refs,
4
+ then reads the compiled SQL from target/compiled/ and feeds each model
5
+ to the standard SQL parser.
6
+
7
+ Unlike sqlmesh, dbt is NOT a Python dependency — we shell out to `dbt compile`
8
+ via `uv run` (so it uses the dbt project's own virtualenv). The user passes
9
+ the path to the dbt project directory and optionally a profiles dir and env file.
10
+
11
+ The venv may live in a parent directory (e.g. dbt/ has the .venv, but the
12
+ actual project is dbt/dp_starrocks/). Use `venv_dir` to control where
13
+ `uv run` executes from (defaults to project_path's parent if a .venv
14
+ is found there).
15
+
16
+ The dbt command can be customised (e.g. "uvx --with dbt-starrocks dbt" or
17
+ just "dbt" if globally installed).
18
+ """
19
+
20
+ import shlex
21
+ import subprocess
22
+ from pathlib import Path
23
+
24
+ from sqlprism.languages.sql import SqlParser
25
+ from sqlprism.languages.sqlmesh import _validate_command
26
+ from sqlprism.languages.utils import build_env, enrich_nodes, find_venv_dir
27
+ from sqlprism.types import ParseResult
28
+
29
+
30
+ class DbtRenderer:
31
+ """Compiles dbt models via ``dbt compile`` and parses the resulting SQL.
32
+
33
+ Shells out to dbt (via ``uv run`` or a custom command) in the project's
34
+ own virtualenv, then reads compiled SQL from ``target/compiled/`` and
35
+ feeds each model through ``SqlParser``. dbt is not a Python dependency
36
+ of the indexer -- it uses whatever version the project has installed.
37
+ """
38
+
39
+ def __init__(self, sql_parser: SqlParser | None = None):
40
+ """Initialise the renderer.
41
+
42
+ Args:
43
+ sql_parser: ``SqlParser`` instance to use for parsing compiled SQL.
44
+ Creates a default instance if not provided.
45
+ """
46
+ self.sql_parser = sql_parser or SqlParser()
47
+
48
+ def render_project(
49
+ self,
50
+ project_path: str | Path,
51
+ profiles_dir: str | Path | None = None,
52
+ env_file: str | Path | None = None,
53
+ target: str | None = None,
54
+ dbt_command: str = "uv run dbt",
55
+ venv_dir: str | Path | None = None,
56
+ dialect: str | None = None,
57
+ schema_catalog: dict | None = None,
58
+ ) -> dict[str, ParseResult]:
59
+ """Compile all dbt models and parse the resulting SQL.
60
+
61
+ Args:
62
+ project_path: Path to dbt project dir (containing dbt_project.yml)
63
+ profiles_dir: Path to directory containing profiles.yml (defaults to project_path)
64
+ env_file: Optional .env file to source before running dbt compile
65
+ target: dbt target name (default: whatever profiles.yml specifies)
66
+ dbt_command: Command to invoke dbt (default: "uv run dbt")
67
+ venv_dir: Directory to run `uv run` from (where .venv lives).
68
+ Defaults to project_path, but auto-detects parent if
69
+ parent has .venv and project_path doesn't.
70
+ dialect: SQL dialect for parsing (e.g. "starrocks", "mysql", "postgres").
71
+ Needed for dialect-specific syntax like backtick quoting.
72
+
73
+ Returns:
74
+ Dict mapping model relative path -> ParseResult
75
+ """
76
+ project_path = Path(project_path).resolve()
77
+ profiles_dir = Path(profiles_dir).resolve() if profiles_dir else project_path
78
+
79
+ # Determine where to run uv from (where .venv lives)
80
+ if venv_dir:
81
+ cwd = Path(venv_dir).resolve()
82
+ else:
83
+ cwd = find_venv_dir(project_path)
84
+
85
+ # Use dialect-specific parser if needed (e.g. starrocks uses backticks)
86
+ parser = self.sql_parser
87
+ if dialect and dialect != getattr(parser, "dialect", None):
88
+ parser = SqlParser(dialect=dialect)
89
+
90
+ env = build_env(env_file)
91
+
92
+ # Run dbt compile
93
+ self._run_dbt_compile(
94
+ project_path=project_path,
95
+ profiles_dir=profiles_dir,
96
+ cwd=cwd,
97
+ env=env,
98
+ target=target,
99
+ dbt_command=dbt_command,
100
+ )
101
+
102
+ # Read dbt_project.yml to get the project name (for compiled path)
103
+ project_name = self._get_project_name(project_path)
104
+
105
+ # Read compiled SQL files from target/compiled/<project_name>/models/
106
+ compiled_dir = project_path / "target" / "compiled" / project_name / "models"
107
+ if not compiled_dir.exists():
108
+ return {}
109
+
110
+ results: dict[str, ParseResult] = {}
111
+ for sql_file in compiled_dir.rglob("*.sql"):
112
+ rel_path = str(sql_file.relative_to(compiled_dir))
113
+ content = sql_file.read_text(errors="replace")
114
+ if not content.strip():
115
+ continue
116
+
117
+ # dbt compiled SQL is bare SELECT — wrap as CREATE TABLE
118
+ # so the SQL parser extracts nodes, edges, and column usage.
119
+ # Derive model name from file stem, schema from parent directory.
120
+ path_parts = rel_path.removesuffix(".sql").split("/")
121
+ model_name = path_parts[-1] # e.g. "orders"
122
+ # e.g. "staging"
123
+ model_schema = "/".join(path_parts[:-1]) if len(path_parts) > 1 else None
124
+
125
+ # Quote names to handle dashes and special chars
126
+ safe_name = model_name.replace('"', '""')
127
+ if model_schema:
128
+ safe_schema = model_schema.replace('"', '""')
129
+ wrapped_sql = f'CREATE TABLE "{safe_schema}"."{safe_name}" AS\n{content}'
130
+ else:
131
+ wrapped_sql = f'CREATE TABLE "{safe_name}" AS\n{content}'
132
+
133
+ result = parser.parse(rel_path, wrapped_sql, schema=schema_catalog)
134
+ enrich_nodes(result, "dbt_model", rel_path)
135
+
136
+ results[rel_path] = result
137
+
138
+ return results
139
+
140
+ def _run_dbt_compile(
141
+ self,
142
+ project_path: Path,
143
+ profiles_dir: Path,
144
+ cwd: Path,
145
+ env: dict[str, str],
146
+ target: str | None,
147
+ dbt_command: str,
148
+ ) -> None:
149
+ """Run dbt compile, pointing at the project directory."""
150
+ _validate_command(dbt_command, allowed_keywords={"dbt", "uv", "uvx"})
151
+ cmd = shlex.split(dbt_command) + [
152
+ "compile",
153
+ "--project-dir",
154
+ str(project_path),
155
+ "--profiles-dir",
156
+ str(profiles_dir),
157
+ ]
158
+ if target:
159
+ cmd.extend(["--target", target])
160
+
161
+ result = subprocess.run(
162
+ cmd,
163
+ cwd=cwd,
164
+ env=env,
165
+ capture_output=True,
166
+ text=True,
167
+ timeout=300, # 5 min timeout for large projects
168
+ )
169
+
170
+ if result.returncode != 0:
171
+ raise RuntimeError(f"dbt compile failed (exit {result.returncode}):\n{result.stderr}")
172
+
173
+ def _get_project_name(self, project_path: Path) -> str:
174
+ """Read project name from dbt_project.yml."""
175
+ dbt_project_file = project_path / "dbt_project.yml"
176
+ if not dbt_project_file.exists():
177
+ raise FileNotFoundError(f"No dbt_project.yml found in {project_path}")
178
+
179
+ content = dbt_project_file.read_text()
180
+
181
+ # Try proper YAML parsing first (pyyaml may be available via dbt)
182
+ try:
183
+ import yaml
184
+
185
+ data = yaml.safe_load(content)
186
+ if isinstance(data, dict) and "name" in data:
187
+ return str(data["name"])
188
+ except (ImportError, Exception):
189
+ pass
190
+
191
+ # Fallback: line scanning for top-level name: field
192
+ for line in content.splitlines():
193
+ if line.lstrip().startswith("#"):
194
+ continue
195
+ if line.startswith("name:"): # only match unindented
196
+ name = line.split(":", 1)[1].strip().strip("'\"")
197
+ return name
198
+
199
+ raise ValueError(f"Could not find 'name:' in {dbt_project_file}")