sqlglot 27.29.0__py3-none-any.whl → 28.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. sqlglot/__main__.py +6 -4
  2. sqlglot/_version.py +2 -2
  3. sqlglot/dialects/bigquery.py +116 -295
  4. sqlglot/dialects/clickhouse.py +67 -2
  5. sqlglot/dialects/databricks.py +38 -1
  6. sqlglot/dialects/dialect.py +327 -286
  7. sqlglot/dialects/dremio.py +4 -1
  8. sqlglot/dialects/duckdb.py +718 -22
  9. sqlglot/dialects/exasol.py +243 -10
  10. sqlglot/dialects/hive.py +8 -8
  11. sqlglot/dialects/mysql.py +11 -2
  12. sqlglot/dialects/oracle.py +29 -0
  13. sqlglot/dialects/postgres.py +46 -24
  14. sqlglot/dialects/presto.py +47 -16
  15. sqlglot/dialects/redshift.py +16 -0
  16. sqlglot/dialects/risingwave.py +3 -0
  17. sqlglot/dialects/singlestore.py +12 -3
  18. sqlglot/dialects/snowflake.py +199 -271
  19. sqlglot/dialects/spark.py +2 -2
  20. sqlglot/dialects/spark2.py +11 -48
  21. sqlglot/dialects/sqlite.py +9 -0
  22. sqlglot/dialects/teradata.py +5 -8
  23. sqlglot/dialects/trino.py +6 -0
  24. sqlglot/dialects/tsql.py +61 -25
  25. sqlglot/diff.py +4 -2
  26. sqlglot/errors.py +69 -0
  27. sqlglot/expressions.py +484 -84
  28. sqlglot/generator.py +143 -41
  29. sqlglot/helper.py +2 -2
  30. sqlglot/optimizer/annotate_types.py +247 -140
  31. sqlglot/optimizer/canonicalize.py +6 -1
  32. sqlglot/optimizer/eliminate_joins.py +1 -1
  33. sqlglot/optimizer/eliminate_subqueries.py +2 -2
  34. sqlglot/optimizer/merge_subqueries.py +5 -5
  35. sqlglot/optimizer/normalize.py +20 -13
  36. sqlglot/optimizer/normalize_identifiers.py +17 -3
  37. sqlglot/optimizer/optimizer.py +4 -0
  38. sqlglot/optimizer/pushdown_predicates.py +1 -1
  39. sqlglot/optimizer/qualify.py +14 -6
  40. sqlglot/optimizer/qualify_columns.py +113 -352
  41. sqlglot/optimizer/qualify_tables.py +112 -70
  42. sqlglot/optimizer/resolver.py +374 -0
  43. sqlglot/optimizer/scope.py +27 -16
  44. sqlglot/optimizer/simplify.py +1074 -964
  45. sqlglot/optimizer/unnest_subqueries.py +12 -2
  46. sqlglot/parser.py +276 -160
  47. sqlglot/planner.py +2 -2
  48. sqlglot/schema.py +15 -4
  49. sqlglot/tokens.py +42 -7
  50. sqlglot/transforms.py +77 -22
  51. sqlglot/typing/__init__.py +316 -0
  52. sqlglot/typing/bigquery.py +376 -0
  53. sqlglot/typing/hive.py +12 -0
  54. sqlglot/typing/presto.py +24 -0
  55. sqlglot/typing/snowflake.py +505 -0
  56. sqlglot/typing/spark2.py +58 -0
  57. sqlglot/typing/tsql.py +9 -0
  58. {sqlglot-27.29.0.dist-info → sqlglot-28.4.0.dist-info}/METADATA +2 -2
  59. sqlglot-28.4.0.dist-info/RECORD +92 -0
  60. sqlglot-27.29.0.dist-info/RECORD +0 -84
  61. {sqlglot-27.29.0.dist-info → sqlglot-28.4.0.dist-info}/WHEEL +0 -0
  62. {sqlglot-27.29.0.dist-info → sqlglot-28.4.0.dist-info}/licenses/LICENSE +0 -0
  63. {sqlglot-27.29.0.dist-info → sqlglot-28.4.0.dist-info}/top_level.txt +0 -0
@@ -6,34 +6,37 @@ import functools
6
6
  import itertools
7
7
  import typing as t
8
8
  from collections import deque, defaultdict
9
- from functools import reduce
9
+ from functools import reduce, wraps
10
10
 
11
11
  import sqlglot
12
12
  from sqlglot import Dialect, exp
13
13
  from sqlglot.helper import first, merge_ranges, while_changing
14
+ from sqlglot.optimizer.annotate_types import TypeAnnotator
14
15
  from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
16
+ from sqlglot.schema import ensure_schema
15
17
 
16
18
  if t.TYPE_CHECKING:
17
19
  from sqlglot.dialects.dialect import DialectType
18
20
 
21
+ DateRange = t.Tuple[datetime.date, datetime.date]
19
22
  DateTruncBinaryTransform = t.Callable[
20
23
  [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
21
24
  ]
22
25
 
26
+
23
27
  logger = logging.getLogger("sqlglot")
24
28
 
29
+
25
30
  # Final means that an expression should not be simplified
26
31
  FINAL = "final"
27
32
 
28
- # Value ranges for byte-sized signed/unsigned integers
29
- TINYINT_MIN = -128
30
- TINYINT_MAX = 127
31
- UTINYINT_MIN = 0
32
- UTINYINT_MAX = 255
33
-
34
-
35
- class UnsupportedUnit(Exception):
36
- pass
33
+ SIMPLIFIABLE = (
34
+ exp.Binary,
35
+ exp.Func,
36
+ exp.Lambda,
37
+ exp.Predicate,
38
+ exp.Unary,
39
+ )
37
40
 
38
41
 
39
42
  def simplify(
@@ -60,96 +63,15 @@ def simplify(
60
63
  Returns:
61
64
  sqlglot.Expression: simplified expression
62
65
  """
66
+ return Simplifier(dialect=dialect).simplify(
67
+ expression,
68
+ constant_propagation=constant_propagation,
69
+ coalesce_simplification=coalesce_simplification,
70
+ )
63
71
 
64
- dialect = Dialect.get_or_raise(dialect)
65
-
66
- def _simplify(expression):
67
- pre_transformation_stack = [expression]
68
- post_transformation_stack = []
69
-
70
- while pre_transformation_stack:
71
- node = pre_transformation_stack.pop()
72
-
73
- if node.meta.get(FINAL):
74
- continue
75
-
76
- # group by expressions cannot be simplified, for example
77
- # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
78
- # the projection must exactly match the group by key
79
- group = node.args.get("group")
80
-
81
- if group and hasattr(node, "selects"):
82
- groups = set(group.expressions)
83
- group.meta[FINAL] = True
84
-
85
- for s in node.selects:
86
- for n in s.walk():
87
- if n in groups:
88
- s.meta[FINAL] = True
89
- break
90
-
91
- having = node.args.get("having")
92
- if having:
93
- for n in having.walk():
94
- if n in groups:
95
- having.meta[FINAL] = True
96
- break
97
-
98
- parent = node.parent
99
- root = node is expression
100
-
101
- new_node = rewrite_between(node)
102
- new_node = uniq_sort(new_node, root)
103
- new_node = absorb_and_eliminate(new_node, root)
104
- new_node = simplify_concat(new_node)
105
- new_node = simplify_conditionals(new_node)
106
-
107
- if constant_propagation:
108
- new_node = propagate_constants(new_node, root)
109
-
110
- if new_node is not node:
111
- node.replace(new_node)
112
-
113
- pre_transformation_stack.extend(
114
- n for n in new_node.iter_expressions(reverse=True) if not n.meta.get(FINAL)
115
- )
116
- post_transformation_stack.append((new_node, parent))
117
-
118
- while post_transformation_stack:
119
- node, parent = post_transformation_stack.pop()
120
- root = node is expression
121
-
122
- # Resets parent, arg_key, index pointers– this is needed because some of the
123
- # previous transformations mutate the AST, leading to an inconsistent state
124
- for k, v in tuple(node.args.items()):
125
- node.set(k, v)
126
-
127
- # Post-order transformations
128
- new_node = simplify_not(node, dialect)
129
- new_node = flatten(new_node)
130
- new_node = simplify_connectors(new_node, root)
131
- new_node = remove_complements(new_node, root)
132
-
133
- if coalesce_simplification:
134
- new_node = simplify_coalesce(new_node, dialect)
135
-
136
- new_node.parent = parent
137
-
138
- new_node = simplify_literals(new_node, root)
139
- new_node = simplify_equality(new_node)
140
- new_node = simplify_parens(new_node, dialect)
141
- new_node = simplify_datetrunc(new_node, dialect)
142
- new_node = sort_comparison(new_node)
143
- new_node = simplify_startswith(new_node)
144
-
145
- if new_node is not node:
146
- node.replace(new_node)
147
-
148
- return new_node
149
72
 
150
- expression = while_changing(expression, _simplify)
151
- remove_where_true(expression)
152
- return expression
73
+ class UnsupportedUnit(Exception):
74
+ pass
153
75
 
154
76
 
155
77
  def catch(*exceptions):
@@ -167,89 +89,30 @@ def catch(*exceptions):
167
89
  return decorator
168
90
 
169
91
 
170
- def rewrite_between(expression: exp.Expression) -> exp.Expression:
171
- """Rewrite x between y and z to x >= y AND x <= z.
172
-
173
- This is done because comparison simplification is only done on lt/lte/gt/gte.
174
- """
175
- if isinstance(expression, exp.Between):
176
- negate = isinstance(expression.parent, exp.Not)
92
+ def annotate_types_on_change(func):
93
+ @wraps(func)
94
+ def _func(self, expression: exp.Expression, *args, **kwargs) -> t.Optional[exp.Expression]:
95
+ new_expression = func(self, expression, *args, **kwargs)
177
96
 
178
- expression = exp.and_(
179
- exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
180
- exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
181
- copy=False,
182
- )
97
+ if new_expression is None:
98
+ return new_expression
183
99
 
184
- if negate:
185
- expression = exp.paren(expression, copy=False)
100
+ if self.annotate_new_expressions and expression != new_expression:
101
+ self._annotator.clear()
186
102
 
187
- return expression
103
+ # We annotate this to ensure new children nodes are also annotated
104
+ new_expression = self._annotator.annotate(
105
+ expression=new_expression,
106
+ annotate_scope=False,
107
+ )
188
108
 
109
+ # Whatever expression the original expression is transformed into needs to preserve
110
+ # the original type, otherwise the simplification could result in a different schema
111
+ new_expression.type = expression.type
189
112
 
190
- COMPLEMENT_COMPARISONS = {
191
- exp.LT: exp.GTE,
192
- exp.GT: exp.LTE,
193
- exp.LTE: exp.GT,
194
- exp.GTE: exp.LT,
195
- exp.EQ: exp.NEQ,
196
- exp.NEQ: exp.EQ,
197
- }
113
+ return new_expression
198
114
 
199
- COMPLEMENT_SUBQUERY_PREDICATES = {
200
- exp.All: exp.Any,
201
- exp.Any: exp.All,
202
- }
203
-
204
-
205
- def simplify_not(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
206
- """
207
- Demorgan's Law
208
- NOT (x OR y) -> NOT x AND NOT y
209
- NOT (x AND y) -> NOT x OR NOT y
210
- """
211
- if isinstance(expression, exp.Not):
212
- this = expression.this
213
- if is_null(this):
214
- return exp.null()
215
- if this.__class__ in COMPLEMENT_COMPARISONS:
216
- right = this.expression
217
- complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__)
218
- if complement_subquery_predicate:
219
- right = complement_subquery_predicate(this=right.this)
220
-
221
- return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
222
- if isinstance(this, exp.Paren):
223
- condition = this.unnest()
224
- if isinstance(condition, exp.And):
225
- return exp.paren(
226
- exp.or_(
227
- exp.not_(condition.left, copy=False),
228
- exp.not_(condition.right, copy=False),
229
- copy=False,
230
- )
231
- )
232
- if isinstance(condition, exp.Or):
233
- return exp.paren(
234
- exp.and_(
235
- exp.not_(condition.left, copy=False),
236
- exp.not_(condition.right, copy=False),
237
- copy=False,
238
- )
239
- )
240
- if is_null(condition):
241
- return exp.null()
242
- if always_true(this):
243
- return exp.false()
244
- if is_false(this):
245
- return exp.true()
246
- if isinstance(this, exp.Not) and dialect.SAFE_TO_ELIMINATE_DOUBLE_NEGATION:
247
- inner = this.this
248
- if inner.is_type(exp.DataType.Type.BOOLEAN) or isinstance(inner, exp.Predicate):
249
- # double negation
250
- # NOT NOT x -> x, if x is BOOLEAN type
251
- return inner
252
- return expression
115
+ return _func
253
116
 
254
117
 
255
118
  def flatten(expression):
@@ -265,247 +128,43 @@ def flatten(expression):
265
128
  return expression
266
129
 
267
130
 
268
- def simplify_connectors(expression, root=True):
269
- def _simplify_connectors(expression, left, right):
270
- if isinstance(expression, exp.And):
271
- if is_false(left) or is_false(right):
272
- return exp.false()
273
- if is_zero(left) or is_zero(right):
274
- return exp.false()
275
- if (
276
- (is_null(left) and is_null(right))
277
- or (is_null(left) and always_true(right))
278
- or (always_true(left) and is_null(right))
279
- ):
280
- return exp.null()
281
- if always_true(left) and always_true(right):
282
- return exp.true()
283
- if always_true(left):
284
- return right
285
- if always_true(right):
286
- return left
287
- return _simplify_comparison(expression, left, right)
288
- elif isinstance(expression, exp.Or):
289
- if always_true(left) or always_true(right):
290
- return exp.true()
291
- if (
292
- (is_null(left) and is_null(right))
293
- or (is_null(left) and always_false(right))
294
- or (always_false(left) and is_null(right))
295
- ):
296
- return exp.null()
297
- if is_false(left):
298
- return right
299
- if is_false(right):
300
- return left
301
- return _simplify_comparison(expression, left, right, or_=True)
302
-
303
- if isinstance(expression, exp.Connector):
304
- return _flat_simplify(expression, _simplify_connectors, root)
305
- return expression
306
-
307
-
308
- LT_LTE = (exp.LT, exp.LTE)
309
- GT_GTE = (exp.GT, exp.GTE)
310
-
311
- COMPARISONS = (
312
- *LT_LTE,
313
- *GT_GTE,
314
- exp.EQ,
315
- exp.NEQ,
316
- exp.Is,
317
- )
318
-
319
- INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
320
- exp.LT: exp.GT,
321
- exp.GT: exp.LT,
322
- exp.LTE: exp.GTE,
323
- exp.GTE: exp.LTE,
324
- }
325
-
326
- NONDETERMINISTIC = (exp.Rand, exp.Randn)
327
- AND_OR = (exp.And, exp.Or)
328
-
329
-
330
- def _simplify_comparison(expression, left, right, or_=False):
331
- if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
332
- ll, lr = left.args.values()
333
- rl, rr = right.args.values()
334
-
335
- largs = {ll, lr}
336
- rargs = {rl, rr}
337
-
338
- matching = largs & rargs
339
- columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)}
340
-
341
- if matching and columns:
342
- try:
343
- l = first(largs - columns)
344
- r = first(rargs - columns)
345
- except StopIteration:
346
- return expression
347
-
348
- if l.is_number and r.is_number:
349
- l = l.to_py()
350
- r = r.to_py()
351
- elif l.is_string and r.is_string:
352
- l = l.name
353
- r = r.name
354
- else:
355
- l = extract_date(l)
356
- if not l:
357
- return None
358
- r = extract_date(r)
359
- if not r:
360
- return None
361
- # python won't compare date and datetime, but many engines will upcast
362
- l, r = cast_as_datetime(l), cast_as_datetime(r)
363
-
364
- for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
365
- if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
366
- return left if (av > bv if or_ else av <= bv) else right
367
- if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
368
- return left if (av < bv if or_ else av >= bv) else right
369
-
370
- # we can't ever shortcut to true because the column could be null
371
- if not or_:
372
- if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
373
- if av <= bv:
374
- return exp.false()
375
- elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
376
- if av >= bv:
377
- return exp.false()
378
- elif isinstance(a, exp.EQ):
379
- if isinstance(b, exp.LT):
380
- return exp.false() if av >= bv else a
381
- if isinstance(b, exp.LTE):
382
- return exp.false() if av > bv else a
383
- if isinstance(b, exp.GT):
384
- return exp.false() if av <= bv else a
385
- if isinstance(b, exp.GTE):
386
- return exp.false() if av < bv else a
387
- if isinstance(b, exp.NEQ):
388
- return exp.false() if av == bv else a
389
- return None
390
-
391
-
392
- def remove_complements(expression, root=True):
393
- """
394
- Removing complements.
395
-
396
- A AND NOT A -> FALSE
397
- A OR NOT A -> TRUE
398
- """
399
- if isinstance(expression, AND_OR) and (root or not expression.same_parent):
400
- ops = set(expression.flatten())
401
- for op in ops:
402
- if isinstance(op, exp.Not) and op.this in ops:
403
- return exp.false() if isinstance(expression, exp.And) else exp.true()
404
-
405
- return expression
406
-
407
-
408
- def uniq_sort(expression, root=True):
409
- """
410
- Uniq and sort a connector.
411
-
412
- C AND A AND B AND B -> A AND B AND C
413
- """
414
- if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
415
- flattened = tuple(expression.flatten())
416
-
417
- if isinstance(expression, exp.Xor):
418
- result_func = exp.xor
419
- # Do not deduplicate XOR as A XOR A != A if A == True
420
- deduped = None
421
- arr = tuple((gen(e), e) for e in flattened)
422
- else:
423
- result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
424
- deduped = {gen(e): e for e in flattened}
425
- arr = tuple(deduped.items())
426
-
427
- # check if the operands are already sorted, if not sort them
428
- # A AND C AND B -> A AND B AND C
429
- for i, (sql, e) in enumerate(arr[1:]):
430
- if sql < arr[i][0]:
431
- expression = result_func(*(e for _, e in sorted(arr)), copy=False)
432
- break
433
- else:
434
- # we didn't have to sort but maybe we need to dedup
435
- if deduped and len(deduped) < len(flattened):
436
- expression = result_func(*deduped.values(), copy=False)
437
-
438
- return expression
439
-
440
-
441
- def absorb_and_eliminate(expression, root=True):
442
- """
443
- absorption:
444
- A AND (A OR B) -> A
445
- A OR (A AND B) -> A
446
- A AND (NOT A OR B) -> A AND B
447
- A OR (NOT A AND B) -> A OR B
448
- elimination:
449
- (A AND B) OR (A AND NOT B) -> A
450
- (A OR B) AND (A OR NOT B) -> A
451
- """
452
- if isinstance(expression, AND_OR) and (root or not expression.same_parent):
453
- kind = exp.Or if isinstance(expression, exp.And) else exp.And
454
-
455
- ops = tuple(expression.flatten())
456
-
457
- # Initialize lookup tables:
458
- # Set of all operands, used to find complements for absorption.
459
- op_set = set()
460
- # Sub-operands, used to find subsets for absorption.
461
- subops = defaultdict(list)
462
- # Pairs of complements, used for elimination.
463
- pairs = defaultdict(list)
464
-
465
- # Populate the lookup tables
466
- for op in ops:
467
- op_set.add(op)
468
-
469
- if not isinstance(op, kind):
470
- # In cases like: A OR (A AND B)
471
- # Subop will be: ^
472
- subops[op].append({op})
473
- continue
474
-
475
- # In cases like: (A AND B) OR (A AND B AND C)
476
- # Subops will be: ^ ^
477
- subset = set(op.flatten())
478
- for i in subset:
479
- subops[i].append(subset)
131
+ def simplify_parens(expression: exp.Expression, dialect: DialectType) -> exp.Expression:
132
+ if not isinstance(expression, exp.Paren):
133
+ return expression
480
134
 
481
- a, b = op.unnest_operands()
482
- if isinstance(a, exp.Not):
483
- pairs[frozenset((a.this, b))].append((op, b))
484
- if isinstance(b, exp.Not):
485
- pairs[frozenset((a, b.this))].append((op, a))
135
+ this = expression.this
136
+ parent = expression.parent
137
+ parent_is_predicate = isinstance(parent, exp.Predicate)
486
138
 
487
- for op in ops:
488
- if not isinstance(op, kind):
489
- continue
139
+ if isinstance(this, exp.Select):
140
+ return expression
490
141
 
491
- a, b = op.unnest_operands()
142
+ if isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)):
143
+ return expression
492
144
 
493
- # Absorb
494
- if isinstance(a, exp.Not) and a.this in op_set:
495
- a.replace(exp.true() if kind == exp.And else exp.false())
496
- continue
497
- if isinstance(b, exp.Not) and b.this in op_set:
498
- b.replace(exp.true() if kind == exp.And else exp.false())
499
- continue
500
- superset = set(op.flatten())
501
- if any(any(subset < superset for subset in subops[i]) for i in superset):
502
- op.replace(exp.false() if kind == exp.And else exp.true())
503
- continue
145
+ if (
146
+ Dialect.get_or_raise(dialect).REQUIRES_PARENTHESIZED_STRUCT_ACCESS
147
+ and isinstance(parent, exp.Dot)
148
+ and (isinstance(parent.right, (exp.Identifier, exp.Star)))
149
+ ):
150
+ return expression
504
151
 
505
- # Eliminate
506
- for other, complement in pairs[frozenset((a, b))]:
507
- op.replace(complement)
508
- other.replace(complement)
152
+ if (
153
+ not isinstance(parent, (exp.Condition, exp.Binary))
154
+ or isinstance(parent, exp.Paren)
155
+ or (
156
+ not isinstance(this, exp.Binary)
157
+ and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
158
+ )
159
+ or (
160
+ isinstance(this, exp.Predicate)
161
+ and not (parent_is_predicate or isinstance(parent, exp.Neg))
162
+ )
163
+ or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
164
+ or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
165
+ or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
166
+ ):
167
+ return this
509
168
 
510
169
  return expression
511
170
 
@@ -549,20 +208,6 @@ def propagate_constants(expression, root=True):
549
208
  return expression
550
209
 
551
210
 
552
- INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
553
- exp.DateAdd: exp.Sub,
554
- exp.DateSub: exp.Add,
555
- exp.DatetimeAdd: exp.Sub,
556
- exp.DatetimeSub: exp.Add,
557
- }
558
-
559
- INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
560
- **INVERSE_DATE_OPS,
561
- exp.Add: exp.Sub,
562
- exp.Sub: exp.Add,
563
- }
564
-
565
-
566
211
  def _is_number(expression: exp.Expression) -> bool:
567
212
  return expression.is_number
568
213
 
@@ -571,210 +216,6 @@ def _is_interval(expression: exp.Expression) -> bool:
571
216
  return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
572
217
 
573
218
 
574
- @catch(ModuleNotFoundError, UnsupportedUnit)
575
- def simplify_equality(expression: exp.Expression) -> exp.Expression:
576
- """
577
- Use the subtraction and addition properties of equality to simplify expressions:
578
-
579
- x + 1 = 3 becomes x = 2
580
-
581
- There are two binary operations in the above expression: + and =
582
- Here's how we reference all the operands in the code below:
583
-
584
- l r
585
- x + 1 = 3
586
- a b
587
- """
588
- if isinstance(expression, COMPARISONS):
589
- l, r = expression.left, expression.right
590
-
591
- if l.__class__ not in INVERSE_OPS:
592
- return expression
593
-
594
- if r.is_number:
595
- a_predicate = _is_number
596
- b_predicate = _is_number
597
- elif _is_date_literal(r):
598
- a_predicate = _is_date_literal
599
- b_predicate = _is_interval
600
- else:
601
- return expression
602
-
603
- if l.__class__ in INVERSE_DATE_OPS:
604
- l = t.cast(exp.IntervalOp, l)
605
- a = l.this
606
- b = l.interval()
607
- else:
608
- l = t.cast(exp.Binary, l)
609
- a, b = l.left, l.right
610
-
611
- if not a_predicate(a) and b_predicate(b):
612
- pass
613
- elif not a_predicate(b) and b_predicate(a):
614
- a, b = b, a
615
- else:
616
- return expression
617
-
618
- return expression.__class__(
619
- this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
620
- )
621
- return expression
622
-
623
-
624
- def simplify_literals(expression, root=True):
625
- if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
626
- return _flat_simplify(expression, _simplify_binary, root)
627
-
628
- if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
629
- return expression.this.this
630
-
631
- if type(expression) in INVERSE_DATE_OPS:
632
- return _simplify_binary(expression, expression.this, expression.interval()) or expression
633
-
634
- return expression
635
-
636
-
637
- NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
638
-
639
-
640
- def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
641
- if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
642
- this = _simplify_integer_cast(expr.this)
643
- else:
644
- this = expr.this
645
-
646
- if isinstance(expr, exp.Cast) and this.is_int:
647
- num = this.to_py()
648
-
649
- # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
650
- # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
651
- # engine-dependent
652
- if (
653
- TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
654
- ) or (
655
- UTINYINT_MIN <= num <= UTINYINT_MAX
656
- and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
657
- ):
658
- return this
659
-
660
- return expr
661
-
662
-
663
- def _simplify_binary(expression, a, b):
664
- if isinstance(expression, COMPARISONS):
665
- a = _simplify_integer_cast(a)
666
- b = _simplify_integer_cast(b)
667
-
668
- if isinstance(expression, exp.Is):
669
- if isinstance(b, exp.Not):
670
- c = b.this
671
- not_ = True
672
- else:
673
- c = b
674
- not_ = False
675
-
676
- if is_null(c):
677
- if isinstance(a, exp.Literal):
678
- return exp.true() if not_ else exp.false()
679
- if is_null(a):
680
- return exp.false() if not_ else exp.true()
681
- elif isinstance(expression, NULL_OK):
682
- return None
683
- elif is_null(a) or is_null(b):
684
- return exp.null()
685
-
686
- if a.is_number and b.is_number:
687
- num_a = a.to_py()
688
- num_b = b.to_py()
689
-
690
- if isinstance(expression, exp.Add):
691
- return exp.Literal.number(num_a + num_b)
692
- if isinstance(expression, exp.Mul):
693
- return exp.Literal.number(num_a * num_b)
694
-
695
- # We only simplify Sub, Div if a and b have the same parent because they're not associative
696
- if isinstance(expression, exp.Sub):
697
- return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
698
- if isinstance(expression, exp.Div):
699
- # engines have differing int div behavior so intdiv is not safe
700
- if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
701
- return None
702
- return exp.Literal.number(num_a / num_b)
703
-
704
- boolean = eval_boolean(expression, num_a, num_b)
705
-
706
- if boolean:
707
- return boolean
708
- elif a.is_string and b.is_string:
709
- boolean = eval_boolean(expression, a.this, b.this)
710
-
711
- if boolean:
712
- return boolean
713
- elif _is_date_literal(a) and isinstance(b, exp.Interval):
714
- date, b = extract_date(a), extract_interval(b)
715
- if date and b:
716
- if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
717
- return date_literal(date + b, extract_type(a))
718
- if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
719
- return date_literal(date - b, extract_type(a))
720
- elif isinstance(a, exp.Interval) and _is_date_literal(b):
721
- a, date = extract_interval(a), extract_date(b)
722
- # you cannot subtract a date from an interval
723
- if a and b and isinstance(expression, exp.Add):
724
- return date_literal(a + date, extract_type(b))
725
- elif _is_date_literal(a) and _is_date_literal(b):
726
- if isinstance(expression, exp.Predicate):
727
- a, b = extract_date(a), extract_date(b)
728
- boolean = eval_boolean(expression, a, b)
729
- if boolean:
730
- return boolean
731
-
732
- return None
733
-
734
-
735
- def simplify_parens(expression: exp.Expression, dialect: DialectType = None) -> exp.Expression:
736
- if not isinstance(expression, exp.Paren):
737
- return expression
738
-
739
- this = expression.this
740
- parent = expression.parent
741
- parent_is_predicate = isinstance(parent, exp.Predicate)
742
-
743
- if isinstance(this, exp.Select):
744
- return expression
745
-
746
- if isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)):
747
- return expression
748
-
749
- # Handle risingwave struct columns
750
- # see https://docs.risingwave.com/sql/data-types/struct#retrieve-data-in-a-struct
751
- if (
752
- dialect == "risingwave"
753
- and isinstance(parent, exp.Dot)
754
- and (isinstance(parent.right, (exp.Identifier, exp.Star)))
755
- ):
756
- return expression
757
-
758
- if (
759
- not isinstance(parent, (exp.Condition, exp.Binary))
760
- or isinstance(parent, exp.Paren)
761
- or (
762
- not isinstance(this, exp.Binary)
763
- and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
764
- )
765
- or (
766
- isinstance(this, exp.Predicate)
767
- and not (parent_is_predicate or isinstance(parent, exp.Neg))
768
- )
769
- or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
770
- or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
771
- or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
772
- ):
773
- return this
774
-
775
- return expression
776
-
777
-
778
219
  def _is_nonnull_constant(expression: exp.Expression) -> bool:
779
220
  return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
780
221
 
@@ -783,164 +224,6 @@ def _is_constant(expression: exp.Expression) -> bool:
783
224
  return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression)
784
225
 
785
226
 
786
- def simplify_coalesce(expression: exp.Expression, dialect: DialectType) -> exp.Expression:
787
- # COALESCE(x) -> x
788
- if (
789
- isinstance(expression, exp.Coalesce)
790
- and (not expression.expressions or _is_nonnull_constant(expression.this))
791
- # COALESCE is also used as a Spark partitioning hint
792
- and not isinstance(expression.parent, exp.Hint)
793
- ):
794
- return expression.this
795
-
796
- # We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift,
797
- # because they are not always equivalent. For example, if `x` is `NULL` and it comes
798
- # from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE`
799
- if dialect == "redshift":
800
- return expression
801
-
802
- if not isinstance(expression, COMPARISONS):
803
- return expression
804
-
805
- if isinstance(expression.left, exp.Coalesce):
806
- coalesce = expression.left
807
- other = expression.right
808
- elif isinstance(expression.right, exp.Coalesce):
809
- coalesce = expression.right
810
- other = expression.left
811
- else:
812
- return expression
813
-
814
- # This transformation is valid for non-constants,
815
- # but it really only does anything if they are both constants.
816
- if not _is_constant(other):
817
- return expression
818
-
819
- # Find the first constant arg
820
- for arg_index, arg in enumerate(coalesce.expressions):
821
- if _is_constant(arg):
822
- break
823
- else:
824
- return expression
825
-
826
- coalesce.set("expressions", coalesce.expressions[:arg_index])
827
-
828
- # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
829
- # since we already remove COALESCE at the top of this function.
830
- coalesce = coalesce if coalesce.expressions else coalesce.this
831
-
832
- # This expression is more complex than when we started, but it will get simplified further
833
- return exp.paren(
834
- exp.or_(
835
- exp.and_(
836
- coalesce.is_(exp.null()).not_(copy=False),
837
- expression.copy(),
838
- copy=False,
839
- ),
840
- exp.and_(
841
- coalesce.is_(exp.null()),
842
- type(expression)(this=arg.copy(), expression=other.copy()),
843
- copy=False,
844
- ),
845
- copy=False,
846
- )
847
- )
848
-
849
-
850
- CONCATS = (exp.Concat, exp.DPipe)
851
-
852
-
853
- def simplify_concat(expression):
854
- """Reduces all groups that contain string literals by concatenating them."""
855
- if not isinstance(expression, CONCATS) or (
856
- # We can't reduce a CONCAT_WS call if we don't statically know the separator
857
- isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
858
- ):
859
- return expression
860
-
861
- if isinstance(expression, exp.ConcatWs):
862
- sep_expr, *expressions = expression.expressions
863
- sep = sep_expr.name
864
- concat_type = exp.ConcatWs
865
- args = {}
866
- else:
867
- expressions = expression.expressions
868
- sep = ""
869
- concat_type = exp.Concat
870
- args = {
871
- "safe": expression.args.get("safe"),
872
- "coalesce": expression.args.get("coalesce"),
873
- }
874
-
875
- new_args = []
876
- for is_string_group, group in itertools.groupby(
877
- expressions or expression.flatten(), lambda e: e.is_string
878
- ):
879
- if is_string_group:
880
- new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
881
- else:
882
- new_args.extend(group)
883
-
884
- if len(new_args) == 1 and new_args[0].is_string:
885
- return new_args[0]
886
-
887
- if concat_type is exp.ConcatWs:
888
- new_args = [sep_expr] + new_args
889
- elif isinstance(expression, exp.DPipe):
890
- return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
891
-
892
- return concat_type(expressions=new_args, **args)
893
-
894
-
895
- def simplify_conditionals(expression):
896
- """Simplifies expressions like IF, CASE if their condition is statically known."""
897
- if isinstance(expression, exp.Case):
898
- this = expression.this
899
- for case in expression.args["ifs"]:
900
- cond = case.this
901
- if this:
902
- # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
903
- cond = cond.replace(this.pop().eq(cond))
904
-
905
- if always_true(cond):
906
- return case.args["true"]
907
-
908
- if always_false(cond):
909
- case.pop()
910
- if not expression.args["ifs"]:
911
- return expression.args.get("default") or exp.null()
912
- elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
913
- if always_true(expression.this):
914
- return expression.args["true"]
915
- if always_false(expression.this):
916
- return expression.args.get("false") or exp.null()
917
-
918
- return expression
919
-
920
-
921
- def simplify_startswith(expression: exp.Expression) -> exp.Expression:
922
- """
923
- Reduces a prefix check to either TRUE or FALSE if both the string and the
924
- prefix are statically known.
925
-
926
- Example:
927
- >>> from sqlglot import parse_one
928
- >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
929
- 'TRUE'
930
- """
931
- if (
932
- isinstance(expression, exp.StartsWith)
933
- and expression.this.is_string
934
- and expression.expression.is_string
935
- ):
936
- return exp.convert(expression.name.startswith(expression.expression.name))
937
-
938
- return expression
939
-
940
-
941
- DateRange = t.Tuple[datetime.date, datetime.date]
942
-
943
-
944
227
  def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
945
228
  """
946
229
  Get the date range for a DATE_TRUNC equality comparison:
@@ -966,170 +249,40 @@ def _datetrunc_eq_expression(
966
249
  return exp.and_(
967
250
  left >= date_literal(drange[0], target_type),
968
251
  left < date_literal(drange[1], target_type),
969
- copy=False,
970
- )
971
-
972
-
973
- def _datetrunc_eq(
974
- left: exp.Expression,
975
- date: datetime.date,
976
- unit: str,
977
- dialect: Dialect,
978
- target_type: t.Optional[exp.DataType],
979
- ) -> t.Optional[exp.Expression]:
980
- drange = _datetrunc_range(date, unit, dialect)
981
- if not drange:
982
- return None
983
-
984
- return _datetrunc_eq_expression(left, drange, target_type)
985
-
986
-
987
- def _datetrunc_neq(
988
- left: exp.Expression,
989
- date: datetime.date,
990
- unit: str,
991
- dialect: Dialect,
992
- target_type: t.Optional[exp.DataType],
993
- ) -> t.Optional[exp.Expression]:
994
- drange = _datetrunc_range(date, unit, dialect)
995
- if not drange:
996
- return None
997
-
998
- return exp.and_(
999
- left < date_literal(drange[0], target_type),
1000
- left >= date_literal(drange[1], target_type),
1001
- copy=False,
1002
- )
1003
-
1004
-
1005
- DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
1006
- exp.LT: lambda l, dt, u, d, t: l
1007
- < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
1008
- exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
1009
- exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
1010
- exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
1011
- exp.EQ: _datetrunc_eq,
1012
- exp.NEQ: _datetrunc_neq,
1013
- }
1014
- DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
1015
- DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
1016
-
1017
-
1018
- def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
1019
- return isinstance(left, DATETRUNCS) and _is_date_literal(right)
1020
-
1021
-
1022
- @catch(ModuleNotFoundError, UnsupportedUnit)
1023
- def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
1024
- """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
1025
- comparison = expression.__class__
1026
-
1027
- if isinstance(expression, DATETRUNCS):
1028
- this = expression.this
1029
- trunc_type = extract_type(this)
1030
- date = extract_date(this)
1031
- if date and expression.unit:
1032
- return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type)
1033
- elif comparison not in DATETRUNC_COMPARISONS:
1034
- return expression
1035
-
1036
- if isinstance(expression, exp.Binary):
1037
- l, r = expression.left, expression.right
1038
-
1039
- if not _is_datetrunc_predicate(l, r):
1040
- return expression
1041
-
1042
- l = t.cast(exp.DateTrunc, l)
1043
- trunc_arg = l.this
1044
- unit = l.unit.name.lower()
1045
- date = extract_date(r)
1046
-
1047
- if not date:
1048
- return expression
1049
-
1050
- return (
1051
- DATETRUNC_BINARY_COMPARISONS[comparison](
1052
- trunc_arg, date, unit, dialect, extract_type(r)
1053
- )
1054
- or expression
1055
- )
1056
-
1057
- if isinstance(expression, exp.In):
1058
- l = expression.this
1059
- rs = expression.expressions
1060
-
1061
- if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
1062
- l = t.cast(exp.DateTrunc, l)
1063
- unit = l.unit.name.lower()
1064
-
1065
- ranges = []
1066
- for r in rs:
1067
- date = extract_date(r)
1068
- if not date:
1069
- return expression
1070
- drange = _datetrunc_range(date, unit, dialect)
1071
- if drange:
1072
- ranges.append(drange)
1073
-
1074
- if not ranges:
1075
- return expression
1076
-
1077
- ranges = merge_ranges(ranges)
1078
- target_type = extract_type(*rs)
1079
-
1080
- return exp.or_(
1081
- *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False
1082
- )
1083
-
1084
- return expression
1085
-
1086
-
1087
- def sort_comparison(expression: exp.Expression) -> exp.Expression:
1088
- if expression.__class__ in COMPLEMENT_COMPARISONS:
1089
- l, r = expression.this, expression.expression
1090
- l_column = isinstance(l, exp.Column)
1091
- r_column = isinstance(r, exp.Column)
1092
- l_const = _is_constant(l)
1093
- r_const = _is_constant(r)
1094
-
1095
- if (
1096
- (l_column and not r_column)
1097
- or (r_const and not l_const)
1098
- or isinstance(r, exp.SubqueryPredicate)
1099
- ):
1100
- return expression
1101
- if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1102
- return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1103
- this=r, expression=l
1104
- )
1105
- return expression
252
+ copy=False,
253
+ )
1106
254
 
1107
255
 
1108
- # CROSS joins result in an empty table if the right table is empty.
1109
- # So we can only simplify certain types of joins to CROSS.
1110
- # Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
1111
- JOINS = {
1112
- ("", ""),
1113
- ("", "INNER"),
1114
- ("RIGHT", ""),
1115
- ("RIGHT", "OUTER"),
1116
- }
256
+ def _datetrunc_eq(
257
+ left: exp.Expression,
258
+ date: datetime.date,
259
+ unit: str,
260
+ dialect: Dialect,
261
+ target_type: t.Optional[exp.DataType],
262
+ ) -> t.Optional[exp.Expression]:
263
+ drange = _datetrunc_range(date, unit, dialect)
264
+ if not drange:
265
+ return None
266
+
267
+ return _datetrunc_eq_expression(left, drange, target_type)
1117
268
 
1118
269
 
1119
- def remove_where_true(expression):
1120
- for where in expression.find_all(exp.Where):
1121
- if always_true(where.this):
1122
- where.pop()
1123
- for join in expression.find_all(exp.Join):
1124
- if (
1125
- always_true(join.args.get("on"))
1126
- and not join.args.get("using")
1127
- and not join.args.get("method")
1128
- and (join.side, join.kind) in JOINS
1129
- ):
1130
- join.args["on"].pop()
1131
- join.set("side", None)
1132
- join.set("kind", "CROSS")
270
+ def _datetrunc_neq(
271
+ left: exp.Expression,
272
+ date: datetime.date,
273
+ unit: str,
274
+ dialect: Dialect,
275
+ target_type: t.Optional[exp.DataType],
276
+ ) -> t.Optional[exp.Expression]:
277
+ drange = _datetrunc_range(date, unit, dialect)
278
+ if not drange:
279
+ return None
280
+
281
+ return exp.and_(
282
+ left < date_literal(drange[0], target_type),
283
+ left >= date_literal(drange[1], target_type),
284
+ copy=False,
285
+ )
1133
286
 
1134
287
 
1135
288
  def always_true(expression):
@@ -1316,30 +469,987 @@ def boolean_literal(condition):
1316
469
  return exp.true() if condition else exp.false()
1317
470
 
1318
471
 
1319
- def _flat_simplify(expression, simplifier, root=True):
1320
- if root or not expression.same_parent:
1321
- operands = []
1322
- queue = deque(expression.flatten(unnest=False))
1323
- size = len(queue)
472
+ class Simplifier:
473
+ def __init__(self, dialect: DialectType = None, annotate_new_expressions: bool = True):
474
+ self.dialect = Dialect.get_or_raise(dialect)
475
+ self.annotate_new_expressions = annotate_new_expressions
476
+
477
+ self._annotator: TypeAnnotator = TypeAnnotator(
478
+ schema=ensure_schema(None, dialect=self.dialect), overwrite_types=False
479
+ )
480
+
481
+ # Value ranges for byte-sized signed/unsigned integers
482
+ TINYINT_MIN = -128
483
+ TINYINT_MAX = 127
484
+ UTINYINT_MIN = 0
485
+ UTINYINT_MAX = 255
486
+
487
+ COMPLEMENT_COMPARISONS = {
488
+ exp.LT: exp.GTE,
489
+ exp.GT: exp.LTE,
490
+ exp.LTE: exp.GT,
491
+ exp.GTE: exp.LT,
492
+ exp.EQ: exp.NEQ,
493
+ exp.NEQ: exp.EQ,
494
+ }
495
+
496
+ COMPLEMENT_SUBQUERY_PREDICATES = {
497
+ exp.All: exp.Any,
498
+ exp.Any: exp.All,
499
+ }
500
+
501
+ LT_LTE = (exp.LT, exp.LTE)
502
+ GT_GTE = (exp.GT, exp.GTE)
503
+
504
+ COMPARISONS = (
505
+ *LT_LTE,
506
+ *GT_GTE,
507
+ exp.EQ,
508
+ exp.NEQ,
509
+ exp.Is,
510
+ )
511
+
512
+ INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
513
+ exp.LT: exp.GT,
514
+ exp.GT: exp.LT,
515
+ exp.LTE: exp.GTE,
516
+ exp.GTE: exp.LTE,
517
+ }
518
+
519
+ NONDETERMINISTIC = (exp.Rand, exp.Randn)
520
+ AND_OR = (exp.And, exp.Or)
521
+
522
+ INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
523
+ exp.DateAdd: exp.Sub,
524
+ exp.DateSub: exp.Add,
525
+ exp.DatetimeAdd: exp.Sub,
526
+ exp.DatetimeSub: exp.Add,
527
+ }
528
+
529
+ INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
530
+ **INVERSE_DATE_OPS,
531
+ exp.Add: exp.Sub,
532
+ exp.Sub: exp.Add,
533
+ }
534
+
535
+ NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
536
+
537
+ CONCATS = (exp.Concat, exp.DPipe)
538
+
539
+ DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
540
+ exp.LT: lambda l, dt, u, d, t: l
541
+ < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
542
+ exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
543
+ exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
544
+ exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
545
+ exp.EQ: _datetrunc_eq,
546
+ exp.NEQ: _datetrunc_neq,
547
+ }
548
+
549
+ DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
550
+ DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
551
+
552
+ SAFE_CONNECTOR_ELIMINATION_RESULT = (exp.Connector, exp.Boolean)
553
+
554
+ # CROSS joins result in an empty table if the right table is empty.
555
+ # So we can only simplify certain types of joins to CROSS.
556
+ # Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
557
+ JOINS = {
558
+ ("", ""),
559
+ ("", "INNER"),
560
+ ("RIGHT", ""),
561
+ ("RIGHT", "OUTER"),
562
+ }
563
+
564
+ def simplify(
565
+ self,
566
+ expression: exp.Expression,
567
+ constant_propagation: bool = False,
568
+ coalesce_simplification: bool = False,
569
+ ):
570
+ wheres = []
571
+ joins = []
572
+
573
+ for node in expression.walk(
574
+ prune=lambda n: bool(isinstance(n, exp.Condition) or n.meta.get(FINAL))
575
+ ):
576
+ if node.meta.get(FINAL):
577
+ continue
578
+
579
+ # group by expressions cannot be simplified, for example
580
+ # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
581
+ # the projection must exactly match the group by key
582
+ group = node.args.get("group")
583
+
584
+ if group and hasattr(node, "selects"):
585
+ groups = set(group.expressions)
586
+ group.meta[FINAL] = True
587
+
588
+ for s in node.selects:
589
+ for n in s.walk(FINAL):
590
+ if n in groups:
591
+ s.meta[FINAL] = True
592
+ break
593
+
594
+ having = node.args.get("having")
595
+
596
+ if having:
597
+ for n in having.walk():
598
+ if n in groups:
599
+ having.meta[FINAL] = True
600
+ break
601
+
602
+ if isinstance(node, exp.Condition):
603
+ simplified = while_changing(
604
+ node, lambda e: self._simplify(e, constant_propagation, coalesce_simplification)
605
+ )
606
+
607
+ if node is expression:
608
+ expression = simplified
609
+ elif isinstance(node, exp.Where):
610
+ wheres.append(node)
611
+ elif isinstance(node, exp.Join):
612
+ # snowflake match_conditions have very strict ordering rules
613
+ if match := node.args.get("match_condition"):
614
+ match.meta[FINAL] = True
615
+
616
+ joins.append(node)
617
+
618
+ for where in wheres:
619
+ if always_true(where.this):
620
+ where.pop()
621
+ for join in joins:
622
+ if (
623
+ always_true(join.args.get("on"))
624
+ and not join.args.get("using")
625
+ and not join.args.get("method")
626
+ and (join.side, join.kind) in self.JOINS
627
+ ):
628
+ join.args["on"].pop()
629
+ join.set("side", None)
630
+ join.set("kind", "CROSS")
631
+
632
+ return expression
633
+
634
+ def _simplify(
635
+ self, expression: exp.Expression, constant_propagation: bool, coalesce_simplification: bool
636
+ ):
637
+ pre_transformation_stack = [expression]
638
+ post_transformation_stack = []
639
+
640
+ while pre_transformation_stack:
641
+ original = pre_transformation_stack.pop()
642
+ node = original
643
+
644
+ if not isinstance(node, SIMPLIFIABLE):
645
+ if isinstance(node, exp.Query):
646
+ self.simplify(node, constant_propagation, coalesce_simplification)
647
+ continue
648
+
649
+ parent = node.parent
650
+ root = node is expression
651
+
652
+ node = self.rewrite_between(node)
653
+ node = self.uniq_sort(node, root)
654
+ node = self.absorb_and_eliminate(node, root)
655
+ node = self.simplify_concat(node)
656
+ node = self.simplify_conditionals(node)
657
+
658
+ if constant_propagation:
659
+ node = propagate_constants(node, root)
660
+
661
+ if node is not original:
662
+ original.replace(node)
663
+
664
+ for n in node.iter_expressions(reverse=True):
665
+ if n.meta.get(FINAL):
666
+ raise
667
+ pre_transformation_stack.extend(
668
+ n for n in node.iter_expressions(reverse=True) if not n.meta.get(FINAL)
669
+ )
670
+ post_transformation_stack.append((node, parent))
671
+
672
+ while post_transformation_stack:
673
+ original, parent = post_transformation_stack.pop()
674
+ root = original is expression
675
+
676
+ # Resets parent, arg_key, index pointers– this is needed because some of the
677
+ # previous transformations mutate the AST, leading to an inconsistent state
678
+ for k, v in tuple(original.args.items()):
679
+ original.set(k, v)
680
+
681
+ # Post-order transformations
682
+ node = self.simplify_not(original)
683
+ node = flatten(node)
684
+ node = self.simplify_connectors(node, root)
685
+ node = self.remove_complements(node, root)
686
+
687
+ if coalesce_simplification:
688
+ node = self.simplify_coalesce(node)
689
+ node.parent = parent
690
+
691
+ node = self.simplify_literals(node, root)
692
+ node = self.simplify_equality(node)
693
+ node = simplify_parens(node, dialect=self.dialect)
694
+ node = self.simplify_datetrunc(node)
695
+ node = self.sort_comparison(node)
696
+ node = self.simplify_startswith(node)
697
+
698
+ if node is not original:
699
+ original.replace(node)
700
+
701
+ return node
702
+
703
+ @annotate_types_on_change
704
+ def rewrite_between(self, expression: exp.Expression) -> exp.Expression:
705
+ """Rewrite x between y and z to x >= y AND x <= z.
706
+
707
+ This is done because comparison simplification is only done on lt/lte/gt/gte.
708
+ """
709
+ if isinstance(expression, exp.Between):
710
+ negate = isinstance(expression.parent, exp.Not)
711
+
712
+ expression = exp.and_(
713
+ exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
714
+ exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
715
+ copy=False,
716
+ )
717
+
718
+ if negate:
719
+ expression = exp.paren(expression, copy=False)
720
+
721
+ return expression
722
+
723
+ @annotate_types_on_change
724
+ def simplify_not(self, expression: exp.Expression) -> exp.Expression:
725
+ """
726
+ Demorgan's Law
727
+ NOT (x OR y) -> NOT x AND NOT y
728
+ NOT (x AND y) -> NOT x OR NOT y
729
+ """
730
+ if isinstance(expression, exp.Not):
731
+ this = expression.this
732
+ if is_null(this):
733
+ return exp.and_(exp.null(), exp.true(), copy=False)
734
+ if this.__class__ in self.COMPLEMENT_COMPARISONS:
735
+ right = this.expression
736
+ complement_subquery_predicate = self.COMPLEMENT_SUBQUERY_PREDICATES.get(
737
+ right.__class__
738
+ )
739
+ if complement_subquery_predicate:
740
+ right = complement_subquery_predicate(this=right.this)
741
+
742
+ return self.COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
743
+ if isinstance(this, exp.Paren):
744
+ condition = this.unnest()
745
+ if isinstance(condition, exp.And):
746
+ return exp.paren(
747
+ exp.or_(
748
+ exp.not_(condition.left, copy=False),
749
+ exp.not_(condition.right, copy=False),
750
+ copy=False,
751
+ ),
752
+ copy=False,
753
+ )
754
+ if isinstance(condition, exp.Or):
755
+ return exp.paren(
756
+ exp.and_(
757
+ exp.not_(condition.left, copy=False),
758
+ exp.not_(condition.right, copy=False),
759
+ copy=False,
760
+ ),
761
+ copy=False,
762
+ )
763
+ if is_null(condition):
764
+ return exp.and_(exp.null(), exp.true(), copy=False)
765
+ if always_true(this):
766
+ return exp.false()
767
+ if is_false(this):
768
+ return exp.true()
769
+ if isinstance(this, exp.Not) and self.dialect.SAFE_TO_ELIMINATE_DOUBLE_NEGATION:
770
+ inner = this.this
771
+ if inner.is_type(exp.DataType.Type.BOOLEAN):
772
+ # double negation
773
+ # NOT NOT x -> x, if x is BOOLEAN type
774
+ return inner
775
+ return expression
776
+
777
+ @annotate_types_on_change
778
+ def simplify_connectors(self, expression, root=True):
779
+ def _simplify_connectors(expression, left, right):
780
+ if isinstance(expression, exp.And):
781
+ if is_false(left) or is_false(right):
782
+ return exp.false()
783
+ if is_zero(left) or is_zero(right):
784
+ return exp.false()
785
+ if (
786
+ (is_null(left) and is_null(right))
787
+ or (is_null(left) and always_true(right))
788
+ or (always_true(left) and is_null(right))
789
+ ):
790
+ return exp.null()
791
+ if always_true(left) and always_true(right):
792
+ return exp.true()
793
+ if always_true(left):
794
+ return right
795
+ if always_true(right):
796
+ return left
797
+ return self._simplify_comparison(expression, left, right)
798
+ elif isinstance(expression, exp.Or):
799
+ if always_true(left) or always_true(right):
800
+ return exp.true()
801
+ if (
802
+ (is_null(left) and is_null(right))
803
+ or (is_null(left) and always_false(right))
804
+ or (always_false(left) and is_null(right))
805
+ ):
806
+ return exp.null()
807
+ if is_false(left):
808
+ return right
809
+ if is_false(right):
810
+ return left
811
+ return self._simplify_comparison(expression, left, right, or_=True)
812
+
813
+ if isinstance(expression, exp.Connector):
814
+ original_parent = expression.parent
815
+ expression = self._flat_simplify(expression, _simplify_connectors, root)
816
+
817
+ # If we reduced a connector to, e.g., a column (t1 AND ... AND tn -> Tk), then we need
818
+ # to ensure that the resulting type is boolean. We know this is true only for connectors,
819
+ # boolean values and columns that are essentially operands to a connector:
820
+ #
821
+ # A AND (((B)))
822
+ # ~ this is safe to keep because it will eventually be part of another connector
823
+ if not isinstance(
824
+ expression, self.SAFE_CONNECTOR_ELIMINATION_RESULT
825
+ ) and not expression.is_type(exp.DataType.Type.BOOLEAN):
826
+ while True:
827
+ if isinstance(original_parent, exp.Connector):
828
+ break
829
+ if not isinstance(original_parent, exp.Paren):
830
+ expression = expression.and_(exp.true(), copy=False)
831
+ break
832
+
833
+ original_parent = original_parent.parent
834
+
835
+ return expression
836
+
837
+ @annotate_types_on_change
838
+ def _simplify_comparison(self, expression, left, right, or_=False):
839
+ if isinstance(left, self.COMPARISONS) and isinstance(right, self.COMPARISONS):
840
+ ll, lr = left.args.values()
841
+ rl, rr = right.args.values()
842
+
843
+ largs = {ll, lr}
844
+ rargs = {rl, rr}
845
+
846
+ matching = largs & rargs
847
+ columns = {
848
+ m for m in matching if not _is_constant(m) and not m.find(*self.NONDETERMINISTIC)
849
+ }
850
+
851
+ if matching and columns:
852
+ try:
853
+ l = first(largs - columns)
854
+ r = first(rargs - columns)
855
+ except StopIteration:
856
+ return expression
857
+
858
+ if l.is_number and r.is_number:
859
+ l = l.to_py()
860
+ r = r.to_py()
861
+ elif l.is_string and r.is_string:
862
+ l = l.name
863
+ r = r.name
864
+ else:
865
+ l = extract_date(l)
866
+ if not l:
867
+ return None
868
+ r = extract_date(r)
869
+ if not r:
870
+ return None
871
+ # python won't compare date and datetime, but many engines will upcast
872
+ l, r = cast_as_datetime(l), cast_as_datetime(r)
873
+
874
+ for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
875
+ if isinstance(a, self.LT_LTE) and isinstance(b, self.LT_LTE):
876
+ return left if (av > bv if or_ else av <= bv) else right
877
+ if isinstance(a, self.GT_GTE) and isinstance(b, self.GT_GTE):
878
+ return left if (av < bv if or_ else av >= bv) else right
879
+
880
+ # we can't ever shortcut to true because the column could be null
881
+ if not or_:
882
+ if isinstance(a, exp.LT) and isinstance(b, self.GT_GTE):
883
+ if av <= bv:
884
+ return exp.false()
885
+ elif isinstance(a, exp.GT) and isinstance(b, self.LT_LTE):
886
+ if av >= bv:
887
+ return exp.false()
888
+ elif isinstance(a, exp.EQ):
889
+ if isinstance(b, exp.LT):
890
+ return exp.false() if av >= bv else a
891
+ if isinstance(b, exp.LTE):
892
+ return exp.false() if av > bv else a
893
+ if isinstance(b, exp.GT):
894
+ return exp.false() if av <= bv else a
895
+ if isinstance(b, exp.GTE):
896
+ return exp.false() if av < bv else a
897
+ if isinstance(b, exp.NEQ):
898
+ return exp.false() if av == bv else a
899
+ return None
1324
900
 
1325
- while queue:
1326
- a = queue.popleft()
901
+ @annotate_types_on_change
902
+ def remove_complements(self, expression, root=True):
903
+ """
904
+ Removing complements.
905
+
906
+ A AND NOT A -> FALSE (only for non-NULL A)
907
+ A OR NOT A -> TRUE (only for non-NULL A)
908
+ """
909
+ if isinstance(expression, self.AND_OR) and (root or not expression.same_parent):
910
+ ops = set(expression.flatten())
911
+ for op in ops:
912
+ if isinstance(op, exp.Not) and op.this in ops:
913
+ if expression.meta.get("nonnull") is True:
914
+ return exp.false() if isinstance(expression, exp.And) else exp.true()
1327
915
 
1328
- for b in queue:
1329
- result = simplifier(expression, a, b)
916
+ return expression
1330
917
 
1331
- if result and result is not expression:
1332
- queue.remove(b)
1333
- queue.appendleft(result)
918
+ @annotate_types_on_change
919
+ def uniq_sort(self, expression, root=True):
920
+ """
921
+ Uniq and sort a connector.
922
+
923
+ C AND A AND B AND B -> A AND B AND C
924
+ """
925
+ if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
926
+ flattened = tuple(expression.flatten())
927
+
928
+ if isinstance(expression, exp.Xor):
929
+ result_func = exp.xor
930
+ # Do not deduplicate XOR as A XOR A != A if A == True
931
+ deduped = None
932
+ arr = tuple((gen(e), e) for e in flattened)
933
+ else:
934
+ result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
935
+ deduped = {gen(e): e for e in flattened}
936
+ arr = tuple(deduped.items())
937
+
938
+ # check if the operands are already sorted, if not sort them
939
+ # A AND C AND B -> A AND B AND C
940
+ for i, (sql, e) in enumerate(arr[1:]):
941
+ if sql < arr[i][0]:
942
+ expression = result_func(*(e for _, e in sorted(arr)), copy=False)
1334
943
  break
1335
944
  else:
1336
- operands.append(a)
945
+ # we didn't have to sort but maybe we need to dedup
946
+ if deduped and len(deduped) < len(flattened):
947
+ unique_operand = flattened[0]
948
+ if len(deduped) == 1:
949
+ expression = unique_operand.and_(exp.true(), copy=False)
950
+ else:
951
+ expression = result_func(*deduped.values(), copy=False)
952
+
953
+ return expression
954
+
955
+ @annotate_types_on_change
956
+ def absorb_and_eliminate(self, expression, root=True):
957
+ """
958
+ absorption:
959
+ A AND (A OR B) -> A
960
+ A OR (A AND B) -> A
961
+ A AND (NOT A OR B) -> A AND B
962
+ A OR (NOT A AND B) -> A OR B
963
+ elimination:
964
+ (A AND B) OR (A AND NOT B) -> A
965
+ (A OR B) AND (A OR NOT B) -> A
966
+ """
967
+ if isinstance(expression, self.AND_OR) and (root or not expression.same_parent):
968
+ kind = exp.Or if isinstance(expression, exp.And) else exp.And
969
+
970
+ ops = tuple(expression.flatten())
971
+
972
+ # Initialize lookup tables:
973
+ # Set of all operands, used to find complements for absorption.
974
+ op_set = set()
975
+ # Sub-operands, used to find subsets for absorption.
976
+ subops = defaultdict(list)
977
+ # Pairs of complements, used for elimination.
978
+ pairs = defaultdict(list)
979
+
980
+ # Populate the lookup tables
981
+ for op in ops:
982
+ op_set.add(op)
983
+
984
+ if not isinstance(op, kind):
985
+ # In cases like: A OR (A AND B)
986
+ # Subop will be: ^
987
+ subops[op].append({op})
988
+ continue
989
+
990
+ # In cases like: (A AND B) OR (A AND B AND C)
991
+ # Subops will be: ^ ^
992
+ subset = set(op.flatten())
993
+ for i in subset:
994
+ subops[i].append(subset)
995
+
996
+ a, b = op.unnest_operands()
997
+ if isinstance(a, exp.Not):
998
+ pairs[frozenset((a.this, b))].append((op, b))
999
+ if isinstance(b, exp.Not):
1000
+ pairs[frozenset((a, b.this))].append((op, a))
1001
+
1002
+ for op in ops:
1003
+ if not isinstance(op, kind):
1004
+ continue
1005
+
1006
+ a, b = op.unnest_operands()
1007
+
1008
+ # Absorb
1009
+ if isinstance(a, exp.Not) and a.this in op_set:
1010
+ a.replace(exp.true() if kind == exp.And else exp.false())
1011
+ continue
1012
+ if isinstance(b, exp.Not) and b.this in op_set:
1013
+ b.replace(exp.true() if kind == exp.And else exp.false())
1014
+ continue
1015
+ superset = set(op.flatten())
1016
+ if any(any(subset < superset for subset in subops[i]) for i in superset):
1017
+ op.replace(exp.false() if kind == exp.And else exp.true())
1018
+ continue
1019
+
1020
+ # Eliminate
1021
+ for other, complement in pairs[frozenset((a, b))]:
1022
+ op.replace(complement)
1023
+ other.replace(complement)
1024
+
1025
+ return expression
1026
+
1027
+ @annotate_types_on_change
1028
+ @catch(ModuleNotFoundError, UnsupportedUnit)
1029
+ def simplify_equality(self, expression: exp.Expression) -> exp.Expression:
1030
+ """
1031
+ Use the subtraction and addition properties of equality to simplify expressions:
1032
+
1033
+ x + 1 = 3 becomes x = 2
1034
+
1035
+ There are two binary operations in the above expression: + and =
1036
+ Here's how we reference all the operands in the code below:
1037
+
1038
+ l r
1039
+ x + 1 = 3
1040
+ a b
1041
+ """
1042
+ if isinstance(expression, self.COMPARISONS):
1043
+ l, r = expression.left, expression.right
1044
+
1045
+ if l.__class__ not in self.INVERSE_OPS:
1046
+ return expression
1047
+
1048
+ if r.is_number:
1049
+ a_predicate = _is_number
1050
+ b_predicate = _is_number
1051
+ elif _is_date_literal(r):
1052
+ a_predicate = _is_date_literal
1053
+ b_predicate = _is_interval
1054
+ else:
1055
+ return expression
1056
+
1057
+ if l.__class__ in self.INVERSE_DATE_OPS:
1058
+ l = t.cast(exp.IntervalOp, l)
1059
+ a = l.this
1060
+ b = l.interval()
1061
+ else:
1062
+ l = t.cast(exp.Binary, l)
1063
+ a, b = l.left, l.right
1337
1064
 
1338
- if len(operands) < size:
1339
- return functools.reduce(
1340
- lambda a, b: expression.__class__(this=a, expression=b), operands
1065
+ if not a_predicate(a) and b_predicate(b):
1066
+ pass
1067
+ elif not a_predicate(b) and b_predicate(a):
1068
+ a, b = b, a
1069
+ else:
1070
+ return expression
1071
+
1072
+ return expression.__class__(
1073
+ this=a, expression=self.INVERSE_OPS[l.__class__](this=r, expression=b)
1341
1074
  )
1342
- return expression
1075
+ return expression
1076
+
1077
+ @annotate_types_on_change
1078
+ def simplify_literals(self, expression, root=True):
1079
+ if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
1080
+ return self._flat_simplify(expression, self._simplify_binary, root)
1081
+
1082
+ if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
1083
+ return expression.this.this
1084
+
1085
+ if type(expression) in self.INVERSE_DATE_OPS:
1086
+ return (
1087
+ self._simplify_binary(expression, expression.this, expression.interval())
1088
+ or expression
1089
+ )
1090
+
1091
+ return expression
1092
+
1093
+ def _simplify_integer_cast(self, expr: exp.Expression) -> exp.Expression:
1094
+ if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
1095
+ this = self._simplify_integer_cast(expr.this)
1096
+ else:
1097
+ this = expr.this
1098
+
1099
+ if isinstance(expr, exp.Cast) and this.is_int:
1100
+ num = this.to_py()
1101
+
1102
+ # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
1103
+ # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
1104
+ # engine-dependent
1105
+ if (
1106
+ self.TINYINT_MIN <= num <= self.TINYINT_MAX
1107
+ and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
1108
+ ) or (
1109
+ self.UTINYINT_MIN <= num <= self.UTINYINT_MAX
1110
+ and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
1111
+ ):
1112
+ return this
1113
+
1114
+ return expr
1115
+
1116
+ def _simplify_binary(self, expression, a, b):
1117
+ if isinstance(expression, self.COMPARISONS):
1118
+ a = self._simplify_integer_cast(a)
1119
+ b = self._simplify_integer_cast(b)
1120
+
1121
+ if isinstance(expression, exp.Is):
1122
+ if isinstance(b, exp.Not):
1123
+ c = b.this
1124
+ not_ = True
1125
+ else:
1126
+ c = b
1127
+ not_ = False
1128
+
1129
+ if is_null(c):
1130
+ if isinstance(a, exp.Literal):
1131
+ return exp.true() if not_ else exp.false()
1132
+ if is_null(a):
1133
+ return exp.false() if not_ else exp.true()
1134
+ elif isinstance(expression, self.NULL_OK):
1135
+ return None
1136
+ elif (is_null(a) or is_null(b)) and isinstance(expression.parent, exp.If):
1137
+ return exp.null()
1138
+
1139
+ if a.is_number and b.is_number:
1140
+ num_a = a.to_py()
1141
+ num_b = b.to_py()
1142
+
1143
+ if isinstance(expression, exp.Add):
1144
+ return exp.Literal.number(num_a + num_b)
1145
+ if isinstance(expression, exp.Mul):
1146
+ return exp.Literal.number(num_a * num_b)
1147
+
1148
+ # We only simplify Sub, Div if a and b have the same parent because they're not associative
1149
+ if isinstance(expression, exp.Sub):
1150
+ return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
1151
+ if isinstance(expression, exp.Div):
1152
+ # engines have differing int div behavior so intdiv is not safe
1153
+ if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
1154
+ return None
1155
+ return exp.Literal.number(num_a / num_b)
1156
+
1157
+ boolean = eval_boolean(expression, num_a, num_b)
1158
+
1159
+ if boolean:
1160
+ return boolean
1161
+ elif a.is_string and b.is_string:
1162
+ boolean = eval_boolean(expression, a.this, b.this)
1163
+
1164
+ if boolean:
1165
+ return boolean
1166
+ elif _is_date_literal(a) and isinstance(b, exp.Interval):
1167
+ date, b = extract_date(a), extract_interval(b)
1168
+ if date and b:
1169
+ if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
1170
+ return date_literal(date + b, extract_type(a))
1171
+ if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
1172
+ return date_literal(date - b, extract_type(a))
1173
+ elif isinstance(a, exp.Interval) and _is_date_literal(b):
1174
+ a, date = extract_interval(a), extract_date(b)
1175
+ # you cannot subtract a date from an interval
1176
+ if a and b and isinstance(expression, exp.Add):
1177
+ return date_literal(a + date, extract_type(b))
1178
+ elif _is_date_literal(a) and _is_date_literal(b):
1179
+ if isinstance(expression, exp.Predicate):
1180
+ a, b = extract_date(a), extract_date(b)
1181
+ boolean = eval_boolean(expression, a, b)
1182
+ if boolean:
1183
+ return boolean
1184
+
1185
+ return None
1186
+
1187
+ @annotate_types_on_change
1188
+ def simplify_coalesce(self, expression: exp.Expression) -> exp.Expression:
1189
+ # COALESCE(x) -> x
1190
+ if (
1191
+ isinstance(expression, exp.Coalesce)
1192
+ and (not expression.expressions or _is_nonnull_constant(expression.this))
1193
+ # COALESCE is also used as a Spark partitioning hint
1194
+ and not isinstance(expression.parent, exp.Hint)
1195
+ ):
1196
+ return expression.this
1197
+
1198
+ if self.dialect.COALESCE_COMPARISON_NON_STANDARD:
1199
+ return expression
1200
+
1201
+ if not isinstance(expression, self.COMPARISONS):
1202
+ return expression
1203
+
1204
+ if isinstance(expression.left, exp.Coalesce):
1205
+ coalesce = expression.left
1206
+ other = expression.right
1207
+ elif isinstance(expression.right, exp.Coalesce):
1208
+ coalesce = expression.right
1209
+ other = expression.left
1210
+ else:
1211
+ return expression
1212
+
1213
+ # This transformation is valid for non-constants,
1214
+ # but it really only does anything if they are both constants.
1215
+ if not _is_constant(other):
1216
+ return expression
1217
+
1218
+ # Find the first constant arg
1219
+ for arg_index, arg in enumerate(coalesce.expressions):
1220
+ if _is_constant(arg):
1221
+ break
1222
+ else:
1223
+ return expression
1224
+
1225
+ coalesce.set("expressions", coalesce.expressions[:arg_index])
1226
+
1227
+ # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
1228
+ # since we already remove COALESCE at the top of this function.
1229
+ coalesce = coalesce if coalesce.expressions else coalesce.this
1230
+
1231
+ # This expression is more complex than when we started, but it will get simplified further
1232
+ return exp.paren(
1233
+ exp.or_(
1234
+ exp.and_(
1235
+ coalesce.is_(exp.null()).not_(copy=False),
1236
+ expression.copy(),
1237
+ copy=False,
1238
+ ),
1239
+ exp.and_(
1240
+ coalesce.is_(exp.null()),
1241
+ type(expression)(this=arg.copy(), expression=other.copy()),
1242
+ copy=False,
1243
+ ),
1244
+ copy=False,
1245
+ ),
1246
+ copy=False,
1247
+ )
1248
+
1249
+ @annotate_types_on_change
1250
+ def simplify_concat(self, expression):
1251
+ """Reduces all groups that contain string literals by concatenating them."""
1252
+ if not isinstance(expression, self.CONCATS) or (
1253
+ # We can't reduce a CONCAT_WS call if we don't statically know the separator
1254
+ isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
1255
+ ):
1256
+ return expression
1257
+
1258
+ if isinstance(expression, exp.ConcatWs):
1259
+ sep_expr, *expressions = expression.expressions
1260
+ sep = sep_expr.name
1261
+ concat_type = exp.ConcatWs
1262
+ args = {}
1263
+ else:
1264
+ expressions = expression.expressions
1265
+ sep = ""
1266
+ concat_type = exp.Concat
1267
+ args = {
1268
+ "safe": expression.args.get("safe"),
1269
+ "coalesce": expression.args.get("coalesce"),
1270
+ }
1271
+
1272
+ new_args = []
1273
+ for is_string_group, group in itertools.groupby(
1274
+ expressions or expression.flatten(), lambda e: e.is_string
1275
+ ):
1276
+ if is_string_group:
1277
+ new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
1278
+ else:
1279
+ new_args.extend(group)
1280
+
1281
+ if len(new_args) == 1 and new_args[0].is_string:
1282
+ return new_args[0]
1283
+
1284
+ if concat_type is exp.ConcatWs:
1285
+ new_args = [sep_expr] + new_args
1286
+ elif isinstance(expression, exp.DPipe):
1287
+ return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
1288
+
1289
+ return concat_type(expressions=new_args, **args)
1290
+
1291
+ @annotate_types_on_change
1292
+ def simplify_conditionals(self, expression):
1293
+ """Simplifies expressions like IF, CASE if their condition is statically known."""
1294
+ if isinstance(expression, exp.Case):
1295
+ this = expression.this
1296
+ for case in expression.args["ifs"]:
1297
+ cond = case.this
1298
+ if this:
1299
+ # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
1300
+ cond = cond.replace(this.pop().eq(cond))
1301
+
1302
+ if always_true(cond):
1303
+ return case.args["true"]
1304
+
1305
+ if always_false(cond):
1306
+ case.pop()
1307
+ if not expression.args["ifs"]:
1308
+ return expression.args.get("default") or exp.null()
1309
+ elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
1310
+ if always_true(expression.this):
1311
+ return expression.args["true"]
1312
+ if always_false(expression.this):
1313
+ return expression.args.get("false") or exp.null()
1314
+
1315
+ return expression
1316
+
1317
+ @annotate_types_on_change
1318
+ def simplify_startswith(self, expression: exp.Expression) -> exp.Expression:
1319
+ """
1320
+ Reduces a prefix check to either TRUE or FALSE if both the string and the
1321
+ prefix are statically known.
1322
+
1323
+ Example:
1324
+ >>> from sqlglot import parse_one
1325
+ >>> Simplifier().simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
1326
+ 'TRUE'
1327
+ """
1328
+ if (
1329
+ isinstance(expression, exp.StartsWith)
1330
+ and expression.this.is_string
1331
+ and expression.expression.is_string
1332
+ ):
1333
+ return exp.convert(expression.name.startswith(expression.expression.name))
1334
+
1335
+ return expression
1336
+
1337
+ def _is_datetrunc_predicate(self, left: exp.Expression, right: exp.Expression) -> bool:
1338
+ return isinstance(left, self.DATETRUNCS) and _is_date_literal(right)
1339
+
1340
+ @annotate_types_on_change
1341
+ @catch(ModuleNotFoundError, UnsupportedUnit)
1342
+ def simplify_datetrunc(self, expression: exp.Expression) -> exp.Expression:
1343
+ """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
1344
+ comparison = expression.__class__
1345
+
1346
+ if isinstance(expression, self.DATETRUNCS):
1347
+ this = expression.this
1348
+ trunc_type = extract_type(this)
1349
+ date = extract_date(this)
1350
+ if date and expression.unit:
1351
+ return date_literal(
1352
+ date_floor(date, expression.unit.name.lower(), self.dialect), trunc_type
1353
+ )
1354
+ elif comparison not in self.DATETRUNC_COMPARISONS:
1355
+ return expression
1356
+
1357
+ if isinstance(expression, exp.Binary):
1358
+ l, r = expression.left, expression.right
1359
+
1360
+ if not self._is_datetrunc_predicate(l, r):
1361
+ return expression
1362
+
1363
+ l = t.cast(exp.DateTrunc, l)
1364
+ trunc_arg = l.this
1365
+ unit = l.unit.name.lower()
1366
+ date = extract_date(r)
1367
+
1368
+ if not date:
1369
+ return expression
1370
+
1371
+ return (
1372
+ self.DATETRUNC_BINARY_COMPARISONS[comparison](
1373
+ trunc_arg, date, unit, self.dialect, extract_type(r)
1374
+ )
1375
+ or expression
1376
+ )
1377
+
1378
+ if isinstance(expression, exp.In):
1379
+ l = expression.this
1380
+ rs = expression.expressions
1381
+
1382
+ if rs and all(self._is_datetrunc_predicate(l, r) for r in rs):
1383
+ l = t.cast(exp.DateTrunc, l)
1384
+ unit = l.unit.name.lower()
1385
+
1386
+ ranges = []
1387
+ for r in rs:
1388
+ date = extract_date(r)
1389
+ if not date:
1390
+ return expression
1391
+ drange = _datetrunc_range(date, unit, self.dialect)
1392
+ if drange:
1393
+ ranges.append(drange)
1394
+
1395
+ if not ranges:
1396
+ return expression
1397
+
1398
+ ranges = merge_ranges(ranges)
1399
+ target_type = extract_type(*rs)
1400
+
1401
+ return exp.or_(
1402
+ *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges],
1403
+ copy=False,
1404
+ )
1405
+
1406
+ return expression
1407
+
1408
+ @annotate_types_on_change
1409
+ def sort_comparison(self, expression: exp.Expression) -> exp.Expression:
1410
+ if expression.__class__ in self.COMPLEMENT_COMPARISONS:
1411
+ l, r = expression.this, expression.expression
1412
+ l_column = isinstance(l, exp.Column)
1413
+ r_column = isinstance(r, exp.Column)
1414
+ l_const = _is_constant(l)
1415
+ r_const = _is_constant(r)
1416
+
1417
+ if (
1418
+ (l_column and not r_column)
1419
+ or (r_const and not l_const)
1420
+ or isinstance(r, exp.SubqueryPredicate)
1421
+ ):
1422
+ return expression
1423
+ if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1424
+ return self.INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1425
+ this=r, expression=l
1426
+ )
1427
+ return expression
1428
+
1429
+ def _flat_simplify(self, expression, simplifier, root=True):
1430
+ if root or not expression.same_parent:
1431
+ operands = []
1432
+ queue = deque(expression.flatten(unnest=False))
1433
+ size = len(queue)
1434
+
1435
+ while queue:
1436
+ a = queue.popleft()
1437
+
1438
+ for b in queue:
1439
+ result = simplifier(expression, a, b)
1440
+
1441
+ if result and result is not expression:
1442
+ queue.remove(b)
1443
+ queue.appendleft(result)
1444
+ break
1445
+ else:
1446
+ operands.append(a)
1447
+
1448
+ if len(operands) < size:
1449
+ return functools.reduce(
1450
+ lambda a, b: expression.__class__(this=a, expression=b), operands
1451
+ )
1452
+ return expression
1343
1453
 
1344
1454
 
1345
1455
  def gen(expression: t.Any, comments: bool = False) -> str: