sqlglot 27.27.0__py3-none-any.whl → 28.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. sqlglot/__init__.py +1 -0
  2. sqlglot/__main__.py +6 -4
  3. sqlglot/_version.py +2 -2
  4. sqlglot/dialects/bigquery.py +118 -279
  5. sqlglot/dialects/clickhouse.py +73 -5
  6. sqlglot/dialects/databricks.py +38 -1
  7. sqlglot/dialects/dialect.py +354 -275
  8. sqlglot/dialects/dremio.py +4 -1
  9. sqlglot/dialects/duckdb.py +754 -25
  10. sqlglot/dialects/exasol.py +243 -10
  11. sqlglot/dialects/hive.py +8 -8
  12. sqlglot/dialects/mysql.py +14 -4
  13. sqlglot/dialects/oracle.py +29 -0
  14. sqlglot/dialects/postgres.py +60 -26
  15. sqlglot/dialects/presto.py +47 -16
  16. sqlglot/dialects/redshift.py +16 -0
  17. sqlglot/dialects/risingwave.py +3 -0
  18. sqlglot/dialects/singlestore.py +12 -3
  19. sqlglot/dialects/snowflake.py +239 -218
  20. sqlglot/dialects/spark.py +15 -4
  21. sqlglot/dialects/spark2.py +11 -48
  22. sqlglot/dialects/sqlite.py +10 -0
  23. sqlglot/dialects/starrocks.py +3 -0
  24. sqlglot/dialects/teradata.py +5 -8
  25. sqlglot/dialects/trino.py +6 -0
  26. sqlglot/dialects/tsql.py +61 -22
  27. sqlglot/diff.py +4 -2
  28. sqlglot/errors.py +69 -0
  29. sqlglot/executor/__init__.py +5 -10
  30. sqlglot/executor/python.py +1 -29
  31. sqlglot/expressions.py +637 -100
  32. sqlglot/generator.py +160 -43
  33. sqlglot/helper.py +2 -44
  34. sqlglot/lineage.py +10 -4
  35. sqlglot/optimizer/annotate_types.py +247 -140
  36. sqlglot/optimizer/canonicalize.py +6 -1
  37. sqlglot/optimizer/eliminate_joins.py +1 -1
  38. sqlglot/optimizer/eliminate_subqueries.py +2 -2
  39. sqlglot/optimizer/merge_subqueries.py +5 -5
  40. sqlglot/optimizer/normalize.py +20 -13
  41. sqlglot/optimizer/normalize_identifiers.py +17 -3
  42. sqlglot/optimizer/optimizer.py +4 -0
  43. sqlglot/optimizer/pushdown_predicates.py +1 -1
  44. sqlglot/optimizer/qualify.py +18 -10
  45. sqlglot/optimizer/qualify_columns.py +122 -275
  46. sqlglot/optimizer/qualify_tables.py +128 -76
  47. sqlglot/optimizer/resolver.py +374 -0
  48. sqlglot/optimizer/scope.py +27 -16
  49. sqlglot/optimizer/simplify.py +1075 -959
  50. sqlglot/optimizer/unnest_subqueries.py +12 -2
  51. sqlglot/parser.py +296 -170
  52. sqlglot/planner.py +2 -2
  53. sqlglot/schema.py +15 -4
  54. sqlglot/tokens.py +42 -7
  55. sqlglot/transforms.py +77 -22
  56. sqlglot/typing/__init__.py +316 -0
  57. sqlglot/typing/bigquery.py +376 -0
  58. sqlglot/typing/hive.py +12 -0
  59. sqlglot/typing/presto.py +24 -0
  60. sqlglot/typing/snowflake.py +505 -0
  61. sqlglot/typing/spark2.py +58 -0
  62. sqlglot/typing/tsql.py +9 -0
  63. {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/METADATA +2 -2
  64. sqlglot-28.4.0.dist-info/RECORD +92 -0
  65. sqlglot-27.27.0.dist-info/RECORD +0 -84
  66. {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/WHEEL +0 -0
  67. {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/licenses/LICENSE +0 -0
  68. {sqlglot-27.27.0.dist-info → sqlglot-28.4.0.dist-info}/top_level.txt +0 -0
@@ -6,34 +6,37 @@ import functools
6
6
  import itertools
7
7
  import typing as t
8
8
  from collections import deque, defaultdict
9
- from functools import reduce
9
+ from functools import reduce, wraps
10
10
 
11
11
  import sqlglot
12
12
  from sqlglot import Dialect, exp
13
13
  from sqlglot.helper import first, merge_ranges, while_changing
14
+ from sqlglot.optimizer.annotate_types import TypeAnnotator
14
15
  from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
16
+ from sqlglot.schema import ensure_schema
15
17
 
16
18
  if t.TYPE_CHECKING:
17
19
  from sqlglot.dialects.dialect import DialectType
18
20
 
21
+ DateRange = t.Tuple[datetime.date, datetime.date]
19
22
  DateTruncBinaryTransform = t.Callable[
20
23
  [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
21
24
  ]
22
25
 
26
+
23
27
  logger = logging.getLogger("sqlglot")
24
28
 
29
+
25
30
  # Final means that an expression should not be simplified
26
31
  FINAL = "final"
27
32
 
28
- # Value ranges for byte-sized signed/unsigned integers
29
- TINYINT_MIN = -128
30
- TINYINT_MAX = 127
31
- UTINYINT_MIN = 0
32
- UTINYINT_MAX = 255
33
-
34
-
35
- class UnsupportedUnit(Exception):
36
- pass
33
+ SIMPLIFIABLE = (
34
+ exp.Binary,
35
+ exp.Func,
36
+ exp.Lambda,
37
+ exp.Predicate,
38
+ exp.Unary,
39
+ )
37
40
 
38
41
 
39
42
  def simplify(
@@ -60,96 +63,15 @@ def simplify(
60
63
  Returns:
61
64
  sqlglot.Expression: simplified expression
62
65
  """
66
+ return Simplifier(dialect=dialect).simplify(
67
+ expression,
68
+ constant_propagation=constant_propagation,
69
+ coalesce_simplification=coalesce_simplification,
70
+ )
63
71
 
64
- dialect = Dialect.get_or_raise(dialect)
65
-
66
- def _simplify(expression):
67
- pre_transformation_stack = [expression]
68
- post_transformation_stack = []
69
-
70
- while pre_transformation_stack:
71
- node = pre_transformation_stack.pop()
72
-
73
- if node.meta.get(FINAL):
74
- continue
75
-
76
- # group by expressions cannot be simplified, for example
77
- # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
78
- # the projection must exactly match the group by key
79
- group = node.args.get("group")
80
-
81
- if group and hasattr(node, "selects"):
82
- groups = set(group.expressions)
83
- group.meta[FINAL] = True
84
-
85
- for s in node.selects:
86
- for n in s.walk():
87
- if n in groups:
88
- s.meta[FINAL] = True
89
- break
90
-
91
- having = node.args.get("having")
92
- if having:
93
- for n in having.walk():
94
- if n in groups:
95
- having.meta[FINAL] = True
96
- break
97
-
98
- parent = node.parent
99
- root = node is expression
100
-
101
- new_node = rewrite_between(node)
102
- new_node = uniq_sort(new_node, root)
103
- new_node = absorb_and_eliminate(new_node, root)
104
- new_node = simplify_concat(new_node)
105
- new_node = simplify_conditionals(new_node)
106
-
107
- if constant_propagation:
108
- new_node = propagate_constants(new_node, root)
109
-
110
- if new_node is not node:
111
- node.replace(new_node)
112
-
113
- pre_transformation_stack.extend(
114
- n for n in new_node.iter_expressions(reverse=True) if not n.meta.get(FINAL)
115
- )
116
- post_transformation_stack.append((new_node, parent))
117
-
118
- while post_transformation_stack:
119
- node, parent = post_transformation_stack.pop()
120
- root = node is expression
121
-
122
- # Resets parent, arg_key, index pointers– this is needed because some of the
123
- # previous transformations mutate the AST, leading to an inconsistent state
124
- for k, v in tuple(node.args.items()):
125
- node.set(k, v)
126
-
127
- # Post-order transformations
128
- new_node = simplify_not(node)
129
- new_node = flatten(new_node)
130
- new_node = simplify_connectors(new_node, root)
131
- new_node = remove_complements(new_node, root)
132
-
133
- if coalesce_simplification:
134
- new_node = simplify_coalesce(new_node, dialect)
135
-
136
- new_node.parent = parent
137
-
138
- new_node = simplify_literals(new_node, root)
139
- new_node = simplify_equality(new_node)
140
- new_node = simplify_parens(new_node, dialect)
141
- new_node = simplify_datetrunc(new_node, dialect)
142
- new_node = sort_comparison(new_node)
143
- new_node = simplify_startswith(new_node)
144
-
145
- if new_node is not node:
146
- node.replace(new_node)
147
-
148
- return new_node
149
72
 
150
- expression = while_changing(expression, _simplify)
151
- remove_where_true(expression)
152
- return expression
73
+ class UnsupportedUnit(Exception):
74
+ pass
153
75
 
154
76
 
155
77
  def catch(*exceptions):
@@ -167,87 +89,30 @@ def catch(*exceptions):
167
89
  return decorator
168
90
 
169
91
 
170
- def rewrite_between(expression: exp.Expression) -> exp.Expression:
171
- """Rewrite x between y and z to x >= y AND x <= z.
172
-
173
- This is done because comparison simplification is only done on lt/lte/gt/gte.
174
- """
175
- if isinstance(expression, exp.Between):
176
- negate = isinstance(expression.parent, exp.Not)
177
-
178
- expression = exp.and_(
179
- exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
180
- exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
181
- copy=False,
182
- )
183
-
184
- if negate:
185
- expression = exp.paren(expression, copy=False)
92
+ def annotate_types_on_change(func):
93
+ @wraps(func)
94
+ def _func(self, expression: exp.Expression, *args, **kwargs) -> t.Optional[exp.Expression]:
95
+ new_expression = func(self, expression, *args, **kwargs)
186
96
 
187
- return expression
97
+ if new_expression is None:
98
+ return new_expression
188
99
 
100
+ if self.annotate_new_expressions and expression != new_expression:
101
+ self._annotator.clear()
189
102
 
190
- COMPLEMENT_COMPARISONS = {
191
- exp.LT: exp.GTE,
192
- exp.GT: exp.LTE,
193
- exp.LTE: exp.GT,
194
- exp.GTE: exp.LT,
195
- exp.EQ: exp.NEQ,
196
- exp.NEQ: exp.EQ,
197
- }
103
+ # We annotate this to ensure new children nodes are also annotated
104
+ new_expression = self._annotator.annotate(
105
+ expression=new_expression,
106
+ annotate_scope=False,
107
+ )
198
108
 
199
- COMPLEMENT_SUBQUERY_PREDICATES = {
200
- exp.All: exp.Any,
201
- exp.Any: exp.All,
202
- }
109
+ # Whatever expression the original expression is transformed into needs to preserve
110
+ # the original type, otherwise the simplification could result in a different schema
111
+ new_expression.type = expression.type
203
112
 
113
+ return new_expression
204
114
 
205
- def simplify_not(expression):
206
- """
207
- Demorgan's Law
208
- NOT (x OR y) -> NOT x AND NOT y
209
- NOT (x AND y) -> NOT x OR NOT y
210
- """
211
- if isinstance(expression, exp.Not):
212
- this = expression.this
213
- if is_null(this):
214
- return exp.null()
215
- if this.__class__ in COMPLEMENT_COMPARISONS:
216
- right = this.expression
217
- complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__)
218
- if complement_subquery_predicate:
219
- right = complement_subquery_predicate(this=right.this)
220
-
221
- return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
222
- if isinstance(this, exp.Paren):
223
- condition = this.unnest()
224
- if isinstance(condition, exp.And):
225
- return exp.paren(
226
- exp.or_(
227
- exp.not_(condition.left, copy=False),
228
- exp.not_(condition.right, copy=False),
229
- copy=False,
230
- )
231
- )
232
- if isinstance(condition, exp.Or):
233
- return exp.paren(
234
- exp.and_(
235
- exp.not_(condition.left, copy=False),
236
- exp.not_(condition.right, copy=False),
237
- copy=False,
238
- )
239
- )
240
- if is_null(condition):
241
- return exp.null()
242
- if always_true(this):
243
- return exp.false()
244
- if is_false(this):
245
- return exp.true()
246
- if isinstance(this, exp.Not):
247
- # double negation
248
- # NOT NOT x -> x
249
- return this.this
250
- return expression
115
+ return _func
251
116
 
252
117
 
253
118
  def flatten(expression):
@@ -263,246 +128,43 @@ def flatten(expression):
263
128
  return expression
264
129
 
265
130
 
266
- def simplify_connectors(expression, root=True):
267
- def _simplify_connectors(expression, left, right):
268
- if isinstance(expression, exp.And):
269
- if is_false(left) or is_false(right):
270
- return exp.false()
271
- if is_zero(left) or is_zero(right):
272
- return exp.false()
273
- if is_null(left) or is_null(right):
274
- return exp.null()
275
- if always_true(left) and always_true(right):
276
- return exp.true()
277
- if always_true(left):
278
- return right
279
- if always_true(right):
280
- return left
281
- return _simplify_comparison(expression, left, right)
282
- elif isinstance(expression, exp.Or):
283
- if always_true(left) or always_true(right):
284
- return exp.true()
285
- if (
286
- (is_null(left) and is_null(right))
287
- or (is_null(left) and always_false(right))
288
- or (always_false(left) and is_null(right))
289
- ):
290
- return exp.null()
291
- if is_false(left):
292
- return right
293
- if is_false(right):
294
- return left
295
- return _simplify_comparison(expression, left, right, or_=True)
296
- elif isinstance(expression, exp.Xor):
297
- if left == right:
298
- return exp.false()
299
-
300
- if isinstance(expression, exp.Connector):
301
- return _flat_simplify(expression, _simplify_connectors, root)
302
- return expression
303
-
304
-
305
- LT_LTE = (exp.LT, exp.LTE)
306
- GT_GTE = (exp.GT, exp.GTE)
307
-
308
- COMPARISONS = (
309
- *LT_LTE,
310
- *GT_GTE,
311
- exp.EQ,
312
- exp.NEQ,
313
- exp.Is,
314
- )
315
-
316
- INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
317
- exp.LT: exp.GT,
318
- exp.GT: exp.LT,
319
- exp.LTE: exp.GTE,
320
- exp.GTE: exp.LTE,
321
- }
322
-
323
- NONDETERMINISTIC = (exp.Rand, exp.Randn)
324
- AND_OR = (exp.And, exp.Or)
325
-
326
-
327
- def _simplify_comparison(expression, left, right, or_=False):
328
- if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
329
- ll, lr = left.args.values()
330
- rl, rr = right.args.values()
331
-
332
- largs = {ll, lr}
333
- rargs = {rl, rr}
334
-
335
- matching = largs & rargs
336
- columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)}
337
-
338
- if matching and columns:
339
- try:
340
- l = first(largs - columns)
341
- r = first(rargs - columns)
342
- except StopIteration:
343
- return expression
344
-
345
- if l.is_number and r.is_number:
346
- l = l.to_py()
347
- r = r.to_py()
348
- elif l.is_string and r.is_string:
349
- l = l.name
350
- r = r.name
351
- else:
352
- l = extract_date(l)
353
- if not l:
354
- return None
355
- r = extract_date(r)
356
- if not r:
357
- return None
358
- # python won't compare date and datetime, but many engines will upcast
359
- l, r = cast_as_datetime(l), cast_as_datetime(r)
360
-
361
- for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
362
- if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
363
- return left if (av > bv if or_ else av <= bv) else right
364
- if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
365
- return left if (av < bv if or_ else av >= bv) else right
366
-
367
- # we can't ever shortcut to true because the column could be null
368
- if not or_:
369
- if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
370
- if av <= bv:
371
- return exp.false()
372
- elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
373
- if av >= bv:
374
- return exp.false()
375
- elif isinstance(a, exp.EQ):
376
- if isinstance(b, exp.LT):
377
- return exp.false() if av >= bv else a
378
- if isinstance(b, exp.LTE):
379
- return exp.false() if av > bv else a
380
- if isinstance(b, exp.GT):
381
- return exp.false() if av <= bv else a
382
- if isinstance(b, exp.GTE):
383
- return exp.false() if av < bv else a
384
- if isinstance(b, exp.NEQ):
385
- return exp.false() if av == bv else a
386
- return None
387
-
388
-
389
- def remove_complements(expression, root=True):
390
- """
391
- Removing complements.
392
-
393
- A AND NOT A -> FALSE
394
- A OR NOT A -> TRUE
395
- """
396
- if isinstance(expression, AND_OR) and (root or not expression.same_parent):
397
- ops = set(expression.flatten())
398
- for op in ops:
399
- if isinstance(op, exp.Not) and op.this in ops:
400
- return exp.false() if isinstance(expression, exp.And) else exp.true()
401
-
402
- return expression
403
-
404
-
405
- def uniq_sort(expression, root=True):
406
- """
407
- Uniq and sort a connector.
408
-
409
- C AND A AND B AND B -> A AND B AND C
410
- """
411
- if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
412
- flattened = tuple(expression.flatten())
413
-
414
- if isinstance(expression, exp.Xor):
415
- result_func = exp.xor
416
- # Do not deduplicate XOR as A XOR A != A if A == True
417
- deduped = None
418
- arr = tuple((gen(e), e) for e in flattened)
419
- else:
420
- result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
421
- deduped = {gen(e): e for e in flattened}
422
- arr = tuple(deduped.items())
423
-
424
- # check if the operands are already sorted, if not sort them
425
- # A AND C AND B -> A AND B AND C
426
- for i, (sql, e) in enumerate(arr[1:]):
427
- if sql < arr[i][0]:
428
- expression = result_func(*(e for _, e in sorted(arr)), copy=False)
429
- break
430
- else:
431
- # we didn't have to sort but maybe we need to dedup
432
- if deduped and len(deduped) < len(flattened):
433
- expression = result_func(*deduped.values(), copy=False)
434
-
435
- return expression
436
-
437
-
438
- def absorb_and_eliminate(expression, root=True):
439
- """
440
- absorption:
441
- A AND (A OR B) -> A
442
- A OR (A AND B) -> A
443
- A AND (NOT A OR B) -> A AND B
444
- A OR (NOT A AND B) -> A OR B
445
- elimination:
446
- (A AND B) OR (A AND NOT B) -> A
447
- (A OR B) AND (A OR NOT B) -> A
448
- """
449
- if isinstance(expression, AND_OR) and (root or not expression.same_parent):
450
- kind = exp.Or if isinstance(expression, exp.And) else exp.And
451
-
452
- ops = tuple(expression.flatten())
453
-
454
- # Initialize lookup tables:
455
- # Set of all operands, used to find complements for absorption.
456
- op_set = set()
457
- # Sub-operands, used to find subsets for absorption.
458
- subops = defaultdict(list)
459
- # Pairs of complements, used for elimination.
460
- pairs = defaultdict(list)
461
-
462
- # Populate the lookup tables
463
- for op in ops:
464
- op_set.add(op)
465
-
466
- if not isinstance(op, kind):
467
- # In cases like: A OR (A AND B)
468
- # Subop will be: ^
469
- subops[op].append({op})
470
- continue
471
-
472
- # In cases like: (A AND B) OR (A AND B AND C)
473
- # Subops will be: ^ ^
474
- subset = set(op.flatten())
475
- for i in subset:
476
- subops[i].append(subset)
131
+ def simplify_parens(expression: exp.Expression, dialect: DialectType) -> exp.Expression:
132
+ if not isinstance(expression, exp.Paren):
133
+ return expression
477
134
 
478
- a, b = op.unnest_operands()
479
- if isinstance(a, exp.Not):
480
- pairs[frozenset((a.this, b))].append((op, b))
481
- if isinstance(b, exp.Not):
482
- pairs[frozenset((a, b.this))].append((op, a))
135
+ this = expression.this
136
+ parent = expression.parent
137
+ parent_is_predicate = isinstance(parent, exp.Predicate)
483
138
 
484
- for op in ops:
485
- if not isinstance(op, kind):
486
- continue
139
+ if isinstance(this, exp.Select):
140
+ return expression
487
141
 
488
- a, b = op.unnest_operands()
142
+ if isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)):
143
+ return expression
489
144
 
490
- # Absorb
491
- if isinstance(a, exp.Not) and a.this in op_set:
492
- a.replace(exp.true() if kind == exp.And else exp.false())
493
- continue
494
- if isinstance(b, exp.Not) and b.this in op_set:
495
- b.replace(exp.true() if kind == exp.And else exp.false())
496
- continue
497
- superset = set(op.flatten())
498
- if any(any(subset < superset for subset in subops[i]) for i in superset):
499
- op.replace(exp.false() if kind == exp.And else exp.true())
500
- continue
145
+ if (
146
+ Dialect.get_or_raise(dialect).REQUIRES_PARENTHESIZED_STRUCT_ACCESS
147
+ and isinstance(parent, exp.Dot)
148
+ and (isinstance(parent.right, (exp.Identifier, exp.Star)))
149
+ ):
150
+ return expression
501
151
 
502
- # Eliminate
503
- for other, complement in pairs[frozenset((a, b))]:
504
- op.replace(complement)
505
- other.replace(complement)
152
+ if (
153
+ not isinstance(parent, (exp.Condition, exp.Binary))
154
+ or isinstance(parent, exp.Paren)
155
+ or (
156
+ not isinstance(this, exp.Binary)
157
+ and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
158
+ )
159
+ or (
160
+ isinstance(this, exp.Predicate)
161
+ and not (parent_is_predicate or isinstance(parent, exp.Neg))
162
+ )
163
+ or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
164
+ or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
165
+ or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
166
+ ):
167
+ return this
506
168
 
507
169
  return expression
508
170
 
@@ -546,20 +208,6 @@ def propagate_constants(expression, root=True):
546
208
  return expression
547
209
 
548
210
 
549
- INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
550
- exp.DateAdd: exp.Sub,
551
- exp.DateSub: exp.Add,
552
- exp.DatetimeAdd: exp.Sub,
553
- exp.DatetimeSub: exp.Add,
554
- }
555
-
556
- INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
557
- **INVERSE_DATE_OPS,
558
- exp.Add: exp.Sub,
559
- exp.Sub: exp.Add,
560
- }
561
-
562
-
563
211
  def _is_number(expression: exp.Expression) -> bool:
564
212
  return expression.is_number
565
213
 
@@ -568,207 +216,6 @@ def _is_interval(expression: exp.Expression) -> bool:
568
216
  return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
569
217
 
570
218
 
571
- @catch(ModuleNotFoundError, UnsupportedUnit)
572
- def simplify_equality(expression: exp.Expression) -> exp.Expression:
573
- """
574
- Use the subtraction and addition properties of equality to simplify expressions:
575
-
576
- x + 1 = 3 becomes x = 2
577
-
578
- There are two binary operations in the above expression: + and =
579
- Here's how we reference all the operands in the code below:
580
-
581
- l r
582
- x + 1 = 3
583
- a b
584
- """
585
- if isinstance(expression, COMPARISONS):
586
- l, r = expression.left, expression.right
587
-
588
- if l.__class__ not in INVERSE_OPS:
589
- return expression
590
-
591
- if r.is_number:
592
- a_predicate = _is_number
593
- b_predicate = _is_number
594
- elif _is_date_literal(r):
595
- a_predicate = _is_date_literal
596
- b_predicate = _is_interval
597
- else:
598
- return expression
599
-
600
- if l.__class__ in INVERSE_DATE_OPS:
601
- l = t.cast(exp.IntervalOp, l)
602
- a = l.this
603
- b = l.interval()
604
- else:
605
- l = t.cast(exp.Binary, l)
606
- a, b = l.left, l.right
607
-
608
- if not a_predicate(a) and b_predicate(b):
609
- pass
610
- elif not a_predicate(b) and b_predicate(a):
611
- a, b = b, a
612
- else:
613
- return expression
614
-
615
- return expression.__class__(
616
- this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
617
- )
618
- return expression
619
-
620
-
621
- def simplify_literals(expression, root=True):
622
- if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
623
- return _flat_simplify(expression, _simplify_binary, root)
624
-
625
- if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
626
- return expression.this.this
627
-
628
- if type(expression) in INVERSE_DATE_OPS:
629
- return _simplify_binary(expression, expression.this, expression.interval()) or expression
630
-
631
- return expression
632
-
633
-
634
- NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
635
-
636
-
637
- def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
638
- if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
639
- this = _simplify_integer_cast(expr.this)
640
- else:
641
- this = expr.this
642
-
643
- if isinstance(expr, exp.Cast) and this.is_int:
644
- num = this.to_py()
645
-
646
- # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
647
- # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
648
- # engine-dependent
649
- if (
650
- TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
651
- ) or (
652
- UTINYINT_MIN <= num <= UTINYINT_MAX
653
- and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
654
- ):
655
- return this
656
-
657
- return expr
658
-
659
-
660
- def _simplify_binary(expression, a, b):
661
- if isinstance(expression, COMPARISONS):
662
- a = _simplify_integer_cast(a)
663
- b = _simplify_integer_cast(b)
664
-
665
- if isinstance(expression, exp.Is):
666
- if isinstance(b, exp.Not):
667
- c = b.this
668
- not_ = True
669
- else:
670
- c = b
671
- not_ = False
672
-
673
- if is_null(c):
674
- if isinstance(a, exp.Literal):
675
- return exp.true() if not_ else exp.false()
676
- if is_null(a):
677
- return exp.false() if not_ else exp.true()
678
- elif isinstance(expression, NULL_OK):
679
- return None
680
- elif is_null(a) or is_null(b):
681
- return exp.null()
682
-
683
- if a.is_number and b.is_number:
684
- num_a = a.to_py()
685
- num_b = b.to_py()
686
-
687
- if isinstance(expression, exp.Add):
688
- return exp.Literal.number(num_a + num_b)
689
- if isinstance(expression, exp.Mul):
690
- return exp.Literal.number(num_a * num_b)
691
-
692
- # We only simplify Sub, Div if a and b have the same parent because they're not associative
693
- if isinstance(expression, exp.Sub):
694
- return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
695
- if isinstance(expression, exp.Div):
696
- # engines have differing int div behavior so intdiv is not safe
697
- if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
698
- return None
699
- return exp.Literal.number(num_a / num_b)
700
-
701
- boolean = eval_boolean(expression, num_a, num_b)
702
-
703
- if boolean:
704
- return boolean
705
- elif a.is_string and b.is_string:
706
- boolean = eval_boolean(expression, a.this, b.this)
707
-
708
- if boolean:
709
- return boolean
710
- elif _is_date_literal(a) and isinstance(b, exp.Interval):
711
- date, b = extract_date(a), extract_interval(b)
712
- if date and b:
713
- if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
714
- return date_literal(date + b, extract_type(a))
715
- if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
716
- return date_literal(date - b, extract_type(a))
717
- elif isinstance(a, exp.Interval) and _is_date_literal(b):
718
- a, date = extract_interval(a), extract_date(b)
719
- # you cannot subtract a date from an interval
720
- if a and b and isinstance(expression, exp.Add):
721
- return date_literal(a + date, extract_type(b))
722
- elif _is_date_literal(a) and _is_date_literal(b):
723
- if isinstance(expression, exp.Predicate):
724
- a, b = extract_date(a), extract_date(b)
725
- boolean = eval_boolean(expression, a, b)
726
- if boolean:
727
- return boolean
728
-
729
- return None
730
-
731
-
732
- def simplify_parens(expression: exp.Expression, dialect: DialectType = None) -> exp.Expression:
733
- if not isinstance(expression, exp.Paren):
734
- return expression
735
-
736
- this = expression.this
737
- parent = expression.parent
738
- parent_is_predicate = isinstance(parent, exp.Predicate)
739
-
740
- if isinstance(this, exp.Select):
741
- return expression
742
-
743
- if isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)):
744
- return expression
745
-
746
- # Handle risingwave struct columns
747
- # see https://docs.risingwave.com/sql/data-types/struct#retrieve-data-in-a-struct
748
- if (
749
- dialect == "risingwave"
750
- and isinstance(parent, exp.Dot)
751
- and (isinstance(parent.right, (exp.Identifier, exp.Star)))
752
- ):
753
- return expression
754
-
755
- if (
756
- not isinstance(parent, (exp.Condition, exp.Binary))
757
- or isinstance(parent, exp.Paren)
758
- or (
759
- not isinstance(this, exp.Binary)
760
- and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
761
- )
762
- or (isinstance(this, exp.Predicate) and not parent_is_predicate)
763
- or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
764
- or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
765
- or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
766
- ):
767
- return this
768
-
769
- return expression
770
-
771
-
772
219
  def _is_nonnull_constant(expression: exp.Expression) -> bool:
773
220
  return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
774
221
 
@@ -777,164 +224,6 @@ def _is_constant(expression: exp.Expression) -> bool:
777
224
  return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression)
778
225
 
779
226
 
780
- def simplify_coalesce(expression: exp.Expression, dialect: DialectType) -> exp.Expression:
781
- # COALESCE(x) -> x
782
- if (
783
- isinstance(expression, exp.Coalesce)
784
- and (not expression.expressions or _is_nonnull_constant(expression.this))
785
- # COALESCE is also used as a Spark partitioning hint
786
- and not isinstance(expression.parent, exp.Hint)
787
- ):
788
- return expression.this
789
-
790
- # We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift,
791
- # because they are not always equivalent. For example, if `x` is `NULL` and it comes
792
- # from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE`
793
- if dialect == "redshift":
794
- return expression
795
-
796
- if not isinstance(expression, COMPARISONS):
797
- return expression
798
-
799
- if isinstance(expression.left, exp.Coalesce):
800
- coalesce = expression.left
801
- other = expression.right
802
- elif isinstance(expression.right, exp.Coalesce):
803
- coalesce = expression.right
804
- other = expression.left
805
- else:
806
- return expression
807
-
808
- # This transformation is valid for non-constants,
809
- # but it really only does anything if they are both constants.
810
- if not _is_constant(other):
811
- return expression
812
-
813
- # Find the first constant arg
814
- for arg_index, arg in enumerate(coalesce.expressions):
815
- if _is_constant(arg):
816
- break
817
- else:
818
- return expression
819
-
820
- coalesce.set("expressions", coalesce.expressions[:arg_index])
821
-
822
- # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
823
- # since we already remove COALESCE at the top of this function.
824
- coalesce = coalesce if coalesce.expressions else coalesce.this
825
-
826
- # This expression is more complex than when we started, but it will get simplified further
827
- return exp.paren(
828
- exp.or_(
829
- exp.and_(
830
- coalesce.is_(exp.null()).not_(copy=False),
831
- expression.copy(),
832
- copy=False,
833
- ),
834
- exp.and_(
835
- coalesce.is_(exp.null()),
836
- type(expression)(this=arg.copy(), expression=other.copy()),
837
- copy=False,
838
- ),
839
- copy=False,
840
- )
841
- )
842
-
843
-
844
- CONCATS = (exp.Concat, exp.DPipe)
845
-
846
-
847
- def simplify_concat(expression):
848
- """Reduces all groups that contain string literals by concatenating them."""
849
- if not isinstance(expression, CONCATS) or (
850
- # We can't reduce a CONCAT_WS call if we don't statically know the separator
851
- isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
852
- ):
853
- return expression
854
-
855
- if isinstance(expression, exp.ConcatWs):
856
- sep_expr, *expressions = expression.expressions
857
- sep = sep_expr.name
858
- concat_type = exp.ConcatWs
859
- args = {}
860
- else:
861
- expressions = expression.expressions
862
- sep = ""
863
- concat_type = exp.Concat
864
- args = {
865
- "safe": expression.args.get("safe"),
866
- "coalesce": expression.args.get("coalesce"),
867
- }
868
-
869
- new_args = []
870
- for is_string_group, group in itertools.groupby(
871
- expressions or expression.flatten(), lambda e: e.is_string
872
- ):
873
- if is_string_group:
874
- new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
875
- else:
876
- new_args.extend(group)
877
-
878
- if len(new_args) == 1 and new_args[0].is_string:
879
- return new_args[0]
880
-
881
- if concat_type is exp.ConcatWs:
882
- new_args = [sep_expr] + new_args
883
- elif isinstance(expression, exp.DPipe):
884
- return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
885
-
886
- return concat_type(expressions=new_args, **args)
887
-
888
-
889
- def simplify_conditionals(expression):
890
- """Simplifies expressions like IF, CASE if their condition is statically known."""
891
- if isinstance(expression, exp.Case):
892
- this = expression.this
893
- for case in expression.args["ifs"]:
894
- cond = case.this
895
- if this:
896
- # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
897
- cond = cond.replace(this.pop().eq(cond))
898
-
899
- if always_true(cond):
900
- return case.args["true"]
901
-
902
- if always_false(cond):
903
- case.pop()
904
- if not expression.args["ifs"]:
905
- return expression.args.get("default") or exp.null()
906
- elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
907
- if always_true(expression.this):
908
- return expression.args["true"]
909
- if always_false(expression.this):
910
- return expression.args.get("false") or exp.null()
911
-
912
- return expression
913
-
914
-
915
- def simplify_startswith(expression: exp.Expression) -> exp.Expression:
916
- """
917
- Reduces a prefix check to either TRUE or FALSE if both the string and the
918
- prefix are statically known.
919
-
920
- Example:
921
- >>> from sqlglot import parse_one
922
- >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
923
- 'TRUE'
924
- """
925
- if (
926
- isinstance(expression, exp.StartsWith)
927
- and expression.this.is_string
928
- and expression.expression.is_string
929
- ):
930
- return exp.convert(expression.name.startswith(expression.expression.name))
931
-
932
- return expression
933
-
934
-
935
- DateRange = t.Tuple[datetime.date, datetime.date]
936
-
937
-
938
227
  def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
939
228
  """
940
229
  Get the date range for a DATE_TRUNC equality comparison:
@@ -960,175 +249,45 @@ def _datetrunc_eq_expression(
960
249
  return exp.and_(
961
250
  left >= date_literal(drange[0], target_type),
962
251
  left < date_literal(drange[1], target_type),
963
- copy=False,
964
- )
965
-
966
-
967
- def _datetrunc_eq(
968
- left: exp.Expression,
969
- date: datetime.date,
970
- unit: str,
971
- dialect: Dialect,
972
- target_type: t.Optional[exp.DataType],
973
- ) -> t.Optional[exp.Expression]:
974
- drange = _datetrunc_range(date, unit, dialect)
975
- if not drange:
976
- return None
977
-
978
- return _datetrunc_eq_expression(left, drange, target_type)
979
-
980
-
981
- def _datetrunc_neq(
982
- left: exp.Expression,
983
- date: datetime.date,
984
- unit: str,
985
- dialect: Dialect,
986
- target_type: t.Optional[exp.DataType],
987
- ) -> t.Optional[exp.Expression]:
988
- drange = _datetrunc_range(date, unit, dialect)
989
- if not drange:
990
- return None
991
-
992
- return exp.and_(
993
- left < date_literal(drange[0], target_type),
994
- left >= date_literal(drange[1], target_type),
995
- copy=False,
996
- )
997
-
998
-
999
- DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
1000
- exp.LT: lambda l, dt, u, d, t: l
1001
- < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
1002
- exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
1003
- exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
1004
- exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
1005
- exp.EQ: _datetrunc_eq,
1006
- exp.NEQ: _datetrunc_neq,
1007
- }
1008
- DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
1009
- DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
1010
-
1011
-
1012
- def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
1013
- return isinstance(left, DATETRUNCS) and _is_date_literal(right)
1014
-
1015
-
1016
- @catch(ModuleNotFoundError, UnsupportedUnit)
1017
- def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
1018
- """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
1019
- comparison = expression.__class__
1020
-
1021
- if isinstance(expression, DATETRUNCS):
1022
- this = expression.this
1023
- trunc_type = extract_type(this)
1024
- date = extract_date(this)
1025
- if date and expression.unit:
1026
- return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type)
1027
- elif comparison not in DATETRUNC_COMPARISONS:
1028
- return expression
1029
-
1030
- if isinstance(expression, exp.Binary):
1031
- l, r = expression.left, expression.right
1032
-
1033
- if not _is_datetrunc_predicate(l, r):
1034
- return expression
1035
-
1036
- l = t.cast(exp.DateTrunc, l)
1037
- trunc_arg = l.this
1038
- unit = l.unit.name.lower()
1039
- date = extract_date(r)
1040
-
1041
- if not date:
1042
- return expression
1043
-
1044
- return (
1045
- DATETRUNC_BINARY_COMPARISONS[comparison](
1046
- trunc_arg, date, unit, dialect, extract_type(r)
1047
- )
1048
- or expression
1049
- )
1050
-
1051
- if isinstance(expression, exp.In):
1052
- l = expression.this
1053
- rs = expression.expressions
1054
-
1055
- if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
1056
- l = t.cast(exp.DateTrunc, l)
1057
- unit = l.unit.name.lower()
1058
-
1059
- ranges = []
1060
- for r in rs:
1061
- date = extract_date(r)
1062
- if not date:
1063
- return expression
1064
- drange = _datetrunc_range(date, unit, dialect)
1065
- if drange:
1066
- ranges.append(drange)
1067
-
1068
- if not ranges:
1069
- return expression
1070
-
1071
- ranges = merge_ranges(ranges)
1072
- target_type = extract_type(*rs)
1073
-
1074
- return exp.or_(
1075
- *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False
1076
- )
1077
-
1078
- return expression
1079
-
1080
-
1081
- def sort_comparison(expression: exp.Expression) -> exp.Expression:
1082
- if expression.__class__ in COMPLEMENT_COMPARISONS:
1083
- l, r = expression.this, expression.expression
1084
- l_column = isinstance(l, exp.Column)
1085
- r_column = isinstance(r, exp.Column)
1086
- l_const = _is_constant(l)
1087
- r_const = _is_constant(r)
1088
-
1089
- if (
1090
- (l_column and not r_column)
1091
- or (r_const and not l_const)
1092
- or isinstance(r, exp.SubqueryPredicate)
1093
- ):
1094
- return expression
1095
- if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1096
- return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1097
- this=r, expression=l
1098
- )
1099
- return expression
252
+ copy=False,
253
+ )
1100
254
 
1101
255
 
1102
- # CROSS joins result in an empty table if the right table is empty.
1103
- # So we can only simplify certain types of joins to CROSS.
1104
- # Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
1105
- JOINS = {
1106
- ("", ""),
1107
- ("", "INNER"),
1108
- ("RIGHT", ""),
1109
- ("RIGHT", "OUTER"),
1110
- }
256
+ def _datetrunc_eq(
257
+ left: exp.Expression,
258
+ date: datetime.date,
259
+ unit: str,
260
+ dialect: Dialect,
261
+ target_type: t.Optional[exp.DataType],
262
+ ) -> t.Optional[exp.Expression]:
263
+ drange = _datetrunc_range(date, unit, dialect)
264
+ if not drange:
265
+ return None
266
+
267
+ return _datetrunc_eq_expression(left, drange, target_type)
1111
268
 
1112
269
 
1113
- def remove_where_true(expression):
1114
- for where in expression.find_all(exp.Where):
1115
- if always_true(where.this):
1116
- where.pop()
1117
- for join in expression.find_all(exp.Join):
1118
- if (
1119
- always_true(join.args.get("on"))
1120
- and not join.args.get("using")
1121
- and not join.args.get("method")
1122
- and (join.side, join.kind) in JOINS
1123
- ):
1124
- join.args["on"].pop()
1125
- join.set("side", None)
1126
- join.set("kind", "CROSS")
270
+ def _datetrunc_neq(
271
+ left: exp.Expression,
272
+ date: datetime.date,
273
+ unit: str,
274
+ dialect: Dialect,
275
+ target_type: t.Optional[exp.DataType],
276
+ ) -> t.Optional[exp.Expression]:
277
+ drange = _datetrunc_range(date, unit, dialect)
278
+ if not drange:
279
+ return None
280
+
281
+ return exp.and_(
282
+ left < date_literal(drange[0], target_type),
283
+ left >= date_literal(drange[1], target_type),
284
+ copy=False,
285
+ )
1127
286
 
1128
287
 
1129
288
  def always_true(expression):
1130
289
  return (isinstance(expression, exp.Boolean) and expression.this) or (
1131
- isinstance(expression, exp.Literal) and not is_zero(expression)
290
+ isinstance(expression, exp.Literal) and expression.is_number and not is_zero(expression)
1132
291
  )
1133
292
 
1134
293
 
@@ -1310,30 +469,987 @@ def boolean_literal(condition):
1310
469
  return exp.true() if condition else exp.false()
1311
470
 
1312
471
 
1313
- def _flat_simplify(expression, simplifier, root=True):
1314
- if root or not expression.same_parent:
1315
- operands = []
1316
- queue = deque(expression.flatten(unnest=False))
1317
- size = len(queue)
472
+ class Simplifier:
473
+ def __init__(self, dialect: DialectType = None, annotate_new_expressions: bool = True):
474
+ self.dialect = Dialect.get_or_raise(dialect)
475
+ self.annotate_new_expressions = annotate_new_expressions
476
+
477
+ self._annotator: TypeAnnotator = TypeAnnotator(
478
+ schema=ensure_schema(None, dialect=self.dialect), overwrite_types=False
479
+ )
480
+
481
+ # Value ranges for byte-sized signed/unsigned integers
482
+ TINYINT_MIN = -128
483
+ TINYINT_MAX = 127
484
+ UTINYINT_MIN = 0
485
+ UTINYINT_MAX = 255
486
+
487
+ COMPLEMENT_COMPARISONS = {
488
+ exp.LT: exp.GTE,
489
+ exp.GT: exp.LTE,
490
+ exp.LTE: exp.GT,
491
+ exp.GTE: exp.LT,
492
+ exp.EQ: exp.NEQ,
493
+ exp.NEQ: exp.EQ,
494
+ }
495
+
496
+ COMPLEMENT_SUBQUERY_PREDICATES = {
497
+ exp.All: exp.Any,
498
+ exp.Any: exp.All,
499
+ }
500
+
501
+ LT_LTE = (exp.LT, exp.LTE)
502
+ GT_GTE = (exp.GT, exp.GTE)
503
+
504
+ COMPARISONS = (
505
+ *LT_LTE,
506
+ *GT_GTE,
507
+ exp.EQ,
508
+ exp.NEQ,
509
+ exp.Is,
510
+ )
511
+
512
+ INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
513
+ exp.LT: exp.GT,
514
+ exp.GT: exp.LT,
515
+ exp.LTE: exp.GTE,
516
+ exp.GTE: exp.LTE,
517
+ }
518
+
519
+ NONDETERMINISTIC = (exp.Rand, exp.Randn)
520
+ AND_OR = (exp.And, exp.Or)
521
+
522
+ INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
523
+ exp.DateAdd: exp.Sub,
524
+ exp.DateSub: exp.Add,
525
+ exp.DatetimeAdd: exp.Sub,
526
+ exp.DatetimeSub: exp.Add,
527
+ }
528
+
529
+ INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
530
+ **INVERSE_DATE_OPS,
531
+ exp.Add: exp.Sub,
532
+ exp.Sub: exp.Add,
533
+ }
534
+
535
+ NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
536
+
537
+ CONCATS = (exp.Concat, exp.DPipe)
538
+
539
+ DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
540
+ exp.LT: lambda l, dt, u, d, t: l
541
+ < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
542
+ exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
543
+ exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
544
+ exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
545
+ exp.EQ: _datetrunc_eq,
546
+ exp.NEQ: _datetrunc_neq,
547
+ }
548
+
549
+ DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
550
+ DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
551
+
552
+ SAFE_CONNECTOR_ELIMINATION_RESULT = (exp.Connector, exp.Boolean)
553
+
554
+ # CROSS joins result in an empty table if the right table is empty.
555
+ # So we can only simplify certain types of joins to CROSS.
556
+ # Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
557
+ JOINS = {
558
+ ("", ""),
559
+ ("", "INNER"),
560
+ ("RIGHT", ""),
561
+ ("RIGHT", "OUTER"),
562
+ }
563
+
564
+ def simplify(
565
+ self,
566
+ expression: exp.Expression,
567
+ constant_propagation: bool = False,
568
+ coalesce_simplification: bool = False,
569
+ ):
570
+ wheres = []
571
+ joins = []
572
+
573
+ for node in expression.walk(
574
+ prune=lambda n: bool(isinstance(n, exp.Condition) or n.meta.get(FINAL))
575
+ ):
576
+ if node.meta.get(FINAL):
577
+ continue
578
+
579
+ # group by expressions cannot be simplified, for example
580
+ # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
581
+ # the projection must exactly match the group by key
582
+ group = node.args.get("group")
583
+
584
+ if group and hasattr(node, "selects"):
585
+ groups = set(group.expressions)
586
+ group.meta[FINAL] = True
587
+
588
+ for s in node.selects:
589
+ for n in s.walk(FINAL):
590
+ if n in groups:
591
+ s.meta[FINAL] = True
592
+ break
593
+
594
+ having = node.args.get("having")
595
+
596
+ if having:
597
+ for n in having.walk():
598
+ if n in groups:
599
+ having.meta[FINAL] = True
600
+ break
601
+
602
+ if isinstance(node, exp.Condition):
603
+ simplified = while_changing(
604
+ node, lambda e: self._simplify(e, constant_propagation, coalesce_simplification)
605
+ )
606
+
607
+ if node is expression:
608
+ expression = simplified
609
+ elif isinstance(node, exp.Where):
610
+ wheres.append(node)
611
+ elif isinstance(node, exp.Join):
612
+ # snowflake match_conditions have very strict ordering rules
613
+ if match := node.args.get("match_condition"):
614
+ match.meta[FINAL] = True
615
+
616
+ joins.append(node)
617
+
618
+ for where in wheres:
619
+ if always_true(where.this):
620
+ where.pop()
621
+ for join in joins:
622
+ if (
623
+ always_true(join.args.get("on"))
624
+ and not join.args.get("using")
625
+ and not join.args.get("method")
626
+ and (join.side, join.kind) in self.JOINS
627
+ ):
628
+ join.args["on"].pop()
629
+ join.set("side", None)
630
+ join.set("kind", "CROSS")
631
+
632
+ return expression
633
+
634
+ def _simplify(
635
+ self, expression: exp.Expression, constant_propagation: bool, coalesce_simplification: bool
636
+ ):
637
+ pre_transformation_stack = [expression]
638
+ post_transformation_stack = []
639
+
640
+ while pre_transformation_stack:
641
+ original = pre_transformation_stack.pop()
642
+ node = original
643
+
644
+ if not isinstance(node, SIMPLIFIABLE):
645
+ if isinstance(node, exp.Query):
646
+ self.simplify(node, constant_propagation, coalesce_simplification)
647
+ continue
648
+
649
+ parent = node.parent
650
+ root = node is expression
651
+
652
+ node = self.rewrite_between(node)
653
+ node = self.uniq_sort(node, root)
654
+ node = self.absorb_and_eliminate(node, root)
655
+ node = self.simplify_concat(node)
656
+ node = self.simplify_conditionals(node)
657
+
658
+ if constant_propagation:
659
+ node = propagate_constants(node, root)
660
+
661
+ if node is not original:
662
+ original.replace(node)
663
+
664
+ for n in node.iter_expressions(reverse=True):
665
+ if n.meta.get(FINAL):
666
+ raise
667
+ pre_transformation_stack.extend(
668
+ n for n in node.iter_expressions(reverse=True) if not n.meta.get(FINAL)
669
+ )
670
+ post_transformation_stack.append((node, parent))
671
+
672
+ while post_transformation_stack:
673
+ original, parent = post_transformation_stack.pop()
674
+ root = original is expression
675
+
676
+ # Resets parent, arg_key, index pointers– this is needed because some of the
677
+ # previous transformations mutate the AST, leading to an inconsistent state
678
+ for k, v in tuple(original.args.items()):
679
+ original.set(k, v)
680
+
681
+ # Post-order transformations
682
+ node = self.simplify_not(original)
683
+ node = flatten(node)
684
+ node = self.simplify_connectors(node, root)
685
+ node = self.remove_complements(node, root)
686
+
687
+ if coalesce_simplification:
688
+ node = self.simplify_coalesce(node)
689
+ node.parent = parent
690
+
691
+ node = self.simplify_literals(node, root)
692
+ node = self.simplify_equality(node)
693
+ node = simplify_parens(node, dialect=self.dialect)
694
+ node = self.simplify_datetrunc(node)
695
+ node = self.sort_comparison(node)
696
+ node = self.simplify_startswith(node)
697
+
698
+ if node is not original:
699
+ original.replace(node)
700
+
701
+ return node
702
+
703
+ @annotate_types_on_change
704
+ def rewrite_between(self, expression: exp.Expression) -> exp.Expression:
705
+ """Rewrite x between y and z to x >= y AND x <= z.
706
+
707
+ This is done because comparison simplification is only done on lt/lte/gt/gte.
708
+ """
709
+ if isinstance(expression, exp.Between):
710
+ negate = isinstance(expression.parent, exp.Not)
711
+
712
+ expression = exp.and_(
713
+ exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
714
+ exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
715
+ copy=False,
716
+ )
717
+
718
+ if negate:
719
+ expression = exp.paren(expression, copy=False)
720
+
721
+ return expression
722
+
723
+ @annotate_types_on_change
724
+ def simplify_not(self, expression: exp.Expression) -> exp.Expression:
725
+ """
726
+ Demorgan's Law
727
+ NOT (x OR y) -> NOT x AND NOT y
728
+ NOT (x AND y) -> NOT x OR NOT y
729
+ """
730
+ if isinstance(expression, exp.Not):
731
+ this = expression.this
732
+ if is_null(this):
733
+ return exp.and_(exp.null(), exp.true(), copy=False)
734
+ if this.__class__ in self.COMPLEMENT_COMPARISONS:
735
+ right = this.expression
736
+ complement_subquery_predicate = self.COMPLEMENT_SUBQUERY_PREDICATES.get(
737
+ right.__class__
738
+ )
739
+ if complement_subquery_predicate:
740
+ right = complement_subquery_predicate(this=right.this)
741
+
742
+ return self.COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
743
+ if isinstance(this, exp.Paren):
744
+ condition = this.unnest()
745
+ if isinstance(condition, exp.And):
746
+ return exp.paren(
747
+ exp.or_(
748
+ exp.not_(condition.left, copy=False),
749
+ exp.not_(condition.right, copy=False),
750
+ copy=False,
751
+ ),
752
+ copy=False,
753
+ )
754
+ if isinstance(condition, exp.Or):
755
+ return exp.paren(
756
+ exp.and_(
757
+ exp.not_(condition.left, copy=False),
758
+ exp.not_(condition.right, copy=False),
759
+ copy=False,
760
+ ),
761
+ copy=False,
762
+ )
763
+ if is_null(condition):
764
+ return exp.and_(exp.null(), exp.true(), copy=False)
765
+ if always_true(this):
766
+ return exp.false()
767
+ if is_false(this):
768
+ return exp.true()
769
+ if isinstance(this, exp.Not) and self.dialect.SAFE_TO_ELIMINATE_DOUBLE_NEGATION:
770
+ inner = this.this
771
+ if inner.is_type(exp.DataType.Type.BOOLEAN):
772
+ # double negation
773
+ # NOT NOT x -> x, if x is BOOLEAN type
774
+ return inner
775
+ return expression
776
+
777
+ @annotate_types_on_change
778
+ def simplify_connectors(self, expression, root=True):
779
+ def _simplify_connectors(expression, left, right):
780
+ if isinstance(expression, exp.And):
781
+ if is_false(left) or is_false(right):
782
+ return exp.false()
783
+ if is_zero(left) or is_zero(right):
784
+ return exp.false()
785
+ if (
786
+ (is_null(left) and is_null(right))
787
+ or (is_null(left) and always_true(right))
788
+ or (always_true(left) and is_null(right))
789
+ ):
790
+ return exp.null()
791
+ if always_true(left) and always_true(right):
792
+ return exp.true()
793
+ if always_true(left):
794
+ return right
795
+ if always_true(right):
796
+ return left
797
+ return self._simplify_comparison(expression, left, right)
798
+ elif isinstance(expression, exp.Or):
799
+ if always_true(left) or always_true(right):
800
+ return exp.true()
801
+ if (
802
+ (is_null(left) and is_null(right))
803
+ or (is_null(left) and always_false(right))
804
+ or (always_false(left) and is_null(right))
805
+ ):
806
+ return exp.null()
807
+ if is_false(left):
808
+ return right
809
+ if is_false(right):
810
+ return left
811
+ return self._simplify_comparison(expression, left, right, or_=True)
812
+
813
+ if isinstance(expression, exp.Connector):
814
+ original_parent = expression.parent
815
+ expression = self._flat_simplify(expression, _simplify_connectors, root)
816
+
817
+ # If we reduced a connector to, e.g., a column (t1 AND ... AND tn -> Tk), then we need
818
+ # to ensure that the resulting type is boolean. We know this is true only for connectors,
819
+ # boolean values and columns that are essentially operands to a connector:
820
+ #
821
+ # A AND (((B)))
822
+ # ~ this is safe to keep because it will eventually be part of another connector
823
+ if not isinstance(
824
+ expression, self.SAFE_CONNECTOR_ELIMINATION_RESULT
825
+ ) and not expression.is_type(exp.DataType.Type.BOOLEAN):
826
+ while True:
827
+ if isinstance(original_parent, exp.Connector):
828
+ break
829
+ if not isinstance(original_parent, exp.Paren):
830
+ expression = expression.and_(exp.true(), copy=False)
831
+ break
832
+
833
+ original_parent = original_parent.parent
834
+
835
+ return expression
836
+
837
+ @annotate_types_on_change
838
+ def _simplify_comparison(self, expression, left, right, or_=False):
839
+ if isinstance(left, self.COMPARISONS) and isinstance(right, self.COMPARISONS):
840
+ ll, lr = left.args.values()
841
+ rl, rr = right.args.values()
842
+
843
+ largs = {ll, lr}
844
+ rargs = {rl, rr}
845
+
846
+ matching = largs & rargs
847
+ columns = {
848
+ m for m in matching if not _is_constant(m) and not m.find(*self.NONDETERMINISTIC)
849
+ }
850
+
851
+ if matching and columns:
852
+ try:
853
+ l = first(largs - columns)
854
+ r = first(rargs - columns)
855
+ except StopIteration:
856
+ return expression
857
+
858
+ if l.is_number and r.is_number:
859
+ l = l.to_py()
860
+ r = r.to_py()
861
+ elif l.is_string and r.is_string:
862
+ l = l.name
863
+ r = r.name
864
+ else:
865
+ l = extract_date(l)
866
+ if not l:
867
+ return None
868
+ r = extract_date(r)
869
+ if not r:
870
+ return None
871
+ # python won't compare date and datetime, but many engines will upcast
872
+ l, r = cast_as_datetime(l), cast_as_datetime(r)
873
+
874
+ for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
875
+ if isinstance(a, self.LT_LTE) and isinstance(b, self.LT_LTE):
876
+ return left if (av > bv if or_ else av <= bv) else right
877
+ if isinstance(a, self.GT_GTE) and isinstance(b, self.GT_GTE):
878
+ return left if (av < bv if or_ else av >= bv) else right
879
+
880
+ # we can't ever shortcut to true because the column could be null
881
+ if not or_:
882
+ if isinstance(a, exp.LT) and isinstance(b, self.GT_GTE):
883
+ if av <= bv:
884
+ return exp.false()
885
+ elif isinstance(a, exp.GT) and isinstance(b, self.LT_LTE):
886
+ if av >= bv:
887
+ return exp.false()
888
+ elif isinstance(a, exp.EQ):
889
+ if isinstance(b, exp.LT):
890
+ return exp.false() if av >= bv else a
891
+ if isinstance(b, exp.LTE):
892
+ return exp.false() if av > bv else a
893
+ if isinstance(b, exp.GT):
894
+ return exp.false() if av <= bv else a
895
+ if isinstance(b, exp.GTE):
896
+ return exp.false() if av < bv else a
897
+ if isinstance(b, exp.NEQ):
898
+ return exp.false() if av == bv else a
899
+ return None
1318
900
 
1319
- while queue:
1320
- a = queue.popleft()
901
+ @annotate_types_on_change
902
+ def remove_complements(self, expression, root=True):
903
+ """
904
+ Removing complements.
905
+
906
+ A AND NOT A -> FALSE (only for non-NULL A)
907
+ A OR NOT A -> TRUE (only for non-NULL A)
908
+ """
909
+ if isinstance(expression, self.AND_OR) and (root or not expression.same_parent):
910
+ ops = set(expression.flatten())
911
+ for op in ops:
912
+ if isinstance(op, exp.Not) and op.this in ops:
913
+ if expression.meta.get("nonnull") is True:
914
+ return exp.false() if isinstance(expression, exp.And) else exp.true()
1321
915
 
1322
- for b in queue:
1323
- result = simplifier(expression, a, b)
916
+ return expression
1324
917
 
1325
- if result and result is not expression:
1326
- queue.remove(b)
1327
- queue.appendleft(result)
918
+ @annotate_types_on_change
919
+ def uniq_sort(self, expression, root=True):
920
+ """
921
+ Uniq and sort a connector.
922
+
923
+ C AND A AND B AND B -> A AND B AND C
924
+ """
925
+ if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
926
+ flattened = tuple(expression.flatten())
927
+
928
+ if isinstance(expression, exp.Xor):
929
+ result_func = exp.xor
930
+ # Do not deduplicate XOR as A XOR A != A if A == True
931
+ deduped = None
932
+ arr = tuple((gen(e), e) for e in flattened)
933
+ else:
934
+ result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
935
+ deduped = {gen(e): e for e in flattened}
936
+ arr = tuple(deduped.items())
937
+
938
+ # check if the operands are already sorted, if not sort them
939
+ # A AND C AND B -> A AND B AND C
940
+ for i, (sql, e) in enumerate(arr[1:]):
941
+ if sql < arr[i][0]:
942
+ expression = result_func(*(e for _, e in sorted(arr)), copy=False)
1328
943
  break
1329
944
  else:
1330
- operands.append(a)
945
+ # we didn't have to sort but maybe we need to dedup
946
+ if deduped and len(deduped) < len(flattened):
947
+ unique_operand = flattened[0]
948
+ if len(deduped) == 1:
949
+ expression = unique_operand.and_(exp.true(), copy=False)
950
+ else:
951
+ expression = result_func(*deduped.values(), copy=False)
952
+
953
+ return expression
954
+
955
+ @annotate_types_on_change
956
+ def absorb_and_eliminate(self, expression, root=True):
957
+ """
958
+ absorption:
959
+ A AND (A OR B) -> A
960
+ A OR (A AND B) -> A
961
+ A AND (NOT A OR B) -> A AND B
962
+ A OR (NOT A AND B) -> A OR B
963
+ elimination:
964
+ (A AND B) OR (A AND NOT B) -> A
965
+ (A OR B) AND (A OR NOT B) -> A
966
+ """
967
+ if isinstance(expression, self.AND_OR) and (root or not expression.same_parent):
968
+ kind = exp.Or if isinstance(expression, exp.And) else exp.And
969
+
970
+ ops = tuple(expression.flatten())
971
+
972
+ # Initialize lookup tables:
973
+ # Set of all operands, used to find complements for absorption.
974
+ op_set = set()
975
+ # Sub-operands, used to find subsets for absorption.
976
+ subops = defaultdict(list)
977
+ # Pairs of complements, used for elimination.
978
+ pairs = defaultdict(list)
979
+
980
+ # Populate the lookup tables
981
+ for op in ops:
982
+ op_set.add(op)
983
+
984
+ if not isinstance(op, kind):
985
+ # In cases like: A OR (A AND B)
986
+ # Subop will be: ^
987
+ subops[op].append({op})
988
+ continue
989
+
990
+ # In cases like: (A AND B) OR (A AND B AND C)
991
+ # Subops will be: ^ ^
992
+ subset = set(op.flatten())
993
+ for i in subset:
994
+ subops[i].append(subset)
995
+
996
+ a, b = op.unnest_operands()
997
+ if isinstance(a, exp.Not):
998
+ pairs[frozenset((a.this, b))].append((op, b))
999
+ if isinstance(b, exp.Not):
1000
+ pairs[frozenset((a, b.this))].append((op, a))
1001
+
1002
+ for op in ops:
1003
+ if not isinstance(op, kind):
1004
+ continue
1005
+
1006
+ a, b = op.unnest_operands()
1007
+
1008
+ # Absorb
1009
+ if isinstance(a, exp.Not) and a.this in op_set:
1010
+ a.replace(exp.true() if kind == exp.And else exp.false())
1011
+ continue
1012
+ if isinstance(b, exp.Not) and b.this in op_set:
1013
+ b.replace(exp.true() if kind == exp.And else exp.false())
1014
+ continue
1015
+ superset = set(op.flatten())
1016
+ if any(any(subset < superset for subset in subops[i]) for i in superset):
1017
+ op.replace(exp.false() if kind == exp.And else exp.true())
1018
+ continue
1019
+
1020
+ # Eliminate
1021
+ for other, complement in pairs[frozenset((a, b))]:
1022
+ op.replace(complement)
1023
+ other.replace(complement)
1024
+
1025
+ return expression
1026
+
1027
+ @annotate_types_on_change
1028
+ @catch(ModuleNotFoundError, UnsupportedUnit)
1029
+ def simplify_equality(self, expression: exp.Expression) -> exp.Expression:
1030
+ """
1031
+ Use the subtraction and addition properties of equality to simplify expressions:
1032
+
1033
+ x + 1 = 3 becomes x = 2
1034
+
1035
+ There are two binary operations in the above expression: + and =
1036
+ Here's how we reference all the operands in the code below:
1037
+
1038
+ l r
1039
+ x + 1 = 3
1040
+ a b
1041
+ """
1042
+ if isinstance(expression, self.COMPARISONS):
1043
+ l, r = expression.left, expression.right
1044
+
1045
+ if l.__class__ not in self.INVERSE_OPS:
1046
+ return expression
1047
+
1048
+ if r.is_number:
1049
+ a_predicate = _is_number
1050
+ b_predicate = _is_number
1051
+ elif _is_date_literal(r):
1052
+ a_predicate = _is_date_literal
1053
+ b_predicate = _is_interval
1054
+ else:
1055
+ return expression
1056
+
1057
+ if l.__class__ in self.INVERSE_DATE_OPS:
1058
+ l = t.cast(exp.IntervalOp, l)
1059
+ a = l.this
1060
+ b = l.interval()
1061
+ else:
1062
+ l = t.cast(exp.Binary, l)
1063
+ a, b = l.left, l.right
1331
1064
 
1332
- if len(operands) < size:
1333
- return functools.reduce(
1334
- lambda a, b: expression.__class__(this=a, expression=b), operands
1065
+ if not a_predicate(a) and b_predicate(b):
1066
+ pass
1067
+ elif not a_predicate(b) and b_predicate(a):
1068
+ a, b = b, a
1069
+ else:
1070
+ return expression
1071
+
1072
+ return expression.__class__(
1073
+ this=a, expression=self.INVERSE_OPS[l.__class__](this=r, expression=b)
1335
1074
  )
1336
- return expression
1075
+ return expression
1076
+
1077
+ @annotate_types_on_change
1078
+ def simplify_literals(self, expression, root=True):
1079
+ if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
1080
+ return self._flat_simplify(expression, self._simplify_binary, root)
1081
+
1082
+ if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
1083
+ return expression.this.this
1084
+
1085
+ if type(expression) in self.INVERSE_DATE_OPS:
1086
+ return (
1087
+ self._simplify_binary(expression, expression.this, expression.interval())
1088
+ or expression
1089
+ )
1090
+
1091
+ return expression
1092
+
1093
+ def _simplify_integer_cast(self, expr: exp.Expression) -> exp.Expression:
1094
+ if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
1095
+ this = self._simplify_integer_cast(expr.this)
1096
+ else:
1097
+ this = expr.this
1098
+
1099
+ if isinstance(expr, exp.Cast) and this.is_int:
1100
+ num = this.to_py()
1101
+
1102
+ # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
1103
+ # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
1104
+ # engine-dependent
1105
+ if (
1106
+ self.TINYINT_MIN <= num <= self.TINYINT_MAX
1107
+ and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
1108
+ ) or (
1109
+ self.UTINYINT_MIN <= num <= self.UTINYINT_MAX
1110
+ and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
1111
+ ):
1112
+ return this
1113
+
1114
+ return expr
1115
+
1116
+ def _simplify_binary(self, expression, a, b):
1117
+ if isinstance(expression, self.COMPARISONS):
1118
+ a = self._simplify_integer_cast(a)
1119
+ b = self._simplify_integer_cast(b)
1120
+
1121
+ if isinstance(expression, exp.Is):
1122
+ if isinstance(b, exp.Not):
1123
+ c = b.this
1124
+ not_ = True
1125
+ else:
1126
+ c = b
1127
+ not_ = False
1128
+
1129
+ if is_null(c):
1130
+ if isinstance(a, exp.Literal):
1131
+ return exp.true() if not_ else exp.false()
1132
+ if is_null(a):
1133
+ return exp.false() if not_ else exp.true()
1134
+ elif isinstance(expression, self.NULL_OK):
1135
+ return None
1136
+ elif (is_null(a) or is_null(b)) and isinstance(expression.parent, exp.If):
1137
+ return exp.null()
1138
+
1139
+ if a.is_number and b.is_number:
1140
+ num_a = a.to_py()
1141
+ num_b = b.to_py()
1142
+
1143
+ if isinstance(expression, exp.Add):
1144
+ return exp.Literal.number(num_a + num_b)
1145
+ if isinstance(expression, exp.Mul):
1146
+ return exp.Literal.number(num_a * num_b)
1147
+
1148
+ # We only simplify Sub, Div if a and b have the same parent because they're not associative
1149
+ if isinstance(expression, exp.Sub):
1150
+ return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
1151
+ if isinstance(expression, exp.Div):
1152
+ # engines have differing int div behavior so intdiv is not safe
1153
+ if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
1154
+ return None
1155
+ return exp.Literal.number(num_a / num_b)
1156
+
1157
+ boolean = eval_boolean(expression, num_a, num_b)
1158
+
1159
+ if boolean:
1160
+ return boolean
1161
+ elif a.is_string and b.is_string:
1162
+ boolean = eval_boolean(expression, a.this, b.this)
1163
+
1164
+ if boolean:
1165
+ return boolean
1166
+ elif _is_date_literal(a) and isinstance(b, exp.Interval):
1167
+ date, b = extract_date(a), extract_interval(b)
1168
+ if date and b:
1169
+ if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
1170
+ return date_literal(date + b, extract_type(a))
1171
+ if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
1172
+ return date_literal(date - b, extract_type(a))
1173
+ elif isinstance(a, exp.Interval) and _is_date_literal(b):
1174
+ a, date = extract_interval(a), extract_date(b)
1175
+ # you cannot subtract a date from an interval
1176
+ if a and b and isinstance(expression, exp.Add):
1177
+ return date_literal(a + date, extract_type(b))
1178
+ elif _is_date_literal(a) and _is_date_literal(b):
1179
+ if isinstance(expression, exp.Predicate):
1180
+ a, b = extract_date(a), extract_date(b)
1181
+ boolean = eval_boolean(expression, a, b)
1182
+ if boolean:
1183
+ return boolean
1184
+
1185
+ return None
1186
+
1187
+ @annotate_types_on_change
1188
+ def simplify_coalesce(self, expression: exp.Expression) -> exp.Expression:
1189
+ # COALESCE(x) -> x
1190
+ if (
1191
+ isinstance(expression, exp.Coalesce)
1192
+ and (not expression.expressions or _is_nonnull_constant(expression.this))
1193
+ # COALESCE is also used as a Spark partitioning hint
1194
+ and not isinstance(expression.parent, exp.Hint)
1195
+ ):
1196
+ return expression.this
1197
+
1198
+ if self.dialect.COALESCE_COMPARISON_NON_STANDARD:
1199
+ return expression
1200
+
1201
+ if not isinstance(expression, self.COMPARISONS):
1202
+ return expression
1203
+
1204
+ if isinstance(expression.left, exp.Coalesce):
1205
+ coalesce = expression.left
1206
+ other = expression.right
1207
+ elif isinstance(expression.right, exp.Coalesce):
1208
+ coalesce = expression.right
1209
+ other = expression.left
1210
+ else:
1211
+ return expression
1212
+
1213
+ # This transformation is valid for non-constants,
1214
+ # but it really only does anything if they are both constants.
1215
+ if not _is_constant(other):
1216
+ return expression
1217
+
1218
+ # Find the first constant arg
1219
+ for arg_index, arg in enumerate(coalesce.expressions):
1220
+ if _is_constant(arg):
1221
+ break
1222
+ else:
1223
+ return expression
1224
+
1225
+ coalesce.set("expressions", coalesce.expressions[:arg_index])
1226
+
1227
+ # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
1228
+ # since we already remove COALESCE at the top of this function.
1229
+ coalesce = coalesce if coalesce.expressions else coalesce.this
1230
+
1231
+ # This expression is more complex than when we started, but it will get simplified further
1232
+ return exp.paren(
1233
+ exp.or_(
1234
+ exp.and_(
1235
+ coalesce.is_(exp.null()).not_(copy=False),
1236
+ expression.copy(),
1237
+ copy=False,
1238
+ ),
1239
+ exp.and_(
1240
+ coalesce.is_(exp.null()),
1241
+ type(expression)(this=arg.copy(), expression=other.copy()),
1242
+ copy=False,
1243
+ ),
1244
+ copy=False,
1245
+ ),
1246
+ copy=False,
1247
+ )
1248
+
1249
+ @annotate_types_on_change
1250
+ def simplify_concat(self, expression):
1251
+ """Reduces all groups that contain string literals by concatenating them."""
1252
+ if not isinstance(expression, self.CONCATS) or (
1253
+ # We can't reduce a CONCAT_WS call if we don't statically know the separator
1254
+ isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
1255
+ ):
1256
+ return expression
1257
+
1258
+ if isinstance(expression, exp.ConcatWs):
1259
+ sep_expr, *expressions = expression.expressions
1260
+ sep = sep_expr.name
1261
+ concat_type = exp.ConcatWs
1262
+ args = {}
1263
+ else:
1264
+ expressions = expression.expressions
1265
+ sep = ""
1266
+ concat_type = exp.Concat
1267
+ args = {
1268
+ "safe": expression.args.get("safe"),
1269
+ "coalesce": expression.args.get("coalesce"),
1270
+ }
1271
+
1272
+ new_args = []
1273
+ for is_string_group, group in itertools.groupby(
1274
+ expressions or expression.flatten(), lambda e: e.is_string
1275
+ ):
1276
+ if is_string_group:
1277
+ new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
1278
+ else:
1279
+ new_args.extend(group)
1280
+
1281
+ if len(new_args) == 1 and new_args[0].is_string:
1282
+ return new_args[0]
1283
+
1284
+ if concat_type is exp.ConcatWs:
1285
+ new_args = [sep_expr] + new_args
1286
+ elif isinstance(expression, exp.DPipe):
1287
+ return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
1288
+
1289
+ return concat_type(expressions=new_args, **args)
1290
+
1291
+ @annotate_types_on_change
1292
+ def simplify_conditionals(self, expression):
1293
+ """Simplifies expressions like IF, CASE if their condition is statically known."""
1294
+ if isinstance(expression, exp.Case):
1295
+ this = expression.this
1296
+ for case in expression.args["ifs"]:
1297
+ cond = case.this
1298
+ if this:
1299
+ # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
1300
+ cond = cond.replace(this.pop().eq(cond))
1301
+
1302
+ if always_true(cond):
1303
+ return case.args["true"]
1304
+
1305
+ if always_false(cond):
1306
+ case.pop()
1307
+ if not expression.args["ifs"]:
1308
+ return expression.args.get("default") or exp.null()
1309
+ elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
1310
+ if always_true(expression.this):
1311
+ return expression.args["true"]
1312
+ if always_false(expression.this):
1313
+ return expression.args.get("false") or exp.null()
1314
+
1315
+ return expression
1316
+
1317
+ @annotate_types_on_change
1318
+ def simplify_startswith(self, expression: exp.Expression) -> exp.Expression:
1319
+ """
1320
+ Reduces a prefix check to either TRUE or FALSE if both the string and the
1321
+ prefix are statically known.
1322
+
1323
+ Example:
1324
+ >>> from sqlglot import parse_one
1325
+ >>> Simplifier().simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
1326
+ 'TRUE'
1327
+ """
1328
+ if (
1329
+ isinstance(expression, exp.StartsWith)
1330
+ and expression.this.is_string
1331
+ and expression.expression.is_string
1332
+ ):
1333
+ return exp.convert(expression.name.startswith(expression.expression.name))
1334
+
1335
+ return expression
1336
+
1337
+ def _is_datetrunc_predicate(self, left: exp.Expression, right: exp.Expression) -> bool:
1338
+ return isinstance(left, self.DATETRUNCS) and _is_date_literal(right)
1339
+
1340
+ @annotate_types_on_change
1341
+ @catch(ModuleNotFoundError, UnsupportedUnit)
1342
+ def simplify_datetrunc(self, expression: exp.Expression) -> exp.Expression:
1343
+ """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
1344
+ comparison = expression.__class__
1345
+
1346
+ if isinstance(expression, self.DATETRUNCS):
1347
+ this = expression.this
1348
+ trunc_type = extract_type(this)
1349
+ date = extract_date(this)
1350
+ if date and expression.unit:
1351
+ return date_literal(
1352
+ date_floor(date, expression.unit.name.lower(), self.dialect), trunc_type
1353
+ )
1354
+ elif comparison not in self.DATETRUNC_COMPARISONS:
1355
+ return expression
1356
+
1357
+ if isinstance(expression, exp.Binary):
1358
+ l, r = expression.left, expression.right
1359
+
1360
+ if not self._is_datetrunc_predicate(l, r):
1361
+ return expression
1362
+
1363
+ l = t.cast(exp.DateTrunc, l)
1364
+ trunc_arg = l.this
1365
+ unit = l.unit.name.lower()
1366
+ date = extract_date(r)
1367
+
1368
+ if not date:
1369
+ return expression
1370
+
1371
+ return (
1372
+ self.DATETRUNC_BINARY_COMPARISONS[comparison](
1373
+ trunc_arg, date, unit, self.dialect, extract_type(r)
1374
+ )
1375
+ or expression
1376
+ )
1377
+
1378
+ if isinstance(expression, exp.In):
1379
+ l = expression.this
1380
+ rs = expression.expressions
1381
+
1382
+ if rs and all(self._is_datetrunc_predicate(l, r) for r in rs):
1383
+ l = t.cast(exp.DateTrunc, l)
1384
+ unit = l.unit.name.lower()
1385
+
1386
+ ranges = []
1387
+ for r in rs:
1388
+ date = extract_date(r)
1389
+ if not date:
1390
+ return expression
1391
+ drange = _datetrunc_range(date, unit, self.dialect)
1392
+ if drange:
1393
+ ranges.append(drange)
1394
+
1395
+ if not ranges:
1396
+ return expression
1397
+
1398
+ ranges = merge_ranges(ranges)
1399
+ target_type = extract_type(*rs)
1400
+
1401
+ return exp.or_(
1402
+ *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges],
1403
+ copy=False,
1404
+ )
1405
+
1406
+ return expression
1407
+
1408
+ @annotate_types_on_change
1409
+ def sort_comparison(self, expression: exp.Expression) -> exp.Expression:
1410
+ if expression.__class__ in self.COMPLEMENT_COMPARISONS:
1411
+ l, r = expression.this, expression.expression
1412
+ l_column = isinstance(l, exp.Column)
1413
+ r_column = isinstance(r, exp.Column)
1414
+ l_const = _is_constant(l)
1415
+ r_const = _is_constant(r)
1416
+
1417
+ if (
1418
+ (l_column and not r_column)
1419
+ or (r_const and not l_const)
1420
+ or isinstance(r, exp.SubqueryPredicate)
1421
+ ):
1422
+ return expression
1423
+ if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1424
+ return self.INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1425
+ this=r, expression=l
1426
+ )
1427
+ return expression
1428
+
1429
+ def _flat_simplify(self, expression, simplifier, root=True):
1430
+ if root or not expression.same_parent:
1431
+ operands = []
1432
+ queue = deque(expression.flatten(unnest=False))
1433
+ size = len(queue)
1434
+
1435
+ while queue:
1436
+ a = queue.popleft()
1437
+
1438
+ for b in queue:
1439
+ result = simplifier(expression, a, b)
1440
+
1441
+ if result and result is not expression:
1442
+ queue.remove(b)
1443
+ queue.appendleft(result)
1444
+ break
1445
+ else:
1446
+ operands.append(a)
1447
+
1448
+ if len(operands) < size:
1449
+ return functools.reduce(
1450
+ lambda a, b: expression.__class__(this=a, expression=b), operands
1451
+ )
1452
+ return expression
1337
1453
 
1338
1454
 
1339
1455
  def gen(expression: t.Any, comments: bool = False) -> str: