PostBOUND 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. postbound/__init__.py +211 -0
  2. postbound/_base.py +6 -0
  3. postbound/_bench.py +1012 -0
  4. postbound/_core.py +1153 -0
  5. postbound/_hints.py +1373 -0
  6. postbound/_jointree.py +1079 -0
  7. postbound/_pipelines.py +1121 -0
  8. postbound/_qep.py +1986 -0
  9. postbound/_stages.py +876 -0
  10. postbound/_validation.py +734 -0
  11. postbound/db/__init__.py +72 -0
  12. postbound/db/_db.py +2348 -0
  13. postbound/db/_duckdb.py +785 -0
  14. postbound/db/mysql.py +1195 -0
  15. postbound/db/postgres.py +4216 -0
  16. postbound/experiments/__init__.py +12 -0
  17. postbound/experiments/analysis.py +674 -0
  18. postbound/experiments/benchmarking.py +54 -0
  19. postbound/experiments/ceb.py +877 -0
  20. postbound/experiments/interactive.py +105 -0
  21. postbound/experiments/querygen.py +334 -0
  22. postbound/experiments/workloads.py +980 -0
  23. postbound/optimizer/__init__.py +92 -0
  24. postbound/optimizer/__init__.pyi +73 -0
  25. postbound/optimizer/_cardinalities.py +369 -0
  26. postbound/optimizer/_joingraph.py +1150 -0
  27. postbound/optimizer/dynprog.py +1825 -0
  28. postbound/optimizer/enumeration.py +432 -0
  29. postbound/optimizer/native.py +539 -0
  30. postbound/optimizer/noopt.py +54 -0
  31. postbound/optimizer/presets.py +147 -0
  32. postbound/optimizer/randomized.py +650 -0
  33. postbound/optimizer/tonic.py +1479 -0
  34. postbound/optimizer/ues.py +1607 -0
  35. postbound/qal/__init__.py +343 -0
  36. postbound/qal/_qal.py +9678 -0
  37. postbound/qal/formatter.py +1089 -0
  38. postbound/qal/parser.py +2344 -0
  39. postbound/qal/relalg.py +4257 -0
  40. postbound/qal/transform.py +2184 -0
  41. postbound/shortcuts.py +70 -0
  42. postbound/util/__init__.py +46 -0
  43. postbound/util/_errors.py +33 -0
  44. postbound/util/collections.py +490 -0
  45. postbound/util/dataframe.py +71 -0
  46. postbound/util/dicts.py +330 -0
  47. postbound/util/jsonize.py +68 -0
  48. postbound/util/logging.py +106 -0
  49. postbound/util/misc.py +168 -0
  50. postbound/util/networkx.py +401 -0
  51. postbound/util/numbers.py +438 -0
  52. postbound/util/proc.py +107 -0
  53. postbound/util/stats.py +37 -0
  54. postbound/util/system.py +48 -0
  55. postbound/util/typing.py +35 -0
  56. postbound/vis/__init__.py +5 -0
  57. postbound/vis/fdl.py +69 -0
  58. postbound/vis/graphs.py +48 -0
  59. postbound/vis/optimizer.py +538 -0
  60. postbound/vis/plots.py +84 -0
  61. postbound/vis/tonic.py +70 -0
  62. postbound/vis/trees.py +105 -0
  63. postbound-0.19.0.dist-info/METADATA +355 -0
  64. postbound-0.19.0.dist-info/RECORD +67 -0
  65. postbound-0.19.0.dist-info/WHEEL +5 -0
  66. postbound-0.19.0.dist-info/licenses/LICENSE.txt +202 -0
  67. postbound-0.19.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2184 @@
1
+ """This module provides tools to modify the contents of existing `SqlQuery` instances.
2
+
3
+ Since queries are designed as immutable data objects, these transformations operate by implementing new query instances.
4
+
5
+ The tools differ in their granularity, ranging from utilities that swap out individual expressions and predicates, to tools
6
+ that change the entire structure of the query.
7
+
8
+ Some important transformations include:
9
+
10
+ - `flatten_and_predicate`: Simplifies the predicate structure by moving all nested ``AND`` predicates to their parent ``AND``
11
+ - `extract_query_fragment`: Extracts parts of an original query based on a subset of its tables (i.e. induced join graph and
12
+ filter predicates)
13
+ - `add_ec_predicates`: Expands a querie's **WHERE** clause to include all join predicates that are implied by other (equi-)
14
+ joins
15
+ - `as_count_star_query` and `as_explain_analyze` change the query to be executed as **COUNT(*)** or **EXPLAIN ANALYZE**
16
+ respectively
17
+
18
+ In addition to these frequently-used transformations, there are also lots of utilities that add, remove, or modify specific
19
+ parts of queries, such as individual clauses or expressions.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import typing
25
+ import warnings
26
+ from collections.abc import Callable, Iterable, Sequence
27
+ from typing import Optional, overload
28
+
29
+ from .. import util
30
+ from ._qal import (
31
+ AbstractPredicate,
32
+ BaseClause,
33
+ BaseProjection,
34
+ BetweenPredicate,
35
+ BinaryPredicate,
36
+ CaseExpression,
37
+ CastExpression,
38
+ ClauseVisitor,
39
+ ColumnExpression,
40
+ ColumnReference,
41
+ CommonTableExpression,
42
+ CompoundOperator,
43
+ CompoundPredicate,
44
+ DirectTableSource,
45
+ ExceptClause,
46
+ Explain,
47
+ ExplicitFromClause,
48
+ ExplicitSqlQuery,
49
+ From,
50
+ FunctionExpression,
51
+ FunctionTableSource,
52
+ GroupBy,
53
+ Having,
54
+ Hint,
55
+ ImplicitFromClause,
56
+ ImplicitSqlQuery,
57
+ InPredicate,
58
+ IntersectClause,
59
+ JoinTableSource,
60
+ JoinType,
61
+ Limit,
62
+ MathExpression,
63
+ MixedSqlQuery,
64
+ OrderBy,
65
+ OrderByExpression,
66
+ PredicateVisitor,
67
+ Select,
68
+ SelectStatement,
69
+ SetQuery,
70
+ SqlExpression,
71
+ SqlExpressionVisitor,
72
+ SqlQuery,
73
+ StarExpression,
74
+ StaticValueExpression,
75
+ SubqueryExpression,
76
+ SubqueryTableSource,
77
+ TableReference,
78
+ TableSource,
79
+ UnaryPredicate,
80
+ UnionClause,
81
+ ValuesTableSource,
82
+ ValuesWithQuery,
83
+ Where,
84
+ WindowExpression,
85
+ WithQuery,
86
+ build_query,
87
+ determine_join_equivalence_classes,
88
+ generate_predicates_for_equivalence_classes,
89
+ )
90
+
91
+ # TODO: at a later point in time, the entire query traversal/modification logic could be refactored to use unified
92
+ # access instead of implementing the same pattern matching and traversal logic all over again
93
+
94
+ QueryType = typing.TypeVar("QueryType", bound=SqlQuery)
95
+ """The concrete class of a query.
96
+
97
+ This generic type is used for transformations that do not change the type of a query and operate on all the different query
98
+ types.
99
+ """
100
+
101
+ SelectQueryType = typing.TypeVar("SelectQueryType", bound=SelectStatement)
102
+ """The concrete class of a select query.
103
+
104
+ This generic type is used for transformations that do not change the type of a query and operate on all the different select
105
+ query types.
106
+ """
107
+
108
+ ClauseType = typing.TypeVar("ClauseType", bound=BaseClause)
109
+ """The concrete class of a clause.
110
+
111
+ This generic type is used for transformations that do not change the type of a clause and operate on all the different clause
112
+ types.
113
+ """
114
+
115
+ PredicateType = typing.TypeVar("PredicateType", bound=AbstractPredicate)
116
+ """The concrete type of a predicate.
117
+
118
+ This generic type is used for transformations that do not change the type of a predicate and operate on all the different
119
+ predicate types.
120
+ """
121
+
122
+
123
+ def flatten_and_predicate(predicate: AbstractPredicate) -> AbstractPredicate:
124
+ """Simplifies the predicate structure by moving all nested ``AND`` predicates to their parent ``AND`` predicate.
125
+
126
+ For example, consider the predicate ``(R.a = S.b AND R.a = 42) AND S.b = 24``. This is transformed into the flattened
127
+ equivalent conjunction ``R.a = S.b AND R.a = 42 AND S.b = 24``.
128
+
129
+ This procedure continues in a recursive manner, until the first disjunction or negation is encountered. All predicates
130
+ below that point are left as-is for the current branch of the predicate hierarchy.
131
+
132
+ Parameters
133
+ ----------
134
+ predicate : AbstractPredicate
135
+ The predicate to simplified
136
+
137
+ Returns
138
+ -------
139
+ AbstractPredicate
140
+ An equivalent version of the given `predicate`, with all conjunctions unnested
141
+ """
142
+ if not isinstance(predicate, CompoundPredicate):
143
+ return predicate
144
+
145
+ not_operation = predicate.operation == CompoundOperator.Not
146
+ or_operation = predicate.operation == CompoundOperator.Or
147
+ if not_operation or or_operation:
148
+ return predicate
149
+
150
+ flattened_children = set()
151
+ for child in predicate.children:
152
+ if (
153
+ isinstance(child, CompoundPredicate)
154
+ and child.operation == CompoundOperator.And
155
+ ):
156
+ flattened_child = flatten_and_predicate(child)
157
+ if isinstance(flattened_child, CompoundPredicate):
158
+ flattened_children |= set(flattened_child.children)
159
+ else:
160
+ flattened_children.add(flattened_child)
161
+ else:
162
+ flattened_children.add(child)
163
+
164
+ if len(flattened_children) == 1:
165
+ return util.simplify(flattened_children)
166
+
167
+ return CompoundPredicate.create_and(flattened_children)
168
+
169
+
170
+ def explicit_to_implicit(source_query: ExplicitSqlQuery) -> ImplicitSqlQuery:
171
+ """Transforms a query with an explicit ``FROM`` clause to a query with an implicit ``FROM`` clause.
172
+
173
+ Currently, this process is only supported for explicit queries that do not contain subqueries in their ``FROM`` clause.
174
+
175
+ Parameters
176
+ ----------
177
+ source_query : ExplicitSqlQuery
178
+ The query that should be transformed
179
+
180
+ Returns
181
+ -------
182
+ ImplicitSqlQuery
183
+ An equivalent version of the given query, using an implicit ``FROM`` clause
184
+
185
+ Raises
186
+ ------
187
+ ValueError
188
+ If the `source_query` contains subquery table sources
189
+ """
190
+ additional_predicates: list[AbstractPredicate] = []
191
+ complete_from_tables: list[TableReference] = []
192
+ join_working_set: list[TableSource] = list(source_query.from_clause.items)
193
+
194
+ while join_working_set:
195
+ current_table_source = join_working_set.pop()
196
+
197
+ match current_table_source:
198
+ case DirectTableSource():
199
+ complete_from_tables.append(current_table_source.table)
200
+ case SubqueryTableSource():
201
+ raise ValueError(
202
+ "Transforming subqueries to implicit table references is not supported yet"
203
+ )
204
+ case JoinTableSource() if (
205
+ current_table_source.join_type == JoinType.InnerJoin
206
+ ):
207
+ join_working_set.append(current_table_source.left)
208
+ join_working_set.append(current_table_source.right)
209
+ additional_predicates.append(current_table_source.join_condition)
210
+ case _:
211
+ raise ValueError(
212
+ "Unsupported table source type: " + str(type(current_table_source))
213
+ )
214
+
215
+ final_from_clause = ImplicitFromClause.create_for(complete_from_tables)
216
+
217
+ if source_query.where_clause:
218
+ final_predicate = CompoundPredicate.create_and(
219
+ [source_query.where_clause.predicate] + additional_predicates
220
+ )
221
+ else:
222
+ final_predicate = CompoundPredicate.create_and(additional_predicates)
223
+
224
+ final_predicate = flatten_and_predicate(final_predicate)
225
+ final_where_clause = Where(final_predicate)
226
+
227
+ return ImplicitSqlQuery(
228
+ select_clause=source_query.select_clause,
229
+ from_clause=final_from_clause,
230
+ where_clause=final_where_clause,
231
+ groupby_clause=source_query.groupby_clause,
232
+ having_clause=source_query.having_clause,
233
+ orderby_clause=source_query.orderby_clause,
234
+ limit_clause=source_query.limit_clause,
235
+ cte_clause=source_query.cte_clause,
236
+ )
237
+
238
+
239
+ def _get_predicate_fragment(
240
+ predicate: AbstractPredicate, referenced_tables: set[TableReference]
241
+ ) -> AbstractPredicate | None:
242
+ """Filters the predicate hierarchy to include only those base predicates that reference the given tables.
243
+
244
+ The referenced tables operate as a superset - parts of the predicate are retained if the tables that they reference are a
245
+ subset of the target tables.
246
+
247
+ In the general case, the resulting predicate is no longer equivalent to the original predicate, since large portions of
248
+ the predicate hierarchy are pruned (exactly those that do not touch the given tables). Simplifications will be applied to
249
+ the predicate as necessary. For example, if only a single child predicate of a conjunction references the given tables, the
250
+ conjunction is removed and the child predicate is inserted instead.
251
+
252
+ Notice that no logical simplifications are applied. For example, if the resulting predicate fragment is the conjunction
253
+ ``R.a < 42 AND R.a < 84``, this is not simplified into ``R.a < 42``.
254
+
255
+ Parameters
256
+ ----------
257
+ predicate : AbstractPredicate
258
+ The predicate to filter
259
+ referenced_tables : set[TableReference]
260
+ The superset of all allowed tables. Those parts of the `predicate` are pruned, whose tables are not a subset of the
261
+
262
+ Returns
263
+ -------
264
+ AbstractPredicate | None
265
+ The largest fragment of the original predicate that references only a subset of the `referenced_tables`. If the entire
266
+ predicate was pruned, ``None`` is returned.
267
+
268
+ Examples
269
+ --------
270
+ Consider the predicate ``R.a > 100 AND (S.b = 42 OR R.a = 42)``. The predicate fragment for table ``R`` would be
271
+ ``R.a > 100 AND R.a = 42``.
272
+ """
273
+ if not isinstance(predicate, CompoundPredicate):
274
+ return predicate if predicate.tables().issubset(referenced_tables) else None
275
+
276
+ compound_predicate: CompoundPredicate = predicate
277
+ child_fragments = [
278
+ _get_predicate_fragment(child, referenced_tables)
279
+ for child in compound_predicate.children
280
+ ]
281
+ child_fragments = [fragment for fragment in child_fragments if fragment]
282
+ if not child_fragments:
283
+ return None
284
+ elif (
285
+ len(child_fragments) == 1
286
+ and compound_predicate.operation != CompoundOperator.Not
287
+ ):
288
+ return child_fragments[0]
289
+ else:
290
+ return CompoundPredicate(compound_predicate.operation, child_fragments)
291
+
292
+
293
+ def extract_query_fragment(
294
+ source_query: SelectQueryType,
295
+ referenced_tables: TableReference | Iterable[TableReference],
296
+ ) -> Optional[SelectQueryType]:
297
+ """Filters a query to only include parts that reference specific tables.
298
+
299
+ This builds a new query from the given query that contains exactly those parts of the original query's clauses that
300
+ reference only the given tables or a subset of them.
301
+
302
+ For example, consider the query ``SELECT * FROM R, S, T WHERE R.a = S.b AND S.c = T.d AND R.a = 42 ORDER BY S.b``
303
+ the query fragment for tables ``R`` and ``S`` would look like this:
304
+ ``SELECT * FROM R, S WHERE R.a = S.b AND R.a = 42 ORDER BY S.b``, whereas the query fragment for table ``S`` would
305
+ look like ``SELECT * FROM S ORDER BY S.b``.
306
+
307
+ Notice that this can break disjunctions: the fragment for table ``R`` of query
308
+ ``SELECT * FROM R, S, WHERE R.a < 100 AND (R.a = 42 OR S.b = 42)`` is ``SELECT * FROM R WHERE R.a < 100 AND R.a = 42``.
309
+ This also indicates that the fragment extraction does not perform any logical pruning of superflous predicates.
310
+
311
+ Parameters
312
+ ----------
313
+ source_query : SelectQueryType
314
+ The query that should be transformed
315
+ referenced_tables : TableReference | Iterable[TableReference]
316
+ The tables that should be extracted
317
+
318
+ Returns
319
+ -------
320
+ Optional[SelectQueryType]
321
+ A query that only consists of those parts of the `source_query`, that reference (a subset of) the `referenced_tables`.
322
+ If there is no such subset, ``None`` is returned.
323
+
324
+ Warnings
325
+ --------
326
+ The current implementation only works for `SetQuery` and `ImplicitSqlQuery` instances. If the `source_query` is not of one
327
+ of these types (or contains subqueries that are not of these types), a `ValueError` is raised.
328
+ """
329
+ if not isinstance(source_query, (ImplicitSqlQuery, SetQuery)):
330
+ raise ValueError(
331
+ "Fragment extraction only works for implicit queries and set queries"
332
+ )
333
+
334
+ referenced_tables: set[TableReference] = set(util.enlist(referenced_tables))
335
+ if not referenced_tables.issubset(source_query.tables()):
336
+ return None
337
+
338
+ cte_fragment = (
339
+ [
340
+ with_query
341
+ for with_query in source_query.cte_clause.queries
342
+ if with_query.target_table in referenced_tables
343
+ ]
344
+ if source_query.cte_clause
345
+ else []
346
+ )
347
+ cte_clause = (
348
+ CommonTableExpression(cte_fragment, recursive=source_query.cte_clause.recursive)
349
+ if cte_fragment
350
+ else None
351
+ )
352
+
353
+ if source_query.orderby_clause:
354
+ order_fragment = [
355
+ order
356
+ for order in source_query.orderby_clause.expressions
357
+ if order.column.tables().issubset(referenced_tables)
358
+ ]
359
+ orderby_clause = OrderBy(order_fragment) if order_fragment else None
360
+ else:
361
+ orderby_clause = None
362
+
363
+ if isinstance(source_query, SetQuery):
364
+ left_query = extract_query_fragment(source_query.left_query, referenced_tables)
365
+ right_query = extract_query_fragment(
366
+ source_query.right_query, referenced_tables
367
+ )
368
+ return SetQuery(
369
+ left_query,
370
+ right_query,
371
+ source_query.set_operation,
372
+ cte_clause=cte_clause,
373
+ orderby_clause=orderby_clause,
374
+ limit_clause=source_query.limit_clause,
375
+ hints=source_query.hints,
376
+ explain=source_query.explain,
377
+ )
378
+
379
+ select_fragment = []
380
+ for target in source_query.select_clause:
381
+ if target.tables() == referenced_tables or not target.columns():
382
+ select_fragment.append(target)
383
+
384
+ if select_fragment:
385
+ select_clause = Select(
386
+ select_fragment, distinct=source_query.select_clause.is_distinct()
387
+ )
388
+ else:
389
+ select_clause = Select.star(distinct=source_query.select_clause.is_distinct())
390
+
391
+ if source_query.from_clause:
392
+ from_clause = ImplicitFromClause(
393
+ [
394
+ DirectTableSource(tab)
395
+ for tab in source_query.tables()
396
+ if tab in referenced_tables
397
+ ]
398
+ )
399
+ else:
400
+ from_clause = None
401
+
402
+ if source_query.where_clause:
403
+ predicate_fragment = _get_predicate_fragment(
404
+ source_query.where_clause.predicate, referenced_tables
405
+ )
406
+ where_clause = Where(predicate_fragment) if predicate_fragment else None
407
+ else:
408
+ where_clause = None
409
+
410
+ if source_query.groupby_clause:
411
+ group_column_fragment = [
412
+ col
413
+ for col in source_query.groupby_clause.group_columns
414
+ if col.tables().issubset(referenced_tables)
415
+ ]
416
+ if group_column_fragment:
417
+ groupby_clause = GroupBy(
418
+ group_column_fragment, source_query.groupby_clause.distinct
419
+ )
420
+ else:
421
+ groupby_clause = None
422
+ else:
423
+ groupby_clause = None
424
+
425
+ if source_query.having_clause:
426
+ having_fragment = _get_predicate_fragment(
427
+ source_query.having_clause.condition, referenced_tables
428
+ )
429
+ having_clause = Having(having_fragment) if having_fragment else None
430
+ else:
431
+ having_clause = None
432
+
433
+ return ImplicitSqlQuery(
434
+ select_clause=select_clause,
435
+ from_clause=from_clause,
436
+ where_clause=where_clause,
437
+ groupby_clause=groupby_clause,
438
+ having_clause=having_clause,
439
+ orderby_clause=orderby_clause,
440
+ limit_clause=source_query.limit_clause,
441
+ cte_clause=cte_clause,
442
+ )
443
+
444
+
445
+ def _default_subquery_name(tables: Iterable[TableReference]) -> str:
446
+ """Constructs a valid SQL table name for a subquery consisting of specific tables.
447
+
448
+ Parameters
449
+ ----------
450
+ tables : Iterable[TableReference]
451
+ The tables that should be represented by a subquery.
452
+
453
+ Returns
454
+ -------
455
+ str
456
+ The target name of the subquery
457
+ """
458
+ return "_".join(table.identifier() for table in tables)
459
+
460
+
461
+ def expand_to_query(predicate: AbstractPredicate) -> ImplicitSqlQuery:
462
+ """Provides a ``SELECT *`` query that computes the result set of a specific predicate.
463
+
464
+ Parameters
465
+ ----------
466
+ predicate : AbstractPredicate
467
+ The predicate to expand
468
+
469
+ Returns
470
+ -------
471
+ ImplicitSqlQuery
472
+ An SQL query of the form ``SELECT * FROM <predicate tables> WHERE <predicate>``.
473
+ """
474
+ select_clause = Select.star()
475
+ from_clause = ImplicitFromClause(predicate.tables())
476
+ where_clause = Where(predicate)
477
+ return build_query([select_clause, from_clause, where_clause])
478
+
479
+
480
+ def move_into_subquery(
481
+ query: SqlQuery, tables: Iterable[TableReference], subquery_name: str = ""
482
+ ) -> SqlQuery:
483
+ """Transforms a specific query by moving some of its tables into a subquery.
484
+
485
+ This transformation renames all usages of columns that are now produced by the subquery to references to the virtual
486
+ subquery table instead.
487
+
488
+ Notice that this transformation currently only really works for "good natured" queries, i.e. mostly implicit SPJ+ queries.
489
+ Notably, the transformation is only supported for queries that do not already contain subqueries, since moving tables
490
+ between subqueries is a quite tricky process. Likewise, the renaming is only applied at a table-level. If the tables export
491
+ columns of the same name, these are not renamed and the transformation fails as well. If in doubt, you should definitely
492
+ check the output of this method for your complicated queries to prevent bad surprises!
493
+
494
+ Parameters
495
+ ----------
496
+ query : SqlQuery
497
+ The query to transform
498
+ tables : Iterable[TableReference]
499
+ The tables that should be placed into a subquery
500
+ subquery_name : str, optional
501
+ The target name of the virtual subquery table. If empty, a default name (consisting of all the subquery tables) is
502
+ generated
503
+
504
+ Returns
505
+ -------
506
+ SqlQuery
507
+ The transformed query
508
+
509
+ Raises
510
+ ------
511
+ ValueError
512
+ If the `query` does not contain a ``FROM`` clause.
513
+ ValueError
514
+ If the query contains virtual tables
515
+ ValueError
516
+ If `tables` contains less than 2 entries. In this case, using a subquery is completely pointless.
517
+ ValueError
518
+ If the tables that should become part of the subquery both provide columns of the same name, and these columns are used
519
+ in the rest of the query. This level of renaming is currently not accounted for.
520
+ """
521
+ if not query.from_clause:
522
+ raise ValueError("Cannot create a subquery for a query without a FROM clause")
523
+ if any(table.virtual for table in query.tables()):
524
+ raise ValueError("Cannot move into subquery for queries with virtual tables")
525
+
526
+ tables = set(tables)
527
+
528
+ # deleted CTE check: this was not necessary because a CTE produces a virtual table. This already fails the previous test
529
+ if len(tables) < 2:
530
+ raise ValueError("At least two tables required")
531
+
532
+ predicates = query.predicates()
533
+ all_referenced_columns = util.set_union(
534
+ clause.columns() for clause in query.clauses() if not isinstance(clause, From)
535
+ )
536
+ columns_from_subquery_tables = {
537
+ column for column in all_referenced_columns if column.table in tables
538
+ }
539
+ if len({column.name for column in columns_from_subquery_tables}) < len(
540
+ columns_from_subquery_tables
541
+ ):
542
+ raise ValueError(
543
+ "Cannot create subquery: subquery tables export columns of the same name"
544
+ )
545
+
546
+ subquery_name = subquery_name if subquery_name else _default_subquery_name(tables)
547
+ subquery_table = TableReference.create_virtual(subquery_name)
548
+ renamed_columns = {
549
+ column: ColumnReference(column.name, subquery_table)
550
+ for column in columns_from_subquery_tables
551
+ }
552
+
553
+ subquery_predicates: list[AbstractPredicate] = []
554
+ for table in tables:
555
+ filter_predicate = predicates.filters_for(table)
556
+ if not filter_predicate:
557
+ continue
558
+ subquery_predicates.append(filter_predicate)
559
+ join_predicates = predicates.joins_between(tables, tables)
560
+ if join_predicates:
561
+ subquery_predicates.append(join_predicates)
562
+
563
+ subquery_select = Select.create_for(columns_from_subquery_tables)
564
+ subquery_from = ImplicitFromClause.create_for(tables)
565
+ subquery_where = (
566
+ Where(CompoundPredicate.create_and(subquery_predicates))
567
+ if subquery_predicates
568
+ else None
569
+ )
570
+ subquery_clauses = [subquery_select, subquery_from, subquery_where]
571
+ subquery = build_query(clause for clause in subquery_clauses if clause)
572
+ subquery_table_source = SubqueryTableSource(subquery, subquery_name)
573
+
574
+ updated_from_sources = [
575
+ table_source
576
+ for table_source in query.from_clause.items
577
+ if not table_source.tables() < tables
578
+ ]
579
+ update_from_clause = ImplicitFromClause.create_for(
580
+ updated_from_sources + [subquery_table_source]
581
+ )
582
+ updated_predicate = query.where_clause.predicate if query.where_clause else None
583
+ for predicate in subquery_predicates:
584
+ updated_predicate = remove_predicate(updated_predicate, predicate)
585
+ updated_where_clause = Where(updated_predicate) if updated_predicate else None
586
+
587
+ updated_query = drop_clause(query, [From, Where])
588
+ updated_query = add_clause(
589
+ updated_query, [update_from_clause, updated_where_clause]
590
+ )
591
+
592
+ updated_other_clauses: list[BaseClause] = []
593
+ for clause in updated_query.clauses():
594
+ if isinstance(clause, From):
595
+ continue
596
+ renamed_clause = rename_columns_in_clause(clause, renamed_columns)
597
+ updated_other_clauses.append(renamed_clause)
598
+ final_query = replace_clause(updated_query, updated_other_clauses)
599
+ return final_query
600
+
601
+
602
+ def add_ec_predicates(query: ImplicitSqlQuery) -> ImplicitSqlQuery:
603
+ """Expands the join predicates of a query to include all predicates that are implied by the join equivalence classes.
604
+
605
+ Parameters
606
+ ----------
607
+ query : ImplicitSqlQuery
608
+ The query to analyze
609
+
610
+ Returns
611
+ -------
612
+ ImplicitSqlQuery
613
+ An equivalent query that explicitly contains all predicates from join equivalence classes.
614
+
615
+ See Also
616
+ --------
617
+ determine_join_equivalence_classes
618
+ generate_predicates_for_equivalence_classes
619
+ """
620
+ if not query.where_clause:
621
+ return query
622
+ predicates = query.predicates()
623
+
624
+ ec_classes = determine_join_equivalence_classes(predicates.joins())
625
+ ec_predicates = generate_predicates_for_equivalence_classes(ec_classes)
626
+
627
+ all_predicates = list(ec_predicates) + list(predicates.filters())
628
+ updated_where_clause = Where(CompoundPredicate.create_and(all_predicates))
629
+
630
+ return replace_clause(query, updated_where_clause)
631
+
632
+
633
+ def as_star_query(source_query: QueryType) -> QueryType:
634
+ """Transforms a specific query to use a ``SELECT *`` projection instead.
635
+
636
+ Notice that this can break certain queries where a renamed column from the ``SELECT`` clause is used in other parts of
637
+ the query, such as ``ORDER BY`` clauses (e.g. ``SELECT SUM(foo) AS f FROM bar ORDER BY f``). We currently do not undo such
638
+ a renaming.
639
+
640
+ Parameters
641
+ ----------
642
+ source_query : QueryType
643
+ The query to transform
644
+
645
+ Returns
646
+ -------
647
+ QueryType
648
+ A variant of the input query that uses a ``SELECT *`` projection.
649
+ """
650
+ select = Select.star()
651
+ query_clauses = [
652
+ clause for clause in source_query.clauses() if not isinstance(clause, Select)
653
+ ]
654
+ return build_query(query_clauses + [select])
655
+
656
+
657
+ def as_count_star_query(source_query: QueryType) -> QueryType:
658
+ """Transforms a specific query to use a ``SELECT COUNT(*)`` projection instead.
659
+
660
+ Notice that this can break certain queries where a renamed column from the ``SELECT`` clause is used in other parts of
661
+ the query, such as ``ORDER BY`` clauses (e.g. ``SELECT SUM(foo) AS f FROM bar ORDER BY f``). We currently do not undo such
662
+ a renaming.
663
+
664
+ Parameters
665
+ ----------
666
+ source_query : QueryType
667
+ The query to transform
668
+
669
+ Returns
670
+ -------
671
+ QueryType
672
+ A variant of the input query that uses a ``SELECT COUNT(*)`` projection.
673
+ """
674
+ select = Select.count_star()
675
+ query_clauses = [
676
+ clause for clause in source_query.clauses() if not isinstance(clause, Select)
677
+ ]
678
+ return build_query(query_clauses + [select])
679
+
680
+
681
+ def drop_hints(
682
+ query: SelectQueryType, preparatory_statements_only: bool = False
683
+ ) -> SelectQueryType:
684
+ """Removes the hint clause from a specific query.
685
+
686
+ Parameters
687
+ ----------
688
+ query : SelectQueryType
689
+ The query to transform
690
+ preparatory_statements_only : bool, optional
691
+ Whether only the preparatory statements from the hint block should be removed. This would retain the actual hints.
692
+ Defaults to ``False``, which removes the entire block, no matter its contents.
693
+
694
+ Returns
695
+ -------
696
+ SelectQueryType
697
+ The query without the hint block
698
+ """
699
+ new_hints = (
700
+ Hint("", query.hints.query_hints)
701
+ if preparatory_statements_only and query.hints
702
+ else None
703
+ )
704
+ query_clauses = [
705
+ clause for clause in query.clauses() if not isinstance(clause, Hint)
706
+ ]
707
+ return build_query(query_clauses + [new_hints])
708
+
709
+
710
+ def as_explain(
711
+ query: SelectQueryType, explain: Explain = Explain.plan()
712
+ ) -> SelectQueryType:
713
+ """Transforms a specific query into an ``EXPLAIN`` query.
714
+
715
+ Parameters
716
+ ----------
717
+ query : SelectQueryType
718
+ The query to transform
719
+ explain : Explain, optional
720
+ The ``EXPLAIN`` block to use. Defaults to a standard ``Explain.plan()`` block.
721
+
722
+ Returns
723
+ -------
724
+ SelectQueryType
725
+ The transformed query
726
+ """
727
+ query_clauses = [
728
+ clause for clause in query.clauses() if not isinstance(clause, Explain)
729
+ ]
730
+ return build_query(query_clauses + [explain])
731
+
732
+
733
+ def as_explain_analyze(query: SelectQueryType) -> SelectQueryType:
734
+ """Transforms a specific query into an ``EXPLAIN ANALYZE`` query.
735
+
736
+ Parameters
737
+ ----------
738
+ query : SelectQueryType
739
+ The query to transform
740
+
741
+ Returns
742
+ -------
743
+ SelectQueryType
744
+ The transformed query. It uses an ``EXPLAIN ANALYZE`` block with the default output format. If this is not desired,
745
+ the `as_explain` transformation has to be used and the target ``EXPLAIN`` block has to be given explicitly.
746
+ """
747
+ return as_explain(query, Explain.explain_analyze())
748
+
749
+
750
+ def remove_predicate(
751
+ predicate: Optional[AbstractPredicate], predicate_to_remove: AbstractPredicate
752
+ ) -> Optional[AbstractPredicate]:
753
+ """Drops a specific predicate from the predicate hierarchy.
754
+
755
+ If necessary, the hierarchy will be simplified. For example, if the `predicate_to_remove` is one of two childs of a
756
+ conjunction, the removal would leave a conjunction of just a single predicate. In this case, the conjunction can be dropped
757
+ altogether, leaving just the other child predicate. The same also applies to disjunctions and negations.
758
+
759
+ Parameters
760
+ ----------
761
+ predicate : Optional[AbstractPredicate]
762
+ The predicate hierarchy from which should removed. If this is ``None``, no removal is attempted.
763
+ predicate_to_remove : AbstractPredicate
764
+ The predicate that should be removed.
765
+
766
+ Returns
767
+ -------
768
+ Optional[AbstractPredicate]
769
+ The resulting (simplified) predicate hierarchy. Will be ``None`` if there are no meaningful predicates left after
770
+ removal, or if the `predicate` equals the `predicate_to_remove`.
771
+ """
772
+ if not predicate or predicate == predicate_to_remove:
773
+ return None
774
+ if not isinstance(predicate, CompoundPredicate):
775
+ return predicate
776
+
777
+ if predicate.operation == CompoundOperator.Not:
778
+ updated_child = remove_predicate(predicate.children, predicate_to_remove)
779
+ return CompoundPredicate.create_not(updated_child) if updated_child else None
780
+
781
+ updated_children = [
782
+ remove_predicate(child_pred, predicate_to_remove)
783
+ for child_pred in predicate.children
784
+ ]
785
+ updated_children = [child_pred for child_pred in updated_children if child_pred]
786
+ if not updated_children:
787
+ return None
788
+ elif len(updated_children) == 1:
789
+ return updated_children[0]
790
+ else:
791
+ return CompoundPredicate(predicate.operation, updated_children)
792
+
793
+
794
+ def add_clause(
795
+ query: SelectQueryType, clauses_to_add: BaseClause | Iterable[BaseClause]
796
+ ) -> SelectQueryType:
797
+ """Creates a new SQL query, potentailly with additional clauses.
798
+
799
+ No validation is performed. Conflicts are resolved according to the rules of `build_query`. This means that the query
800
+ can potentially be switched from an implicit query to an explicit one and vice-versa.
801
+
802
+ Parameters
803
+ ----------
804
+ query : SqlQuery
805
+ The query to which the clause(s) should be added
806
+ clauses_to_add : BaseClause | Iterable[BaseClause]
807
+ The new clauses
808
+
809
+ Returns
810
+ -------
811
+ SqlQuery
812
+ A new clauses consisting of the old query's clauses and the `clauses_to_add`. Duplicate clauses are overwritten by
813
+ the `clauses_to_add`.
814
+ """
815
+ clauses_to_add = util.enlist(clauses_to_add)
816
+ new_clause_types = {type(clause) for clause in clauses_to_add}
817
+ remaining_clauses = [
818
+ clause for clause in query.clauses() if type(clause) not in new_clause_types
819
+ ]
820
+ return build_query(remaining_clauses + list(clauses_to_add))
821
+
822
+
823
+ ClauseDescription = typing.Union[
824
+ typing.Type, BaseClause, Iterable[typing.Type | BaseClause]
825
+ ]
826
+ """Denotes different ways clauses to remove can be denoted.
827
+
828
+ See Also
829
+ --------
830
+ drop_clause
831
+ """
832
+
833
+
834
+ def drop_clause(
835
+ query: SelectQueryType, clauses_to_drop: ClauseDescription
836
+ ) -> SelectQueryType:
837
+ """Removes specific clauses from a query.
838
+
839
+ The clauses can be denoted in two different ways: either as the raw type of the clause, or as an instance of the same
840
+ clause type as the one that should be removed. Notice that the instance of the clause does not need to be equal to the
841
+ clause of the query. It just needs to be the same type of clause.
842
+
843
+ This method does not perform any validation, other than the rules described in `build_query`.
844
+
845
+ Parameters
846
+ ----------
847
+ query : SelectQueryType
848
+ The query to remove clauses from
849
+ clauses_to_drop : ClauseDescription
850
+ The clause(s) to remove. This can be a single clause type or clause instance, or an iterable of clauses types,
851
+ intermixed with clause instances. In either way clauses of the desired types are dropped from the query.
852
+
853
+ Returns
854
+ -------
855
+ SelectQueryType
856
+ A query without the specified clauses
857
+
858
+ Examples
859
+ --------
860
+
861
+ The following two calls achieve exactly the same thing: getting rid of the ``LIMIT`` clause.
862
+
863
+ .. code-block:: python
864
+
865
+ drop_clause(query, Limit)
866
+ drop_clause(query, query.limit_clause)
867
+ """
868
+ clauses_to_drop = set(util.enlist(clauses_to_drop))
869
+ clauses_to_drop = {
870
+ drop if isinstance(drop, typing.Type) else type(drop)
871
+ for drop in clauses_to_drop
872
+ }
873
+ remaining_clauses = [
874
+ clause for clause in query.clauses() if type(clause) not in clauses_to_drop
875
+ ]
876
+ return build_query(remaining_clauses)
877
+
878
+
879
+ def replace_clause(
880
+ query: SelectQueryType, replacements: BaseClause | Iterable[BaseClause]
881
+ ) -> SelectQueryType:
882
+ """Creates a new SQL query with the replacements being used instead of the original clauses.
883
+
884
+ Clauses are matched on a per-type basis (including subclasses, i.e. a replacement can be a subclass of an existing clause).
885
+ Therefore, this function does not switch a query from implicit to explicit or vice-versa. Use a combination of
886
+ `drop_clause` and `add_clause` for that. If a replacement is not present in the original query, it is simply ignored.
887
+
888
+ No validation other than the rules of `build_query` is performed.
889
+
890
+ Parameters
891
+ ----------
892
+ query : SelectQueryType
893
+ The query to update
894
+ replacements : BaseClause | Iterable[BaseClause]
895
+ The new clause instances that should be used instead of the old ones.
896
+
897
+ Returns
898
+ -------
899
+ SelectQueryType
900
+ An updated query where the matching `replacements` clauses are used in place of the clause instances that were
901
+ originally present in the query
902
+ """
903
+ available_replacements: set[BaseClause] = set(util.enlist(replacements))
904
+
905
+ replaced_clauses: list[BaseClause] = []
906
+ for current_clause in query.clauses():
907
+ if not available_replacements:
908
+ replaced_clauses.append(current_clause)
909
+ continue
910
+
911
+ final_clause = current_clause
912
+ for replacement_clause in available_replacements:
913
+ if isinstance(replacement_clause, type(current_clause)):
914
+ final_clause = replacement_clause
915
+ break
916
+ replaced_clauses.append(final_clause)
917
+ if final_clause in available_replacements:
918
+ available_replacements.remove(final_clause)
919
+
920
+ return build_query(replaced_clauses)
921
+
922
+
923
+ def _replace_expression_in_predicate(
924
+ predicate: Optional[PredicateType],
925
+ replacement: Callable[[SqlExpression], SqlExpression],
926
+ ) -> Optional[PredicateType]:
927
+ """Handler to update all expressions in a specific predicate.
928
+
929
+ This method does not perform any sanity checks on the new predicate.
930
+
931
+
932
+ Parameters
933
+ ----------
934
+ predicate : PredicateType
935
+ The predicate to update. Can be ``None``, in which case no replacement is performed.
936
+ replacement : Callable[[SqlExpression], SqlExpression]
937
+ A function mapping each expression to a (potentially updated) expression
938
+
939
+ Returns
940
+ -------
941
+ Optional[PredicateType]
942
+ The updated predicate. Can be ``None``, if `predicate` already was.
943
+
944
+ Raises
945
+ ------
946
+ ValueError
947
+ If the predicate is of no known type. This indicates that this method is missing a handler for a specific predicate
948
+ type that was added later on.
949
+ """
950
+ if not predicate:
951
+ return None
952
+
953
+ if isinstance(predicate, BinaryPredicate):
954
+ renamed_first_arg = replacement(predicate.first_argument)
955
+ renamed_second_arg = replacement(predicate.second_argument)
956
+ return BinaryPredicate(
957
+ predicate.operation, renamed_first_arg, renamed_second_arg
958
+ )
959
+ elif isinstance(predicate, BetweenPredicate):
960
+ renamed_col = replacement(predicate.column)
961
+ renamed_interval_start = replacement(predicate.interval_start)
962
+ renamed_interval_end = replacement(predicate.interval_end)
963
+ return BetweenPredicate(
964
+ renamed_col, (renamed_interval_start, renamed_interval_end)
965
+ )
966
+ elif isinstance(predicate, InPredicate):
967
+ renamed_col = replacement(predicate.column)
968
+ renamed_vals = [replacement(val) for val in predicate.values]
969
+ return InPredicate(renamed_col, renamed_vals)
970
+ elif isinstance(predicate, UnaryPredicate):
971
+ return UnaryPredicate(replacement(predicate.column), predicate.operation)
972
+ elif isinstance(predicate, CompoundPredicate):
973
+ if predicate.operation == CompoundOperator.Not:
974
+ renamed_children = [
975
+ _replace_expression_in_predicate(predicate.children, replacement)
976
+ ]
977
+ else:
978
+ renamed_children = [
979
+ _replace_expression_in_predicate(child, replacement)
980
+ for child in predicate.children
981
+ ]
982
+ return CompoundPredicate(predicate.operation, renamed_children)
983
+ else:
984
+ raise ValueError("Unknown predicate type: " + str(predicate))
985
+
986
+
987
+ def _replace_expression_in_table_source(
988
+ table_source: Optional[TableSource],
989
+ replacement: Callable[[SqlExpression], SqlExpression],
990
+ ) -> Optional[TableSource]:
991
+ """Handler to update all expressions in a table source.
992
+
993
+ This method does not perform any sanity checks on the updated sources.
994
+
995
+ Parameters
996
+ ----------
997
+ table_source : TableSource
998
+ The source to update. Can be ``None``, in which case no replacement is performed.
999
+ replacement : Callable[[SqlExpression], SqlExpression]
1000
+ A function mapping each expression to a (potentially updated) expression
1001
+
1002
+ Returns
1003
+ -------
1004
+ Optional[TableSource]
1005
+ The updated table source. Can be ``None``, if `table_source` already was.
1006
+
1007
+ Raises
1008
+ ------
1009
+ ValueError
1010
+ If the table source is of no known type. This indicates that this method is missing a handler for a specific source
1011
+ type that was added later on.
1012
+ """
1013
+ if table_source is None:
1014
+ return None
1015
+
1016
+ match table_source:
1017
+ case DirectTableSource():
1018
+ # no expressions in a plain table reference, we are done here
1019
+ return table_source
1020
+ case SubqueryTableSource():
1021
+ replaced_subquery = replacement(table_source.expression)
1022
+ assert isinstance(replaced_subquery, SubqueryExpression)
1023
+ replaced_subquery = replace_expressions(
1024
+ replaced_subquery.query, replacement
1025
+ )
1026
+ return SubqueryTableSource(
1027
+ replaced_subquery,
1028
+ table_source.target_name,
1029
+ lateral=table_source.lateral,
1030
+ )
1031
+ case JoinTableSource():
1032
+ replaced_left = _replace_expression_in_table_source(
1033
+ table_source.left, replacement
1034
+ )
1035
+ replaced_right = _replace_expression_in_table_source(
1036
+ table_source.right, replacement
1037
+ )
1038
+ replaced_condition = _replace_expression_in_predicate(
1039
+ table_source.join_condition, replacement
1040
+ )
1041
+ return JoinTableSource(
1042
+ replaced_left,
1043
+ replaced_right,
1044
+ join_condition=replaced_condition,
1045
+ join_type=table_source.join_type,
1046
+ )
1047
+ case ValuesTableSource():
1048
+ replaced_values = [
1049
+ tuple([replacement(val) for val in row]) for row in table_source.rows
1050
+ ]
1051
+ return ValuesTableSource(
1052
+ replaced_values,
1053
+ alias=table_source.table.identifier(),
1054
+ columns=table_source.cols,
1055
+ )
1056
+ case FunctionTableSource():
1057
+ replaced_function = replacement(table_source.function)
1058
+ return FunctionTableSource(replaced_function, table_source.target_table)
1059
+ case _:
1060
+ raise TypeError("Unknown table source type: " + str(table_source))
1061
+
1062
+
1063
+ def _replace_expressions_in_clause(
1064
+ clause: Optional[ClauseType], replacement: Callable[[SqlExpression], SqlExpression]
1065
+ ) -> Optional[ClauseType]:
1066
+ """Handler to update all expressions in a clause.
1067
+
1068
+ This method does not perform any sanity checks on the updated clauses.
1069
+
1070
+ Parameters
1071
+ ----------
1072
+ clause : ClauseType
1073
+ The clause to update. Can be ``None``, in which case no replacement is performed.
1074
+ replacement : Callable[[SqlExpression], SqlExpression]
1075
+ A function mapping each expression to a (potentially updated) expression
1076
+
1077
+ Returns
1078
+ -------
1079
+ Optional[ClauseType]
1080
+ The updated clause. Can be ``None``, if `clause` already was.
1081
+
1082
+ Raises
1083
+ ------
1084
+ ValueError
1085
+ If the clause is of no known type. This indicates that this method is missing a handler for a specific clause type that
1086
+ was added later on.
1087
+ """
1088
+ if not clause:
1089
+ return None
1090
+
1091
+ if isinstance(clause, Hint) or isinstance(clause, Explain):
1092
+ return clause
1093
+ elif isinstance(clause, CommonTableExpression):
1094
+ replaced_queries: list[WithQuery] = []
1095
+
1096
+ for cte in clause.queries:
1097
+ if isinstance(cte, ValuesWithQuery):
1098
+ replaced_values = [
1099
+ tuple([replacement(val) for val in row]) for row in cte.rows
1100
+ ]
1101
+ replaced_cte = ValuesWithQuery(
1102
+ replaced_values,
1103
+ target_name=cte.target_table,
1104
+ columns=cte.cols,
1105
+ materialized=cte.materialized,
1106
+ )
1107
+ replaced_queries.append(replaced_cte)
1108
+ continue
1109
+
1110
+ replaced_cte = WithQuery(
1111
+ replace_expressions(cte.query, replacement),
1112
+ cte.target_table,
1113
+ materialized=cte.materialized,
1114
+ )
1115
+ replaced_queries.append(replaced_cte)
1116
+
1117
+ return CommonTableExpression(replaced_queries, recursive=clause.recursive)
1118
+ elif isinstance(clause, Select):
1119
+ replaced_targets = [
1120
+ BaseProjection(replacement(proj.expression), proj.target_name)
1121
+ for proj in clause.targets
1122
+ ]
1123
+ return Select(replaced_targets, distinct=clause.distinct_specifier())
1124
+ elif isinstance(clause, ImplicitFromClause):
1125
+ return clause
1126
+ elif isinstance(clause, ExplicitFromClause):
1127
+ replaced_joins = [
1128
+ _replace_expression_in_table_source(join, replacement)
1129
+ for join in clause.items
1130
+ ]
1131
+ return ExplicitFromClause(replaced_joins)
1132
+ elif isinstance(clause, From):
1133
+ replaced_contents = [
1134
+ _replace_expression_in_table_source(target, replacement)
1135
+ for target in clause.items
1136
+ ]
1137
+ return From(replaced_contents)
1138
+ elif isinstance(clause, Where):
1139
+ return Where(_replace_expression_in_predicate(clause.predicate, replacement))
1140
+ elif isinstance(clause, GroupBy):
1141
+ replaced_cols = [replacement(col) for col in clause.group_columns]
1142
+ return GroupBy(replaced_cols, clause.distinct)
1143
+ elif isinstance(clause, Having):
1144
+ return Having(_replace_expression_in_predicate(clause.condition, replacement))
1145
+ elif isinstance(clause, OrderBy):
1146
+ replaced_cols = [
1147
+ OrderByExpression(replacement(col.column), col.ascending, col.nulls_first)
1148
+ for col in clause.expressions
1149
+ ]
1150
+ return OrderBy(replaced_cols)
1151
+ elif isinstance(clause, Limit):
1152
+ return clause
1153
+ elif isinstance(clause, UnionClause):
1154
+ replaced_left = replace_expressions(clause.left_query, replacement)
1155
+ replaced_right = replace_expressions(clause.right_query, replacement)
1156
+ return UnionClause(replaced_left, replaced_right, union_all=clause.union_all)
1157
+ elif isinstance(clause, IntersectClause):
1158
+ replaced_left = replace_expressions(clause.left_query, replacement)
1159
+ replaced_right = replace_expressions(clause.right_query, replacement)
1160
+ return IntersectClause(replaced_left, replaced_right)
1161
+ elif isinstance(clause, ExceptClause):
1162
+ replaced_left = replace_expressions(clause.left_query, replacement)
1163
+ replaced_right = replace_expressions(clause.right_query, replacement)
1164
+ return ExceptClause(replaced_left, replaced_right)
1165
+ else:
1166
+ raise ValueError("Unknown clause: " + str(clause))
1167
+
1168
+
1169
+ def replace_expressions(
1170
+ query: SelectQueryType, replacement: Callable[[SqlExpression], SqlExpression]
1171
+ ) -> SelectQueryType:
1172
+ """Updates all expressions in a query.
1173
+
1174
+ The replacement handler can either produce entirely new expressions, or simply return the current expression instance if
1175
+ no update should be performed. Be very careful with this method since no sanity checks are performed, other than the rules
1176
+ of `build_query`.
1177
+
1178
+ Parameters
1179
+ ----------
1180
+ query : SelectQueryType
1181
+ The query to update
1182
+ replacement : Callable[[SqlExpression], SqlExpression]
1183
+ A function mapping each of the current expressions in the `query` to potentially updated expressions.
1184
+
1185
+ Returns
1186
+ -------
1187
+ SelectQueryType
1188
+ The updated query
1189
+ """
1190
+ replaced_clauses = [
1191
+ _replace_expressions_in_clause(clause, replacement)
1192
+ for clause in query.clauses()
1193
+ ]
1194
+ return build_query(replaced_clauses)
1195
+
1196
+
1197
+ def _perform_predicate_replacement(
1198
+ current_predicate: AbstractPredicate,
1199
+ target_predicate: AbstractPredicate,
1200
+ new_predicate: AbstractPredicate,
1201
+ ) -> AbstractPredicate:
1202
+ """Handler to change specific predicates in a predicate hierarchy to other predicates.
1203
+
1204
+ This does not perform any sanity checks on the updated predicate hierarchy, nor is the hierarchy simplified.
1205
+
1206
+ Parameters
1207
+ ----------
1208
+ current_predicate : AbstractPredicate
1209
+ The predicate hierarchy in which the updates should occur
1210
+ target_predicate : AbstractPredicate
1211
+ The predicate that should be replaced
1212
+ new_predicate : AbstractPredicate
1213
+ The new predicate that should be used instead of the `target_predicate`
1214
+
1215
+ Returns
1216
+ -------
1217
+ AbstractPredicate
1218
+ The updated predicate
1219
+ """
1220
+ if current_predicate == target_predicate:
1221
+ return new_predicate
1222
+
1223
+ if isinstance(current_predicate, CompoundPredicate):
1224
+ if current_predicate.operation == CompoundOperator.Not:
1225
+ replaced_children = [
1226
+ _perform_predicate_replacement(
1227
+ current_predicate.children, target_predicate, new_predicate
1228
+ )
1229
+ ]
1230
+ else:
1231
+ replaced_children = [
1232
+ _perform_predicate_replacement(
1233
+ child_pred, target_predicate, new_predicate
1234
+ )
1235
+ for child_pred in current_predicate.children
1236
+ ]
1237
+ return CompoundPredicate(current_predicate.operation, replaced_children)
1238
+ else:
1239
+ return current_predicate
1240
+
1241
+
1242
+ def replace_predicate(
1243
+ query: ImplicitSqlQuery,
1244
+ predicate_to_replace: AbstractPredicate,
1245
+ new_predicate: AbstractPredicate,
1246
+ ) -> ImplicitSqlQuery:
1247
+ """Rewrites a specific query to use a new predicate in place of an old one.
1248
+
1249
+ In the current implementation this does only work for top-level predicates, i.e. subqueries and CTEs are not considered.
1250
+ Furthermore, only the ``WHERE`` clause and the ``HAVING`` clause are modified, since these should be the only ones that
1251
+ contain predicates.
1252
+
1253
+ If the predicate to replace is not found, nothing happens. In the same vein, no sanity checks are performed on the updated
1254
+ query.
1255
+
1256
+ Parameters
1257
+ ----------
1258
+ query : ImplicitSqlQuery
1259
+ The query update
1260
+ predicate_to_replace : AbstractPredicate
1261
+ The old predicate that should be dropped
1262
+ new_predicate : AbstractPredicate
1263
+ The predicate that should be used in place of `predicate_to_replace`. This can be an entirely different type of
1264
+ predicate, e.g. a conjunction of join conditions that replace a single join predicate.
1265
+
1266
+ Returns
1267
+ -------
1268
+ ImplicitSqlQuery
1269
+ The updated query
1270
+ """
1271
+ # TODO: also allow replacement in explicit SQL queries
1272
+ # TODO: allow predicate replacement in subqueries / CTEs
1273
+ # TODO: allow replacement in set queries
1274
+ if not query.where_clause and not query.having_clause:
1275
+ return query
1276
+
1277
+ if query.where_clause:
1278
+ replaced_predicate = _perform_predicate_replacement(
1279
+ query.where_clause.predicate, predicate_to_replace, new_predicate
1280
+ )
1281
+ replaced_where = Where(replaced_predicate)
1282
+ else:
1283
+ replaced_where = None
1284
+
1285
+ if query.having_clause:
1286
+ replaced_predicate = _perform_predicate_replacement(
1287
+ query.having_clause.condition, predicate_to_replace, new_predicate
1288
+ )
1289
+ replaced_having = Having(replaced_predicate)
1290
+ else:
1291
+ replaced_having = None
1292
+
1293
+ candidate_clauses = [replaced_where, replaced_having]
1294
+ return replace_clause(query, [clause for clause in candidate_clauses if clause])
1295
+
1296
+
1297
+ def rename_columns_in_query(
1298
+ query: SelectQueryType, available_renamings: dict[ColumnReference, ColumnReference]
1299
+ ) -> SelectQueryType:
1300
+ """Replaces specific column references by new references for an entire query.
1301
+
1302
+ Parameters
1303
+ ----------
1304
+ query : SelectQueryType
1305
+ The query to update
1306
+ available_renamings : dict[ColumnReference, ColumnReference]
1307
+ A dictionary mapping each of the old column values to the values that should be used instead.
1308
+
1309
+ Returns
1310
+ -------
1311
+ SelectQueryType
1312
+ The updated query
1313
+
1314
+ Raises
1315
+ ------
1316
+ TypeError
1317
+ If the query is of no known type. This indicates that this method is missing a handler for a specific query type that
1318
+ was added later on.
1319
+ """
1320
+ renamed_cte = rename_columns_in_clause(query.cte_clause, available_renamings)
1321
+ renamed_having = rename_columns_in_clause(query.having_clause, available_renamings)
1322
+ renamed_orderby = rename_columns_in_clause(
1323
+ query.orderby_clause, available_renamings
1324
+ )
1325
+
1326
+ if isinstance(query, SetQuery):
1327
+ renamed_left = rename_columns_in_query(query.left_query, available_renamings)
1328
+ renamed_right = rename_columns_in_query(query.right_query, available_renamings)
1329
+ return SetQuery(
1330
+ renamed_left,
1331
+ renamed_right,
1332
+ query.set_operation,
1333
+ cte_clause=renamed_cte,
1334
+ orderby_clause=renamed_orderby,
1335
+ limit_clause=query.limit_clause,
1336
+ hints=query.hints,
1337
+ explain_clause=query.explain,
1338
+ )
1339
+
1340
+ renamed_select = rename_columns_in_clause(query.select_clause, available_renamings)
1341
+ renamed_from = rename_columns_in_clause(query.from_clause, available_renamings)
1342
+ renamed_where = rename_columns_in_clause(query.where_clause, available_renamings)
1343
+ renamed_groupby = rename_columns_in_clause(
1344
+ query.groupby_clause, available_renamings
1345
+ )
1346
+
1347
+ if isinstance(query, ImplicitSqlQuery):
1348
+ return ImplicitSqlQuery(
1349
+ select_clause=renamed_select,
1350
+ from_clause=renamed_from,
1351
+ where_clause=renamed_where,
1352
+ groupby_clause=renamed_groupby,
1353
+ having_clause=renamed_having,
1354
+ orderby_clause=renamed_orderby,
1355
+ limit_clause=query.limit_clause,
1356
+ cte_clause=renamed_cte,
1357
+ hints=query.hints,
1358
+ explain_clause=query.explain,
1359
+ )
1360
+ elif isinstance(query, ExplicitSqlQuery):
1361
+ return ExplicitSqlQuery(
1362
+ select_clause=renamed_select,
1363
+ from_clause=renamed_from,
1364
+ where_clause=renamed_where,
1365
+ groupby_clause=renamed_groupby,
1366
+ having_clause=renamed_having,
1367
+ orderby_clause=renamed_orderby,
1368
+ limit_clause=query.limit_clause,
1369
+ cte_clause=renamed_cte,
1370
+ hints=query.hints,
1371
+ explain_clause=query.explain,
1372
+ )
1373
+ elif isinstance(query, MixedSqlQuery):
1374
+ return MixedSqlQuery(
1375
+ select_clause=renamed_select,
1376
+ from_clause=renamed_from,
1377
+ where_clause=renamed_where,
1378
+ groupby_clause=renamed_groupby,
1379
+ having_clause=renamed_having,
1380
+ orderby_clause=renamed_orderby,
1381
+ limit_clause=query.limit_clause,
1382
+ cte_clause=renamed_cte,
1383
+ hints=query.hints,
1384
+ explain_clause=query.explain,
1385
+ )
1386
+ else:
1387
+ raise TypeError("Unknown query type: " + str(query))
1388
+
1389
+
1390
+ def rename_columns_in_expression(
1391
+ expression: Optional[SqlExpression],
1392
+ available_renamings: dict[ColumnReference, ColumnReference],
1393
+ ) -> Optional[SqlExpression]:
1394
+ """Replaces references to specific columns in an expression.
1395
+
1396
+ Parameters
1397
+ ----------
1398
+ expression : Optional[SqlExpression]
1399
+ The expression to update. If ``None``, no renaming is performed.
1400
+ available_renamings : dict[ColumnReference, ColumnReference]
1401
+ A dictionary mapping each of the old column values to the values that should be used instead.
1402
+
1403
+ Returns
1404
+ -------
1405
+ Optional[SqlExpression]
1406
+ The updated expression. Can be ``None``, if `expression` already was.
1407
+
1408
+ Raises
1409
+ ------
1410
+ ValueError
1411
+ If the expression is of no known type. This indicates that this method is missing a handler for a specific expressoin
1412
+ type that was added later on.
1413
+ """
1414
+ if expression is None:
1415
+ return None
1416
+
1417
+ if isinstance(expression, StaticValueExpression) or isinstance(
1418
+ expression, StarExpression
1419
+ ):
1420
+ return expression
1421
+ elif isinstance(expression, ColumnExpression):
1422
+ return (
1423
+ ColumnExpression(available_renamings[expression.column])
1424
+ if expression.column in available_renamings
1425
+ else expression
1426
+ )
1427
+ elif isinstance(expression, CastExpression):
1428
+ renamed_child = rename_columns_in_expression(
1429
+ expression.casted_expression, available_renamings
1430
+ )
1431
+ return CastExpression(renamed_child, expression.target_type)
1432
+ elif isinstance(expression, MathExpression):
1433
+ renamed_first_arg = rename_columns_in_expression(
1434
+ expression.first_arg, available_renamings
1435
+ )
1436
+ renamed_second_arg = rename_columns_in_expression(
1437
+ expression.second_arg, available_renamings
1438
+ )
1439
+ return MathExpression(
1440
+ expression.operator, renamed_first_arg, renamed_second_arg
1441
+ )
1442
+ elif isinstance(expression, FunctionExpression):
1443
+ renamed_arguments = [
1444
+ rename_columns_in_expression(arg, available_renamings)
1445
+ for arg in expression.arguments
1446
+ ]
1447
+ return FunctionExpression(
1448
+ expression.function, renamed_arguments, distinct=expression.distinct
1449
+ )
1450
+ elif isinstance(expression, SubqueryExpression):
1451
+ return SubqueryExpression(
1452
+ rename_columns_in_query(expression.query, available_renamings)
1453
+ )
1454
+ elif isinstance(expression, WindowExpression):
1455
+ renamed_function = rename_columns_in_expression(
1456
+ expression.window_function, available_renamings
1457
+ )
1458
+ renamed_partition = [
1459
+ rename_columns_in_expression(part, available_renamings)
1460
+ for part in expression.partitioning
1461
+ ]
1462
+ renamed_orderby = (
1463
+ rename_columns_in_clause(expression.ordering, available_renamings)
1464
+ if expression.ordering
1465
+ else None
1466
+ )
1467
+ renamed_filter = (
1468
+ rename_columns_in_predicate(
1469
+ expression.filter_condition, available_renamings
1470
+ )
1471
+ if expression.filter_condition
1472
+ else None
1473
+ )
1474
+ return WindowExpression(
1475
+ renamed_function,
1476
+ partitioning=renamed_partition,
1477
+ ordering=renamed_orderby,
1478
+ filter_condition=renamed_filter,
1479
+ )
1480
+ elif isinstance(expression, CaseExpression):
1481
+ renamed_cases = [
1482
+ (
1483
+ rename_columns_in_predicate(condition, available_renamings),
1484
+ rename_columns_in_expression(result, available_renamings),
1485
+ )
1486
+ for condition, result in expression.cases
1487
+ ]
1488
+ renamed_else = (
1489
+ rename_columns_in_expression(
1490
+ expression.else_expression, available_renamings
1491
+ )
1492
+ if expression.else_expression
1493
+ else None
1494
+ )
1495
+ return CaseExpression(renamed_cases, else_expr=renamed_else)
1496
+ else:
1497
+ raise ValueError("Unknown expression type: " + str(expression))
1498
+
1499
+
1500
+ def _rename_columns_in_expression(
1501
+ expression: Optional[SqlExpression],
1502
+ available_renamings: dict[ColumnReference, ColumnReference],
1503
+ ) -> Optional[SqlExpression]:
1504
+ """See `rename_columns_in_expression` for details.
1505
+
1506
+ See Also
1507
+ --------
1508
+ rename_columns_in_expression
1509
+ """
1510
+ warnings.warn(
1511
+ "This method is deprecated and will be removed in the future. Use `rename_columns_in_expression` instead.",
1512
+ FutureWarning,
1513
+ )
1514
+ return rename_columns_in_expression(expression, available_renamings)
1515
+
1516
+
1517
+ def rename_columns_in_predicate(
1518
+ predicate: Optional[AbstractPredicate],
1519
+ available_renamings: dict[ColumnReference, ColumnReference],
1520
+ ) -> Optional[AbstractPredicate]:
1521
+ """Replaces all references to specific columns in a predicate by new references.
1522
+
1523
+ Parameters
1524
+ ----------
1525
+ predicate : Optional[AbstractPredicate]
1526
+ The predicate to update. Can be ``None``, in which case no update is performed.
1527
+ available_renamings : dict[ColumnReference, ColumnReference]
1528
+ A dictionary mapping each of the old column values to the values that should be used instead.
1529
+
1530
+ Returns
1531
+ -------
1532
+ Optional[AbstractPredicate]
1533
+ The updated predicate. Can be ``None``, if `predicate` already was.
1534
+
1535
+ Raises
1536
+ ------
1537
+ ValueError
1538
+ If the query is of no known type. This indicates that this method is missing a handler for a specific query type that
1539
+ was added later on.
1540
+ """
1541
+ if not predicate:
1542
+ return None
1543
+
1544
+ if isinstance(predicate, BinaryPredicate):
1545
+ renamed_first_arg = rename_columns_in_expression(
1546
+ predicate.first_argument, available_renamings
1547
+ )
1548
+ renamed_second_arg = rename_columns_in_expression(
1549
+ predicate.second_argument, available_renamings
1550
+ )
1551
+ return BinaryPredicate(
1552
+ predicate.operation, renamed_first_arg, renamed_second_arg
1553
+ )
1554
+ elif isinstance(predicate, BetweenPredicate):
1555
+ renamed_col = rename_columns_in_expression(
1556
+ predicate.column, available_renamings
1557
+ )
1558
+ renamed_interval_start = rename_columns_in_expression(
1559
+ predicate.interval_start, available_renamings
1560
+ )
1561
+ renamed_interval_end = rename_columns_in_expression(
1562
+ predicate.interval_end, available_renamings
1563
+ )
1564
+ return BetweenPredicate(
1565
+ renamed_col, (renamed_interval_start, renamed_interval_end)
1566
+ )
1567
+ elif isinstance(predicate, InPredicate):
1568
+ renamed_col = rename_columns_in_expression(
1569
+ predicate.column, available_renamings
1570
+ )
1571
+ renamed_vals = [
1572
+ rename_columns_in_expression(val, available_renamings)
1573
+ for val in predicate.values
1574
+ ]
1575
+ return InPredicate(renamed_col, renamed_vals)
1576
+ elif isinstance(predicate, UnaryPredicate):
1577
+ return UnaryPredicate(
1578
+ rename_columns_in_expression(predicate.column, available_renamings),
1579
+ predicate.operation,
1580
+ )
1581
+ elif isinstance(predicate, CompoundPredicate):
1582
+ renamed_children = (
1583
+ [rename_columns_in_predicate(predicate.children, available_renamings)]
1584
+ if predicate.operation == CompoundOperator.Not
1585
+ else [
1586
+ rename_columns_in_predicate(child, available_renamings)
1587
+ for child in predicate.children
1588
+ ]
1589
+ )
1590
+ return CompoundPredicate(predicate.operation, renamed_children)
1591
+ else:
1592
+ raise ValueError("Unknown predicate type: " + str(predicate))
1593
+
1594
+
1595
+ def _rename_columns_in_table_source(
1596
+ table_source: TableSource,
1597
+ available_renamings: dict[ColumnReference, ColumnReference],
1598
+ ) -> Optional[TableSource]:
1599
+ """Handler method to replace all references to specific columns by new columns.
1600
+
1601
+ Parameters
1602
+ ----------
1603
+ table_source : TableSource
1604
+ The source that should be updated
1605
+ available_renamings : dict[ColumnReference, ColumnReference]
1606
+ A dictionary mapping each of the old column values to the values that should be used instead.
1607
+
1608
+ Returns
1609
+ -------
1610
+ Optional[TableSource]
1611
+ The updated source. Can be ``None``, if `table_source` already was.
1612
+
1613
+ Raises
1614
+ ------
1615
+ TypeError
1616
+ If the source is of no known type. This indicates that this method is missing a handler for a specific source type that
1617
+ was added later on.
1618
+ """
1619
+ if table_source is None:
1620
+ return None
1621
+
1622
+ match table_source:
1623
+ case DirectTableSource():
1624
+ # no columns in a plain table reference, we are done here
1625
+ return table_source
1626
+
1627
+ case SubqueryTableSource():
1628
+ renamed_subquery = rename_columns_in_query(
1629
+ table_source.query, available_renamings
1630
+ )
1631
+ return SubqueryTableSource(
1632
+ renamed_subquery, table_source.target_name, lateral=table_source.lateral
1633
+ )
1634
+
1635
+ case JoinTableSource():
1636
+ renamed_left = _rename_columns_in_table_source(
1637
+ table_source.left, available_renamings
1638
+ )
1639
+ renamed_right = _rename_columns_in_table_source(
1640
+ table_source.right, available_renamings
1641
+ )
1642
+ renamed_condition = rename_columns_in_predicate(
1643
+ table_source.join_condition, available_renamings
1644
+ )
1645
+ return JoinTableSource(
1646
+ renamed_left,
1647
+ renamed_right,
1648
+ join_condition=renamed_condition,
1649
+ join_type=table_source.join_type,
1650
+ )
1651
+
1652
+ case ValuesTableSource():
1653
+ if not any(
1654
+ col.belongs_to(table_source.table) for col in available_renamings
1655
+ ):
1656
+ return table_source
1657
+
1658
+ for current_col, target_col in available_renamings.items():
1659
+ if not current_col.belongs_to(table_source.table):
1660
+ continue
1661
+ if current_col.table != target_col.table:
1662
+ raise ValueError(
1663
+ "Cannot rename columns in a VALUES table source to a different table"
1664
+ )
1665
+
1666
+ # if we found a column that should be renamed, we need to replace the whole column specification
1667
+ # this process might be repeated multiple times, if multiple appropriate renamings exist
1668
+ current_col_spec = table_source.cols
1669
+ new_col_spec = [
1670
+ (col if col.name != current_col.name else target_col.name)
1671
+ for col in current_col_spec
1672
+ ]
1673
+ table_source = ValuesTableSource(
1674
+ table_source.rows,
1675
+ alias=table_source.table.identifier(),
1676
+ columns=new_col_spec,
1677
+ )
1678
+ return table_source
1679
+
1680
+ case FunctionTableSource():
1681
+ renamed_function = rename_columns_in_expression(
1682
+ table_source.function, available_renamings
1683
+ )
1684
+ return FunctionTableSource(renamed_function, table_source.target_table)
1685
+
1686
+ case _:
1687
+ raise TypeError("Unknown table source type: " + str(table_source))
1688
+
1689
+
1690
+ def rename_columns_in_clause(
1691
+ clause: Optional[ClauseType],
1692
+ available_renamings: dict[ColumnReference, ColumnReference],
1693
+ ) -> Optional[ClauseType]:
1694
+ """Replaces all references to specific columns in a clause by new columns.
1695
+
1696
+ Parameters
1697
+ ----------
1698
+ clause : Optional[ClauseType]
1699
+ The clause to update. Can be ``None``, in which case no update is performed.
1700
+ available_renamings : dict[ColumnReference, ColumnReference]
1701
+ A dictionary mapping each of the old column values to the values that should be used instead.
1702
+
1703
+ Returns
1704
+ -------
1705
+ Optional[ClauseType]
1706
+ The updated clause. Can be ``None``, if `clause` already was.
1707
+
1708
+ Raises
1709
+ ------
1710
+ ValueError
1711
+ If the clause is of no known type. This indicates that this method is missing a handler for a specific clause type that
1712
+ was added later on.
1713
+ """
1714
+ if not clause:
1715
+ return None
1716
+
1717
+ if isinstance(clause, Hint) or isinstance(clause, Explain):
1718
+ return clause
1719
+ if isinstance(clause, CommonTableExpression):
1720
+ renamed_ctes: list[WithQuery] = []
1721
+
1722
+ for cte in clause.queries:
1723
+ if not isinstance(cte, ValuesWithQuery):
1724
+ new_query = rename_columns_in_query(cte.query, available_renamings)
1725
+ renamed_ctes.append(
1726
+ WithQuery(
1727
+ new_query, cte.target_table, materialized=cte.materialized
1728
+ )
1729
+ )
1730
+ continue
1731
+
1732
+ if not any(col.belongs_to(cte.target_table) for col in available_renamings):
1733
+ continue
1734
+
1735
+ renamed_cte = cte
1736
+ for current_col, target_col in available_renamings.items():
1737
+ if not current_col.belongs_to(cte.target_table):
1738
+ continue
1739
+ if current_col.table != target_col.table:
1740
+ raise ValueError(
1741
+ "Cannot rename columns in a VALUES table source to a different table"
1742
+ )
1743
+
1744
+ # if we found a column that should be renamed, we need to replace the whole column specification
1745
+ # this process might be repeated multiple times, if multiple appropriate renamings exist
1746
+ current_col_spec = cte.cols
1747
+ new_col_spec = [
1748
+ (col if col.name != current_col.name else target_col.name)
1749
+ for col in current_col_spec
1750
+ ]
1751
+ renamed_cte = ValuesWithQuery(
1752
+ cte.rows,
1753
+ target_name=cte.target_table,
1754
+ columns=new_col_spec,
1755
+ materialized=cte.materialized,
1756
+ )
1757
+
1758
+ renamed_ctes.append(renamed_cte)
1759
+
1760
+ return CommonTableExpression(renamed_ctes, recursive=clause.recursive)
1761
+ if isinstance(clause, Select):
1762
+ renamed_targets = [
1763
+ BaseProjection(
1764
+ rename_columns_in_expression(proj.expression, available_renamings),
1765
+ proj.target_name,
1766
+ )
1767
+ for proj in clause.targets
1768
+ ]
1769
+ return Select(renamed_targets, distinct=clause.distinct_specifier())
1770
+ elif isinstance(clause, ImplicitFromClause):
1771
+ return clause
1772
+ elif isinstance(clause, ExplicitFromClause):
1773
+ renamed_joins = [
1774
+ _rename_columns_in_table_source(join, available_renamings)
1775
+ for join in clause.items
1776
+ ]
1777
+ return ExplicitFromClause(renamed_joins)
1778
+ elif isinstance(clause, From):
1779
+ renamed_sources = [
1780
+ _rename_columns_in_table_source(table_source, available_renamings)
1781
+ for table_source in clause.items
1782
+ ]
1783
+ return From(renamed_sources)
1784
+ elif isinstance(clause, Where):
1785
+ return Where(rename_columns_in_predicate(clause.predicate, available_renamings))
1786
+ elif isinstance(clause, GroupBy):
1787
+ renamed_cols = [
1788
+ rename_columns_in_expression(col, available_renamings)
1789
+ for col in clause.group_columns
1790
+ ]
1791
+ return GroupBy(renamed_cols, clause.distinct)
1792
+ elif isinstance(clause, Having):
1793
+ return Having(
1794
+ rename_columns_in_predicate(clause.condition, available_renamings)
1795
+ )
1796
+ elif isinstance(clause, OrderBy):
1797
+ renamed_cols = [
1798
+ OrderByExpression(
1799
+ rename_columns_in_expression(col.column, available_renamings),
1800
+ col.ascending,
1801
+ col.nulls_first,
1802
+ )
1803
+ for col in clause.expressions
1804
+ ]
1805
+ return OrderBy(renamed_cols)
1806
+ elif isinstance(clause, Limit):
1807
+ return clause
1808
+ elif isinstance(clause, UnionClause):
1809
+ renamed_left = rename_columns_in_query(clause.left_query, available_renamings)
1810
+ renamed_right = rename_columns_in_query(clause.right_query, available_renamings)
1811
+ return UnionClause(renamed_left, renamed_right, union_all=clause.union_all)
1812
+ elif isinstance(clause, IntersectClause):
1813
+ renamed_left = rename_columns_in_query(clause.left_query, available_renamings)
1814
+ renamed_right = rename_columns_in_query(clause.right_query, available_renamings)
1815
+ return IntersectClause(renamed_left, renamed_right)
1816
+ elif isinstance(clause, ExceptClause):
1817
+ renamed_left = rename_columns_in_query(clause.left_query, available_renamings)
1818
+ renamed_right = rename_columns_in_query(clause.right_query, available_renamings)
1819
+ return ExceptClause(renamed_left, renamed_right)
1820
+ else:
1821
+ raise ValueError("Unknown clause: " + str(clause))
1822
+
1823
+
1824
+ class _TableReferenceRenamer(
1825
+ ClauseVisitor[BaseClause],
1826
+ PredicateVisitor[AbstractPredicate],
1827
+ SqlExpressionVisitor[SqlExpression],
1828
+ ):
1829
+ """Visitor to replace all references to specific tables to refer to new tables instead.
1830
+
1831
+ Parameters
1832
+ ----------
1833
+ renamings : Optional[dict[TableReference, TableReference]], optional
1834
+ Map from old table name to new name. If this is given, `source_table` and `target_table` are ignored.
1835
+ source_table : Optional[TableReference], optional
1836
+ Create a visitor for a single renaming operation. This parameter specifies the old table name that should be replaced.
1837
+ This parameter is ignored if `renamings` is given.
1838
+ target_table : Optional[TableReference], optional
1839
+ Create a visitor for a single renaming operation. This parameter specifies the new table name that should be used
1840
+ instead of `source_table`. This parameter is ignored if `renamings` is given.
1841
+ """
1842
+
1843
+ def __init__(
1844
+ self,
1845
+ renamings: Optional[dict[TableReference, TableReference]] = None,
1846
+ *,
1847
+ source_table: Optional[TableReference] = None,
1848
+ target_table: Optional[TableReference] = None,
1849
+ ) -> None:
1850
+ if renamings is not None:
1851
+ self._renamings = renamings
1852
+
1853
+ if source_table is None or target_table is None:
1854
+ raise ValueError(
1855
+ "Both source_table and target_table must be provided if renamings are not given explicitly"
1856
+ )
1857
+ self._renamings = {source_table: target_table}
1858
+
1859
+ def visit_hint_clause(self, clause: Hint) -> Hint:
1860
+ return clause
1861
+
1862
+ def visit_explain_clause(self, clause: Explain) -> Explain:
1863
+ return clause
1864
+
1865
+ def visit_cte_clause(self, clause: CommonTableExpression) -> CommonTableExpression:
1866
+ ctes: list[WithQuery] = []
1867
+
1868
+ for cte in clause.queries:
1869
+ nested_renamings = cte.query.accept_visitor(self)
1870
+ nested_query = build_query(nested_renamings.values())
1871
+ target_table = self._rename_table(cte.target_table)
1872
+
1873
+ ctes.append(
1874
+ WithQuery(nested_query, target_table, materialized=cte.materialized)
1875
+ )
1876
+
1877
+ return CommonTableExpression(ctes, recursive=clause.recursive)
1878
+
1879
+ def visit_select_clause(self, clause) -> Select:
1880
+ projections: list[BaseProjection] = []
1881
+
1882
+ for proj in clause:
1883
+ renamed_expression = proj.expression.accept_visitor(self)
1884
+ projections.append(BaseProjection(renamed_expression, proj.target_name))
1885
+
1886
+ return Select(projections, distinct=clause.distinct_specifier())
1887
+
1888
+ def visit_from_clause(self, clause: From) -> From:
1889
+ match clause:
1890
+ case ImplicitFromClause(tables):
1891
+ renamed_tables = [self._rename_table_source(src) for src in tables]
1892
+ return ImplicitFromClause.create_for(renamed_tables)
1893
+
1894
+ case ExplicitFromClause(join):
1895
+ renamed_join = self._rename_table_source(join)
1896
+ return ExplicitFromClause(renamed_join)
1897
+
1898
+ case From(items):
1899
+ renamed_items = [self._rename_table_source(item) for item in items]
1900
+ return From(renamed_items)
1901
+
1902
+ case _:
1903
+ raise ValueError("Unknown from clause type: " + str(clause))
1904
+
1905
+ def visit_where_clause(self, clause: Where) -> Where:
1906
+ renamed_predicate = clause.predicate.accept_visitor(self)
1907
+ return Where(renamed_predicate)
1908
+
1909
+ def visit_groupby_clause(self, clause: GroupBy) -> GroupBy:
1910
+ renamed_groupings = [
1911
+ grouping.accept_visitor(self) for grouping in clause.group_columns
1912
+ ]
1913
+ return GroupBy(renamed_groupings, clause.distinct)
1914
+
1915
+ def visit_having_clause(self, clause: Having) -> Having:
1916
+ renamed_predicate = clause.condition.accept_visitor(self)
1917
+ return Having(renamed_predicate)
1918
+
1919
+ def visit_orderby_clause(self, clause: OrderBy) -> OrderBy:
1920
+ renamed_orderings: list[OrderByExpression] = []
1921
+
1922
+ for ordering in clause:
1923
+ renamed_expression = ordering.column.accept_visitor(self)
1924
+ renamed_orderings.append(
1925
+ OrderByExpression(
1926
+ renamed_expression, ordering.ascending, ordering.nulls_first
1927
+ )
1928
+ )
1929
+
1930
+ return renamed_orderings
1931
+
1932
+ def visit_limit_clause(self, clause: Limit) -> Limit:
1933
+ return clause
1934
+
1935
+ def visit_union_clause(self, clause: UnionClause) -> UnionClause:
1936
+ renamed_lhs = clause.left_query.accept_visitor(self)
1937
+ renamed_rhs = clause.right_query.accept_visitor(self)
1938
+ return UnionClause(renamed_lhs, renamed_rhs, union_all=clause.union_all)
1939
+
1940
+ def visit_except_clause(self, clause: ExceptClause) -> ExceptClause:
1941
+ renamed_lhs = clause.left_query.accept_visitor(self)
1942
+ renamed_rhs = clause.right_query.accept_visitor(self)
1943
+ return ExceptClause(renamed_lhs, renamed_rhs)
1944
+
1945
+ def visit_intersect_clause(self, clause: IntersectClause) -> IntersectClause:
1946
+ renamed_lhs = clause.left_query.accept_visitor(self)
1947
+ renamed_rhs = clause.right_query.accept_visitor(self)
1948
+ return IntersectClause(renamed_lhs, renamed_rhs)
1949
+
1950
+ def visit_binary_predicate(self, predicate: BinaryPredicate) -> BinaryPredicate:
1951
+ renamed_lhs = predicate.first_argument.accept_visitor(self)
1952
+ renamed_rhs = predicate.second_argument.accept_visitor(self)
1953
+ return BinaryPredicate(predicate.operation, renamed_lhs, renamed_rhs)
1954
+
1955
+ def visit_between_predicate(self, predicate: BetweenPredicate) -> BetweenPredicate:
1956
+ renamed_col = predicate.column.accept_visitor(self)
1957
+ renamed_start = predicate.interval_start.accept_visitor(self)
1958
+ renamed_end = predicate.interval_end.accept_visitor(self)
1959
+ return BetweenPredicate(renamed_col, (renamed_start, renamed_end))
1960
+
1961
+ def visit_in_predicate(self, predicate: InPredicate) -> InPredicate:
1962
+ renamed_col = predicate.column.accept_visitor(self)
1963
+ renamed_vals = [val.accept_visitor(self) for val in predicate.values]
1964
+ return InPredicate(renamed_col, renamed_vals)
1965
+
1966
+ def visit_unary_predicate(self, predicate: UnaryPredicate) -> UnaryPredicate:
1967
+ renamed_col = predicate.column.accept_visitor(self)
1968
+ return UnaryPredicate(renamed_col, predicate.operation)
1969
+
1970
+ def visit_not_predicate(
1971
+ self, predicate: CompoundPredicate, child_predicate: AbstractPredicate
1972
+ ) -> CompoundPredicate:
1973
+ renamed_child = child_predicate.accept_visitor(self)
1974
+ return CompoundPredicate(CompoundOperator.Not, [renamed_child])
1975
+
1976
+ def visit_or_predicate(
1977
+ self,
1978
+ predicate: CompoundPredicate,
1979
+ child_predicates: Sequence[AbstractPredicate],
1980
+ ) -> CompoundPredicate:
1981
+ renamed_children = [child.accept_visitor(self) for child in child_predicates]
1982
+ return CompoundPredicate(CompoundOperator.Or, renamed_children)
1983
+
1984
+ def visit_and_predicate(
1985
+ self,
1986
+ predicate: CompoundPredicate,
1987
+ child_predicates: Sequence[AbstractPredicate],
1988
+ ) -> CompoundPredicate:
1989
+ renamed_children = [child.accept_visitor(self) for child in child_predicates]
1990
+ return CompoundPredicate(CompoundOperator.And, renamed_children)
1991
+
1992
+ def visit_static_value_expr(
1993
+ self, expr: StaticValueExpression
1994
+ ) -> StaticValueExpression:
1995
+ return expr
1996
+
1997
+ def visit_cast_expr(self, expr: CastExpression) -> CastExpression:
1998
+ renamed_child = expr.casted_expression.accept_visitor(self)
1999
+ return CastExpression(renamed_child, expr.target_type)
2000
+
2001
+ def visit_math_expr(self, expr: MathExpression) -> MathExpression:
2002
+ renamed_lhs = expr.first_arg.accept_visitor(self)
2003
+
2004
+ if isinstance(expr.second_arg, SqlExpression):
2005
+ renamed_rhs = expr.second_arg.accept_visitor(self)
2006
+ elif expr.second_arg is not None:
2007
+ renamed_rhs = [
2008
+ nested_expr.accept_visitor(self) for nested_expr in expr.second_arg
2009
+ ]
2010
+ else:
2011
+ renamed_rhs = None
2012
+
2013
+ return MathExpression(expr.operator, renamed_lhs, renamed_rhs)
2014
+
2015
+ def visit_column_expr(self, expr: ColumnExpression) -> ColumnExpression:
2016
+ return self._rename_column(expr)
2017
+
2018
+ def visit_function_expr(self, expr: FunctionExpression) -> FunctionExpression:
2019
+ renamed_args = [arg.accept_visitor(self) for arg in expr.arguments]
2020
+ renamed_filter = (
2021
+ expr.filter_where.accept_visitor(self) if expr.filter_where else None
2022
+ )
2023
+ return FunctionExpression(
2024
+ expr.function,
2025
+ renamed_args,
2026
+ distinct=expr.distinct,
2027
+ filter_where=renamed_filter,
2028
+ )
2029
+
2030
+ def visit_subquery_expr(self, expr: SubqueryExpression) -> SubqueryExpression:
2031
+ renamed_subquery = expr.query.accept_visitor(self)
2032
+ return SubqueryExpression(renamed_subquery)
2033
+
2034
+ def visit_star_expr(self, expr: StarExpression) -> StarExpression:
2035
+ renamed_table = self._rename_table(expr.from_table) if expr.from_table else None
2036
+ return StarExpression(from_table=renamed_table)
2037
+
2038
+ def visit_window_expr(self, expr: WindowExpression) -> WindowExpression:
2039
+ renamed_window_func = expr.window_function.accept_visitor(self)
2040
+ renamed_partition = [part.accept_visitor(self) for part in expr.partitioning]
2041
+ renamed_ordering = expr.ordering.accept_visitor(self) if expr.ordering else None
2042
+ renamed_filter = (
2043
+ expr.filter_condition.accept_visitor(self)
2044
+ if expr.filter_condition
2045
+ else None
2046
+ )
2047
+ return WindowExpression(
2048
+ renamed_window_func,
2049
+ partitioning=renamed_partition,
2050
+ ordering=renamed_ordering,
2051
+ filter_condition=renamed_filter,
2052
+ )
2053
+
2054
+ def visit_case_expr(self, expr: CaseExpression) -> CaseExpression:
2055
+ renamed_cases: list[tuple[AbstractPredicate, SqlExpression]] = []
2056
+
2057
+ for condition, value in expr.cases:
2058
+ renamed_condition = condition.accept_visitor(self)
2059
+ renamed_value = value.accept_visitor(self)
2060
+ renamed_cases.append((renamed_condition, renamed_value))
2061
+
2062
+ renamed_default = (
2063
+ expr.default_case.accept_visitor(self) if expr.default_case else None
2064
+ )
2065
+ return CaseExpression(renamed_cases, else_expr=renamed_default)
2066
+
2067
+ def visit_predicate_expr(self, expr: AbstractPredicate) -> AbstractPredicate:
2068
+ return expr.accept_visitor(self)
2069
+
2070
+ def _rename_table_source(self, source: TableSource) -> JoinTableSource:
2071
+ """Helper method to traverse and rename (the contents of) an arbitrary *FROM* item."""
2072
+ match source:
2073
+ case DirectTableSource(tab):
2074
+ renamed_table = self._rename_table(tab)
2075
+ return DirectTableSource(renamed_table)
2076
+
2077
+ case SubqueryTableSource(subquery, target_name, lateral):
2078
+ nested_renamings = subquery.accept_visitor(self)
2079
+ nested_query = build_query(nested_renamings.values())
2080
+ target_table = self._rename_table(target_name)
2081
+ return SubqueryTableSource(nested_query, target_table, lateral=lateral)
2082
+
2083
+ case JoinTableSource(lhs, rhs, join_condition, join_type):
2084
+ renamed_lhs = self._rename_table_source(lhs)
2085
+ renamed_rhs = self._rename_table_source(rhs)
2086
+ renamed_condition = (
2087
+ join_condition.accept_visitor(self) if join_condition else None
2088
+ )
2089
+ return JoinTableSource(
2090
+ renamed_lhs,
2091
+ renamed_rhs,
2092
+ join_condition=renamed_condition,
2093
+ join_type=join_type,
2094
+ )
2095
+
2096
+ case ValuesTableSource(rows, alias, columns):
2097
+ renamed_alias = self._rename_table(alias)
2098
+ return ValuesTableSource(rows, alias=renamed_alias, columns=columns)
2099
+
2100
+ case FunctionTableSource(function, alias):
2101
+ renamed_function = function.accept_visitor(self)
2102
+ renamed_alias = self._rename_table(alias) if alias else ""
2103
+ return FunctionTableSource(renamed_function, renamed_alias)
2104
+
2105
+ case _:
2106
+ raise ValueError("Unknown table source type: " + str(source))
2107
+
2108
+ @overload
2109
+ def _rename_table(self, table: TableReference) -> TableReference: ...
2110
+
2111
+ @overload
2112
+ def _rename_table(self, table: str) -> str: ...
2113
+
2114
+ def _rename_table(self, table: str | TableReference) -> str | TableReference:
2115
+ """Helper method to rename a specific table reference independent of its specific representation."""
2116
+ if isinstance(table, TableReference):
2117
+ return self._renamings.get(table, table)
2118
+
2119
+ return next(
2120
+ (
2121
+ target_tab.identifier()
2122
+ for orig_tab, target_tab in self._renamings.items()
2123
+ if orig_tab.identifier() == table
2124
+ ),
2125
+ table,
2126
+ )
2127
+
2128
+ def _rename_column(self, column: ColumnReference) -> ColumnReference:
2129
+ """Helper method to rename a specific column reference."""
2130
+ if not column.is_bound():
2131
+ return column
2132
+
2133
+ target_table = self._renamings.get(column.table)
2134
+ return column.bind_to(target_table) if target_table else column
2135
+
2136
+
2137
+ def rename_table(
2138
+ source_query: SelectQueryType,
2139
+ from_table: TableReference,
2140
+ target_table: TableReference,
2141
+ *,
2142
+ prefix_column_names: bool = False,
2143
+ ) -> SelectQueryType:
2144
+ """Changes all references to a specific table to refer to another table instead.
2145
+
2146
+ Parameters
2147
+ ----------
2148
+ source_query : SelectQueryType
2149
+ The query that should be updated
2150
+ from_table : TableReference
2151
+ The table that should be replaced
2152
+ target_table : TableReference
2153
+ The table that should be used instead
2154
+ prefix_column_names : bool, optional
2155
+ Whether a prefix should be added to column names. If this is ``True``, column references will be changed in two ways:
2156
+
2157
+ 1. if they belonged to the `from_table`, they will now belong to the `target_table` after the renaming
2158
+ 2. The column names will be changed to include the identifier of the `from_table` as a prefix.
2159
+
2160
+ Returns
2161
+ -------
2162
+ SelectQueryType
2163
+ The updated query
2164
+ """
2165
+
2166
+ # Despite the convenient _TableReferenceRenamer, we still need to do a little bit of manual gathering/traversal to support
2167
+ # column prefixes.
2168
+ necessary_renamings: dict[ColumnReference, ColumnReference] = {}
2169
+ for column in filter(lambda col: col.table == from_table, source_query.columns()):
2170
+ new_column_name = (
2171
+ f"{column.table.identifier()}_{column.name}"
2172
+ if prefix_column_names
2173
+ else column.name
2174
+ )
2175
+ necessary_renamings[column] = ColumnReference(new_column_name, target_table)
2176
+
2177
+ renamed_cols = rename_columns_in_query(source_query, necessary_renamings)
2178
+
2179
+ tab_renamer = _TableReferenceRenamer(
2180
+ source_table=from_table, target_table=target_table
2181
+ )
2182
+ renamed_clauses = renamed_cols.accept_visitor(tab_renamer)
2183
+
2184
+ return build_query(renamed_clauses.values())