mdbt 0.4.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mdbt/main.py ADDED
@@ -0,0 +1,474 @@
1
+ import os
2
+ import re
3
+ import shutil
4
+ import subprocess
5
+ import sys
6
+ import typing as t
7
+
8
+ import pyperclip
9
+ from click.core import Command
10
+ from click.core import Context
11
+
12
+ from mdbt.core import Core
13
+
14
+
15
+ class MDBT(Core):
16
+
17
+ def __init__(self, test_mode: bool = False):
18
+ super().__init__(test_mode=test_mode)
19
+
20
+ def build(self, ctx: Context, full_refresh, select, fail_fast, threads):
21
+ flags = {
22
+ "select": select,
23
+ "fail_fast": fail_fast,
24
+ "threads": threads,
25
+ "full_refresh": full_refresh,
26
+ }
27
+ args = self._create_common_args(flags)
28
+ try:
29
+ run_result = self.execute_dbt_command_stream("build", args)
30
+ except subprocess.CalledProcessError as e:
31
+ self.handle_cmd_line_error(e)
32
+
33
+ if not run_result:
34
+ raise DbtError("DBT build failed with errors.")
35
+
36
+ def trun(self, ctx: Context, full_refresh, select, fail_fast, threads):
37
+ flags = {
38
+ "select": select,
39
+ "fail_fast": fail_fast,
40
+ "threads": threads,
41
+ "full_refresh": full_refresh,
42
+ }
43
+ args = self._create_common_args(flags)
44
+ args = args + ["--exclude", self.exclude_seed_snapshot]
45
+ try:
46
+ run_result = self.execute_dbt_command_stream("build", args)
47
+ except subprocess.CalledProcessError as e:
48
+ self.handle_cmd_line_error(e)
49
+
50
+ if not run_result:
51
+ raise DbtError("DBT build failed with errors.")
52
+
53
+ def run(self, ctx: Context, full_refresh, select, fail_fast, threads):
54
+ flags = {
55
+ "select": select,
56
+ "fail_fast": fail_fast,
57
+ "threads": threads,
58
+ "full_refresh": full_refresh,
59
+ }
60
+ args = self._create_common_args(flags)
61
+ try:
62
+ run_result = self.execute_dbt_command_stream("run", args)
63
+ except subprocess.CalledProcessError as e:
64
+ self.handle_cmd_line_error(e)
65
+
66
+ if not run_result:
67
+ raise DbtError("DBT build failed with errors.")
68
+
69
+ def test(self, ctx: Context, select, fail_fast, threads):
70
+ flags = {"select": select, "fail_fast": fail_fast, "threads": threads}
71
+ args = self._create_common_args(flags)
72
+ try:
73
+ run_result = self.execute_dbt_command_stream("test", args)
74
+ except subprocess.CalledProcessError as e:
75
+ self.handle_cmd_line_error(e)
76
+
77
+ if not run_result:
78
+ raise DbtError("DBT build failed with errors.")
79
+
80
+ def unittest(self, ctx: Context, select, fail_fast):
81
+ select = f"{select},tag:unit-test" # Comma is an and condition.
82
+ flags = {"select": select, "fail_fast": fail_fast}
83
+ args = self._create_common_args(flags)
84
+ try:
85
+ run_result = self.execute_dbt_command_stream("test", args)
86
+ except subprocess.CalledProcessError as e:
87
+ self.handle_cmd_line_error(e)
88
+
89
+ if not run_result:
90
+ raise DbtError("DBT build failed with errors.")
91
+
92
+ def compile(self, ctx: Context, select):
93
+ # We ignore the ctx object as compile has no threads.
94
+ try:
95
+ self.execute_dbt_command_stream("compile", [])
96
+ except subprocess.CalledProcessError as e:
97
+ self.handle_cmd_line_error(e)
98
+
99
+ def clip_compile(self, ctx: Context, select):
100
+ # We ignore the ctx object as compile has no threads.
101
+ try:
102
+ self.execute_dbt_command_stream("compile", ["-s", select])
103
+ results = self.dbt_execute_command_output
104
+ # Copy to clipboard
105
+ results = self.extract_sql_code(results)
106
+ pyperclip.copy(results)
107
+ except subprocess.CalledProcessError as e:
108
+ self.handle_cmd_line_error(e)
109
+
110
+ def sbuild(self, ctx: Context, full_refresh, threads):
111
+ print("Starting a state build based on local manifest.json")
112
+ artifact_dir = "_artifacts"
113
+ target_dir = "target"
114
+ # Path to the artifacts file that will be generated by the dbt compile command representing the current state.
115
+ manifest_path = os.path.join("./", target_dir, "manifest.json")
116
+ # Path to the artifact file that represents the prior build state.
117
+ manifest_artifact_path = os.path.join("./", artifact_dir, "manifest.json")
118
+
119
+ self.execute_state_based_build(
120
+ ctx,
121
+ artifact_dir,
122
+ manifest_artifact_path,
123
+ manifest_path,
124
+ full_refresh,
125
+ threads,
126
+ roll_back_manifest_flag=True,
127
+ )
128
+
129
+ def pbuild(self, ctx: Context, full_refresh, threads, skip_download):
130
+ print("Starting a state build based on production manifest.json")
131
+ artifact_dir = "logs"
132
+ target_dir = "target"
133
+ # Pull artifacts from Snowflake. These are the latest production artifacts.
134
+ try:
135
+ if not self.test_mode and not skip_download:
136
+ subprocess.run(
137
+ ["dbt", "run-operation", "get_last_artifacts"], check=True
138
+ )
139
+ except subprocess.CalledProcessError as e:
140
+ self.handle_cmd_line_error(e)
141
+
142
+ manifest_path = os.path.join("", target_dir, "manifest.json")
143
+ manifest_artifact_path = os.path.join("", artifact_dir, "manifest.json")
144
+
145
+ self.execute_state_based_build(
146
+ ctx,
147
+ artifact_dir,
148
+ manifest_artifact_path,
149
+ manifest_path,
150
+ full_refresh,
151
+ threads,
152
+ roll_back_manifest_flag=False,
153
+ )
154
+
155
+ def gbuild(self, ctx: Context, main, full_refresh, threads):
156
+ """
157
+ Build based off of a Git diff of changed models.
158
+
159
+ Args:
160
+ ctx:
161
+ full_refresh:
162
+ threads:
163
+
164
+ Returns:
165
+
166
+ """
167
+ if main:
168
+ print("Building based on changes from main branch.")
169
+ result = subprocess.run(
170
+ ["git", "diff", "main", "--name-only"],
171
+ stdout=subprocess.PIPE,
172
+ text=True,
173
+ )
174
+ else:
175
+ result = subprocess.run(
176
+ ["git", "diff", "--name-only"], stdout=subprocess.PIPE, text=True
177
+ )
178
+
179
+ modified_files = result.stdout.splitlines()
180
+
181
+ sql_files = [
182
+ file.split("/")[-1].replace(".sql", "")
183
+ for file in modified_files
184
+ if "models" in file and file.endswith(".sql")
185
+ ]
186
+
187
+ # Construct state commands
188
+ build_children = ctx.obj.get("build_children", False)
189
+ build_children_count = ctx.obj.get("build_children_count", None)
190
+ build_parents = ctx.obj.get("build_parents", False)
191
+ build_parent_count = ctx.obj.get("build_parents_count", None)
192
+ if build_children:
193
+ if build_children_count:
194
+ for i in range(len(sql_files)):
195
+ sql_files[i] = f"{sql_files[i]}+{build_children_count}"
196
+ else:
197
+ for i in range(len(sql_files)):
198
+ sql_files[i] = f"{sql_files[i]}+"
199
+
200
+ if build_parents:
201
+ if build_parent_count:
202
+ for i in range(len(sql_files)):
203
+ sql_files[i] = f"{build_parent_count}+{sql_files[i]}"
204
+ else:
205
+ for i in range(len(sql_files)):
206
+ sql_files[i] = f"+{sql_files[i]}"
207
+
208
+ select_list = " ".join(sql_files)
209
+
210
+ full_refresh = self._scan_for_incremental_full_refresh(
211
+ state_flags=["--select", select_list],
212
+ exclude_flags=None,
213
+ full_refresh=full_refresh,
214
+ )
215
+
216
+ args = [
217
+ "--select",
218
+ select_list,
219
+ "--exclude",
220
+ "resource_type:seed,resource_type:snapshot",
221
+ ]
222
+ if threads:
223
+ args.append("--threads")
224
+ args.append(str(threads))
225
+
226
+ if full_refresh:
227
+ args.append("--full-refresh")
228
+
229
+ self.execute_dbt_command_stream("build", args)
230
+
231
+ def execute_state_based_build(
232
+ self,
233
+ ctx: Context,
234
+ artifact_dir: str,
235
+ manifest_artifact_path: str,
236
+ manifest_path: str,
237
+ full_refresh: bool,
238
+ threads: int,
239
+ roll_back_manifest_flag: bool,
240
+ ):
241
+ if roll_back_manifest_flag and not self.test_mode:
242
+ print(
243
+ f"Making a backup of the current manifest.json at {manifest_path} to {manifest_artifact_path}"
244
+ )
245
+ # Move the manifest from ./target to ./_artifacts. This becomes the prior state. Only used for local state
246
+ # build. Not used for pdbuild (production build).
247
+ shutil.move(manifest_path, manifest_artifact_path)
248
+ # Execute dbt compile
249
+ try:
250
+ if not self.test_mode:
251
+ subprocess.run(["dbt", "compile"], check=True)
252
+ except subprocess.CalledProcessError as e:
253
+ self.handle_cmd_line_error(e)
254
+
255
+ # Construct state commands
256
+ build_children = ctx.obj.get("build_children", False)
257
+ build_children_count = ctx.obj.get("build_children_count", None)
258
+ build_parents = ctx.obj.get("build_parents", False)
259
+ build_parent_count = ctx.obj.get("build_parents_count", None)
260
+ state_modified_str = "state:modified"
261
+ if build_children:
262
+ state_modified_str = f"{state_modified_str}+"
263
+ if build_children_count:
264
+ state_modified_str = f"{state_modified_str}{build_children_count}"
265
+ if build_parents:
266
+ state_modified_str = f"+{state_modified_str}"
267
+ if build_parent_count:
268
+ state_modified_str = f"{build_parent_count}{state_modified_str}"
269
+
270
+ state_flags = [
271
+ "--select",
272
+ state_modified_str,
273
+ "--state",
274
+ os.path.join("", artifact_dir) + "/",
275
+ ]
276
+ exclude_flags = ["--exclude", self.exclude_seed_snapshot]
277
+ # Get a list of models and their config
278
+
279
+ full_refresh = self._scan_for_incremental_full_refresh(
280
+ state_flags, exclude_flags, full_refresh
281
+ )
282
+
283
+ run_result = None
284
+ # Execute dbt build excluding snapshots and seeds
285
+ # Rest the args.
286
+ args = self._create_common_args({"threads": threads})
287
+ args = args + state_flags + exclude_flags
288
+ if full_refresh:
289
+ args = args + ["--full-refresh"]
290
+
291
+ try:
292
+ run_result = self.execute_dbt_command_stream("build", args)
293
+ except subprocess.CalledProcessError as e:
294
+ print(f'Failure while running command: {" ".join(e.cmd)}')
295
+ print(e.stderr)
296
+ print(e.stdout)
297
+ if roll_back_manifest_flag and not self.test_mode:
298
+ self.roll_back_manifest(e, manifest_artifact_path, manifest_path)
299
+ else:
300
+ sys.exit(e.returncode)
301
+
302
+ if not run_result:
303
+ e = "DBT build failed with errors."
304
+ self.roll_back_manifest(e, manifest_artifact_path, manifest_path)
305
+ raise DbtError("DBT build failed with errors.")
306
+
307
+ def _scan_for_incremental_full_refresh(
308
+ self, state_flags, exclude_flags, full_refresh
309
+ ):
310
+ if state_flags and exclude_flags:
311
+ args = state_flags + exclude_flags
312
+ elif state_flags and not exclude_flags:
313
+ args = state_flags
314
+ elif exclude_flags and not state_flags:
315
+ args = exclude_flags
316
+ else:
317
+ args = []
318
+
319
+ args = args + ["--output-keys", "name resource_type config"]
320
+ models_json = self.dbt_ls_to_json(args)
321
+ if not full_refresh:
322
+ for model in models_json:
323
+ if model["config"]["materialized"] == "incremental":
324
+ full_refresh = True
325
+ print(
326
+ f'Found incremental build model: {model["name"]}. Initiating full refresh.'
327
+ )
328
+ break
329
+ return full_refresh
330
+
331
+ @staticmethod
332
+ def _create_common_args(flags: t.Dict[str, t.Any]) -> t.List[str]:
333
+ threads = flags.get("threads", None)
334
+ select = flags.get("select", None)
335
+ fail_fast = flags.get("fail_fast", None)
336
+ full_refresh = flags.get("full_refresh", None)
337
+ args = []
338
+ if threads:
339
+ args.append("--threads")
340
+ args.append(str(threads))
341
+ if select:
342
+ args.append("--select")
343
+ args.append(select)
344
+ if fail_fast:
345
+ args.append("--fail-fast")
346
+ if full_refresh:
347
+ args.append("--full-refresh")
348
+ return args
349
+
350
+ @staticmethod
351
+ def roll_back_manifest(e, manifest_artifact_path, manifest_path):
352
+ print(f"DBT build failed. Rolling back manifest state with error\n {e}")
353
+ # Move the manifest.json from _artifacts back to target dir. If the build failed, we want to rebuild against this
354
+ # state, not the one generated by the compile command.
355
+ shutil.move(manifest_artifact_path, manifest_path)
356
+ sys.exit(e.returncode)
357
+
358
+ def execute_dbt_command_stream(self, command: str, args: t.List[str]) -> bool:
359
+ """
360
+ Execute a dbt command with the given arguments. This function will stream the output of the command in real-time
361
+ Args:
362
+ command: The DBT command to run.
363
+ args: A list of args to pass into the command.
364
+
365
+ Returns:
366
+ True if successful, False if error.
367
+ """
368
+
369
+ dbt_command = ["dbt", command] + args
370
+ print(f'Running command: {" ".join(dbt_command)}')
371
+ if self.test_mode:
372
+ self.dbt_test_mode_command_check_value = dbt_command
373
+ return True
374
+ else:
375
+ stderr, stdout = self.subprocess_stream(dbt_command)
376
+
377
+ # Check for errors using a regex method if necessary
378
+ if self.contains_errors(stdout + stderr):
379
+ return False
380
+
381
+ return True
382
+
383
+ def subprocess_stream(self, args):
384
+ process = subprocess.Popen(
385
+ args,
386
+ stdout=subprocess.PIPE,
387
+ stderr=subprocess.PIPE,
388
+ text=True, # Ensure outputs are in text mode rather than bytes
389
+ )
390
+ # Real-time output streaming
391
+ while True:
392
+ output = process.stdout.readline()
393
+ if output == "" and process.poll() is not None:
394
+ break
395
+ if output:
396
+ print(output.rstrip()) # Print each line of the output
397
+ self.dbt_execute_command_output += output.rstrip() + "\n"
398
+ # Capture and print any remaining output after the loop
399
+ stdout, stderr = process.communicate()
400
+ if stdout:
401
+ print(stdout.strip())
402
+ # Check exit code
403
+ if process.returncode != 0:
404
+ print(f"Command resulted in an error: {stderr}")
405
+ raise subprocess.CalledProcessError(
406
+ returncode=process.returncode, cmd=args, output=stderr
407
+ )
408
+ return stderr, stdout
409
+
410
+ @staticmethod
411
+ def extract_sql_code(log: str) -> str:
412
+ """
413
+ Extract the SQL code from the given log string.
414
+
415
+ Parameters:
416
+ log (str): The log string containing the header info and SQL code.
417
+
418
+ Returns:
419
+ str: The extracted SQL code.
420
+ """
421
+ # Split the log by lines
422
+ lines = log.splitlines()
423
+
424
+ # Iterate over the lines and find the first empty line
425
+ sql_start_index = 0
426
+ for i, line in enumerate(lines):
427
+ if line.startswith("\x1b["):
428
+ sql_start_index = i + 1
429
+
430
+ # Join the lines from the first empty line to the end
431
+ sql_code = "\n".join(lines[sql_start_index:])
432
+
433
+ return sql_code
434
+
435
+ @staticmethod
436
+ def contains_errors(text):
437
+ pattern = r"([2-9]|\d{2,})\s+errors?"
438
+ error_flag = bool(re.search(pattern, text))
439
+ return error_flag
440
+
441
+
442
+ class DbtError(Exception):
443
+ def __init__(self, message):
444
+ self.message = "DBT build failed with errors."
445
+
446
+ def __str__(self):
447
+ return self.message
448
+
449
+
450
+ class MockCtx(Context):
451
+ def __init__(self, command: t.Optional["Command"] = None) -> None:
452
+ super().__init__(command)
453
+ self.obj = {
454
+ "build_children": False,
455
+ "build_children_count": None,
456
+ "parents_children": False,
457
+ "build_parent_count": None,
458
+ }
459
+
460
+
461
+ if __name__ == "__main__":
462
+ mdbt = MDBT()
463
+ mock_ctx = MockCtx(Command("Duck"))
464
+ mock_ctx.obj["build_children"] = True
465
+ # mdbt.build(full_refresh=False, select=None, fail_fast=False)
466
+ # mdbt.trun(full_refresh=False, select=None, fail_fast=False)
467
+ # mdbt.run(full_refresh=False, select=None, fail_fast=False)
468
+ # mdbt.test(select=None, fail_fast=False)
469
+ # mdbt.compile()
470
+ # mdbt.sbuild(ctx=None, full_refresh=False)
471
+ # mdbt.pbuild(ctx=MockCtx(Command('Duck')), full_refresh=False)
472
+ # mdbt.gbuild(ctx=mock_ctx, full_refresh=False, threads=8)
473
+ # mdbt.format(ctx=mock_ctx, select=None, all=True, main=False)
474
+ sys.exit(0)
@@ -0,0 +1,84 @@
1
+ import os
2
+ import subprocess
3
+ import sys
4
+
5
+ from click.core import Context
6
+
7
+ from mdbt.core import Core
8
+
9
+
10
+ class PrecommitFormat(Core):
11
+
12
+ def __init__(self, test_mode=False):
13
+ super().__init__(test_mode=test_mode)
14
+
15
+ def pre_commit(self, ctx):
16
+ args = ["pre-commit", "run", "--all-files"]
17
+
18
+ try:
19
+ subprocess.run(args, check=True)
20
+ except subprocess.CalledProcessError as e:
21
+ self.handle_cmd_line_error(e)
22
+
23
+ def format(self, ctx: Context, select, all=False, main=False):
24
+ """
25
+ Scan for files that have changed since the last commit and pass them to sqlfluff fix command for cleanup.
26
+
27
+ Args:
28
+ ctx: Context object.
29
+ """
30
+ print("Scanning for changed files since last commit.")
31
+ # Set the env path to the .sqlfluffignore
32
+ os.environ["SQLFLUFF_CONFIG"] = "../.sqlfluffignore"
33
+ try:
34
+ if main:
35
+ # Check against main.
36
+ result = subprocess.run(
37
+ ["git", "diff", "--name-only", "main"],
38
+ stdout=subprocess.PIPE,
39
+ text=True,
40
+ check=True,
41
+ )
42
+ else:
43
+ # Check against last commit.
44
+ result = subprocess.run(
45
+ ["git", "diff", "--name-only"],
46
+ stdout=subprocess.PIPE,
47
+ text=True,
48
+ check=True,
49
+ )
50
+ changed_files = result.stdout.splitlines()
51
+ except subprocess.CalledProcessError as e:
52
+ print(f'Failure while running git command: {" ".join(e.cmd)}')
53
+ print(e.stderr)
54
+ print(e.stdout)
55
+ sys.exit(e.returncode)
56
+
57
+ # Filter SQL files
58
+ sql_files = [file for file in changed_files if file.endswith(".sql")]
59
+
60
+ # Filter out any files that are not in the models directory
61
+ sql_files = [file for file in sql_files if "models" in file]
62
+
63
+ if not sql_files and not all:
64
+ print("No SQL files have changed since the last commit.")
65
+ return
66
+
67
+ if all:
68
+ sql_files = ["./models"]
69
+
70
+ for sql_file in sql_files:
71
+ try:
72
+ print(f"Running sqlfluff fix on {sql_file}")
73
+ subprocess.run(
74
+ ["sqlfluff", "fix", sql_file, "--config", "../.sqlfluff"],
75
+ check=True,
76
+ )
77
+ except subprocess.CalledProcessError as e:
78
+ print(f"Failure while running sqlfluff fix command on {sql_file}")
79
+ print(e.stderr)
80
+ print(e.stdout)
81
+ # Optionally, we might not want to exit immediately but continue fixing other files
82
+ # sys.exit(e.returncode)
83
+
84
+ print("Sqlfluff fix completed for all changed SQL files.")