@altimateai/altimate-code 0.5.1 → 0.5.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. package/CHANGELOG.md +35 -0
  2. package/bin/altimate +6 -0
  3. package/bin/altimate-code +6 -0
  4. package/dbt-tools/bin/altimate-dbt +2 -0
  5. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/__init__.py +0 -0
  6. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/fetch_schema.py +35 -0
  7. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/utils.py +353 -0
  8. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/validate_sql.py +114 -0
  9. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__init__.py +178 -0
  10. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__main__.py +96 -0
  11. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/_typing.py +17 -0
  12. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/__init__.py +3 -0
  13. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/__init__.py +18 -0
  14. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/_typing.py +18 -0
  15. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/column.py +332 -0
  16. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/dataframe.py +866 -0
  17. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/functions.py +1267 -0
  18. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/group.py +59 -0
  19. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/normalize.py +78 -0
  20. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/operations.py +53 -0
  21. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/readwriter.py +108 -0
  22. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/session.py +190 -0
  23. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/transforms.py +9 -0
  24. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/types.py +212 -0
  25. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/util.py +32 -0
  26. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/window.py +134 -0
  27. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/__init__.py +118 -0
  28. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/athena.py +166 -0
  29. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/bigquery.py +1331 -0
  30. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/clickhouse.py +1393 -0
  31. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/databricks.py +131 -0
  32. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dialect.py +1915 -0
  33. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/doris.py +561 -0
  34. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/drill.py +157 -0
  35. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/druid.py +20 -0
  36. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/duckdb.py +1159 -0
  37. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dune.py +16 -0
  38. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/hive.py +787 -0
  39. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/materialize.py +94 -0
  40. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/mysql.py +1324 -0
  41. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/oracle.py +378 -0
  42. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/postgres.py +778 -0
  43. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/presto.py +788 -0
  44. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/prql.py +203 -0
  45. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/redshift.py +448 -0
  46. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/risingwave.py +78 -0
  47. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/snowflake.py +1464 -0
  48. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark.py +202 -0
  49. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark2.py +349 -0
  50. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/sqlite.py +320 -0
  51. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/starrocks.py +343 -0
  52. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tableau.py +61 -0
  53. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/teradata.py +356 -0
  54. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/trino.py +115 -0
  55. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tsql.py +1403 -0
  56. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/diff.py +456 -0
  57. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/errors.py +93 -0
  58. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/__init__.py +95 -0
  59. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/context.py +101 -0
  60. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/env.py +246 -0
  61. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/python.py +460 -0
  62. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/table.py +155 -0
  63. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/expressions.py +8870 -0
  64. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/generator.py +4993 -0
  65. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/helper.py +582 -0
  66. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/jsonpath.py +227 -0
  67. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/lineage.py +423 -0
  68. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/__init__.py +11 -0
  69. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/annotate_types.py +589 -0
  70. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/canonicalize.py +222 -0
  71. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_ctes.py +43 -0
  72. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_joins.py +181 -0
  73. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_subqueries.py +189 -0
  74. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/isolate_table_selects.py +50 -0
  75. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/merge_subqueries.py +415 -0
  76. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize.py +200 -0
  77. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize_identifiers.py +64 -0
  78. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimize_joins.py +91 -0
  79. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimizer.py +94 -0
  80. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_predicates.py +222 -0
  81. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_projections.py +172 -0
  82. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify.py +104 -0
  83. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_columns.py +1024 -0
  84. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_tables.py +155 -0
  85. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/scope.py +904 -0
  86. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/simplify.py +1587 -0
  87. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/unnest_subqueries.py +302 -0
  88. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/parser.py +8501 -0
  89. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/planner.py +463 -0
  90. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/schema.py +588 -0
  91. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/serde.py +68 -0
  92. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/time.py +687 -0
  93. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/tokens.py +1520 -0
  94. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/transforms.py +1020 -0
  95. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/trie.py +81 -0
  96. package/dbt-tools/dist/altimate_python_packages/dbt_core_integration.py +825 -0
  97. package/dbt-tools/dist/altimate_python_packages/dbt_utils.py +157 -0
  98. package/dbt-tools/dist/index.js +23859 -0
  99. package/package.json +13 -13
  100. package/postinstall.mjs +42 -0
  101. package/skills/altimate-setup/SKILL.md +31 -0
@@ -0,0 +1,1587 @@
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import logging
5
+ import functools
6
+ import itertools
7
+ import typing as t
8
+ from collections import deque, defaultdict
9
+ from functools import reduce
10
+
11
+ import sqlglot
12
+ from sqlglot import Dialect, exp
13
+ from sqlglot.helper import first, merge_ranges, while_changing
14
+ from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
15
+
16
+ if t.TYPE_CHECKING:
17
+ from sqlglot.dialects.dialect import DialectType
18
+
19
+ DateTruncBinaryTransform = t.Callable[
20
+ [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
21
+ ]
22
+
23
+ logger = logging.getLogger("sqlglot")
24
+
25
+ # Final means that an expression should not be simplified
26
+ FINAL = "final"
27
+
28
+ # Value ranges for byte-sized signed/unsigned integers
29
+ TINYINT_MIN = -128
30
+ TINYINT_MAX = 127
31
+ UTINYINT_MIN = 0
32
+ UTINYINT_MAX = 255
33
+
34
+
35
+ class UnsupportedUnit(Exception):
36
+ pass
37
+
38
+
39
+ def simplify(
40
+ expression: exp.Expression,
41
+ constant_propagation: bool = False,
42
+ coalesce_simplification: bool = False,
43
+ dialect: DialectType = None,
44
+ ):
45
+ """
46
+ Rewrite sqlglot AST to simplify expressions.
47
+
48
+ Example:
49
+ >>> import sqlglot
50
+ >>> expression = sqlglot.parse_one("TRUE AND TRUE")
51
+ >>> simplify(expression).sql()
52
+ 'TRUE'
53
+
54
+ Args:
55
+ expression: expression to simplify
56
+ constant_propagation: whether the constant propagation rule should be used
57
+ coalesce_simplification: whether the simplify coalesce rule should be used.
58
+ This rule tries to remove coalesce functions, which can be useful in certain analyses but
59
+ can leave the query more verbose.
60
+ Returns:
61
+ sqlglot.Expression: simplified expression
62
+ """
63
+
64
+ dialect = Dialect.get_or_raise(dialect)
65
+
66
+ def _simplify(expression):
67
+ pre_transformation_stack = [expression]
68
+ post_transformation_stack = []
69
+
70
+ while pre_transformation_stack:
71
+ node = pre_transformation_stack.pop()
72
+
73
+ if node.meta.get(FINAL):
74
+ continue
75
+
76
+ # group by expressions cannot be simplified, for example
77
+ # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
78
+ # the projection must exactly match the group by key
79
+ group = node.args.get("group")
80
+
81
+ if group and hasattr(node, "selects"):
82
+ groups = set(group.expressions)
83
+ group.meta[FINAL] = True
84
+
85
+ for s in node.selects:
86
+ for n in s.walk():
87
+ if n in groups:
88
+ s.meta[FINAL] = True
89
+ break
90
+
91
+ having = node.args.get("having")
92
+ if having:
93
+ for n in having.walk():
94
+ if n in groups:
95
+ having.meta[FINAL] = True
96
+ break
97
+
98
+ parent = node.parent
99
+ root = node is expression
100
+
101
+ new_node = rewrite_between(node)
102
+ new_node = uniq_sort(new_node, root)
103
+ new_node = absorb_and_eliminate(new_node, root)
104
+ new_node = simplify_concat(new_node)
105
+ new_node = simplify_conditionals(new_node)
106
+
107
+ if constant_propagation:
108
+ new_node = propagate_constants(new_node, root)
109
+
110
+ if new_node is not node:
111
+ node.replace(new_node)
112
+
113
+ pre_transformation_stack.extend(
114
+ n for n in new_node.iter_expressions(reverse=True) if not n.meta.get(FINAL)
115
+ )
116
+ post_transformation_stack.append((new_node, parent))
117
+
118
+ while post_transformation_stack:
119
+ node, parent = post_transformation_stack.pop()
120
+ root = node is expression
121
+
122
+ # Resets parent, arg_key, index pointers– this is needed because some of the
123
+ # previous transformations mutate the AST, leading to an inconsistent state
124
+ for k, v in tuple(node.args.items()):
125
+ node.set(k, v)
126
+
127
+ # Post-order transformations
128
+ new_node = simplify_not(node)
129
+ new_node = flatten(new_node)
130
+ new_node = simplify_connectors(new_node, root)
131
+ new_node = remove_complements(new_node, root)
132
+
133
+ if coalesce_simplification:
134
+ new_node = simplify_coalesce(new_node, dialect)
135
+
136
+ new_node.parent = parent
137
+
138
+ new_node = simplify_literals(new_node, root)
139
+ new_node = simplify_equality(new_node)
140
+ new_node = simplify_parens(new_node)
141
+ new_node = simplify_datetrunc(new_node, dialect)
142
+ new_node = sort_comparison(new_node)
143
+ new_node = simplify_startswith(new_node)
144
+
145
+ if new_node is not node:
146
+ node.replace(new_node)
147
+
148
+ return new_node
149
+
150
+ expression = while_changing(expression, _simplify)
151
+ remove_where_true(expression)
152
+ return expression
153
+
154
+
155
+ def catch(*exceptions):
156
+ """Decorator that ignores a simplification function if any of `exceptions` are raised"""
157
+
158
+ def decorator(func):
159
+ def wrapped(expression, *args, **kwargs):
160
+ try:
161
+ return func(expression, *args, **kwargs)
162
+ except exceptions:
163
+ return expression
164
+
165
+ return wrapped
166
+
167
+ return decorator
168
+
169
+
170
+ def rewrite_between(expression: exp.Expression) -> exp.Expression:
171
+ """Rewrite x between y and z to x >= y AND x <= z.
172
+
173
+ This is done because comparison simplification is only done on lt/lte/gt/gte.
174
+ """
175
+ if isinstance(expression, exp.Between):
176
+ negate = isinstance(expression.parent, exp.Not)
177
+
178
+ expression = exp.and_(
179
+ exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
180
+ exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
181
+ copy=False,
182
+ )
183
+
184
+ if negate:
185
+ expression = exp.paren(expression, copy=False)
186
+
187
+ return expression
188
+
189
+
190
+ COMPLEMENT_COMPARISONS = {
191
+ exp.LT: exp.GTE,
192
+ exp.GT: exp.LTE,
193
+ exp.LTE: exp.GT,
194
+ exp.GTE: exp.LT,
195
+ exp.EQ: exp.NEQ,
196
+ exp.NEQ: exp.EQ,
197
+ }
198
+
199
+ COMPLEMENT_SUBQUERY_PREDICATES = {
200
+ exp.All: exp.Any,
201
+ exp.Any: exp.All,
202
+ }
203
+
204
+
205
+ def simplify_not(expression):
206
+ """
207
+ Demorgan's Law
208
+ NOT (x OR y) -> NOT x AND NOT y
209
+ NOT (x AND y) -> NOT x OR NOT y
210
+ """
211
+ if isinstance(expression, exp.Not):
212
+ this = expression.this
213
+ if is_null(this):
214
+ return exp.null()
215
+ if this.__class__ in COMPLEMENT_COMPARISONS:
216
+ right = this.expression
217
+ complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__)
218
+ if complement_subquery_predicate:
219
+ right = complement_subquery_predicate(this=right.this)
220
+
221
+ return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
222
+ if isinstance(this, exp.Paren):
223
+ condition = this.unnest()
224
+ if isinstance(condition, exp.And):
225
+ return exp.paren(
226
+ exp.or_(
227
+ exp.not_(condition.left, copy=False),
228
+ exp.not_(condition.right, copy=False),
229
+ copy=False,
230
+ )
231
+ )
232
+ if isinstance(condition, exp.Or):
233
+ return exp.paren(
234
+ exp.and_(
235
+ exp.not_(condition.left, copy=False),
236
+ exp.not_(condition.right, copy=False),
237
+ copy=False,
238
+ )
239
+ )
240
+ if is_null(condition):
241
+ return exp.null()
242
+ if always_true(this):
243
+ return exp.false()
244
+ if is_false(this):
245
+ return exp.true()
246
+ if isinstance(this, exp.Not):
247
+ # double negation
248
+ # NOT NOT x -> x
249
+ return this.this
250
+ return expression
251
+
252
+
253
+ def flatten(expression):
254
+ """
255
+ A AND (B AND C) -> A AND B AND C
256
+ A OR (B OR C) -> A OR B OR C
257
+ """
258
+ if isinstance(expression, exp.Connector):
259
+ for node in expression.args.values():
260
+ child = node.unnest()
261
+ if isinstance(child, expression.__class__):
262
+ node.replace(child)
263
+ return expression
264
+
265
+
266
+ def simplify_connectors(expression, root=True):
267
+ def _simplify_connectors(expression, left, right):
268
+ if isinstance(expression, exp.And):
269
+ if is_false(left) or is_false(right):
270
+ return exp.false()
271
+ if is_zero(left) or is_zero(right):
272
+ return exp.false()
273
+ if is_null(left) or is_null(right):
274
+ return exp.null()
275
+ if always_true(left) and always_true(right):
276
+ return exp.true()
277
+ if always_true(left):
278
+ return right
279
+ if always_true(right):
280
+ return left
281
+ return _simplify_comparison(expression, left, right)
282
+ elif isinstance(expression, exp.Or):
283
+ if always_true(left) or always_true(right):
284
+ return exp.true()
285
+ if (
286
+ (is_null(left) and is_null(right))
287
+ or (is_null(left) and always_false(right))
288
+ or (always_false(left) and is_null(right))
289
+ ):
290
+ return exp.null()
291
+ if is_false(left):
292
+ return right
293
+ if is_false(right):
294
+ return left
295
+ return _simplify_comparison(expression, left, right, or_=True)
296
+ elif isinstance(expression, exp.Xor):
297
+ if left == right:
298
+ return exp.false()
299
+
300
+ if isinstance(expression, exp.Connector):
301
+ return _flat_simplify(expression, _simplify_connectors, root)
302
+ return expression
303
+
304
+
305
+ LT_LTE = (exp.LT, exp.LTE)
306
+ GT_GTE = (exp.GT, exp.GTE)
307
+
308
+ COMPARISONS = (
309
+ *LT_LTE,
310
+ *GT_GTE,
311
+ exp.EQ,
312
+ exp.NEQ,
313
+ exp.Is,
314
+ )
315
+
316
+ INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
317
+ exp.LT: exp.GT,
318
+ exp.GT: exp.LT,
319
+ exp.LTE: exp.GTE,
320
+ exp.GTE: exp.LTE,
321
+ }
322
+
323
+ NONDETERMINISTIC = (exp.Rand, exp.Randn)
324
+ AND_OR = (exp.And, exp.Or)
325
+
326
+
327
+ def _simplify_comparison(expression, left, right, or_=False):
328
+ if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
329
+ ll, lr = left.args.values()
330
+ rl, rr = right.args.values()
331
+
332
+ largs = {ll, lr}
333
+ rargs = {rl, rr}
334
+
335
+ matching = largs & rargs
336
+ columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)}
337
+
338
+ if matching and columns:
339
+ try:
340
+ l = first(largs - columns)
341
+ r = first(rargs - columns)
342
+ except StopIteration:
343
+ return expression
344
+
345
+ if l.is_number and r.is_number:
346
+ l = l.to_py()
347
+ r = r.to_py()
348
+ elif l.is_string and r.is_string:
349
+ l = l.name
350
+ r = r.name
351
+ else:
352
+ l = extract_date(l)
353
+ if not l:
354
+ return None
355
+ r = extract_date(r)
356
+ if not r:
357
+ return None
358
+ # python won't compare date and datetime, but many engines will upcast
359
+ l, r = cast_as_datetime(l), cast_as_datetime(r)
360
+
361
+ for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
362
+ if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
363
+ return left if (av > bv if or_ else av <= bv) else right
364
+ if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
365
+ return left if (av < bv if or_ else av >= bv) else right
366
+
367
+ # we can't ever shortcut to true because the column could be null
368
+ if not or_:
369
+ if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
370
+ if av <= bv:
371
+ return exp.false()
372
+ elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
373
+ if av >= bv:
374
+ return exp.false()
375
+ elif isinstance(a, exp.EQ):
376
+ if isinstance(b, exp.LT):
377
+ return exp.false() if av >= bv else a
378
+ if isinstance(b, exp.LTE):
379
+ return exp.false() if av > bv else a
380
+ if isinstance(b, exp.GT):
381
+ return exp.false() if av <= bv else a
382
+ if isinstance(b, exp.GTE):
383
+ return exp.false() if av < bv else a
384
+ if isinstance(b, exp.NEQ):
385
+ return exp.false() if av == bv else a
386
+ return None
387
+
388
+
389
+ def remove_complements(expression, root=True):
390
+ """
391
+ Removing complements.
392
+
393
+ A AND NOT A -> FALSE
394
+ A OR NOT A -> TRUE
395
+ """
396
+ if isinstance(expression, AND_OR) and (root or not expression.same_parent):
397
+ ops = set(expression.flatten())
398
+ for op in ops:
399
+ if isinstance(op, exp.Not) and op.this in ops:
400
+ return exp.false() if isinstance(expression, exp.And) else exp.true()
401
+
402
+ return expression
403
+
404
+
405
+ def uniq_sort(expression, root=True):
406
+ """
407
+ Uniq and sort a connector.
408
+
409
+ C AND A AND B AND B -> A AND B AND C
410
+ """
411
+ if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
412
+ flattened = tuple(expression.flatten())
413
+
414
+ if isinstance(expression, exp.Xor):
415
+ result_func = exp.xor
416
+ # Do not deduplicate XOR as A XOR A != A if A == True
417
+ deduped = None
418
+ arr = tuple((gen(e), e) for e in flattened)
419
+ else:
420
+ result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
421
+ deduped = {gen(e): e for e in flattened}
422
+ arr = tuple(deduped.items())
423
+
424
+ # check if the operands are already sorted, if not sort them
425
+ # A AND C AND B -> A AND B AND C
426
+ for i, (sql, e) in enumerate(arr[1:]):
427
+ if sql < arr[i][0]:
428
+ expression = result_func(*(e for _, e in sorted(arr)), copy=False)
429
+ break
430
+ else:
431
+ # we didn't have to sort but maybe we need to dedup
432
+ if deduped and len(deduped) < len(flattened):
433
+ expression = result_func(*deduped.values(), copy=False)
434
+
435
+ return expression
436
+
437
+
438
+ def absorb_and_eliminate(expression, root=True):
439
+ """
440
+ absorption:
441
+ A AND (A OR B) -> A
442
+ A OR (A AND B) -> A
443
+ A AND (NOT A OR B) -> A AND B
444
+ A OR (NOT A AND B) -> A OR B
445
+ elimination:
446
+ (A AND B) OR (A AND NOT B) -> A
447
+ (A OR B) AND (A OR NOT B) -> A
448
+ """
449
+ if isinstance(expression, AND_OR) and (root or not expression.same_parent):
450
+ kind = exp.Or if isinstance(expression, exp.And) else exp.And
451
+
452
+ ops = tuple(expression.flatten())
453
+
454
+ # Initialize lookup tables:
455
+ # Set of all operands, used to find complements for absorption.
456
+ op_set = set()
457
+ # Sub-operands, used to find subsets for absorption.
458
+ subops = defaultdict(list)
459
+ # Pairs of complements, used for elimination.
460
+ pairs = defaultdict(list)
461
+
462
+ # Populate the lookup tables
463
+ for op in ops:
464
+ op_set.add(op)
465
+
466
+ if not isinstance(op, kind):
467
+ # In cases like: A OR (A AND B)
468
+ # Subop will be: ^
469
+ subops[op].append({op})
470
+ continue
471
+
472
+ # In cases like: (A AND B) OR (A AND B AND C)
473
+ # Subops will be: ^ ^
474
+ subset = set(op.flatten())
475
+ for i in subset:
476
+ subops[i].append(subset)
477
+
478
+ a, b = op.unnest_operands()
479
+ if isinstance(a, exp.Not):
480
+ pairs[frozenset((a.this, b))].append((op, b))
481
+ if isinstance(b, exp.Not):
482
+ pairs[frozenset((a, b.this))].append((op, a))
483
+
484
+ for op in ops:
485
+ if not isinstance(op, kind):
486
+ continue
487
+
488
+ a, b = op.unnest_operands()
489
+
490
+ # Absorb
491
+ if isinstance(a, exp.Not) and a.this in op_set:
492
+ a.replace(exp.true() if kind == exp.And else exp.false())
493
+ continue
494
+ if isinstance(b, exp.Not) and b.this in op_set:
495
+ b.replace(exp.true() if kind == exp.And else exp.false())
496
+ continue
497
+ superset = set(op.flatten())
498
+ if any(any(subset < superset for subset in subops[i]) for i in superset):
499
+ op.replace(exp.false() if kind == exp.And else exp.true())
500
+ continue
501
+
502
+ # Eliminate
503
+ for other, complement in pairs[frozenset((a, b))]:
504
+ op.replace(complement)
505
+ other.replace(complement)
506
+
507
+ return expression
508
+
509
+
510
+ def propagate_constants(expression, root=True):
511
+ """
512
+ Propagate constants for conjunctions in DNF:
513
+
514
+ SELECT * FROM t WHERE a = b AND b = 5 becomes
515
+ SELECT * FROM t WHERE a = 5 AND b = 5
516
+
517
+ Reference: https://www.sqlite.org/optoverview.html
518
+ """
519
+
520
+ if (
521
+ isinstance(expression, exp.And)
522
+ and (root or not expression.same_parent)
523
+ and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
524
+ ):
525
+ constant_mapping = {}
526
+ for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
527
+ if isinstance(expr, exp.EQ):
528
+ l, r = expr.left, expr.right
529
+
530
+ # TODO: create a helper that can be used to detect nested literal expressions such
531
+ # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
532
+ if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
533
+ constant_mapping[l] = (id(l), r)
534
+
535
+ if constant_mapping:
536
+ for column in find_all_in_scope(expression, exp.Column):
537
+ parent = column.parent
538
+ column_id, constant = constant_mapping.get(column) or (None, None)
539
+ if (
540
+ column_id is not None
541
+ and id(column) != column_id
542
+ and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
543
+ ):
544
+ column.replace(constant.copy())
545
+
546
+ return expression
547
+
548
+
549
+ INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
550
+ exp.DateAdd: exp.Sub,
551
+ exp.DateSub: exp.Add,
552
+ exp.DatetimeAdd: exp.Sub,
553
+ exp.DatetimeSub: exp.Add,
554
+ }
555
+
556
+ INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
557
+ **INVERSE_DATE_OPS,
558
+ exp.Add: exp.Sub,
559
+ exp.Sub: exp.Add,
560
+ }
561
+
562
+
563
+ def _is_number(expression: exp.Expression) -> bool:
564
+ return expression.is_number
565
+
566
+
567
+ def _is_interval(expression: exp.Expression) -> bool:
568
+ return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
569
+
570
+
571
+ @catch(ModuleNotFoundError, UnsupportedUnit)
572
+ def simplify_equality(expression: exp.Expression) -> exp.Expression:
573
+ """
574
+ Use the subtraction and addition properties of equality to simplify expressions:
575
+
576
+ x + 1 = 3 becomes x = 2
577
+
578
+ There are two binary operations in the above expression: + and =
579
+ Here's how we reference all the operands in the code below:
580
+
581
+ l r
582
+ x + 1 = 3
583
+ a b
584
+ """
585
+ if isinstance(expression, COMPARISONS):
586
+ l, r = expression.left, expression.right
587
+
588
+ if l.__class__ not in INVERSE_OPS:
589
+ return expression
590
+
591
+ if r.is_number:
592
+ a_predicate = _is_number
593
+ b_predicate = _is_number
594
+ elif _is_date_literal(r):
595
+ a_predicate = _is_date_literal
596
+ b_predicate = _is_interval
597
+ else:
598
+ return expression
599
+
600
+ if l.__class__ in INVERSE_DATE_OPS:
601
+ l = t.cast(exp.IntervalOp, l)
602
+ a = l.this
603
+ b = l.interval()
604
+ else:
605
+ l = t.cast(exp.Binary, l)
606
+ a, b = l.left, l.right
607
+
608
+ if not a_predicate(a) and b_predicate(b):
609
+ pass
610
+ elif not a_predicate(b) and b_predicate(a):
611
+ a, b = b, a
612
+ else:
613
+ return expression
614
+
615
+ return expression.__class__(
616
+ this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
617
+ )
618
+ return expression
619
+
620
+
621
+ def simplify_literals(expression, root=True):
622
+ if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
623
+ return _flat_simplify(expression, _simplify_binary, root)
624
+
625
+ if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
626
+ return expression.this.this
627
+
628
+ if type(expression) in INVERSE_DATE_OPS:
629
+ return _simplify_binary(expression, expression.this, expression.interval()) or expression
630
+
631
+ return expression
632
+
633
+
634
+ NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
635
+
636
+
637
+ def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
638
+ if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
639
+ this = _simplify_integer_cast(expr.this)
640
+ else:
641
+ this = expr.this
642
+
643
+ if isinstance(expr, exp.Cast) and this.is_int:
644
+ num = this.to_py()
645
+
646
+ # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
647
+ # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
648
+ # engine-dependent
649
+ if (
650
+ TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
651
+ ) or (
652
+ UTINYINT_MIN <= num <= UTINYINT_MAX
653
+ and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
654
+ ):
655
+ return this
656
+
657
+ return expr
658
+
659
+
660
+ def _simplify_binary(expression, a, b):
661
+ if isinstance(expression, COMPARISONS):
662
+ a = _simplify_integer_cast(a)
663
+ b = _simplify_integer_cast(b)
664
+
665
+ if isinstance(expression, exp.Is):
666
+ if isinstance(b, exp.Not):
667
+ c = b.this
668
+ not_ = True
669
+ else:
670
+ c = b
671
+ not_ = False
672
+
673
+ if is_null(c):
674
+ if isinstance(a, exp.Literal):
675
+ return exp.true() if not_ else exp.false()
676
+ if is_null(a):
677
+ return exp.false() if not_ else exp.true()
678
+ elif isinstance(expression, NULL_OK):
679
+ return None
680
+ elif is_null(a) or is_null(b):
681
+ return exp.null()
682
+
683
+ if a.is_number and b.is_number:
684
+ num_a = a.to_py()
685
+ num_b = b.to_py()
686
+
687
+ if isinstance(expression, exp.Add):
688
+ return exp.Literal.number(num_a + num_b)
689
+ if isinstance(expression, exp.Mul):
690
+ return exp.Literal.number(num_a * num_b)
691
+
692
+ # We only simplify Sub, Div if a and b have the same parent because they're not associative
693
+ if isinstance(expression, exp.Sub):
694
+ return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
695
+ if isinstance(expression, exp.Div):
696
+ # engines have differing int div behavior so intdiv is not safe
697
+ if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
698
+ return None
699
+ return exp.Literal.number(num_a / num_b)
700
+
701
+ boolean = eval_boolean(expression, num_a, num_b)
702
+
703
+ if boolean:
704
+ return boolean
705
+ elif a.is_string and b.is_string:
706
+ boolean = eval_boolean(expression, a.this, b.this)
707
+
708
+ if boolean:
709
+ return boolean
710
+ elif _is_date_literal(a) and isinstance(b, exp.Interval):
711
+ date, b = extract_date(a), extract_interval(b)
712
+ if date and b:
713
+ if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
714
+ return date_literal(date + b, extract_type(a))
715
+ if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
716
+ return date_literal(date - b, extract_type(a))
717
+ elif isinstance(a, exp.Interval) and _is_date_literal(b):
718
+ a, date = extract_interval(a), extract_date(b)
719
+ # you cannot subtract a date from an interval
720
+ if a and b and isinstance(expression, exp.Add):
721
+ return date_literal(a + date, extract_type(b))
722
+ elif _is_date_literal(a) and _is_date_literal(b):
723
+ if isinstance(expression, exp.Predicate):
724
+ a, b = extract_date(a), extract_date(b)
725
+ boolean = eval_boolean(expression, a, b)
726
+ if boolean:
727
+ return boolean
728
+
729
+ return None
730
+
731
+
732
+ def simplify_parens(expression):
733
+ if not isinstance(expression, exp.Paren):
734
+ return expression
735
+
736
+ this = expression.this
737
+ parent = expression.parent
738
+ parent_is_predicate = isinstance(parent, exp.Predicate)
739
+
740
+ if (
741
+ not isinstance(this, exp.Select)
742
+ and not isinstance(parent, (exp.SubqueryPredicate, exp.Bracket))
743
+ and (
744
+ not isinstance(parent, (exp.Condition, exp.Binary))
745
+ or isinstance(parent, exp.Paren)
746
+ or (
747
+ not isinstance(this, exp.Binary)
748
+ and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
749
+ )
750
+ or (isinstance(this, exp.Predicate) and not parent_is_predicate)
751
+ or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
752
+ or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
753
+ or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
754
+ )
755
+ ):
756
+ return this
757
+ return expression
758
+
759
+
760
+ def _is_nonnull_constant(expression: exp.Expression) -> bool:
761
+ return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
762
+
763
+
764
+ def _is_constant(expression: exp.Expression) -> bool:
765
+ return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression)
766
+
767
+
768
+ def simplify_coalesce(expression: exp.Expression, dialect: DialectType) -> exp.Expression:
769
+ # COALESCE(x) -> x
770
+ if (
771
+ isinstance(expression, exp.Coalesce)
772
+ and (not expression.expressions or _is_nonnull_constant(expression.this))
773
+ # COALESCE is also used as a Spark partitioning hint
774
+ and not isinstance(expression.parent, exp.Hint)
775
+ ):
776
+ return expression.this
777
+
778
+ # We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift,
779
+ # because they are not always equivalent. For example, if `x` is `NULL` and it comes
780
+ # from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE`
781
+ if dialect == "redshift":
782
+ return expression
783
+
784
+ if not isinstance(expression, COMPARISONS):
785
+ return expression
786
+
787
+ if isinstance(expression.left, exp.Coalesce):
788
+ coalesce = expression.left
789
+ other = expression.right
790
+ elif isinstance(expression.right, exp.Coalesce):
791
+ coalesce = expression.right
792
+ other = expression.left
793
+ else:
794
+ return expression
795
+
796
+ # This transformation is valid for non-constants,
797
+ # but it really only does anything if they are both constants.
798
+ if not _is_constant(other):
799
+ return expression
800
+
801
+ # Find the first constant arg
802
+ for arg_index, arg in enumerate(coalesce.expressions):
803
+ if _is_constant(arg):
804
+ break
805
+ else:
806
+ return expression
807
+
808
+ coalesce.set("expressions", coalesce.expressions[:arg_index])
809
+
810
+ # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
811
+ # since we already remove COALESCE at the top of this function.
812
+ coalesce = coalesce if coalesce.expressions else coalesce.this
813
+
814
+ # This expression is more complex than when we started, but it will get simplified further
815
+ return exp.paren(
816
+ exp.or_(
817
+ exp.and_(
818
+ coalesce.is_(exp.null()).not_(copy=False),
819
+ expression.copy(),
820
+ copy=False,
821
+ ),
822
+ exp.and_(
823
+ coalesce.is_(exp.null()),
824
+ type(expression)(this=arg.copy(), expression=other.copy()),
825
+ copy=False,
826
+ ),
827
+ copy=False,
828
+ )
829
+ )
830
+
831
+
832
+ CONCATS = (exp.Concat, exp.DPipe)
833
+
834
+
835
+ def simplify_concat(expression):
836
+ """Reduces all groups that contain string literals by concatenating them."""
837
+ if not isinstance(expression, CONCATS) or (
838
+ # We can't reduce a CONCAT_WS call if we don't statically know the separator
839
+ isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
840
+ ):
841
+ return expression
842
+
843
+ if isinstance(expression, exp.ConcatWs):
844
+ sep_expr, *expressions = expression.expressions
845
+ sep = sep_expr.name
846
+ concat_type = exp.ConcatWs
847
+ args = {}
848
+ else:
849
+ expressions = expression.expressions
850
+ sep = ""
851
+ concat_type = exp.Concat
852
+ args = {
853
+ "safe": expression.args.get("safe"),
854
+ "coalesce": expression.args.get("coalesce"),
855
+ }
856
+
857
+ new_args = []
858
+ for is_string_group, group in itertools.groupby(
859
+ expressions or expression.flatten(), lambda e: e.is_string
860
+ ):
861
+ if is_string_group:
862
+ new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
863
+ else:
864
+ new_args.extend(group)
865
+
866
+ if len(new_args) == 1 and new_args[0].is_string:
867
+ return new_args[0]
868
+
869
+ if concat_type is exp.ConcatWs:
870
+ new_args = [sep_expr] + new_args
871
+ elif isinstance(expression, exp.DPipe):
872
+ return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
873
+
874
+ return concat_type(expressions=new_args, **args)
875
+
876
+
877
+ def simplify_conditionals(expression):
878
+ """Simplifies expressions like IF, CASE if their condition is statically known."""
879
+ if isinstance(expression, exp.Case):
880
+ this = expression.this
881
+ for case in expression.args["ifs"]:
882
+ cond = case.this
883
+ if this:
884
+ # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
885
+ cond = cond.replace(this.pop().eq(cond))
886
+
887
+ if always_true(cond):
888
+ return case.args["true"]
889
+
890
+ if always_false(cond):
891
+ case.pop()
892
+ if not expression.args["ifs"]:
893
+ return expression.args.get("default") or exp.null()
894
+ elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
895
+ if always_true(expression.this):
896
+ return expression.args["true"]
897
+ if always_false(expression.this):
898
+ return expression.args.get("false") or exp.null()
899
+
900
+ return expression
901
+
902
+
903
+ def simplify_startswith(expression: exp.Expression) -> exp.Expression:
904
+ """
905
+ Reduces a prefix check to either TRUE or FALSE if both the string and the
906
+ prefix are statically known.
907
+
908
+ Example:
909
+ >>> from sqlglot import parse_one
910
+ >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
911
+ 'TRUE'
912
+ """
913
+ if (
914
+ isinstance(expression, exp.StartsWith)
915
+ and expression.this.is_string
916
+ and expression.expression.is_string
917
+ ):
918
+ return exp.convert(expression.name.startswith(expression.expression.name))
919
+
920
+ return expression
921
+
922
+
923
+ DateRange = t.Tuple[datetime.date, datetime.date]
924
+
925
+
926
+ def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
927
+ """
928
+ Get the date range for a DATE_TRUNC equality comparison:
929
+
930
+ Example:
931
+ _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
932
+ Returns:
933
+ tuple of [min, max) or None if a value can never be equal to `date` for `unit`
934
+ """
935
+ floor = date_floor(date, unit, dialect)
936
+
937
+ if date != floor:
938
+ # This will always be False, except for NULL values.
939
+ return None
940
+
941
+ return floor, floor + interval(unit)
942
+
943
+
944
+ def _datetrunc_eq_expression(
945
+ left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType]
946
+ ) -> exp.Expression:
947
+ """Get the logical expression for a date range"""
948
+ return exp.and_(
949
+ left >= date_literal(drange[0], target_type),
950
+ left < date_literal(drange[1], target_type),
951
+ copy=False,
952
+ )
953
+
954
+
955
+ def _datetrunc_eq(
956
+ left: exp.Expression,
957
+ date: datetime.date,
958
+ unit: str,
959
+ dialect: Dialect,
960
+ target_type: t.Optional[exp.DataType],
961
+ ) -> t.Optional[exp.Expression]:
962
+ drange = _datetrunc_range(date, unit, dialect)
963
+ if not drange:
964
+ return None
965
+
966
+ return _datetrunc_eq_expression(left, drange, target_type)
967
+
968
+
969
+ def _datetrunc_neq(
970
+ left: exp.Expression,
971
+ date: datetime.date,
972
+ unit: str,
973
+ dialect: Dialect,
974
+ target_type: t.Optional[exp.DataType],
975
+ ) -> t.Optional[exp.Expression]:
976
+ drange = _datetrunc_range(date, unit, dialect)
977
+ if not drange:
978
+ return None
979
+
980
+ return exp.and_(
981
+ left < date_literal(drange[0], target_type),
982
+ left >= date_literal(drange[1], target_type),
983
+ copy=False,
984
+ )
985
+
986
+
987
+ DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
988
+ exp.LT: lambda l, dt, u, d, t: l
989
+ < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
990
+ exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
991
+ exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
992
+ exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
993
+ exp.EQ: _datetrunc_eq,
994
+ exp.NEQ: _datetrunc_neq,
995
+ }
996
+ DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
997
+ DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
998
+
999
+
1000
+ def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
1001
+ return isinstance(left, DATETRUNCS) and _is_date_literal(right)
1002
+
1003
+
1004
+ @catch(ModuleNotFoundError, UnsupportedUnit)
1005
+ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
1006
+ """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
1007
+ comparison = expression.__class__
1008
+
1009
+ if isinstance(expression, DATETRUNCS):
1010
+ this = expression.this
1011
+ trunc_type = extract_type(this)
1012
+ date = extract_date(this)
1013
+ if date and expression.unit:
1014
+ return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type)
1015
+ elif comparison not in DATETRUNC_COMPARISONS:
1016
+ return expression
1017
+
1018
+ if isinstance(expression, exp.Binary):
1019
+ l, r = expression.left, expression.right
1020
+
1021
+ if not _is_datetrunc_predicate(l, r):
1022
+ return expression
1023
+
1024
+ l = t.cast(exp.DateTrunc, l)
1025
+ trunc_arg = l.this
1026
+ unit = l.unit.name.lower()
1027
+ date = extract_date(r)
1028
+
1029
+ if not date:
1030
+ return expression
1031
+
1032
+ return (
1033
+ DATETRUNC_BINARY_COMPARISONS[comparison](
1034
+ trunc_arg, date, unit, dialect, extract_type(r)
1035
+ )
1036
+ or expression
1037
+ )
1038
+
1039
+ if isinstance(expression, exp.In):
1040
+ l = expression.this
1041
+ rs = expression.expressions
1042
+
1043
+ if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
1044
+ l = t.cast(exp.DateTrunc, l)
1045
+ unit = l.unit.name.lower()
1046
+
1047
+ ranges = []
1048
+ for r in rs:
1049
+ date = extract_date(r)
1050
+ if not date:
1051
+ return expression
1052
+ drange = _datetrunc_range(date, unit, dialect)
1053
+ if drange:
1054
+ ranges.append(drange)
1055
+
1056
+ if not ranges:
1057
+ return expression
1058
+
1059
+ ranges = merge_ranges(ranges)
1060
+ target_type = extract_type(*rs)
1061
+
1062
+ return exp.or_(
1063
+ *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False
1064
+ )
1065
+
1066
+ return expression
1067
+
1068
+
1069
+ def sort_comparison(expression: exp.Expression) -> exp.Expression:
1070
+ if expression.__class__ in COMPLEMENT_COMPARISONS:
1071
+ l, r = expression.this, expression.expression
1072
+ l_column = isinstance(l, exp.Column)
1073
+ r_column = isinstance(r, exp.Column)
1074
+ l_const = _is_constant(l)
1075
+ r_const = _is_constant(r)
1076
+
1077
+ if (
1078
+ (l_column and not r_column)
1079
+ or (r_const and not l_const)
1080
+ or isinstance(r, exp.SubqueryPredicate)
1081
+ ):
1082
+ return expression
1083
+ if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1084
+ return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1085
+ this=r, expression=l
1086
+ )
1087
+ return expression
1088
+
1089
+
1090
+ # CROSS joins result in an empty table if the right table is empty.
1091
+ # So we can only simplify certain types of joins to CROSS.
1092
+ # Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
1093
+ JOINS = {
1094
+ ("", ""),
1095
+ ("", "INNER"),
1096
+ ("RIGHT", ""),
1097
+ ("RIGHT", "OUTER"),
1098
+ }
1099
+
1100
+
1101
+ def remove_where_true(expression):
1102
+ for where in expression.find_all(exp.Where):
1103
+ if always_true(where.this):
1104
+ where.pop()
1105
+ for join in expression.find_all(exp.Join):
1106
+ if (
1107
+ always_true(join.args.get("on"))
1108
+ and not join.args.get("using")
1109
+ and not join.args.get("method")
1110
+ and (join.side, join.kind) in JOINS
1111
+ ):
1112
+ join.args["on"].pop()
1113
+ join.set("side", None)
1114
+ join.set("kind", "CROSS")
1115
+
1116
+
1117
+ def always_true(expression):
1118
+ return (isinstance(expression, exp.Boolean) and expression.this) or (
1119
+ isinstance(expression, exp.Literal) and not is_zero(expression)
1120
+ )
1121
+
1122
+
1123
+ def always_false(expression):
1124
+ return is_false(expression) or is_null(expression) or is_zero(expression)
1125
+
1126
+
1127
+ def is_zero(expression):
1128
+ return isinstance(expression, exp.Literal) and expression.to_py() == 0
1129
+
1130
+
1131
+ def is_complement(a, b):
1132
+ return isinstance(b, exp.Not) and b.this == a
1133
+
1134
+
1135
+ def is_false(a: exp.Expression) -> bool:
1136
+ return type(a) is exp.Boolean and not a.this
1137
+
1138
+
1139
+ def is_null(a: exp.Expression) -> bool:
1140
+ return type(a) is exp.Null
1141
+
1142
+
1143
+ def eval_boolean(expression, a, b):
1144
+ if isinstance(expression, (exp.EQ, exp.Is)):
1145
+ return boolean_literal(a == b)
1146
+ if isinstance(expression, exp.NEQ):
1147
+ return boolean_literal(a != b)
1148
+ if isinstance(expression, exp.GT):
1149
+ return boolean_literal(a > b)
1150
+ if isinstance(expression, exp.GTE):
1151
+ return boolean_literal(a >= b)
1152
+ if isinstance(expression, exp.LT):
1153
+ return boolean_literal(a < b)
1154
+ if isinstance(expression, exp.LTE):
1155
+ return boolean_literal(a <= b)
1156
+ return None
1157
+
1158
+
1159
+ def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1160
+ if isinstance(value, datetime.datetime):
1161
+ return value.date()
1162
+ if isinstance(value, datetime.date):
1163
+ return value
1164
+ try:
1165
+ return datetime.datetime.fromisoformat(value).date()
1166
+ except ValueError:
1167
+ return None
1168
+
1169
+
1170
+ def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1171
+ if isinstance(value, datetime.datetime):
1172
+ return value
1173
+ if isinstance(value, datetime.date):
1174
+ return datetime.datetime(year=value.year, month=value.month, day=value.day)
1175
+ try:
1176
+ return datetime.datetime.fromisoformat(value)
1177
+ except ValueError:
1178
+ return None
1179
+
1180
+
1181
+ def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1182
+ if not value:
1183
+ return None
1184
+ if to.is_type(exp.DataType.Type.DATE):
1185
+ return cast_as_date(value)
1186
+ if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1187
+ return cast_as_datetime(value)
1188
+ return None
1189
+
1190
+
1191
+ def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1192
+ if isinstance(cast, exp.Cast):
1193
+ to = cast.to
1194
+ elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1195
+ to = exp.DataType.build(exp.DataType.Type.DATE)
1196
+ else:
1197
+ return None
1198
+
1199
+ if isinstance(cast.this, exp.Literal):
1200
+ value: t.Any = cast.this.name
1201
+ elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1202
+ value = extract_date(cast.this)
1203
+ else:
1204
+ return None
1205
+ return cast_value(value, to)
1206
+
1207
+
1208
+ def _is_date_literal(expression: exp.Expression) -> bool:
1209
+ return extract_date(expression) is not None
1210
+
1211
+
1212
+ def extract_interval(expression):
1213
+ try:
1214
+ n = int(expression.this.to_py())
1215
+ unit = expression.text("unit").lower()
1216
+ return interval(unit, n)
1217
+ except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1218
+ return None
1219
+
1220
+
1221
+ def extract_type(*expressions):
1222
+ target_type = None
1223
+ for expression in expressions:
1224
+ target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
1225
+ if target_type:
1226
+ break
1227
+
1228
+ return target_type
1229
+
1230
+
1231
+ def date_literal(date, target_type=None):
1232
+ if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
1233
+ target_type = (
1234
+ exp.DataType.Type.DATETIME
1235
+ if isinstance(date, datetime.datetime)
1236
+ else exp.DataType.Type.DATE
1237
+ )
1238
+
1239
+ return exp.cast(exp.Literal.string(date), target_type)
1240
+
1241
+
1242
+ def interval(unit: str, n: int = 1):
1243
+ from dateutil.relativedelta import relativedelta
1244
+
1245
+ if unit == "year":
1246
+ return relativedelta(years=1 * n)
1247
+ if unit == "quarter":
1248
+ return relativedelta(months=3 * n)
1249
+ if unit == "month":
1250
+ return relativedelta(months=1 * n)
1251
+ if unit == "week":
1252
+ return relativedelta(weeks=1 * n)
1253
+ if unit == "day":
1254
+ return relativedelta(days=1 * n)
1255
+ if unit == "hour":
1256
+ return relativedelta(hours=1 * n)
1257
+ if unit == "minute":
1258
+ return relativedelta(minutes=1 * n)
1259
+ if unit == "second":
1260
+ return relativedelta(seconds=1 * n)
1261
+
1262
+ raise UnsupportedUnit(f"Unsupported unit: {unit}")
1263
+
1264
+
1265
+ def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1266
+ if unit == "year":
1267
+ return d.replace(month=1, day=1)
1268
+ if unit == "quarter":
1269
+ if d.month <= 3:
1270
+ return d.replace(month=1, day=1)
1271
+ elif d.month <= 6:
1272
+ return d.replace(month=4, day=1)
1273
+ elif d.month <= 9:
1274
+ return d.replace(month=7, day=1)
1275
+ else:
1276
+ return d.replace(month=10, day=1)
1277
+ if unit == "month":
1278
+ return d.replace(month=d.month, day=1)
1279
+ if unit == "week":
1280
+ # Assuming week starts on Monday (0) and ends on Sunday (6)
1281
+ return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1282
+ if unit == "day":
1283
+ return d
1284
+
1285
+ raise UnsupportedUnit(f"Unsupported unit: {unit}")
1286
+
1287
+
1288
+ def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1289
+ floor = date_floor(d, unit, dialect)
1290
+
1291
+ if floor == d:
1292
+ return d
1293
+
1294
+ return floor + interval(unit)
1295
+
1296
+
1297
+ def boolean_literal(condition):
1298
+ return exp.true() if condition else exp.false()
1299
+
1300
+
1301
+ def _flat_simplify(expression, simplifier, root=True):
1302
+ if root or not expression.same_parent:
1303
+ operands = []
1304
+ queue = deque(expression.flatten(unnest=False))
1305
+ size = len(queue)
1306
+
1307
+ while queue:
1308
+ a = queue.popleft()
1309
+
1310
+ for b in queue:
1311
+ result = simplifier(expression, a, b)
1312
+
1313
+ if result and result is not expression:
1314
+ queue.remove(b)
1315
+ queue.appendleft(result)
1316
+ break
1317
+ else:
1318
+ operands.append(a)
1319
+
1320
+ if len(operands) < size:
1321
+ return functools.reduce(
1322
+ lambda a, b: expression.__class__(this=a, expression=b), operands
1323
+ )
1324
+ return expression
1325
+
1326
+
1327
+ def gen(expression: t.Any, comments: bool = False) -> str:
1328
+ """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1329
+
1330
+ Sorting and deduping sql is a necessary step for optimization. Calling the actual
1331
+ generator is expensive so we have a bare minimum sql generator here.
1332
+
1333
+ Args:
1334
+ expression: the expression to convert into a SQL string.
1335
+ comments: whether to include the expression's comments.
1336
+ """
1337
+ return Gen().gen(expression, comments=comments)
1338
+
1339
+
1340
+ class Gen:
1341
+ def __init__(self):
1342
+ self.stack = []
1343
+ self.sqls = []
1344
+
1345
+ def gen(self, expression: exp.Expression, comments: bool = False) -> str:
1346
+ self.stack = [expression]
1347
+ self.sqls.clear()
1348
+
1349
+ while self.stack:
1350
+ node = self.stack.pop()
1351
+
1352
+ if isinstance(node, exp.Expression):
1353
+ if comments and node.comments:
1354
+ self.stack.append(f" /*{','.join(node.comments)}*/")
1355
+
1356
+ exp_handler_name = f"{node.key}_sql"
1357
+
1358
+ if hasattr(self, exp_handler_name):
1359
+ getattr(self, exp_handler_name)(node)
1360
+ elif isinstance(node, exp.Func):
1361
+ self._function(node)
1362
+ else:
1363
+ key = node.key.upper()
1364
+ self.stack.append(f"{key} " if self._args(node) else key)
1365
+ elif type(node) is list:
1366
+ for n in reversed(node):
1367
+ if n is not None:
1368
+ self.stack.extend((n, ","))
1369
+ if node:
1370
+ self.stack.pop()
1371
+ else:
1372
+ if node is not None:
1373
+ self.sqls.append(str(node))
1374
+
1375
+ return "".join(self.sqls)
1376
+
1377
+ def add_sql(self, e: exp.Add) -> None:
1378
+ self._binary(e, " + ")
1379
+
1380
+ def alias_sql(self, e: exp.Alias) -> None:
1381
+ self.stack.extend(
1382
+ (
1383
+ e.args.get("alias"),
1384
+ " AS ",
1385
+ e.args.get("this"),
1386
+ )
1387
+ )
1388
+
1389
+ def and_sql(self, e: exp.And) -> None:
1390
+ self._binary(e, " AND ")
1391
+
1392
+ def anonymous_sql(self, e: exp.Anonymous) -> None:
1393
+ this = e.this
1394
+ if isinstance(this, str):
1395
+ name = this.upper()
1396
+ elif isinstance(this, exp.Identifier):
1397
+ name = this.this
1398
+ name = f'"{name}"' if this.quoted else name.upper()
1399
+ else:
1400
+ raise ValueError(
1401
+ f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1402
+ )
1403
+
1404
+ self.stack.extend(
1405
+ (
1406
+ ")",
1407
+ e.expressions,
1408
+ "(",
1409
+ name,
1410
+ )
1411
+ )
1412
+
1413
+ def between_sql(self, e: exp.Between) -> None:
1414
+ self.stack.extend(
1415
+ (
1416
+ e.args.get("high"),
1417
+ " AND ",
1418
+ e.args.get("low"),
1419
+ " BETWEEN ",
1420
+ e.this,
1421
+ )
1422
+ )
1423
+
1424
+ def boolean_sql(self, e: exp.Boolean) -> None:
1425
+ self.stack.append("TRUE" if e.this else "FALSE")
1426
+
1427
+ def bracket_sql(self, e: exp.Bracket) -> None:
1428
+ self.stack.extend(
1429
+ (
1430
+ "]",
1431
+ e.expressions,
1432
+ "[",
1433
+ e.this,
1434
+ )
1435
+ )
1436
+
1437
+ def column_sql(self, e: exp.Column) -> None:
1438
+ for p in reversed(e.parts):
1439
+ self.stack.extend((p, "."))
1440
+ self.stack.pop()
1441
+
1442
+ def datatype_sql(self, e: exp.DataType) -> None:
1443
+ self._args(e, 1)
1444
+ self.stack.append(f"{e.this.name} ")
1445
+
1446
+ def div_sql(self, e: exp.Div) -> None:
1447
+ self._binary(e, " / ")
1448
+
1449
+ def dot_sql(self, e: exp.Dot) -> None:
1450
+ self._binary(e, ".")
1451
+
1452
+ def eq_sql(self, e: exp.EQ) -> None:
1453
+ self._binary(e, " = ")
1454
+
1455
+ def from_sql(self, e: exp.From) -> None:
1456
+ self.stack.extend((e.this, "FROM "))
1457
+
1458
+ def gt_sql(self, e: exp.GT) -> None:
1459
+ self._binary(e, " > ")
1460
+
1461
+ def gte_sql(self, e: exp.GTE) -> None:
1462
+ self._binary(e, " >= ")
1463
+
1464
+ def identifier_sql(self, e: exp.Identifier) -> None:
1465
+ self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1466
+
1467
+ def ilike_sql(self, e: exp.ILike) -> None:
1468
+ self._binary(e, " ILIKE ")
1469
+
1470
+ def in_sql(self, e: exp.In) -> None:
1471
+ self.stack.append(")")
1472
+ self._args(e, 1)
1473
+ self.stack.extend(
1474
+ (
1475
+ "(",
1476
+ " IN ",
1477
+ e.this,
1478
+ )
1479
+ )
1480
+
1481
+ def intdiv_sql(self, e: exp.IntDiv) -> None:
1482
+ self._binary(e, " DIV ")
1483
+
1484
+ def is_sql(self, e: exp.Is) -> None:
1485
+ self._binary(e, " IS ")
1486
+
1487
+ def like_sql(self, e: exp.Like) -> None:
1488
+ self._binary(e, " Like ")
1489
+
1490
+ def literal_sql(self, e: exp.Literal) -> None:
1491
+ self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1492
+
1493
+ def lt_sql(self, e: exp.LT) -> None:
1494
+ self._binary(e, " < ")
1495
+
1496
+ def lte_sql(self, e: exp.LTE) -> None:
1497
+ self._binary(e, " <= ")
1498
+
1499
+ def mod_sql(self, e: exp.Mod) -> None:
1500
+ self._binary(e, " % ")
1501
+
1502
+ def mul_sql(self, e: exp.Mul) -> None:
1503
+ self._binary(e, " * ")
1504
+
1505
+ def neg_sql(self, e: exp.Neg) -> None:
1506
+ self._unary(e, "-")
1507
+
1508
+ def neq_sql(self, e: exp.NEQ) -> None:
1509
+ self._binary(e, " <> ")
1510
+
1511
+ def not_sql(self, e: exp.Not) -> None:
1512
+ self._unary(e, "NOT ")
1513
+
1514
+ def null_sql(self, e: exp.Null) -> None:
1515
+ self.stack.append("NULL")
1516
+
1517
+ def or_sql(self, e: exp.Or) -> None:
1518
+ self._binary(e, " OR ")
1519
+
1520
+ def paren_sql(self, e: exp.Paren) -> None:
1521
+ self.stack.extend(
1522
+ (
1523
+ ")",
1524
+ e.this,
1525
+ "(",
1526
+ )
1527
+ )
1528
+
1529
+ def sub_sql(self, e: exp.Sub) -> None:
1530
+ self._binary(e, " - ")
1531
+
1532
+ def subquery_sql(self, e: exp.Subquery) -> None:
1533
+ self._args(e, 2)
1534
+ alias = e.args.get("alias")
1535
+ if alias:
1536
+ self.stack.append(alias)
1537
+ self.stack.extend((")", e.this, "("))
1538
+
1539
+ def table_sql(self, e: exp.Table) -> None:
1540
+ self._args(e, 4)
1541
+ alias = e.args.get("alias")
1542
+ if alias:
1543
+ self.stack.append(alias)
1544
+ for p in reversed(e.parts):
1545
+ self.stack.extend((p, "."))
1546
+ self.stack.pop()
1547
+
1548
+ def tablealias_sql(self, e: exp.TableAlias) -> None:
1549
+ columns = e.columns
1550
+
1551
+ if columns:
1552
+ self.stack.extend((")", columns, "("))
1553
+
1554
+ self.stack.extend((e.this, " AS "))
1555
+
1556
+ def var_sql(self, e: exp.Var) -> None:
1557
+ self.stack.append(e.this)
1558
+
1559
+ def _binary(self, e: exp.Binary, op: str) -> None:
1560
+ self.stack.extend((e.expression, op, e.this))
1561
+
1562
+ def _unary(self, e: exp.Unary, op: str) -> None:
1563
+ self.stack.extend((e.this, op))
1564
+
1565
+ def _function(self, e: exp.Func) -> None:
1566
+ self.stack.extend(
1567
+ (
1568
+ ")",
1569
+ list(e.args.values()),
1570
+ "(",
1571
+ e.sql_name(),
1572
+ )
1573
+ )
1574
+
1575
+ def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1576
+ kvs = []
1577
+ arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1578
+
1579
+ for k in arg_types or arg_types:
1580
+ v = node.args.get(k)
1581
+
1582
+ if v is not None:
1583
+ kvs.append([f":{k}", v])
1584
+ if kvs:
1585
+ self.stack.append(kvs)
1586
+ return True
1587
+ return False