modaryn 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
modaryn/__init__.py ADDED
File without changes
File without changes
File without changes
File without changes
@@ -0,0 +1,140 @@
1
+ import warnings
2
+ from typing import Callable, Dict, List, Optional, Set
3
+ import sqlglot
4
+ from sqlglot import exp
5
+ from sqlglot.lineage import lineage
6
+ from modaryn.domain.model import DbtProject, ColumnReference, DbtModel
7
+
8
+
9
+ class LineageAnalyzer:
10
+ def __init__(self, dialect: str = "bigquery"):
11
+ self.dialect = dialect
12
+
13
+ def analyze(self, project: DbtProject, on_progress: Optional[Callable[[int, int], None]] = None):
14
+ """
15
+ Analyzes column-level lineage for all models in the project.
16
+ on_progress: optional callback(current, total) called after each model is processed.
17
+ """
18
+ schema = self._build_schema(project)
19
+ # Store table names in lowercase for case-insensitive lookup
20
+ table_to_id = {model.model_name.lower(): model.unique_id for model in project.models.values()}
21
+
22
+ models = list(project.models.values())
23
+ total = len(models)
24
+ for i, model in enumerate(models):
25
+ if on_progress:
26
+ on_progress(i + 1, total)
27
+ if not model.raw_sql:
28
+ continue
29
+
30
+ for column_name in model.columns:
31
+ try:
32
+ # Try variations ordered by likelihood for the dialect to minimize failed attempts.
33
+ # BigQuery uses backticks; Snowflake/Redshift default to uppercase; others use lowercase.
34
+ node = None
35
+ last_error = None
36
+ search_variations = self._get_column_variations(column_name)
37
+
38
+ for variation in search_variations:
39
+ try:
40
+ node = lineage(variation, sql=model.raw_sql, schema=schema, dialect=self.dialect)
41
+ if node:
42
+ break
43
+ except Exception as e:
44
+ last_error = e
45
+ continue
46
+
47
+ if node:
48
+ self._extract_source_columns(model, column_name, node, table_to_id, project)
49
+ elif last_error:
50
+ warnings.warn(
51
+ f"Lineage unavailable for column '{column_name}' in model '{model.model_name}': {last_error}",
52
+ UserWarning,
53
+ stacklevel=2,
54
+ )
55
+ except Exception as e:
56
+ warnings.warn(
57
+ f"Lineage analysis failed for column '{column_name}' in model '{model.model_name}': {e}",
58
+ UserWarning,
59
+ stacklevel=2,
60
+ )
61
+ continue
62
+
63
+ def _get_column_variations(self, column_name: str) -> List[str]:
64
+ """Returns column name variations ordered by likelihood for the current dialect."""
65
+ if self.dialect == "bigquery":
66
+ return [f'`{column_name}`', column_name, column_name.upper(), f'"{column_name}"']
67
+ elif self.dialect in ("snowflake", "redshift"):
68
+ return [column_name.upper(), column_name, f'"{column_name}"', f'`{column_name}`']
69
+ else:
70
+ return [column_name, column_name.upper(), f'"{column_name}"', f'`{column_name}`']
71
+
72
+ def _build_schema(self, project: DbtProject) -> Dict:
73
+ """
74
+ Builds a sqlglot compatible schema from the dbt project.
75
+ """
76
+ schema = {}
77
+ for model in project.models.values():
78
+ # Use lowercase for table and column names in schema to allow flexible matching
79
+ model_name_lower = model.model_name.lower()
80
+ schema[model_name_lower] = {col.name.lower(): "UNKNOWN" for col in model.columns.values()}
81
+ return schema
82
+
83
+ def _extract_source_columns(self, target_model: DbtModel, target_column_name: str, node, table_to_id: Dict[str, str], project: DbtProject):
84
+ """
85
+ Recursively finds the source columns from the lineage node and populates the project model.
86
+ """
87
+ processed = set()
88
+
89
+ def walk(current_node):
90
+ if id(current_node) in processed:
91
+ return
92
+ processed.add(id(current_node))
93
+
94
+ # Identify if this node represents a source table and column.
95
+ # We prioritize Table expressions but fallback to parsing the node name (e.g., 'table.column').
96
+ table_name = None
97
+ source_col_raw = None
98
+
99
+ if isinstance(current_node.expression, exp.Table):
100
+ table_id_raw = current_node.expression.this
101
+ if hasattr(table_id_raw, 'name'):
102
+ table_name = table_id_raw.name.lower().strip('"`')
103
+ else:
104
+ table_name = str(table_id_raw).lower().strip('"`')
105
+ source_col_raw = current_node.name.split('.')[-1].lower().strip('"`')
106
+
107
+ elif '.' in current_node.name:
108
+ parts = current_node.name.split('.')
109
+ table_name = parts[-2].lower().strip('"`')
110
+ source_col_raw = parts[-1].lower().strip('"`')
111
+
112
+ # If we found a candidate table name, check if it's in our dbt project
113
+ if table_name and table_name in table_to_id:
114
+ source_model_id = table_to_id[table_name]
115
+ source_model = project.models.get(source_model_id)
116
+
117
+ if source_model:
118
+ # Map normalized source_col_raw back to the actual column name in the source model
119
+ actual_source_col = None
120
+ for col_name in source_model.columns:
121
+ if col_name.lower() == source_col_raw:
122
+ actual_source_col = col_name
123
+ break
124
+
125
+ if actual_source_col:
126
+ # Add reference if not already present
127
+ if not any(ref.model_unique_id == source_model_id and ref.column_name == actual_source_col
128
+ for ref in target_model.columns[target_column_name].upstream_columns):
129
+
130
+ target_model.columns[target_column_name].upstream_columns.append(
131
+ ColumnReference(model_unique_id=source_model_id, column_name=actual_source_col)
132
+ )
133
+ source_model.columns[actual_source_col].downstream_columns.append(
134
+ ColumnReference(model_unique_id=target_model.unique_id, column_name=target_column_name)
135
+ )
136
+
137
+ for downstream in current_node.downstream:
138
+ walk(downstream)
139
+
140
+ walk(node)
@@ -0,0 +1,48 @@
1
+ from dataclasses import dataclass
2
+ import sqlglot
3
+
4
+
5
+ @dataclass
6
+ class SqlComplexityResult:
7
+ join_count: int
8
+ cte_count: int
9
+ conditional_count: int
10
+ where_count: int
11
+ sql_char_count: int
12
+
13
+
14
+ class SqlComplexityAnalyzer:
15
+ def __init__(self, dialect: str = "bigquery"):
16
+ self.dialect = dialect
17
+
18
+ def analyze(self, sql: str) -> SqlComplexityResult:
19
+ """
20
+ Analyzes the complexity of a SQL query.
21
+
22
+ Args:
23
+ sql: The SQL query string to analyze.
24
+
25
+ Returns:
26
+ A dictionary containing complexity metrics.
27
+ """
28
+ try:
29
+ expression = sqlglot.parse_one(sql, read=self.dialect)
30
+ except sqlglot.errors.ParseError as e:
31
+ # If sqlglot can't parse, return zero for all metrics
32
+ # We don't print warnings here to avoid polluting test output
33
+ return SqlComplexityResult(join_count=0, cte_count=0, conditional_count=0, where_count=0, sql_char_count=0)
34
+
35
+ join_count = len(list(expression.find_all(sqlglot.exp.Join)))
36
+ cte_count = len(list(expression.find_all(sqlglot.exp.CTE)))
37
+ conditional_count = len(list(expression.find_all(sqlglot.exp.If)))
38
+ where_count = len(list(expression.find_all(sqlglot.exp.Where)))
39
+ sql_char_count = len(sql.replace(' ', '').strip())
40
+
41
+
42
+ return SqlComplexityResult(
43
+ join_count=join_count,
44
+ cte_count=cte_count,
45
+ conditional_count=conditional_count,
46
+ where_count=where_count,
47
+ sql_char_count=sql_char_count
48
+ )
@@ -0,0 +1,6 @@
1
+ ███╗ ███╗ ██████╗ ██████╗ █████╗ ██████╗ ██╗ ██╗███╗ ██╗
2
+ ████╗ ████║██╔═══██╗██╔══██╗██╔══██╗██╔══██╗╚██╗ ██╔╝████╗ ██║
3
+ ██╔████╔██║██║ ██║██║ ██║███████║██████╔╝ ╚████╔╝ ██╔██╗ ██║
4
+ ██║╚██╔╝██║██║ ██║██║ ██║██╔══██║██╔══██╗ ╚██╔╝ ██║╚██╗██║
5
+ ██║ ╚═╝ ██║╚██████╔╝██████╔╝██║ ██║██║ ██║ ██║ ██║ ╚████║
6
+ ╚═╝ ╚═╝ ╚═════╝ ╚═════╝ ╚═╝ ╚═╝╚═╝ ╚═╝ ╚═╝ ╚═╝ ╚═══╝