mdbt 0.4.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mdbt/__init__.py ADDED
File without changes
mdbt/ai_core.py ADDED
@@ -0,0 +1,116 @@
1
+ import os
2
+ import re
3
+ import subprocess
4
+ from typing import Dict
5
+ from typing import List
6
+
7
+ import openai
8
+ from snowflake.connector import DatabaseError
9
+
10
+ from mdbt.core import Core
11
+ from mdbt.prompts import Prompts
12
+
13
+ # Have to load env before import openai package.
14
+ # flake8: noqa: E402
15
+
16
+
17
+ class AiCore(Core):
18
+
19
+ def __init__(self, model: str = "gpt-4.1", test_mode: bool = False):
20
+ super().__init__(test_mode=test_mode)
21
+ self.model = model
22
+ # Make sure you have OPENAI_API_KEY set in your environment variables.
23
+ self.client = openai.OpenAI()
24
+
25
+ self.prompts = Prompts()
26
+
27
+ def send_message(self, _messages: List[Dict[str, str]]) -> object:
28
+ print("Sending to API")
29
+ completion = self.client.chat.completions.create(
30
+ model=self.model, messages=_messages
31
+ )
32
+ return completion.choices[0].message.content
33
+
34
+ @staticmethod
35
+ def read_file(path: str) -> str:
36
+ with open(path, "r") as file:
37
+ return file.read()
38
+
39
+ @staticmethod
40
+ def is_file_committed(file_path):
41
+ try:
42
+ # Check the Git status of the file
43
+ subprocess.run(
44
+ ["git", "ls-files", "--error-unmatch", file_path],
45
+ check=True,
46
+ stdout=subprocess.PIPE,
47
+ stderr=subprocess.PIPE,
48
+ )
49
+ # If the file is tracked, check if it has any modifications
50
+ status_result = subprocess.run(
51
+ ["git", "status", "--porcelain", file_path], stdout=subprocess.PIPE
52
+ )
53
+ status_output = status_result.stdout.decode().strip()
54
+ # If the output is empty, file is committed and has no modifications
55
+ return len(status_output) == 0
56
+ except subprocess.CalledProcessError:
57
+ # The file is either untracked or does not exist
58
+ return False
59
+
60
+ def _get_sample_data_from_snowflake(self, model_names: List[str]) -> Dict[str, str]:
61
+ """
62
+ Compiles the target model to SQL, then breaks out each sub query and CTE into a separate SQL strings, executing
63
+ each to get a sample of the data.
64
+ Args:
65
+ model_names: A list of target model names to pull sample data from.
66
+
67
+ Returns:
68
+ A dictionary of model names and their sample data in CSV format.
69
+ """
70
+ sample_results = {}
71
+ for model_name in model_names:
72
+ print(f"Getting sample data for {model_name}")
73
+ args = ["--select", model_name]
74
+ cmd = "compile"
75
+ results = self.execute_dbt_command_capture(cmd, args)
76
+ extracted_sql = self.extract_sql(results)
77
+ sample_sql = self.build_sample_sql(extracted_sql)
78
+ try:
79
+ self._cur.execute(sample_sql)
80
+ except DatabaseError as e:
81
+ print(f"Error executing sample SQL for {model_name}")
82
+ print(e)
83
+ print("\n\n" + sample_sql + "\n\n")
84
+ raise e
85
+ tmp_df = self._cur.fetch_pandas_all()
86
+ sample_results[model_name] = tmp_df.to_csv(index=False)
87
+ print(f"Sample results: {sample_results}")
88
+ return sample_results
89
+
90
+ @staticmethod
91
+ def build_sample_sql(sql: str) -> str:
92
+ sql = f"""
93
+ with tgt_table as (
94
+ {sql}
95
+ )
96
+ select *
97
+ from tgt_table
98
+ sample (10 rows)
99
+ """
100
+ return sql
101
+
102
+ @staticmethod
103
+ def extract_sql(log):
104
+ sql_lines = [line for line in log.splitlines() if not re.match(r"--\s.*", line)]
105
+
106
+ keyword_line_index = 0
107
+ for i, line in enumerate(sql_lines):
108
+ if "Compiled node" in line:
109
+ keyword_line_index = i + 1
110
+ break
111
+
112
+ sql_lines = sql_lines[keyword_line_index:]
113
+
114
+ # Join the remaining lines and remove escape sequences
115
+ sql = "\n".join(sql_lines).replace("\x1b[0m", "").strip()
116
+ return sql
@@ -0,0 +1,147 @@
1
+ import subprocess
2
+
3
+ import pyperclip
4
+ from dotenv import find_dotenv
5
+ from dotenv import load_dotenv
6
+
7
+ from mdbt.ai_core import AiCore
8
+ from mdbt.prompts import Prompts
9
+
10
+ load_dotenv(find_dotenv("../.env"))
11
+ load_dotenv(find_dotenv(".env"))
12
+
13
+
14
+ class BuildDBTDocs(AiCore):
15
+ """
16
+ # Make sure you have OPENAI_API_KEY set in your environment variables.
17
+ """
18
+
19
+ def __init__(self):
20
+ super().__init__()
21
+
22
+ def main(self, model_name, sys_context, is_new=False):
23
+ if model_name.endswith(".sql"):
24
+ model_name = model_name[:-4]
25
+ if not is_new:
26
+ print(
27
+ """
28
+ 1) Build new DBT documentation.
29
+ 2) Check existing DBT documentation against model for missing definitions.
30
+ """
31
+ )
32
+ mode = int(input())
33
+ else:
34
+ mode = 1
35
+ print("Getting file.")
36
+ sql_file_path = self.get_file_path(model_name)
37
+
38
+ if "l4" in sql_file_path.lower() or "l3" in sql_file_path.lower():
39
+ system_instructions = Prompts().dbt_docs_gte_l3_prompt
40
+ else:
41
+ system_instructions = Prompts().dbt_docs_lte_l2_prompt
42
+
43
+ if sys_context:
44
+ system_instructions += f"\nContext about system docs are generated for: \n{sys_context}\n"
45
+
46
+ sample_data = self._get_sample_data_from_snowflake([model_name])
47
+
48
+ system_instructions = system_instructions + sample_data[model_name]
49
+
50
+ # Might bring this back in the future.
51
+ extra_info = ""
52
+
53
+ if mode == 1:
54
+ # Build new documentation
55
+ user_input = self.build_user_msg_mode_1(sql_file_path, extra_info)
56
+ yml_file_path = sql_file_path.replace(".sql", ".yml")
57
+ elif mode == 2:
58
+ # Check existing documentation
59
+ yml_file_path = sql_file_path[:-4] + ".yml"
60
+ user_input = self.build_user_msg_mode_2(
61
+ sql_file_path, yml_file_path, extra_info
62
+ )
63
+ else:
64
+ print(mode)
65
+ raise ValueError("Invalid mode")
66
+
67
+ messages = [
68
+ {"role": "user", "content": system_instructions + "\n" + user_input}
69
+ ]
70
+
71
+ assistant_responses = []
72
+ result = self.send_message(messages)
73
+ assistant_responses.append(result)
74
+
75
+ messages.append({"role": "assistant", "content": assistant_responses[0]})
76
+ print(assistant_responses[0])
77
+ output = assistant_responses[0]
78
+ # Check for ``` at end of output (str) and remove
79
+ # Remove trailing markdown code fences if present
80
+ lines = output.split('\n')
81
+ if lines and '```' in lines[-1].strip():
82
+ lines = lines[:-1]
83
+ elif len(lines) > 1 and '```' in lines[-2].strip():
84
+ # Remove the second-to-last line if it's a code fence
85
+ lines.pop(-2)
86
+ output = '\n'.join(lines)
87
+ if not is_new:
88
+ clip_or_file = input(
89
+ f"1. to copy to clipboard\n2, to write to file ({yml_file_path}\n:"
90
+ )
91
+ else:
92
+ clip_or_file = "2"
93
+
94
+ if clip_or_file == "1":
95
+ print("Output copied to clipboard")
96
+ pyperclip.copy(output)
97
+ elif clip_or_file == "2":
98
+ if mode == 2:
99
+ # Make a backup of the current YML file.
100
+ self.backup_existing_yml_file(yml_file_path)
101
+ output = assistant_responses[0].split("\n")
102
+ # output = output[1:-1]
103
+ output = output[1:]
104
+ output = "\n".join(output)
105
+ with open(yml_file_path, "w") as file:
106
+ file.write(output)
107
+ if not self.is_file_committed(yml_file_path):
108
+ if not is_new:
109
+ commit_file = input("Press 1 to add to git, any other key to byapss: ")
110
+ else:
111
+ commit_file = "1"
112
+
113
+ if commit_file == "1":
114
+ subprocess.run(["git", "add", yml_file_path])
115
+
116
+ @staticmethod
117
+ def backup_existing_yml_file(yml_file_path):
118
+ with open(yml_file_path, "r") as file:
119
+ yml_content = file.read()
120
+ with open(yml_file_path + ".bak", "w") as file:
121
+ file.write(yml_content)
122
+
123
+ def build_user_msg_mode_1(self, _sql_file_path: str, extra_info: str) -> str:
124
+ self.read_file(_sql_file_path)
125
+ model_name = _sql_file_path.split("/")[-1].split(".")[0]
126
+ prompt_str = f"Build new DBT documentation for the following SQL query with model name {model_name}"
127
+ if len(extra_info):
128
+ prompt_str += f"\n{extra_info}"
129
+
130
+ return prompt_str
131
+
132
+ def build_user_msg_mode_2(
133
+ self, _sql_file_path: str, _yml_file_path: str, extra_info: str
134
+ ) -> str:
135
+ self.read_file(_sql_file_path)
136
+ yml = self.read_file(_yml_file_path)
137
+ model_name = _sql_file_path.split("/")[-1].split(".")[0]
138
+ prompt_str = f"Check for missing columns in the following DBT documentation for the following SQL query with model name {model_name}. Identify any columns in the DBT documentation that do not exist in the SQL and comment them out."
139
+ if len(extra_info):
140
+ prompt_str += f"\n {extra_info}"
141
+ prompt_str += f"\nYML File Contents:\n{yml}"
142
+
143
+ return prompt_str
144
+
145
+
146
+ if __name__ == "__main__":
147
+ BuildDBTDocs().main("revenue_by_dvm")
@@ -0,0 +1,129 @@
1
+ import logging
2
+ import os
3
+ import re
4
+ import warnings
5
+ from typing import Dict
6
+
7
+ import pyperclip
8
+ from dotenv import find_dotenv
9
+ from dotenv import load_dotenv
10
+
11
+
12
+ load_dotenv(find_dotenv("../.env"))
13
+ load_dotenv(find_dotenv(".env"))
14
+ # flake8: noqa: E402
15
+ from mdbt.ai_core import AiCore
16
+
17
+
18
+ # Have to load env before import openai package.
19
+ warnings.simplefilter(action="ignore", category=FutureWarning)
20
+ logging.getLogger("snowflake.connector").setLevel(logging.WARNING)
21
+
22
+
23
+ class BuildUnitTestDataAI(AiCore):
24
+
25
+ def __init__(self):
26
+ super().__init__(model="o3-mini")
27
+
28
+ def main(self, model_name: str):
29
+
30
+ file_path = self.get_file_path(model_name)
31
+ # Extract the folder immediately after 'models'. Not sure I need to use this just yet, holding on to it for
32
+ # later.
33
+ layer_name = file_path.split("/")[1][:2]
34
+ sub_folder = file_path.split("/")[2]
35
+ file_name = os.path.splitext(os.path.basename(file_path))[0]
36
+
37
+ test_file_path = (
38
+ f"tests/unit_tests/{layer_name}/{sub_folder}/test_{file_name}.sql"
39
+ )
40
+
41
+ input_sql_file_name = file_path
42
+
43
+ input_sql = self.read_file(input_sql_file_name)
44
+
45
+ models_in_model_file = self.extract_model_names(input_sql)
46
+
47
+ sample_data = self._get_sample_data_from_snowflake(models_in_model_file)
48
+
49
+ prompt = self.build_prompt(
50
+ self.prompts.build_unit_test_prompt.format(model_name=model_name),
51
+ model_name,
52
+ input_sql,
53
+ sample_data,
54
+ )
55
+
56
+ print(f"##################\n{prompt}\n##################")
57
+
58
+ messages = [
59
+ {
60
+ "role": "user",
61
+ "content": "You are helping to build unit tests for DBT (database build tools) models.\n"
62
+ + prompt,
63
+ },
64
+ ]
65
+
66
+ response = self.send_message(messages)
67
+
68
+ output = self._remove_first_and_last_line_from_string(response)
69
+ print(output)
70
+
71
+ clip_or_file = input(
72
+ f"1. to copy to clipboard\n2, to write to file ({test_file_path}"
73
+ )
74
+
75
+ if clip_or_file == "1":
76
+ print("Output copied to clipboard")
77
+ pyperclip.copy(output)
78
+ elif clip_or_file == "2":
79
+ # Check if file exists and ask if it should be overwritten.
80
+ if os.path.exists(test_file_path):
81
+ overwrite = input(f"File {test_file_path} exists. Overwrite? (y/n)")
82
+ if overwrite.lower() == "y":
83
+ with open(test_file_path, "w") as file:
84
+ file.write(output)
85
+ print(f"Output written to {test_file_path}")
86
+ else:
87
+ with open(test_file_path, "w") as file:
88
+ file.write(output)
89
+ print(f"Output written to {test_file_path}")
90
+
91
+ def _remove_first_and_last_line_from_string(self, s: str) -> str:
92
+ return "\n".join(s.split("\n")[1:-1])
93
+
94
+ @staticmethod
95
+ def extract_model_names(dbt_script):
96
+ # Regular expression to find all occurrences of {{ ref('model_name') }}
97
+ pattern = r"\{\{\s*ref\('([^']+)'\)\s*\}\}"
98
+ # Find all matches in the script
99
+ model_names = re.findall(pattern, dbt_script)
100
+ return model_names
101
+
102
+ @staticmethod
103
+ def build_prompt(
104
+ prompt_template: str,
105
+ model_name: str,
106
+ model_sql,
107
+ sample_models_and_data: Dict[str, str],
108
+ ):
109
+ sample_str = ""
110
+ for model_name, sample_data in sample_models_and_data.items():
111
+ sample_str += f"""{model_name}: \n{sample_data}\n"""
112
+
113
+ output = f"""
114
+ The model name we are building the test for is {model_name}. In the example, this says "model_name". Put this value in that same place.'
115
+ {prompt_template}
116
+
117
+ The SQL for the model is:
118
+ {model_sql}
119
+
120
+ Here is sample data for each input model. This just represents a random sample. Use it to create realistic test data, but try to build the test input data so that it tests the logic found within the model, regardless of the particular combination of sample data. Imagine that certain flags might be true or false, even if that flag is always true or false in the sample data.
121
+
122
+ {sample_str}
123
+
124
+ """
125
+ return output
126
+
127
+
128
+ if __name__ == "__main__":
129
+ BuildUnitTestDataAI().main("avg_client_rev_per_year")