sqlglot 27.28.1__py3-none-any.whl → 27.29.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlglot/__init__.py +1 -0
- sqlglot/_version.py +2 -2
- sqlglot/dialects/bigquery.py +1 -0
- sqlglot/dialects/dialect.py +45 -7
- sqlglot/dialects/duckdb.py +17 -3
- sqlglot/dialects/mysql.py +1 -0
- sqlglot/dialects/postgres.py +14 -2
- sqlglot/dialects/snowflake.py +55 -18
- sqlglot/dialects/spark.py +3 -0
- sqlglot/dialects/sqlite.py +1 -0
- sqlglot/executor/__init__.py +5 -10
- sqlglot/executor/python.py +1 -29
- sqlglot/expressions.py +102 -12
- sqlglot/generator.py +16 -2
- sqlglot/helper.py +0 -42
- sqlglot/lineage.py +1 -1
- sqlglot/optimizer/qualify.py +5 -5
- sqlglot/optimizer/qualify_columns.py +89 -9
- sqlglot/optimizer/qualify_tables.py +33 -23
- sqlglot/optimizer/simplify.py +12 -7
- sqlglot/parser.py +16 -8
- {sqlglot-27.28.1.dist-info → sqlglot-27.29.0.dist-info}/METADATA +1 -1
- {sqlglot-27.28.1.dist-info → sqlglot-27.29.0.dist-info}/RECORD +26 -26
- {sqlglot-27.28.1.dist-info → sqlglot-27.29.0.dist-info}/WHEEL +0 -0
- {sqlglot-27.28.1.dist-info → sqlglot-27.29.0.dist-info}/licenses/LICENSE +0 -0
- {sqlglot-27.28.1.dist-info → sqlglot-27.29.0.dist-info}/top_level.txt +0 -0
sqlglot/expressions.py
CHANGED
|
@@ -118,7 +118,7 @@ class Expression(metaclass=_Expression):
|
|
|
118
118
|
self._set_parent(arg_key, value)
|
|
119
119
|
|
|
120
120
|
def __eq__(self, other) -> bool:
|
|
121
|
-
return type(self) is type(other) and hash(self) == hash(other)
|
|
121
|
+
return self is other or (type(self) is type(other) and hash(self) == hash(other))
|
|
122
122
|
|
|
123
123
|
def __hash__(self) -> int:
|
|
124
124
|
if self._hash is None:
|
|
@@ -1893,7 +1893,13 @@ class Comment(Expression):
|
|
|
1893
1893
|
|
|
1894
1894
|
|
|
1895
1895
|
class Comprehension(Expression):
|
|
1896
|
-
arg_types = {
|
|
1896
|
+
arg_types = {
|
|
1897
|
+
"this": True,
|
|
1898
|
+
"expression": True,
|
|
1899
|
+
"position": False,
|
|
1900
|
+
"iterator": True,
|
|
1901
|
+
"condition": False,
|
|
1902
|
+
}
|
|
1897
1903
|
|
|
1898
1904
|
|
|
1899
1905
|
# https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl
|
|
@@ -5620,6 +5626,14 @@ class Boolnot(Func):
|
|
|
5620
5626
|
pass
|
|
5621
5627
|
|
|
5622
5628
|
|
|
5629
|
+
class Booland(Func):
|
|
5630
|
+
arg_types = {"this": True, "expression": True}
|
|
5631
|
+
|
|
5632
|
+
|
|
5633
|
+
class Boolor(Func):
|
|
5634
|
+
arg_types = {"this": True, "expression": True}
|
|
5635
|
+
|
|
5636
|
+
|
|
5623
5637
|
# https://cloud.google.com/bigquery/docs/reference/standard-sql/json_functions#bool_for_json
|
|
5624
5638
|
class JSONBool(Func):
|
|
5625
5639
|
pass
|
|
@@ -5974,11 +5988,11 @@ class Lead(AggFunc):
|
|
|
5974
5988
|
# some dialects have a distinction between first and first_value, usually first is an aggregate func
|
|
5975
5989
|
# and first_value is a window func
|
|
5976
5990
|
class First(AggFunc):
|
|
5977
|
-
|
|
5991
|
+
arg_types = {"this": True, "expression": False}
|
|
5978
5992
|
|
|
5979
5993
|
|
|
5980
5994
|
class Last(AggFunc):
|
|
5981
|
-
|
|
5995
|
+
arg_types = {"this": True, "expression": False}
|
|
5982
5996
|
|
|
5983
5997
|
|
|
5984
5998
|
class FirstValue(AggFunc):
|
|
@@ -6276,6 +6290,14 @@ class WeekOfYear(Func):
|
|
|
6276
6290
|
_sql_names = ["WEEK_OF_YEAR", "WEEKOFYEAR"]
|
|
6277
6291
|
|
|
6278
6292
|
|
|
6293
|
+
class YearOfWeek(Func):
|
|
6294
|
+
_sql_names = ["YEAR_OF_WEEK", "YEAROFWEEK"]
|
|
6295
|
+
|
|
6296
|
+
|
|
6297
|
+
class YearOfWeekIso(Func):
|
|
6298
|
+
_sql_names = ["YEAR_OF_WEEK_ISO", "YEAROFWEEKISO"]
|
|
6299
|
+
|
|
6300
|
+
|
|
6279
6301
|
class MonthsBetween(Func):
|
|
6280
6302
|
arg_types = {"this": True, "expression": True, "roundoff": False}
|
|
6281
6303
|
|
|
@@ -6426,6 +6448,10 @@ class Encode(Func):
|
|
|
6426
6448
|
arg_types = {"this": True, "charset": True}
|
|
6427
6449
|
|
|
6428
6450
|
|
|
6451
|
+
class EqualNull(Func):
|
|
6452
|
+
arg_types = {"this": True, "expression": True}
|
|
6453
|
+
|
|
6454
|
+
|
|
6429
6455
|
class Exp(Func):
|
|
6430
6456
|
pass
|
|
6431
6457
|
|
|
@@ -6570,6 +6596,16 @@ class Greatest(Func):
|
|
|
6570
6596
|
is_var_len_args = True
|
|
6571
6597
|
|
|
6572
6598
|
|
|
6599
|
+
class GreatestIgnoreNulls(Func):
|
|
6600
|
+
arg_types = {"expressions": True}
|
|
6601
|
+
is_var_len_args = True
|
|
6602
|
+
|
|
6603
|
+
|
|
6604
|
+
class LeastIgnoreNulls(Func):
|
|
6605
|
+
arg_types = {"expressions": True}
|
|
6606
|
+
is_var_len_args = True
|
|
6607
|
+
|
|
6608
|
+
|
|
6573
6609
|
# Trino's `ON OVERFLOW TRUNCATE [filler_string] {WITH | WITHOUT} COUNT`
|
|
6574
6610
|
# https://trino.io/docs/current/functions/aggregate.html#listagg
|
|
6575
6611
|
class OverflowTruncateBehavior(Expression):
|
|
@@ -6668,6 +6704,10 @@ class IsInf(Func):
|
|
|
6668
6704
|
_sql_names = ["IS_INF", "ISINF"]
|
|
6669
6705
|
|
|
6670
6706
|
|
|
6707
|
+
class IsNullValue(Func):
|
|
6708
|
+
pass
|
|
6709
|
+
|
|
6710
|
+
|
|
6671
6711
|
# https://www.postgresql.org/docs/current/functions-json.html
|
|
6672
6712
|
class JSON(Expression):
|
|
6673
6713
|
arg_types = {"this": False, "with": False, "unique": False}
|
|
@@ -7349,6 +7389,7 @@ class RegexpReplace(Func):
|
|
|
7349
7389
|
"position": False,
|
|
7350
7390
|
"occurrence": False,
|
|
7351
7391
|
"modifiers": False,
|
|
7392
|
+
"single_replace": False,
|
|
7352
7393
|
}
|
|
7353
7394
|
|
|
7354
7395
|
|
|
@@ -7391,6 +7432,14 @@ class RegexpCount(Func):
|
|
|
7391
7432
|
}
|
|
7392
7433
|
|
|
7393
7434
|
|
|
7435
|
+
class RegrValx(Func):
|
|
7436
|
+
arg_types = {"this": True, "expression": True}
|
|
7437
|
+
|
|
7438
|
+
|
|
7439
|
+
class RegrValy(Func):
|
|
7440
|
+
arg_types = {"this": True, "expression": True}
|
|
7441
|
+
|
|
7442
|
+
|
|
7394
7443
|
class Repeat(Func):
|
|
7395
7444
|
arg_types = {"this": True, "times": True}
|
|
7396
7445
|
|
|
@@ -7754,18 +7803,38 @@ class Uuid(Func):
|
|
|
7754
7803
|
arg_types = {"this": False, "name": False}
|
|
7755
7804
|
|
|
7756
7805
|
|
|
7806
|
+
TIMESTAMP_PARTS = {
|
|
7807
|
+
"year": False,
|
|
7808
|
+
"month": False,
|
|
7809
|
+
"day": False,
|
|
7810
|
+
"hour": False,
|
|
7811
|
+
"min": False,
|
|
7812
|
+
"sec": False,
|
|
7813
|
+
"nano": False,
|
|
7814
|
+
}
|
|
7815
|
+
|
|
7816
|
+
|
|
7757
7817
|
class TimestampFromParts(Func):
|
|
7758
7818
|
_sql_names = ["TIMESTAMP_FROM_PARTS", "TIMESTAMPFROMPARTS"]
|
|
7759
7819
|
arg_types = {
|
|
7760
|
-
|
|
7761
|
-
"month": True,
|
|
7762
|
-
"day": True,
|
|
7763
|
-
"hour": True,
|
|
7764
|
-
"min": True,
|
|
7765
|
-
"sec": True,
|
|
7766
|
-
"nano": False,
|
|
7820
|
+
**TIMESTAMP_PARTS,
|
|
7767
7821
|
"zone": False,
|
|
7768
7822
|
"milli": False,
|
|
7823
|
+
"this": False,
|
|
7824
|
+
"expression": False,
|
|
7825
|
+
}
|
|
7826
|
+
|
|
7827
|
+
|
|
7828
|
+
class TimestampLtzFromParts(Func):
|
|
7829
|
+
_sql_names = ["TIMESTAMP_LTZ_FROM_PARTS", "TIMESTAMPLTZFROMPARTS"]
|
|
7830
|
+
arg_types = TIMESTAMP_PARTS.copy()
|
|
7831
|
+
|
|
7832
|
+
|
|
7833
|
+
class TimestampTzFromParts(Func):
|
|
7834
|
+
_sql_names = ["TIMESTAMP_TZ_FROM_PARTS", "TIMESTAMPTZFROMPARTS"]
|
|
7835
|
+
arg_types = {
|
|
7836
|
+
**TIMESTAMP_PARTS,
|
|
7837
|
+
"zone": False,
|
|
7769
7838
|
}
|
|
7770
7839
|
|
|
7771
7840
|
|
|
@@ -7851,7 +7920,8 @@ class Merge(DML):
|
|
|
7851
7920
|
arg_types = {
|
|
7852
7921
|
"this": True,
|
|
7853
7922
|
"using": True,
|
|
7854
|
-
"on":
|
|
7923
|
+
"on": False,
|
|
7924
|
+
"using_cond": False,
|
|
7855
7925
|
"whens": True,
|
|
7856
7926
|
"with": False,
|
|
7857
7927
|
"returning": False,
|
|
@@ -9355,6 +9425,26 @@ def replace_tree(
|
|
|
9355
9425
|
return new_node
|
|
9356
9426
|
|
|
9357
9427
|
|
|
9428
|
+
def find_tables(expression: Expression) -> t.Set[Table]:
|
|
9429
|
+
"""
|
|
9430
|
+
Find all tables referenced in a query.
|
|
9431
|
+
|
|
9432
|
+
Args:
|
|
9433
|
+
expressions: The query to find the tables in.
|
|
9434
|
+
|
|
9435
|
+
Returns:
|
|
9436
|
+
A set of all the tables.
|
|
9437
|
+
"""
|
|
9438
|
+
from sqlglot.optimizer.scope import traverse_scope
|
|
9439
|
+
|
|
9440
|
+
return {
|
|
9441
|
+
table
|
|
9442
|
+
for scope in traverse_scope(expression)
|
|
9443
|
+
for table in scope.tables
|
|
9444
|
+
if table.name and table.name not in scope.cte_sources
|
|
9445
|
+
}
|
|
9446
|
+
|
|
9447
|
+
|
|
9358
9448
|
def column_table_names(expression: Expression, exclude: str = "") -> t.Set[str]:
|
|
9359
9449
|
"""
|
|
9360
9450
|
Return all table names referenced through columns in an expression.
|
sqlglot/generator.py
CHANGED
|
@@ -2531,6 +2531,12 @@ class Generator(metaclass=_Generator):
|
|
|
2531
2531
|
def boolean_sql(self, expression: exp.Boolean) -> str:
|
|
2532
2532
|
return "TRUE" if expression.this else "FALSE"
|
|
2533
2533
|
|
|
2534
|
+
def booland_sql(self, expression: exp.Booland) -> str:
|
|
2535
|
+
return f"(({self.sql(expression, 'this')}) AND ({self.sql(expression, 'expression')}))"
|
|
2536
|
+
|
|
2537
|
+
def boolor_sql(self, expression: exp.Boolor) -> str:
|
|
2538
|
+
return f"(({self.sql(expression, 'this')}) OR ({self.sql(expression, 'expression')}))"
|
|
2539
|
+
|
|
2534
2540
|
def order_sql(self, expression: exp.Order, flat: bool = False) -> str:
|
|
2535
2541
|
this = self.sql(expression, "this")
|
|
2536
2542
|
this = f"{this} " if this else this
|
|
@@ -4078,9 +4084,15 @@ class Generator(metaclass=_Generator):
|
|
|
4078
4084
|
|
|
4079
4085
|
this = self.sql(table)
|
|
4080
4086
|
using = f"USING {self.sql(expression, 'using')}"
|
|
4081
|
-
on = f"ON {self.sql(expression, 'on')}"
|
|
4082
4087
|
whens = self.sql(expression, "whens")
|
|
4083
4088
|
|
|
4089
|
+
on = self.sql(expression, "on")
|
|
4090
|
+
on = f"ON {on}" if on else ""
|
|
4091
|
+
|
|
4092
|
+
if not on:
|
|
4093
|
+
on = self.expressions(expression, key="using_cond")
|
|
4094
|
+
on = f"USING ({on})" if on else ""
|
|
4095
|
+
|
|
4084
4096
|
returning = self.sql(expression, "returning")
|
|
4085
4097
|
if returning:
|
|
4086
4098
|
whens = f"{whens}{returning}"
|
|
@@ -4244,10 +4256,12 @@ class Generator(metaclass=_Generator):
|
|
|
4244
4256
|
def comprehension_sql(self, expression: exp.Comprehension) -> str:
|
|
4245
4257
|
this = self.sql(expression, "this")
|
|
4246
4258
|
expr = self.sql(expression, "expression")
|
|
4259
|
+
position = self.sql(expression, "position")
|
|
4260
|
+
position = f", {position}" if position else ""
|
|
4247
4261
|
iterator = self.sql(expression, "iterator")
|
|
4248
4262
|
condition = self.sql(expression, "condition")
|
|
4249
4263
|
condition = f" IF {condition}" if condition else ""
|
|
4250
|
-
return f"{this} FOR {expr} IN {iterator}{condition}"
|
|
4264
|
+
return f"{this} FOR {expr}{position} IN {iterator}{condition}"
|
|
4251
4265
|
|
|
4252
4266
|
def columnprefix_sql(self, expression: exp.ColumnPrefix) -> str:
|
|
4253
4267
|
return f"{self.sql(expression, 'this')}({self.sql(expression, 'expression')})"
|
sqlglot/helper.py
CHANGED
|
@@ -7,7 +7,6 @@ import re
|
|
|
7
7
|
import sys
|
|
8
8
|
import typing as t
|
|
9
9
|
from collections.abc import Collection, Set
|
|
10
|
-
from contextlib import contextmanager
|
|
11
10
|
from copy import copy
|
|
12
11
|
from difflib import get_close_matches
|
|
13
12
|
from enum import Enum
|
|
@@ -272,47 +271,6 @@ def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
|
|
|
272
271
|
return result
|
|
273
272
|
|
|
274
273
|
|
|
275
|
-
def open_file(file_name: str) -> t.TextIO:
|
|
276
|
-
"""Open a file that may be compressed as gzip and return it in universal newline mode."""
|
|
277
|
-
with open(file_name, "rb") as f:
|
|
278
|
-
gzipped = f.read(2) == b"\x1f\x8b"
|
|
279
|
-
|
|
280
|
-
if gzipped:
|
|
281
|
-
import gzip
|
|
282
|
-
|
|
283
|
-
return gzip.open(file_name, "rt", newline="")
|
|
284
|
-
|
|
285
|
-
return open(file_name, encoding="utf-8", newline="")
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
@contextmanager
|
|
289
|
-
def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
|
|
290
|
-
"""
|
|
291
|
-
Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
|
|
292
|
-
|
|
293
|
-
Args:
|
|
294
|
-
read_csv: A `ReadCSV` function call.
|
|
295
|
-
|
|
296
|
-
Yields:
|
|
297
|
-
A python csv reader.
|
|
298
|
-
"""
|
|
299
|
-
args = read_csv.expressions
|
|
300
|
-
file = open_file(read_csv.name)
|
|
301
|
-
|
|
302
|
-
delimiter = ","
|
|
303
|
-
args = iter(arg.name for arg in args) # type: ignore
|
|
304
|
-
for k, v in zip(args, args):
|
|
305
|
-
if k == "delimiter":
|
|
306
|
-
delimiter = v
|
|
307
|
-
|
|
308
|
-
try:
|
|
309
|
-
import csv as csv_
|
|
310
|
-
|
|
311
|
-
yield csv_.reader(file, delimiter=delimiter)
|
|
312
|
-
finally:
|
|
313
|
-
file.close()
|
|
314
|
-
|
|
315
|
-
|
|
316
274
|
def find_new_name(taken: t.Collection[str], base: str) -> str:
|
|
317
275
|
"""
|
|
318
276
|
Searches for a new name.
|
sqlglot/lineage.py
CHANGED
|
@@ -232,7 +232,7 @@ def to_node(
|
|
|
232
232
|
)
|
|
233
233
|
|
|
234
234
|
# if the select is a star add all scope sources as downstreams
|
|
235
|
-
if select.
|
|
235
|
+
if isinstance(select, exp.Star):
|
|
236
236
|
for source in scope.sources.values():
|
|
237
237
|
if isinstance(source, Scope):
|
|
238
238
|
source = source.expression
|
sqlglot/optimizer/qualify.py
CHANGED
|
@@ -31,7 +31,7 @@ def qualify(
|
|
|
31
31
|
validate_qualify_columns: bool = True,
|
|
32
32
|
quote_identifiers: bool = True,
|
|
33
33
|
identify: bool = True,
|
|
34
|
-
|
|
34
|
+
on_qualify: t.Optional[t.Callable[[exp.Expression], None]] = None,
|
|
35
35
|
) -> exp.Expression:
|
|
36
36
|
"""
|
|
37
37
|
Rewrite sqlglot AST to have normalized and qualified tables and columns.
|
|
@@ -63,21 +63,21 @@ def qualify(
|
|
|
63
63
|
This step is necessary to ensure correctness for case sensitive queries.
|
|
64
64
|
But this flag is provided in case this step is performed at a later time.
|
|
65
65
|
identify: If True, quote all identifiers, else only necessary ones.
|
|
66
|
-
|
|
66
|
+
on_qualify: Callback after a table has been qualified.
|
|
67
67
|
|
|
68
68
|
Returns:
|
|
69
69
|
The qualified expression.
|
|
70
70
|
"""
|
|
71
71
|
schema = ensure_schema(schema, dialect=dialect)
|
|
72
|
+
|
|
73
|
+
expression = normalize_identifiers(expression, dialect=dialect)
|
|
72
74
|
expression = qualify_tables(
|
|
73
75
|
expression,
|
|
74
76
|
db=db,
|
|
75
77
|
catalog=catalog,
|
|
76
|
-
schema=schema,
|
|
77
78
|
dialect=dialect,
|
|
78
|
-
|
|
79
|
+
on_qualify=on_qualify,
|
|
79
80
|
)
|
|
80
|
-
expression = normalize_identifiers(expression, dialect=dialect)
|
|
81
81
|
|
|
82
82
|
if isolate_tables:
|
|
83
83
|
expression = isolate_table_selects(expression, schema=schema)
|
|
@@ -551,7 +551,8 @@ def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualificati
|
|
|
551
551
|
continue
|
|
552
552
|
|
|
553
553
|
# column_table can be a '' because bigquery unnest has no table alias
|
|
554
|
-
column_table = resolver.get_table(
|
|
554
|
+
column_table = resolver.get_table(column)
|
|
555
|
+
|
|
555
556
|
if column_table:
|
|
556
557
|
column.set("table", column_table)
|
|
557
558
|
elif (
|
|
@@ -948,21 +949,29 @@ class Resolver:
|
|
|
948
949
|
self._infer_schema = infer_schema
|
|
949
950
|
self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
|
|
950
951
|
|
|
951
|
-
def get_table(self,
|
|
952
|
+
def get_table(self, column: str | exp.Column) -> t.Optional[exp.Identifier]:
|
|
952
953
|
"""
|
|
953
954
|
Get the table for a column name.
|
|
954
955
|
|
|
955
956
|
Args:
|
|
956
|
-
|
|
957
|
+
column: The column expression (or column name) to find the table for.
|
|
957
958
|
Returns:
|
|
958
959
|
The table name if it can be found/inferred.
|
|
959
960
|
"""
|
|
960
|
-
if
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
961
|
+
column_name = column if isinstance(column, str) else column.name
|
|
962
|
+
|
|
963
|
+
table_name = self._get_table_name_from_sources(column_name)
|
|
964
|
+
|
|
965
|
+
if not table_name and isinstance(column, exp.Column):
|
|
966
|
+
# Fall-back case: If we couldn't find the `table_name` from ALL of the sources,
|
|
967
|
+
# attempt to disambiguate the column based on other characteristics e.g if this column is in a join condition,
|
|
968
|
+
# we may be able to disambiguate based on the source order.
|
|
969
|
+
if join_context := self._get_column_join_context(column):
|
|
970
|
+
# In this case, the return value will be the join that _may_ be able to disambiguate the column
|
|
971
|
+
# and we can use the source columns available at that join to get the table name
|
|
972
|
+
table_name = self._get_table_name_from_sources(
|
|
973
|
+
column_name, self._get_available_source_columns(join_context)
|
|
974
|
+
)
|
|
966
975
|
|
|
967
976
|
if not table_name and self._infer_schema:
|
|
968
977
|
sources_without_schema = tuple(
|
|
@@ -1101,6 +1110,77 @@ class Resolver:
|
|
|
1101
1110
|
}
|
|
1102
1111
|
return self._source_columns
|
|
1103
1112
|
|
|
1113
|
+
def _get_table_name_from_sources(
|
|
1114
|
+
self, column_name: str, source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
|
|
1115
|
+
) -> t.Optional[str]:
|
|
1116
|
+
if not source_columns:
|
|
1117
|
+
# If not supplied, get all sources to calculate unambiguous columns
|
|
1118
|
+
if self._unambiguous_columns is None:
|
|
1119
|
+
self._unambiguous_columns = self._get_unambiguous_columns(
|
|
1120
|
+
self._get_all_source_columns()
|
|
1121
|
+
)
|
|
1122
|
+
|
|
1123
|
+
unambiguous_columns = self._unambiguous_columns
|
|
1124
|
+
else:
|
|
1125
|
+
unambiguous_columns = self._get_unambiguous_columns(source_columns)
|
|
1126
|
+
|
|
1127
|
+
return unambiguous_columns.get(column_name)
|
|
1128
|
+
|
|
1129
|
+
def _get_column_join_context(self, column: exp.Column) -> t.Optional[exp.Join]:
|
|
1130
|
+
"""
|
|
1131
|
+
Check if a column participating in a join can be qualified based on the source order.
|
|
1132
|
+
"""
|
|
1133
|
+
args = self.scope.expression.args
|
|
1134
|
+
joins = args.get("joins")
|
|
1135
|
+
|
|
1136
|
+
if not joins or args.get("laterals") or args.get("pivots"):
|
|
1137
|
+
# Feature gap: We currently don't try to disambiguate columns if other sources
|
|
1138
|
+
# (e.g laterals, pivots) exist alongside joins
|
|
1139
|
+
return None
|
|
1140
|
+
|
|
1141
|
+
join_ancestor = column.find_ancestor(exp.Join, exp.Select)
|
|
1142
|
+
|
|
1143
|
+
if (
|
|
1144
|
+
isinstance(join_ancestor, exp.Join)
|
|
1145
|
+
and join_ancestor.alias_or_name in self.scope.selected_sources
|
|
1146
|
+
):
|
|
1147
|
+
# Ensure that the found ancestor is a join that contains an actual source,
|
|
1148
|
+
# e.g in Clickhouse `b` is an array expression in `a ARRAY JOIN b`
|
|
1149
|
+
return join_ancestor
|
|
1150
|
+
|
|
1151
|
+
return None
|
|
1152
|
+
|
|
1153
|
+
def _get_available_source_columns(
|
|
1154
|
+
self, join_ancestor: exp.Join
|
|
1155
|
+
) -> t.Dict[str, t.Sequence[str]]:
|
|
1156
|
+
"""
|
|
1157
|
+
Get the source columns that are available at the point where a column is referenced.
|
|
1158
|
+
|
|
1159
|
+
For columns in JOIN conditions, this only includes tables that have been joined
|
|
1160
|
+
up to that point. Example:
|
|
1161
|
+
|
|
1162
|
+
```
|
|
1163
|
+
SELECT * FROM t_1 INNER JOIN ... INNER JOIN t_n ON t_1.a = c INNER JOIN t_n+1 ON ...
|
|
1164
|
+
``` ^
|
|
1165
|
+
|
|
|
1166
|
+
+----------------------------------+
|
|
1167
|
+
|
|
|
1168
|
+
⌄
|
|
1169
|
+
The unqualified column `c` is not ambiguous if no other sources up until that
|
|
1170
|
+
join i.e t_1, ..., t_n, contain a column named `c`.
|
|
1171
|
+
|
|
1172
|
+
"""
|
|
1173
|
+
args = self.scope.expression.args
|
|
1174
|
+
|
|
1175
|
+
# Collect tables in order: FROM clause tables + joined tables up to current join
|
|
1176
|
+
from_name = args["from"].alias_or_name
|
|
1177
|
+
available_sources = {from_name: self.get_source_columns(from_name)}
|
|
1178
|
+
|
|
1179
|
+
for join in args["joins"][: t.cast(int, join_ancestor.index) + 1]:
|
|
1180
|
+
available_sources[join.alias_or_name] = self.get_source_columns(join.alias_or_name)
|
|
1181
|
+
|
|
1182
|
+
return available_sources
|
|
1183
|
+
|
|
1104
1184
|
def _get_unambiguous_columns(
|
|
1105
1185
|
self, source_columns: t.Dict[str, t.Sequence[str]]
|
|
1106
1186
|
) -> t.Mapping[str, str]:
|
|
@@ -4,11 +4,10 @@ import itertools
|
|
|
4
4
|
import typing as t
|
|
5
5
|
|
|
6
6
|
from sqlglot import alias, exp
|
|
7
|
-
from sqlglot.dialects.dialect import DialectType
|
|
8
|
-
from sqlglot.helper import
|
|
7
|
+
from sqlglot.dialects.dialect import Dialect, DialectType
|
|
8
|
+
from sqlglot.helper import name_sequence
|
|
9
|
+
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
|
|
9
10
|
from sqlglot.optimizer.scope import Scope, traverse_scope
|
|
10
|
-
from sqlglot.schema import Schema
|
|
11
|
-
from sqlglot.dialects.dialect import Dialect
|
|
12
11
|
|
|
13
12
|
if t.TYPE_CHECKING:
|
|
14
13
|
from sqlglot._typing import E
|
|
@@ -18,8 +17,7 @@ def qualify_tables(
|
|
|
18
17
|
expression: E,
|
|
19
18
|
db: t.Optional[str | exp.Identifier] = None,
|
|
20
19
|
catalog: t.Optional[str | exp.Identifier] = None,
|
|
21
|
-
|
|
22
|
-
infer_csv_schemas: bool = False,
|
|
20
|
+
on_qualify: t.Optional[t.Callable[[exp.Expression], None]] = None,
|
|
23
21
|
dialect: DialectType = None,
|
|
24
22
|
) -> E:
|
|
25
23
|
"""
|
|
@@ -40,18 +38,28 @@ def qualify_tables(
|
|
|
40
38
|
expression: Expression to qualify
|
|
41
39
|
db: Database name
|
|
42
40
|
catalog: Catalog name
|
|
43
|
-
|
|
44
|
-
infer_csv_schemas: Whether to scan READ_CSV calls in order to infer the CSVs' schemas.
|
|
41
|
+
on_qualify: Callback after a table has been qualified.
|
|
45
42
|
dialect: The dialect to parse catalog and schema into.
|
|
46
43
|
|
|
47
44
|
Returns:
|
|
48
45
|
The qualified expression.
|
|
49
46
|
"""
|
|
50
|
-
next_alias_name = name_sequence("_q_")
|
|
51
|
-
db = exp.parse_identifier(db, dialect=dialect) if db else None
|
|
52
|
-
catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
|
|
53
47
|
dialect = Dialect.get_or_raise(dialect)
|
|
54
48
|
|
|
49
|
+
alias_sequence = name_sequence("_q_")
|
|
50
|
+
|
|
51
|
+
def next_alias_name() -> str:
|
|
52
|
+
return normalize_identifiers(alias_sequence(), dialect=dialect).name
|
|
53
|
+
|
|
54
|
+
if db := db or None:
|
|
55
|
+
db = exp.parse_identifier(db, dialect=dialect)
|
|
56
|
+
db.meta["is_table"] = True
|
|
57
|
+
db = normalize_identifiers(db, dialect=dialect)
|
|
58
|
+
if catalog := catalog or None:
|
|
59
|
+
catalog = exp.parse_identifier(catalog, dialect=dialect)
|
|
60
|
+
catalog.meta["is_table"] = True
|
|
61
|
+
catalog = normalize_identifiers(catalog, dialect=dialect)
|
|
62
|
+
|
|
55
63
|
def _qualify(table: exp.Table) -> None:
|
|
56
64
|
if isinstance(table.this, exp.Identifier):
|
|
57
65
|
if db and not table.args.get("db"):
|
|
@@ -97,7 +105,10 @@ def qualify_tables(
|
|
|
97
105
|
name = source.name
|
|
98
106
|
|
|
99
107
|
# Mutates the source by attaching an alias to it
|
|
100
|
-
|
|
108
|
+
normalized_alias = normalize_identifiers(
|
|
109
|
+
name or source.name or alias_sequence(), dialect=dialect
|
|
110
|
+
)
|
|
111
|
+
alias(source, normalized_alias, copy=False, table=True)
|
|
101
112
|
|
|
102
113
|
table_aliases[".".join(p.name for p in source.parts)] = exp.to_identifier(
|
|
103
114
|
source.alias
|
|
@@ -106,7 +117,10 @@ def qualify_tables(
|
|
|
106
117
|
if pivots:
|
|
107
118
|
pivot = pivots[0]
|
|
108
119
|
if not pivot.alias:
|
|
109
|
-
pivot_alias =
|
|
120
|
+
pivot_alias = normalize_identifiers(
|
|
121
|
+
source.alias if pivot.unpivot else alias_sequence(),
|
|
122
|
+
dialect=dialect,
|
|
123
|
+
)
|
|
110
124
|
pivot.set("alias", exp.TableAlias(this=exp.to_identifier(pivot_alias)))
|
|
111
125
|
|
|
112
126
|
# This case corresponds to a pivoted CTE, we don't want to qualify that
|
|
@@ -115,15 +129,8 @@ def qualify_tables(
|
|
|
115
129
|
|
|
116
130
|
_qualify(source)
|
|
117
131
|
|
|
118
|
-
if
|
|
119
|
-
|
|
120
|
-
header = next(reader)
|
|
121
|
-
columns = next(reader)
|
|
122
|
-
schema.add_table(
|
|
123
|
-
source,
|
|
124
|
-
{k: type(v).__name__ for k, v in zip(header, columns)},
|
|
125
|
-
match_depth=False,
|
|
126
|
-
)
|
|
132
|
+
if on_qualify:
|
|
133
|
+
on_qualify(source)
|
|
127
134
|
elif isinstance(source, Scope) and source.is_udtf:
|
|
128
135
|
udtf = source.expression
|
|
129
136
|
table_alias = udtf.args.get("alias") or exp.TableAlias(
|
|
@@ -134,7 +141,10 @@ def qualify_tables(
|
|
|
134
141
|
if not table_alias.name:
|
|
135
142
|
table_alias.set("this", exp.to_identifier(next_alias_name()))
|
|
136
143
|
if isinstance(udtf, exp.Values) and not table_alias.columns:
|
|
137
|
-
column_aliases =
|
|
144
|
+
column_aliases = [
|
|
145
|
+
normalize_identifiers(i, dialect=dialect)
|
|
146
|
+
for i in dialect.generate_values_aliases(udtf)
|
|
147
|
+
]
|
|
138
148
|
table_alias.set("columns", column_aliases)
|
|
139
149
|
else:
|
|
140
150
|
for node in scope.walk():
|
sqlglot/optimizer/simplify.py
CHANGED
|
@@ -125,7 +125,7 @@ def simplify(
|
|
|
125
125
|
node.set(k, v)
|
|
126
126
|
|
|
127
127
|
# Post-order transformations
|
|
128
|
-
new_node = simplify_not(node)
|
|
128
|
+
new_node = simplify_not(node, dialect)
|
|
129
129
|
new_node = flatten(new_node)
|
|
130
130
|
new_node = simplify_connectors(new_node, root)
|
|
131
131
|
new_node = remove_complements(new_node, root)
|
|
@@ -202,7 +202,7 @@ COMPLEMENT_SUBQUERY_PREDICATES = {
|
|
|
202
202
|
}
|
|
203
203
|
|
|
204
204
|
|
|
205
|
-
def simplify_not(expression):
|
|
205
|
+
def simplify_not(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
|
|
206
206
|
"""
|
|
207
207
|
Demorgan's Law
|
|
208
208
|
NOT (x OR y) -> NOT x AND NOT y
|
|
@@ -243,10 +243,12 @@ def simplify_not(expression):
|
|
|
243
243
|
return exp.false()
|
|
244
244
|
if is_false(this):
|
|
245
245
|
return exp.true()
|
|
246
|
-
if isinstance(this, exp.Not):
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
246
|
+
if isinstance(this, exp.Not) and dialect.SAFE_TO_ELIMINATE_DOUBLE_NEGATION:
|
|
247
|
+
inner = this.this
|
|
248
|
+
if inner.is_type(exp.DataType.Type.BOOLEAN) or isinstance(inner, exp.Predicate):
|
|
249
|
+
# double negation
|
|
250
|
+
# NOT NOT x -> x, if x is BOOLEAN type
|
|
251
|
+
return inner
|
|
250
252
|
return expression
|
|
251
253
|
|
|
252
254
|
|
|
@@ -760,7 +762,10 @@ def simplify_parens(expression: exp.Expression, dialect: DialectType = None) ->
|
|
|
760
762
|
not isinstance(this, exp.Binary)
|
|
761
763
|
and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
|
|
762
764
|
)
|
|
763
|
-
or (
|
|
765
|
+
or (
|
|
766
|
+
isinstance(this, exp.Predicate)
|
|
767
|
+
and not (parent_is_predicate or isinstance(parent, exp.Neg))
|
|
768
|
+
)
|
|
764
769
|
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
|
|
765
770
|
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
|
|
766
771
|
or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
|