altimate-code 0.5.2 → 0.5.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. package/CHANGELOG.md +12 -0
  2. package/bin/altimate +6 -0
  3. package/bin/altimate-code +6 -0
  4. package/dbt-tools/bin/altimate-dbt +2 -0
  5. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/__init__.py +0 -0
  6. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/fetch_schema.py +35 -0
  7. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/utils.py +353 -0
  8. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/validate_sql.py +114 -0
  9. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__init__.py +178 -0
  10. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__main__.py +96 -0
  11. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/_typing.py +17 -0
  12. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/__init__.py +3 -0
  13. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/__init__.py +18 -0
  14. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/_typing.py +18 -0
  15. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/column.py +332 -0
  16. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/dataframe.py +866 -0
  17. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/functions.py +1267 -0
  18. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/group.py +59 -0
  19. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/normalize.py +78 -0
  20. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/operations.py +53 -0
  21. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/readwriter.py +108 -0
  22. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/session.py +190 -0
  23. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/transforms.py +9 -0
  24. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/types.py +212 -0
  25. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/util.py +32 -0
  26. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/window.py +134 -0
  27. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/__init__.py +118 -0
  28. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/athena.py +166 -0
  29. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/bigquery.py +1331 -0
  30. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/clickhouse.py +1393 -0
  31. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/databricks.py +131 -0
  32. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dialect.py +1915 -0
  33. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/doris.py +561 -0
  34. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/drill.py +157 -0
  35. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/druid.py +20 -0
  36. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/duckdb.py +1159 -0
  37. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dune.py +16 -0
  38. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/hive.py +787 -0
  39. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/materialize.py +94 -0
  40. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/mysql.py +1324 -0
  41. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/oracle.py +378 -0
  42. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/postgres.py +778 -0
  43. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/presto.py +788 -0
  44. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/prql.py +203 -0
  45. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/redshift.py +448 -0
  46. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/risingwave.py +78 -0
  47. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/snowflake.py +1464 -0
  48. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark.py +202 -0
  49. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark2.py +349 -0
  50. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/sqlite.py +320 -0
  51. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/starrocks.py +343 -0
  52. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tableau.py +61 -0
  53. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/teradata.py +356 -0
  54. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/trino.py +115 -0
  55. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tsql.py +1403 -0
  56. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/diff.py +456 -0
  57. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/errors.py +93 -0
  58. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/__init__.py +95 -0
  59. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/context.py +101 -0
  60. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/env.py +246 -0
  61. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/python.py +460 -0
  62. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/table.py +155 -0
  63. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/expressions.py +8870 -0
  64. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/generator.py +4993 -0
  65. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/helper.py +582 -0
  66. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/jsonpath.py +227 -0
  67. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/lineage.py +423 -0
  68. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/__init__.py +11 -0
  69. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/annotate_types.py +589 -0
  70. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/canonicalize.py +222 -0
  71. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_ctes.py +43 -0
  72. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_joins.py +181 -0
  73. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_subqueries.py +189 -0
  74. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/isolate_table_selects.py +50 -0
  75. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/merge_subqueries.py +415 -0
  76. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize.py +200 -0
  77. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize_identifiers.py +64 -0
  78. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimize_joins.py +91 -0
  79. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimizer.py +94 -0
  80. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_predicates.py +222 -0
  81. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_projections.py +172 -0
  82. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify.py +104 -0
  83. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_columns.py +1024 -0
  84. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_tables.py +155 -0
  85. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/scope.py +904 -0
  86. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/simplify.py +1587 -0
  87. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/unnest_subqueries.py +302 -0
  88. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/parser.py +8501 -0
  89. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/planner.py +463 -0
  90. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/schema.py +588 -0
  91. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/serde.py +68 -0
  92. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/time.py +687 -0
  93. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/tokens.py +1520 -0
  94. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/transforms.py +1020 -0
  95. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/trie.py +81 -0
  96. package/dbt-tools/dist/altimate_python_packages/dbt_core_integration.py +825 -0
  97. package/dbt-tools/dist/altimate_python_packages/dbt_utils.py +157 -0
  98. package/dbt-tools/dist/index.js +23859 -0
  99. package/package.json +13 -13
  100. package/postinstall.mjs +42 -0
  101. package/skills/altimate-setup/SKILL.md +31 -0
@@ -0,0 +1,91 @@
1
+ from __future__ import annotations
2
+
3
+ import typing as t
4
+
5
+ from sqlglot import exp
6
+ from sqlglot.helper import tsort
7
+
8
+ JOIN_ATTRS = ("on", "side", "kind", "using", "method")
9
+
10
+
11
+ def optimize_joins(expression):
12
+ """
13
+ Removes cross joins if possible and reorder joins based on predicate dependencies.
14
+
15
+ Example:
16
+ >>> from sqlglot import parse_one
17
+ >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
18
+ 'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
19
+ """
20
+
21
+ for select in expression.find_all(exp.Select):
22
+ references = {}
23
+ cross_joins = []
24
+
25
+ for join in select.args.get("joins", []):
26
+ tables = other_table_names(join)
27
+
28
+ if tables:
29
+ for table in tables:
30
+ references[table] = references.get(table, []) + [join]
31
+ else:
32
+ cross_joins.append((join.alias_or_name, join))
33
+
34
+ for name, join in cross_joins:
35
+ for dep in references.get(name, []):
36
+ on = dep.args["on"]
37
+
38
+ if isinstance(on, exp.Connector):
39
+ if len(other_table_names(dep)) < 2:
40
+ continue
41
+
42
+ operator = type(on)
43
+ for predicate in on.flatten():
44
+ if name in exp.column_table_names(predicate):
45
+ predicate.replace(exp.true())
46
+ predicate = exp._combine(
47
+ [join.args.get("on"), predicate], operator, copy=False
48
+ )
49
+ join.on(predicate, append=False, copy=False)
50
+
51
+ expression = reorder_joins(expression)
52
+ expression = normalize(expression)
53
+ return expression
54
+
55
+
56
+ def reorder_joins(expression):
57
+ """
58
+ Reorder joins by topological sort order based on predicate references.
59
+ """
60
+ for from_ in expression.find_all(exp.From):
61
+ parent = from_.parent
62
+ joins = {join.alias_or_name: join for join in parent.args.get("joins", [])}
63
+ dag = {name: other_table_names(join) for name, join in joins.items()}
64
+ parent.set(
65
+ "joins",
66
+ [joins[name] for name in tsort(dag) if name != from_.alias_or_name and name in joins],
67
+ )
68
+ return expression
69
+
70
+
71
+ def normalize(expression):
72
+ """
73
+ Remove INNER and OUTER from joins as they are optional.
74
+ """
75
+ for join in expression.find_all(exp.Join):
76
+ if not any(join.args.get(k) for k in JOIN_ATTRS):
77
+ join.set("kind", "CROSS")
78
+
79
+ if join.kind == "CROSS":
80
+ join.set("on", None)
81
+ else:
82
+ join.set("kind", None)
83
+
84
+ if not join.args.get("on") and not join.args.get("using"):
85
+ join.set("on", exp.true())
86
+ return expression
87
+
88
+
89
+ def other_table_names(join: exp.Join) -> t.Set[str]:
90
+ on = join.args.get("on")
91
+ return exp.column_table_names(on, join.alias_or_name) if on else set()
@@ -0,0 +1,94 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ import typing as t
5
+
6
+ from sqlglot import Schema, exp
7
+ from sqlglot.dialects.dialect import DialectType
8
+ from sqlglot.optimizer.annotate_types import annotate_types
9
+ from sqlglot.optimizer.canonicalize import canonicalize
10
+ from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
11
+ from sqlglot.optimizer.eliminate_joins import eliminate_joins
12
+ from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
13
+ from sqlglot.optimizer.merge_subqueries import merge_subqueries
14
+ from sqlglot.optimizer.normalize import normalize
15
+ from sqlglot.optimizer.optimize_joins import optimize_joins
16
+ from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
17
+ from sqlglot.optimizer.pushdown_projections import pushdown_projections
18
+ from sqlglot.optimizer.qualify import qualify
19
+ from sqlglot.optimizer.qualify_columns import quote_identifiers
20
+ from sqlglot.optimizer.simplify import simplify
21
+ from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
22
+ from sqlglot.schema import ensure_schema
23
+
24
+ RULES = (
25
+ qualify,
26
+ pushdown_projections,
27
+ normalize,
28
+ unnest_subqueries,
29
+ pushdown_predicates,
30
+ optimize_joins,
31
+ eliminate_subqueries,
32
+ merge_subqueries,
33
+ eliminate_joins,
34
+ eliminate_ctes,
35
+ quote_identifiers,
36
+ annotate_types,
37
+ canonicalize,
38
+ simplify,
39
+ )
40
+
41
+
42
+ def optimize(
43
+ expression: str | exp.Expression,
44
+ schema: t.Optional[dict | Schema] = None,
45
+ db: t.Optional[str | exp.Identifier] = None,
46
+ catalog: t.Optional[str | exp.Identifier] = None,
47
+ dialect: DialectType = None,
48
+ rules: t.Sequence[t.Callable] = RULES,
49
+ **kwargs,
50
+ ) -> exp.Expression:
51
+ """
52
+ Rewrite a sqlglot AST into an optimized form.
53
+
54
+ Args:
55
+ expression: expression to optimize
56
+ schema: database schema.
57
+ This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of
58
+ the following forms:
59
+ 1. {table: {col: type}}
60
+ 2. {db: {table: {col: type}}}
61
+ 3. {catalog: {db: {table: {col: type}}}}
62
+ If no schema is provided then the default schema defined at `sqlgot.schema` will be used
63
+ db: specify the default database, as might be set by a `USE DATABASE db` statement
64
+ catalog: specify the default catalog, as might be set by a `USE CATALOG c` statement
65
+ dialect: The dialect to parse the sql string.
66
+ rules: sequence of optimizer rules to use.
67
+ Many of the rules require tables and columns to be qualified.
68
+ Do not remove `qualify` from the sequence of rules unless you know what you're doing!
69
+ **kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.
70
+
71
+ Returns:
72
+ The optimized expression.
73
+ """
74
+ schema = ensure_schema(schema, dialect=dialect)
75
+ possible_kwargs = {
76
+ "db": db,
77
+ "catalog": catalog,
78
+ "schema": schema,
79
+ "dialect": dialect,
80
+ "isolate_tables": True, # needed for other optimizations to perform well
81
+ "quote_identifiers": False,
82
+ **kwargs,
83
+ }
84
+
85
+ optimized = exp.maybe_parse(expression, dialect=dialect, copy=True)
86
+ for rule in rules:
87
+ # Find any additional rule parameters, beyond `expression`
88
+ rule_params = inspect.getfullargspec(rule).args
89
+ rule_kwargs = {
90
+ param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
91
+ }
92
+ optimized = rule(optimized, **rule_kwargs)
93
+
94
+ return optimized
@@ -0,0 +1,222 @@
1
+ from sqlglot import exp
2
+ from sqlglot.optimizer.normalize import normalized
3
+ from sqlglot.optimizer.scope import build_scope, find_in_scope
4
+ from sqlglot.optimizer.simplify import simplify
5
+ from sqlglot import Dialect
6
+
7
+
8
+ def pushdown_predicates(expression, dialect=None):
9
+ """
10
+ Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
11
+
12
+ Example:
13
+ >>> import sqlglot
14
+ >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
15
+ >>> expression = sqlglot.parse_one(sql)
16
+ >>> pushdown_predicates(expression).sql()
17
+ 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
18
+
19
+ Args:
20
+ expression (sqlglot.Expression): expression to optimize
21
+ Returns:
22
+ sqlglot.Expression: optimized expression
23
+ """
24
+ from sqlglot.dialects.presto import Presto
25
+
26
+ root = build_scope(expression)
27
+
28
+ dialect = Dialect.get_or_raise(dialect)
29
+ unnest_requires_cross_join = isinstance(dialect, Presto)
30
+
31
+ if root:
32
+ scope_ref_count = root.ref_count()
33
+
34
+ for scope in reversed(list(root.traverse())):
35
+ select = scope.expression
36
+ where = select.args.get("where")
37
+ if where:
38
+ selected_sources = scope.selected_sources
39
+ join_index = {
40
+ join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or [])
41
+ }
42
+
43
+ # a right join can only push down to itself and not the source FROM table
44
+ # presto, trino and athena don't support inner joins where the RHS is an UNNEST expression
45
+ pushdown_allowed = True
46
+ for k, (node, source) in selected_sources.items():
47
+ parent = node.find_ancestor(exp.Join, exp.From)
48
+ if isinstance(parent, exp.Join):
49
+ if parent.side == "RIGHT":
50
+ selected_sources = {k: (node, source)}
51
+ break
52
+ if isinstance(node, exp.Unnest) and unnest_requires_cross_join:
53
+ pushdown_allowed = False
54
+ break
55
+
56
+ if pushdown_allowed:
57
+ pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index)
58
+
59
+ # joins should only pushdown into itself, not to other joins
60
+ # so we limit the selected sources to only itself
61
+ for join in select.args.get("joins") or []:
62
+ name = join.alias_or_name
63
+ if name in scope.selected_sources:
64
+ pushdown(
65
+ join.args.get("on"),
66
+ {name: scope.selected_sources[name]},
67
+ scope_ref_count,
68
+ dialect,
69
+ )
70
+
71
+ return expression
72
+
73
+
74
+ def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
75
+ if not condition:
76
+ return
77
+
78
+ condition = condition.replace(simplify(condition, dialect=dialect))
79
+ cnf_like = normalized(condition) or not normalized(condition, dnf=True)
80
+
81
+ predicates = list(
82
+ condition.flatten()
83
+ if isinstance(condition, exp.And if cnf_like else exp.Or)
84
+ else [condition]
85
+ )
86
+
87
+ if cnf_like:
88
+ pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index)
89
+ else:
90
+ pushdown_dnf(predicates, sources, scope_ref_count)
91
+
92
+
93
+ def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
94
+ """
95
+ If the predicates are in CNF like form, we can simply replace each block in the parent.
96
+ """
97
+ join_index = join_index or {}
98
+ for predicate in predicates:
99
+ for node in nodes_for_predicate(predicate, sources, scope_ref_count).values():
100
+ if isinstance(node, exp.Join):
101
+ name = node.alias_or_name
102
+ predicate_tables = exp.column_table_names(predicate, name)
103
+
104
+ # Don't push the predicate if it references tables that appear in later joins
105
+ this_index = join_index[name]
106
+ if all(join_index.get(table, -1) < this_index for table in predicate_tables):
107
+ predicate.replace(exp.true())
108
+ node.on(predicate, copy=False)
109
+ break
110
+ if isinstance(node, exp.Select):
111
+ predicate.replace(exp.true())
112
+ inner_predicate = replace_aliases(node, predicate)
113
+ if find_in_scope(inner_predicate, exp.AggFunc):
114
+ node.having(inner_predicate, copy=False)
115
+ else:
116
+ node.where(inner_predicate, copy=False)
117
+
118
+
119
+ def pushdown_dnf(predicates, sources, scope_ref_count):
120
+ """
121
+ If the predicates are in DNF form, we can only push down conditions that are in all blocks.
122
+ Additionally, we can't remove predicates from their original form.
123
+ """
124
+ # find all the tables that can be pushdown too
125
+ # these are tables that are referenced in all blocks of a DNF
126
+ # (a.x AND b.x) OR (a.y AND c.y)
127
+ # only table a can be push down
128
+ pushdown_tables = set()
129
+
130
+ for a in predicates:
131
+ a_tables = exp.column_table_names(a)
132
+
133
+ for b in predicates:
134
+ a_tables &= exp.column_table_names(b)
135
+
136
+ pushdown_tables.update(a_tables)
137
+
138
+ conditions = {}
139
+
140
+ # pushdown all predicates to their respective nodes
141
+ for table in sorted(pushdown_tables):
142
+ for predicate in predicates:
143
+ nodes = nodes_for_predicate(predicate, sources, scope_ref_count)
144
+
145
+ if table not in nodes:
146
+ continue
147
+
148
+ conditions[table] = (
149
+ exp.or_(conditions[table], predicate) if table in conditions else predicate
150
+ )
151
+
152
+ for name, node in nodes.items():
153
+ if name not in conditions:
154
+ continue
155
+
156
+ predicate = conditions[name]
157
+
158
+ if isinstance(node, exp.Join):
159
+ node.on(predicate, copy=False)
160
+ elif isinstance(node, exp.Select):
161
+ inner_predicate = replace_aliases(node, predicate)
162
+ if find_in_scope(inner_predicate, exp.AggFunc):
163
+ node.having(inner_predicate, copy=False)
164
+ else:
165
+ node.where(inner_predicate, copy=False)
166
+
167
+
168
+ def nodes_for_predicate(predicate, sources, scope_ref_count):
169
+ nodes = {}
170
+ tables = exp.column_table_names(predicate)
171
+ where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
172
+
173
+ for table in sorted(tables):
174
+ node, source = sources.get(table) or (None, None)
175
+
176
+ # if the predicate is in a where statement we can try to push it down
177
+ # we want to find the root join or from statement
178
+ if node and where_condition:
179
+ node = node.find_ancestor(exp.Join, exp.From)
180
+
181
+ # a node can reference a CTE which should be pushed down
182
+ if isinstance(node, exp.From) and not isinstance(source, exp.Table):
183
+ with_ = source.parent.expression.args.get("with")
184
+ if with_ and with_.recursive:
185
+ return {}
186
+ node = source.expression
187
+
188
+ if isinstance(node, exp.Join):
189
+ if node.side and node.side != "RIGHT":
190
+ return {}
191
+ nodes[table] = node
192
+ elif isinstance(node, exp.Select) and len(tables) == 1:
193
+ # We can't push down window expressions
194
+ has_window_expression = any(
195
+ select for select in node.selects if select.find(exp.Window)
196
+ )
197
+ # we can't push down predicates to select statements if they are referenced in
198
+ # multiple places.
199
+ if (
200
+ not node.args.get("group")
201
+ and scope_ref_count[id(source)] < 2
202
+ and not has_window_expression
203
+ ):
204
+ nodes[table] = node
205
+ return nodes
206
+
207
+
208
+ def replace_aliases(source, predicate):
209
+ aliases = {}
210
+
211
+ for select in source.selects:
212
+ if isinstance(select, exp.Alias):
213
+ aliases[select.alias] = select.this
214
+ else:
215
+ aliases[select.name] = select
216
+
217
+ def _replace_alias(column):
218
+ if isinstance(column, exp.Column) and column.name in aliases:
219
+ return aliases[column.name].copy()
220
+ return column
221
+
222
+ return predicate.transform(_replace_alias)
@@ -0,0 +1,172 @@
1
+ from __future__ import annotations
2
+
3
+ import typing as t
4
+ from collections import defaultdict
5
+
6
+ from sqlglot import alias, exp
7
+ from sqlglot.optimizer.qualify_columns import Resolver
8
+ from sqlglot.optimizer.scope import Scope, traverse_scope
9
+ from sqlglot.schema import ensure_schema
10
+ from sqlglot.errors import OptimizeError
11
+ from sqlglot.helper import seq_get
12
+
13
+ if t.TYPE_CHECKING:
14
+ from sqlglot._typing import E
15
+ from sqlglot.schema import Schema
16
+ from sqlglot.dialects.dialect import DialectType
17
+
18
+ # Sentinel value that means an outer query selecting ALL columns
19
+ SELECT_ALL = object()
20
+
21
+
22
+ # Selection to use if selection list is empty
23
+ def default_selection(is_agg: bool) -> exp.Alias:
24
+ return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_")
25
+
26
+
27
+ def pushdown_projections(
28
+ expression: E,
29
+ schema: t.Optional[t.Dict | Schema] = None,
30
+ remove_unused_selections: bool = True,
31
+ dialect: DialectType = None,
32
+ ) -> E:
33
+ """
34
+ Rewrite sqlglot AST to remove unused columns projections.
35
+
36
+ Example:
37
+ >>> import sqlglot
38
+ >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
39
+ >>> expression = sqlglot.parse_one(sql)
40
+ >>> pushdown_projections(expression).sql()
41
+ 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
42
+
43
+ Args:
44
+ expression (sqlglot.Expression): expression to optimize
45
+ remove_unused_selections (bool): remove selects that are unused
46
+ Returns:
47
+ sqlglot.Expression: optimized expression
48
+ """
49
+ # Map of Scope to all columns being selected by outer queries.
50
+ schema = ensure_schema(schema, dialect=dialect)
51
+ source_column_alias_count: t.Dict[exp.Expression | Scope, int] = {}
52
+ referenced_columns: t.DefaultDict[Scope, t.Set[str | object]] = defaultdict(set)
53
+
54
+ # We build the scope tree (which is traversed in DFS postorder), then iterate
55
+ # over the result in reverse order. This should ensure that the set of selected
56
+ # columns for a particular scope are completely build by the time we get to it.
57
+ for scope in reversed(traverse_scope(expression)):
58
+ parent_selections = referenced_columns.get(scope, {SELECT_ALL})
59
+ alias_count = source_column_alias_count.get(scope, 0)
60
+
61
+ # We can't remove columns SELECT DISTINCT nor UNION DISTINCT.
62
+ if scope.expression.args.get("distinct"):
63
+ parent_selections = {SELECT_ALL}
64
+
65
+ if isinstance(scope.expression, exp.SetOperation):
66
+ set_op = scope.expression
67
+ if not (set_op.kind or set_op.side):
68
+ # Do not optimize this set operation if it's using the BigQuery specific
69
+ # kind / side syntax (e.g INNER UNION ALL BY NAME) which changes the semantics of the operation
70
+ left, right = scope.union_scopes
71
+ if len(left.expression.selects) != len(right.expression.selects):
72
+ scope_sql = scope.expression.sql(dialect=dialect)
73
+ raise OptimizeError(
74
+ f"Invalid set operation due to column mismatch: {scope_sql}."
75
+ )
76
+
77
+ referenced_columns[left] = parent_selections
78
+
79
+ if any(select.is_star for select in right.expression.selects):
80
+ referenced_columns[right] = parent_selections
81
+ elif not any(select.is_star for select in left.expression.selects):
82
+ if scope.expression.args.get("by_name"):
83
+ referenced_columns[right] = referenced_columns[left]
84
+ else:
85
+ referenced_columns[right] = {
86
+ right.expression.selects[i].alias_or_name
87
+ for i, select in enumerate(left.expression.selects)
88
+ if SELECT_ALL in parent_selections
89
+ or select.alias_or_name in parent_selections
90
+ }
91
+
92
+ if isinstance(scope.expression, exp.Select):
93
+ if remove_unused_selections:
94
+ _remove_unused_selections(scope, parent_selections, schema, alias_count)
95
+
96
+ if scope.expression.is_star:
97
+ continue
98
+
99
+ # Group columns by source name
100
+ selects = defaultdict(set)
101
+ for col in scope.columns:
102
+ table_name = col.table
103
+ col_name = col.name
104
+ selects[table_name].add(col_name)
105
+
106
+ # Push the selected columns down to the next scope
107
+ for name, (node, source) in scope.selected_sources.items():
108
+ if isinstance(source, Scope):
109
+ select = seq_get(source.expression.selects, 0)
110
+
111
+ if scope.pivots or isinstance(select, exp.QueryTransform):
112
+ columns = {SELECT_ALL}
113
+ else:
114
+ columns = selects.get(name) or set()
115
+
116
+ referenced_columns[source].update(columns)
117
+
118
+ column_aliases = node.alias_column_names
119
+ if column_aliases:
120
+ source_column_alias_count[source] = len(column_aliases)
121
+
122
+ return expression
123
+
124
+
125
+ def _remove_unused_selections(scope, parent_selections, schema, alias_count):
126
+ order = scope.expression.args.get("order")
127
+
128
+ if order:
129
+ # Assume columns without a qualified table are references to output columns
130
+ order_refs = {c.name for c in order.find_all(exp.Column) if not c.table}
131
+ else:
132
+ order_refs = set()
133
+
134
+ new_selections = []
135
+ removed = False
136
+ star = False
137
+ is_agg = False
138
+
139
+ select_all = SELECT_ALL in parent_selections
140
+
141
+ for selection in scope.expression.selects:
142
+ name = selection.alias_or_name
143
+
144
+ if select_all or name in parent_selections or name in order_refs or alias_count > 0:
145
+ new_selections.append(selection)
146
+ alias_count -= 1
147
+ else:
148
+ if selection.is_star:
149
+ star = True
150
+ removed = True
151
+
152
+ if not is_agg and selection.find(exp.AggFunc):
153
+ is_agg = True
154
+
155
+ if star:
156
+ resolver = Resolver(scope, schema)
157
+ names = {s.alias_or_name for s in new_selections}
158
+
159
+ for name in sorted(parent_selections):
160
+ if name not in names:
161
+ new_selections.append(
162
+ alias(exp.column(name, table=resolver.get_table(name)), name, copy=False)
163
+ )
164
+
165
+ # If there are no remaining selections, just select a single constant
166
+ if not new_selections:
167
+ new_selections.append(default_selection(is_agg))
168
+
169
+ scope.expression.select(*new_selections, append=False, copy=False)
170
+
171
+ if removed:
172
+ scope.clear_cache()
@@ -0,0 +1,104 @@
1
+ from __future__ import annotations
2
+
3
+ import typing as t
4
+
5
+ from sqlglot import exp
6
+ from sqlglot.dialects.dialect import Dialect, DialectType
7
+ from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
8
+ from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
9
+ from sqlglot.optimizer.qualify_columns import (
10
+ pushdown_cte_alias_columns as pushdown_cte_alias_columns_func,
11
+ qualify_columns as qualify_columns_func,
12
+ quote_identifiers as quote_identifiers_func,
13
+ validate_qualify_columns as validate_qualify_columns_func,
14
+ )
15
+ from sqlglot.optimizer.qualify_tables import qualify_tables
16
+ from sqlglot.schema import Schema, ensure_schema
17
+
18
+
19
+ def qualify(
20
+ expression: exp.Expression,
21
+ dialect: DialectType = None,
22
+ db: t.Optional[str] = None,
23
+ catalog: t.Optional[str] = None,
24
+ schema: t.Optional[dict | Schema] = None,
25
+ expand_alias_refs: bool = True,
26
+ expand_stars: bool = True,
27
+ infer_schema: t.Optional[bool] = None,
28
+ isolate_tables: bool = False,
29
+ qualify_columns: bool = True,
30
+ allow_partial_qualification: bool = False,
31
+ validate_qualify_columns: bool = True,
32
+ quote_identifiers: bool = True,
33
+ identify: bool = True,
34
+ infer_csv_schemas: bool = False,
35
+ ) -> exp.Expression:
36
+ """
37
+ Rewrite sqlglot AST to have normalized and qualified tables and columns.
38
+
39
+ This step is necessary for all further SQLGlot optimizations.
40
+
41
+ Example:
42
+ >>> import sqlglot
43
+ >>> schema = {"tbl": {"col": "INT"}}
44
+ >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
45
+ >>> qualify(expression, schema=schema).sql()
46
+ 'SELECT "tbl"."col" AS "col" FROM "tbl" AS "tbl"'
47
+
48
+ Args:
49
+ expression: Expression to qualify.
50
+ db: Default database name for tables.
51
+ catalog: Default catalog name for tables.
52
+ schema: Schema to infer column names and types.
53
+ expand_alias_refs: Whether to expand references to aliases.
54
+ expand_stars: Whether to expand star queries. This is a necessary step
55
+ for most of the optimizer's rules to work; do not set to False unless you
56
+ know what you're doing!
57
+ infer_schema: Whether to infer the schema if missing.
58
+ isolate_tables: Whether to isolate table selects.
59
+ qualify_columns: Whether to qualify columns.
60
+ allow_partial_qualification: Whether to allow partial qualification.
61
+ validate_qualify_columns: Whether to validate columns.
62
+ quote_identifiers: Whether to run the quote_identifiers step.
63
+ This step is necessary to ensure correctness for case sensitive queries.
64
+ But this flag is provided in case this step is performed at a later time.
65
+ identify: If True, quote all identifiers, else only necessary ones.
66
+ infer_csv_schemas: Whether to scan READ_CSV calls in order to infer the CSVs' schemas.
67
+
68
+ Returns:
69
+ The qualified expression.
70
+ """
71
+ schema = ensure_schema(schema, dialect=dialect)
72
+ expression = qualify_tables(
73
+ expression,
74
+ db=db,
75
+ catalog=catalog,
76
+ schema=schema,
77
+ dialect=dialect,
78
+ infer_csv_schemas=infer_csv_schemas,
79
+ )
80
+ expression = normalize_identifiers(expression, dialect=dialect)
81
+
82
+ if isolate_tables:
83
+ expression = isolate_table_selects(expression, schema=schema)
84
+
85
+ if Dialect.get_or_raise(dialect).PREFER_CTE_ALIAS_COLUMN:
86
+ expression = pushdown_cte_alias_columns_func(expression)
87
+
88
+ if qualify_columns:
89
+ expression = qualify_columns_func(
90
+ expression,
91
+ schema,
92
+ expand_alias_refs=expand_alias_refs,
93
+ expand_stars=expand_stars,
94
+ infer_schema=infer_schema,
95
+ allow_partial_qualification=allow_partial_qualification,
96
+ )
97
+
98
+ if quote_identifiers:
99
+ expression = quote_identifiers_func(expression, dialect=dialect, identify=identify)
100
+
101
+ if validate_qualify_columns:
102
+ validate_qualify_columns_func(expression)
103
+
104
+ return expression