sqlglot 27.27.0__py3-none-any.whl → 28.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlglot/__init__.py +1 -0
- sqlglot/__main__.py +6 -4
- sqlglot/_version.py +2 -2
- sqlglot/dialects/bigquery.py +118 -279
- sqlglot/dialects/clickhouse.py +73 -5
- sqlglot/dialects/databricks.py +38 -1
- sqlglot/dialects/dialect.py +354 -275
- sqlglot/dialects/dremio.py +4 -1
- sqlglot/dialects/duckdb.py +754 -25
- sqlglot/dialects/exasol.py +243 -10
- sqlglot/dialects/hive.py +8 -8
- sqlglot/dialects/mysql.py +14 -4
- sqlglot/dialects/oracle.py +29 -0
- sqlglot/dialects/postgres.py +60 -26
- sqlglot/dialects/presto.py +47 -16
- sqlglot/dialects/redshift.py +16 -0
- sqlglot/dialects/risingwave.py +3 -0
- sqlglot/dialects/singlestore.py +12 -3
- sqlglot/dialects/snowflake.py +239 -218
- sqlglot/dialects/spark.py +15 -4
- sqlglot/dialects/spark2.py +11 -48
- sqlglot/dialects/sqlite.py +10 -0
- sqlglot/dialects/starrocks.py +3 -0
- sqlglot/dialects/teradata.py +5 -8
- sqlglot/dialects/trino.py +6 -0
- sqlglot/dialects/tsql.py +61 -22
- sqlglot/diff.py +4 -2
- sqlglot/errors.py +69 -0
- sqlglot/executor/__init__.py +5 -10
- sqlglot/executor/python.py +1 -29
- sqlglot/expressions.py +637 -100
- sqlglot/generator.py +160 -43
- sqlglot/helper.py +2 -44
- sqlglot/lineage.py +10 -4
- sqlglot/optimizer/annotate_types.py +247 -140
- sqlglot/optimizer/canonicalize.py +6 -1
- sqlglot/optimizer/eliminate_joins.py +1 -1
- sqlglot/optimizer/eliminate_subqueries.py +2 -2
- sqlglot/optimizer/merge_subqueries.py +5 -5
- sqlglot/optimizer/normalize.py +20 -13
- sqlglot/optimizer/normalize_identifiers.py +17 -3
- sqlglot/optimizer/optimizer.py +4 -0
- sqlglot/optimizer/pushdown_predicates.py +1 -1
- sqlglot/optimizer/qualify.py +18 -10
- sqlglot/optimizer/qualify_columns.py +122 -275
- sqlglot/optimizer/qualify_tables.py +128 -76
- sqlglot/optimizer/resolver.py +374 -0
- sqlglot/optimizer/scope.py +27 -16
- sqlglot/optimizer/simplify.py +1075 -959
- sqlglot/optimizer/unnest_subqueries.py +12 -2
- sqlglot/parser.py +296 -170
- sqlglot/planner.py +2 -2
- sqlglot/schema.py +15 -4
- sqlglot/tokens.py +42 -7
- sqlglot/transforms.py +77 -22
- sqlglot/typing/__init__.py +316 -0
- sqlglot/typing/bigquery.py +376 -0
- sqlglot/typing/hive.py +12 -0
- sqlglot/typing/presto.py +24 -0
- sqlglot/typing/snowflake.py +505 -0
- sqlglot/typing/spark2.py +58 -0
- sqlglot/typing/tsql.py +9 -0
- {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/METADATA +2 -2
- sqlglot-28.4.0.dist-info/RECORD +92 -0
- sqlglot-27.27.0.dist-info/RECORD +0 -84
- {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/WHEEL +0 -0
- {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/licenses/LICENSE +0 -0
- {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/top_level.txt +0 -0
sqlglot/optimizer/simplify.py
CHANGED
|
@@ -6,34 +6,37 @@ import functools
|
|
|
6
6
|
import itertools
|
|
7
7
|
import typing as t
|
|
8
8
|
from collections import deque, defaultdict
|
|
9
|
-
from functools import reduce
|
|
9
|
+
from functools import reduce, wraps
|
|
10
10
|
|
|
11
11
|
import sqlglot
|
|
12
12
|
from sqlglot import Dialect, exp
|
|
13
13
|
from sqlglot.helper import first, merge_ranges, while_changing
|
|
14
|
+
from sqlglot.optimizer.annotate_types import TypeAnnotator
|
|
14
15
|
from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
|
|
16
|
+
from sqlglot.schema import ensure_schema
|
|
15
17
|
|
|
16
18
|
if t.TYPE_CHECKING:
|
|
17
19
|
from sqlglot.dialects.dialect import DialectType
|
|
18
20
|
|
|
21
|
+
DateRange = t.Tuple[datetime.date, datetime.date]
|
|
19
22
|
DateTruncBinaryTransform = t.Callable[
|
|
20
23
|
[exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
|
|
21
24
|
]
|
|
22
25
|
|
|
26
|
+
|
|
23
27
|
logger = logging.getLogger("sqlglot")
|
|
24
28
|
|
|
29
|
+
|
|
25
30
|
# Final means that an expression should not be simplified
|
|
26
31
|
FINAL = "final"
|
|
27
32
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
class UnsupportedUnit(Exception):
|
|
36
|
-
pass
|
|
33
|
+
SIMPLIFIABLE = (
|
|
34
|
+
exp.Binary,
|
|
35
|
+
exp.Func,
|
|
36
|
+
exp.Lambda,
|
|
37
|
+
exp.Predicate,
|
|
38
|
+
exp.Unary,
|
|
39
|
+
)
|
|
37
40
|
|
|
38
41
|
|
|
39
42
|
def simplify(
|
|
@@ -60,96 +63,15 @@ def simplify(
|
|
|
60
63
|
Returns:
|
|
61
64
|
sqlglot.Expression: simplified expression
|
|
62
65
|
"""
|
|
66
|
+
return Simplifier(dialect=dialect).simplify(
|
|
67
|
+
expression,
|
|
68
|
+
constant_propagation=constant_propagation,
|
|
69
|
+
coalesce_simplification=coalesce_simplification,
|
|
70
|
+
)
|
|
63
71
|
|
|
64
|
-
dialect = Dialect.get_or_raise(dialect)
|
|
65
|
-
|
|
66
|
-
def _simplify(expression):
|
|
67
|
-
pre_transformation_stack = [expression]
|
|
68
|
-
post_transformation_stack = []
|
|
69
|
-
|
|
70
|
-
while pre_transformation_stack:
|
|
71
|
-
node = pre_transformation_stack.pop()
|
|
72
|
-
|
|
73
|
-
if node.meta.get(FINAL):
|
|
74
|
-
continue
|
|
75
|
-
|
|
76
|
-
# group by expressions cannot be simplified, for example
|
|
77
|
-
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
|
|
78
|
-
# the projection must exactly match the group by key
|
|
79
|
-
group = node.args.get("group")
|
|
80
|
-
|
|
81
|
-
if group and hasattr(node, "selects"):
|
|
82
|
-
groups = set(group.expressions)
|
|
83
|
-
group.meta[FINAL] = True
|
|
84
|
-
|
|
85
|
-
for s in node.selects:
|
|
86
|
-
for n in s.walk():
|
|
87
|
-
if n in groups:
|
|
88
|
-
s.meta[FINAL] = True
|
|
89
|
-
break
|
|
90
|
-
|
|
91
|
-
having = node.args.get("having")
|
|
92
|
-
if having:
|
|
93
|
-
for n in having.walk():
|
|
94
|
-
if n in groups:
|
|
95
|
-
having.meta[FINAL] = True
|
|
96
|
-
break
|
|
97
|
-
|
|
98
|
-
parent = node.parent
|
|
99
|
-
root = node is expression
|
|
100
|
-
|
|
101
|
-
new_node = rewrite_between(node)
|
|
102
|
-
new_node = uniq_sort(new_node, root)
|
|
103
|
-
new_node = absorb_and_eliminate(new_node, root)
|
|
104
|
-
new_node = simplify_concat(new_node)
|
|
105
|
-
new_node = simplify_conditionals(new_node)
|
|
106
|
-
|
|
107
|
-
if constant_propagation:
|
|
108
|
-
new_node = propagate_constants(new_node, root)
|
|
109
|
-
|
|
110
|
-
if new_node is not node:
|
|
111
|
-
node.replace(new_node)
|
|
112
|
-
|
|
113
|
-
pre_transformation_stack.extend(
|
|
114
|
-
n for n in new_node.iter_expressions(reverse=True) if not n.meta.get(FINAL)
|
|
115
|
-
)
|
|
116
|
-
post_transformation_stack.append((new_node, parent))
|
|
117
|
-
|
|
118
|
-
while post_transformation_stack:
|
|
119
|
-
node, parent = post_transformation_stack.pop()
|
|
120
|
-
root = node is expression
|
|
121
|
-
|
|
122
|
-
# Resets parent, arg_key, index pointers– this is needed because some of the
|
|
123
|
-
# previous transformations mutate the AST, leading to an inconsistent state
|
|
124
|
-
for k, v in tuple(node.args.items()):
|
|
125
|
-
node.set(k, v)
|
|
126
|
-
|
|
127
|
-
# Post-order transformations
|
|
128
|
-
new_node = simplify_not(node)
|
|
129
|
-
new_node = flatten(new_node)
|
|
130
|
-
new_node = simplify_connectors(new_node, root)
|
|
131
|
-
new_node = remove_complements(new_node, root)
|
|
132
|
-
|
|
133
|
-
if coalesce_simplification:
|
|
134
|
-
new_node = simplify_coalesce(new_node, dialect)
|
|
135
|
-
|
|
136
|
-
new_node.parent = parent
|
|
137
|
-
|
|
138
|
-
new_node = simplify_literals(new_node, root)
|
|
139
|
-
new_node = simplify_equality(new_node)
|
|
140
|
-
new_node = simplify_parens(new_node, dialect)
|
|
141
|
-
new_node = simplify_datetrunc(new_node, dialect)
|
|
142
|
-
new_node = sort_comparison(new_node)
|
|
143
|
-
new_node = simplify_startswith(new_node)
|
|
144
|
-
|
|
145
|
-
if new_node is not node:
|
|
146
|
-
node.replace(new_node)
|
|
147
|
-
|
|
148
|
-
return new_node
|
|
149
72
|
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
return expression
|
|
73
|
+
class UnsupportedUnit(Exception):
|
|
74
|
+
pass
|
|
153
75
|
|
|
154
76
|
|
|
155
77
|
def catch(*exceptions):
|
|
@@ -167,87 +89,30 @@ def catch(*exceptions):
|
|
|
167
89
|
return decorator
|
|
168
90
|
|
|
169
91
|
|
|
170
|
-
def
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
"""
|
|
175
|
-
if isinstance(expression, exp.Between):
|
|
176
|
-
negate = isinstance(expression.parent, exp.Not)
|
|
177
|
-
|
|
178
|
-
expression = exp.and_(
|
|
179
|
-
exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
|
|
180
|
-
exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
|
|
181
|
-
copy=False,
|
|
182
|
-
)
|
|
183
|
-
|
|
184
|
-
if negate:
|
|
185
|
-
expression = exp.paren(expression, copy=False)
|
|
92
|
+
def annotate_types_on_change(func):
|
|
93
|
+
@wraps(func)
|
|
94
|
+
def _func(self, expression: exp.Expression, *args, **kwargs) -> t.Optional[exp.Expression]:
|
|
95
|
+
new_expression = func(self, expression, *args, **kwargs)
|
|
186
96
|
|
|
187
|
-
|
|
97
|
+
if new_expression is None:
|
|
98
|
+
return new_expression
|
|
188
99
|
|
|
100
|
+
if self.annotate_new_expressions and expression != new_expression:
|
|
101
|
+
self._annotator.clear()
|
|
189
102
|
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
exp.EQ: exp.NEQ,
|
|
196
|
-
exp.NEQ: exp.EQ,
|
|
197
|
-
}
|
|
103
|
+
# We annotate this to ensure new children nodes are also annotated
|
|
104
|
+
new_expression = self._annotator.annotate(
|
|
105
|
+
expression=new_expression,
|
|
106
|
+
annotate_scope=False,
|
|
107
|
+
)
|
|
198
108
|
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
}
|
|
109
|
+
# Whatever expression the original expression is transformed into needs to preserve
|
|
110
|
+
# the original type, otherwise the simplification could result in a different schema
|
|
111
|
+
new_expression.type = expression.type
|
|
203
112
|
|
|
113
|
+
return new_expression
|
|
204
114
|
|
|
205
|
-
|
|
206
|
-
"""
|
|
207
|
-
Demorgan's Law
|
|
208
|
-
NOT (x OR y) -> NOT x AND NOT y
|
|
209
|
-
NOT (x AND y) -> NOT x OR NOT y
|
|
210
|
-
"""
|
|
211
|
-
if isinstance(expression, exp.Not):
|
|
212
|
-
this = expression.this
|
|
213
|
-
if is_null(this):
|
|
214
|
-
return exp.null()
|
|
215
|
-
if this.__class__ in COMPLEMENT_COMPARISONS:
|
|
216
|
-
right = this.expression
|
|
217
|
-
complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__)
|
|
218
|
-
if complement_subquery_predicate:
|
|
219
|
-
right = complement_subquery_predicate(this=right.this)
|
|
220
|
-
|
|
221
|
-
return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
|
|
222
|
-
if isinstance(this, exp.Paren):
|
|
223
|
-
condition = this.unnest()
|
|
224
|
-
if isinstance(condition, exp.And):
|
|
225
|
-
return exp.paren(
|
|
226
|
-
exp.or_(
|
|
227
|
-
exp.not_(condition.left, copy=False),
|
|
228
|
-
exp.not_(condition.right, copy=False),
|
|
229
|
-
copy=False,
|
|
230
|
-
)
|
|
231
|
-
)
|
|
232
|
-
if isinstance(condition, exp.Or):
|
|
233
|
-
return exp.paren(
|
|
234
|
-
exp.and_(
|
|
235
|
-
exp.not_(condition.left, copy=False),
|
|
236
|
-
exp.not_(condition.right, copy=False),
|
|
237
|
-
copy=False,
|
|
238
|
-
)
|
|
239
|
-
)
|
|
240
|
-
if is_null(condition):
|
|
241
|
-
return exp.null()
|
|
242
|
-
if always_true(this):
|
|
243
|
-
return exp.false()
|
|
244
|
-
if is_false(this):
|
|
245
|
-
return exp.true()
|
|
246
|
-
if isinstance(this, exp.Not):
|
|
247
|
-
# double negation
|
|
248
|
-
# NOT NOT x -> x
|
|
249
|
-
return this.this
|
|
250
|
-
return expression
|
|
115
|
+
return _func
|
|
251
116
|
|
|
252
117
|
|
|
253
118
|
def flatten(expression):
|
|
@@ -263,246 +128,43 @@ def flatten(expression):
|
|
|
263
128
|
return expression
|
|
264
129
|
|
|
265
130
|
|
|
266
|
-
def
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
if is_false(left) or is_false(right):
|
|
270
|
-
return exp.false()
|
|
271
|
-
if is_zero(left) or is_zero(right):
|
|
272
|
-
return exp.false()
|
|
273
|
-
if is_null(left) or is_null(right):
|
|
274
|
-
return exp.null()
|
|
275
|
-
if always_true(left) and always_true(right):
|
|
276
|
-
return exp.true()
|
|
277
|
-
if always_true(left):
|
|
278
|
-
return right
|
|
279
|
-
if always_true(right):
|
|
280
|
-
return left
|
|
281
|
-
return _simplify_comparison(expression, left, right)
|
|
282
|
-
elif isinstance(expression, exp.Or):
|
|
283
|
-
if always_true(left) or always_true(right):
|
|
284
|
-
return exp.true()
|
|
285
|
-
if (
|
|
286
|
-
(is_null(left) and is_null(right))
|
|
287
|
-
or (is_null(left) and always_false(right))
|
|
288
|
-
or (always_false(left) and is_null(right))
|
|
289
|
-
):
|
|
290
|
-
return exp.null()
|
|
291
|
-
if is_false(left):
|
|
292
|
-
return right
|
|
293
|
-
if is_false(right):
|
|
294
|
-
return left
|
|
295
|
-
return _simplify_comparison(expression, left, right, or_=True)
|
|
296
|
-
elif isinstance(expression, exp.Xor):
|
|
297
|
-
if left == right:
|
|
298
|
-
return exp.false()
|
|
299
|
-
|
|
300
|
-
if isinstance(expression, exp.Connector):
|
|
301
|
-
return _flat_simplify(expression, _simplify_connectors, root)
|
|
302
|
-
return expression
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
LT_LTE = (exp.LT, exp.LTE)
|
|
306
|
-
GT_GTE = (exp.GT, exp.GTE)
|
|
307
|
-
|
|
308
|
-
COMPARISONS = (
|
|
309
|
-
*LT_LTE,
|
|
310
|
-
*GT_GTE,
|
|
311
|
-
exp.EQ,
|
|
312
|
-
exp.NEQ,
|
|
313
|
-
exp.Is,
|
|
314
|
-
)
|
|
315
|
-
|
|
316
|
-
INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
|
|
317
|
-
exp.LT: exp.GT,
|
|
318
|
-
exp.GT: exp.LT,
|
|
319
|
-
exp.LTE: exp.GTE,
|
|
320
|
-
exp.GTE: exp.LTE,
|
|
321
|
-
}
|
|
322
|
-
|
|
323
|
-
NONDETERMINISTIC = (exp.Rand, exp.Randn)
|
|
324
|
-
AND_OR = (exp.And, exp.Or)
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
def _simplify_comparison(expression, left, right, or_=False):
|
|
328
|
-
if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
|
|
329
|
-
ll, lr = left.args.values()
|
|
330
|
-
rl, rr = right.args.values()
|
|
331
|
-
|
|
332
|
-
largs = {ll, lr}
|
|
333
|
-
rargs = {rl, rr}
|
|
334
|
-
|
|
335
|
-
matching = largs & rargs
|
|
336
|
-
columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)}
|
|
337
|
-
|
|
338
|
-
if matching and columns:
|
|
339
|
-
try:
|
|
340
|
-
l = first(largs - columns)
|
|
341
|
-
r = first(rargs - columns)
|
|
342
|
-
except StopIteration:
|
|
343
|
-
return expression
|
|
344
|
-
|
|
345
|
-
if l.is_number and r.is_number:
|
|
346
|
-
l = l.to_py()
|
|
347
|
-
r = r.to_py()
|
|
348
|
-
elif l.is_string and r.is_string:
|
|
349
|
-
l = l.name
|
|
350
|
-
r = r.name
|
|
351
|
-
else:
|
|
352
|
-
l = extract_date(l)
|
|
353
|
-
if not l:
|
|
354
|
-
return None
|
|
355
|
-
r = extract_date(r)
|
|
356
|
-
if not r:
|
|
357
|
-
return None
|
|
358
|
-
# python won't compare date and datetime, but many engines will upcast
|
|
359
|
-
l, r = cast_as_datetime(l), cast_as_datetime(r)
|
|
360
|
-
|
|
361
|
-
for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
|
|
362
|
-
if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
|
|
363
|
-
return left if (av > bv if or_ else av <= bv) else right
|
|
364
|
-
if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
|
|
365
|
-
return left if (av < bv if or_ else av >= bv) else right
|
|
366
|
-
|
|
367
|
-
# we can't ever shortcut to true because the column could be null
|
|
368
|
-
if not or_:
|
|
369
|
-
if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
|
|
370
|
-
if av <= bv:
|
|
371
|
-
return exp.false()
|
|
372
|
-
elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
|
|
373
|
-
if av >= bv:
|
|
374
|
-
return exp.false()
|
|
375
|
-
elif isinstance(a, exp.EQ):
|
|
376
|
-
if isinstance(b, exp.LT):
|
|
377
|
-
return exp.false() if av >= bv else a
|
|
378
|
-
if isinstance(b, exp.LTE):
|
|
379
|
-
return exp.false() if av > bv else a
|
|
380
|
-
if isinstance(b, exp.GT):
|
|
381
|
-
return exp.false() if av <= bv else a
|
|
382
|
-
if isinstance(b, exp.GTE):
|
|
383
|
-
return exp.false() if av < bv else a
|
|
384
|
-
if isinstance(b, exp.NEQ):
|
|
385
|
-
return exp.false() if av == bv else a
|
|
386
|
-
return None
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
def remove_complements(expression, root=True):
|
|
390
|
-
"""
|
|
391
|
-
Removing complements.
|
|
392
|
-
|
|
393
|
-
A AND NOT A -> FALSE
|
|
394
|
-
A OR NOT A -> TRUE
|
|
395
|
-
"""
|
|
396
|
-
if isinstance(expression, AND_OR) and (root or not expression.same_parent):
|
|
397
|
-
ops = set(expression.flatten())
|
|
398
|
-
for op in ops:
|
|
399
|
-
if isinstance(op, exp.Not) and op.this in ops:
|
|
400
|
-
return exp.false() if isinstance(expression, exp.And) else exp.true()
|
|
401
|
-
|
|
402
|
-
return expression
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
def uniq_sort(expression, root=True):
|
|
406
|
-
"""
|
|
407
|
-
Uniq and sort a connector.
|
|
408
|
-
|
|
409
|
-
C AND A AND B AND B -> A AND B AND C
|
|
410
|
-
"""
|
|
411
|
-
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
|
|
412
|
-
flattened = tuple(expression.flatten())
|
|
413
|
-
|
|
414
|
-
if isinstance(expression, exp.Xor):
|
|
415
|
-
result_func = exp.xor
|
|
416
|
-
# Do not deduplicate XOR as A XOR A != A if A == True
|
|
417
|
-
deduped = None
|
|
418
|
-
arr = tuple((gen(e), e) for e in flattened)
|
|
419
|
-
else:
|
|
420
|
-
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
|
|
421
|
-
deduped = {gen(e): e for e in flattened}
|
|
422
|
-
arr = tuple(deduped.items())
|
|
423
|
-
|
|
424
|
-
# check if the operands are already sorted, if not sort them
|
|
425
|
-
# A AND C AND B -> A AND B AND C
|
|
426
|
-
for i, (sql, e) in enumerate(arr[1:]):
|
|
427
|
-
if sql < arr[i][0]:
|
|
428
|
-
expression = result_func(*(e for _, e in sorted(arr)), copy=False)
|
|
429
|
-
break
|
|
430
|
-
else:
|
|
431
|
-
# we didn't have to sort but maybe we need to dedup
|
|
432
|
-
if deduped and len(deduped) < len(flattened):
|
|
433
|
-
expression = result_func(*deduped.values(), copy=False)
|
|
434
|
-
|
|
435
|
-
return expression
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
def absorb_and_eliminate(expression, root=True):
|
|
439
|
-
"""
|
|
440
|
-
absorption:
|
|
441
|
-
A AND (A OR B) -> A
|
|
442
|
-
A OR (A AND B) -> A
|
|
443
|
-
A AND (NOT A OR B) -> A AND B
|
|
444
|
-
A OR (NOT A AND B) -> A OR B
|
|
445
|
-
elimination:
|
|
446
|
-
(A AND B) OR (A AND NOT B) -> A
|
|
447
|
-
(A OR B) AND (A OR NOT B) -> A
|
|
448
|
-
"""
|
|
449
|
-
if isinstance(expression, AND_OR) and (root or not expression.same_parent):
|
|
450
|
-
kind = exp.Or if isinstance(expression, exp.And) else exp.And
|
|
451
|
-
|
|
452
|
-
ops = tuple(expression.flatten())
|
|
453
|
-
|
|
454
|
-
# Initialize lookup tables:
|
|
455
|
-
# Set of all operands, used to find complements for absorption.
|
|
456
|
-
op_set = set()
|
|
457
|
-
# Sub-operands, used to find subsets for absorption.
|
|
458
|
-
subops = defaultdict(list)
|
|
459
|
-
# Pairs of complements, used for elimination.
|
|
460
|
-
pairs = defaultdict(list)
|
|
461
|
-
|
|
462
|
-
# Populate the lookup tables
|
|
463
|
-
for op in ops:
|
|
464
|
-
op_set.add(op)
|
|
465
|
-
|
|
466
|
-
if not isinstance(op, kind):
|
|
467
|
-
# In cases like: A OR (A AND B)
|
|
468
|
-
# Subop will be: ^
|
|
469
|
-
subops[op].append({op})
|
|
470
|
-
continue
|
|
471
|
-
|
|
472
|
-
# In cases like: (A AND B) OR (A AND B AND C)
|
|
473
|
-
# Subops will be: ^ ^
|
|
474
|
-
subset = set(op.flatten())
|
|
475
|
-
for i in subset:
|
|
476
|
-
subops[i].append(subset)
|
|
131
|
+
def simplify_parens(expression: exp.Expression, dialect: DialectType) -> exp.Expression:
|
|
132
|
+
if not isinstance(expression, exp.Paren):
|
|
133
|
+
return expression
|
|
477
134
|
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
if isinstance(b, exp.Not):
|
|
482
|
-
pairs[frozenset((a, b.this))].append((op, a))
|
|
135
|
+
this = expression.this
|
|
136
|
+
parent = expression.parent
|
|
137
|
+
parent_is_predicate = isinstance(parent, exp.Predicate)
|
|
483
138
|
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
continue
|
|
139
|
+
if isinstance(this, exp.Select):
|
|
140
|
+
return expression
|
|
487
141
|
|
|
488
|
-
|
|
142
|
+
if isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)):
|
|
143
|
+
return expression
|
|
489
144
|
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
continue
|
|
497
|
-
superset = set(op.flatten())
|
|
498
|
-
if any(any(subset < superset for subset in subops[i]) for i in superset):
|
|
499
|
-
op.replace(exp.false() if kind == exp.And else exp.true())
|
|
500
|
-
continue
|
|
145
|
+
if (
|
|
146
|
+
Dialect.get_or_raise(dialect).REQUIRES_PARENTHESIZED_STRUCT_ACCESS
|
|
147
|
+
and isinstance(parent, exp.Dot)
|
|
148
|
+
and (isinstance(parent.right, (exp.Identifier, exp.Star)))
|
|
149
|
+
):
|
|
150
|
+
return expression
|
|
501
151
|
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
152
|
+
if (
|
|
153
|
+
not isinstance(parent, (exp.Condition, exp.Binary))
|
|
154
|
+
or isinstance(parent, exp.Paren)
|
|
155
|
+
or (
|
|
156
|
+
not isinstance(this, exp.Binary)
|
|
157
|
+
and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
|
|
158
|
+
)
|
|
159
|
+
or (
|
|
160
|
+
isinstance(this, exp.Predicate)
|
|
161
|
+
and not (parent_is_predicate or isinstance(parent, exp.Neg))
|
|
162
|
+
)
|
|
163
|
+
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
|
|
164
|
+
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
|
|
165
|
+
or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
|
|
166
|
+
):
|
|
167
|
+
return this
|
|
506
168
|
|
|
507
169
|
return expression
|
|
508
170
|
|
|
@@ -546,20 +208,6 @@ def propagate_constants(expression, root=True):
|
|
|
546
208
|
return expression
|
|
547
209
|
|
|
548
210
|
|
|
549
|
-
INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
|
|
550
|
-
exp.DateAdd: exp.Sub,
|
|
551
|
-
exp.DateSub: exp.Add,
|
|
552
|
-
exp.DatetimeAdd: exp.Sub,
|
|
553
|
-
exp.DatetimeSub: exp.Add,
|
|
554
|
-
}
|
|
555
|
-
|
|
556
|
-
INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
|
|
557
|
-
**INVERSE_DATE_OPS,
|
|
558
|
-
exp.Add: exp.Sub,
|
|
559
|
-
exp.Sub: exp.Add,
|
|
560
|
-
}
|
|
561
|
-
|
|
562
|
-
|
|
563
211
|
def _is_number(expression: exp.Expression) -> bool:
|
|
564
212
|
return expression.is_number
|
|
565
213
|
|
|
@@ -568,207 +216,6 @@ def _is_interval(expression: exp.Expression) -> bool:
|
|
|
568
216
|
return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
|
|
569
217
|
|
|
570
218
|
|
|
571
|
-
@catch(ModuleNotFoundError, UnsupportedUnit)
|
|
572
|
-
def simplify_equality(expression: exp.Expression) -> exp.Expression:
|
|
573
|
-
"""
|
|
574
|
-
Use the subtraction and addition properties of equality to simplify expressions:
|
|
575
|
-
|
|
576
|
-
x + 1 = 3 becomes x = 2
|
|
577
|
-
|
|
578
|
-
There are two binary operations in the above expression: + and =
|
|
579
|
-
Here's how we reference all the operands in the code below:
|
|
580
|
-
|
|
581
|
-
l r
|
|
582
|
-
x + 1 = 3
|
|
583
|
-
a b
|
|
584
|
-
"""
|
|
585
|
-
if isinstance(expression, COMPARISONS):
|
|
586
|
-
l, r = expression.left, expression.right
|
|
587
|
-
|
|
588
|
-
if l.__class__ not in INVERSE_OPS:
|
|
589
|
-
return expression
|
|
590
|
-
|
|
591
|
-
if r.is_number:
|
|
592
|
-
a_predicate = _is_number
|
|
593
|
-
b_predicate = _is_number
|
|
594
|
-
elif _is_date_literal(r):
|
|
595
|
-
a_predicate = _is_date_literal
|
|
596
|
-
b_predicate = _is_interval
|
|
597
|
-
else:
|
|
598
|
-
return expression
|
|
599
|
-
|
|
600
|
-
if l.__class__ in INVERSE_DATE_OPS:
|
|
601
|
-
l = t.cast(exp.IntervalOp, l)
|
|
602
|
-
a = l.this
|
|
603
|
-
b = l.interval()
|
|
604
|
-
else:
|
|
605
|
-
l = t.cast(exp.Binary, l)
|
|
606
|
-
a, b = l.left, l.right
|
|
607
|
-
|
|
608
|
-
if not a_predicate(a) and b_predicate(b):
|
|
609
|
-
pass
|
|
610
|
-
elif not a_predicate(b) and b_predicate(a):
|
|
611
|
-
a, b = b, a
|
|
612
|
-
else:
|
|
613
|
-
return expression
|
|
614
|
-
|
|
615
|
-
return expression.__class__(
|
|
616
|
-
this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
|
|
617
|
-
)
|
|
618
|
-
return expression
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
def simplify_literals(expression, root=True):
|
|
622
|
-
if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
|
|
623
|
-
return _flat_simplify(expression, _simplify_binary, root)
|
|
624
|
-
|
|
625
|
-
if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
|
|
626
|
-
return expression.this.this
|
|
627
|
-
|
|
628
|
-
if type(expression) in INVERSE_DATE_OPS:
|
|
629
|
-
return _simplify_binary(expression, expression.this, expression.interval()) or expression
|
|
630
|
-
|
|
631
|
-
return expression
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
|
|
638
|
-
if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
|
|
639
|
-
this = _simplify_integer_cast(expr.this)
|
|
640
|
-
else:
|
|
641
|
-
this = expr.this
|
|
642
|
-
|
|
643
|
-
if isinstance(expr, exp.Cast) and this.is_int:
|
|
644
|
-
num = this.to_py()
|
|
645
|
-
|
|
646
|
-
# Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
|
|
647
|
-
# integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
|
|
648
|
-
# engine-dependent
|
|
649
|
-
if (
|
|
650
|
-
TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
|
|
651
|
-
) or (
|
|
652
|
-
UTINYINT_MIN <= num <= UTINYINT_MAX
|
|
653
|
-
and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
|
|
654
|
-
):
|
|
655
|
-
return this
|
|
656
|
-
|
|
657
|
-
return expr
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
def _simplify_binary(expression, a, b):
|
|
661
|
-
if isinstance(expression, COMPARISONS):
|
|
662
|
-
a = _simplify_integer_cast(a)
|
|
663
|
-
b = _simplify_integer_cast(b)
|
|
664
|
-
|
|
665
|
-
if isinstance(expression, exp.Is):
|
|
666
|
-
if isinstance(b, exp.Not):
|
|
667
|
-
c = b.this
|
|
668
|
-
not_ = True
|
|
669
|
-
else:
|
|
670
|
-
c = b
|
|
671
|
-
not_ = False
|
|
672
|
-
|
|
673
|
-
if is_null(c):
|
|
674
|
-
if isinstance(a, exp.Literal):
|
|
675
|
-
return exp.true() if not_ else exp.false()
|
|
676
|
-
if is_null(a):
|
|
677
|
-
return exp.false() if not_ else exp.true()
|
|
678
|
-
elif isinstance(expression, NULL_OK):
|
|
679
|
-
return None
|
|
680
|
-
elif is_null(a) or is_null(b):
|
|
681
|
-
return exp.null()
|
|
682
|
-
|
|
683
|
-
if a.is_number and b.is_number:
|
|
684
|
-
num_a = a.to_py()
|
|
685
|
-
num_b = b.to_py()
|
|
686
|
-
|
|
687
|
-
if isinstance(expression, exp.Add):
|
|
688
|
-
return exp.Literal.number(num_a + num_b)
|
|
689
|
-
if isinstance(expression, exp.Mul):
|
|
690
|
-
return exp.Literal.number(num_a * num_b)
|
|
691
|
-
|
|
692
|
-
# We only simplify Sub, Div if a and b have the same parent because they're not associative
|
|
693
|
-
if isinstance(expression, exp.Sub):
|
|
694
|
-
return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
|
|
695
|
-
if isinstance(expression, exp.Div):
|
|
696
|
-
# engines have differing int div behavior so intdiv is not safe
|
|
697
|
-
if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
|
|
698
|
-
return None
|
|
699
|
-
return exp.Literal.number(num_a / num_b)
|
|
700
|
-
|
|
701
|
-
boolean = eval_boolean(expression, num_a, num_b)
|
|
702
|
-
|
|
703
|
-
if boolean:
|
|
704
|
-
return boolean
|
|
705
|
-
elif a.is_string and b.is_string:
|
|
706
|
-
boolean = eval_boolean(expression, a.this, b.this)
|
|
707
|
-
|
|
708
|
-
if boolean:
|
|
709
|
-
return boolean
|
|
710
|
-
elif _is_date_literal(a) and isinstance(b, exp.Interval):
|
|
711
|
-
date, b = extract_date(a), extract_interval(b)
|
|
712
|
-
if date and b:
|
|
713
|
-
if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
|
|
714
|
-
return date_literal(date + b, extract_type(a))
|
|
715
|
-
if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
|
|
716
|
-
return date_literal(date - b, extract_type(a))
|
|
717
|
-
elif isinstance(a, exp.Interval) and _is_date_literal(b):
|
|
718
|
-
a, date = extract_interval(a), extract_date(b)
|
|
719
|
-
# you cannot subtract a date from an interval
|
|
720
|
-
if a and b and isinstance(expression, exp.Add):
|
|
721
|
-
return date_literal(a + date, extract_type(b))
|
|
722
|
-
elif _is_date_literal(a) and _is_date_literal(b):
|
|
723
|
-
if isinstance(expression, exp.Predicate):
|
|
724
|
-
a, b = extract_date(a), extract_date(b)
|
|
725
|
-
boolean = eval_boolean(expression, a, b)
|
|
726
|
-
if boolean:
|
|
727
|
-
return boolean
|
|
728
|
-
|
|
729
|
-
return None
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
def simplify_parens(expression: exp.Expression, dialect: DialectType = None) -> exp.Expression:
|
|
733
|
-
if not isinstance(expression, exp.Paren):
|
|
734
|
-
return expression
|
|
735
|
-
|
|
736
|
-
this = expression.this
|
|
737
|
-
parent = expression.parent
|
|
738
|
-
parent_is_predicate = isinstance(parent, exp.Predicate)
|
|
739
|
-
|
|
740
|
-
if isinstance(this, exp.Select):
|
|
741
|
-
return expression
|
|
742
|
-
|
|
743
|
-
if isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)):
|
|
744
|
-
return expression
|
|
745
|
-
|
|
746
|
-
# Handle risingwave struct columns
|
|
747
|
-
# see https://docs.risingwave.com/sql/data-types/struct#retrieve-data-in-a-struct
|
|
748
|
-
if (
|
|
749
|
-
dialect == "risingwave"
|
|
750
|
-
and isinstance(parent, exp.Dot)
|
|
751
|
-
and (isinstance(parent.right, (exp.Identifier, exp.Star)))
|
|
752
|
-
):
|
|
753
|
-
return expression
|
|
754
|
-
|
|
755
|
-
if (
|
|
756
|
-
not isinstance(parent, (exp.Condition, exp.Binary))
|
|
757
|
-
or isinstance(parent, exp.Paren)
|
|
758
|
-
or (
|
|
759
|
-
not isinstance(this, exp.Binary)
|
|
760
|
-
and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
|
|
761
|
-
)
|
|
762
|
-
or (isinstance(this, exp.Predicate) and not parent_is_predicate)
|
|
763
|
-
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
|
|
764
|
-
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
|
|
765
|
-
or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
|
|
766
|
-
):
|
|
767
|
-
return this
|
|
768
|
-
|
|
769
|
-
return expression
|
|
770
|
-
|
|
771
|
-
|
|
772
219
|
def _is_nonnull_constant(expression: exp.Expression) -> bool:
|
|
773
220
|
return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
|
|
774
221
|
|
|
@@ -777,164 +224,6 @@ def _is_constant(expression: exp.Expression) -> bool:
|
|
|
777
224
|
return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression)
|
|
778
225
|
|
|
779
226
|
|
|
780
|
-
def simplify_coalesce(expression: exp.Expression, dialect: DialectType) -> exp.Expression:
|
|
781
|
-
# COALESCE(x) -> x
|
|
782
|
-
if (
|
|
783
|
-
isinstance(expression, exp.Coalesce)
|
|
784
|
-
and (not expression.expressions or _is_nonnull_constant(expression.this))
|
|
785
|
-
# COALESCE is also used as a Spark partitioning hint
|
|
786
|
-
and not isinstance(expression.parent, exp.Hint)
|
|
787
|
-
):
|
|
788
|
-
return expression.this
|
|
789
|
-
|
|
790
|
-
# We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift,
|
|
791
|
-
# because they are not always equivalent. For example, if `x` is `NULL` and it comes
|
|
792
|
-
# from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE`
|
|
793
|
-
if dialect == "redshift":
|
|
794
|
-
return expression
|
|
795
|
-
|
|
796
|
-
if not isinstance(expression, COMPARISONS):
|
|
797
|
-
return expression
|
|
798
|
-
|
|
799
|
-
if isinstance(expression.left, exp.Coalesce):
|
|
800
|
-
coalesce = expression.left
|
|
801
|
-
other = expression.right
|
|
802
|
-
elif isinstance(expression.right, exp.Coalesce):
|
|
803
|
-
coalesce = expression.right
|
|
804
|
-
other = expression.left
|
|
805
|
-
else:
|
|
806
|
-
return expression
|
|
807
|
-
|
|
808
|
-
# This transformation is valid for non-constants,
|
|
809
|
-
# but it really only does anything if they are both constants.
|
|
810
|
-
if not _is_constant(other):
|
|
811
|
-
return expression
|
|
812
|
-
|
|
813
|
-
# Find the first constant arg
|
|
814
|
-
for arg_index, arg in enumerate(coalesce.expressions):
|
|
815
|
-
if _is_constant(arg):
|
|
816
|
-
break
|
|
817
|
-
else:
|
|
818
|
-
return expression
|
|
819
|
-
|
|
820
|
-
coalesce.set("expressions", coalesce.expressions[:arg_index])
|
|
821
|
-
|
|
822
|
-
# Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
|
|
823
|
-
# since we already remove COALESCE at the top of this function.
|
|
824
|
-
coalesce = coalesce if coalesce.expressions else coalesce.this
|
|
825
|
-
|
|
826
|
-
# This expression is more complex than when we started, but it will get simplified further
|
|
827
|
-
return exp.paren(
|
|
828
|
-
exp.or_(
|
|
829
|
-
exp.and_(
|
|
830
|
-
coalesce.is_(exp.null()).not_(copy=False),
|
|
831
|
-
expression.copy(),
|
|
832
|
-
copy=False,
|
|
833
|
-
),
|
|
834
|
-
exp.and_(
|
|
835
|
-
coalesce.is_(exp.null()),
|
|
836
|
-
type(expression)(this=arg.copy(), expression=other.copy()),
|
|
837
|
-
copy=False,
|
|
838
|
-
),
|
|
839
|
-
copy=False,
|
|
840
|
-
)
|
|
841
|
-
)
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
CONCATS = (exp.Concat, exp.DPipe)
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
def simplify_concat(expression):
|
|
848
|
-
"""Reduces all groups that contain string literals by concatenating them."""
|
|
849
|
-
if not isinstance(expression, CONCATS) or (
|
|
850
|
-
# We can't reduce a CONCAT_WS call if we don't statically know the separator
|
|
851
|
-
isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
|
|
852
|
-
):
|
|
853
|
-
return expression
|
|
854
|
-
|
|
855
|
-
if isinstance(expression, exp.ConcatWs):
|
|
856
|
-
sep_expr, *expressions = expression.expressions
|
|
857
|
-
sep = sep_expr.name
|
|
858
|
-
concat_type = exp.ConcatWs
|
|
859
|
-
args = {}
|
|
860
|
-
else:
|
|
861
|
-
expressions = expression.expressions
|
|
862
|
-
sep = ""
|
|
863
|
-
concat_type = exp.Concat
|
|
864
|
-
args = {
|
|
865
|
-
"safe": expression.args.get("safe"),
|
|
866
|
-
"coalesce": expression.args.get("coalesce"),
|
|
867
|
-
}
|
|
868
|
-
|
|
869
|
-
new_args = []
|
|
870
|
-
for is_string_group, group in itertools.groupby(
|
|
871
|
-
expressions or expression.flatten(), lambda e: e.is_string
|
|
872
|
-
):
|
|
873
|
-
if is_string_group:
|
|
874
|
-
new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
|
|
875
|
-
else:
|
|
876
|
-
new_args.extend(group)
|
|
877
|
-
|
|
878
|
-
if len(new_args) == 1 and new_args[0].is_string:
|
|
879
|
-
return new_args[0]
|
|
880
|
-
|
|
881
|
-
if concat_type is exp.ConcatWs:
|
|
882
|
-
new_args = [sep_expr] + new_args
|
|
883
|
-
elif isinstance(expression, exp.DPipe):
|
|
884
|
-
return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
|
|
885
|
-
|
|
886
|
-
return concat_type(expressions=new_args, **args)
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
def simplify_conditionals(expression):
|
|
890
|
-
"""Simplifies expressions like IF, CASE if their condition is statically known."""
|
|
891
|
-
if isinstance(expression, exp.Case):
|
|
892
|
-
this = expression.this
|
|
893
|
-
for case in expression.args["ifs"]:
|
|
894
|
-
cond = case.this
|
|
895
|
-
if this:
|
|
896
|
-
# Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
|
|
897
|
-
cond = cond.replace(this.pop().eq(cond))
|
|
898
|
-
|
|
899
|
-
if always_true(cond):
|
|
900
|
-
return case.args["true"]
|
|
901
|
-
|
|
902
|
-
if always_false(cond):
|
|
903
|
-
case.pop()
|
|
904
|
-
if not expression.args["ifs"]:
|
|
905
|
-
return expression.args.get("default") or exp.null()
|
|
906
|
-
elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
|
|
907
|
-
if always_true(expression.this):
|
|
908
|
-
return expression.args["true"]
|
|
909
|
-
if always_false(expression.this):
|
|
910
|
-
return expression.args.get("false") or exp.null()
|
|
911
|
-
|
|
912
|
-
return expression
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
def simplify_startswith(expression: exp.Expression) -> exp.Expression:
|
|
916
|
-
"""
|
|
917
|
-
Reduces a prefix check to either TRUE or FALSE if both the string and the
|
|
918
|
-
prefix are statically known.
|
|
919
|
-
|
|
920
|
-
Example:
|
|
921
|
-
>>> from sqlglot import parse_one
|
|
922
|
-
>>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
|
|
923
|
-
'TRUE'
|
|
924
|
-
"""
|
|
925
|
-
if (
|
|
926
|
-
isinstance(expression, exp.StartsWith)
|
|
927
|
-
and expression.this.is_string
|
|
928
|
-
and expression.expression.is_string
|
|
929
|
-
):
|
|
930
|
-
return exp.convert(expression.name.startswith(expression.expression.name))
|
|
931
|
-
|
|
932
|
-
return expression
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
DateRange = t.Tuple[datetime.date, datetime.date]
|
|
936
|
-
|
|
937
|
-
|
|
938
227
|
def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
|
|
939
228
|
"""
|
|
940
229
|
Get the date range for a DATE_TRUNC equality comparison:
|
|
@@ -960,175 +249,45 @@ def _datetrunc_eq_expression(
|
|
|
960
249
|
return exp.and_(
|
|
961
250
|
left >= date_literal(drange[0], target_type),
|
|
962
251
|
left < date_literal(drange[1], target_type),
|
|
963
|
-
copy=False,
|
|
964
|
-
)
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
def _datetrunc_eq(
|
|
968
|
-
left: exp.Expression,
|
|
969
|
-
date: datetime.date,
|
|
970
|
-
unit: str,
|
|
971
|
-
dialect: Dialect,
|
|
972
|
-
target_type: t.Optional[exp.DataType],
|
|
973
|
-
) -> t.Optional[exp.Expression]:
|
|
974
|
-
drange = _datetrunc_range(date, unit, dialect)
|
|
975
|
-
if not drange:
|
|
976
|
-
return None
|
|
977
|
-
|
|
978
|
-
return _datetrunc_eq_expression(left, drange, target_type)
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
def _datetrunc_neq(
|
|
982
|
-
left: exp.Expression,
|
|
983
|
-
date: datetime.date,
|
|
984
|
-
unit: str,
|
|
985
|
-
dialect: Dialect,
|
|
986
|
-
target_type: t.Optional[exp.DataType],
|
|
987
|
-
) -> t.Optional[exp.Expression]:
|
|
988
|
-
drange = _datetrunc_range(date, unit, dialect)
|
|
989
|
-
if not drange:
|
|
990
|
-
return None
|
|
991
|
-
|
|
992
|
-
return exp.and_(
|
|
993
|
-
left < date_literal(drange[0], target_type),
|
|
994
|
-
left >= date_literal(drange[1], target_type),
|
|
995
|
-
copy=False,
|
|
996
|
-
)
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
|
|
1000
|
-
exp.LT: lambda l, dt, u, d, t: l
|
|
1001
|
-
< date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
|
|
1002
|
-
exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
|
|
1003
|
-
exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
|
|
1004
|
-
exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
|
|
1005
|
-
exp.EQ: _datetrunc_eq,
|
|
1006
|
-
exp.NEQ: _datetrunc_neq,
|
|
1007
|
-
}
|
|
1008
|
-
DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
|
|
1009
|
-
DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
|
|
1013
|
-
return isinstance(left, DATETRUNCS) and _is_date_literal(right)
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
@catch(ModuleNotFoundError, UnsupportedUnit)
|
|
1017
|
-
def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
|
|
1018
|
-
"""Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
|
|
1019
|
-
comparison = expression.__class__
|
|
1020
|
-
|
|
1021
|
-
if isinstance(expression, DATETRUNCS):
|
|
1022
|
-
this = expression.this
|
|
1023
|
-
trunc_type = extract_type(this)
|
|
1024
|
-
date = extract_date(this)
|
|
1025
|
-
if date and expression.unit:
|
|
1026
|
-
return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type)
|
|
1027
|
-
elif comparison not in DATETRUNC_COMPARISONS:
|
|
1028
|
-
return expression
|
|
1029
|
-
|
|
1030
|
-
if isinstance(expression, exp.Binary):
|
|
1031
|
-
l, r = expression.left, expression.right
|
|
1032
|
-
|
|
1033
|
-
if not _is_datetrunc_predicate(l, r):
|
|
1034
|
-
return expression
|
|
1035
|
-
|
|
1036
|
-
l = t.cast(exp.DateTrunc, l)
|
|
1037
|
-
trunc_arg = l.this
|
|
1038
|
-
unit = l.unit.name.lower()
|
|
1039
|
-
date = extract_date(r)
|
|
1040
|
-
|
|
1041
|
-
if not date:
|
|
1042
|
-
return expression
|
|
1043
|
-
|
|
1044
|
-
return (
|
|
1045
|
-
DATETRUNC_BINARY_COMPARISONS[comparison](
|
|
1046
|
-
trunc_arg, date, unit, dialect, extract_type(r)
|
|
1047
|
-
)
|
|
1048
|
-
or expression
|
|
1049
|
-
)
|
|
1050
|
-
|
|
1051
|
-
if isinstance(expression, exp.In):
|
|
1052
|
-
l = expression.this
|
|
1053
|
-
rs = expression.expressions
|
|
1054
|
-
|
|
1055
|
-
if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
|
|
1056
|
-
l = t.cast(exp.DateTrunc, l)
|
|
1057
|
-
unit = l.unit.name.lower()
|
|
1058
|
-
|
|
1059
|
-
ranges = []
|
|
1060
|
-
for r in rs:
|
|
1061
|
-
date = extract_date(r)
|
|
1062
|
-
if not date:
|
|
1063
|
-
return expression
|
|
1064
|
-
drange = _datetrunc_range(date, unit, dialect)
|
|
1065
|
-
if drange:
|
|
1066
|
-
ranges.append(drange)
|
|
1067
|
-
|
|
1068
|
-
if not ranges:
|
|
1069
|
-
return expression
|
|
1070
|
-
|
|
1071
|
-
ranges = merge_ranges(ranges)
|
|
1072
|
-
target_type = extract_type(*rs)
|
|
1073
|
-
|
|
1074
|
-
return exp.or_(
|
|
1075
|
-
*[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False
|
|
1076
|
-
)
|
|
1077
|
-
|
|
1078
|
-
return expression
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
def sort_comparison(expression: exp.Expression) -> exp.Expression:
|
|
1082
|
-
if expression.__class__ in COMPLEMENT_COMPARISONS:
|
|
1083
|
-
l, r = expression.this, expression.expression
|
|
1084
|
-
l_column = isinstance(l, exp.Column)
|
|
1085
|
-
r_column = isinstance(r, exp.Column)
|
|
1086
|
-
l_const = _is_constant(l)
|
|
1087
|
-
r_const = _is_constant(r)
|
|
1088
|
-
|
|
1089
|
-
if (
|
|
1090
|
-
(l_column and not r_column)
|
|
1091
|
-
or (r_const and not l_const)
|
|
1092
|
-
or isinstance(r, exp.SubqueryPredicate)
|
|
1093
|
-
):
|
|
1094
|
-
return expression
|
|
1095
|
-
if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
|
|
1096
|
-
return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
|
|
1097
|
-
this=r, expression=l
|
|
1098
|
-
)
|
|
1099
|
-
return expression
|
|
252
|
+
copy=False,
|
|
253
|
+
)
|
|
1100
254
|
|
|
1101
255
|
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
(
|
|
1110
|
-
|
|
256
|
+
def _datetrunc_eq(
|
|
257
|
+
left: exp.Expression,
|
|
258
|
+
date: datetime.date,
|
|
259
|
+
unit: str,
|
|
260
|
+
dialect: Dialect,
|
|
261
|
+
target_type: t.Optional[exp.DataType],
|
|
262
|
+
) -> t.Optional[exp.Expression]:
|
|
263
|
+
drange = _datetrunc_range(date, unit, dialect)
|
|
264
|
+
if not drange:
|
|
265
|
+
return None
|
|
266
|
+
|
|
267
|
+
return _datetrunc_eq_expression(left, drange, target_type)
|
|
1111
268
|
|
|
1112
269
|
|
|
1113
|
-
def
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
270
|
+
def _datetrunc_neq(
|
|
271
|
+
left: exp.Expression,
|
|
272
|
+
date: datetime.date,
|
|
273
|
+
unit: str,
|
|
274
|
+
dialect: Dialect,
|
|
275
|
+
target_type: t.Optional[exp.DataType],
|
|
276
|
+
) -> t.Optional[exp.Expression]:
|
|
277
|
+
drange = _datetrunc_range(date, unit, dialect)
|
|
278
|
+
if not drange:
|
|
279
|
+
return None
|
|
280
|
+
|
|
281
|
+
return exp.and_(
|
|
282
|
+
left < date_literal(drange[0], target_type),
|
|
283
|
+
left >= date_literal(drange[1], target_type),
|
|
284
|
+
copy=False,
|
|
285
|
+
)
|
|
1127
286
|
|
|
1128
287
|
|
|
1129
288
|
def always_true(expression):
|
|
1130
289
|
return (isinstance(expression, exp.Boolean) and expression.this) or (
|
|
1131
|
-
isinstance(expression, exp.Literal) and not is_zero(expression)
|
|
290
|
+
isinstance(expression, exp.Literal) and expression.is_number and not is_zero(expression)
|
|
1132
291
|
)
|
|
1133
292
|
|
|
1134
293
|
|
|
@@ -1310,30 +469,987 @@ def boolean_literal(condition):
|
|
|
1310
469
|
return exp.true() if condition else exp.false()
|
|
1311
470
|
|
|
1312
471
|
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
|
|
472
|
+
class Simplifier:
|
|
473
|
+
def __init__(self, dialect: DialectType = None, annotate_new_expressions: bool = True):
|
|
474
|
+
self.dialect = Dialect.get_or_raise(dialect)
|
|
475
|
+
self.annotate_new_expressions = annotate_new_expressions
|
|
476
|
+
|
|
477
|
+
self._annotator: TypeAnnotator = TypeAnnotator(
|
|
478
|
+
schema=ensure_schema(None, dialect=self.dialect), overwrite_types=False
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
# Value ranges for byte-sized signed/unsigned integers
|
|
482
|
+
TINYINT_MIN = -128
|
|
483
|
+
TINYINT_MAX = 127
|
|
484
|
+
UTINYINT_MIN = 0
|
|
485
|
+
UTINYINT_MAX = 255
|
|
486
|
+
|
|
487
|
+
COMPLEMENT_COMPARISONS = {
|
|
488
|
+
exp.LT: exp.GTE,
|
|
489
|
+
exp.GT: exp.LTE,
|
|
490
|
+
exp.LTE: exp.GT,
|
|
491
|
+
exp.GTE: exp.LT,
|
|
492
|
+
exp.EQ: exp.NEQ,
|
|
493
|
+
exp.NEQ: exp.EQ,
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
COMPLEMENT_SUBQUERY_PREDICATES = {
|
|
497
|
+
exp.All: exp.Any,
|
|
498
|
+
exp.Any: exp.All,
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
LT_LTE = (exp.LT, exp.LTE)
|
|
502
|
+
GT_GTE = (exp.GT, exp.GTE)
|
|
503
|
+
|
|
504
|
+
COMPARISONS = (
|
|
505
|
+
*LT_LTE,
|
|
506
|
+
*GT_GTE,
|
|
507
|
+
exp.EQ,
|
|
508
|
+
exp.NEQ,
|
|
509
|
+
exp.Is,
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
|
|
513
|
+
exp.LT: exp.GT,
|
|
514
|
+
exp.GT: exp.LT,
|
|
515
|
+
exp.LTE: exp.GTE,
|
|
516
|
+
exp.GTE: exp.LTE,
|
|
517
|
+
}
|
|
518
|
+
|
|
519
|
+
NONDETERMINISTIC = (exp.Rand, exp.Randn)
|
|
520
|
+
AND_OR = (exp.And, exp.Or)
|
|
521
|
+
|
|
522
|
+
INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
|
|
523
|
+
exp.DateAdd: exp.Sub,
|
|
524
|
+
exp.DateSub: exp.Add,
|
|
525
|
+
exp.DatetimeAdd: exp.Sub,
|
|
526
|
+
exp.DatetimeSub: exp.Add,
|
|
527
|
+
}
|
|
528
|
+
|
|
529
|
+
INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
|
|
530
|
+
**INVERSE_DATE_OPS,
|
|
531
|
+
exp.Add: exp.Sub,
|
|
532
|
+
exp.Sub: exp.Add,
|
|
533
|
+
}
|
|
534
|
+
|
|
535
|
+
NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
|
|
536
|
+
|
|
537
|
+
CONCATS = (exp.Concat, exp.DPipe)
|
|
538
|
+
|
|
539
|
+
DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
|
|
540
|
+
exp.LT: lambda l, dt, u, d, t: l
|
|
541
|
+
< date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
|
|
542
|
+
exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
|
|
543
|
+
exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
|
|
544
|
+
exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
|
|
545
|
+
exp.EQ: _datetrunc_eq,
|
|
546
|
+
exp.NEQ: _datetrunc_neq,
|
|
547
|
+
}
|
|
548
|
+
|
|
549
|
+
DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
|
|
550
|
+
DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
|
|
551
|
+
|
|
552
|
+
SAFE_CONNECTOR_ELIMINATION_RESULT = (exp.Connector, exp.Boolean)
|
|
553
|
+
|
|
554
|
+
# CROSS joins result in an empty table if the right table is empty.
|
|
555
|
+
# So we can only simplify certain types of joins to CROSS.
|
|
556
|
+
# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
|
|
557
|
+
JOINS = {
|
|
558
|
+
("", ""),
|
|
559
|
+
("", "INNER"),
|
|
560
|
+
("RIGHT", ""),
|
|
561
|
+
("RIGHT", "OUTER"),
|
|
562
|
+
}
|
|
563
|
+
|
|
564
|
+
def simplify(
|
|
565
|
+
self,
|
|
566
|
+
expression: exp.Expression,
|
|
567
|
+
constant_propagation: bool = False,
|
|
568
|
+
coalesce_simplification: bool = False,
|
|
569
|
+
):
|
|
570
|
+
wheres = []
|
|
571
|
+
joins = []
|
|
572
|
+
|
|
573
|
+
for node in expression.walk(
|
|
574
|
+
prune=lambda n: bool(isinstance(n, exp.Condition) or n.meta.get(FINAL))
|
|
575
|
+
):
|
|
576
|
+
if node.meta.get(FINAL):
|
|
577
|
+
continue
|
|
578
|
+
|
|
579
|
+
# group by expressions cannot be simplified, for example
|
|
580
|
+
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
|
|
581
|
+
# the projection must exactly match the group by key
|
|
582
|
+
group = node.args.get("group")
|
|
583
|
+
|
|
584
|
+
if group and hasattr(node, "selects"):
|
|
585
|
+
groups = set(group.expressions)
|
|
586
|
+
group.meta[FINAL] = True
|
|
587
|
+
|
|
588
|
+
for s in node.selects:
|
|
589
|
+
for n in s.walk(FINAL):
|
|
590
|
+
if n in groups:
|
|
591
|
+
s.meta[FINAL] = True
|
|
592
|
+
break
|
|
593
|
+
|
|
594
|
+
having = node.args.get("having")
|
|
595
|
+
|
|
596
|
+
if having:
|
|
597
|
+
for n in having.walk():
|
|
598
|
+
if n in groups:
|
|
599
|
+
having.meta[FINAL] = True
|
|
600
|
+
break
|
|
601
|
+
|
|
602
|
+
if isinstance(node, exp.Condition):
|
|
603
|
+
simplified = while_changing(
|
|
604
|
+
node, lambda e: self._simplify(e, constant_propagation, coalesce_simplification)
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
if node is expression:
|
|
608
|
+
expression = simplified
|
|
609
|
+
elif isinstance(node, exp.Where):
|
|
610
|
+
wheres.append(node)
|
|
611
|
+
elif isinstance(node, exp.Join):
|
|
612
|
+
# snowflake match_conditions have very strict ordering rules
|
|
613
|
+
if match := node.args.get("match_condition"):
|
|
614
|
+
match.meta[FINAL] = True
|
|
615
|
+
|
|
616
|
+
joins.append(node)
|
|
617
|
+
|
|
618
|
+
for where in wheres:
|
|
619
|
+
if always_true(where.this):
|
|
620
|
+
where.pop()
|
|
621
|
+
for join in joins:
|
|
622
|
+
if (
|
|
623
|
+
always_true(join.args.get("on"))
|
|
624
|
+
and not join.args.get("using")
|
|
625
|
+
and not join.args.get("method")
|
|
626
|
+
and (join.side, join.kind) in self.JOINS
|
|
627
|
+
):
|
|
628
|
+
join.args["on"].pop()
|
|
629
|
+
join.set("side", None)
|
|
630
|
+
join.set("kind", "CROSS")
|
|
631
|
+
|
|
632
|
+
return expression
|
|
633
|
+
|
|
634
|
+
def _simplify(
|
|
635
|
+
self, expression: exp.Expression, constant_propagation: bool, coalesce_simplification: bool
|
|
636
|
+
):
|
|
637
|
+
pre_transformation_stack = [expression]
|
|
638
|
+
post_transformation_stack = []
|
|
639
|
+
|
|
640
|
+
while pre_transformation_stack:
|
|
641
|
+
original = pre_transformation_stack.pop()
|
|
642
|
+
node = original
|
|
643
|
+
|
|
644
|
+
if not isinstance(node, SIMPLIFIABLE):
|
|
645
|
+
if isinstance(node, exp.Query):
|
|
646
|
+
self.simplify(node, constant_propagation, coalesce_simplification)
|
|
647
|
+
continue
|
|
648
|
+
|
|
649
|
+
parent = node.parent
|
|
650
|
+
root = node is expression
|
|
651
|
+
|
|
652
|
+
node = self.rewrite_between(node)
|
|
653
|
+
node = self.uniq_sort(node, root)
|
|
654
|
+
node = self.absorb_and_eliminate(node, root)
|
|
655
|
+
node = self.simplify_concat(node)
|
|
656
|
+
node = self.simplify_conditionals(node)
|
|
657
|
+
|
|
658
|
+
if constant_propagation:
|
|
659
|
+
node = propagate_constants(node, root)
|
|
660
|
+
|
|
661
|
+
if node is not original:
|
|
662
|
+
original.replace(node)
|
|
663
|
+
|
|
664
|
+
for n in node.iter_expressions(reverse=True):
|
|
665
|
+
if n.meta.get(FINAL):
|
|
666
|
+
raise
|
|
667
|
+
pre_transformation_stack.extend(
|
|
668
|
+
n for n in node.iter_expressions(reverse=True) if not n.meta.get(FINAL)
|
|
669
|
+
)
|
|
670
|
+
post_transformation_stack.append((node, parent))
|
|
671
|
+
|
|
672
|
+
while post_transformation_stack:
|
|
673
|
+
original, parent = post_transformation_stack.pop()
|
|
674
|
+
root = original is expression
|
|
675
|
+
|
|
676
|
+
# Resets parent, arg_key, index pointers– this is needed because some of the
|
|
677
|
+
# previous transformations mutate the AST, leading to an inconsistent state
|
|
678
|
+
for k, v in tuple(original.args.items()):
|
|
679
|
+
original.set(k, v)
|
|
680
|
+
|
|
681
|
+
# Post-order transformations
|
|
682
|
+
node = self.simplify_not(original)
|
|
683
|
+
node = flatten(node)
|
|
684
|
+
node = self.simplify_connectors(node, root)
|
|
685
|
+
node = self.remove_complements(node, root)
|
|
686
|
+
|
|
687
|
+
if coalesce_simplification:
|
|
688
|
+
node = self.simplify_coalesce(node)
|
|
689
|
+
node.parent = parent
|
|
690
|
+
|
|
691
|
+
node = self.simplify_literals(node, root)
|
|
692
|
+
node = self.simplify_equality(node)
|
|
693
|
+
node = simplify_parens(node, dialect=self.dialect)
|
|
694
|
+
node = self.simplify_datetrunc(node)
|
|
695
|
+
node = self.sort_comparison(node)
|
|
696
|
+
node = self.simplify_startswith(node)
|
|
697
|
+
|
|
698
|
+
if node is not original:
|
|
699
|
+
original.replace(node)
|
|
700
|
+
|
|
701
|
+
return node
|
|
702
|
+
|
|
703
|
+
@annotate_types_on_change
|
|
704
|
+
def rewrite_between(self, expression: exp.Expression) -> exp.Expression:
|
|
705
|
+
"""Rewrite x between y and z to x >= y AND x <= z.
|
|
706
|
+
|
|
707
|
+
This is done because comparison simplification is only done on lt/lte/gt/gte.
|
|
708
|
+
"""
|
|
709
|
+
if isinstance(expression, exp.Between):
|
|
710
|
+
negate = isinstance(expression.parent, exp.Not)
|
|
711
|
+
|
|
712
|
+
expression = exp.and_(
|
|
713
|
+
exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
|
|
714
|
+
exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
|
|
715
|
+
copy=False,
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
if negate:
|
|
719
|
+
expression = exp.paren(expression, copy=False)
|
|
720
|
+
|
|
721
|
+
return expression
|
|
722
|
+
|
|
723
|
+
@annotate_types_on_change
|
|
724
|
+
def simplify_not(self, expression: exp.Expression) -> exp.Expression:
|
|
725
|
+
"""
|
|
726
|
+
Demorgan's Law
|
|
727
|
+
NOT (x OR y) -> NOT x AND NOT y
|
|
728
|
+
NOT (x AND y) -> NOT x OR NOT y
|
|
729
|
+
"""
|
|
730
|
+
if isinstance(expression, exp.Not):
|
|
731
|
+
this = expression.this
|
|
732
|
+
if is_null(this):
|
|
733
|
+
return exp.and_(exp.null(), exp.true(), copy=False)
|
|
734
|
+
if this.__class__ in self.COMPLEMENT_COMPARISONS:
|
|
735
|
+
right = this.expression
|
|
736
|
+
complement_subquery_predicate = self.COMPLEMENT_SUBQUERY_PREDICATES.get(
|
|
737
|
+
right.__class__
|
|
738
|
+
)
|
|
739
|
+
if complement_subquery_predicate:
|
|
740
|
+
right = complement_subquery_predicate(this=right.this)
|
|
741
|
+
|
|
742
|
+
return self.COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
|
|
743
|
+
if isinstance(this, exp.Paren):
|
|
744
|
+
condition = this.unnest()
|
|
745
|
+
if isinstance(condition, exp.And):
|
|
746
|
+
return exp.paren(
|
|
747
|
+
exp.or_(
|
|
748
|
+
exp.not_(condition.left, copy=False),
|
|
749
|
+
exp.not_(condition.right, copy=False),
|
|
750
|
+
copy=False,
|
|
751
|
+
),
|
|
752
|
+
copy=False,
|
|
753
|
+
)
|
|
754
|
+
if isinstance(condition, exp.Or):
|
|
755
|
+
return exp.paren(
|
|
756
|
+
exp.and_(
|
|
757
|
+
exp.not_(condition.left, copy=False),
|
|
758
|
+
exp.not_(condition.right, copy=False),
|
|
759
|
+
copy=False,
|
|
760
|
+
),
|
|
761
|
+
copy=False,
|
|
762
|
+
)
|
|
763
|
+
if is_null(condition):
|
|
764
|
+
return exp.and_(exp.null(), exp.true(), copy=False)
|
|
765
|
+
if always_true(this):
|
|
766
|
+
return exp.false()
|
|
767
|
+
if is_false(this):
|
|
768
|
+
return exp.true()
|
|
769
|
+
if isinstance(this, exp.Not) and self.dialect.SAFE_TO_ELIMINATE_DOUBLE_NEGATION:
|
|
770
|
+
inner = this.this
|
|
771
|
+
if inner.is_type(exp.DataType.Type.BOOLEAN):
|
|
772
|
+
# double negation
|
|
773
|
+
# NOT NOT x -> x, if x is BOOLEAN type
|
|
774
|
+
return inner
|
|
775
|
+
return expression
|
|
776
|
+
|
|
777
|
+
@annotate_types_on_change
|
|
778
|
+
def simplify_connectors(self, expression, root=True):
|
|
779
|
+
def _simplify_connectors(expression, left, right):
|
|
780
|
+
if isinstance(expression, exp.And):
|
|
781
|
+
if is_false(left) or is_false(right):
|
|
782
|
+
return exp.false()
|
|
783
|
+
if is_zero(left) or is_zero(right):
|
|
784
|
+
return exp.false()
|
|
785
|
+
if (
|
|
786
|
+
(is_null(left) and is_null(right))
|
|
787
|
+
or (is_null(left) and always_true(right))
|
|
788
|
+
or (always_true(left) and is_null(right))
|
|
789
|
+
):
|
|
790
|
+
return exp.null()
|
|
791
|
+
if always_true(left) and always_true(right):
|
|
792
|
+
return exp.true()
|
|
793
|
+
if always_true(left):
|
|
794
|
+
return right
|
|
795
|
+
if always_true(right):
|
|
796
|
+
return left
|
|
797
|
+
return self._simplify_comparison(expression, left, right)
|
|
798
|
+
elif isinstance(expression, exp.Or):
|
|
799
|
+
if always_true(left) or always_true(right):
|
|
800
|
+
return exp.true()
|
|
801
|
+
if (
|
|
802
|
+
(is_null(left) and is_null(right))
|
|
803
|
+
or (is_null(left) and always_false(right))
|
|
804
|
+
or (always_false(left) and is_null(right))
|
|
805
|
+
):
|
|
806
|
+
return exp.null()
|
|
807
|
+
if is_false(left):
|
|
808
|
+
return right
|
|
809
|
+
if is_false(right):
|
|
810
|
+
return left
|
|
811
|
+
return self._simplify_comparison(expression, left, right, or_=True)
|
|
812
|
+
|
|
813
|
+
if isinstance(expression, exp.Connector):
|
|
814
|
+
original_parent = expression.parent
|
|
815
|
+
expression = self._flat_simplify(expression, _simplify_connectors, root)
|
|
816
|
+
|
|
817
|
+
# If we reduced a connector to, e.g., a column (t1 AND ... AND tn -> Tk), then we need
|
|
818
|
+
# to ensure that the resulting type is boolean. We know this is true only for connectors,
|
|
819
|
+
# boolean values and columns that are essentially operands to a connector:
|
|
820
|
+
#
|
|
821
|
+
# A AND (((B)))
|
|
822
|
+
# ~ this is safe to keep because it will eventually be part of another connector
|
|
823
|
+
if not isinstance(
|
|
824
|
+
expression, self.SAFE_CONNECTOR_ELIMINATION_RESULT
|
|
825
|
+
) and not expression.is_type(exp.DataType.Type.BOOLEAN):
|
|
826
|
+
while True:
|
|
827
|
+
if isinstance(original_parent, exp.Connector):
|
|
828
|
+
break
|
|
829
|
+
if not isinstance(original_parent, exp.Paren):
|
|
830
|
+
expression = expression.and_(exp.true(), copy=False)
|
|
831
|
+
break
|
|
832
|
+
|
|
833
|
+
original_parent = original_parent.parent
|
|
834
|
+
|
|
835
|
+
return expression
|
|
836
|
+
|
|
837
|
+
@annotate_types_on_change
|
|
838
|
+
def _simplify_comparison(self, expression, left, right, or_=False):
|
|
839
|
+
if isinstance(left, self.COMPARISONS) and isinstance(right, self.COMPARISONS):
|
|
840
|
+
ll, lr = left.args.values()
|
|
841
|
+
rl, rr = right.args.values()
|
|
842
|
+
|
|
843
|
+
largs = {ll, lr}
|
|
844
|
+
rargs = {rl, rr}
|
|
845
|
+
|
|
846
|
+
matching = largs & rargs
|
|
847
|
+
columns = {
|
|
848
|
+
m for m in matching if not _is_constant(m) and not m.find(*self.NONDETERMINISTIC)
|
|
849
|
+
}
|
|
850
|
+
|
|
851
|
+
if matching and columns:
|
|
852
|
+
try:
|
|
853
|
+
l = first(largs - columns)
|
|
854
|
+
r = first(rargs - columns)
|
|
855
|
+
except StopIteration:
|
|
856
|
+
return expression
|
|
857
|
+
|
|
858
|
+
if l.is_number and r.is_number:
|
|
859
|
+
l = l.to_py()
|
|
860
|
+
r = r.to_py()
|
|
861
|
+
elif l.is_string and r.is_string:
|
|
862
|
+
l = l.name
|
|
863
|
+
r = r.name
|
|
864
|
+
else:
|
|
865
|
+
l = extract_date(l)
|
|
866
|
+
if not l:
|
|
867
|
+
return None
|
|
868
|
+
r = extract_date(r)
|
|
869
|
+
if not r:
|
|
870
|
+
return None
|
|
871
|
+
# python won't compare date and datetime, but many engines will upcast
|
|
872
|
+
l, r = cast_as_datetime(l), cast_as_datetime(r)
|
|
873
|
+
|
|
874
|
+
for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
|
|
875
|
+
if isinstance(a, self.LT_LTE) and isinstance(b, self.LT_LTE):
|
|
876
|
+
return left if (av > bv if or_ else av <= bv) else right
|
|
877
|
+
if isinstance(a, self.GT_GTE) and isinstance(b, self.GT_GTE):
|
|
878
|
+
return left if (av < bv if or_ else av >= bv) else right
|
|
879
|
+
|
|
880
|
+
# we can't ever shortcut to true because the column could be null
|
|
881
|
+
if not or_:
|
|
882
|
+
if isinstance(a, exp.LT) and isinstance(b, self.GT_GTE):
|
|
883
|
+
if av <= bv:
|
|
884
|
+
return exp.false()
|
|
885
|
+
elif isinstance(a, exp.GT) and isinstance(b, self.LT_LTE):
|
|
886
|
+
if av >= bv:
|
|
887
|
+
return exp.false()
|
|
888
|
+
elif isinstance(a, exp.EQ):
|
|
889
|
+
if isinstance(b, exp.LT):
|
|
890
|
+
return exp.false() if av >= bv else a
|
|
891
|
+
if isinstance(b, exp.LTE):
|
|
892
|
+
return exp.false() if av > bv else a
|
|
893
|
+
if isinstance(b, exp.GT):
|
|
894
|
+
return exp.false() if av <= bv else a
|
|
895
|
+
if isinstance(b, exp.GTE):
|
|
896
|
+
return exp.false() if av < bv else a
|
|
897
|
+
if isinstance(b, exp.NEQ):
|
|
898
|
+
return exp.false() if av == bv else a
|
|
899
|
+
return None
|
|
1318
900
|
|
|
1319
|
-
|
|
1320
|
-
|
|
901
|
+
@annotate_types_on_change
|
|
902
|
+
def remove_complements(self, expression, root=True):
|
|
903
|
+
"""
|
|
904
|
+
Removing complements.
|
|
905
|
+
|
|
906
|
+
A AND NOT A -> FALSE (only for non-NULL A)
|
|
907
|
+
A OR NOT A -> TRUE (only for non-NULL A)
|
|
908
|
+
"""
|
|
909
|
+
if isinstance(expression, self.AND_OR) and (root or not expression.same_parent):
|
|
910
|
+
ops = set(expression.flatten())
|
|
911
|
+
for op in ops:
|
|
912
|
+
if isinstance(op, exp.Not) and op.this in ops:
|
|
913
|
+
if expression.meta.get("nonnull") is True:
|
|
914
|
+
return exp.false() if isinstance(expression, exp.And) else exp.true()
|
|
1321
915
|
|
|
1322
|
-
|
|
1323
|
-
result = simplifier(expression, a, b)
|
|
916
|
+
return expression
|
|
1324
917
|
|
|
1325
|
-
|
|
1326
|
-
|
|
1327
|
-
|
|
918
|
+
@annotate_types_on_change
|
|
919
|
+
def uniq_sort(self, expression, root=True):
|
|
920
|
+
"""
|
|
921
|
+
Uniq and sort a connector.
|
|
922
|
+
|
|
923
|
+
C AND A AND B AND B -> A AND B AND C
|
|
924
|
+
"""
|
|
925
|
+
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
|
|
926
|
+
flattened = tuple(expression.flatten())
|
|
927
|
+
|
|
928
|
+
if isinstance(expression, exp.Xor):
|
|
929
|
+
result_func = exp.xor
|
|
930
|
+
# Do not deduplicate XOR as A XOR A != A if A == True
|
|
931
|
+
deduped = None
|
|
932
|
+
arr = tuple((gen(e), e) for e in flattened)
|
|
933
|
+
else:
|
|
934
|
+
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
|
|
935
|
+
deduped = {gen(e): e for e in flattened}
|
|
936
|
+
arr = tuple(deduped.items())
|
|
937
|
+
|
|
938
|
+
# check if the operands are already sorted, if not sort them
|
|
939
|
+
# A AND C AND B -> A AND B AND C
|
|
940
|
+
for i, (sql, e) in enumerate(arr[1:]):
|
|
941
|
+
if sql < arr[i][0]:
|
|
942
|
+
expression = result_func(*(e for _, e in sorted(arr)), copy=False)
|
|
1328
943
|
break
|
|
1329
944
|
else:
|
|
1330
|
-
|
|
945
|
+
# we didn't have to sort but maybe we need to dedup
|
|
946
|
+
if deduped and len(deduped) < len(flattened):
|
|
947
|
+
unique_operand = flattened[0]
|
|
948
|
+
if len(deduped) == 1:
|
|
949
|
+
expression = unique_operand.and_(exp.true(), copy=False)
|
|
950
|
+
else:
|
|
951
|
+
expression = result_func(*deduped.values(), copy=False)
|
|
952
|
+
|
|
953
|
+
return expression
|
|
954
|
+
|
|
955
|
+
@annotate_types_on_change
|
|
956
|
+
def absorb_and_eliminate(self, expression, root=True):
|
|
957
|
+
"""
|
|
958
|
+
absorption:
|
|
959
|
+
A AND (A OR B) -> A
|
|
960
|
+
A OR (A AND B) -> A
|
|
961
|
+
A AND (NOT A OR B) -> A AND B
|
|
962
|
+
A OR (NOT A AND B) -> A OR B
|
|
963
|
+
elimination:
|
|
964
|
+
(A AND B) OR (A AND NOT B) -> A
|
|
965
|
+
(A OR B) AND (A OR NOT B) -> A
|
|
966
|
+
"""
|
|
967
|
+
if isinstance(expression, self.AND_OR) and (root or not expression.same_parent):
|
|
968
|
+
kind = exp.Or if isinstance(expression, exp.And) else exp.And
|
|
969
|
+
|
|
970
|
+
ops = tuple(expression.flatten())
|
|
971
|
+
|
|
972
|
+
# Initialize lookup tables:
|
|
973
|
+
# Set of all operands, used to find complements for absorption.
|
|
974
|
+
op_set = set()
|
|
975
|
+
# Sub-operands, used to find subsets for absorption.
|
|
976
|
+
subops = defaultdict(list)
|
|
977
|
+
# Pairs of complements, used for elimination.
|
|
978
|
+
pairs = defaultdict(list)
|
|
979
|
+
|
|
980
|
+
# Populate the lookup tables
|
|
981
|
+
for op in ops:
|
|
982
|
+
op_set.add(op)
|
|
983
|
+
|
|
984
|
+
if not isinstance(op, kind):
|
|
985
|
+
# In cases like: A OR (A AND B)
|
|
986
|
+
# Subop will be: ^
|
|
987
|
+
subops[op].append({op})
|
|
988
|
+
continue
|
|
989
|
+
|
|
990
|
+
# In cases like: (A AND B) OR (A AND B AND C)
|
|
991
|
+
# Subops will be: ^ ^
|
|
992
|
+
subset = set(op.flatten())
|
|
993
|
+
for i in subset:
|
|
994
|
+
subops[i].append(subset)
|
|
995
|
+
|
|
996
|
+
a, b = op.unnest_operands()
|
|
997
|
+
if isinstance(a, exp.Not):
|
|
998
|
+
pairs[frozenset((a.this, b))].append((op, b))
|
|
999
|
+
if isinstance(b, exp.Not):
|
|
1000
|
+
pairs[frozenset((a, b.this))].append((op, a))
|
|
1001
|
+
|
|
1002
|
+
for op in ops:
|
|
1003
|
+
if not isinstance(op, kind):
|
|
1004
|
+
continue
|
|
1005
|
+
|
|
1006
|
+
a, b = op.unnest_operands()
|
|
1007
|
+
|
|
1008
|
+
# Absorb
|
|
1009
|
+
if isinstance(a, exp.Not) and a.this in op_set:
|
|
1010
|
+
a.replace(exp.true() if kind == exp.And else exp.false())
|
|
1011
|
+
continue
|
|
1012
|
+
if isinstance(b, exp.Not) and b.this in op_set:
|
|
1013
|
+
b.replace(exp.true() if kind == exp.And else exp.false())
|
|
1014
|
+
continue
|
|
1015
|
+
superset = set(op.flatten())
|
|
1016
|
+
if any(any(subset < superset for subset in subops[i]) for i in superset):
|
|
1017
|
+
op.replace(exp.false() if kind == exp.And else exp.true())
|
|
1018
|
+
continue
|
|
1019
|
+
|
|
1020
|
+
# Eliminate
|
|
1021
|
+
for other, complement in pairs[frozenset((a, b))]:
|
|
1022
|
+
op.replace(complement)
|
|
1023
|
+
other.replace(complement)
|
|
1024
|
+
|
|
1025
|
+
return expression
|
|
1026
|
+
|
|
1027
|
+
@annotate_types_on_change
|
|
1028
|
+
@catch(ModuleNotFoundError, UnsupportedUnit)
|
|
1029
|
+
def simplify_equality(self, expression: exp.Expression) -> exp.Expression:
|
|
1030
|
+
"""
|
|
1031
|
+
Use the subtraction and addition properties of equality to simplify expressions:
|
|
1032
|
+
|
|
1033
|
+
x + 1 = 3 becomes x = 2
|
|
1034
|
+
|
|
1035
|
+
There are two binary operations in the above expression: + and =
|
|
1036
|
+
Here's how we reference all the operands in the code below:
|
|
1037
|
+
|
|
1038
|
+
l r
|
|
1039
|
+
x + 1 = 3
|
|
1040
|
+
a b
|
|
1041
|
+
"""
|
|
1042
|
+
if isinstance(expression, self.COMPARISONS):
|
|
1043
|
+
l, r = expression.left, expression.right
|
|
1044
|
+
|
|
1045
|
+
if l.__class__ not in self.INVERSE_OPS:
|
|
1046
|
+
return expression
|
|
1047
|
+
|
|
1048
|
+
if r.is_number:
|
|
1049
|
+
a_predicate = _is_number
|
|
1050
|
+
b_predicate = _is_number
|
|
1051
|
+
elif _is_date_literal(r):
|
|
1052
|
+
a_predicate = _is_date_literal
|
|
1053
|
+
b_predicate = _is_interval
|
|
1054
|
+
else:
|
|
1055
|
+
return expression
|
|
1056
|
+
|
|
1057
|
+
if l.__class__ in self.INVERSE_DATE_OPS:
|
|
1058
|
+
l = t.cast(exp.IntervalOp, l)
|
|
1059
|
+
a = l.this
|
|
1060
|
+
b = l.interval()
|
|
1061
|
+
else:
|
|
1062
|
+
l = t.cast(exp.Binary, l)
|
|
1063
|
+
a, b = l.left, l.right
|
|
1331
1064
|
|
|
1332
|
-
|
|
1333
|
-
|
|
1334
|
-
|
|
1065
|
+
if not a_predicate(a) and b_predicate(b):
|
|
1066
|
+
pass
|
|
1067
|
+
elif not a_predicate(b) and b_predicate(a):
|
|
1068
|
+
a, b = b, a
|
|
1069
|
+
else:
|
|
1070
|
+
return expression
|
|
1071
|
+
|
|
1072
|
+
return expression.__class__(
|
|
1073
|
+
this=a, expression=self.INVERSE_OPS[l.__class__](this=r, expression=b)
|
|
1335
1074
|
)
|
|
1336
|
-
|
|
1075
|
+
return expression
|
|
1076
|
+
|
|
1077
|
+
@annotate_types_on_change
|
|
1078
|
+
def simplify_literals(self, expression, root=True):
|
|
1079
|
+
if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
|
|
1080
|
+
return self._flat_simplify(expression, self._simplify_binary, root)
|
|
1081
|
+
|
|
1082
|
+
if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
|
|
1083
|
+
return expression.this.this
|
|
1084
|
+
|
|
1085
|
+
if type(expression) in self.INVERSE_DATE_OPS:
|
|
1086
|
+
return (
|
|
1087
|
+
self._simplify_binary(expression, expression.this, expression.interval())
|
|
1088
|
+
or expression
|
|
1089
|
+
)
|
|
1090
|
+
|
|
1091
|
+
return expression
|
|
1092
|
+
|
|
1093
|
+
def _simplify_integer_cast(self, expr: exp.Expression) -> exp.Expression:
|
|
1094
|
+
if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
|
|
1095
|
+
this = self._simplify_integer_cast(expr.this)
|
|
1096
|
+
else:
|
|
1097
|
+
this = expr.this
|
|
1098
|
+
|
|
1099
|
+
if isinstance(expr, exp.Cast) and this.is_int:
|
|
1100
|
+
num = this.to_py()
|
|
1101
|
+
|
|
1102
|
+
# Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
|
|
1103
|
+
# integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
|
|
1104
|
+
# engine-dependent
|
|
1105
|
+
if (
|
|
1106
|
+
self.TINYINT_MIN <= num <= self.TINYINT_MAX
|
|
1107
|
+
and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
|
|
1108
|
+
) or (
|
|
1109
|
+
self.UTINYINT_MIN <= num <= self.UTINYINT_MAX
|
|
1110
|
+
and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
|
|
1111
|
+
):
|
|
1112
|
+
return this
|
|
1113
|
+
|
|
1114
|
+
return expr
|
|
1115
|
+
|
|
1116
|
+
def _simplify_binary(self, expression, a, b):
|
|
1117
|
+
if isinstance(expression, self.COMPARISONS):
|
|
1118
|
+
a = self._simplify_integer_cast(a)
|
|
1119
|
+
b = self._simplify_integer_cast(b)
|
|
1120
|
+
|
|
1121
|
+
if isinstance(expression, exp.Is):
|
|
1122
|
+
if isinstance(b, exp.Not):
|
|
1123
|
+
c = b.this
|
|
1124
|
+
not_ = True
|
|
1125
|
+
else:
|
|
1126
|
+
c = b
|
|
1127
|
+
not_ = False
|
|
1128
|
+
|
|
1129
|
+
if is_null(c):
|
|
1130
|
+
if isinstance(a, exp.Literal):
|
|
1131
|
+
return exp.true() if not_ else exp.false()
|
|
1132
|
+
if is_null(a):
|
|
1133
|
+
return exp.false() if not_ else exp.true()
|
|
1134
|
+
elif isinstance(expression, self.NULL_OK):
|
|
1135
|
+
return None
|
|
1136
|
+
elif (is_null(a) or is_null(b)) and isinstance(expression.parent, exp.If):
|
|
1137
|
+
return exp.null()
|
|
1138
|
+
|
|
1139
|
+
if a.is_number and b.is_number:
|
|
1140
|
+
num_a = a.to_py()
|
|
1141
|
+
num_b = b.to_py()
|
|
1142
|
+
|
|
1143
|
+
if isinstance(expression, exp.Add):
|
|
1144
|
+
return exp.Literal.number(num_a + num_b)
|
|
1145
|
+
if isinstance(expression, exp.Mul):
|
|
1146
|
+
return exp.Literal.number(num_a * num_b)
|
|
1147
|
+
|
|
1148
|
+
# We only simplify Sub, Div if a and b have the same parent because they're not associative
|
|
1149
|
+
if isinstance(expression, exp.Sub):
|
|
1150
|
+
return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
|
|
1151
|
+
if isinstance(expression, exp.Div):
|
|
1152
|
+
# engines have differing int div behavior so intdiv is not safe
|
|
1153
|
+
if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
|
|
1154
|
+
return None
|
|
1155
|
+
return exp.Literal.number(num_a / num_b)
|
|
1156
|
+
|
|
1157
|
+
boolean = eval_boolean(expression, num_a, num_b)
|
|
1158
|
+
|
|
1159
|
+
if boolean:
|
|
1160
|
+
return boolean
|
|
1161
|
+
elif a.is_string and b.is_string:
|
|
1162
|
+
boolean = eval_boolean(expression, a.this, b.this)
|
|
1163
|
+
|
|
1164
|
+
if boolean:
|
|
1165
|
+
return boolean
|
|
1166
|
+
elif _is_date_literal(a) and isinstance(b, exp.Interval):
|
|
1167
|
+
date, b = extract_date(a), extract_interval(b)
|
|
1168
|
+
if date and b:
|
|
1169
|
+
if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
|
|
1170
|
+
return date_literal(date + b, extract_type(a))
|
|
1171
|
+
if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
|
|
1172
|
+
return date_literal(date - b, extract_type(a))
|
|
1173
|
+
elif isinstance(a, exp.Interval) and _is_date_literal(b):
|
|
1174
|
+
a, date = extract_interval(a), extract_date(b)
|
|
1175
|
+
# you cannot subtract a date from an interval
|
|
1176
|
+
if a and b and isinstance(expression, exp.Add):
|
|
1177
|
+
return date_literal(a + date, extract_type(b))
|
|
1178
|
+
elif _is_date_literal(a) and _is_date_literal(b):
|
|
1179
|
+
if isinstance(expression, exp.Predicate):
|
|
1180
|
+
a, b = extract_date(a), extract_date(b)
|
|
1181
|
+
boolean = eval_boolean(expression, a, b)
|
|
1182
|
+
if boolean:
|
|
1183
|
+
return boolean
|
|
1184
|
+
|
|
1185
|
+
return None
|
|
1186
|
+
|
|
1187
|
+
@annotate_types_on_change
|
|
1188
|
+
def simplify_coalesce(self, expression: exp.Expression) -> exp.Expression:
|
|
1189
|
+
# COALESCE(x) -> x
|
|
1190
|
+
if (
|
|
1191
|
+
isinstance(expression, exp.Coalesce)
|
|
1192
|
+
and (not expression.expressions or _is_nonnull_constant(expression.this))
|
|
1193
|
+
# COALESCE is also used as a Spark partitioning hint
|
|
1194
|
+
and not isinstance(expression.parent, exp.Hint)
|
|
1195
|
+
):
|
|
1196
|
+
return expression.this
|
|
1197
|
+
|
|
1198
|
+
if self.dialect.COALESCE_COMPARISON_NON_STANDARD:
|
|
1199
|
+
return expression
|
|
1200
|
+
|
|
1201
|
+
if not isinstance(expression, self.COMPARISONS):
|
|
1202
|
+
return expression
|
|
1203
|
+
|
|
1204
|
+
if isinstance(expression.left, exp.Coalesce):
|
|
1205
|
+
coalesce = expression.left
|
|
1206
|
+
other = expression.right
|
|
1207
|
+
elif isinstance(expression.right, exp.Coalesce):
|
|
1208
|
+
coalesce = expression.right
|
|
1209
|
+
other = expression.left
|
|
1210
|
+
else:
|
|
1211
|
+
return expression
|
|
1212
|
+
|
|
1213
|
+
# This transformation is valid for non-constants,
|
|
1214
|
+
# but it really only does anything if they are both constants.
|
|
1215
|
+
if not _is_constant(other):
|
|
1216
|
+
return expression
|
|
1217
|
+
|
|
1218
|
+
# Find the first constant arg
|
|
1219
|
+
for arg_index, arg in enumerate(coalesce.expressions):
|
|
1220
|
+
if _is_constant(arg):
|
|
1221
|
+
break
|
|
1222
|
+
else:
|
|
1223
|
+
return expression
|
|
1224
|
+
|
|
1225
|
+
coalesce.set("expressions", coalesce.expressions[:arg_index])
|
|
1226
|
+
|
|
1227
|
+
# Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
|
|
1228
|
+
# since we already remove COALESCE at the top of this function.
|
|
1229
|
+
coalesce = coalesce if coalesce.expressions else coalesce.this
|
|
1230
|
+
|
|
1231
|
+
# This expression is more complex than when we started, but it will get simplified further
|
|
1232
|
+
return exp.paren(
|
|
1233
|
+
exp.or_(
|
|
1234
|
+
exp.and_(
|
|
1235
|
+
coalesce.is_(exp.null()).not_(copy=False),
|
|
1236
|
+
expression.copy(),
|
|
1237
|
+
copy=False,
|
|
1238
|
+
),
|
|
1239
|
+
exp.and_(
|
|
1240
|
+
coalesce.is_(exp.null()),
|
|
1241
|
+
type(expression)(this=arg.copy(), expression=other.copy()),
|
|
1242
|
+
copy=False,
|
|
1243
|
+
),
|
|
1244
|
+
copy=False,
|
|
1245
|
+
),
|
|
1246
|
+
copy=False,
|
|
1247
|
+
)
|
|
1248
|
+
|
|
1249
|
+
@annotate_types_on_change
|
|
1250
|
+
def simplify_concat(self, expression):
|
|
1251
|
+
"""Reduces all groups that contain string literals by concatenating them."""
|
|
1252
|
+
if not isinstance(expression, self.CONCATS) or (
|
|
1253
|
+
# We can't reduce a CONCAT_WS call if we don't statically know the separator
|
|
1254
|
+
isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
|
|
1255
|
+
):
|
|
1256
|
+
return expression
|
|
1257
|
+
|
|
1258
|
+
if isinstance(expression, exp.ConcatWs):
|
|
1259
|
+
sep_expr, *expressions = expression.expressions
|
|
1260
|
+
sep = sep_expr.name
|
|
1261
|
+
concat_type = exp.ConcatWs
|
|
1262
|
+
args = {}
|
|
1263
|
+
else:
|
|
1264
|
+
expressions = expression.expressions
|
|
1265
|
+
sep = ""
|
|
1266
|
+
concat_type = exp.Concat
|
|
1267
|
+
args = {
|
|
1268
|
+
"safe": expression.args.get("safe"),
|
|
1269
|
+
"coalesce": expression.args.get("coalesce"),
|
|
1270
|
+
}
|
|
1271
|
+
|
|
1272
|
+
new_args = []
|
|
1273
|
+
for is_string_group, group in itertools.groupby(
|
|
1274
|
+
expressions or expression.flatten(), lambda e: e.is_string
|
|
1275
|
+
):
|
|
1276
|
+
if is_string_group:
|
|
1277
|
+
new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
|
|
1278
|
+
else:
|
|
1279
|
+
new_args.extend(group)
|
|
1280
|
+
|
|
1281
|
+
if len(new_args) == 1 and new_args[0].is_string:
|
|
1282
|
+
return new_args[0]
|
|
1283
|
+
|
|
1284
|
+
if concat_type is exp.ConcatWs:
|
|
1285
|
+
new_args = [sep_expr] + new_args
|
|
1286
|
+
elif isinstance(expression, exp.DPipe):
|
|
1287
|
+
return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
|
|
1288
|
+
|
|
1289
|
+
return concat_type(expressions=new_args, **args)
|
|
1290
|
+
|
|
1291
|
+
@annotate_types_on_change
|
|
1292
|
+
def simplify_conditionals(self, expression):
|
|
1293
|
+
"""Simplifies expressions like IF, CASE if their condition is statically known."""
|
|
1294
|
+
if isinstance(expression, exp.Case):
|
|
1295
|
+
this = expression.this
|
|
1296
|
+
for case in expression.args["ifs"]:
|
|
1297
|
+
cond = case.this
|
|
1298
|
+
if this:
|
|
1299
|
+
# Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
|
|
1300
|
+
cond = cond.replace(this.pop().eq(cond))
|
|
1301
|
+
|
|
1302
|
+
if always_true(cond):
|
|
1303
|
+
return case.args["true"]
|
|
1304
|
+
|
|
1305
|
+
if always_false(cond):
|
|
1306
|
+
case.pop()
|
|
1307
|
+
if not expression.args["ifs"]:
|
|
1308
|
+
return expression.args.get("default") or exp.null()
|
|
1309
|
+
elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
|
|
1310
|
+
if always_true(expression.this):
|
|
1311
|
+
return expression.args["true"]
|
|
1312
|
+
if always_false(expression.this):
|
|
1313
|
+
return expression.args.get("false") or exp.null()
|
|
1314
|
+
|
|
1315
|
+
return expression
|
|
1316
|
+
|
|
1317
|
+
@annotate_types_on_change
|
|
1318
|
+
def simplify_startswith(self, expression: exp.Expression) -> exp.Expression:
|
|
1319
|
+
"""
|
|
1320
|
+
Reduces a prefix check to either TRUE or FALSE if both the string and the
|
|
1321
|
+
prefix are statically known.
|
|
1322
|
+
|
|
1323
|
+
Example:
|
|
1324
|
+
>>> from sqlglot import parse_one
|
|
1325
|
+
>>> Simplifier().simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
|
|
1326
|
+
'TRUE'
|
|
1327
|
+
"""
|
|
1328
|
+
if (
|
|
1329
|
+
isinstance(expression, exp.StartsWith)
|
|
1330
|
+
and expression.this.is_string
|
|
1331
|
+
and expression.expression.is_string
|
|
1332
|
+
):
|
|
1333
|
+
return exp.convert(expression.name.startswith(expression.expression.name))
|
|
1334
|
+
|
|
1335
|
+
return expression
|
|
1336
|
+
|
|
1337
|
+
def _is_datetrunc_predicate(self, left: exp.Expression, right: exp.Expression) -> bool:
|
|
1338
|
+
return isinstance(left, self.DATETRUNCS) and _is_date_literal(right)
|
|
1339
|
+
|
|
1340
|
+
@annotate_types_on_change
|
|
1341
|
+
@catch(ModuleNotFoundError, UnsupportedUnit)
|
|
1342
|
+
def simplify_datetrunc(self, expression: exp.Expression) -> exp.Expression:
|
|
1343
|
+
"""Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
|
|
1344
|
+
comparison = expression.__class__
|
|
1345
|
+
|
|
1346
|
+
if isinstance(expression, self.DATETRUNCS):
|
|
1347
|
+
this = expression.this
|
|
1348
|
+
trunc_type = extract_type(this)
|
|
1349
|
+
date = extract_date(this)
|
|
1350
|
+
if date and expression.unit:
|
|
1351
|
+
return date_literal(
|
|
1352
|
+
date_floor(date, expression.unit.name.lower(), self.dialect), trunc_type
|
|
1353
|
+
)
|
|
1354
|
+
elif comparison not in self.DATETRUNC_COMPARISONS:
|
|
1355
|
+
return expression
|
|
1356
|
+
|
|
1357
|
+
if isinstance(expression, exp.Binary):
|
|
1358
|
+
l, r = expression.left, expression.right
|
|
1359
|
+
|
|
1360
|
+
if not self._is_datetrunc_predicate(l, r):
|
|
1361
|
+
return expression
|
|
1362
|
+
|
|
1363
|
+
l = t.cast(exp.DateTrunc, l)
|
|
1364
|
+
trunc_arg = l.this
|
|
1365
|
+
unit = l.unit.name.lower()
|
|
1366
|
+
date = extract_date(r)
|
|
1367
|
+
|
|
1368
|
+
if not date:
|
|
1369
|
+
return expression
|
|
1370
|
+
|
|
1371
|
+
return (
|
|
1372
|
+
self.DATETRUNC_BINARY_COMPARISONS[comparison](
|
|
1373
|
+
trunc_arg, date, unit, self.dialect, extract_type(r)
|
|
1374
|
+
)
|
|
1375
|
+
or expression
|
|
1376
|
+
)
|
|
1377
|
+
|
|
1378
|
+
if isinstance(expression, exp.In):
|
|
1379
|
+
l = expression.this
|
|
1380
|
+
rs = expression.expressions
|
|
1381
|
+
|
|
1382
|
+
if rs and all(self._is_datetrunc_predicate(l, r) for r in rs):
|
|
1383
|
+
l = t.cast(exp.DateTrunc, l)
|
|
1384
|
+
unit = l.unit.name.lower()
|
|
1385
|
+
|
|
1386
|
+
ranges = []
|
|
1387
|
+
for r in rs:
|
|
1388
|
+
date = extract_date(r)
|
|
1389
|
+
if not date:
|
|
1390
|
+
return expression
|
|
1391
|
+
drange = _datetrunc_range(date, unit, self.dialect)
|
|
1392
|
+
if drange:
|
|
1393
|
+
ranges.append(drange)
|
|
1394
|
+
|
|
1395
|
+
if not ranges:
|
|
1396
|
+
return expression
|
|
1397
|
+
|
|
1398
|
+
ranges = merge_ranges(ranges)
|
|
1399
|
+
target_type = extract_type(*rs)
|
|
1400
|
+
|
|
1401
|
+
return exp.or_(
|
|
1402
|
+
*[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges],
|
|
1403
|
+
copy=False,
|
|
1404
|
+
)
|
|
1405
|
+
|
|
1406
|
+
return expression
|
|
1407
|
+
|
|
1408
|
+
@annotate_types_on_change
|
|
1409
|
+
def sort_comparison(self, expression: exp.Expression) -> exp.Expression:
|
|
1410
|
+
if expression.__class__ in self.COMPLEMENT_COMPARISONS:
|
|
1411
|
+
l, r = expression.this, expression.expression
|
|
1412
|
+
l_column = isinstance(l, exp.Column)
|
|
1413
|
+
r_column = isinstance(r, exp.Column)
|
|
1414
|
+
l_const = _is_constant(l)
|
|
1415
|
+
r_const = _is_constant(r)
|
|
1416
|
+
|
|
1417
|
+
if (
|
|
1418
|
+
(l_column and not r_column)
|
|
1419
|
+
or (r_const and not l_const)
|
|
1420
|
+
or isinstance(r, exp.SubqueryPredicate)
|
|
1421
|
+
):
|
|
1422
|
+
return expression
|
|
1423
|
+
if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
|
|
1424
|
+
return self.INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
|
|
1425
|
+
this=r, expression=l
|
|
1426
|
+
)
|
|
1427
|
+
return expression
|
|
1428
|
+
|
|
1429
|
+
def _flat_simplify(self, expression, simplifier, root=True):
|
|
1430
|
+
if root or not expression.same_parent:
|
|
1431
|
+
operands = []
|
|
1432
|
+
queue = deque(expression.flatten(unnest=False))
|
|
1433
|
+
size = len(queue)
|
|
1434
|
+
|
|
1435
|
+
while queue:
|
|
1436
|
+
a = queue.popleft()
|
|
1437
|
+
|
|
1438
|
+
for b in queue:
|
|
1439
|
+
result = simplifier(expression, a, b)
|
|
1440
|
+
|
|
1441
|
+
if result and result is not expression:
|
|
1442
|
+
queue.remove(b)
|
|
1443
|
+
queue.appendleft(result)
|
|
1444
|
+
break
|
|
1445
|
+
else:
|
|
1446
|
+
operands.append(a)
|
|
1447
|
+
|
|
1448
|
+
if len(operands) < size:
|
|
1449
|
+
return functools.reduce(
|
|
1450
|
+
lambda a, b: expression.__class__(this=a, expression=b), operands
|
|
1451
|
+
)
|
|
1452
|
+
return expression
|
|
1337
1453
|
|
|
1338
1454
|
|
|
1339
1455
|
def gen(expression: t.Any, comments: bool = False) -> str:
|