sqlglot 27.29.0__py3-none-any.whl → 28.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. sqlglot/__main__.py +6 -4
  2. sqlglot/_version.py +2 -2
  3. sqlglot/dialects/bigquery.py +116 -295
  4. sqlglot/dialects/clickhouse.py +67 -2
  5. sqlglot/dialects/databricks.py +38 -1
  6. sqlglot/dialects/dialect.py +327 -286
  7. sqlglot/dialects/dremio.py +4 -1
  8. sqlglot/dialects/duckdb.py +718 -22
  9. sqlglot/dialects/exasol.py +243 -10
  10. sqlglot/dialects/hive.py +8 -8
  11. sqlglot/dialects/mysql.py +11 -2
  12. sqlglot/dialects/oracle.py +29 -0
  13. sqlglot/dialects/postgres.py +46 -24
  14. sqlglot/dialects/presto.py +47 -16
  15. sqlglot/dialects/redshift.py +16 -0
  16. sqlglot/dialects/risingwave.py +3 -0
  17. sqlglot/dialects/singlestore.py +12 -3
  18. sqlglot/dialects/snowflake.py +199 -271
  19. sqlglot/dialects/spark.py +2 -2
  20. sqlglot/dialects/spark2.py +11 -48
  21. sqlglot/dialects/sqlite.py +9 -0
  22. sqlglot/dialects/teradata.py +5 -8
  23. sqlglot/dialects/trino.py +6 -0
  24. sqlglot/dialects/tsql.py +61 -25
  25. sqlglot/diff.py +4 -2
  26. sqlglot/errors.py +69 -0
  27. sqlglot/expressions.py +484 -84
  28. sqlglot/generator.py +143 -41
  29. sqlglot/helper.py +2 -2
  30. sqlglot/optimizer/annotate_types.py +247 -140
  31. sqlglot/optimizer/canonicalize.py +6 -1
  32. sqlglot/optimizer/eliminate_joins.py +1 -1
  33. sqlglot/optimizer/eliminate_subqueries.py +2 -2
  34. sqlglot/optimizer/merge_subqueries.py +5 -5
  35. sqlglot/optimizer/normalize.py +20 -13
  36. sqlglot/optimizer/normalize_identifiers.py +17 -3
  37. sqlglot/optimizer/optimizer.py +4 -0
  38. sqlglot/optimizer/pushdown_predicates.py +1 -1
  39. sqlglot/optimizer/qualify.py +14 -6
  40. sqlglot/optimizer/qualify_columns.py +113 -352
  41. sqlglot/optimizer/qualify_tables.py +112 -70
  42. sqlglot/optimizer/resolver.py +374 -0
  43. sqlglot/optimizer/scope.py +27 -16
  44. sqlglot/optimizer/simplify.py +1074 -964
  45. sqlglot/optimizer/unnest_subqueries.py +12 -2
  46. sqlglot/parser.py +276 -160
  47. sqlglot/planner.py +2 -2
  48. sqlglot/schema.py +15 -4
  49. sqlglot/tokens.py +42 -7
  50. sqlglot/transforms.py +77 -22
  51. sqlglot/typing/__init__.py +316 -0
  52. sqlglot/typing/bigquery.py +376 -0
  53. sqlglot/typing/hive.py +12 -0
  54. sqlglot/typing/presto.py +24 -0
  55. sqlglot/typing/snowflake.py +505 -0
  56. sqlglot/typing/spark2.py +58 -0
  57. sqlglot/typing/tsql.py +9 -0
  58. {sqlglot-27.29.0.dist-info → sqlglot-28.4.0.dist-info}/METADATA +2 -2
  59. sqlglot-28.4.0.dist-info/RECORD +92 -0
  60. sqlglot-27.29.0.dist-info/RECORD +0 -84
  61. {sqlglot-27.29.0.dist-info → sqlglot-28.4.0.dist-info}/WHEEL +0 -0
  62. {sqlglot-27.29.0.dist-info → sqlglot-28.4.0.dist-info}/licenses/LICENSE +0 -0
  63. {sqlglot-27.29.0.dist-info → sqlglot-28.4.0.dist-info}/top_level.txt +0 -0
@@ -65,7 +65,7 @@ def eliminate_subqueries(expression: exp.Expression) -> exp.Expression:
65
65
  # Existing CTES in the root expression. We'll use this for deduplication.
66
66
  existing_ctes: ExistingCTEsMapping = {}
67
67
 
68
- with_ = root.expression.args.get("with")
68
+ with_ = root.expression.args.get("with_")
69
69
  recursive = False
70
70
  if with_:
71
71
  recursive = with_.args.get("recursive")
@@ -97,7 +97,7 @@ def eliminate_subqueries(expression: exp.Expression) -> exp.Expression:
97
97
 
98
98
  if new_ctes:
99
99
  query = expression.expression if isinstance(expression, exp.DDL) else expression
100
- query.set("with", exp.With(expressions=new_ctes, recursive=recursive))
100
+ query.set("with_", exp.With(expressions=new_ctes, recursive=recursive))
101
101
 
102
102
  return expression
103
103
 
@@ -48,7 +48,7 @@ def merge_subqueries(expression: E, leave_tables_isolated: bool = False) -> E:
48
48
  # If a derived table has these Select args, it can't be merged
49
49
  UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
50
50
  "expressions",
51
- "from",
51
+ "from_",
52
52
  "joins",
53
53
  "where",
54
54
  "order",
@@ -165,7 +165,7 @@ def _mergeable(
165
165
  if not on:
166
166
  return False
167
167
  selections = [c.name for c in on.find_all(exp.Column) if c.table == alias]
168
- inner_from = inner_scope.expression.args.get("from")
168
+ inner_from = inner_scope.expression.args.get("from_")
169
169
  if not inner_from:
170
170
  return False
171
171
  inner_from_table = inner_from.alias_or_name
@@ -197,7 +197,7 @@ def _mergeable(
197
197
  and not outer_scope.expression.is_star
198
198
  and isinstance(inner_select, exp.Select)
199
199
  and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
200
- and inner_select.args.get("from") is not None
200
+ and inner_select.args.get("from_") is not None
201
201
  and not outer_scope.pivots
202
202
  and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions)
203
203
  and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
@@ -261,7 +261,7 @@ def _merge_from(
261
261
  """
262
262
  Merge FROM clause of inner query into outer query.
263
263
  """
264
- new_subquery = inner_scope.expression.args["from"].this
264
+ new_subquery = inner_scope.expression.args["from_"].this
265
265
  new_subquery.set("joins", node_to_replace.args.get("joins"))
266
266
  node_to_replace.replace(new_subquery)
267
267
  for join_hint in outer_scope.join_hints:
@@ -357,7 +357,7 @@ def _merge_where(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoi
357
357
  if isinstance(from_or_join, exp.Join):
358
358
  # Merge predicates from an outer join to the ON clause
359
359
  # if it only has columns that are already joined
360
- from_ = expression.args.get("from")
360
+ from_ = expression.args.get("from_")
361
361
  sources = {from_.alias_or_name} if from_ else set()
362
362
 
363
363
  for join in expression.args["joins"]:
@@ -6,7 +6,7 @@ from sqlglot import exp
6
6
  from sqlglot.errors import OptimizeError
7
7
  from sqlglot.helper import while_changing
8
8
  from sqlglot.optimizer.scope import find_all_in_scope
9
- from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
9
+ from sqlglot.optimizer.simplify import Simplifier, flatten
10
10
 
11
11
  logger = logging.getLogger("sqlglot")
12
12
 
@@ -28,6 +28,8 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
28
28
  Returns:
29
29
  sqlglot.Expression: normalized expression
30
30
  """
31
+ simplifier = Simplifier(annotate_new_expressions=False)
32
+
31
33
  for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))):
32
34
  if isinstance(node, exp.Connector):
33
35
  if normalized(node, dnf=dnf):
@@ -35,7 +37,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
35
37
  root = node is expression
36
38
  original = node.copy()
37
39
 
38
- node.transform(rewrite_between, copy=False)
40
+ node.transform(simplifier.rewrite_between, copy=False)
39
41
  distance = normalization_distance(node, dnf=dnf, max_=max_distance)
40
42
 
41
43
  if distance > max_distance:
@@ -46,7 +48,10 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
46
48
 
47
49
  try:
48
50
  node = node.replace(
49
- while_changing(node, lambda e: distributive_law(e, dnf, max_distance))
51
+ while_changing(
52
+ node,
53
+ lambda e: distributive_law(e, dnf, max_distance, simplifier=simplifier),
54
+ )
50
55
  )
51
56
  except OptimizeError as e:
52
57
  logger.info(e)
@@ -146,7 +151,7 @@ def _predicate_lengths(expression, dnf, max_=float("inf"), depth=0):
146
151
  yield from _predicate_lengths(right, dnf, max_, depth)
147
152
 
148
153
 
149
- def distributive_law(expression, dnf, max_distance):
154
+ def distributive_law(expression, dnf, max_distance, simplifier=None):
150
155
  """
151
156
  x OR (y AND z) -> (x OR y) AND (x OR z)
152
157
  (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
@@ -168,32 +173,34 @@ def distributive_law(expression, dnf, max_distance):
168
173
  from_func = exp.and_ if from_exp == exp.And else exp.or_
169
174
  to_func = exp.and_ if to_exp == exp.And else exp.or_
170
175
 
176
+ simplifier = simplifier or Simplifier(annotate_new_expressions=False)
177
+
171
178
  if isinstance(a, to_exp) and isinstance(b, to_exp):
172
179
  if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
173
- return _distribute(a, b, from_func, to_func)
174
- return _distribute(b, a, from_func, to_func)
180
+ return _distribute(a, b, from_func, to_func, simplifier)
181
+ return _distribute(b, a, from_func, to_func, simplifier)
175
182
  if isinstance(a, to_exp):
176
- return _distribute(b, a, from_func, to_func)
183
+ return _distribute(b, a, from_func, to_func, simplifier)
177
184
  if isinstance(b, to_exp):
178
- return _distribute(a, b, from_func, to_func)
185
+ return _distribute(a, b, from_func, to_func, simplifier)
179
186
 
180
187
  return expression
181
188
 
182
189
 
183
- def _distribute(a, b, from_func, to_func):
190
+ def _distribute(a, b, from_func, to_func, simplifier):
184
191
  if isinstance(a, exp.Connector):
185
192
  exp.replace_children(
186
193
  a,
187
194
  lambda c: to_func(
188
- uniq_sort(flatten(from_func(c, b.left))),
189
- uniq_sort(flatten(from_func(c, b.right))),
195
+ simplifier.uniq_sort(flatten(from_func(c, b.left))),
196
+ simplifier.uniq_sort(flatten(from_func(c, b.right))),
190
197
  copy=False,
191
198
  ),
192
199
  )
193
200
  else:
194
201
  a = to_func(
195
- uniq_sort(flatten(from_func(a, b.left))),
196
- uniq_sort(flatten(from_func(a, b.right))),
202
+ simplifier.uniq_sort(flatten(from_func(a, b.left))),
203
+ simplifier.uniq_sort(flatten(from_func(a, b.right))),
197
204
  copy=False,
198
205
  )
199
206
 
@@ -10,14 +10,18 @@ if t.TYPE_CHECKING:
10
10
 
11
11
 
12
12
  @t.overload
13
- def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: ...
13
+ def normalize_identifiers(
14
+ expression: E, dialect: DialectType = None, store_original_column_identifiers: bool = False
15
+ ) -> E: ...
14
16
 
15
17
 
16
18
  @t.overload
17
- def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: ...
19
+ def normalize_identifiers(
20
+ expression: str, dialect: DialectType = None, store_original_column_identifiers: bool = False
21
+ ) -> exp.Identifier: ...
18
22
 
19
23
 
20
- def normalize_identifiers(expression, dialect=None):
24
+ def normalize_identifiers(expression, dialect=None, store_original_column_identifiers=False):
21
25
  """
22
26
  Normalize identifiers by converting them to either lower or upper case,
23
27
  ensuring the semantics are preserved in each case (e.g. by respecting
@@ -48,6 +52,8 @@ def normalize_identifiers(expression, dialect=None):
48
52
  Args:
49
53
  expression: The expression to transform.
50
54
  dialect: The dialect to use in order to decide how to normalize identifiers.
55
+ store_original_column_identifiers: Whether to store the original column identifiers in
56
+ the meta data of the expression in case we want to undo the normalization at a later point.
51
57
 
52
58
  Returns:
53
59
  The transformed expression.
@@ -59,6 +65,14 @@ def normalize_identifiers(expression, dialect=None):
59
65
 
60
66
  for node in expression.walk(prune=lambda n: n.meta.get("case_sensitive")):
61
67
  if not node.meta.get("case_sensitive"):
68
+ if store_original_column_identifiers and isinstance(node, exp.Column):
69
+ # TODO: This does not handle non-column cases, e.g PARSE_JSON(...).key
70
+ parent = node
71
+ while parent and isinstance(parent.parent, exp.Dot):
72
+ parent = parent.parent
73
+
74
+ node.meta["dot_parts"] = [p.name for p in parent.parts]
75
+
62
76
  dialect.normalize_identifier(node)
63
77
 
64
78
  return expression
@@ -46,6 +46,7 @@ def optimize(
46
46
  catalog: t.Optional[str | exp.Identifier] = None,
47
47
  dialect: DialectType = None,
48
48
  rules: t.Sequence[t.Callable] = RULES,
49
+ sql: t.Optional[str] = None,
49
50
  **kwargs,
50
51
  ) -> exp.Expression:
51
52
  """
@@ -66,6 +67,8 @@ def optimize(
66
67
  rules: sequence of optimizer rules to use.
67
68
  Many of the rules require tables and columns to be qualified.
68
69
  Do not remove `qualify` from the sequence of rules unless you know what you're doing!
70
+ sql: Original SQL string for error highlighting. If not provided, errors will not include
71
+ highlighting. Requires that the expression has position metadata from parsing.
69
72
  **kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.
70
73
 
71
74
  Returns:
@@ -77,6 +80,7 @@ def optimize(
77
80
  "catalog": catalog,
78
81
  "schema": schema,
79
82
  "dialect": dialect,
83
+ "sql": sql,
80
84
  "isolate_tables": True, # needed for other optimizations to perform well
81
85
  "quote_identifiers": False,
82
86
  **kwargs,
@@ -181,7 +181,7 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
181
181
 
182
182
  # a node can reference a CTE which should be pushed down
183
183
  if isinstance(node, exp.From) and not isinstance(source, exp.Table):
184
- with_ = source.parent.expression.args.get("with")
184
+ with_ = source.parent.expression.args.get("with_")
185
185
  if with_ and with_.recursive:
186
186
  return {}
187
187
  node = source.expression
@@ -7,7 +7,6 @@ from sqlglot.dialects.dialect import Dialect, DialectType
7
7
  from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
8
8
  from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
9
9
  from sqlglot.optimizer.qualify_columns import (
10
- pushdown_cte_alias_columns as pushdown_cte_alias_columns_func,
11
10
  qualify_columns as qualify_columns_func,
12
11
  quote_identifiers as quote_identifiers_func,
13
12
  validate_qualify_columns as validate_qualify_columns_func,
@@ -31,7 +30,9 @@ def qualify(
31
30
  validate_qualify_columns: bool = True,
32
31
  quote_identifiers: bool = True,
33
32
  identify: bool = True,
33
+ canonicalize_table_aliases: bool = False,
34
34
  on_qualify: t.Optional[t.Callable[[exp.Expression], None]] = None,
35
+ sql: t.Optional[str] = None,
35
36
  ) -> exp.Expression:
36
37
  """
37
38
  Rewrite sqlglot AST to have normalized and qualified tables and columns.
@@ -63,28 +64,35 @@ def qualify(
63
64
  This step is necessary to ensure correctness for case sensitive queries.
64
65
  But this flag is provided in case this step is performed at a later time.
65
66
  identify: If True, quote all identifiers, else only necessary ones.
67
+ canonicalize_table_aliases: Whether to use canonical aliases (_0, _1, ...) for all sources
68
+ instead of preserving table names.
66
69
  on_qualify: Callback after a table has been qualified.
70
+ sql: Original SQL string for error highlighting. If not provided, errors will not include
71
+ highlighting. Requires that the expression has position metadata from parsing.
67
72
 
68
73
  Returns:
69
74
  The qualified expression.
70
75
  """
71
76
  schema = ensure_schema(schema, dialect=dialect)
77
+ dialect = Dialect.get_or_raise(dialect)
72
78
 
73
- expression = normalize_identifiers(expression, dialect=dialect)
79
+ expression = normalize_identifiers(
80
+ expression,
81
+ dialect=dialect,
82
+ store_original_column_identifiers=True,
83
+ )
74
84
  expression = qualify_tables(
75
85
  expression,
76
86
  db=db,
77
87
  catalog=catalog,
78
88
  dialect=dialect,
79
89
  on_qualify=on_qualify,
90
+ canonicalize_table_aliases=canonicalize_table_aliases,
80
91
  )
81
92
 
82
93
  if isolate_tables:
83
94
  expression = isolate_table_selects(expression, schema=schema)
84
95
 
85
- if Dialect.get_or_raise(dialect).PREFER_CTE_ALIAS_COLUMN:
86
- expression = pushdown_cte_alias_columns_func(expression)
87
-
88
96
  if qualify_columns:
89
97
  expression = qualify_columns_func(
90
98
  expression,
@@ -99,6 +107,6 @@ def qualify(
99
107
  expression = quote_identifiers_func(expression, dialect=dialect, identify=identify)
100
108
 
101
109
  if validate_qualify_columns:
102
- validate_qualify_columns_func(expression)
110
+ validate_qualify_columns_func(expression, sql=sql)
103
111
 
104
112
  return expression