iflow-mcp_niclasolofsson-dbt-core-mcp 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dbt_core_mcp/__init__.py +18 -0
- dbt_core_mcp/__main__.py +436 -0
- dbt_core_mcp/context.py +459 -0
- dbt_core_mcp/cte_generator.py +601 -0
- dbt_core_mcp/dbt/__init__.py +1 -0
- dbt_core_mcp/dbt/bridge_runner.py +1361 -0
- dbt_core_mcp/dbt/manifest.py +781 -0
- dbt_core_mcp/dbt/runner.py +67 -0
- dbt_core_mcp/dependencies.py +50 -0
- dbt_core_mcp/server.py +381 -0
- dbt_core_mcp/tools/__init__.py +77 -0
- dbt_core_mcp/tools/analyze_impact.py +78 -0
- dbt_core_mcp/tools/build_models.py +190 -0
- dbt_core_mcp/tools/demo/__init__.py +1 -0
- dbt_core_mcp/tools/demo/hello.html +267 -0
- dbt_core_mcp/tools/demo/ui_demo.py +41 -0
- dbt_core_mcp/tools/get_column_lineage.py +1988 -0
- dbt_core_mcp/tools/get_lineage.py +89 -0
- dbt_core_mcp/tools/get_project_info.py +96 -0
- dbt_core_mcp/tools/get_resource_info.py +134 -0
- dbt_core_mcp/tools/install_deps.py +102 -0
- dbt_core_mcp/tools/list_resources.py +84 -0
- dbt_core_mcp/tools/load_seeds.py +179 -0
- dbt_core_mcp/tools/query_database.py +459 -0
- dbt_core_mcp/tools/run_models.py +234 -0
- dbt_core_mcp/tools/snapshot_models.py +120 -0
- dbt_core_mcp/tools/test_models.py +238 -0
- dbt_core_mcp/utils/__init__.py +1 -0
- dbt_core_mcp/utils/env_detector.py +186 -0
- dbt_core_mcp/utils/process_check.py +130 -0
- dbt_core_mcp/utils/tool_utils.py +411 -0
- dbt_core_mcp/utils/warehouse_adapter.py +82 -0
- dbt_core_mcp/utils/warehouse_databricks.py +297 -0
- iflow_mcp_niclasolofsson_dbt_core_mcp-1.7.0.dist-info/METADATA +784 -0
- iflow_mcp_niclasolofsson_dbt_core_mcp-1.7.0.dist-info/RECORD +38 -0
- iflow_mcp_niclasolofsson_dbt_core_mcp-1.7.0.dist-info/WHEEL +4 -0
- iflow_mcp_niclasolofsson_dbt_core_mcp-1.7.0.dist-info/entry_points.txt +2 -0
- iflow_mcp_niclasolofsson_dbt_core_mcp-1.7.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,411 @@
|
|
|
1
|
+
"""Utility functions for dbt tools.
|
|
2
|
+
|
|
3
|
+
Helper methods for result parsing, progress reporting, schema querying, and state management.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
import shutil
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from fastmcp.server.context import Context
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def parse_run_results(project_dir: Path | None) -> dict[str, Any]:
|
|
18
|
+
"""Parse target/run_results.json after dbt run/test/build.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
Dictionary with results array and metadata
|
|
22
|
+
"""
|
|
23
|
+
if not project_dir:
|
|
24
|
+
return {"results": [], "elapsed_time": 0}
|
|
25
|
+
|
|
26
|
+
run_results_path = project_dir / "target" / "run_results.json"
|
|
27
|
+
if not run_results_path.exists():
|
|
28
|
+
return {"results": [], "elapsed_time": 0}
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
with open(run_results_path, encoding="utf-8") as f:
|
|
32
|
+
data = json.load(f)
|
|
33
|
+
|
|
34
|
+
# Simplify results for output
|
|
35
|
+
simplified_results = []
|
|
36
|
+
for result in data.get("results", []):
|
|
37
|
+
simplified_result = {
|
|
38
|
+
"unique_id": result.get("unique_id"),
|
|
39
|
+
"status": result.get("status"),
|
|
40
|
+
"message": result.get("message"),
|
|
41
|
+
"execution_time": result.get("execution_time"),
|
|
42
|
+
"failures": result.get("failures"),
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
# Include additional diagnostic fields for failed tests
|
|
46
|
+
if result.get("status") in ("fail", "error"):
|
|
47
|
+
simplified_result["compiled_code"] = result.get("compiled_code")
|
|
48
|
+
simplified_result["adapter_response"] = result.get("adapter_response")
|
|
49
|
+
|
|
50
|
+
simplified_results.append(simplified_result)
|
|
51
|
+
|
|
52
|
+
return {
|
|
53
|
+
"results": simplified_results,
|
|
54
|
+
"elapsed_time": data.get("elapsed_time", 0),
|
|
55
|
+
}
|
|
56
|
+
except Exception as e:
|
|
57
|
+
logger.warning(f"Failed to parse run_results.json: {e}")
|
|
58
|
+
return {"results": [], "elapsed_time": 0}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def validate_and_parse_results(project_dir: Path | None, result: Any, command_name: str) -> dict[str, Any]:
|
|
62
|
+
"""Parse run_results.json and validate execution succeeded.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
project_dir: Path to the dbt project
|
|
66
|
+
result: The execution result from dbt runner
|
|
67
|
+
command_name: Name of dbt command (e.g., "run", "test", "build", "seed")
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Parsed run_results dictionary
|
|
71
|
+
|
|
72
|
+
Raises:
|
|
73
|
+
RuntimeError: If dbt failed before execution (parse error, connection failure, etc.)
|
|
74
|
+
"""
|
|
75
|
+
run_results = parse_run_results(project_dir)
|
|
76
|
+
|
|
77
|
+
if not run_results.get("results"):
|
|
78
|
+
# No results means dbt failed before execution
|
|
79
|
+
if result and not result.success:
|
|
80
|
+
error_msg = str(result.exception) if result.exception else f"dbt {command_name} execution failed"
|
|
81
|
+
# Extract specific error from stdout if available
|
|
82
|
+
if result.stdout and "Error" in result.stdout:
|
|
83
|
+
lines = result.stdout.split("\n")
|
|
84
|
+
for i, line in enumerate(lines):
|
|
85
|
+
if "Error" in line or "error" in line:
|
|
86
|
+
error_msg = "\n".join(lines[i : min(i + 5, len(lines))]).strip()
|
|
87
|
+
break
|
|
88
|
+
else:
|
|
89
|
+
# Include full stdout/stderr for debugging when no specific error found
|
|
90
|
+
stdout_preview = (result.stdout[:500] + "...") if result.stdout and len(result.stdout) > 500 else (result.stdout or "(no stdout)")
|
|
91
|
+
stderr_preview = (result.stderr[:500] + "...") if result.stderr and len(result.stderr) > 500 else (result.stderr or "(no stderr)")
|
|
92
|
+
error_msg = f"{error_msg}\nstdout: {stdout_preview}\nstderr: {stderr_preview}"
|
|
93
|
+
raise RuntimeError(f"dbt {command_name} failed to execute: {error_msg}")
|
|
94
|
+
|
|
95
|
+
return run_results
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
async def report_final_progress(
|
|
99
|
+
ctx: Context | None,
|
|
100
|
+
results_list: list[dict[str, Any]],
|
|
101
|
+
command_name: str,
|
|
102
|
+
resource_type: str,
|
|
103
|
+
) -> None:
|
|
104
|
+
"""Report final progress with status breakdown.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
ctx: MCP context for progress reporting
|
|
108
|
+
results_list: List of result dictionaries from dbt execution
|
|
109
|
+
command_name: Command prefix for message (e.g., "Run", "Test", "Build")
|
|
110
|
+
resource_type: Resource type for message (e.g., "models", "tests", "resources")
|
|
111
|
+
"""
|
|
112
|
+
if not ctx:
|
|
113
|
+
return
|
|
114
|
+
|
|
115
|
+
if not results_list:
|
|
116
|
+
await ctx.report_progress(progress=0, total=0, message=f"0 {resource_type} matched selector")
|
|
117
|
+
return
|
|
118
|
+
|
|
119
|
+
# Count statuses - different commands use different status values
|
|
120
|
+
total = len(results_list)
|
|
121
|
+
passed_count = sum(1 for r in results_list if r.get("status") in ("success", "pass"))
|
|
122
|
+
failed_count = sum(1 for r in results_list if r.get("status") in ("error", "fail"))
|
|
123
|
+
skip_count = sum(1 for r in results_list if r.get("status") in ("skipped", "skip"))
|
|
124
|
+
warn_count = sum(1 for r in results_list if r.get("status") == "warn")
|
|
125
|
+
|
|
126
|
+
# Build status parts
|
|
127
|
+
parts = []
|
|
128
|
+
if passed_count > 0:
|
|
129
|
+
# Use "All passed" only if no other statuses present
|
|
130
|
+
has_other_statuses = failed_count > 0 or warn_count > 0 or skip_count > 0
|
|
131
|
+
parts.append(f"✅ {passed_count} passed" if has_other_statuses else "✅ All passed")
|
|
132
|
+
if failed_count > 0:
|
|
133
|
+
parts.append(f"❌ {failed_count} failed")
|
|
134
|
+
if warn_count > 0:
|
|
135
|
+
parts.append(f"⚠️ {warn_count} warned")
|
|
136
|
+
if skip_count > 0:
|
|
137
|
+
parts.append(f"⏭️ {skip_count} skipped")
|
|
138
|
+
|
|
139
|
+
summary = f"{command_name}: {total}/{total} {resource_type} completed ({', '.join(parts)})"
|
|
140
|
+
await ctx.report_progress(progress=total, total=total, message=summary)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
async def get_table_schema_from_db(runner: Any, model_name: str, source_name: str | None = None) -> list[dict[str, Any]]:
|
|
144
|
+
"""Get full table schema from database using DESCRIBE.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
runner: BridgeRunner instance
|
|
148
|
+
model_name: Name of the model/table
|
|
149
|
+
source_name: If provided, treat as source and use source() instead of ref()
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
List of column dictionaries with details (column_name, column_type, null, etc.)
|
|
153
|
+
Empty list if query fails or table doesn't exist
|
|
154
|
+
"""
|
|
155
|
+
try:
|
|
156
|
+
if source_name:
|
|
157
|
+
sql = f"DESCRIBE {{{{ source('{source_name}', '{model_name}') }}}}"
|
|
158
|
+
else:
|
|
159
|
+
sql = f"DESCRIBE {{{{ ref('{model_name}') }}}}"
|
|
160
|
+
result = await runner.invoke_query(sql) # type: ignore
|
|
161
|
+
|
|
162
|
+
if not result.success or not result.stdout:
|
|
163
|
+
return []
|
|
164
|
+
|
|
165
|
+
# Parse JSON output using robust regex + JSONDecoder
|
|
166
|
+
import re
|
|
167
|
+
|
|
168
|
+
json_match = re.search(r'\{\s*"show"\s*:\s*\[', result.stdout)
|
|
169
|
+
if not json_match:
|
|
170
|
+
return []
|
|
171
|
+
|
|
172
|
+
decoder = json.JSONDecoder()
|
|
173
|
+
data, _ = decoder.raw_decode(result.stdout, json_match.start())
|
|
174
|
+
|
|
175
|
+
if "show" in data:
|
|
176
|
+
return data["show"] # type: ignore[no-any-return]
|
|
177
|
+
|
|
178
|
+
return []
|
|
179
|
+
except Exception as e:
|
|
180
|
+
logger.warning(f"Failed to query table schema for {model_name}: {e}")
|
|
181
|
+
return []
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
async def get_table_columns_from_db(runner: Any, model_name: str) -> list[str]:
|
|
185
|
+
"""Get actual column names from database table.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
runner: BridgeRunner instance
|
|
189
|
+
model_name: Name of the model
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
List of column names from the actual table
|
|
193
|
+
"""
|
|
194
|
+
schema = await get_table_schema_from_db(runner, model_name)
|
|
195
|
+
if not schema:
|
|
196
|
+
return []
|
|
197
|
+
|
|
198
|
+
# Extract column names from schema
|
|
199
|
+
columns: list[str] = []
|
|
200
|
+
for row in schema:
|
|
201
|
+
# Try common column name fields
|
|
202
|
+
col_name = row.get("column_name") or row.get("Field") or row.get("name") or row.get("COLUMN_NAME")
|
|
203
|
+
if col_name and isinstance(col_name, str):
|
|
204
|
+
columns.append(col_name)
|
|
205
|
+
|
|
206
|
+
logger.info(f"Extracted {len(columns)} columns for {model_name}: {columns}")
|
|
207
|
+
return sorted(columns)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def clear_stale_run_results(project_dir: Path | None) -> None:
|
|
211
|
+
"""Delete stale run_results.json before command execution.
|
|
212
|
+
|
|
213
|
+
This prevents reading cached results from previous runs.
|
|
214
|
+
"""
|
|
215
|
+
if not project_dir:
|
|
216
|
+
return
|
|
217
|
+
|
|
218
|
+
run_results_path = project_dir / "target" / "run_results.json"
|
|
219
|
+
if run_results_path.exists():
|
|
220
|
+
try:
|
|
221
|
+
run_results_path.unlink()
|
|
222
|
+
logger.debug("Deleted stale run_results.json before execution")
|
|
223
|
+
except OSError as e:
|
|
224
|
+
logger.warning(f"Could not delete stale run_results.json: {e}")
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
async def save_execution_state(runner: Any, project_dir: Path | None) -> None:
|
|
228
|
+
"""Save current manifest as state for future state-based runs.
|
|
229
|
+
|
|
230
|
+
After successful execution, saves manifest.json to target/state_last_run/
|
|
231
|
+
so future runs can use --state to detect modifications.
|
|
232
|
+
"""
|
|
233
|
+
if not project_dir:
|
|
234
|
+
return
|
|
235
|
+
|
|
236
|
+
state_dir = project_dir / "target" / "state_last_run"
|
|
237
|
+
state_dir.mkdir(parents=True, exist_ok=True)
|
|
238
|
+
|
|
239
|
+
manifest_path = runner.get_manifest_path() # type: ignore
|
|
240
|
+
|
|
241
|
+
try:
|
|
242
|
+
shutil.copy(manifest_path, state_dir / "manifest.json")
|
|
243
|
+
logger.debug(f"Saved execution state to {state_dir}")
|
|
244
|
+
except OSError as e:
|
|
245
|
+
logger.warning(f"Failed to save execution state: {e}")
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def get_project_paths(project_dir: Path | None) -> dict[str, list[str]]:
|
|
249
|
+
"""Read configured paths from dbt_project.yml.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
Dictionary with path types as keys and lists of paths as values
|
|
253
|
+
"""
|
|
254
|
+
if not project_dir:
|
|
255
|
+
return {}
|
|
256
|
+
|
|
257
|
+
project_file = project_dir / "dbt_project.yml"
|
|
258
|
+
if not project_file.exists():
|
|
259
|
+
return {}
|
|
260
|
+
|
|
261
|
+
try:
|
|
262
|
+
import yaml
|
|
263
|
+
|
|
264
|
+
with open(project_file, encoding="utf-8") as f:
|
|
265
|
+
config = yaml.safe_load(f)
|
|
266
|
+
|
|
267
|
+
return {
|
|
268
|
+
"model-paths": config.get("model-paths", ["models"]),
|
|
269
|
+
"seed-paths": config.get("seed-paths", ["seeds"]),
|
|
270
|
+
"snapshot-paths": config.get("snapshot-paths", ["snapshots"]),
|
|
271
|
+
"analysis-paths": config.get("analysis-paths", ["analyses"]),
|
|
272
|
+
"macro-paths": config.get("macro-paths", ["macros"]),
|
|
273
|
+
"test-paths": config.get("test-paths", ["tests"]),
|
|
274
|
+
"target-path": config.get("target-path", "target"),
|
|
275
|
+
}
|
|
276
|
+
except Exception as e:
|
|
277
|
+
logger.warning(f"Failed to parse dbt_project.yml: {e}")
|
|
278
|
+
return {}
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def compare_model_schemas(
|
|
282
|
+
project_dir: Path | None,
|
|
283
|
+
model_unique_ids: list[str],
|
|
284
|
+
state_manifest_path: Path,
|
|
285
|
+
current_manifest_data: dict[str, Any],
|
|
286
|
+
) -> dict[str, Any]:
|
|
287
|
+
"""Compare schemas of models before and after run.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
project_dir: Path to dbt project
|
|
291
|
+
model_unique_ids: List of model unique IDs that were run
|
|
292
|
+
state_manifest_path: Path to the saved state manifest.json
|
|
293
|
+
current_manifest_data: Current manifest dictionary
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
Dictionary with schema changes per model
|
|
297
|
+
"""
|
|
298
|
+
if not state_manifest_path.exists():
|
|
299
|
+
return {}
|
|
300
|
+
|
|
301
|
+
try:
|
|
302
|
+
# Load state (before) manifest
|
|
303
|
+
with open(state_manifest_path, encoding="utf-8") as f:
|
|
304
|
+
state_manifest = json.load(f)
|
|
305
|
+
|
|
306
|
+
schema_changes: dict[str, dict[str, Any]] = {}
|
|
307
|
+
|
|
308
|
+
for unique_id in model_unique_ids:
|
|
309
|
+
# Skip non-model nodes (like tests)
|
|
310
|
+
if not unique_id.startswith("model."):
|
|
311
|
+
continue
|
|
312
|
+
|
|
313
|
+
# Get before and after column definitions
|
|
314
|
+
before_node = state_manifest.get("nodes", {}).get(unique_id, {})
|
|
315
|
+
after_node = current_manifest_data.get("nodes", {}).get(unique_id, {})
|
|
316
|
+
|
|
317
|
+
before_columns = before_node.get("columns", {})
|
|
318
|
+
after_columns = after_node.get("columns", {})
|
|
319
|
+
|
|
320
|
+
# Skip if no column definitions exist (not in schema.yml)
|
|
321
|
+
if not before_columns and not after_columns:
|
|
322
|
+
continue
|
|
323
|
+
|
|
324
|
+
# Compare columns
|
|
325
|
+
before_names = set(before_columns.keys())
|
|
326
|
+
after_names = set(after_columns.keys())
|
|
327
|
+
|
|
328
|
+
added = sorted(after_names - before_names)
|
|
329
|
+
removed = sorted(before_names - after_names)
|
|
330
|
+
|
|
331
|
+
# Check for type changes in common columns
|
|
332
|
+
changed_types = {}
|
|
333
|
+
for col in before_names & after_names:
|
|
334
|
+
before_type = before_columns[col].get("data_type")
|
|
335
|
+
after_type = after_columns[col].get("data_type")
|
|
336
|
+
if before_type != after_type and before_type is not None and after_type is not None:
|
|
337
|
+
changed_types[col] = {"from": before_type, "to": after_type}
|
|
338
|
+
|
|
339
|
+
# Only record if there are actual changes
|
|
340
|
+
if added or removed or changed_types:
|
|
341
|
+
model_name = after_node.get("name", unique_id.split(".")[-1])
|
|
342
|
+
schema_changes[model_name] = {
|
|
343
|
+
"changed": True,
|
|
344
|
+
"added_columns": added,
|
|
345
|
+
"removed_columns": removed,
|
|
346
|
+
"changed_types": changed_types,
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
return schema_changes
|
|
350
|
+
|
|
351
|
+
except Exception as e:
|
|
352
|
+
logger.warning(f"Failed to compare schemas: {e}")
|
|
353
|
+
return {}
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def manifest_exists(project_dir: Path | None) -> bool:
|
|
357
|
+
"""Check if manifest.json exists.
|
|
358
|
+
|
|
359
|
+
Simple check - tools will handle their own parsing as needed.
|
|
360
|
+
"""
|
|
361
|
+
if project_dir is None:
|
|
362
|
+
return False
|
|
363
|
+
manifest_path = project_dir / "target" / "manifest.json"
|
|
364
|
+
return manifest_path.exists()
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
async def prepare_state_based_selection(
|
|
368
|
+
project_dir: Path | None,
|
|
369
|
+
select_state_modified: bool,
|
|
370
|
+
select_state_modified_plus_downstream: bool,
|
|
371
|
+
select: str | None,
|
|
372
|
+
) -> str | None:
|
|
373
|
+
"""Validate and prepare state-based selection.
|
|
374
|
+
|
|
375
|
+
Args:
|
|
376
|
+
project_dir: Path to dbt project
|
|
377
|
+
select_state_modified: Use state:modified selector
|
|
378
|
+
select_state_modified_plus_downstream: Extend to state:modified+
|
|
379
|
+
select: Manual selector (conflicts with state-based)
|
|
380
|
+
|
|
381
|
+
Returns:
|
|
382
|
+
The dbt selector string to use ("state:modified" or "state:modified+"), or None if:
|
|
383
|
+
- Not using state-based selection
|
|
384
|
+
- No previous state exists (cannot determine modifications)
|
|
385
|
+
|
|
386
|
+
Raises:
|
|
387
|
+
ValueError: If validation fails
|
|
388
|
+
"""
|
|
389
|
+
# Validate: hierarchical requirement
|
|
390
|
+
if select_state_modified_plus_downstream and not select_state_modified:
|
|
391
|
+
raise ValueError("select_state_modified_plus_downstream requires select_state_modified=True")
|
|
392
|
+
|
|
393
|
+
# Validate: can't use both state-based and manual selection
|
|
394
|
+
if select_state_modified and select:
|
|
395
|
+
raise ValueError("Cannot use both select_state_modified* flags and select parameter")
|
|
396
|
+
|
|
397
|
+
# If not using state-based selection, return None
|
|
398
|
+
if not select_state_modified:
|
|
399
|
+
return None
|
|
400
|
+
|
|
401
|
+
# Check if state exists
|
|
402
|
+
if not project_dir:
|
|
403
|
+
return None
|
|
404
|
+
|
|
405
|
+
state_dir = project_dir / "target" / "state_last_run"
|
|
406
|
+
if not state_dir.exists():
|
|
407
|
+
# No state - cannot determine modifications
|
|
408
|
+
return None
|
|
409
|
+
|
|
410
|
+
# Return selector (state exists)
|
|
411
|
+
return "state:modified+" if select_state_modified_plus_downstream else "state:modified"
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Warehouse Adapter Protocol.
|
|
3
|
+
|
|
4
|
+
Provides an interface for database-specific warehouse operations like pre-warming,
|
|
5
|
+
with implementations for different database platforms (Databricks, Snowflake, etc.).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Callable, Protocol
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class WarehouseAdapter(Protocol):
|
|
16
|
+
"""Protocol for warehouse-specific operations."""
|
|
17
|
+
|
|
18
|
+
async def prewarm(self, progress_callback: Callable[[int, int, str], Any] | None = None) -> None:
|
|
19
|
+
"""
|
|
20
|
+
Pre-warm the warehouse/cluster before executing dbt commands.
|
|
21
|
+
|
|
22
|
+
This method is called before dbt operations that require database access.
|
|
23
|
+
For serverless warehouses, this starts the warehouse and waits for it to be ready.
|
|
24
|
+
For other databases, this may be a no-op.
|
|
25
|
+
|
|
26
|
+
Multiple calls to prewarm() should be safe - if the warehouse is already running,
|
|
27
|
+
the operation should be idempotent.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
progress_callback: Optional callback for progress updates (current, total, message)
|
|
31
|
+
"""
|
|
32
|
+
...
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class NoOpWarehouseAdapter:
|
|
36
|
+
"""
|
|
37
|
+
Default no-op warehouse adapter for databases that don't need pre-warming.
|
|
38
|
+
|
|
39
|
+
Used for databases like Postgres, DuckDB, BigQuery, etc. that don't have
|
|
40
|
+
cold-start delays or where pre-warming isn't beneficial.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
async def prewarm(self, progress_callback: Callable[[int, int, str], Any] | None = None) -> None:
|
|
44
|
+
"""No-op pre-warm for databases that don't need it."""
|
|
45
|
+
logger.debug("No warehouse pre-warming needed for this database type")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def create_warehouse_adapter(project_dir: Path, adapter_type: str) -> WarehouseAdapter:
|
|
49
|
+
"""
|
|
50
|
+
Factory function to create the appropriate warehouse adapter.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
project_dir: Path to the dbt project directory
|
|
54
|
+
adapter_type: The dbt adapter type (e.g., 'databricks', 'snowflake', 'postgres')
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
WarehouseAdapter instance for the specified database type
|
|
58
|
+
|
|
59
|
+
Examples:
|
|
60
|
+
>>> adapter = create_warehouse_adapter(Path("/project"), "databricks")
|
|
61
|
+
>>> await adapter.prewarm() # Starts Databricks serverless warehouse
|
|
62
|
+
|
|
63
|
+
>>> adapter = create_warehouse_adapter(Path("/project"), "postgres")
|
|
64
|
+
>>> await adapter.prewarm() # No-op for Postgres
|
|
65
|
+
"""
|
|
66
|
+
adapter_type_lower = adapter_type.lower()
|
|
67
|
+
|
|
68
|
+
if adapter_type_lower == "databricks":
|
|
69
|
+
# Import here to avoid dependency issues if databricks libs not installed
|
|
70
|
+
from .warehouse_databricks import DatabricksWarehouseAdapter
|
|
71
|
+
|
|
72
|
+
logger.info(f"Creating Databricks warehouse adapter for {project_dir}")
|
|
73
|
+
return DatabricksWarehouseAdapter(project_dir)
|
|
74
|
+
|
|
75
|
+
# TODO: Add Snowflake adapter when needed
|
|
76
|
+
# elif adapter_type_lower == "snowflake":
|
|
77
|
+
# from .warehouse_snowflake import SnowflakeWarehouseAdapter
|
|
78
|
+
# return SnowflakeWarehouseAdapter(project_dir)
|
|
79
|
+
|
|
80
|
+
# Default to no-op for all other databases
|
|
81
|
+
logger.info(f"Using no-op warehouse adapter for {adapter_type}")
|
|
82
|
+
return NoOpWarehouseAdapter()
|