mdbt 0.4.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mdbt/__init__.py +0 -0
- mdbt/ai_core.py +116 -0
- mdbt/build_dbt_docs_ai.py +147 -0
- mdbt/build_unit_test_data_ai.py +129 -0
- mdbt/cmdline.py +368 -0
- mdbt/core.py +113 -0
- mdbt/expectations_output_builder.py +74 -0
- mdbt/lightdash.py +84 -0
- mdbt/main.py +474 -0
- mdbt/precommit_format.py +84 -0
- mdbt/prompts.py +244 -0
- mdbt/recce.py +66 -0
- mdbt/sort_yaml_fields.py +148 -0
- mdbt/sql_sorter.py +165 -0
- mdbt-0.4.27.dist-info/METADATA +28 -0
- mdbt-0.4.27.dist-info/RECORD +20 -0
- mdbt-0.4.27.dist-info/WHEEL +5 -0
- mdbt-0.4.27.dist-info/entry_points.txt +2 -0
- mdbt-0.4.27.dist-info/licenses/LICENSE +21 -0
- mdbt-0.4.27.dist-info/top_level.txt +1 -0
mdbt/__init__.py
ADDED
|
File without changes
|
mdbt/ai_core.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import subprocess
|
|
4
|
+
from typing import Dict
|
|
5
|
+
from typing import List
|
|
6
|
+
|
|
7
|
+
import openai
|
|
8
|
+
from snowflake.connector import DatabaseError
|
|
9
|
+
|
|
10
|
+
from mdbt.core import Core
|
|
11
|
+
from mdbt.prompts import Prompts
|
|
12
|
+
|
|
13
|
+
# Have to load env before import openai package.
|
|
14
|
+
# flake8: noqa: E402
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AiCore(Core):
|
|
18
|
+
|
|
19
|
+
def __init__(self, model: str = "gpt-4.1", test_mode: bool = False):
|
|
20
|
+
super().__init__(test_mode=test_mode)
|
|
21
|
+
self.model = model
|
|
22
|
+
# Make sure you have OPENAI_API_KEY set in your environment variables.
|
|
23
|
+
self.client = openai.OpenAI()
|
|
24
|
+
|
|
25
|
+
self.prompts = Prompts()
|
|
26
|
+
|
|
27
|
+
def send_message(self, _messages: List[Dict[str, str]]) -> object:
|
|
28
|
+
print("Sending to API")
|
|
29
|
+
completion = self.client.chat.completions.create(
|
|
30
|
+
model=self.model, messages=_messages
|
|
31
|
+
)
|
|
32
|
+
return completion.choices[0].message.content
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
def read_file(path: str) -> str:
|
|
36
|
+
with open(path, "r") as file:
|
|
37
|
+
return file.read()
|
|
38
|
+
|
|
39
|
+
@staticmethod
|
|
40
|
+
def is_file_committed(file_path):
|
|
41
|
+
try:
|
|
42
|
+
# Check the Git status of the file
|
|
43
|
+
subprocess.run(
|
|
44
|
+
["git", "ls-files", "--error-unmatch", file_path],
|
|
45
|
+
check=True,
|
|
46
|
+
stdout=subprocess.PIPE,
|
|
47
|
+
stderr=subprocess.PIPE,
|
|
48
|
+
)
|
|
49
|
+
# If the file is tracked, check if it has any modifications
|
|
50
|
+
status_result = subprocess.run(
|
|
51
|
+
["git", "status", "--porcelain", file_path], stdout=subprocess.PIPE
|
|
52
|
+
)
|
|
53
|
+
status_output = status_result.stdout.decode().strip()
|
|
54
|
+
# If the output is empty, file is committed and has no modifications
|
|
55
|
+
return len(status_output) == 0
|
|
56
|
+
except subprocess.CalledProcessError:
|
|
57
|
+
# The file is either untracked or does not exist
|
|
58
|
+
return False
|
|
59
|
+
|
|
60
|
+
def _get_sample_data_from_snowflake(self, model_names: List[str]) -> Dict[str, str]:
|
|
61
|
+
"""
|
|
62
|
+
Compiles the target model to SQL, then breaks out each sub query and CTE into a separate SQL strings, executing
|
|
63
|
+
each to get a sample of the data.
|
|
64
|
+
Args:
|
|
65
|
+
model_names: A list of target model names to pull sample data from.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
A dictionary of model names and their sample data in CSV format.
|
|
69
|
+
"""
|
|
70
|
+
sample_results = {}
|
|
71
|
+
for model_name in model_names:
|
|
72
|
+
print(f"Getting sample data for {model_name}")
|
|
73
|
+
args = ["--select", model_name]
|
|
74
|
+
cmd = "compile"
|
|
75
|
+
results = self.execute_dbt_command_capture(cmd, args)
|
|
76
|
+
extracted_sql = self.extract_sql(results)
|
|
77
|
+
sample_sql = self.build_sample_sql(extracted_sql)
|
|
78
|
+
try:
|
|
79
|
+
self._cur.execute(sample_sql)
|
|
80
|
+
except DatabaseError as e:
|
|
81
|
+
print(f"Error executing sample SQL for {model_name}")
|
|
82
|
+
print(e)
|
|
83
|
+
print("\n\n" + sample_sql + "\n\n")
|
|
84
|
+
raise e
|
|
85
|
+
tmp_df = self._cur.fetch_pandas_all()
|
|
86
|
+
sample_results[model_name] = tmp_df.to_csv(index=False)
|
|
87
|
+
print(f"Sample results: {sample_results}")
|
|
88
|
+
return sample_results
|
|
89
|
+
|
|
90
|
+
@staticmethod
|
|
91
|
+
def build_sample_sql(sql: str) -> str:
|
|
92
|
+
sql = f"""
|
|
93
|
+
with tgt_table as (
|
|
94
|
+
{sql}
|
|
95
|
+
)
|
|
96
|
+
select *
|
|
97
|
+
from tgt_table
|
|
98
|
+
sample (10 rows)
|
|
99
|
+
"""
|
|
100
|
+
return sql
|
|
101
|
+
|
|
102
|
+
@staticmethod
|
|
103
|
+
def extract_sql(log):
|
|
104
|
+
sql_lines = [line for line in log.splitlines() if not re.match(r"--\s.*", line)]
|
|
105
|
+
|
|
106
|
+
keyword_line_index = 0
|
|
107
|
+
for i, line in enumerate(sql_lines):
|
|
108
|
+
if "Compiled node" in line:
|
|
109
|
+
keyword_line_index = i + 1
|
|
110
|
+
break
|
|
111
|
+
|
|
112
|
+
sql_lines = sql_lines[keyword_line_index:]
|
|
113
|
+
|
|
114
|
+
# Join the remaining lines and remove escape sequences
|
|
115
|
+
sql = "\n".join(sql_lines).replace("\x1b[0m", "").strip()
|
|
116
|
+
return sql
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
import subprocess
|
|
2
|
+
|
|
3
|
+
import pyperclip
|
|
4
|
+
from dotenv import find_dotenv
|
|
5
|
+
from dotenv import load_dotenv
|
|
6
|
+
|
|
7
|
+
from mdbt.ai_core import AiCore
|
|
8
|
+
from mdbt.prompts import Prompts
|
|
9
|
+
|
|
10
|
+
load_dotenv(find_dotenv("../.env"))
|
|
11
|
+
load_dotenv(find_dotenv(".env"))
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class BuildDBTDocs(AiCore):
|
|
15
|
+
"""
|
|
16
|
+
# Make sure you have OPENAI_API_KEY set in your environment variables.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self):
|
|
20
|
+
super().__init__()
|
|
21
|
+
|
|
22
|
+
def main(self, model_name, sys_context, is_new=False):
|
|
23
|
+
if model_name.endswith(".sql"):
|
|
24
|
+
model_name = model_name[:-4]
|
|
25
|
+
if not is_new:
|
|
26
|
+
print(
|
|
27
|
+
"""
|
|
28
|
+
1) Build new DBT documentation.
|
|
29
|
+
2) Check existing DBT documentation against model for missing definitions.
|
|
30
|
+
"""
|
|
31
|
+
)
|
|
32
|
+
mode = int(input())
|
|
33
|
+
else:
|
|
34
|
+
mode = 1
|
|
35
|
+
print("Getting file.")
|
|
36
|
+
sql_file_path = self.get_file_path(model_name)
|
|
37
|
+
|
|
38
|
+
if "l4" in sql_file_path.lower() or "l3" in sql_file_path.lower():
|
|
39
|
+
system_instructions = Prompts().dbt_docs_gte_l3_prompt
|
|
40
|
+
else:
|
|
41
|
+
system_instructions = Prompts().dbt_docs_lte_l2_prompt
|
|
42
|
+
|
|
43
|
+
if sys_context:
|
|
44
|
+
system_instructions += f"\nContext about system docs are generated for: \n{sys_context}\n"
|
|
45
|
+
|
|
46
|
+
sample_data = self._get_sample_data_from_snowflake([model_name])
|
|
47
|
+
|
|
48
|
+
system_instructions = system_instructions + sample_data[model_name]
|
|
49
|
+
|
|
50
|
+
# Might bring this back in the future.
|
|
51
|
+
extra_info = ""
|
|
52
|
+
|
|
53
|
+
if mode == 1:
|
|
54
|
+
# Build new documentation
|
|
55
|
+
user_input = self.build_user_msg_mode_1(sql_file_path, extra_info)
|
|
56
|
+
yml_file_path = sql_file_path.replace(".sql", ".yml")
|
|
57
|
+
elif mode == 2:
|
|
58
|
+
# Check existing documentation
|
|
59
|
+
yml_file_path = sql_file_path[:-4] + ".yml"
|
|
60
|
+
user_input = self.build_user_msg_mode_2(
|
|
61
|
+
sql_file_path, yml_file_path, extra_info
|
|
62
|
+
)
|
|
63
|
+
else:
|
|
64
|
+
print(mode)
|
|
65
|
+
raise ValueError("Invalid mode")
|
|
66
|
+
|
|
67
|
+
messages = [
|
|
68
|
+
{"role": "user", "content": system_instructions + "\n" + user_input}
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
assistant_responses = []
|
|
72
|
+
result = self.send_message(messages)
|
|
73
|
+
assistant_responses.append(result)
|
|
74
|
+
|
|
75
|
+
messages.append({"role": "assistant", "content": assistant_responses[0]})
|
|
76
|
+
print(assistant_responses[0])
|
|
77
|
+
output = assistant_responses[0]
|
|
78
|
+
# Check for ``` at end of output (str) and remove
|
|
79
|
+
# Remove trailing markdown code fences if present
|
|
80
|
+
lines = output.split('\n')
|
|
81
|
+
if lines and '```' in lines[-1].strip():
|
|
82
|
+
lines = lines[:-1]
|
|
83
|
+
elif len(lines) > 1 and '```' in lines[-2].strip():
|
|
84
|
+
# Remove the second-to-last line if it's a code fence
|
|
85
|
+
lines.pop(-2)
|
|
86
|
+
output = '\n'.join(lines)
|
|
87
|
+
if not is_new:
|
|
88
|
+
clip_or_file = input(
|
|
89
|
+
f"1. to copy to clipboard\n2, to write to file ({yml_file_path}\n:"
|
|
90
|
+
)
|
|
91
|
+
else:
|
|
92
|
+
clip_or_file = "2"
|
|
93
|
+
|
|
94
|
+
if clip_or_file == "1":
|
|
95
|
+
print("Output copied to clipboard")
|
|
96
|
+
pyperclip.copy(output)
|
|
97
|
+
elif clip_or_file == "2":
|
|
98
|
+
if mode == 2:
|
|
99
|
+
# Make a backup of the current YML file.
|
|
100
|
+
self.backup_existing_yml_file(yml_file_path)
|
|
101
|
+
output = assistant_responses[0].split("\n")
|
|
102
|
+
# output = output[1:-1]
|
|
103
|
+
output = output[1:]
|
|
104
|
+
output = "\n".join(output)
|
|
105
|
+
with open(yml_file_path, "w") as file:
|
|
106
|
+
file.write(output)
|
|
107
|
+
if not self.is_file_committed(yml_file_path):
|
|
108
|
+
if not is_new:
|
|
109
|
+
commit_file = input("Press 1 to add to git, any other key to byapss: ")
|
|
110
|
+
else:
|
|
111
|
+
commit_file = "1"
|
|
112
|
+
|
|
113
|
+
if commit_file == "1":
|
|
114
|
+
subprocess.run(["git", "add", yml_file_path])
|
|
115
|
+
|
|
116
|
+
@staticmethod
|
|
117
|
+
def backup_existing_yml_file(yml_file_path):
|
|
118
|
+
with open(yml_file_path, "r") as file:
|
|
119
|
+
yml_content = file.read()
|
|
120
|
+
with open(yml_file_path + ".bak", "w") as file:
|
|
121
|
+
file.write(yml_content)
|
|
122
|
+
|
|
123
|
+
def build_user_msg_mode_1(self, _sql_file_path: str, extra_info: str) -> str:
|
|
124
|
+
self.read_file(_sql_file_path)
|
|
125
|
+
model_name = _sql_file_path.split("/")[-1].split(".")[0]
|
|
126
|
+
prompt_str = f"Build new DBT documentation for the following SQL query with model name {model_name}"
|
|
127
|
+
if len(extra_info):
|
|
128
|
+
prompt_str += f"\n{extra_info}"
|
|
129
|
+
|
|
130
|
+
return prompt_str
|
|
131
|
+
|
|
132
|
+
def build_user_msg_mode_2(
|
|
133
|
+
self, _sql_file_path: str, _yml_file_path: str, extra_info: str
|
|
134
|
+
) -> str:
|
|
135
|
+
self.read_file(_sql_file_path)
|
|
136
|
+
yml = self.read_file(_yml_file_path)
|
|
137
|
+
model_name = _sql_file_path.split("/")[-1].split(".")[0]
|
|
138
|
+
prompt_str = f"Check for missing columns in the following DBT documentation for the following SQL query with model name {model_name}. Identify any columns in the DBT documentation that do not exist in the SQL and comment them out."
|
|
139
|
+
if len(extra_info):
|
|
140
|
+
prompt_str += f"\n {extra_info}"
|
|
141
|
+
prompt_str += f"\nYML File Contents:\n{yml}"
|
|
142
|
+
|
|
143
|
+
return prompt_str
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
if __name__ == "__main__":
|
|
147
|
+
BuildDBTDocs().main("revenue_by_dvm")
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
import warnings
|
|
5
|
+
from typing import Dict
|
|
6
|
+
|
|
7
|
+
import pyperclip
|
|
8
|
+
from dotenv import find_dotenv
|
|
9
|
+
from dotenv import load_dotenv
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
load_dotenv(find_dotenv("../.env"))
|
|
13
|
+
load_dotenv(find_dotenv(".env"))
|
|
14
|
+
# flake8: noqa: E402
|
|
15
|
+
from mdbt.ai_core import AiCore
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# Have to load env before import openai package.
|
|
19
|
+
warnings.simplefilter(action="ignore", category=FutureWarning)
|
|
20
|
+
logging.getLogger("snowflake.connector").setLevel(logging.WARNING)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class BuildUnitTestDataAI(AiCore):
|
|
24
|
+
|
|
25
|
+
def __init__(self):
|
|
26
|
+
super().__init__(model="o3-mini")
|
|
27
|
+
|
|
28
|
+
def main(self, model_name: str):
|
|
29
|
+
|
|
30
|
+
file_path = self.get_file_path(model_name)
|
|
31
|
+
# Extract the folder immediately after 'models'. Not sure I need to use this just yet, holding on to it for
|
|
32
|
+
# later.
|
|
33
|
+
layer_name = file_path.split("/")[1][:2]
|
|
34
|
+
sub_folder = file_path.split("/")[2]
|
|
35
|
+
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
|
36
|
+
|
|
37
|
+
test_file_path = (
|
|
38
|
+
f"tests/unit_tests/{layer_name}/{sub_folder}/test_{file_name}.sql"
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
input_sql_file_name = file_path
|
|
42
|
+
|
|
43
|
+
input_sql = self.read_file(input_sql_file_name)
|
|
44
|
+
|
|
45
|
+
models_in_model_file = self.extract_model_names(input_sql)
|
|
46
|
+
|
|
47
|
+
sample_data = self._get_sample_data_from_snowflake(models_in_model_file)
|
|
48
|
+
|
|
49
|
+
prompt = self.build_prompt(
|
|
50
|
+
self.prompts.build_unit_test_prompt.format(model_name=model_name),
|
|
51
|
+
model_name,
|
|
52
|
+
input_sql,
|
|
53
|
+
sample_data,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
print(f"##################\n{prompt}\n##################")
|
|
57
|
+
|
|
58
|
+
messages = [
|
|
59
|
+
{
|
|
60
|
+
"role": "user",
|
|
61
|
+
"content": "You are helping to build unit tests for DBT (database build tools) models.\n"
|
|
62
|
+
+ prompt,
|
|
63
|
+
},
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
response = self.send_message(messages)
|
|
67
|
+
|
|
68
|
+
output = self._remove_first_and_last_line_from_string(response)
|
|
69
|
+
print(output)
|
|
70
|
+
|
|
71
|
+
clip_or_file = input(
|
|
72
|
+
f"1. to copy to clipboard\n2, to write to file ({test_file_path}"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
if clip_or_file == "1":
|
|
76
|
+
print("Output copied to clipboard")
|
|
77
|
+
pyperclip.copy(output)
|
|
78
|
+
elif clip_or_file == "2":
|
|
79
|
+
# Check if file exists and ask if it should be overwritten.
|
|
80
|
+
if os.path.exists(test_file_path):
|
|
81
|
+
overwrite = input(f"File {test_file_path} exists. Overwrite? (y/n)")
|
|
82
|
+
if overwrite.lower() == "y":
|
|
83
|
+
with open(test_file_path, "w") as file:
|
|
84
|
+
file.write(output)
|
|
85
|
+
print(f"Output written to {test_file_path}")
|
|
86
|
+
else:
|
|
87
|
+
with open(test_file_path, "w") as file:
|
|
88
|
+
file.write(output)
|
|
89
|
+
print(f"Output written to {test_file_path}")
|
|
90
|
+
|
|
91
|
+
def _remove_first_and_last_line_from_string(self, s: str) -> str:
|
|
92
|
+
return "\n".join(s.split("\n")[1:-1])
|
|
93
|
+
|
|
94
|
+
@staticmethod
|
|
95
|
+
def extract_model_names(dbt_script):
|
|
96
|
+
# Regular expression to find all occurrences of {{ ref('model_name') }}
|
|
97
|
+
pattern = r"\{\{\s*ref\('([^']+)'\)\s*\}\}"
|
|
98
|
+
# Find all matches in the script
|
|
99
|
+
model_names = re.findall(pattern, dbt_script)
|
|
100
|
+
return model_names
|
|
101
|
+
|
|
102
|
+
@staticmethod
|
|
103
|
+
def build_prompt(
|
|
104
|
+
prompt_template: str,
|
|
105
|
+
model_name: str,
|
|
106
|
+
model_sql,
|
|
107
|
+
sample_models_and_data: Dict[str, str],
|
|
108
|
+
):
|
|
109
|
+
sample_str = ""
|
|
110
|
+
for model_name, sample_data in sample_models_and_data.items():
|
|
111
|
+
sample_str += f"""{model_name}: \n{sample_data}\n"""
|
|
112
|
+
|
|
113
|
+
output = f"""
|
|
114
|
+
The model name we are building the test for is {model_name}. In the example, this says "model_name". Put this value in that same place.'
|
|
115
|
+
{prompt_template}
|
|
116
|
+
|
|
117
|
+
The SQL for the model is:
|
|
118
|
+
{model_sql}
|
|
119
|
+
|
|
120
|
+
Here is sample data for each input model. This just represents a random sample. Use it to create realistic test data, but try to build the test input data so that it tests the logic found within the model, regardless of the particular combination of sample data. Imagine that certain flags might be true or false, even if that flag is always true or false in the sample data.
|
|
121
|
+
|
|
122
|
+
{sample_str}
|
|
123
|
+
|
|
124
|
+
"""
|
|
125
|
+
return output
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
if __name__ == "__main__":
|
|
129
|
+
BuildUnitTestDataAI().main("avg_client_rev_per_year")
|