sqlglot 27.28.0__py3-none-any.whl → 27.29.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqlglot/expressions.py CHANGED
@@ -118,7 +118,7 @@ class Expression(metaclass=_Expression):
118
118
  self._set_parent(arg_key, value)
119
119
 
120
120
  def __eq__(self, other) -> bool:
121
- return type(self) is type(other) and hash(self) == hash(other)
121
+ return self is other or (type(self) is type(other) and hash(self) == hash(other))
122
122
 
123
123
  def __hash__(self) -> int:
124
124
  if self._hash is None:
@@ -1893,7 +1893,13 @@ class Comment(Expression):
1893
1893
 
1894
1894
 
1895
1895
  class Comprehension(Expression):
1896
- arg_types = {"this": True, "expression": True, "iterator": True, "condition": False}
1896
+ arg_types = {
1897
+ "this": True,
1898
+ "expression": True,
1899
+ "position": False,
1900
+ "iterator": True,
1901
+ "condition": False,
1902
+ }
1897
1903
 
1898
1904
 
1899
1905
  # https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl
@@ -5620,6 +5626,14 @@ class Boolnot(Func):
5620
5626
  pass
5621
5627
 
5622
5628
 
5629
+ class Booland(Func):
5630
+ arg_types = {"this": True, "expression": True}
5631
+
5632
+
5633
+ class Boolor(Func):
5634
+ arg_types = {"this": True, "expression": True}
5635
+
5636
+
5623
5637
  # https://cloud.google.com/bigquery/docs/reference/standard-sql/json_functions#bool_for_json
5624
5638
  class JSONBool(Func):
5625
5639
  pass
@@ -5974,11 +5988,11 @@ class Lead(AggFunc):
5974
5988
  # some dialects have a distinction between first and first_value, usually first is an aggregate func
5975
5989
  # and first_value is a window func
5976
5990
  class First(AggFunc):
5977
- pass
5991
+ arg_types = {"this": True, "expression": False}
5978
5992
 
5979
5993
 
5980
5994
  class Last(AggFunc):
5981
- pass
5995
+ arg_types = {"this": True, "expression": False}
5982
5996
 
5983
5997
 
5984
5998
  class FirstValue(AggFunc):
@@ -6276,6 +6290,14 @@ class WeekOfYear(Func):
6276
6290
  _sql_names = ["WEEK_OF_YEAR", "WEEKOFYEAR"]
6277
6291
 
6278
6292
 
6293
+ class YearOfWeek(Func):
6294
+ _sql_names = ["YEAR_OF_WEEK", "YEAROFWEEK"]
6295
+
6296
+
6297
+ class YearOfWeekIso(Func):
6298
+ _sql_names = ["YEAR_OF_WEEK_ISO", "YEAROFWEEKISO"]
6299
+
6300
+
6279
6301
  class MonthsBetween(Func):
6280
6302
  arg_types = {"this": True, "expression": True, "roundoff": False}
6281
6303
 
@@ -6426,6 +6448,10 @@ class Encode(Func):
6426
6448
  arg_types = {"this": True, "charset": True}
6427
6449
 
6428
6450
 
6451
+ class EqualNull(Func):
6452
+ arg_types = {"this": True, "expression": True}
6453
+
6454
+
6429
6455
  class Exp(Func):
6430
6456
  pass
6431
6457
 
@@ -6570,6 +6596,16 @@ class Greatest(Func):
6570
6596
  is_var_len_args = True
6571
6597
 
6572
6598
 
6599
+ class GreatestIgnoreNulls(Func):
6600
+ arg_types = {"expressions": True}
6601
+ is_var_len_args = True
6602
+
6603
+
6604
+ class LeastIgnoreNulls(Func):
6605
+ arg_types = {"expressions": True}
6606
+ is_var_len_args = True
6607
+
6608
+
6573
6609
  # Trino's `ON OVERFLOW TRUNCATE [filler_string] {WITH | WITHOUT} COUNT`
6574
6610
  # https://trino.io/docs/current/functions/aggregate.html#listagg
6575
6611
  class OverflowTruncateBehavior(Expression):
@@ -6668,6 +6704,10 @@ class IsInf(Func):
6668
6704
  _sql_names = ["IS_INF", "ISINF"]
6669
6705
 
6670
6706
 
6707
+ class IsNullValue(Func):
6708
+ pass
6709
+
6710
+
6671
6711
  # https://www.postgresql.org/docs/current/functions-json.html
6672
6712
  class JSON(Expression):
6673
6713
  arg_types = {"this": False, "with": False, "unique": False}
@@ -7349,6 +7389,7 @@ class RegexpReplace(Func):
7349
7389
  "position": False,
7350
7390
  "occurrence": False,
7351
7391
  "modifiers": False,
7392
+ "single_replace": False,
7352
7393
  }
7353
7394
 
7354
7395
 
@@ -7391,6 +7432,14 @@ class RegexpCount(Func):
7391
7432
  }
7392
7433
 
7393
7434
 
7435
+ class RegrValx(Func):
7436
+ arg_types = {"this": True, "expression": True}
7437
+
7438
+
7439
+ class RegrValy(Func):
7440
+ arg_types = {"this": True, "expression": True}
7441
+
7442
+
7394
7443
  class Repeat(Func):
7395
7444
  arg_types = {"this": True, "times": True}
7396
7445
 
@@ -7754,18 +7803,38 @@ class Uuid(Func):
7754
7803
  arg_types = {"this": False, "name": False}
7755
7804
 
7756
7805
 
7806
+ TIMESTAMP_PARTS = {
7807
+ "year": False,
7808
+ "month": False,
7809
+ "day": False,
7810
+ "hour": False,
7811
+ "min": False,
7812
+ "sec": False,
7813
+ "nano": False,
7814
+ }
7815
+
7816
+
7757
7817
  class TimestampFromParts(Func):
7758
7818
  _sql_names = ["TIMESTAMP_FROM_PARTS", "TIMESTAMPFROMPARTS"]
7759
7819
  arg_types = {
7760
- "year": True,
7761
- "month": True,
7762
- "day": True,
7763
- "hour": True,
7764
- "min": True,
7765
- "sec": True,
7766
- "nano": False,
7820
+ **TIMESTAMP_PARTS,
7767
7821
  "zone": False,
7768
7822
  "milli": False,
7823
+ "this": False,
7824
+ "expression": False,
7825
+ }
7826
+
7827
+
7828
+ class TimestampLtzFromParts(Func):
7829
+ _sql_names = ["TIMESTAMP_LTZ_FROM_PARTS", "TIMESTAMPLTZFROMPARTS"]
7830
+ arg_types = TIMESTAMP_PARTS.copy()
7831
+
7832
+
7833
+ class TimestampTzFromParts(Func):
7834
+ _sql_names = ["TIMESTAMP_TZ_FROM_PARTS", "TIMESTAMPTZFROMPARTS"]
7835
+ arg_types = {
7836
+ **TIMESTAMP_PARTS,
7837
+ "zone": False,
7769
7838
  }
7770
7839
 
7771
7840
 
@@ -7851,7 +7920,8 @@ class Merge(DML):
7851
7920
  arg_types = {
7852
7921
  "this": True,
7853
7922
  "using": True,
7854
- "on": True,
7923
+ "on": False,
7924
+ "using_cond": False,
7855
7925
  "whens": True,
7856
7926
  "with": False,
7857
7927
  "returning": False,
@@ -9355,6 +9425,26 @@ def replace_tree(
9355
9425
  return new_node
9356
9426
 
9357
9427
 
9428
+ def find_tables(expression: Expression) -> t.Set[Table]:
9429
+ """
9430
+ Find all tables referenced in a query.
9431
+
9432
+ Args:
9433
+ expressions: The query to find the tables in.
9434
+
9435
+ Returns:
9436
+ A set of all the tables.
9437
+ """
9438
+ from sqlglot.optimizer.scope import traverse_scope
9439
+
9440
+ return {
9441
+ table
9442
+ for scope in traverse_scope(expression)
9443
+ for table in scope.tables
9444
+ if table.name and table.name not in scope.cte_sources
9445
+ }
9446
+
9447
+
9358
9448
  def column_table_names(expression: Expression, exclude: str = "") -> t.Set[str]:
9359
9449
  """
9360
9450
  Return all table names referenced through columns in an expression.
sqlglot/generator.py CHANGED
@@ -2531,6 +2531,12 @@ class Generator(metaclass=_Generator):
2531
2531
  def boolean_sql(self, expression: exp.Boolean) -> str:
2532
2532
  return "TRUE" if expression.this else "FALSE"
2533
2533
 
2534
+ def booland_sql(self, expression: exp.Booland) -> str:
2535
+ return f"(({self.sql(expression, 'this')}) AND ({self.sql(expression, 'expression')}))"
2536
+
2537
+ def boolor_sql(self, expression: exp.Boolor) -> str:
2538
+ return f"(({self.sql(expression, 'this')}) OR ({self.sql(expression, 'expression')}))"
2539
+
2534
2540
  def order_sql(self, expression: exp.Order, flat: bool = False) -> str:
2535
2541
  this = self.sql(expression, "this")
2536
2542
  this = f"{this} " if this else this
@@ -4078,9 +4084,15 @@ class Generator(metaclass=_Generator):
4078
4084
 
4079
4085
  this = self.sql(table)
4080
4086
  using = f"USING {self.sql(expression, 'using')}"
4081
- on = f"ON {self.sql(expression, 'on')}"
4082
4087
  whens = self.sql(expression, "whens")
4083
4088
 
4089
+ on = self.sql(expression, "on")
4090
+ on = f"ON {on}" if on else ""
4091
+
4092
+ if not on:
4093
+ on = self.expressions(expression, key="using_cond")
4094
+ on = f"USING ({on})" if on else ""
4095
+
4084
4096
  returning = self.sql(expression, "returning")
4085
4097
  if returning:
4086
4098
  whens = f"{whens}{returning}"
@@ -4244,10 +4256,12 @@ class Generator(metaclass=_Generator):
4244
4256
  def comprehension_sql(self, expression: exp.Comprehension) -> str:
4245
4257
  this = self.sql(expression, "this")
4246
4258
  expr = self.sql(expression, "expression")
4259
+ position = self.sql(expression, "position")
4260
+ position = f", {position}" if position else ""
4247
4261
  iterator = self.sql(expression, "iterator")
4248
4262
  condition = self.sql(expression, "condition")
4249
4263
  condition = f" IF {condition}" if condition else ""
4250
- return f"{this} FOR {expr} IN {iterator}{condition}"
4264
+ return f"{this} FOR {expr}{position} IN {iterator}{condition}"
4251
4265
 
4252
4266
  def columnprefix_sql(self, expression: exp.ColumnPrefix) -> str:
4253
4267
  return f"{self.sql(expression, 'this')}({self.sql(expression, 'expression')})"
sqlglot/helper.py CHANGED
@@ -7,7 +7,6 @@ import re
7
7
  import sys
8
8
  import typing as t
9
9
  from collections.abc import Collection, Set
10
- from contextlib import contextmanager
11
10
  from copy import copy
12
11
  from difflib import get_close_matches
13
12
  from enum import Enum
@@ -272,47 +271,6 @@ def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
272
271
  return result
273
272
 
274
273
 
275
- def open_file(file_name: str) -> t.TextIO:
276
- """Open a file that may be compressed as gzip and return it in universal newline mode."""
277
- with open(file_name, "rb") as f:
278
- gzipped = f.read(2) == b"\x1f\x8b"
279
-
280
- if gzipped:
281
- import gzip
282
-
283
- return gzip.open(file_name, "rt", newline="")
284
-
285
- return open(file_name, encoding="utf-8", newline="")
286
-
287
-
288
- @contextmanager
289
- def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
290
- """
291
- Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
292
-
293
- Args:
294
- read_csv: A `ReadCSV` function call.
295
-
296
- Yields:
297
- A python csv reader.
298
- """
299
- args = read_csv.expressions
300
- file = open_file(read_csv.name)
301
-
302
- delimiter = ","
303
- args = iter(arg.name for arg in args) # type: ignore
304
- for k, v in zip(args, args):
305
- if k == "delimiter":
306
- delimiter = v
307
-
308
- try:
309
- import csv as csv_
310
-
311
- yield csv_.reader(file, delimiter=delimiter)
312
- finally:
313
- file.close()
314
-
315
-
316
274
  def find_new_name(taken: t.Collection[str], base: str) -> str:
317
275
  """
318
276
  Searches for a new name.
sqlglot/lineage.py CHANGED
@@ -232,7 +232,7 @@ def to_node(
232
232
  )
233
233
 
234
234
  # if the select is a star add all scope sources as downstreams
235
- if select.is_star:
235
+ if isinstance(select, exp.Star):
236
236
  for source in scope.sources.values():
237
237
  if isinstance(source, Scope):
238
238
  source = source.expression
@@ -31,7 +31,7 @@ def qualify(
31
31
  validate_qualify_columns: bool = True,
32
32
  quote_identifiers: bool = True,
33
33
  identify: bool = True,
34
- infer_csv_schemas: bool = False,
34
+ on_qualify: t.Optional[t.Callable[[exp.Expression], None]] = None,
35
35
  ) -> exp.Expression:
36
36
  """
37
37
  Rewrite sqlglot AST to have normalized and qualified tables and columns.
@@ -63,21 +63,21 @@ def qualify(
63
63
  This step is necessary to ensure correctness for case sensitive queries.
64
64
  But this flag is provided in case this step is performed at a later time.
65
65
  identify: If True, quote all identifiers, else only necessary ones.
66
- infer_csv_schemas: Whether to scan READ_CSV calls in order to infer the CSVs' schemas.
66
+ on_qualify: Callback after a table has been qualified.
67
67
 
68
68
  Returns:
69
69
  The qualified expression.
70
70
  """
71
71
  schema = ensure_schema(schema, dialect=dialect)
72
+
73
+ expression = normalize_identifiers(expression, dialect=dialect)
72
74
  expression = qualify_tables(
73
75
  expression,
74
76
  db=db,
75
77
  catalog=catalog,
76
- schema=schema,
77
78
  dialect=dialect,
78
- infer_csv_schemas=infer_csv_schemas,
79
+ on_qualify=on_qualify,
79
80
  )
80
- expression = normalize_identifiers(expression, dialect=dialect)
81
81
 
82
82
  if isolate_tables:
83
83
  expression = isolate_table_selects(expression, schema=schema)
@@ -551,7 +551,8 @@ def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualificati
551
551
  continue
552
552
 
553
553
  # column_table can be a '' because bigquery unnest has no table alias
554
- column_table = resolver.get_table(column_name)
554
+ column_table = resolver.get_table(column)
555
+
555
556
  if column_table:
556
557
  column.set("table", column_table)
557
558
  elif (
@@ -948,21 +949,29 @@ class Resolver:
948
949
  self._infer_schema = infer_schema
949
950
  self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
950
951
 
951
- def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
952
+ def get_table(self, column: str | exp.Column) -> t.Optional[exp.Identifier]:
952
953
  """
953
954
  Get the table for a column name.
954
955
 
955
956
  Args:
956
- column_name: The column name to find the table for.
957
+ column: The column expression (or column name) to find the table for.
957
958
  Returns:
958
959
  The table name if it can be found/inferred.
959
960
  """
960
- if self._unambiguous_columns is None:
961
- self._unambiguous_columns = self._get_unambiguous_columns(
962
- self._get_all_source_columns()
963
- )
964
-
965
- table_name = self._unambiguous_columns.get(column_name)
961
+ column_name = column if isinstance(column, str) else column.name
962
+
963
+ table_name = self._get_table_name_from_sources(column_name)
964
+
965
+ if not table_name and isinstance(column, exp.Column):
966
+ # Fall-back case: If we couldn't find the `table_name` from ALL of the sources,
967
+ # attempt to disambiguate the column based on other characteristics e.g if this column is in a join condition,
968
+ # we may be able to disambiguate based on the source order.
969
+ if join_context := self._get_column_join_context(column):
970
+ # In this case, the return value will be the join that _may_ be able to disambiguate the column
971
+ # and we can use the source columns available at that join to get the table name
972
+ table_name = self._get_table_name_from_sources(
973
+ column_name, self._get_available_source_columns(join_context)
974
+ )
966
975
 
967
976
  if not table_name and self._infer_schema:
968
977
  sources_without_schema = tuple(
@@ -1101,6 +1110,77 @@ class Resolver:
1101
1110
  }
1102
1111
  return self._source_columns
1103
1112
 
1113
+ def _get_table_name_from_sources(
1114
+ self, column_name: str, source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
1115
+ ) -> t.Optional[str]:
1116
+ if not source_columns:
1117
+ # If not supplied, get all sources to calculate unambiguous columns
1118
+ if self._unambiguous_columns is None:
1119
+ self._unambiguous_columns = self._get_unambiguous_columns(
1120
+ self._get_all_source_columns()
1121
+ )
1122
+
1123
+ unambiguous_columns = self._unambiguous_columns
1124
+ else:
1125
+ unambiguous_columns = self._get_unambiguous_columns(source_columns)
1126
+
1127
+ return unambiguous_columns.get(column_name)
1128
+
1129
+ def _get_column_join_context(self, column: exp.Column) -> t.Optional[exp.Join]:
1130
+ """
1131
+ Check if a column participating in a join can be qualified based on the source order.
1132
+ """
1133
+ args = self.scope.expression.args
1134
+ joins = args.get("joins")
1135
+
1136
+ if not joins or args.get("laterals") or args.get("pivots"):
1137
+ # Feature gap: We currently don't try to disambiguate columns if other sources
1138
+ # (e.g laterals, pivots) exist alongside joins
1139
+ return None
1140
+
1141
+ join_ancestor = column.find_ancestor(exp.Join, exp.Select)
1142
+
1143
+ if (
1144
+ isinstance(join_ancestor, exp.Join)
1145
+ and join_ancestor.alias_or_name in self.scope.selected_sources
1146
+ ):
1147
+ # Ensure that the found ancestor is a join that contains an actual source,
1148
+ # e.g in Clickhouse `b` is an array expression in `a ARRAY JOIN b`
1149
+ return join_ancestor
1150
+
1151
+ return None
1152
+
1153
+ def _get_available_source_columns(
1154
+ self, join_ancestor: exp.Join
1155
+ ) -> t.Dict[str, t.Sequence[str]]:
1156
+ """
1157
+ Get the source columns that are available at the point where a column is referenced.
1158
+
1159
+ For columns in JOIN conditions, this only includes tables that have been joined
1160
+ up to that point. Example:
1161
+
1162
+ ```
1163
+ SELECT * FROM t_1 INNER JOIN ... INNER JOIN t_n ON t_1.a = c INNER JOIN t_n+1 ON ...
1164
+ ``` ^
1165
+ |
1166
+ +----------------------------------+
1167
+ |
1168
+
1169
+ The unqualified column `c` is not ambiguous if no other sources up until that
1170
+ join i.e t_1, ..., t_n, contain a column named `c`.
1171
+
1172
+ """
1173
+ args = self.scope.expression.args
1174
+
1175
+ # Collect tables in order: FROM clause tables + joined tables up to current join
1176
+ from_name = args["from"].alias_or_name
1177
+ available_sources = {from_name: self.get_source_columns(from_name)}
1178
+
1179
+ for join in args["joins"][: t.cast(int, join_ancestor.index) + 1]:
1180
+ available_sources[join.alias_or_name] = self.get_source_columns(join.alias_or_name)
1181
+
1182
+ return available_sources
1183
+
1104
1184
  def _get_unambiguous_columns(
1105
1185
  self, source_columns: t.Dict[str, t.Sequence[str]]
1106
1186
  ) -> t.Mapping[str, str]:
@@ -4,11 +4,10 @@ import itertools
4
4
  import typing as t
5
5
 
6
6
  from sqlglot import alias, exp
7
- from sqlglot.dialects.dialect import DialectType
8
- from sqlglot.helper import csv_reader, name_sequence
7
+ from sqlglot.dialects.dialect import Dialect, DialectType
8
+ from sqlglot.helper import name_sequence
9
+ from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
9
10
  from sqlglot.optimizer.scope import Scope, traverse_scope
10
- from sqlglot.schema import Schema
11
- from sqlglot.dialects.dialect import Dialect
12
11
 
13
12
  if t.TYPE_CHECKING:
14
13
  from sqlglot._typing import E
@@ -18,8 +17,7 @@ def qualify_tables(
18
17
  expression: E,
19
18
  db: t.Optional[str | exp.Identifier] = None,
20
19
  catalog: t.Optional[str | exp.Identifier] = None,
21
- schema: t.Optional[Schema] = None,
22
- infer_csv_schemas: bool = False,
20
+ on_qualify: t.Optional[t.Callable[[exp.Expression], None]] = None,
23
21
  dialect: DialectType = None,
24
22
  ) -> E:
25
23
  """
@@ -40,18 +38,28 @@ def qualify_tables(
40
38
  expression: Expression to qualify
41
39
  db: Database name
42
40
  catalog: Catalog name
43
- schema: A schema to populate
44
- infer_csv_schemas: Whether to scan READ_CSV calls in order to infer the CSVs' schemas.
41
+ on_qualify: Callback after a table has been qualified.
45
42
  dialect: The dialect to parse catalog and schema into.
46
43
 
47
44
  Returns:
48
45
  The qualified expression.
49
46
  """
50
- next_alias_name = name_sequence("_q_")
51
- db = exp.parse_identifier(db, dialect=dialect) if db else None
52
- catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
53
47
  dialect = Dialect.get_or_raise(dialect)
54
48
 
49
+ alias_sequence = name_sequence("_q_")
50
+
51
+ def next_alias_name() -> str:
52
+ return normalize_identifiers(alias_sequence(), dialect=dialect).name
53
+
54
+ if db := db or None:
55
+ db = exp.parse_identifier(db, dialect=dialect)
56
+ db.meta["is_table"] = True
57
+ db = normalize_identifiers(db, dialect=dialect)
58
+ if catalog := catalog or None:
59
+ catalog = exp.parse_identifier(catalog, dialect=dialect)
60
+ catalog.meta["is_table"] = True
61
+ catalog = normalize_identifiers(catalog, dialect=dialect)
62
+
55
63
  def _qualify(table: exp.Table) -> None:
56
64
  if isinstance(table.this, exp.Identifier):
57
65
  if db and not table.args.get("db"):
@@ -97,7 +105,10 @@ def qualify_tables(
97
105
  name = source.name
98
106
 
99
107
  # Mutates the source by attaching an alias to it
100
- alias(source, name or source.name or next_alias_name(), copy=False, table=True)
108
+ normalized_alias = normalize_identifiers(
109
+ name or source.name or alias_sequence(), dialect=dialect
110
+ )
111
+ alias(source, normalized_alias, copy=False, table=True)
101
112
 
102
113
  table_aliases[".".join(p.name for p in source.parts)] = exp.to_identifier(
103
114
  source.alias
@@ -106,7 +117,10 @@ def qualify_tables(
106
117
  if pivots:
107
118
  pivot = pivots[0]
108
119
  if not pivot.alias:
109
- pivot_alias = source.alias if pivot.unpivot else next_alias_name()
120
+ pivot_alias = normalize_identifiers(
121
+ source.alias if pivot.unpivot else alias_sequence(),
122
+ dialect=dialect,
123
+ )
110
124
  pivot.set("alias", exp.TableAlias(this=exp.to_identifier(pivot_alias)))
111
125
 
112
126
  # This case corresponds to a pivoted CTE, we don't want to qualify that
@@ -115,15 +129,8 @@ def qualify_tables(
115
129
 
116
130
  _qualify(source)
117
131
 
118
- if infer_csv_schemas and schema and isinstance(source.this, exp.ReadCSV):
119
- with csv_reader(source.this) as reader:
120
- header = next(reader)
121
- columns = next(reader)
122
- schema.add_table(
123
- source,
124
- {k: type(v).__name__ for k, v in zip(header, columns)},
125
- match_depth=False,
126
- )
132
+ if on_qualify:
133
+ on_qualify(source)
127
134
  elif isinstance(source, Scope) and source.is_udtf:
128
135
  udtf = source.expression
129
136
  table_alias = udtf.args.get("alias") or exp.TableAlias(
@@ -134,7 +141,10 @@ def qualify_tables(
134
141
  if not table_alias.name:
135
142
  table_alias.set("this", exp.to_identifier(next_alias_name()))
136
143
  if isinstance(udtf, exp.Values) and not table_alias.columns:
137
- column_aliases = dialect.generate_values_aliases(udtf)
144
+ column_aliases = [
145
+ normalize_identifiers(i, dialect=dialect)
146
+ for i in dialect.generate_values_aliases(udtf)
147
+ ]
138
148
  table_alias.set("columns", column_aliases)
139
149
  else:
140
150
  for node in scope.walk():
@@ -125,7 +125,7 @@ def simplify(
125
125
  node.set(k, v)
126
126
 
127
127
  # Post-order transformations
128
- new_node = simplify_not(node)
128
+ new_node = simplify_not(node, dialect)
129
129
  new_node = flatten(new_node)
130
130
  new_node = simplify_connectors(new_node, root)
131
131
  new_node = remove_complements(new_node, root)
@@ -202,7 +202,7 @@ COMPLEMENT_SUBQUERY_PREDICATES = {
202
202
  }
203
203
 
204
204
 
205
- def simplify_not(expression):
205
+ def simplify_not(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
206
206
  """
207
207
  Demorgan's Law
208
208
  NOT (x OR y) -> NOT x AND NOT y
@@ -243,10 +243,12 @@ def simplify_not(expression):
243
243
  return exp.false()
244
244
  if is_false(this):
245
245
  return exp.true()
246
- if isinstance(this, exp.Not):
247
- # double negation
248
- # NOT NOT x -> x
249
- return this.this
246
+ if isinstance(this, exp.Not) and dialect.SAFE_TO_ELIMINATE_DOUBLE_NEGATION:
247
+ inner = this.this
248
+ if inner.is_type(exp.DataType.Type.BOOLEAN) or isinstance(inner, exp.Predicate):
249
+ # double negation
250
+ # NOT NOT x -> x, if x is BOOLEAN type
251
+ return inner
250
252
  return expression
251
253
 
252
254
 
@@ -760,7 +762,10 @@ def simplify_parens(expression: exp.Expression, dialect: DialectType = None) ->
760
762
  not isinstance(this, exp.Binary)
761
763
  and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
762
764
  )
763
- or (isinstance(this, exp.Predicate) and not parent_is_predicate)
765
+ or (
766
+ isinstance(this, exp.Predicate)
767
+ and not (parent_is_predicate or isinstance(parent, exp.Neg))
768
+ )
764
769
  or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
765
770
  or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
766
771
  or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))