PostBOUND 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. postbound/__init__.py +211 -0
  2. postbound/_base.py +6 -0
  3. postbound/_bench.py +1012 -0
  4. postbound/_core.py +1153 -0
  5. postbound/_hints.py +1373 -0
  6. postbound/_jointree.py +1079 -0
  7. postbound/_pipelines.py +1121 -0
  8. postbound/_qep.py +1986 -0
  9. postbound/_stages.py +876 -0
  10. postbound/_validation.py +734 -0
  11. postbound/db/__init__.py +72 -0
  12. postbound/db/_db.py +2348 -0
  13. postbound/db/_duckdb.py +785 -0
  14. postbound/db/mysql.py +1195 -0
  15. postbound/db/postgres.py +4216 -0
  16. postbound/experiments/__init__.py +12 -0
  17. postbound/experiments/analysis.py +674 -0
  18. postbound/experiments/benchmarking.py +54 -0
  19. postbound/experiments/ceb.py +877 -0
  20. postbound/experiments/interactive.py +105 -0
  21. postbound/experiments/querygen.py +334 -0
  22. postbound/experiments/workloads.py +980 -0
  23. postbound/optimizer/__init__.py +92 -0
  24. postbound/optimizer/__init__.pyi +73 -0
  25. postbound/optimizer/_cardinalities.py +369 -0
  26. postbound/optimizer/_joingraph.py +1150 -0
  27. postbound/optimizer/dynprog.py +1825 -0
  28. postbound/optimizer/enumeration.py +432 -0
  29. postbound/optimizer/native.py +539 -0
  30. postbound/optimizer/noopt.py +54 -0
  31. postbound/optimizer/presets.py +147 -0
  32. postbound/optimizer/randomized.py +650 -0
  33. postbound/optimizer/tonic.py +1479 -0
  34. postbound/optimizer/ues.py +1607 -0
  35. postbound/qal/__init__.py +343 -0
  36. postbound/qal/_qal.py +9678 -0
  37. postbound/qal/formatter.py +1089 -0
  38. postbound/qal/parser.py +2344 -0
  39. postbound/qal/relalg.py +4257 -0
  40. postbound/qal/transform.py +2184 -0
  41. postbound/shortcuts.py +70 -0
  42. postbound/util/__init__.py +46 -0
  43. postbound/util/_errors.py +33 -0
  44. postbound/util/collections.py +490 -0
  45. postbound/util/dataframe.py +71 -0
  46. postbound/util/dicts.py +330 -0
  47. postbound/util/jsonize.py +68 -0
  48. postbound/util/logging.py +106 -0
  49. postbound/util/misc.py +168 -0
  50. postbound/util/networkx.py +401 -0
  51. postbound/util/numbers.py +438 -0
  52. postbound/util/proc.py +107 -0
  53. postbound/util/stats.py +37 -0
  54. postbound/util/system.py +48 -0
  55. postbound/util/typing.py +35 -0
  56. postbound/vis/__init__.py +5 -0
  57. postbound/vis/fdl.py +69 -0
  58. postbound/vis/graphs.py +48 -0
  59. postbound/vis/optimizer.py +538 -0
  60. postbound/vis/plots.py +84 -0
  61. postbound/vis/tonic.py +70 -0
  62. postbound/vis/trees.py +105 -0
  63. postbound-0.19.0.dist-info/METADATA +355 -0
  64. postbound-0.19.0.dist-info/RECORD +67 -0
  65. postbound-0.19.0.dist-info/WHEEL +5 -0
  66. postbound-0.19.0.dist-info/licenses/LICENSE.txt +202 -0
  67. postbound-0.19.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,4257 @@
1
+ """relalg provides fundamental building blocks of relational algebra and a converter from SQL to algebra.
2
+
3
+ The central component of our algebra implementation is the `RelNode` class. All relational operators inherit from this abstract
4
+ class. Following the design of the expression and predicate models, all algebraic trees are immutable data structures. Once a
5
+ tree has been generated, it can no longer be modified.
6
+
7
+ One important aspect of our relational algebra design is how to model arbitrary expressions and projections, mappings, etc. on
8
+ these expressions. Some systems introduce temporary variables for mapping targets, e.g. ``arg0 <- R.a + 42`` and then base all
9
+ further accesses on these temporary variables. In this school of thought, an algebra tree for the query
10
+ ``SELECT R.a + 42 FROM R`` would like this:
11
+
12
+ .. math:: \\pi_{arg_0}(\\chi_{arg_0 \\leftarrow R.a + 42}(R))
13
+
14
+ This representation is especially usefull for a code-generation or physical optimization scenario because it enables a
15
+ straightforward creation of additional (temporary) columns. At the same time, it makes the translation of SQL queries to
16
+ relational algebra more challenging, since re-writes have to be applied during parsing. Since we are not concerned with
17
+ code-generation in our algebra representation and focus more on structural properties, we take a different approach: all
18
+ expressions (as defined in the `expressions` module) are contained as-is in the operators. However, we make sure that necessary
19
+ pre-processing actions are included as required. For example, if a complex expression is included in a predicate or a
20
+ projection, we generate the appropriate mapping operation beforehand and use it as an input for the consuming operator.
21
+
22
+ In addition to the conventional operators of relational algebra, we introduce a couple of additional operators that either
23
+ mimic features from SQL, or that make working with the algebra much easier from a technical point-of-view. The first category
24
+ includes operators such as `Sort` or `DuplicateElimination` and `Limit`, whereas the second category includes the
25
+ `SubqueryScan`.
26
+
27
+ Notice that while most algebraic expressions correspond to tree structures, there might be cases where a directed, acyclic
28
+ graph is generated. This is especially the case when a base relation is used as part of subqueries. Nevertheless, there will
29
+ always be only one root (sink) node.
30
+ """
31
+
32
+ from __future__ import annotations
33
+
34
+ import abc
35
+ import collections
36
+ import copy
37
+ import dataclasses
38
+ import enum
39
+ import functools
40
+ import operator
41
+ import typing
42
+ from collections.abc import Generator, Iterable, Sequence
43
+ from typing import Optional
44
+
45
+ from .. import util
46
+ from . import transform
47
+ from ._qal import (
48
+ AbstractPredicate,
49
+ BetweenPredicate,
50
+ BinaryPredicate,
51
+ CaseExpression,
52
+ CastExpression,
53
+ ColumnExpression,
54
+ ColumnReference,
55
+ CompoundOperator,
56
+ CompoundPredicate,
57
+ DirectTableSource,
58
+ ExpressionCollector,
59
+ FunctionExpression,
60
+ ImplicitSqlQuery,
61
+ InPredicate,
62
+ JoinTableSource,
63
+ LogicalOperator,
64
+ MathExpression,
65
+ OrderByExpression,
66
+ PredicateVisitor,
67
+ SelectStatement,
68
+ SetOperator,
69
+ SetQuery,
70
+ SqlExpression,
71
+ SqlExpressionVisitor,
72
+ SqlQuery,
73
+ StarExpression,
74
+ StaticValueExpression,
75
+ SubqueryExpression,
76
+ SubqueryTableSource,
77
+ TableReference,
78
+ TableSource,
79
+ UnaryPredicate,
80
+ WindowExpression,
81
+ )
82
+
83
+ # TODO: the creation and mutation of different relnodes should be handled by a dedicated factory class. This solves all issues
84
+ # with mutatbility/immutability and automated linking to parent nodes.
85
+
86
+
87
+ class RelNode(abc.ABC):
88
+ """Models a fundamental operator in relation algebra. All specific operators like selection or theta join inherit from it.
89
+
90
+ Parameters
91
+ ----------
92
+ parent_node : Optional[RelNode]
93
+ The parent node of the operator, if one exists. The parent is the operator that receives the output relation of the
94
+ current operator. If the current operator is the root and (currently) does not have a parent, *None* can be used.
95
+
96
+ See Also
97
+ --------
98
+ parse_relalg
99
+ """
100
+
101
+ def __init__(self, parent_node: Optional[RelNode]) -> None:
102
+ self._parent = parent_node
103
+ self._sideways_pass: set[RelNode] = set()
104
+ self._node_type = type(self).__name__
105
+ self._hash_val = hash((self._node_type, self._recalc_hash_val()))
106
+ self._maintain_child_links()
107
+
108
+ @property
109
+ def node_type(self) -> str:
110
+ """Get the current operator as a string.
111
+
112
+ Returns
113
+ -------
114
+ str
115
+ The operator name
116
+ """
117
+ return self._node_type
118
+
119
+ @property
120
+ def parent_node(self) -> Optional[RelNode]:
121
+ """Get the parent node of the current operator, if it exists.
122
+
123
+ Returns
124
+ -------
125
+ Optional[RelNode]
126
+ The parent is the operator that receives the output relation of the current operator. If the current operator is
127
+ the root and (currently) does not have a parent, *None* is returned.
128
+ """
129
+ return self._parent
130
+
131
+ @property
132
+ def sideways_pass(self) -> frozenset[RelNode]:
133
+ """Get all nodes that receive the output of the current operator in addition to the parent node.
134
+
135
+ Returns
136
+ -------
137
+ frozenset[RelNode]
138
+ The sideways pass nodes
139
+ """
140
+ return frozenset(self._sideways_pass)
141
+
142
+ def root(self) -> RelNode:
143
+ """Traverses the algebra tree upwards until the root node is found.
144
+
145
+ Returns
146
+ -------
147
+ RelNode
148
+ The root node of the algebra expression. Can be the current node if it does not have a parent.
149
+ """
150
+ if self._parent is None:
151
+ return self
152
+ return self._parent.root()
153
+
154
+ def leaf(self) -> Optional[RelNode]:
155
+ """Traverses the algebra tree downwards until a unique leaf node is found.
156
+
157
+ Returns
158
+ -------
159
+ Optional[RelNode]
160
+ The leaf node. If multiple leaf nodes exist (e.g. for join nodes), *None* is returned.
161
+ """
162
+ children = self.children()
163
+ if not children:
164
+ return self
165
+ return children[0] if len(children) == 1 else None
166
+
167
+ @abc.abstractmethod
168
+ def children(self) -> Sequence[RelNode]:
169
+ """Provides all input nodes of the current operator.
170
+
171
+ Returns
172
+ -------
173
+ Sequence[RelNode]
174
+ The input nodes. For leave nodes such as table scans, the sequence will be usually empty (except for subquery
175
+ aliases), otherwise the children are provided from left to right.
176
+ """
177
+ raise NotImplementedError
178
+
179
+ def tables(self, *, ignore_subqueries: bool = False) -> frozenset[TableReference]:
180
+ """Provides all relations that are contained in the current node.
181
+
182
+ Consider the following algebraic expression: *π(⋈(σ(R), S))*. This expression contains two relations: *R* and *S*.
183
+
184
+ Parameters
185
+ ----------
186
+ ignore_subqueries : bool, optional
187
+ Whether relations that are only referenced in subquery subtrees should be excluded. Off by default.
188
+
189
+ Returns
190
+ -------
191
+ frozenset[TableReference]
192
+ The tables
193
+ """
194
+ return frozenset(
195
+ util.set_union(
196
+ child.tables(ignore_subqueries=ignore_subqueries)
197
+ for child in self.children()
198
+ )
199
+ )
200
+
201
+ def provided_expressions(self) -> frozenset[SqlExpression]:
202
+ """Collects all expressions that are available to parent nodes.
203
+
204
+ These expressions will contain all expressions that are provided by child nodes as well as all expressions that are
205
+ calculated by the current node itself.
206
+
207
+ Returns
208
+ -------
209
+ frozenset[expressions.SqlExpression]
210
+ The expressions
211
+ """
212
+ return util.set_union(child.provided_expressions() for child in self.children())
213
+
214
+ @abc.abstractmethod
215
+ def accept_visitor(self, visitor: RelNodeVisitor[VisitorResult]) -> VisitorResult:
216
+ """Enables processing of the current algebraic expression by an expression visitor.
217
+
218
+ Parameters
219
+ ----------
220
+ visitor : RelNodeVisitor[VisitorResult]
221
+ The visitor
222
+ """
223
+ raise NotImplementedError
224
+
225
+ def dfs_walk(self) -> Generator[RelNode, None, None]:
226
+ """Performs a depth-first search on the algebraic expression.
227
+
228
+ This produces the subtree induced by the current node. The current node is also included in the output.
229
+
230
+ Yields
231
+ ------
232
+ Generator[RelNode, None, None]
233
+ All nodes of the subtree induced by the current node.
234
+ """
235
+ yield self
236
+ for child in self.children():
237
+ yield from child.dfs_walk()
238
+
239
+ def mutate(self, *, as_root: bool = False, **kwargs) -> RelNode:
240
+ """Creates a new instance of the current operator with modified attributes.
241
+
242
+ The specific parameters depend on the concrete operator type. Calling `mutate()` on *any* node generates a copy of the
243
+ *entire* tree, so make sure to use this operation sparingly.
244
+
245
+ See *Notes* for important remarks regarding the mutate implementation for custom node types.
246
+
247
+ Parameters
248
+ ----------
249
+ as_root : bool, optional
250
+ Whether the current node should become the new root node of the tree. Defaults to *False*, which leaves the node
251
+ at its current position in the tree.
252
+ **kwargs
253
+ The attributes to modify. The specific attributes depend on the concrete operator type.
254
+
255
+ Returns
256
+ -------
257
+ RelNode
258
+ The modified node
259
+
260
+ Notes
261
+ -----
262
+ Concrete node types should implement their own `mutate()` function which accepts parameters specific to the node type
263
+ (in addition to the required `as_root` parameter). Other than the function prototype, the methods can share the same
264
+ default implementation:
265
+
266
+ .. code-block:: python
267
+ params = {param: val for param, val in locals().items() if param != "self" and not param.startswith("__")}
268
+ return super().mutate(**params)
269
+
270
+ In order for the method to work properly, all field-specific parameters *must* use the common property/attribute
271
+ conventions. For example, to update a property `input_node` of the operator, its value must be stored in a "private"
272
+ attribute `_input_node`. At the same time, the `mutate()` method must accept a parameter `input_node` (without the
273
+ leading underscore). The method will then automatically update the internal attribute with the new value.
274
+
275
+ See the implementation of `Selection` for an example.
276
+ """
277
+ update_service = _RelNodeUpdateManager(self.root(), initiator=self)
278
+ relalg_copy = update_service.make_relalg_copy(as_root=as_root, **kwargs)
279
+ return relalg_copy.updated_initiator
280
+
281
+ def clone(self) -> RelNode:
282
+ """Obtains a 1:1 copy of the current node.
283
+
284
+ Returns
285
+ -------
286
+ RelNode
287
+ The cloned node
288
+ """
289
+ copied = copy.copy(self)
290
+ copied._sideways_pass = set(copied.sideways_pass)
291
+ return copied
292
+
293
+ def inspect(self, *, _indentation: int = 0) -> str:
294
+ """Provides a nice hierarchical string representation of the algebraic expression.
295
+
296
+ The representation typically spans multiple lines and uses indentation to separate parent nodes from their
297
+ children.
298
+
299
+ Parameters
300
+ ----------
301
+ indentation : int, optional
302
+ Internal parameter to the `inspect` function. Should not be modified by the user. Denotes how deeply
303
+ recursed we are in the plan tree. This enables the correct calculation of the current indentation level.
304
+ Defaults to 0 for the root node.
305
+
306
+ Returns
307
+ -------
308
+ str
309
+ A string representation of the algebraic expression
310
+ """
311
+ padding = " " * _indentation
312
+ prefix = f"{padding}<- " if padding else ""
313
+ inspections = [prefix + str(self)]
314
+ for child in self.children():
315
+ inspections.append(child.inspect(_indentation=_indentation + 2))
316
+ return "\n".join(inspections)
317
+
318
+ def _maintain_child_links(self, *, recursive: bool = False) -> None:
319
+ """Ensures that all child nodes of the current node *A* have *A* set as their parent.
320
+
321
+ Parameters
322
+ ----------
323
+ recursive : bool, optional
324
+ Whether the child links should be maintained recursively. Defaults to *False*.
325
+ """
326
+ for child in self.children():
327
+ if child._parent is None:
328
+ child._parent = self
329
+ continue
330
+ child._sideways_pass.add(self)
331
+
332
+ if not recursive:
333
+ return
334
+ for child in self.children():
335
+ child._maintain_child_links(recursive=True)
336
+
337
+ def _update_child_nodes(self, children: Sequence[RelNode]) -> None:
338
+ """Updates the child nodes of the current operator.
339
+
340
+ This method uses two heuristics: children of nodes with a single input should be called `input_node` (with its internal
341
+ attribute `_input_node`), whereas nodes with two inputs should be called `left_input` and `right_input` (with internal
342
+ attributes `_left_input` and `_right_input`, respectively).
343
+
344
+ If none of these fields are found, a `NotImplementedError` is raised.
345
+
346
+ Parameters
347
+ ----------
348
+ children : Sequence[RelNode]
349
+ The new children
350
+
351
+ Raises
352
+ ------
353
+ ValueError
354
+ If the number of supplied children does not match the expected number of child nodes.
355
+ NotImplementedError
356
+ If the current node does not have the required attributes to update its child nodes.
357
+ """
358
+ self._assert_correct_update_child_count(children)
359
+ attrs = vars(self)
360
+ if "_input_node" in attrs and len(children) == 1:
361
+ attrs["_input_node"] = children[0]
362
+ elif "_left_input" in attrs and "_right_input" in attrs and len(children) == 2:
363
+ attrs["_left_input"] = children[0]
364
+ attrs["_right_input"] = children[1]
365
+ else:
366
+ raise NotImplementedError(
367
+ f"Cannot use the default implementation of _update_child_nodes for node '{self}'."
368
+ )
369
+
370
+ def _clear_parent_links(self, *, recursive: bool = False) -> None:
371
+ """Removes all references to parent nodes from the current operator.
372
+
373
+ Parameters
374
+ ----------
375
+ recursive : bool, optional
376
+ Whether the parent links should be cleared recursively. Defaults to *False*.
377
+ """
378
+ self._parent = None
379
+ self._sideways_pass.clear()
380
+
381
+ if not recursive:
382
+ return
383
+ for child in self.children():
384
+ child._clear_parent_links(recursive=True)
385
+
386
+ def _rebuild_linkage(self) -> None:
387
+ """Ensures that the subtree induced by the current node is correctly linked."""
388
+ self._clear_parent_links(recursive=True)
389
+ self._maintain_child_links(recursive=True)
390
+
391
+ def _rehash(self) -> None:
392
+ """Re-calculates the hash value of the current node and all its children."""
393
+ self._hash_val = hash((self._node_type, self._recalc_hash_val()))
394
+ for child in self.children():
395
+ child._rehash()
396
+
397
+ def _assert_correct_update_child_count(self, children: Sequence[RelNode]) -> None:
398
+ """Ensures that the correct number of child nodes is supplied for updating the current node.
399
+
400
+ Parameters
401
+ ----------
402
+ children : Sequence[RelNode]
403
+ The new children
404
+
405
+ Raises
406
+ ------
407
+ ValueError
408
+ If the number of supplied children does not match the expected number of child nodes.
409
+ """
410
+ n_current_children = len(self.children())
411
+ n_new_children = len(children)
412
+ if n_current_children != n_new_children:
413
+ raise ValueError(
414
+ f"Cannot update a node containing {n_current_children} child nodes "
415
+ f"with {n_new_children} children."
416
+ )
417
+
418
+ @abc.abstractmethod
419
+ def _recalc_hash_val(self) -> int:
420
+ """Calculates the hash value of the current node.
421
+
422
+ This method only needs to consider attributes that are unique to the node. The final hash value will incorporate
423
+ additional information that is shared between nodes, such as the node type.
424
+
425
+ Returns
426
+ -------
427
+ int
428
+ The current hash value
429
+ """
430
+ raise NotImplementedError
431
+
432
+ def __hash__(self) -> int:
433
+ return self._hash_val
434
+
435
+ @abc.abstractmethod
436
+ def __eq__(self, other: object) -> bool:
437
+ raise NotImplementedError
438
+
439
+ def __repr__(self) -> str:
440
+ child_reprs = ", ".join(repr(child) for child in self.children())
441
+ return f"{self.node_type}({child_reprs})"
442
+
443
+ @abc.abstractmethod
444
+ def __str__(self) -> str:
445
+ raise NotImplementedError
446
+
447
+
448
+ class Selection(RelNode):
449
+ """A selection filters the input relation based on an arbitrary predicate.
450
+
451
+ Parameters
452
+ ----------
453
+ input_node : RelNode
454
+ The tuples to filter
455
+ predicate : AbstractPredicate
456
+ The predicate that must be satisfied by all output tuples
457
+ parent_node : Optional[RelNode], optional
458
+ The parent node of the operator, if one exists. The parent is the operator that receives the output relation of the
459
+ current operator. If the current operator is the root and (currently) does not have a parent, *None* can be used.
460
+
461
+ Notes
462
+ -----
463
+ A selection is defined as
464
+
465
+ .. math:: \\sigma_\\theta(R) := \\{ r \\in R | \\theta(r) \\}
466
+ """
467
+
468
+ def __init__(
469
+ self,
470
+ input_node: RelNode,
471
+ predicate: AbstractPredicate,
472
+ *,
473
+ parent_node: Optional[RelNode] = None,
474
+ ) -> None:
475
+ self._input_node = input_node
476
+ self._predicate = predicate
477
+ super().__init__(parent_node)
478
+
479
+ @property
480
+ def input_node(self) -> RelNode:
481
+ """Get the input relation that should be filtered.
482
+
483
+ Returns
484
+ -------
485
+ RelNode
486
+ A relation
487
+ """
488
+ return self._input_node
489
+
490
+ @property
491
+ def predicate(self) -> AbstractPredicate:
492
+ """Get the predicate that must be satisfied by the output tuples.
493
+
494
+ Returns
495
+ -------
496
+ AbstractPredicate
497
+ The filter condition
498
+ """
499
+ return self._predicate
500
+
501
+ def children(self) -> Sequence[RelNode]:
502
+ return [self._input_node]
503
+
504
+ def accept_visitor(self, visitor: RelNodeVisitor[VisitorResult]) -> VisitorResult:
505
+ return visitor.visit_selection(self)
506
+
507
+ def mutate(
508
+ self,
509
+ *,
510
+ input_node: Optional[RelNode] = None,
511
+ predicate: Optional[AbstractPredicate] = None,
512
+ as_root: bool = False,
513
+ ) -> Selection:
514
+ """Creates a new selection with modified attributes.
515
+
516
+ Parameters
517
+ ----------
518
+ input_node : Optional[RelNode], optional
519
+ The new input node to use. If *None*, the current input node is re-used.
520
+ predicate : Optional[AbstractPredicate], optional
521
+ The new predicate to use. If *None*, the current predicate is re-used.
522
+ as_root : bool, optional
523
+ Whether the selection should become the new root node of the tree. This overwrites any value passed to
524
+ `parent`.
525
+
526
+ Returns
527
+ -------
528
+ Selection
529
+ The modified selection node
530
+
531
+ See Also
532
+ --------
533
+ RelNode.mutate : for safety considerations and calling conventions
534
+ """
535
+ params = {
536
+ param: val
537
+ for param, val in locals().items()
538
+ if param != "self" and not param.startswith("__")
539
+ }
540
+ return super().mutate(**params)
541
+
542
+ def _recalc_hash_val(self) -> int:
543
+ return hash((self._input_node, self._predicate))
544
+
545
+ __hash__ = RelNode.__hash__
546
+
547
+ def __eq__(self, other: object) -> bool:
548
+ return (
549
+ isinstance(other, type(self))
550
+ and self._input_node == other._input_node
551
+ and self._predicate == other._predicate
552
+ )
553
+
554
+ def __str__(self) -> str:
555
+ return f"σ ({self._predicate})"
556
+
557
+
558
+ class CrossProduct(RelNode):
559
+ """A cross product calculates the cartesian product between tuples from two relations.
560
+
561
+ Parameters
562
+ ----------
563
+ left_input : RelNode
564
+ Relation containing the first set of tuples
565
+ right_input : RelNode
566
+ Relation containing the second set of tuples
567
+ parent_node : Optional[RelNode], optional
568
+ The parent node of the operator, if one exists. The parent is the operator that receives the output relation of the
569
+ current operator. If the current operator is the root and (currently) does not have a parent, *None* can be used.
570
+
571
+ Notes
572
+ -----
573
+ A cross product is defined as
574
+
575
+ .. math:: R \\times S := \\{ r \\circ s | r \\in R, s \\in S \\}
576
+ """
577
+
578
+ def __init__(
579
+ self,
580
+ left_input: RelNode,
581
+ right_input: RelNode,
582
+ *,
583
+ parent_node: Optional[RelNode] = None,
584
+ ) -> None:
585
+ self._left_input = left_input
586
+ self._right_input = right_input
587
+ super().__init__(parent_node)
588
+
589
+ @property
590
+ def left_input(self) -> RelNode:
591
+ """Get the operator providing the first set of tuples.
592
+
593
+ Returns
594
+ -------
595
+ RelNode
596
+ A relation
597
+ """
598
+ return self._left_input
599
+
600
+ @property
601
+ def right_input(self) -> RelNode:
602
+ """Get the operator providing the second set of tuples.
603
+
604
+ Returns
605
+ -------
606
+ RelNode
607
+ A relation
608
+ """
609
+ return self._right_input
610
+
611
+ def children(self) -> Sequence[RelNode]:
612
+ return [self._left_input, self._right_input]
613
+
614
+ def accept_visitor(self, visitor: RelNodeVisitor[VisitorResult]) -> VisitorResult:
615
+ return visitor.visit_cross_product(self)
616
+
617
+ def mutate(
618
+ self,
619
+ *,
620
+ left_input: Optional[RelNode] = None,
621
+ right_input: Optional[RelNode] = None,
622
+ as_root: bool = False,
623
+ ) -> CrossProduct:
624
+ """Creates a new cross product with modified attributes.
625
+
626
+ Parameters
627
+ ----------
628
+ left_input : Optional[RelNode], optional
629
+ The new left child node to use. If *None*, the current left input node is re-used.
630
+ right_input : Optional[RelNode], optional
631
+ The new right child node to use. If *None*, the current right input node is re-used.
632
+ as_root : bool, optional
633
+ Whether the cross product should become the new root node of the tree. This overwrites any value passed to
634
+ `parent`.
635
+
636
+ Returns
637
+ -------
638
+ CrossProduct
639
+ The modified cross product node
640
+
641
+ See Also
642
+ --------
643
+ RelNode.mutate : for safety considerations and calling conventions
644
+ """
645
+ params = {
646
+ param: val
647
+ for param, val in locals().items()
648
+ if param != "self" and not param.startswith("__")
649
+ }
650
+ return super().mutate(**params)
651
+
652
+ def _recalc_hash_val(self) -> int:
653
+ return hash((self._left_input, self._right_input))
654
+
655
+ __hash__ = RelNode.__hash__
656
+
657
+ def __eq__(self, other: object) -> bool:
658
+ return (
659
+ isinstance(other, type(self))
660
+ and self._left_input == other._left_input
661
+ and self._right_input == other._right_input
662
+ )
663
+
664
+ def __str__(self) -> str:
665
+ return "⨯"
666
+
667
+
668
+ class Union(RelNode):
669
+ """A union combines the tuple sets of two relations into a single output relation.
670
+
671
+ In order for a union to work, both relations must have the same structure.
672
+
673
+ Parameters
674
+ ----------
675
+ left_input : RelNode
676
+ The first relation
677
+ right_input : RelNode
678
+ The second relation
679
+ parent_node : Optional[RelNode], optional
680
+ The parent node of the operator, if one exists. The parent is the operator that receives the output relation of the
681
+ current operator. If the current operator is the root and (currently) does not have a parent, *None* can be used.
682
+
683
+ Notes
684
+ -----
685
+ The union is defined as
686
+
687
+ .. math:: R \\cup S := \\{ t | t \\in R \\lor t \\in S \\}
688
+ """
689
+
690
+ def __init__(
691
+ self,
692
+ left_input: RelNode,
693
+ right_input: RelNode,
694
+ *,
695
+ parent_node: Optional[RelNode] = None,
696
+ ) -> None:
697
+ self._left_input = left_input
698
+ self._right_input = right_input
699
+ super().__init__(parent_node)
700
+
701
+ @property
702
+ def left_input(self) -> RelNode:
703
+ """Get the operator providing the first relation's tuples.
704
+
705
+ Returns
706
+ -------
707
+ RelNode
708
+ A relation
709
+ """
710
+ return self._left_input
711
+
712
+ @property
713
+ def right_input(self) -> RelNode:
714
+ """Get the operator providing the second relation's tuples.
715
+
716
+ Returns
717
+ -------
718
+ RelNode
719
+ A relation
720
+ """
721
+ return self._right_input
722
+
723
+ def children(self) -> Sequence[RelNode]:
724
+ return [self._left_input, self._right_input]
725
+
726
+ def accept_visitor(self, visitor: RelNodeVisitor[VisitorResult]) -> VisitorResult:
727
+ return visitor.visit_union(visitor)
728
+
729
+ def mutate(
730
+ self,
731
+ *,
732
+ left_input: Optional[RelNode] = None,
733
+ right_input: Optional[RelNode] = None,
734
+ as_root: bool = False,
735
+ ) -> Union:
736
+ """Creates a new union with modified attributes.
737
+
738
+ Parameters
739
+ ----------
740
+ left_input : Optional[RelNode], optional
741
+ The new left child node to use. If *None*, the current left input node is re-used.
742
+ right_input : Optional[RelNode], optional
743
+ The new right child node to use. If *None*, the current right input node is re-used.
744
+ as_root : bool, optional
745
+ Whether the union should become the new root node of the tree. This overwrites any value passed to `parent`.
746
+
747
+ Returns
748
+ -------
749
+ Union
750
+ The modified union node
751
+
752
+ See Also
753
+ --------
754
+ RelNode.mutate : for safety considerations and calling conventions
755
+ """
756
+ params = {
757
+ param: val
758
+ for param, val in locals().items()
759
+ if param != "self" and not param.startswith("__")
760
+ }
761
+ return super().mutate(**params)
762
+
763
+ def _recalc_hash_val(self) -> int:
764
+ return hash((self._left_input, self._right_input))
765
+
766
+ __hash__ = RelNode.__hash__
767
+
768
+ def __eq__(self, other: object) -> bool:
769
+ return (
770
+ isinstance(other, type(self))
771
+ and self._left_input == other._left_input
772
+ and self._right_input == other._right_input
773
+ )
774
+
775
+ def __str__(self) -> str:
776
+ return "∪"
777
+
778
+
779
+ class Intersection(RelNode):
780
+ """An intersection provides all tuples that are contained in both of its input operators.
781
+
782
+ In order for an intersection to work, both relations must have the same structure.
783
+
784
+ Parameters
785
+ ----------
786
+ left_input : RelNode
787
+ The first relation.
788
+ right_input : RelNode
789
+ The second relation.
790
+ parent_node : Optional[RelNode], optional
791
+ The parent node of the operator, if one exists. The parent is the operator that receives the output relation of the
792
+ current operator. If the current operator is the root and (currently) does not have a parent, *None* can be used.
793
+
794
+ Notes
795
+ -----
796
+ The difference is defined as
797
+
798
+ .. math:: R \\cap S := \\{ t | t \\in R \\land t \\in S \\}
799
+ """
800
+
801
+ def __init__(
802
+ self,
803
+ left_input: RelNode,
804
+ right_input: RelNode,
805
+ *,
806
+ parent_node: Optional[RelNode] = None,
807
+ ) -> None:
808
+ self._left_input = left_input
809
+ self._right_input = right_input
810
+ super().__init__(parent_node)
811
+
812
+ @property
813
+ def left_input(self) -> RelNode:
814
+ """Get the operator providing the first relation's tuples.
815
+
816
+ Returns
817
+ -------
818
+ RelNode
819
+ A relation
820
+ """
821
+ return self._left_input
822
+
823
+ @property
824
+ def right_input(self) -> RelNode:
825
+ """Get the operator providing the second relation's tuples.
826
+
827
+ Returns
828
+ -------
829
+ RelNode
830
+ A relation
831
+ """
832
+ return self._right_input
833
+
834
+ def children(self) -> Sequence[RelNode]:
835
+ return [self._left_input, self._right_input]
836
+
837
+ def accept_visitor(self, visitor: RelNodeVisitor[VisitorResult]) -> VisitorResult:
838
+ return visitor.visit_intersection(visitor)
839
+
840
+ def mutate(
841
+ self,
842
+ *,
843
+ left_input: Optional[RelNode] = None,
844
+ right_input: Optional[RelNode] = None,
845
+ as_root: bool = False,
846
+ ) -> Intersection:
847
+ """Creates a new intersection with modified attributes.
848
+
849
+ Parameters
850
+ ----------
851
+ left_input : Optional[RelNode], optional
852
+ The new left child node to use. If *None*, the current left input node is re-used.
853
+ right_input : Optional[RelNode], optional
854
+ The new right child node to use. If *None*, the current right input node is re-used.
855
+ as_root : bool, optional
856
+ Whether the intersection should become the new root node of the tree. This overwrites any value passed to
857
+ `parent`.
858
+
859
+ Returns
860
+ -------
861
+ Intersection
862
+ The modified intersection node
863
+
864
+ See Also
865
+ --------
866
+ RelNode.mutate : for safety considerations and calling conventions
867
+ """
868
+ params = {
869
+ param: val
870
+ for param, val in locals().items()
871
+ if param != "self" and not param.startswith("__")
872
+ }
873
+ return super().mutate(**params)
874
+
875
+ def _recalc_hash_val(self) -> int:
876
+ return hash((self._left_input, self._right_input))
877
+
878
+ __hash__ = RelNode.__hash__
879
+
880
+ def __eq__(self, other: object) -> bool:
881
+ return (
882
+ isinstance(other, type(self))
883
+ and self._left_input == other._left_input
884
+ and self._right_input == other._right_input
885
+ )
886
+
887
+ def __str__(self) -> str:
888
+ return "∩"
889
+
890
+
891
+ class Difference(RelNode):
892
+ """An intersection returns all tuples from one relation, that are not present in another relation.
893
+
894
+ In order for the difference to work, both input relations must share the same structure.
895
+
896
+ Parameters
897
+ ----------
898
+ left_input : RelNode
899
+ The first relation. This is the relation to remove tuples from.
900
+ right_input : RelNode
901
+ The second relation. This is the relation containing the tuples that should be removed from the `left_input`.
902
+ parent_node : Optional[RelNode], optional
903
+ The parent node of the operator, if one exists. The parent is the operator that receives the output relation of the
904
+ current operator. If the current operator is the root and (currently) does not have a parent, *None* can be used.
905
+
906
+ Notes
907
+ -----
908
+ The difference is defined as
909
+
910
+ .. math:: R \\setminus S := \\{ r \\in R | r \\notin S \\}
911
+ """
912
+
913
+ def __init__(
914
+ self,
915
+ left_input: RelNode,
916
+ right_input: RelNode,
917
+ *,
918
+ parent_node: Optional[RelNode] = None,
919
+ ) -> None:
920
+ self._left_input = left_input
921
+ self._right_input = right_input
922
+ super().__init__(parent_node)
923
+
924
+ @property
925
+ def left_input(self) -> RelNode:
926
+ """Get the operator providing the relation to remove tuples from.
927
+
928
+ Returns
929
+ -------
930
+ RelNode
931
+ A relation
932
+ """
933
+ return self._left_input
934
+
935
+ @property
936
+ def right_input(self) -> RelNode:
937
+ """Get the operator providing the tuples to remove.
938
+
939
+ Returns
940
+ -------
941
+ RelNode
942
+ A relation
943
+ """
944
+ return self._right_input
945
+
946
+ def children(self) -> Sequence[RelNode]:
947
+ return [self._left_input, self._right_input]
948
+
949
+ def accept_visitor(self, visitor: RelNodeVisitor[VisitorResult]) -> VisitorResult:
950
+ return visitor.visit_difference(visitor)
951
+
952
+ def mutate(
953
+ self,
954
+ *,
955
+ left_input: Optional[RelNode] = None,
956
+ right_input: Optional[RelNode] = None,
957
+ as_root: bool = False,
958
+ ) -> Difference:
959
+ """Creates a new difference with modified attributes.
960
+
961
+ Parameters
962
+ ----------
963
+ left_input : Optional[RelNode], optional
964
+ The new left child node to use. If *None*, the current left input node is re-used.
965
+ right_input : Optional[RelNode], optional
966
+ The new right child node to use. If *None*, the current right input node is re-used.
967
+ as_root : bool, optional
968
+ Whether the difference should become the new root node of the tree. This overwrites any value passed to
969
+ `parent`.
970
+
971
+ Returns
972
+ -------
973
+ Difference
974
+ The modified difference node
975
+
976
+ See Also
977
+ --------
978
+ RelNode.mutate : for safety considerations and calling conventions
979
+ """
980
+ params = {
981
+ param: val
982
+ for param, val in locals().items()
983
+ if param != "self" and not param.startswith("__")
984
+ }
985
+ return super().mutate(**params)
986
+
987
+ def _recalc_hash_val(self) -> int:
988
+ return hash((self._left_input, self._right_input))
989
+
990
+ __hash__ = RelNode.__hash__
991
+
992
+ def __eq__(self, other: object) -> bool:
993
+ return (
994
+ isinstance(other, type(self))
995
+ and self._left_input == other._left_input
996
+ and self._right_input == other._right_input
997
+ )
998
+
999
+ def __str__(self) -> str:
1000
+ return "\\"
1001
+
1002
+
1003
+ class Relation(RelNode):
1004
+ """A relation provides the tuples ("rows") contained in a table.
1005
+
1006
+ Each relation can correspond to a physical table contained in some relational schema, or it can represent the result of a
1007
+ subquery operation.
1008
+
1009
+ Parameters
1010
+ ----------
1011
+ table : TableReference
1012
+ The table that is represented by this relation.
1013
+ provided_columns : Iterable[ColumnReference | ColumnExpression]
1014
+ The columns that are contained in the table.
1015
+ subquery_input : Optional[RelNode], optional
1016
+ For subquery relations, this is the algebraic expression that computes the results of the subquery. Relations that
1017
+ correspond to base tables do not have this attribute set.
1018
+ parent_node : Optional[RelNode], optional
1019
+ The parent node of the operator, if one exists. The parent is the operator that receives the output relation of the
1020
+ current operator. If the current operator is the root and (currently) does not have a parent, *None* can be used.
1021
+ """
1022
+
1023
+ def __init__(
1024
+ self,
1025
+ table: TableReference,
1026
+ provided_columns: Iterable[ColumnReference | ColumnExpression],
1027
+ *,
1028
+ subquery_input: Optional[RelNode] = None,
1029
+ parent_node: Optional[RelNode] = None,
1030
+ ) -> None:
1031
+ self._table = table
1032
+ self._provided_cols = frozenset(
1033
+ col if isinstance(col, ColumnExpression) else ColumnExpression(col)
1034
+ for col in provided_columns
1035
+ )
1036
+
1037
+ self._subquery_input = subquery_input if subquery_input is not None else None
1038
+ if self._subquery_input is not None:
1039
+ # We need to set the parent node explicitly here in order to prevent infinite recursion
1040
+ self._subquery_input._parent = self
1041
+
1042
+ super().__init__(parent_node)
1043
+
1044
+ @property
1045
+ def table(self) -> TableReference:
1046
+ """Get the table that is represented by this relation.
1047
+
1048
+ Returns
1049
+ -------
1050
+ TableReference
1051
+ A table. Usually this will correpond to an actual physical database table, but for subqueries this might also be a
1052
+ virtual table.
1053
+ """
1054
+ return self._table
1055
+
1056
+ @property
1057
+ def subquery_input(self) -> Optional[RelNode]:
1058
+ """Get the root node of the subquery that produces the input tuples for this relation.
1059
+
1060
+ Returns
1061
+ -------
1062
+ Optional[RelNode]
1063
+ The root node if it exists, or *None* for actual base table relations.
1064
+ """
1065
+ return self._subquery_input
1066
+
1067
+ def children(self) -> Sequence[RelNode]:
1068
+ return [self._subquery_input] if self._subquery_input else []
1069
+
1070
+ def tables(self, *, ignore_subqueries: bool = False) -> frozenset[TableReference]:
1071
+ if ignore_subqueries:
1072
+ return frozenset((self._table,))
1073
+ return super().tables() | {self._table}
1074
+
1075
+ def provided_expressions(self) -> frozenset[SqlExpression]:
1076
+ return super().provided_expressions() | self._provided_cols
1077
+
1078
+ def accept_visitor(self, visitor: RelNodeVisitor[VisitorResult]) -> VisitorResult:
1079
+ return visitor.visit_base_relation(self)
1080
+
1081
+ def mutate(
1082
+ self,
1083
+ *,
1084
+ table: Optional[TableReference] = None,
1085
+ provided_columns: Optional[Iterable[ColumnReference | ColumnExpression]] = None,
1086
+ subquery_input: Optional[RelNode] = None,
1087
+ as_root: bool = False,
1088
+ ) -> Relation:
1089
+ """Creates a new relation with modified attributes.
1090
+
1091
+ Parameters
1092
+ ----------
1093
+ table : Optional[TableReference], optional
1094
+ The new table to use. If *None*, the current table is re-used.
1095
+ provided_columns : Optional[Iterable[ColumnReference | ColumnExpression]], optional
1096
+ The new columns to use. If *None*, the current columns are re-used.
1097
+ subquery_input : Optional[RelNode], optional
1098
+ The new subquery input to use. If *None*, the current subquery input is re-used.
1099
+ as_root : bool, optional
1100
+ Whether the relation should become the new root node of the tree. This overwrites any value passed to `parent`.
1101
+
1102
+ Returns
1103
+ -------
1104
+ Relation
1105
+ The modified relation node
1106
+
1107
+ See Also
1108
+ --------
1109
+ RelNode.mutate : for safety considerations and calling conventions
1110
+ """
1111
+ params = {
1112
+ param: val
1113
+ for param, val in locals().items()
1114
+ if param != "self" and not param.startswith("__")
1115
+ }
1116
+ return super().mutate(**params)
1117
+
1118
+ def _update_child_nodes(self, children: Sequence[RelNode]) -> None:
1119
+ self._assert_correct_update_child_count(children)
1120
+ if self._subquery_input:
1121
+ self._subquery_input = children[0]
1122
+
1123
+ def _recalc_hash_val(self) -> int:
1124
+ return hash((self._table, self._subquery_input))
1125
+
1126
+ __hash__ = RelNode.__hash__
1127
+
1128
+ def __eq__(self, other: object) -> bool:
1129
+ return isinstance(other, type(self)) and self._table == other._table
1130
+
1131
+ def __repr__(self) -> str:
1132
+ return self._table.identifier()
1133
+
1134
+ def __str__(self) -> str:
1135
+ return self._table.identifier()
1136
+
1137
+
1138
+ class ThetaJoin(RelNode):
1139
+ """A theta joins combines individual tuples from two input relations if they match a specific predicate.
1140
+
1141
+ Parameters
1142
+ ----------
1143
+ left_input : RelNode
1144
+ Relation containing the first set of tuples.
1145
+ right_input : RelNode
1146
+ Relation containing the second set of tuples.
1147
+ predicate : AbstractPredicate
1148
+ A predicate that must be satisfied by all joined tuples.
1149
+ parent_node : Optional[RelNode], optional
1150
+ The parent node of the operator, if one exists. The parent is the operator that receives the output relation of the
1151
+ current operator. If the current operator is the root and (currently) does not have a parent, *None* can be used.
1152
+
1153
+ Notes
1154
+ -----
1155
+ A theta join is defined as
1156
+
1157
+ .. math:: \\bowtie_\\theta(R, S) := \\{ r \\circ s | r \\in R \\land s \\in S \\land \\theta(r, s) \\}
1158
+ """
1159
+
1160
+ def __init__(
1161
+ self,
1162
+ left_input: RelNode,
1163
+ right_input: RelNode,
1164
+ predicate: AbstractPredicate,
1165
+ *,
1166
+ parent_node: Optional[RelNode] = None,
1167
+ ) -> None:
1168
+ self._left_input = left_input
1169
+ self._right_input = right_input
1170
+ self._predicate = predicate
1171
+ super().__init__(parent_node)
1172
+
1173
+ @property
1174
+ def left_input(self) -> RelNode:
1175
+ """Get the operator providing the first set of tuples.
1176
+
1177
+ Returns
1178
+ -------
1179
+ RelNode
1180
+ A relation
1181
+ """
1182
+ return self._left_input
1183
+
1184
+ @property
1185
+ def right_input(self) -> RelNode:
1186
+ """Get the operator providing the second set of tuples.
1187
+
1188
+ Returns
1189
+ -------
1190
+ RelNode
1191
+ A relation
1192
+ """
1193
+ return self._right_input
1194
+
1195
+ @property
1196
+ def predicate(self) -> AbstractPredicate:
1197
+ """Get the condition that must be satisfied by the input tuples.
1198
+
1199
+ Returns
1200
+ -------
1201
+ AbstractPredicate
1202
+ A predicate
1203
+ """
1204
+ return self._predicate
1205
+
1206
+ def children(self) -> Sequence[RelNode]:
1207
+ return [self._left_input, self._right_input]
1208
+
1209
+ def accept_visitor(self, visitor: RelNodeVisitor[VisitorResult]) -> VisitorResult:
1210
+ return visitor.visit_theta_join(self)
1211
+
1212
+ def mutate(
1213
+ self,
1214
+ *,
1215
+ left_input: Optional[RelNode] = None,
1216
+ right_input: Optional[RelNode] = None,
1217
+ predicate: Optional[AbstractPredicate] = None,
1218
+ as_root: bool = False,
1219
+ ) -> ThetaJoin:
1220
+ """Creates a new theta join with modified attributes.
1221
+
1222
+ Parameters
1223
+ ----------
1224
+ left_input : Optional[RelNode], optional
1225
+ The new left child node to use. If *None*, the current left input node is re-used.
1226
+ right_input : Optional[RelNode], optional
1227
+ The new right child node to use. If *None*, the current right input node is re-used.
1228
+ predicate : Optional[AbstractPredicate], optional
1229
+ The new predicate to use. If *None*, the current predicate is re-used.
1230
+ as_root : bool, optional
1231
+ Whether the theta join should become the new root node of the tree. This overwrites any value passed to
1232
+ `parent`.
1233
+
1234
+ Returns
1235
+ -------
1236
+ ThetaJoin
1237
+ The modified theta join node
1238
+
1239
+ See Also
1240
+ --------
1241
+ RelNode.mutate : for safety considerations and calling conventions
1242
+ """
1243
+ params = {
1244
+ param: val
1245
+ for param, val in locals().items()
1246
+ if param != "self" and not param.startswith("__")
1247
+ }
1248
+ return super().mutate(**params)
1249
+
1250
+ def _recalc_hash_val(self) -> int:
1251
+ return hash((self._left_input, self._right_input, self._predicate))
1252
+
1253
+ __hash__ = RelNode.__hash__
1254
+
1255
+ def __eq__(self, other: object) -> bool:
1256
+ return (
1257
+ isinstance(other, type(self))
1258
+ and self._left_input == other._left_input
1259
+ and self._right_input == other._right_input
1260
+ and self._predicate == other._predicate
1261
+ )
1262
+
1263
+ def __str__(self) -> str:
1264
+ return f"⋈ ϴ=({self._predicate})"
1265
+
1266
+
1267
+ class Projection(RelNode):
1268
+ """A projection selects individual attributes from the tuples of an input relation.
1269
+
1270
+ The output relation will contain exactly the same tuples as the input, but each tuple will potentially contain less
1271
+ attributes.
1272
+
1273
+ Parameters
1274
+ ----------
1275
+ input_node : RelNode
1276
+ The tuples to process
1277
+ targets : Sequence[SqlExpression]
1278
+ The attributes that should still be contained in the output relation
1279
+ parent_node : Optional[RelNode], optional
1280
+ The parent node of the operator, if one exists. The parent is the operator that receives the output relation of the
1281
+ current operator. If the current operator is the root and (currently) does not have a parent, *None* can be used.
1282
+ """
1283
+
1284
+ def __init__(
1285
+ self,
1286
+ input_node: RelNode,
1287
+ targets: Sequence[SqlExpression],
1288
+ *,
1289
+ parent_node: Optional[RelNode] = None,
1290
+ ) -> None:
1291
+ self._input_node = input_node
1292
+ self._targets = tuple(targets)
1293
+ super().__init__(parent_node)
1294
+
1295
+ @property
1296
+ def input_node(self) -> RelNode:
1297
+ """Get the operator providing the tuples to project.
1298
+
1299
+ Returns
1300
+ -------
1301
+ RelNode
1302
+ A relation
1303
+ """
1304
+ return self._input_node
1305
+
1306
+ @property
1307
+ def columns(self) -> Sequence[SqlExpression]:
1308
+ """Provides the attributes that should be included in the output relation's tuples.
1309
+
1310
+ Returns
1311
+ -------
1312
+ Sequence[SqlExpression]
1313
+ The projected attributes.
1314
+ """
1315
+ return self._targets
1316
+
1317
+ def children(self) -> Sequence[RelNode]:
1318
+ return [self._input_node]
1319
+
1320
+ def accept_visitor(self, visitor: RelNodeVisitor[VisitorResult]) -> VisitorResult:
1321
+ return visitor.visit_projection(self)
1322
+
1323
+ def mutate(
1324
+ self,
1325
+ *,
1326
+ input_node: Optional[RelNode] = None,
1327
+ targets: Optional[Sequence[SqlExpression]] = None,
1328
+ as_root: bool = False,
1329
+ ) -> Projection:
1330
+ """Creates a new projection with modified attributes.
1331
+
1332
+ Parameters
1333
+ ----------
1334
+ input_node : Optional[RelNode], optional
1335
+ The new input node to use. If *None*, the current input node is re-used.
1336
+ targets : Optional[Sequence[SqlExpression]], optional
1337
+ The new targets to use. If *None*, the current targets are re-used.
1338
+ as_root : bool, optional
1339
+ Whether the projection should become the new root node of the tree. This overwrites any value passed to
1340
+ `parent`.
1341
+
1342
+ Returns
1343
+ -------
1344
+ Projection
1345
+ The modified projection node
1346
+
1347
+ See Also
1348
+ --------
1349
+ RelNode.mutate : for safety considerations and calling conventions
1350
+ """
1351
+ params = {
1352
+ param: val
1353
+ for param, val in locals().items()
1354
+ if param != "self" and not param.startswith("__")
1355
+ }
1356
+ return super().mutate(**params)
1357
+
1358
+ def _recalc_hash_val(self) -> int:
1359
+ return hash((self._input_node, self._targets))
1360
+
1361
+ __hash__ = RelNode.__hash__
1362
+
1363
+ def __eq__(self, other: object) -> bool:
1364
+ return (
1365
+ isinstance(other, type(self))
1366
+ and self._input_node == other._input_node
1367
+ and self._targets == other._targets
1368
+ )
1369
+
1370
+ def __str__(self) -> str:
1371
+ col_str = ", ".join(str(col) for col in self._targets)
1372
+ return f"π ({col_str})"
1373
+
1374
+
1375
+ class Grouping(RelNode):
1376
+ """Grouping partitions input tuples according to specific attributes and calculates aggregated values.
1377
+
1378
+ Parameters
1379
+ ----------
1380
+ input_node : RelNode
1381
+ The tuples to process
1382
+ group_columns : Sequence[SqlExpression]
1383
+ The expressions that should be used to partition the input tuples. Can be empty if only aggregations over all input
1384
+ tuples should be computed.
1385
+ aggregates : Optional[dict[frozenset[SqlExpression], frozenset[FunctionExpression]]], optional
1386
+ The aggregates that should be computed. This is a mapping from the input expressions to the desired aggregate. Can be
1387
+ empty if only a grouping should be performed. In this case, the grouping operates as a duplicate-elimination mechanism.
1388
+ parent_node : Optional[RelNode], optional
1389
+ The parent node of the operator, if one exists. The parent is the operator that receives the output relation of the
1390
+ current operator. If the current operator is the root and (currently) does not have a parent, *None* can be used.
1391
+ """
1392
+
1393
+ def __init__(
1394
+ self,
1395
+ input_node: RelNode,
1396
+ group_columns: Sequence[SqlExpression],
1397
+ *,
1398
+ aggregates: Optional[
1399
+ dict[frozenset[SqlExpression], frozenset[FunctionExpression]]
1400
+ ] = None,
1401
+ parent_node: Optional[RelNode] = None,
1402
+ ) -> None:
1403
+ if not group_columns and not aggregates:
1404
+ raise ValueError(
1405
+ "Either group columns or aggregation functions must be specified!"
1406
+ )
1407
+ self._input_node = input_node
1408
+ self._group_columns = tuple(group_columns)
1409
+ self._aggregates: util.frozendict[
1410
+ frozenset[SqlExpression], frozenset[FunctionExpression]
1411
+ ] = util.frozendict(aggregates)
1412
+ super().__init__(parent_node)
1413
+
1414
+ @property
1415
+ def input_node(self) -> RelNode:
1416
+ """Get the operator that provides the tuples to group.
1417
+
1418
+ Returns
1419
+ -------
1420
+ RelNode
1421
+ A relation
1422
+ """
1423
+ return self._input_node
1424
+
1425
+ @property
1426
+ def group_columns(self) -> Sequence[SqlExpression]:
1427
+ """Get the expressions that should be used to partition the input tuples.
1428
+
1429
+ Returns
1430
+ -------
1431
+ Sequence[SqlExpression]
1432
+ The group columns. Can be empty if only aggregations over all input tuples should be computed.
1433
+ """
1434
+ return self._group_columns
1435
+
1436
+ @property
1437
+ def aggregates(self) -> util.frozendict[SqlExpression, FunctionExpression]:
1438
+ """Get the aggregates that should be computed.
1439
+
1440
+ Aggregates map from the input expressions to the desired aggregation function.
1441
+
1442
+ Returns
1443
+ -------
1444
+ util.frozendict[SqlExpression, FunctionExpression]
1445
+ The aggregations. Can be empty if only a grouping should be performed.
1446
+ """
1447
+ return self._aggregates
1448
+
1449
+ def children(self) -> Sequence[RelNode]:
1450
+ return [self._input_node]
1451
+
1452
+ def provided_expressions(self) -> frozenset[SqlExpression]:
1453
+ aggregate_expressions = util.set_union(self._aggregates.values())
1454
+ return frozenset(set(self._group_columns) | aggregate_expressions)
1455
+
1456
+ def accept_visitor(self, visitor: RelNodeVisitor[VisitorResult]) -> VisitorResult:
1457
+ return visitor.visit_grouping(self)
1458
+
1459
+ def mutate(
1460
+ self,
1461
+ *,
1462
+ input_node: Optional[RelNode] = None,
1463
+ group_columns: Optional[Sequence[SqlExpression]] = None,
1464
+ aggregates: Optional[
1465
+ dict[frozenset[SqlExpression], frozenset[FunctionExpression]]
1466
+ ] = None,
1467
+ parent: Optional[RelNode] = None,
1468
+ as_root: bool = False,
1469
+ ) -> Grouping:
1470
+ """Creates a new group by with modified attributes.
1471
+
1472
+ Parameters
1473
+ ----------
1474
+ input_node : Optional[RelNode], optional
1475
+ The new input node to use. If *None*, the current input node is re-used.
1476
+ group_columns : Optional[Sequence[SqlExpression]], optional
1477
+ The new group columns to use. If *None*, the current group columns are re-used.
1478
+ aggregates : Optional[dict[frozenset[SqlExpression], frozenset[FunctionExpression]]], optional
1479
+ The new aggregates to use. If *None*, the current aggregates are re-used.
1480
+ as_root : bool, optional
1481
+ Whether the group by should become the new root node of the tree. This overwrites any value passed to `parent`.
1482
+
1483
+ Returns
1484
+ -------
1485
+ Grouping
1486
+ The modified group by node
1487
+
1488
+ See Also
1489
+ --------
1490
+ RelNode.mutate : for safety considerations and calling conventions
1491
+ """
1492
+ params = {
1493
+ param: val
1494
+ for param, val in locals().items()
1495
+ if param != "self" and not param.startswith("__")
1496
+ }
1497
+ return super().mutate(**params)
1498
+
1499
+ def _recalc_hash_val(self) -> int:
1500
+ return hash((self._input_node, self._group_columns, self._aggregates))
1501
+
1502
+ __hash__ = RelNode.__hash__
1503
+
1504
+ def __eq__(self, other: object) -> bool:
1505
+ return (
1506
+ isinstance(other, type(self))
1507
+ and self._input_node == other._input_node
1508
+ and self._group_columns == other._group_columns
1509
+ and self._aggregates == other._aggregates
1510
+ )
1511
+
1512
+ def __str__(self) -> str:
1513
+ pretty_aggregations: dict[str, str] = {}
1514
+ for cols, agg_funcs in self._aggregates.items():
1515
+ if len(cols) == 1:
1516
+ col_str = str(util.simplify(cols))
1517
+ else:
1518
+ col_str = "(" + ", ".join(str(c) for c in cols) + ")"
1519
+ if len(agg_funcs) == 1:
1520
+ agg_str = str(util.simplify(agg_funcs))
1521
+ else:
1522
+ agg_str = "(" + ", ".join(str(agg) for agg in agg_funcs) + ")"
1523
+ pretty_aggregations[col_str] = agg_str
1524
+
1525
+ agg_str = ", ".join(
1526
+ f"{col}: {agg_func}" for col, agg_func in pretty_aggregations.items()
1527
+ )
1528
+ if not self._group_columns:
1529
+ return f"γ ({agg_str})"
1530
+ group_str = ", ".join(str(col) for col in self._group_columns)
1531
+ return f"{group_str} γ ({agg_str})"
1532
+
1533
+
1534
+ class Rename(RelNode):
1535
+ """Rename remaps column names to different names.
1536
+
1537
+ Parameters
1538
+ ----------
1539
+ input_node : RelNode
1540
+ The tuples to modify
1541
+ mapping : dict[ColumnReference, ColumnReference]
1542
+ A map from current column name to new column name.
1543
+ parent_node : Optional[RelNode]
1544
+ The parent node of the operator, if one exists. The parent is the operator that receives the output relation of the
1545
+ current operator. If the current operator is the root and (currently) does not have a parent, *None* can be used.
1546
+
1547
+ Warnings
1548
+ --------
1549
+ This node is currently not used since we do not support natural joins.
1550
+ """
1551
+
1552
+ def __init__(
1553
+ self,
1554
+ input_node: RelNode,
1555
+ mapping: dict[ColumnReference, ColumnReference],
1556
+ *,
1557
+ parent_node: Optional[RelNode] = None,
1558
+ ) -> None:
1559
+ # TODO: check types + add provided / required expressions method
1560
+ self._input_node = input_node
1561
+ self._mapping = util.frozendict(mapping)
1562
+ super().__init__(parent_node)
1563
+
1564
+ @property
1565
+ def input_node(self) -> RelNode:
1566
+ """Get the operator that provides the tuples to modify
1567
+
1568
+ Returns
1569
+ -------
1570
+ RelNode
1571
+ A relation
1572
+ """
1573
+ return self._input_node
1574
+
1575
+ @property
1576
+ def mapping(self) -> util.frozendict[ColumnReference, ColumnReference]:
1577
+ """Get the performed renamings.
1578
+
1579
+ Returns
1580
+ -------
1581
+ util.frozendict[ColumnReference, ColumnReference]
1582
+ A map from current column name to new column name.
1583
+ """
1584
+ return self._mapping
1585
+
1586
+ def provided_expressions(self) -> frozenset[SqlExpression]:
1587
+ return frozenset(
1588
+ transform.rename_columns_in_expression(child_expr, self._mapping)
1589
+ for child_expr in self._input_node.provided_expressions()
1590
+ )
1591
+
1592
+ def children(self) -> Sequence[RelNode]:
1593
+ return [self._input_node]
1594
+
1595
+ def accept_visitor(self, visitor: RelNodeVisitor[VisitorResult]) -> VisitorResult:
1596
+ return visitor.visit_rename(self)
1597
+
1598
+ def mutate(
1599
+ self,
1600
+ *,
1601
+ input_node: Optional[RelNode] = None,
1602
+ mapping: Optional[dict[ColumnReference, ColumnReference]] = None,
1603
+ as_root: bool = False,
1604
+ ) -> Rename:
1605
+ """Creates a new rename with modified attributes.
1606
+
1607
+ Parameters
1608
+ ----------
1609
+ input_node : Optional[RelNode], optional
1610
+ The new input node to use. If *None*, the current input node is re-used.
1611
+ mapping : Optional[dict[ColumnReference, ColumnReference]], optional
1612
+ The new mapping to use. If *None*, the current mapping is re-used.
1613
+ as_root : bool, optional
1614
+ Whether the rename should become the new root node of the tree. This overwrites any value passed to `parent`.
1615
+
1616
+ Returns
1617
+ -------
1618
+ Rename
1619
+ The modified rename node
1620
+
1621
+ See Also
1622
+ --------
1623
+ RelNode.mutate : for safety considerations and calling conventions
1624
+ """
1625
+ params = {
1626
+ param: val
1627
+ for param, val in locals().items()
1628
+ if param != "self" and not param.startswith("__")
1629
+ }
1630
+ return super().mutate(**params)
1631
+
1632
+ def _recalc_hash_val(self) -> int:
1633
+ return hash((self._input_node, self._mapping))
1634
+
1635
+ __hash__ = RelNode.__hash__
1636
+
1637
+ def __eq__(self, other: object) -> bool:
1638
+ return (
1639
+ isinstance(other, type(self))
1640
+ and self._input_node == other._input_node
1641
+ and self._mapping == other._mapping
1642
+ )
1643
+
1644
+ def __str__(self) -> str:
1645
+ map_str = ", ".join(f"{col}: {target}" for col, target in self._mapping.items())
1646
+ return f"ϱ ({map_str})"
1647
+
1648
+
1649
+ SortDirection = typing.Literal["asc", "desc"]
1650
+ """Describes whether tuples should be sorted in ascending or descending order."""
1651
+
1652
+
1653
+ class Sort(RelNode):
1654
+ """Sort modifies the order in which tuples are provided.
1655
+
1656
+ Parameters
1657
+ ----------
1658
+ input_node : RelNode
1659
+ The tuples to order
1660
+ sorting : Sequence[tuple[SqlExpression, SortDirection] | SqlExpression]
1661
+ The expressions that should be used to determine the sorting. For expressions that do not specify any particular
1662
+ direction, ascending order is assumed. Later expressions are used to solve ties among tuples with the same expression
1663
+ values in the first couple of expressions.
1664
+ parent_node : Optional[RelNode], optional
1665
+ _description_, by default None
1666
+
1667
+ Notes
1668
+ -----
1669
+ Strictly speaking, this operator is not part of traditional relational algebra. This is because the algebra uses
1670
+ set-semantics which do not supply any ordering. However, due to SQL's *ORDER BY* clause, most relational algebra dialects
1671
+ support ordering nevertheless.
1672
+
1673
+ However, we do not support special placement of *NULL* column values, i.e. no *ORDER BY R.a NULLS LAST*, etc.
1674
+ """
1675
+
1676
+ # TODO: support NULLS FIRST/NULLS LAST
1677
+ def __init__(
1678
+ self,
1679
+ input_node: RelNode,
1680
+ sorting: Sequence[tuple[SqlExpression, SortDirection] | SqlExpression],
1681
+ *,
1682
+ parent_node: Optional[RelNode] = None,
1683
+ ) -> None:
1684
+ self._input_node = input_node
1685
+ self._sorting = tuple(
1686
+ [
1687
+ sort_item if isinstance(sort_item, tuple) else (sort_item, "asc")
1688
+ for sort_item in sorting
1689
+ ]
1690
+ )
1691
+ super().__init__(parent_node)
1692
+
1693
+ @property
1694
+ def input_node(self) -> RelNode:
1695
+ """Get the operator providing the tuples to sort.
1696
+
1697
+ Returns
1698
+ -------
1699
+ RelNode
1700
+ A relation
1701
+ """
1702
+ return self._input_node
1703
+
1704
+ @property
1705
+ def sorting(self) -> Sequence[tuple[SqlExpression, SortDirection]]:
1706
+ """Get the desired ordering.
1707
+
1708
+ Later expressions are used to solve ties among tuples with the same expression values in the first couple of
1709
+ expressions.
1710
+
1711
+ Returns
1712
+ -------
1713
+ Sequence[tuple[SqlExpression, SortDirection]]
1714
+ The expressions to order, most signifcant orders coming first.
1715
+ """
1716
+ return self._sorting
1717
+
1718
+ def children(self) -> Sequence[RelNode]:
1719
+ return [self._input_node]
1720
+
1721
+ def accept_visitor(self, visitor: RelNodeVisitor[VisitorResult]) -> VisitorResult:
1722
+ return visitor.visit_sort(self)
1723
+
1724
+ def mutate(
1725
+ self,
1726
+ *,
1727
+ input_node: Optional[RelNode] = None,
1728
+ sorting: Optional[
1729
+ Sequence[tuple[SqlExpression, SortDirection] | SqlExpression]
1730
+ ] = None,
1731
+ as_root: bool = False,
1732
+ ) -> Sort:
1733
+ """Creates a new sort with modified attributes.
1734
+
1735
+ Parameters
1736
+ ----------
1737
+ input_node : Optional[RelNode], optional
1738
+ The new input node to use. If *None*, the current input node is re-used.
1739
+ sorting : Optional[Sequence[tuple[SqlExpression, SortDirection] | SqlExpression]], optional
1740
+ The new sorting to use. If *None*, the current sorting is re-used.
1741
+ as_root : bool, optional
1742
+ Whether the sort should become the new root node of the tree. This overwrites any value passed to `parent`.
1743
+
1744
+ Returns
1745
+ -------
1746
+ Sort
1747
+ The modified sort node
1748
+
1749
+ See Also
1750
+ --------
1751
+ RelNode.mutate : for safety considerations and calling conventions
1752
+ """
1753
+ params = {
1754
+ param: val
1755
+ for param, val in locals().items()
1756
+ if param != "self" and not param.startswith("__")
1757
+ }
1758
+ return super().mutate(**params)
1759
+
1760
+ def _recalc_hash_val(self) -> int:
1761
+ return hash((self._input_node, self._sorting))
1762
+
1763
+ __hash__ = RelNode.__hash__
1764
+
1765
+ def __eq__(self, other: object) -> bool:
1766
+ return (
1767
+ isinstance(other, type(self))
1768
+ and self._input_node == other._input_node
1769
+ and self._sorting == other._sorting
1770
+ )
1771
+
1772
+ def __str__(self) -> str:
1773
+ sorting_str = ", ".join(
1774
+ f"{sort_col}{'↓' if sort_dir == 'asc' else '↑'}"
1775
+ for sort_col, sort_dir in self._sorting
1776
+ )
1777
+ return f"τ ({sorting_str})"
1778
+
1779
+
1780
+ class Map(RelNode):
1781
+ """Mapping computes new expressions from the currently existing ones.
1782
+
1783
+ For example, the expression *R.a + 42* can be computed during a mapping operation based on the input from a relation
1784
+ node *R*.
1785
+
1786
+ Parameters
1787
+ ----------
1788
+ input_node : RelNode
1789
+ The tuples to process
1790
+ mapping : dict[frozenset[SqlExpression | ColumnReference], frozenset[SqlExpression]]
1791
+ The expressions to compute. Maps from the arguments to the target expressions. The arguments themselves can be computed
1792
+ during the very same mapping operation. Alternatively, they can be supplied by the `input_node`.
1793
+ parent_node : Optional[RelNode], optional
1794
+ The parent node of the operator, if one exists. The parent is the operator that receives the output relation of the
1795
+ current operator. If the current operator is the root and (currently) does not have a parent, *None* can be used.
1796
+ """
1797
+
1798
+ def __init__(
1799
+ self,
1800
+ input_node: RelNode,
1801
+ mapping: dict[
1802
+ frozenset[SqlExpression | ColumnReference], frozenset[SqlExpression]
1803
+ ],
1804
+ *,
1805
+ parent_node: Optional[RelNode] = None,
1806
+ ) -> None:
1807
+ self._input_node = input_node
1808
+ self._mapping = util.frozendict(
1809
+ {
1810
+ ColumnExpression(expression)
1811
+ if isinstance(expression, ColumnReference)
1812
+ else expression: target
1813
+ for expression, target in mapping.items()
1814
+ }
1815
+ )
1816
+ super().__init__(parent_node)
1817
+
1818
+ @property
1819
+ def input_node(self) -> RelNode:
1820
+ """Get the operator that provides the tuples to map.
1821
+
1822
+ Returns
1823
+ -------
1824
+ RelNode
1825
+ A relation
1826
+ """
1827
+ return self._input_node
1828
+
1829
+ @property
1830
+ def mapping(
1831
+ self,
1832
+ ) -> util.frozendict[frozenset[SqlExpression], frozenset[SqlExpression]]:
1833
+ """Get the expressions to compute. Maps from the arguments to the target expressions.
1834
+
1835
+ The arguments themselves can be computed during the very same mapping operation. Alternatively, they can be supplied by
1836
+ the input node.
1837
+
1838
+ Returns
1839
+ -------
1840
+ util.frozendict[frozenset[SqlExpression], frozenset[SqlExpression]]
1841
+ The expressions
1842
+ """
1843
+ return self._mapping
1844
+
1845
+ def children(self) -> Sequence[RelNode]:
1846
+ return [self._input_node]
1847
+
1848
+ def provided_expressions(self) -> frozenset[SqlExpression]:
1849
+ return super().provided_expressions() | util.set_union(
1850
+ map_target for map_target in self._mapping.values()
1851
+ )
1852
+
1853
+ def accept_visitor(self, visitor: RelNodeVisitor[VisitorResult]) -> VisitorResult:
1854
+ return visitor.visit_map(self)
1855
+
1856
+ def mutate(
1857
+ self,
1858
+ *,
1859
+ input_node: Optional[RelNode] = None,
1860
+ mapping: Optional[
1861
+ dict[frozenset[SqlExpression | ColumnReference], frozenset[SqlExpression]]
1862
+ ] = None,
1863
+ as_root: bool = False,
1864
+ ) -> Map:
1865
+ """Creates a new map with modified attributes.
1866
+
1867
+ Parameters
1868
+ ----------
1869
+ input_node : Optional[RelNode], optional
1870
+ The new input node to use. If *None*, the current input node is re-used.
1871
+ mapping : Optional[dict[frozenset[SqlExpression | ColumnReference], frozenset[SqlExpression]]], optional
1872
+ The new mapping to use. If *None*, the current mapping is re-used.
1873
+ as_root : bool, optional
1874
+ Whether the map should become the new root node of the tree. This overwrites any value passed to `parent`.
1875
+
1876
+ Returns
1877
+ -------
1878
+ Map
1879
+ The modified map node
1880
+
1881
+ See Also
1882
+ --------
1883
+ RelNode.mutate : for safety considerations and calling conventions
1884
+ """
1885
+ params = {
1886
+ param: val
1887
+ for param, val in locals().items()
1888
+ if param != "self" and not param.startswith("__")
1889
+ }
1890
+ return super().mutate(**params)
1891
+
1892
+ def _recalc_hash_val(self) -> int:
1893
+ return hash((self._input_node, self._mapping))
1894
+
1895
+ __hash__ = RelNode.__hash__
1896
+
1897
+ def __eq__(self, other: object) -> bool:
1898
+ return (
1899
+ isinstance(other, type(self))
1900
+ and self._input_node == other._input_node
1901
+ and self._mapping == other._mapping
1902
+ )
1903
+
1904
+ def __str__(self) -> str:
1905
+ pretty_mapping: dict[str, str] = {}
1906
+ for target_col, expression in self._mapping.items():
1907
+ if len(target_col) == 1:
1908
+ target_col = util.simplify(target_col)
1909
+ target_str = str(target_col)
1910
+ else:
1911
+ target_str = "(" + ", ".join(str(t) for t in target_col) + ")"
1912
+ if len(expression) == 1:
1913
+ expression = util.simplify(expression)
1914
+ expr_str = str(expression)
1915
+ else:
1916
+ expr_str = "(" + ", ".join(str(e) for e in expression) + ")"
1917
+ pretty_mapping[target_str] = expr_str
1918
+
1919
+ mapping_str = ", ".join(
1920
+ f"{target_col}: {expr}" for target_col, expr in pretty_mapping.items()
1921
+ )
1922
+ return f"χ ({mapping_str})"
1923
+
1924
+
1925
+ class DuplicateElimination(RelNode):
1926
+ """Duplicate elimination ensures that all attribute combinations of all tuples are unique.
1927
+
1928
+ Parameters
1929
+ ----------
1930
+ input_node : RelNode
1931
+ The tuples that should be unique
1932
+ parent_node : Optional[RelNode], optional
1933
+ The parent node of the operator, if one exists. The parent is the operator that receives the output relation of the
1934
+ current operator. If the current operator is the root and (currently) does not have a parent, *None* can be used.
1935
+
1936
+ Notes
1937
+ -----
1938
+ Strictly speaking, this operator is not part of traditional relational algebra. This is because the algebra uses
1939
+ set-semantics which do not supply any ordering. However, due to SQL's usage of multi-sets which allow duplicates, most
1940
+ relational algebra dialects support ordering nevertheless.
1941
+ """
1942
+
1943
+ def __init__(
1944
+ self, input_node: RelNode, *, parent_node: Optional[RelNode] = None
1945
+ ) -> None:
1946
+ self._input_node = input_node
1947
+ super().__init__(parent_node)
1948
+
1949
+ @property
1950
+ def input_node(self) -> RelNode:
1951
+ return self._input_node
1952
+
1953
+ def children(self) -> Sequence[RelNode]:
1954
+ return [self._input_node]
1955
+
1956
+ def accept_visitor(self, visitor: RelNodeVisitor[VisitorResult]) -> VisitorResult:
1957
+ return visitor.visit_duplicate_elim(self)
1958
+
1959
+ def mutate(
1960
+ self, *, input_node: Optional[RelNode] = None, as_root: bool = False
1961
+ ) -> DuplicateElimination:
1962
+ """Creates a new duplicate elimination with modified attributes.
1963
+
1964
+ Parameters
1965
+ ----------
1966
+ input_node : Optional[RelNode], optional
1967
+ The new input node to use. If *None*, the current input node is re-used.
1968
+ as_root : bool, optional
1969
+ Whether the duplicate elimination should become the new root node of the tree. This overwrites any value passed to
1970
+ `parent`.
1971
+
1972
+ Returns
1973
+ -------
1974
+ DuplicateElimination
1975
+ The modified duplicate elimination node
1976
+
1977
+ See Also
1978
+ --------
1979
+ RelNode.mutate : for safety considerations and calling conventions
1980
+ """
1981
+ params = {
1982
+ param: val
1983
+ for param, val in locals().items()
1984
+ if param != "self" and not param.startswith("__")
1985
+ }
1986
+ return super().mutate(**params)
1987
+
1988
+ def _recalc_hash_val(self) -> int:
1989
+ return hash(self._input_node)
1990
+
1991
+ __hash__ = RelNode.__hash__
1992
+
1993
+ def __eq__(self, other: object) -> bool:
1994
+ return isinstance(other, type(self)) and self._input_node == other._input_node
1995
+
1996
+ def __str__(self) -> str:
1997
+ return "δ"
1998
+
1999
+
2000
+ class SemiJoin(RelNode):
2001
+ """A semi join provides all tuples from one relation with a matching partner tuple from another relation.
2002
+
2003
+ Parameters
2004
+ ----------
2005
+ input_node : RelNode
2006
+ The tuples to "filter"
2007
+ subquery_node : SubqueryScan
2008
+ The relation that provides all tuples that have to match tuples in the `input_node`.
2009
+ predicate : Optional[AbstractPredicate], optional
2010
+ An optional predicate that is used to determine a match.
2011
+ parent_node : Optional[RelNode], optional
2012
+ The parent node of the operator, if one exists. The parent is the operator that receives the output relation of the
2013
+ current operator. If the current operator is the root and (currently) does not have a parent, *None* can be used.
2014
+
2015
+ Notes
2016
+ -----
2017
+ A semi join is defined as
2018
+
2019
+ .. math:: ⋉_\\theta(R, S) := \\{ r | r \\in R \\land s \\in S \\land \\theta(r, s) \\}
2020
+ """
2021
+
2022
+ def __init__(
2023
+ self,
2024
+ input_node: RelNode,
2025
+ subquery_node: SubqueryScan,
2026
+ predicate: Optional[AbstractPredicate] = None,
2027
+ *,
2028
+ parent_node: Optional[RelNode] = None,
2029
+ ) -> None:
2030
+ # TODO: dependent iff predicate is None
2031
+ self._input_node = input_node
2032
+
2033
+ self._subquery_node = subquery_node.mutate()
2034
+ self._subquery_node._parent = (
2035
+ self # we need to set the parent manually to prevent infinite recursion
2036
+ )
2037
+
2038
+ self._predicate = predicate
2039
+
2040
+ super().__init__(parent_node)
2041
+
2042
+ @property
2043
+ def input_node(self) -> RelNode:
2044
+ """Get the operator providing the tuples to filter.
2045
+
2046
+ Returns
2047
+ -------
2048
+ RelNode
2049
+ A relation
2050
+ """
2051
+ return self._input_node
2052
+
2053
+ @property
2054
+ def subquery_node(self) -> SubqueryScan:
2055
+ """Get the operator providing the filtering tuples.
2056
+
2057
+ Returns
2058
+ -------
2059
+ SubqueryScan
2060
+ A relation
2061
+ """
2062
+ return self._subquery_node
2063
+
2064
+ @property
2065
+ def predicate(self) -> Optional[AbstractPredicate]:
2066
+ """Get the match condition to determine the join partners.
2067
+
2068
+ If there is no dedicated predicate, tuples from the `input_node` match, if any tuple is emitted by the
2069
+ `subquery_node`.
2070
+
2071
+ Returns
2072
+ -------
2073
+ Optional[AbstractPredicate]
2074
+ The condition
2075
+ """
2076
+ return self._predicate
2077
+
2078
+ def is_dependent(self) -> bool:
2079
+ """Checks, whether the subquery relation is depdent (sometimes also called correlated) with the input relation.
2080
+
2081
+ Returns
2082
+ -------
2083
+ bool
2084
+ Whether the subquery is correlated with the input query
2085
+
2086
+ See Also
2087
+ --------
2088
+ SqlQuery.is_depedent
2089
+ """
2090
+ return self._predicate is None
2091
+
2092
+ def children(self) -> Sequence[RelNode]:
2093
+ return [self._input_node, self._subquery_node]
2094
+
2095
+ def accept_visitor(self, visitor: RelNodeVisitor[VisitorResult]) -> VisitorResult:
2096
+ return visitor.visit_semijoin(self)
2097
+
2098
+ def mutate(
2099
+ self,
2100
+ *,
2101
+ input_node: Optional[RelNode] = None,
2102
+ subquery_node: Optional[SubqueryScan] = None,
2103
+ predicate: Optional[AbstractPredicate] = None,
2104
+ as_root: bool = False,
2105
+ ) -> SemiJoin:
2106
+ """Creates a new semi join with modified attributes.
2107
+
2108
+ Parameters
2109
+ ----------
2110
+ input_node : Optional[RelNode], optional
2111
+ The new input node to use. If *None*, the current input node is re-used.
2112
+ subquery_node : Optional[SubqueryScan], optional
2113
+ The new subquery node to use. If *None*, the current subquery node is re-used.
2114
+ predicate : Optional[AbstractPredicate], optional
2115
+ The new predicate to use. If *None*, the current predicate is re-used.
2116
+ as_root : bool, optional
2117
+ Whether the semi join should become the new root node of the tree. This overwrites any value passed to `parent`.
2118
+
2119
+ Returns
2120
+ -------
2121
+ SemiJoin
2122
+ The modified semi join node
2123
+
2124
+ See Also
2125
+ --------
2126
+ RelNode.mutate : for safety considerations and calling conventions
2127
+ """
2128
+ params = {
2129
+ param: val
2130
+ for param, val in locals().items()
2131
+ if param != "self" and not param.startswith("__")
2132
+ }
2133
+ return super().mutate(**params)
2134
+
2135
+ def _update_child_nodes(self, children: Sequence[RelNode]) -> None:
2136
+ self._assert_correct_update_child_count(children)
2137
+ self._input_node = children[0]
2138
+ self._subquery_node = children[1]
2139
+
2140
+ def _recalc_hash_val(self) -> int:
2141
+ return hash((self._input_node, self._subquery_node, self._predicate))
2142
+
2143
+ __hash__ = RelNode.__hash__
2144
+
2145
+ def __eq__(self, other: object) -> bool:
2146
+ return (
2147
+ isinstance(other, type(self))
2148
+ and self._input_node == other._input_node
2149
+ and self._subquery_node == other._subquery_node
2150
+ and self._predicate == other._predicate
2151
+ )
2152
+
2153
+ def __str__(self) -> str:
2154
+ return "⋉" if self._predicate is None else f"⋉ ({self._predicate})"
2155
+
2156
+
2157
+ class AntiJoin(RelNode):
2158
+ """An anti join provides all tuples from one relation with no matching partner tuple from another relation.
2159
+
2160
+ Parameters
2161
+ ----------
2162
+ input_node : RelNode
2163
+ The tuples to "filter"
2164
+ subquery_node : SubqueryScan
2165
+ The relation that provides all tuples that have to match tuples in the `input_node`.
2166
+ predicate : Optional[AbstractPredicate], optional
2167
+ An optional predicate that is used to determine a match.
2168
+ parent_node : Optional[RelNode], optional
2169
+ The parent node of the operator, if one exists. The parent is the operator that receives the output relation of the
2170
+ current operator. If the current operator is the root and (currently) does not have a parent, *None* can be used.
2171
+
2172
+ Notes
2173
+ -----
2174
+ An anti join is defined as
2175
+
2176
+ .. math:: ▷_\\theta(R, S) := \\{ r | r \\in R \\land s \\in S \\land \\theta(r, s) \\}
2177
+
2178
+ """
2179
+
2180
+ def __init__(
2181
+ self,
2182
+ input_node: RelNode,
2183
+ subquery_node: SubqueryScan,
2184
+ predicate: Optional[AbstractPredicate] = None,
2185
+ *,
2186
+ parent_node: Optional[RelNode] = None,
2187
+ ) -> None:
2188
+ # TODO: dependent iff predicate is None
2189
+ self._input_node = input_node
2190
+
2191
+ self._subquery_node = subquery_node.mutate()
2192
+ self._subquery_node._parent = (
2193
+ self # we need to set the parent manually to prevent infinite recursion
2194
+ )
2195
+
2196
+ self._predicate = predicate
2197
+
2198
+ super().__init__(parent_node)
2199
+
2200
+ @property
2201
+ def input_node(self) -> RelNode:
2202
+ """Get the operator providing the tuples to filter.
2203
+
2204
+ Returns
2205
+ -------
2206
+ RelNode
2207
+ A relation
2208
+ """
2209
+ return self._input_node
2210
+
2211
+ @property
2212
+ def subquery_node(self) -> SubqueryScan:
2213
+ """Get the operator providing the filtering tuples.
2214
+
2215
+ Returns
2216
+ -------
2217
+ SubqueryScan
2218
+ A relation
2219
+ """
2220
+ return self._subquery_node
2221
+
2222
+ @property
2223
+ def predicate(self) -> Optional[AbstractPredicate]:
2224
+ """Get the match condition to determine the join partners.
2225
+
2226
+ If there is no dedicated predicate, tuples from the `input_node` match, if any tuple is emitted by the
2227
+ `subquery_node`.
2228
+
2229
+ Returns
2230
+ -------
2231
+ Optional[AbstractPredicate]
2232
+ The condition
2233
+ """
2234
+ return self._predicate
2235
+
2236
+ def is_dependent(self) -> bool:
2237
+ """Checks, whether the subquery relation is depdent (sometimes also called correlated) with the input relation.
2238
+
2239
+ Returns
2240
+ -------
2241
+ bool
2242
+ Whether the subquery is correlated with the input query
2243
+
2244
+ See Also
2245
+ --------
2246
+ SqlQuery.is_depedent
2247
+ """
2248
+ return self._predicate is None
2249
+
2250
+ def children(self) -> Sequence[RelNode]:
2251
+ return [self._input_node, self._subquery_node]
2252
+
2253
+ def accept_visitor(self, visitor: RelNodeVisitor[VisitorResult]) -> VisitorResult:
2254
+ return visitor.visit_antijoin(self)
2255
+
2256
+ def mutate(
2257
+ self,
2258
+ *,
2259
+ input_node: Optional[RelNode] = None,
2260
+ subquery_node: Optional[SubqueryScan] = None,
2261
+ predicate: Optional[AbstractPredicate] = None,
2262
+ as_root: bool = False,
2263
+ ) -> AntiJoin:
2264
+ """Creates a new anti join with modified attributes.
2265
+
2266
+ Parameters
2267
+ ----------
2268
+ input_node : Optional[RelNode], optional
2269
+ The new input node to use. If *None*, the current input node is re-used.
2270
+ subquery_node : Optional[SubqueryScan], optional
2271
+ The new subquery node to use. If *None*, the current subquery node is re-used.
2272
+ predicate : Optional[AbstractPredicate], optional
2273
+ The new predicate to use. If *None*, the current predicate is re-used.
2274
+ as_root : bool, optional
2275
+ Whether the anti join should become the new root node of the tree. This overwrites any value passed to `parent`.
2276
+
2277
+ Returns
2278
+ -------
2279
+ AntiJoin
2280
+ The modified anti join node
2281
+
2282
+ See Also
2283
+ --------
2284
+ RelNode.mutate : for safety considerations and calling conventions
2285
+ """
2286
+ params = {
2287
+ param: val
2288
+ for param, val in locals().items()
2289
+ if param != "self" and not param.startswith("__")
2290
+ }
2291
+ return super().mutate(**params)
2292
+
2293
+ def _update_child_nodes(self, children: Sequence[RelNode]) -> None:
2294
+ self._assert_correct_update_child_count(children)
2295
+ self._input_node = children[0]
2296
+ self._subquery_node = children[1]
2297
+
2298
+ def _recalc_hash_val(self) -> int:
2299
+ return hash((self._input_node, self._subquery_node, self._predicate))
2300
+
2301
+ __hash__ = RelNode.__hash__
2302
+
2303
+ def __eq__(self, other: object) -> bool:
2304
+ return (
2305
+ isinstance(other, type(self))
2306
+ and self._input_node == other._input_node
2307
+ and self._subquery_node == other._subquery_node
2308
+ and self._predicate == other._predicate
2309
+ )
2310
+
2311
+ def __str__(self) -> str:
2312
+ return "▷" if self._predicate is None else f"▷ ({self._predicate})"
2313
+
2314
+
2315
+ class SubqueryScan(RelNode):
2316
+ """A meta node to designate a subtree that originated from a subquery.
2317
+
2318
+ Parameters
2319
+ ----------
2320
+ input_node : RelNode
2321
+ The relation that identifies the subquery result
2322
+ subquery : SqlQuery
2323
+ The query that actually calculates the subquery
2324
+ parent_node : Optional[RelNode], optional
2325
+ The parent node of the operator, if one exists. The parent is the operator that receives the output relation of the
2326
+ current operator. If the current operator is the root and (currently) does not have a parent, *None* can be used.
2327
+
2328
+ Notes
2329
+ -----
2330
+ This node is not part of traditional relational algebra, nor do many other systems make use of it. For our purposes it
2331
+ serves as a marker node to quickly designate subqueries and to operate on the original queries or their algebraic
2332
+ representation in a convenient manner.
2333
+ """
2334
+
2335
+ def __init__(
2336
+ self,
2337
+ input_node: RelNode,
2338
+ subquery: SqlQuery,
2339
+ *,
2340
+ parent_node: Optional[RelNode] = None,
2341
+ ) -> None:
2342
+ self._input_node = input_node
2343
+ self._subquery = subquery
2344
+ super().__init__(parent_node)
2345
+
2346
+ @property
2347
+ def input_node(self) -> RelNode:
2348
+ """Get the result node of the subquery
2349
+
2350
+ Returns
2351
+ -------
2352
+ RelNode
2353
+ A relation
2354
+ """
2355
+ return self._input_node
2356
+
2357
+ @property
2358
+ def subquery(self) -> SqlQuery:
2359
+ """Get the actual subquery.
2360
+
2361
+ Returns
2362
+ -------
2363
+ SqlQuery
2364
+ A query
2365
+ """
2366
+ return self._subquery
2367
+
2368
+ def tables(self, *, ignore_subqueries: bool = False) -> frozenset[TableReference]:
2369
+ return (
2370
+ frozenset()
2371
+ if ignore_subqueries
2372
+ else super().tables(ignore_subqueries=ignore_subqueries)
2373
+ )
2374
+
2375
+ def children(self) -> Sequence[RelNode]:
2376
+ return [self._input_node]
2377
+
2378
+ def provided_expressions(self) -> frozenset[SqlExpression]:
2379
+ return {SubqueryExpression(self._subquery)} | super().provided_expressions()
2380
+
2381
+ def accept_visitor(self, visitor: RelNodeVisitor[VisitorResult]) -> VisitorResult:
2382
+ return visitor.visit_subquery(self)
2383
+
2384
+ def mutate(
2385
+ self,
2386
+ *,
2387
+ input_node: Optional[RelNode] = None,
2388
+ subquery: Optional[SqlQuery] = None,
2389
+ as_root: bool = False,
2390
+ ) -> SubqueryScan:
2391
+ """Creates a new subquery scan with modified attributes.
2392
+
2393
+ Parameters
2394
+ ----------
2395
+ input_node : Optional[RelNode], optional
2396
+ The new input node to use. If *None*, the current input node is re-used.
2397
+ subquery : Optional[SqlQuery], optional
2398
+ The new subquery to use. If *None*, the current subquery is re-used.
2399
+ as_root : bool, optional
2400
+ Whether the subquery scan should become the new root node of the tree. This overwrites any value passed to
2401
+ `parent`.
2402
+
2403
+ Returns
2404
+ -------
2405
+ SubqueryScan
2406
+ The modified subquery scan node
2407
+
2408
+ See Also
2409
+ --------
2410
+ RelNode.mutate : for safety considerations and calling conventions
2411
+ """
2412
+ params = {
2413
+ param: val
2414
+ for param, val in locals().items()
2415
+ if param != "self" and not param.startswith("__")
2416
+ }
2417
+ return super().mutate(**params)
2418
+
2419
+ def _recalc_hash_val(self) -> int:
2420
+ return hash((self._input_node, self._subquery))
2421
+
2422
+ __hash__ = RelNode.__hash__
2423
+
2424
+ def __eq__(self, other: object) -> bool:
2425
+ return (
2426
+ isinstance(other, type(self))
2427
+ and self._input_node == other._input_node
2428
+ and self._subquery == other._subquery
2429
+ )
2430
+
2431
+ def __str__(self) -> str:
2432
+ return (
2433
+ "<<Scalar Subquery Scan>>"
2434
+ if self._subquery.is_scalar()
2435
+ else "<<Subquery Scan>>"
2436
+ )
2437
+
2438
+
2439
+ VisitorResult = typing.TypeVar("VisitorResult")
2440
+ """Result type of visitor processes."""
2441
+
2442
+
2443
+ class RelNodeVisitor(abc.ABC, typing.Generic[VisitorResult]):
2444
+ """Basic visitor to operator on arbitrary relational algebra trees.
2445
+
2446
+ See Also
2447
+ --------
2448
+ RelNode
2449
+
2450
+ References
2451
+ ----------
2452
+
2453
+ .. Visitor pattern: https://en.wikipedia.org/wiki/Visitor_pattern
2454
+ """
2455
+
2456
+ @abc.abstractmethod
2457
+ def visit_selection(self, selection: Selection) -> VisitorResult:
2458
+ raise NotImplementedError
2459
+
2460
+ @abc.abstractmethod
2461
+ def visit_cross_product(self, cross_product: CrossProduct) -> VisitorResult:
2462
+ raise NotImplementedError
2463
+
2464
+ @abc.abstractmethod
2465
+ def visit_union(self, union: Union) -> VisitorResult:
2466
+ raise NotImplementedError
2467
+
2468
+ @abc.abstractmethod
2469
+ def visit_intersection(self, intersection: Intersection) -> VisitorResult:
2470
+ raise NotImplementedError
2471
+
2472
+ @abc.abstractmethod
2473
+ def visit_difference(self, difference: Difference) -> VisitorResult:
2474
+ raise NotImplementedError
2475
+
2476
+ @abc.abstractmethod
2477
+ def visit_base_relation(self, base_table: Relation) -> VisitorResult:
2478
+ raise NotImplementedError
2479
+
2480
+ @abc.abstractmethod
2481
+ def visit_theta_join(self, join: ThetaJoin) -> VisitorResult:
2482
+ raise NotImplementedError
2483
+
2484
+ @abc.abstractmethod
2485
+ def visit_projection(self, projection: Projection) -> VisitorResult:
2486
+ raise NotImplementedError
2487
+
2488
+ @abc.abstractmethod
2489
+ def visit_grouping(self, grouping: Grouping) -> VisitorResult:
2490
+ raise NotImplementedError
2491
+
2492
+ @abc.abstractmethod
2493
+ def visit_rename(self, rename: Rename) -> VisitorResult:
2494
+ raise NotImplementedError
2495
+
2496
+ @abc.abstractmethod
2497
+ def visit_sort(self, sorting: Sort) -> VisitorResult:
2498
+ raise NotImplementedError
2499
+
2500
+ @abc.abstractmethod
2501
+ def visit_map(self, mapping: Map) -> VisitorResult:
2502
+ raise NotImplementedError
2503
+
2504
+ @abc.abstractmethod
2505
+ def visit_duplicate_elim(
2506
+ self, duplicate_elim: DuplicateElimination
2507
+ ) -> VisitorResult:
2508
+ raise NotImplementedError
2509
+
2510
+ @abc.abstractmethod
2511
+ def visit_semijoin(self, join: SemiJoin) -> VisitorResult:
2512
+ raise NotImplementedError
2513
+
2514
+ @abc.abstractmethod
2515
+ def visit_antijoin(self, join: AntiJoin) -> VisitorResult:
2516
+ raise NotImplementedError
2517
+
2518
+ @abc.abstractmethod
2519
+ def visit_subquery(self, subquery: SubqueryScan) -> VisitorResult:
2520
+ raise NotImplementedError
2521
+
2522
+
2523
+ def _collect_leaf_nodes(root: RelNode) -> set[Relation]:
2524
+ """Provides all leaf nodes of a relational algebra tree.
2525
+
2526
+ For relation nodes with subquery input, only the leaf nodes of the subquery are returned.
2527
+
2528
+ Parameters
2529
+ ----------
2530
+ root : RelNode
2531
+ The root node of the tree
2532
+
2533
+ Returns
2534
+ -------
2535
+ set[Relation]
2536
+ The leaf nodes
2537
+ """
2538
+ if isinstance(root, Relation) and not root.subquery_input:
2539
+ nodes = {root}
2540
+ else:
2541
+ nodes = util.set_union(_collect_leaf_nodes(child) for child in root.children())
2542
+ return nodes
2543
+
2544
+
2545
+ @dataclasses.dataclass
2546
+ class _RelTreeUpdateSet:
2547
+ """Holds the root node of the updated tree and the node that asked for the tree update in the first place."""
2548
+
2549
+ updated_root: RelNode
2550
+ updated_initiator: RelNode
2551
+
2552
+
2553
+ class _RelNodeUpdateManager:
2554
+ """Handles the process of copying and modifying an entire relational algebra tree.
2555
+
2556
+ Parameters
2557
+ ----------
2558
+ root : RelNode
2559
+ The current root node of the tree
2560
+ initiator : RelNode
2561
+ The node that asked for the tree update in the first place
2562
+ """
2563
+
2564
+ def __init__(self, root: RelNode, *, initiator: RelNode) -> None:
2565
+ self._root = root
2566
+ self._initiator = initiator
2567
+ self._updated_nodes: dict[RelNode, RelNode] = {}
2568
+ self._node_working_set: list[RelNode] = list(_collect_leaf_nodes(root))
2569
+
2570
+ def make_relalg_copy(self, **kwargs) -> _RelTreeUpdateSet:
2571
+ """Creates a copy of the relational algebra tree and applies modifications to it.
2572
+
2573
+ Parameters
2574
+ ----------
2575
+ **kwargs
2576
+ Arbitrary keyword arguments to apply to the tree nodes. These are the arguments to the `mutate` call.
2577
+
2578
+ Returns
2579
+ -------
2580
+ _RelTreeUpdateSet
2581
+ The updated relalg tree.
2582
+ """
2583
+ as_root: bool = kwargs.get("as_root", False)
2584
+ updated_root: RelNode = None
2585
+ updated_initiator: RelNode = None
2586
+
2587
+ while self._node_working_set:
2588
+ current_node = self._node_working_set.pop(0)
2589
+ self._update_node_working_set(current_node)
2590
+
2591
+ already_updated = current_node in self._updated_nodes
2592
+ pending_child_update = any(
2593
+ child not in self._updated_nodes for child in current_node.children()
2594
+ )
2595
+ if already_updated or pending_child_update:
2596
+ continue
2597
+
2598
+ updated_node = current_node.clone()
2599
+ updated_node._update_child_nodes(
2600
+ [self._updated_nodes[child] for child in current_node.children()]
2601
+ )
2602
+ updated_node._clear_parent_links()
2603
+
2604
+ if current_node == self._initiator:
2605
+ self._perform_node_update(updated_node, **kwargs)
2606
+ updated_initiator = updated_node
2607
+ if as_root:
2608
+ updated_root = updated_node
2609
+ break
2610
+
2611
+ self._updated_nodes[current_node] = updated_node
2612
+
2613
+ if not as_root and current_node == self._root:
2614
+ updated_root = updated_node
2615
+
2616
+ assert updated_initiator is not None
2617
+ assert updated_root is not None
2618
+
2619
+ updated_root._rebuild_linkage()
2620
+ updated_root._rehash()
2621
+ self._root._rebuild_linkage()
2622
+ self._root._rehash()
2623
+
2624
+ return _RelTreeUpdateSet(updated_root, updated_initiator)
2625
+
2626
+ def _perform_node_update(self, node: RelNode, **kwargs) -> None:
2627
+ """Updates the initiator of a tree mutation with node-specific arguments.
2628
+
2629
+ Parameters
2630
+ ----------
2631
+ node : RelNode
2632
+ The node to update. Must be the *copied* version of the initiator node.
2633
+ **kwargs
2634
+ The arguments to apply to the node
2635
+ """
2636
+ fields: set[str] = set(vars(node).keys())
2637
+ for attr, value in kwargs.items():
2638
+ if value is None:
2639
+ continue
2640
+ actual_attr_name = f"_{attr}"
2641
+ if actual_attr_name not in fields:
2642
+ continue
2643
+
2644
+ if isinstance(value, RelNode):
2645
+ value = self._merge_new_child(value)
2646
+ setattr(node, actual_attr_name, value)
2647
+
2648
+ def _merge_new_child(self, child_node: RelNode) -> RelNode:
2649
+ """Merges a new child node into the tree update process.
2650
+
2651
+ This ensures that all child nodes of the new node correctly reference their updated counterparts.
2652
+
2653
+ Parameters
2654
+ ----------
2655
+ child_node : RelNode
2656
+ The new child node to merge. Must be the *original* version of the child node.
2657
+
2658
+ Returns
2659
+ -------
2660
+ RelNode
2661
+ The updated child node
2662
+ """
2663
+ internal_node_working_set: list[RelNode] = list(_collect_leaf_nodes(child_node))
2664
+ updated_child: RelNode = None
2665
+
2666
+ while internal_node_working_set:
2667
+ current_node = internal_node_working_set.pop(0)
2668
+ already_updated = current_node in self._updated_nodes
2669
+ if already_updated and current_node == child_node:
2670
+ updated_child = self._updated_nodes[current_node]
2671
+ break
2672
+ if already_updated:
2673
+ self._update_node_working_set(
2674
+ current_node, working_set=internal_node_working_set
2675
+ )
2676
+ continue
2677
+ if any(
2678
+ child not in self._updated_nodes for child in current_node.children()
2679
+ ):
2680
+ continue
2681
+
2682
+ updated_node = current_node.clone()
2683
+ updated_node._update_child_nodes(
2684
+ [self._updated_nodes[child] for child in current_node.children()]
2685
+ )
2686
+ updated_node._clear_parent_links()
2687
+
2688
+ self._updated_nodes[current_node] = updated_node
2689
+ if current_node == child_node:
2690
+ updated_child = updated_node
2691
+ break
2692
+
2693
+ self._update_node_working_set(
2694
+ current_node, working_set=internal_node_working_set
2695
+ )
2696
+
2697
+ assert updated_child is not None
2698
+ return updated_child
2699
+
2700
+ def _update_node_working_set(
2701
+ self, node: RelNode, *, working_set: Optional[list[RelNode]] = None
2702
+ ) -> None:
2703
+ """Utility method to quickly include all relevant parent nodes into a node working set.
2704
+
2705
+ Parameters
2706
+ ----------
2707
+ node : RelNode
2708
+ The node whose upwards links should be added to the working set
2709
+ working_set : Optional[list[RelNode]], optional
2710
+ The working set to update. If *None*, the internal working set of the update manager is used.
2711
+ """
2712
+ working_set = self._node_working_set if working_set is None else working_set
2713
+ if node.parent_node:
2714
+ working_set.append(node.parent_node)
2715
+ working_set.extend(node.sideways_pass)
2716
+
2717
+
2718
+ def _is_aggregation(expression: SqlExpression) -> bool:
2719
+ """Utility to check whether an arbitrary SQL expression is an aggregation function.
2720
+
2721
+ Parameters
2722
+ ----------
2723
+ expression : SqlExpression
2724
+ The expression to check
2725
+
2726
+ Returns
2727
+ -------
2728
+ bool
2729
+ *True* if the expression is an aggregation or *False* otherwise
2730
+ """
2731
+ return isinstance(expression, FunctionExpression) and expression.is_aggregate()
2732
+
2733
+
2734
+ def _requires_aggregation(expression: SqlExpression) -> bool:
2735
+ """Checks, whether the current expression or any of its nested children aggregate input tuples.
2736
+
2737
+ Parameters
2738
+ ----------
2739
+ expression : SqlExpression
2740
+ The expression to check
2741
+
2742
+ Returns
2743
+ -------
2744
+ bool
2745
+ *True* if an aggregation was detected or *False* otherwise.
2746
+ """
2747
+ return any(
2748
+ _is_aggregation(child_expr) or _requires_aggregation(child_expr)
2749
+ for child_expr in expression.iterchildren()
2750
+ )
2751
+
2752
+
2753
+ def _needs_mapping(expression: SqlExpression) -> bool:
2754
+ """Checks, whether an expression has to be calculated via a mapping or can be supplied directly by the execution engine.
2755
+
2756
+ The latter case basically only applies to static values. Direct column expressions are still considered as requiring a
2757
+ mapping.
2758
+
2759
+ Parameters
2760
+ ----------
2761
+ expression : SqlExpression
2762
+ The expression to check
2763
+
2764
+ Returns
2765
+ -------
2766
+ bool
2767
+ *True* if the expression has to be mapped, *False* otherwise.
2768
+ """
2769
+ return not isinstance(expression, (StaticValueExpression, StarExpression))
2770
+
2771
+
2772
+ def _generate_expression_mapping_dict(
2773
+ expressions: list[SqlExpression],
2774
+ ) -> dict[frozenset[SqlExpression], frozenset[SqlExpression]]:
2775
+ """Determines all required expressions and maps them to their dervied expressions.
2776
+
2777
+ Consider an expression *CAST(R.a + 42 AS int)*. In order to evaluate the *CAST* statement, *R.a + 42* has to be calculated
2778
+ first. This knowledge is encoded in the mapping dictionary, which would look like ``{R.a + 42: CAST(...)}`` in this case.
2779
+
2780
+ Notice that this process is not recursive, i.e. nested expressions in the child expressions are not considered. The
2781
+ reasoning behind this is that these expressions should have been computed already via earlier mappings. Continuing the
2782
+ above example, there is no entry ``R.a: R.a + 42`` in the mapping, if the term *R.a + 42* itself is not part of the
2783
+ arguments.
2784
+
2785
+ Parameters
2786
+ ----------
2787
+ expressions : list[SqlExpression]
2788
+ The expressions to resolve
2789
+
2790
+ Returns
2791
+ -------
2792
+ dict[frozenset[SqlExpression], frozenset[SqlExpression]]
2793
+ A map from arguments to target expressions. If the same set of arguments is used to derive multiple expressions, all
2794
+ these target expressions are contained in the dictionary value.
2795
+ """
2796
+ mapping: dict[frozenset[SqlExpression], set[SqlExpression]] = (
2797
+ collections.defaultdict(set)
2798
+ )
2799
+ for expression in expressions:
2800
+ child_expressions = frozenset(
2801
+ child_expr
2802
+ for child_expr in expression.iterchildren()
2803
+ if _needs_mapping(child_expr)
2804
+ )
2805
+ mapping[child_expressions].add(expression)
2806
+ return {
2807
+ child_expressions: frozenset(derived_expressions)
2808
+ for child_expressions, derived_expressions in mapping.items()
2809
+ }
2810
+
2811
+
2812
+ class EvaluationPhase(enum.IntEnum):
2813
+ """Indicates when a specific expression or predicate can be evaluated at the earliest."""
2814
+
2815
+ BaseTable = enum.auto()
2816
+ """Evaluation is possible using only the tuples from the base table, e.g. base table filters."""
2817
+
2818
+ Join = enum.auto()
2819
+ """Evaluation is possible while joining the required base tables, e.g. join predicates."""
2820
+
2821
+ PostJoin = enum.auto()
2822
+ """Evaluation is possible once all base tables have been joined, e.g. mappings over joined columns."""
2823
+
2824
+ PostAggregation = enum.auto()
2825
+ """Evaluation is possible once all aggregations have been performed, e.g. filters over aggregated columns."""
2826
+
2827
+
2828
+ @dataclasses.dataclass(frozen=True)
2829
+ class _SubquerySet:
2830
+ """More expressive wrapper to collect subqueries from SQL queries.
2831
+
2832
+ Two subquery sets can be merged using the addition operator. Boolean tests on the subquery set succeed, if the set contains
2833
+ at least one subquery.
2834
+
2835
+ Attributes
2836
+ ----------
2837
+ subqueries : frozenset[SqlQuery]
2838
+ The subqueries that are currently in the set. Can be empty if there are no subqueries.
2839
+ """
2840
+
2841
+ subqueries: frozenset[SqlQuery]
2842
+
2843
+ @staticmethod
2844
+ def empty() -> _SubquerySet:
2845
+ """Generates a new subquery set without any entries."""
2846
+ return _SubquerySet(frozenset())
2847
+
2848
+ @staticmethod
2849
+ def of(subqueries: Iterable[SqlQuery]) -> _SubquerySet:
2850
+ """Generates a new subquery set containing specific subqueries.
2851
+
2852
+ This factory handles the generation of an appropriate frozenset.
2853
+ """
2854
+ return _SubquerySet(frozenset([subqueries]))
2855
+
2856
+ def __add__(self, other: _SubquerySet) -> _SubquerySet:
2857
+ if not isinstance(other, type(self)):
2858
+ return NotImplemented
2859
+ return _SubquerySet(self.subqueries | other.subqueries)
2860
+
2861
+ def __bool__(self) -> bool:
2862
+ return bool(self.subqueries)
2863
+
2864
+
2865
+ class _SubqueryDetector(
2866
+ SqlExpressionVisitor[_SubquerySet], PredicateVisitor[_SubquerySet]
2867
+ ):
2868
+ """Collects all subqueries from SQL expressions or predicates."""
2869
+
2870
+ def visit_and_predicate(
2871
+ self, predicate: CompoundPredicate, components: Sequence[AbstractPredicate]
2872
+ ) -> _SubquerySet:
2873
+ return self._traverse_predicate_expressions(predicate)
2874
+
2875
+ def visit_or_predicate(
2876
+ self, predicate: CompoundPredicate, components: Sequence[AbstractPredicate]
2877
+ ) -> _SubquerySet:
2878
+ return self._traverse_predicate_expressions(predicate)
2879
+
2880
+ def visit_not_predicate(
2881
+ self, predicate: CompoundPredicate, child_predicate: AbstractPredicate
2882
+ ) -> _SubquerySet:
2883
+ return self._traverse_predicate_expressions(predicate)
2884
+
2885
+ def visit_binary_predicate(self, predicate: BinaryPredicate) -> _SubquerySet:
2886
+ return self._traverse_predicate_expressions(predicate)
2887
+
2888
+ def visit_between_predicate(self, predicate: BetweenPredicate) -> _SubquerySet:
2889
+ return self._traverse_predicate_expressions(predicate)
2890
+
2891
+ def visit_in_predicate(self, predicate: InPredicate) -> _SubquerySet:
2892
+ return self._traverse_predicate_expressions(predicate)
2893
+
2894
+ def visit_unary_predicate(self, predicate: UnaryPredicate) -> _SubquerySet:
2895
+ return self._traverse_predicate_expressions(predicate)
2896
+
2897
+ def visit_static_value_expr(
2898
+ self, expression: StaticValueExpression
2899
+ ) -> _SubquerySet:
2900
+ return _SubquerySet.empty()
2901
+
2902
+ def visit_column_expr(self, expression: ColumnExpression) -> _SubquerySet:
2903
+ return _SubquerySet.empty()
2904
+
2905
+ def visit_cast_expr(self, expression: CastExpression) -> _SubquerySet:
2906
+ return self._traverse_nested_expressions(expression)
2907
+
2908
+ def visit_function_expr(self, expression: FunctionExpression) -> _SubquerySet:
2909
+ return self._traverse_nested_expressions(expression)
2910
+
2911
+ def visit_math_expr(self, expression: MathExpression) -> _SubquerySet:
2912
+ return self._traverse_nested_expressions(expression)
2913
+
2914
+ def visit_star_expr(self, expression: StarExpression) -> _SubquerySet:
2915
+ return _SubquerySet.empty()
2916
+
2917
+ def visit_subquery_expr(self, expression: SubqueryExpression) -> _SubquerySet:
2918
+ return _SubquerySet.of(expression.query)
2919
+
2920
+ def visit_window_expr(self, expression: WindowExpression) -> _SubquerySet:
2921
+ return self._traverse_nested_expressions(expression)
2922
+
2923
+ def visit_case_expr(self, expression: CaseExpression) -> _SubquerySet:
2924
+ return self._traverse_nested_expressions(expression)
2925
+
2926
+ def visit_predicate_expr(self, expression: AbstractPredicate) -> _SubquerySet:
2927
+ return self._traverse_nested_expressions(expression)
2928
+
2929
+ def _traverse_predicate_expressions(
2930
+ self, predicate: AbstractPredicate
2931
+ ) -> _SubquerySet:
2932
+ """Handler to collect subqueries from predicates."""
2933
+ return functools.reduce(
2934
+ operator.add,
2935
+ [
2936
+ expression.accept_visitor(self)
2937
+ for expression in predicate.iterexpressions()
2938
+ ],
2939
+ )
2940
+
2941
+ def _traverse_nested_expressions(self, expression: SqlExpression) -> _SubquerySet:
2942
+ """Handler to collect subqueries from SQL expressions."""
2943
+ return functools.reduce(
2944
+ operator.add,
2945
+ [
2946
+ nested_expression.accept_visitor(self)
2947
+ for nested_expression in expression.iterchildren()
2948
+ ],
2949
+ )
2950
+
2951
+
2952
+ class _BaseTableLookup(
2953
+ SqlExpressionVisitor[Optional[TableReference]], PredicateVisitor[TableReference]
2954
+ ):
2955
+ """Handler to determine the base table in an arbitrarily deep predicate or expression hierarchy.
2956
+
2957
+ This service is designed to traverse filter predicates or expressions operating on a single base table and provides exactly
2958
+ this table. As a special case, it also traverses dependent subqueries. In this case, the base table is the outer table.
2959
+
2960
+ The lookup may be started directly by calling the instantiated service with a predicate or an expression as argument.
2961
+
2962
+ Notes
2963
+ -----
2964
+ In cases where multiple applicable base tables are detected, a ``ValueError`` is raised. Therefore, using the lookup on a
2965
+ predicate is always guaranteed to provide a table (or raise an error) since all predicates operate on tables. On the other
2966
+ hand, checking an arbitrary SQL expression may or may not contain a base table (e.g. *CAST(42 AS float)*). Hence, an
2967
+ optional is returned for expressions.
2968
+ """
2969
+
2970
+ def visit_and_predicate(
2971
+ self, predicate: CompoundPredicate, components: Sequence[AbstractPredicate]
2972
+ ) -> TableReference:
2973
+ base_tables = {child_pred.accept_visitor(self) for child_pred in components}
2974
+ return self._fetch_valid_base_tables(base_tables)
2975
+
2976
+ def visit_or_predicate(
2977
+ self, predicate: CompoundPredicate, components: Sequence[AbstractPredicate]
2978
+ ) -> TableReference:
2979
+ base_tables = {child_pred.accept_visitor(self) for child_pred in components}
2980
+ return self._fetch_valid_base_tables(base_tables)
2981
+
2982
+ def visit_not_predicate(
2983
+ self, predicate: CompoundPredicate, child_predicate: AbstractPredicate
2984
+ ) -> TableReference:
2985
+ return child_predicate.accept_visitor(self)
2986
+
2987
+ def visit_binary_predicate(self, predicate: BinaryPredicate) -> bool:
2988
+ base_tables = (
2989
+ predicate.first_argument.accept_visitor(self),
2990
+ predicate.second_argument.accept_visitor(self),
2991
+ )
2992
+ return self._fetch_valid_base_tables(set(base_tables))
2993
+
2994
+ def visit_between_predicate(self, predicate: BetweenPredicate) -> bool:
2995
+ base_tables = (
2996
+ predicate.column.accept_visitor(self),
2997
+ predicate.interval_start.accept_visitor(self),
2998
+ predicate.interval_end.accept_visitor(self),
2999
+ )
3000
+ return self._fetch_valid_base_tables(set(base_tables))
3001
+
3002
+ def visit_in_predicate(self, predicate: InPredicate) -> bool:
3003
+ base_tables = {predicate.column.accept_visitor(self)}
3004
+ base_tables |= {val.accept_visitor(self) for val in predicate.values}
3005
+ return self._fetch_valid_base_tables(base_tables)
3006
+
3007
+ def visit_unary_predicate(self, predicate: UnaryPredicate) -> bool:
3008
+ return predicate.column.accept_visitor(self)
3009
+
3010
+ def visit_static_value_expr(
3011
+ self, expression: StaticValueExpression
3012
+ ) -> Optional[TableReference]:
3013
+ return None
3014
+
3015
+ def visit_column_expr(
3016
+ self, expression: ColumnExpression
3017
+ ) -> Optional[TableReference]:
3018
+ return expression.column.table
3019
+
3020
+ def visit_cast_expr(self, expression: CastExpression) -> Optional[TableReference]:
3021
+ return expression.casted_expression.accept_visitor(self)
3022
+
3023
+ def visit_function_expr(
3024
+ self, expression: FunctionExpression
3025
+ ) -> Optional[TableReference]:
3026
+ referenced_tables = {
3027
+ argument.accept_visitor(self) for argument in expression.arguments
3028
+ }
3029
+ return self._fetch_valid_base_tables(referenced_tables, accept_empty=True)
3030
+
3031
+ def visit_math_expr(self, expression: MathExpression) -> bool:
3032
+ base_tables = {
3033
+ child.accept_visitor(self) for child in expression.iterchildren()
3034
+ }
3035
+ return self._fetch_valid_base_tables(base_tables)
3036
+
3037
+ def visit_star_expr(self, expression: StarExpression) -> Optional[TableReference]:
3038
+ return expression.from_table
3039
+
3040
+ def visit_subquery_expr(
3041
+ self, expression: SubqueryExpression
3042
+ ) -> Optional[TableReference]:
3043
+ subquery = expression.query
3044
+ if not subquery.is_dependent():
3045
+ return None
3046
+ dependent_tables = subquery.unbound_tables()
3047
+ return self._fetch_valid_base_tables(dependent_tables, accept_empty=True)
3048
+
3049
+ def visit_window_expr(
3050
+ self, expression: WindowExpression
3051
+ ) -> Optional[TableReference]:
3052
+ # base tables can only appear in predicates and window functions are limited to SELECT statements
3053
+ return None
3054
+
3055
+ def visit_case_expr(self, expression: CaseExpression) -> Optional[TableReference]:
3056
+ # base tables can only appear in predicates and we only support case expressions in SELECT statements
3057
+ return None
3058
+
3059
+ def visit_predicate_expr(self, expression: AbstractPredicate) -> TableReference:
3060
+ return expression.accept_visitor(self)
3061
+
3062
+ def _fetch_valid_base_tables(
3063
+ self, base_tables: set[TableReference | None], *, accept_empty: bool = False
3064
+ ) -> Optional[TableReference]:
3065
+ """Handler to extract the actual base table from a set of candidate tables.
3066
+
3067
+ Parameters
3068
+ ----------
3069
+ base_tables : set[TableReference | None]
3070
+ The candidate tables. Usually, this should be a set containing exactly one base table and potentially an
3071
+ additional *None* value. In all other situations an error is raised (see below)
3072
+ accept_empty : bool, optional
3073
+ Whether an empty set of actual candidate tables (i.e. excluding *None* values) is an acceptable argument. If that
3074
+ is the case, *None* is returned in such a situation. Empty candidate sets raise an error by default.
3075
+
3076
+ Returns
3077
+ -------
3078
+ Optional[TableReference]
3079
+ The base table
3080
+
3081
+ Raises
3082
+ ------
3083
+ ValueError
3084
+ If the candidate tables either contain more than one actual table instance, or if empty candidate sets are not
3085
+ accepted and the set does not contain any *non-None* enties.
3086
+ """
3087
+ if None in base_tables:
3088
+ base_tables.remove(None)
3089
+ if len(base_tables) != 1 or (accept_empty and not base_tables):
3090
+ raise ValueError(
3091
+ f"Expected exactly one base predicate but found {base_tables}"
3092
+ )
3093
+ return util.simplify(base_tables) if base_tables else None
3094
+
3095
+ def __call__(self, elem: AbstractPredicate | SqlExpression) -> TableReference:
3096
+ if isinstance(elem, AbstractPredicate) and elem.is_join():
3097
+ raise ValueError(f"Cannot determine base table for join predicate '{elem}'")
3098
+ tables = elem.tables()
3099
+ if len(tables) == 1:
3100
+ return util.simplify(tables)
3101
+ base_table = elem.accept_visitor(self)
3102
+ if base_table is None:
3103
+ raise ValueError(f"No base table found in '{elem}'")
3104
+ return base_table
3105
+
3106
+
3107
+ def _collect_all_expressions(
3108
+ expression: SqlExpression, *, traverse_aggregations: bool = False
3109
+ ) -> frozenset[SqlExpression]:
3110
+ """Provides all expressions in a specific expression tree, including the root expression.
3111
+
3112
+ Parameters
3113
+ ----------
3114
+ expression : SqlExpression
3115
+ The root expression
3116
+ traverse_aggregations : bool, optional
3117
+ Whether expressions nested in aggregation functions should be included. Disabled by default.
3118
+
3119
+ Returns
3120
+ -------
3121
+ frozenset[SqlExpression]
3122
+ The expression as well as all child expressions, including deeply nested children.
3123
+ """
3124
+ if (
3125
+ not traverse_aggregations
3126
+ and isinstance(expression, FunctionExpression)
3127
+ and expression.is_aggregate()
3128
+ ):
3129
+ return frozenset({expression})
3130
+ child_expressions = util.set_union(
3131
+ _collect_all_expressions(child_expr) for child_expr in expression.iterchildren()
3132
+ )
3133
+ all_expressions = frozenset({expression} | child_expressions)
3134
+ return frozenset(
3135
+ {expression for expression in all_expressions if _needs_mapping(expression)}
3136
+ )
3137
+
3138
+
3139
+ def _determine_expression_phase(expression: SqlExpression) -> EvaluationPhase:
3140
+ """Calculates the evaluation phase during which an expression can be evaluated at the earliest."""
3141
+ match expression:
3142
+ case ColumnExpression():
3143
+ return EvaluationPhase.BaseTable
3144
+ case FunctionExpression() if expression.is_aggregate():
3145
+ return EvaluationPhase.PostAggregation
3146
+ case FunctionExpression() | MathExpression() | CastExpression():
3147
+ own_phase = (
3148
+ EvaluationPhase.Join
3149
+ if len(expression.tables()) > 1
3150
+ else EvaluationPhase.BaseTable
3151
+ )
3152
+ child_phase = max(
3153
+ _determine_expression_phase(child_expr)
3154
+ for child_expr in expression.iterchildren()
3155
+ )
3156
+ return max(own_phase, child_phase)
3157
+ case SubqueryExpression():
3158
+ return (
3159
+ EvaluationPhase.BaseTable
3160
+ if len(expression.query.unbound_tables()) < 2
3161
+ else EvaluationPhase.PostJoin
3162
+ )
3163
+ case StarExpression() | StaticValueExpression():
3164
+ # TODO: should we rather raise an error in this case?
3165
+ return EvaluationPhase.BaseTable
3166
+ case WindowExpression() | CaseExpression() | StaticValueExpression():
3167
+ # these expressions can currently only appear within SELECT clauses
3168
+ return EvaluationPhase.PostAggregation
3169
+ case _:
3170
+ raise ValueError(f"Unknown expression type: '{expression}'")
3171
+
3172
+
3173
+ def _determine_predicate_phase(predicate: AbstractPredicate) -> EvaluationPhase:
3174
+ """Calculates the evaluation phase during which a predicate can be evaluated at the earliest.
3175
+
3176
+ See Also
3177
+ --------
3178
+ _determine_expression_phase
3179
+ """
3180
+ nested_subqueries = predicate.accept_visitor(_SubqueryDetector())
3181
+ subquery_tables = len(
3182
+ util.set_union(
3183
+ subquery.bound_tables() for subquery in nested_subqueries.subqueries
3184
+ )
3185
+ )
3186
+ n_tables = len(predicate.tables()) - subquery_tables
3187
+ if n_tables == 1:
3188
+ # It could actually be that the number of tables is negative. E.g. HAVING count(*) < (SELECT min(r_a) FROM R)
3189
+ # Therefore, we only check for exactly 1 table
3190
+ return EvaluationPhase.BaseTable
3191
+ if subquery_tables:
3192
+ # If there are subqueries and multiple base tables present, we encoutered a predicate like
3193
+ # R.a = (SELECT min(S.b) FROM S WHERE S.b = T.c) with a dependent subquery on T. By default, such predicates should be
3194
+ # executed after the join phase.
3195
+ return EvaluationPhase.PostJoin
3196
+
3197
+ expression_phase = max(
3198
+ _determine_expression_phase(expression)
3199
+ for expression in predicate.iterexpressions()
3200
+ if type(expression) not in {StarExpression, StaticValueExpression}
3201
+ )
3202
+ if expression_phase > EvaluationPhase.Join:
3203
+ return expression_phase
3204
+
3205
+ return (
3206
+ EvaluationPhase.Join
3207
+ if isinstance(predicate, BinaryPredicate)
3208
+ else EvaluationPhase.PostJoin
3209
+ )
3210
+
3211
+
3212
+ def _filter_eval_phase(
3213
+ predicate: AbstractPredicate, expected_eval_phase: EvaluationPhase
3214
+ ) -> Optional[AbstractPredicate]:
3215
+ """Provides all parts of predicate that can be evaluated during a specific logical query execution phase.
3216
+
3217
+ The following rules are used to determine matching (sub-)predicates:
3218
+
3219
+ - For base predicates, either the entire matches the expected evaluation phase, or none at all.
3220
+ - For conjunctive predicates, the matching parts are combined into a smaller conjunction. If no part matches, *None* is
3221
+ returned.
3222
+ - For disjunctive predicates or negations, the same rules as for base predicates are used: either the entire predicate
3223
+ matches, or nothing does.
3224
+
3225
+ Parameters
3226
+ ----------
3227
+ predicate : AbstractPredicate
3228
+ The predicate to check
3229
+ expected_eval_phase : EvaluationPhase
3230
+ The desired evaluation phase
3231
+
3232
+ Returns
3233
+ -------
3234
+ Optional[AbstractPredicate]
3235
+ A predicate composed of the matching (sub-) predicates, or *None* if there is no match whatsoever.
3236
+
3237
+ See Also
3238
+ --------
3239
+ _determine_predicate_phase
3240
+ """
3241
+ eval_phase = _determine_predicate_phase(predicate)
3242
+ if eval_phase < expected_eval_phase:
3243
+ return None
3244
+
3245
+ if (
3246
+ isinstance(predicate, CompoundPredicate)
3247
+ and predicate.operation == CompoundOperator.And
3248
+ ):
3249
+ child_predicates = [
3250
+ child
3251
+ for child in predicate.children
3252
+ if _determine_predicate_phase(child) == expected_eval_phase
3253
+ ]
3254
+ return (
3255
+ CompoundPredicate.create_and(child_predicates) if child_predicates else None
3256
+ )
3257
+
3258
+ return predicate if eval_phase == expected_eval_phase else None
3259
+
3260
+
3261
+ class _ImplicitRelalgParser:
3262
+ """Parser to generate a `RelNode` tree from `SqlQuery` instances.
3263
+
3264
+ Parameters
3265
+ ----------
3266
+ query : ImplicitSqlQuery
3267
+ The query to parse
3268
+ provided_base_tables : Optional[dict[TableReference, RelNode]], optional
3269
+ When parsing subqueries, these are the tables that are provided by the outer query and their corresponding relational
3270
+ algebra fragments.
3271
+
3272
+ Notes
3273
+ -----
3274
+ Our parser operates in four strictly sequential stages. These stages approximately correspond to the logical evaluation
3275
+ stages of an SQL query. For reference, see [eder-sql-eval-order]_. At the same time, the entire process is loosely oriented
3276
+ on the query execution strategy applied by PostgreSQL.
3277
+
3278
+ During the initial stage, all base tables are processed. This includes all filters that may be executed directly on the
3279
+ base table, as well as mappings that enable the filters, e.g. for predicates such as *CAST(R.a AS int) = 42*. As a special
3280
+ case, we also handle *EXISTS* and *MISSING* predicates during this stage. The end result of the initial stage is a
3281
+ dictionary that maps each base table to a relational algebra fragment corresponds to the scan of the base table as well as
3282
+ all filters, etc..
3283
+
3284
+ Secondly, all joins are processed. This process builds on the initial dictionary of base tables and iteratively combines
3285
+ pairs of fragments according to their join predicate. In the end, usually just a single fragment remains. This fragment has
3286
+ exactly one root node that corresponds to the final joined relation. If there are multiple fragments (and hence root nodes)
3287
+ remaining, we need to use cross products to ensure that we just have a single relation in the end. A post-processing step
3288
+ applies all filter predicates that were not recognized as joins but that require columns from multiple base tables, e.g.
3289
+ *R.a + S.b < 42*.
3290
+
3291
+ The third stage is concerned with grouping and aggregation. If the query contains aggregates or a *GROUP BY* clause, these
3292
+ operations are now inserted. Since we produced just a single root node during the second stage, our grouping uses this
3293
+ node as input. Once again, during a post-processing step all filter predicates from the *HAVING* clause are inserted as
3294
+ selections.
3295
+
3296
+ The fourth and final phase executes all "cleanup" actions such as sorting, duplicate removal and projection.
3297
+
3298
+ Notice that during each stage, additional mapping steps can be required if some expression requires input that has not
3299
+ yet been calculated.
3300
+
3301
+ References
3302
+ ----------
3303
+
3304
+ .. [eder-sql-eval-oder]_ https://blog.jooq.org/a-beginners-guide-to-the-true-order-of-sql-operations/
3305
+
3306
+ """
3307
+
3308
+ def __init__(
3309
+ self,
3310
+ query: ImplicitSqlQuery,
3311
+ *,
3312
+ provided_base_tables: Optional[dict[TableReference, RelNode]] = None,
3313
+ ) -> None:
3314
+ self._query = query
3315
+ self._base_table_fragments: dict[TableReference, RelNode] = {}
3316
+ self._required_columns: dict[TableReference, set[ColumnReference]] = (
3317
+ collections.defaultdict(set)
3318
+ )
3319
+ self._provided_base_tables: dict[TableReference, RelNode] = (
3320
+ provided_base_tables if provided_base_tables else {}
3321
+ )
3322
+
3323
+ if query:
3324
+ query_cols = self._query.columns()
3325
+ util.collections.foreach(
3326
+ query_cols, lambda col: self._required_columns[col.table].add(col)
3327
+ )
3328
+
3329
+ def generate_relnode(self) -> RelNode:
3330
+ """Produces a relational algebra tree for the current query.
3331
+
3332
+ Returns
3333
+ -------
3334
+ RelNode
3335
+ Root node of the algebraic expression
3336
+ """
3337
+ if isinstance(self._query, SetQuery):
3338
+ return self._parse_set_query(self._query)
3339
+
3340
+ # TODO: robustness: query without FROM clause
3341
+
3342
+ if self._query.cte_clause:
3343
+ for cte in self._query.cte_clause.queries:
3344
+ cte_root = self._add_subquery(cte.query)
3345
+ self._add_table(cte.target_table, input_node=cte_root)
3346
+
3347
+ # we add the WHERE clause before all explicit JOIN statements to make sure filters are already present and we can
3348
+ # stitch together the correct fragments in OUTER JOINs
3349
+ # Once the explicit JOINs have been processed, we continue with all remaining implicit joins
3350
+ # TODO: since the implementation of JOIN statements is currently undergoing a major rework, we don't process such
3351
+ # statements at all
3352
+
3353
+ util.collections.foreach(self._query.from_clause.items, self._add_table_source)
3354
+
3355
+ if self._query.where_clause:
3356
+ self._add_predicate(
3357
+ self._query.where_clause.predicate, eval_phase=EvaluationPhase.BaseTable
3358
+ )
3359
+
3360
+ final_fragment = self._generate_initial_join_order()
3361
+
3362
+ if self._query.where_clause:
3363
+ # add all post-join filters here
3364
+ final_fragment = self._add_predicate(
3365
+ self._query.where_clause.predicate,
3366
+ input_node=final_fragment,
3367
+ eval_phase=EvaluationPhase.PostJoin,
3368
+ )
3369
+
3370
+ final_fragment = self._add_aggregation(final_fragment)
3371
+ if self._query.having_clause:
3372
+ final_fragment = self._add_predicate(
3373
+ self._query.having_clause.condition,
3374
+ input_node=final_fragment,
3375
+ eval_phase=EvaluationPhase.PostAggregation,
3376
+ )
3377
+
3378
+ if self._query.orderby_clause:
3379
+ final_fragment = self._add_ordering(
3380
+ self._query.orderby_clause.expressions, input_node=final_fragment
3381
+ )
3382
+
3383
+ final_fragment = self._add_final_projection(final_fragment)
3384
+ return final_fragment
3385
+
3386
+ def _update_query(self, query: SelectStatement) -> None:
3387
+ self._query = query
3388
+ query_cols = self._query.columns()
3389
+ util.collections.foreach(
3390
+ query_cols, lambda col: self._required_columns[col.table].add(col)
3391
+ )
3392
+
3393
+ def _parse_set_query(self, query: SetQuery) -> RelNode:
3394
+ """Handler method to translate a set query into a relational algebra fragment.
3395
+
3396
+ Parameters
3397
+ ----------
3398
+ query : SetQuery
3399
+ _description_
3400
+
3401
+ Returns
3402
+ -------
3403
+ RelNode
3404
+ _description_
3405
+
3406
+ Raises
3407
+ ------
3408
+ ValueError
3409
+ _description_
3410
+ """
3411
+ parser = _ImplicitRelalgParser(None)
3412
+
3413
+ if query.cte_clause:
3414
+ for cte in query.cte_clause.queries:
3415
+ parser._update_query(cte.query)
3416
+ parser.generate_relnode() # we don't care about the result, we just want to add the CTE to the base tables
3417
+
3418
+ parser._update_query(query.left_query)
3419
+ left_relalg = parser.generate_relnode()
3420
+ parser._update_query(query.right_query)
3421
+ right_relalg = parser.generate_relnode()
3422
+
3423
+ match query.set_operation:
3424
+ case SetOperator.Union:
3425
+ final_fragment = Union(left_relalg, right_relalg)
3426
+ final_fragment = DuplicateElimination(final_fragment)
3427
+ case SetOperator.UnionAll:
3428
+ final_fragment = Union(left_relalg, right_relalg)
3429
+ case SetOperator.Intersect:
3430
+ final_fragment = Intersection(left_relalg, right_relalg)
3431
+ case SetOperator.Except:
3432
+ final_fragment = Difference(left_relalg, right_relalg)
3433
+ case _:
3434
+ raise ValueError(f"Unknown set operation: '{query.set_operation}'")
3435
+
3436
+ if query.orderby_clause:
3437
+ final_fragment = self._add_ordering(
3438
+ query.orderby_clause.expressions, input_node=final_fragment
3439
+ )
3440
+
3441
+ return final_fragment
3442
+
3443
+ def _resolve(self, table: TableReference) -> RelNode:
3444
+ """Provides the algebra fragment for a specific base table, resorting to outer query tables if necessary."""
3445
+ if table in self._base_table_fragments:
3446
+ return self._base_table_fragments[table]
3447
+ return self._provided_base_tables[table]
3448
+
3449
+ def _add_table(
3450
+ self, table: TableReference, *, input_node: Optional[RelNode] = None
3451
+ ) -> RelNode:
3452
+ """Generates and stores a new base table relation node for a specific table.
3453
+
3454
+ The relation will be stored in `self._base_table_fragments`.
3455
+
3456
+ Parameters
3457
+ ----------
3458
+ table : TableReference
3459
+ The base table
3460
+ input_node : Optional[RelNode], optional
3461
+ If the base table corresponds to a subquery or CTE target, this is the root node of the fragment that computes the
3462
+ actual subquery.
3463
+
3464
+ Returns
3465
+ -------
3466
+ RelNode
3467
+ A relational algebra fragment
3468
+ """
3469
+ required_cols = self._required_columns[table]
3470
+ table_node = Relation(table, required_cols, subquery_input=input_node)
3471
+ self._base_table_fragments[table] = table_node
3472
+ return table_node
3473
+
3474
+ def _add_table_source(self, table_source: TableSource) -> RelNode:
3475
+ """Generates the appropriate algebra fragment for a specific table source.
3476
+
3477
+ The fragment will be stored in `self._base_table_fragments`.
3478
+ """
3479
+ match table_source:
3480
+ case DirectTableSource():
3481
+ if table_source.table.virtual:
3482
+ # Virtual tables in direct table sources are only created through references to CTEs. However, these CTEs
3483
+ # have already been included in the base table fragments.
3484
+ return self._base_table_fragments[table_source.table]
3485
+ return self._add_table(table_source.table)
3486
+ case SubqueryTableSource():
3487
+ subquery_root = self._add_subquery(table_source.query)
3488
+ self._base_table_fragments[table_source.target_table] = subquery_root
3489
+ return self._add_table(
3490
+ table_source.target_table, input_node=subquery_root
3491
+ )
3492
+ case JoinTableSource():
3493
+ raise ValueError(
3494
+ f"Explicit JOIN syntax is currently not supported: '{table_source}'"
3495
+ )
3496
+ case _:
3497
+ raise ValueError(f"Unknown table source: '{table_source}'")
3498
+
3499
+ def _generate_initial_join_order(self) -> RelNode:
3500
+ """Combines all base table fragments to generate a single root relation corresponding to the join of all fragments.
3501
+
3502
+ If necessary, fragments are combined via cross products.
3503
+
3504
+ Returns
3505
+ -------
3506
+ RelNode
3507
+ The root operator corresponding to the relation that joins all base table fragments.
3508
+ """
3509
+ # TODO: figure out the interaction between implicit and explicit joins, especially regarding their timing
3510
+
3511
+ joined_tables: set[TableReference] = set()
3512
+ for table_source in self._query.from_clause.items:
3513
+ # TODO: determine correct join partners for explicit JOINs
3514
+ joined_tables |= table_source.tables()
3515
+
3516
+ if self._query.where_clause:
3517
+ self._add_predicate(
3518
+ self._query.where_clause.predicate, eval_phase=EvaluationPhase.Join
3519
+ )
3520
+
3521
+ head_nodes = set(self._base_table_fragments.values())
3522
+ if len(head_nodes) == 1:
3523
+ return util.simplify(head_nodes)
3524
+
3525
+ current_head, *remaining_nodes = head_nodes
3526
+ for remaining_node in remaining_nodes:
3527
+ current_head = CrossProduct(current_head, remaining_node)
3528
+ return current_head
3529
+
3530
+ def _add_aggregation(self, input_node: RelNode) -> RelNode:
3531
+ """Generates all necesssary aggregation operations for the SQL query.
3532
+
3533
+ If there are no necessary aggregations, the current algebra tree is returned unmodified.
3534
+
3535
+ Parameters
3536
+ ----------
3537
+ input_node : RelNode
3538
+ The root of the current algebra tree
3539
+
3540
+ Returns
3541
+ -------
3542
+ RelNode
3543
+ The algebra tree, potentially expanded by grouping, mapping and selection nodes.
3544
+ """
3545
+ aggregation_collector = ExpressionCollector(
3546
+ lambda e: isinstance(e, FunctionExpression) and e.is_aggregate()
3547
+ )
3548
+ aggregation_functions: set[FunctionExpression] = util.set_union(
3549
+ select_expr.accept_visitor(aggregation_collector)
3550
+ for select_expr in self._query.select_clause.iterexpressions()
3551
+ )
3552
+
3553
+ if self._query.having_clause:
3554
+ aggregation_functions |= util.set_union(
3555
+ having_expr.accept_visitor(aggregation_collector)
3556
+ for having_expr in self._query.having_clause.iterexpressions()
3557
+ )
3558
+ if not self._query.groupby_clause and not aggregation_functions:
3559
+ return input_node
3560
+
3561
+ aggregation_arguments: set[SqlExpression] = set()
3562
+ for agg_func in aggregation_functions:
3563
+ aggregation_arguments |= util.set_union(
3564
+ _collect_all_expressions(arg, traverse_aggregations=True)
3565
+ for arg in agg_func.arguments
3566
+ )
3567
+ missing_expressions = aggregation_arguments - input_node.provided_expressions()
3568
+ if missing_expressions:
3569
+ input_node = Map(
3570
+ input_node, _generate_expression_mapping_dict(missing_expressions)
3571
+ )
3572
+
3573
+ group_cols = (
3574
+ self._query.groupby_clause.group_columns
3575
+ if self._query.groupby_clause
3576
+ else []
3577
+ )
3578
+ aggregates: dict[frozenset[SqlExpression], set[FunctionExpression]] = (
3579
+ collections.defaultdict(set)
3580
+ )
3581
+ for agg_func in aggregation_functions:
3582
+ aggregates[agg_func.arguments].add(agg_func)
3583
+ groupby_node = Grouping(
3584
+ input_node,
3585
+ group_columns=group_cols,
3586
+ aggregates={
3587
+ agg_input: frozenset(agg_funcs)
3588
+ for agg_input, agg_funcs in aggregates.items()
3589
+ },
3590
+ )
3591
+ return groupby_node
3592
+
3593
+ def _add_final_projection(self, input_node: RelNode) -> RelNode:
3594
+ """Generates all necessary output preparation nodes.
3595
+
3596
+ Such nodes include the final projection, sorting, duplicate elimination or limit.
3597
+
3598
+ Parameters
3599
+ ----------
3600
+ input_node : RelNode
3601
+ The root of the current algebra tree
3602
+
3603
+ Returns
3604
+ -------
3605
+ RelNode
3606
+ The algebra tree, potentially expanded by some final nodes.
3607
+ """
3608
+ # TODO: Sorting, Duplicate elimination, limit
3609
+ if self._query.select_clause.is_star():
3610
+ return input_node
3611
+ required_expressions = util.set_union(
3612
+ _collect_all_expressions(target.expression)
3613
+ for target in self._query.select_clause.targets
3614
+ )
3615
+ missing_expressions = required_expressions - input_node.provided_expressions()
3616
+ final_node = (
3617
+ Map(input_node, _generate_expression_mapping_dict(missing_expressions))
3618
+ if missing_expressions
3619
+ else input_node
3620
+ )
3621
+ return Projection(
3622
+ final_node,
3623
+ [target.expression for target in self._query.select_clause.targets],
3624
+ )
3625
+
3626
+ def _add_ordering(
3627
+ self, ordering: Sequence[OrderByExpression], *, input_node: RelNode
3628
+ ) -> RelNode:
3629
+ sorting: list[tuple[SqlExpression, SortDirection]] = []
3630
+ final_fragment = input_node
3631
+
3632
+ for order in ordering:
3633
+ final_fragment = self._add_expression(
3634
+ order.column, input_node=final_fragment
3635
+ )
3636
+ sorting.append((order.column, "asc" if order.ascending else "desc"))
3637
+
3638
+ final_fragment = Sort(final_fragment, sorting)
3639
+ return final_fragment
3640
+
3641
+ def _add_predicate(
3642
+ self,
3643
+ predicate: AbstractPredicate,
3644
+ *,
3645
+ input_node: Optional[RelNode] = None,
3646
+ eval_phase: EvaluationPhase = EvaluationPhase.BaseTable,
3647
+ ) -> RelNode:
3648
+ """Inserts a selection into the corresponding relational algebra fragment.
3649
+
3650
+ Parameters
3651
+ ----------
3652
+ predicate : AbstractPredicate
3653
+ The entire selection. Notice that only those parts of the predicate will be included in the selection, that match
3654
+ the expected `eval_phase`.
3655
+ input_node : Optional[RelNode], optional
3656
+ The current fragment. For the base table evaluation phase this can be *None*. The actual base table will be
3657
+ inferred from the predicate. For all other evaluation phases this parameter is required.
3658
+ eval_phase : EvaluationPhase, optional
3659
+ The current evaluation phase, by default `EvaluationPhase.BaseTable`
3660
+
3661
+ Returns
3662
+ -------
3663
+ RelNode
3664
+ The complete fragment. The return value is mostly only interesting for evaluation phases later than the base table
3665
+ evaluation. For the base table evaluation, the result will also be stored in `self._base_table_fragments`.
3666
+
3667
+ See Also
3668
+ --------
3669
+ _filter_eval_phase
3670
+ """
3671
+ predicate = _filter_eval_phase(predicate, eval_phase)
3672
+ if predicate is None:
3673
+ return input_node
3674
+
3675
+ match eval_phase:
3676
+ case EvaluationPhase.BaseTable:
3677
+ for base_table, base_pred in self._split_filter_predicate(
3678
+ predicate
3679
+ ).items():
3680
+ base_table_fragment = self._convert_predicate(
3681
+ base_pred, input_node=self._base_table_fragments[base_table]
3682
+ )
3683
+ self._base_table_fragments[base_table] = base_table_fragment
3684
+ return base_table_fragment
3685
+ case EvaluationPhase.Join:
3686
+ for join_predicate in self._split_join_predicate(predicate):
3687
+ join_node = self._convert_join_predicate(join_predicate)
3688
+ for outer_table in join_node.tables():
3689
+ self._base_table_fragments[outer_table] = join_node
3690
+ return join_node
3691
+ case EvaluationPhase.PostJoin | EvaluationPhase.PostAggregation:
3692
+ assert input_node is not None
3693
+ # Generally speaking, when consuming a post-join predicate, all required tables should be available by now.
3694
+ # However, there is one caveat: a complex predicate that can currently only be executed after the join phase
3695
+ # (e.g. disjunctions of join predicates) could contain a correlated scalar subquery. In this case, some of the
3696
+ # required tables might not be available yet (more precisely, the native tables from the depedent subquery
3697
+ # are not available). We solve this situation by introducing cross products between the dependent subquery
3698
+ # and the outer table _within the subquery_. This is because the subquery needs to reference the outer table
3699
+ # in its join predicate.
3700
+ if not predicate.tables().issubset(input_node.tables()):
3701
+ missing_tables = predicate.tables() - input_node.tables()
3702
+ for outer_table in missing_tables:
3703
+ if outer_table not in self._provided_base_tables:
3704
+ # the table will be supplied by a subquery
3705
+ continue
3706
+ input_node = CrossProduct(
3707
+ input_node, self._provided_base_tables[outer_table]
3708
+ )
3709
+ return self._convert_predicate(predicate, input_node=input_node)
3710
+ case _:
3711
+ raise ValueError(
3712
+ f"Unknown evaluation phase '{eval_phase}' for predicate '{predicate}'"
3713
+ )
3714
+
3715
+ def _convert_predicate(
3716
+ self, predicate: AbstractPredicate, *, input_node: RelNode
3717
+ ) -> RelNode:
3718
+ """Generates the appropriate selection nodes for a specific predicate.
3719
+
3720
+ Depending on the specific predicate, operations other than a plain old selection might be required. For example,
3721
+ for disjunctions involving subqueries, a union is necessary. Therefore, the conversion might actually force a deviation
3722
+ from pure algebra trees and require an directed, acyclic graph instead.
3723
+
3724
+ Likewise, applying a predicate can make a preparatory mapping operation necessary, if the expressions required by the
3725
+ predicate are not yet produced by the input node.
3726
+
3727
+ Parameters
3728
+ ----------
3729
+ predicate : AbstractPredicate
3730
+ The predicate that should be converted
3731
+ input_node : RelNode
3732
+ The operator after which the predicate is required. It is assumed that the input node is actually capable of
3733
+ producing the required attributes in order to apply the predicate. For example, if the predicate consumes
3734
+ attributes from multiple base relations, it is assumed that the input node provides tuples that already contain
3735
+ these nodes.
3736
+
3737
+ Returns
3738
+ -------
3739
+ RelNode
3740
+ The algebra fragment
3741
+ """
3742
+ contains_subqueries = _SubqueryDetector()
3743
+ final_fragment = input_node
3744
+
3745
+ if isinstance(predicate, UnaryPredicate) and not predicate.accept_visitor(
3746
+ contains_subqueries
3747
+ ):
3748
+ final_fragment = self._ensure_predicate_applicability(
3749
+ predicate, final_fragment
3750
+ )
3751
+ final_fragment = Selection(final_fragment, predicate)
3752
+ return final_fragment
3753
+ elif isinstance(predicate, UnaryPredicate):
3754
+ subquery_target = (
3755
+ "semijoin"
3756
+ if predicate.operation == LogicalOperator.Exists
3757
+ else "antijoin"
3758
+ )
3759
+ return self._add_expression(
3760
+ predicate.column,
3761
+ input_node=final_fragment,
3762
+ subquery_target=subquery_target,
3763
+ )
3764
+
3765
+ if isinstance(predicate, BetweenPredicate) and not predicate.accept_visitor(
3766
+ contains_subqueries
3767
+ ):
3768
+ final_fragment = self._ensure_predicate_applicability(
3769
+ predicate, final_fragment
3770
+ )
3771
+ final_fragment = Selection(final_fragment, predicate)
3772
+ return final_fragment
3773
+ elif isinstance(predicate, BetweenPredicate):
3774
+ # BETWEEN predicate with scalar subquery
3775
+ final_fragment = self._add_expression(
3776
+ predicate.column, input_node=final_fragment
3777
+ )
3778
+ final_fragment = self._add_expression(
3779
+ predicate.interval_start, input_node=final_fragment
3780
+ )
3781
+ final_fragment = self._add_expression(
3782
+ predicate.interval_end, input_node=final_fragment
3783
+ )
3784
+ final_fragment = Selection(final_fragment, predicate)
3785
+ return final_fragment
3786
+
3787
+ if isinstance(predicate, InPredicate) and not predicate.accept_visitor(
3788
+ contains_subqueries
3789
+ ):
3790
+ # we need to determine the required expressions due to IN predicates like "r_a + 42 IN (1, 2, 3)"
3791
+ # or "r_a IN (r_b + 42, 42)"
3792
+ final_fragment = self._ensure_predicate_applicability(
3793
+ predicate, final_fragment
3794
+ )
3795
+ final_fragment = Selection(final_fragment, predicate)
3796
+ return final_fragment
3797
+ elif isinstance(predicate, InPredicate):
3798
+ # TODO: test weird IN predicates like r_a IN (1, 2, (SELECT min(...)), 4)
3799
+ # or even r_a IN ((SELECT r_a FROM ...) + (SELECT min(...)))
3800
+ pure_in_values: list[SqlExpression] = []
3801
+ subquery_in_values: list[tuple[SqlExpression, _SubquerySet]] = []
3802
+ for value in predicate.values:
3803
+ detected_subqueries = value.accept_visitor(contains_subqueries)
3804
+ if detected_subqueries and not all(
3805
+ subquery.is_scalar() for subquery in detected_subqueries.subqueries
3806
+ ):
3807
+ subquery_in_values.append((value, detected_subqueries))
3808
+ else:
3809
+ final_fragment = self._add_expression(
3810
+ value, input_node=final_fragment
3811
+ )
3812
+ pure_in_values.append(value)
3813
+ final_fragment = self._add_expression(
3814
+ predicate.column, input_node=final_fragment
3815
+ )
3816
+ if pure_in_values:
3817
+ reduced_predicate = InPredicate(predicate.column, pure_in_values)
3818
+ final_fragment = Selection(final_fragment, reduced_predicate)
3819
+ for subquery_value, detected_subqueries in subquery_in_values:
3820
+ final_fragment = self._add_expression(
3821
+ subquery_value,
3822
+ input_node=final_fragment,
3823
+ subquery_target="in",
3824
+ in_column=predicate.column,
3825
+ )
3826
+ return final_fragment
3827
+
3828
+ if isinstance(predicate, BinaryPredicate) and not predicate.accept_visitor(
3829
+ contains_subqueries
3830
+ ):
3831
+ final_fragment = self._ensure_predicate_applicability(
3832
+ predicate, final_fragment
3833
+ )
3834
+ final_fragment = Selection(final_fragment, predicate)
3835
+ return final_fragment
3836
+ elif isinstance(predicate, BinaryPredicate):
3837
+ if predicate.first_argument.accept_visitor(contains_subqueries):
3838
+ final_fragment = self._add_expression(
3839
+ predicate.first_argument,
3840
+ input_node=final_fragment,
3841
+ subquery_target="scalar",
3842
+ )
3843
+ if predicate.second_argument.accept_visitor(contains_subqueries):
3844
+ final_fragment = self._add_expression(
3845
+ predicate.second_argument,
3846
+ input_node=final_fragment,
3847
+ subquery_target="scalar",
3848
+ )
3849
+ final_fragment = self._ensure_predicate_applicability(
3850
+ predicate, final_fragment
3851
+ )
3852
+ final_fragment = Selection(final_fragment, predicate)
3853
+ return final_fragment
3854
+
3855
+ if not isinstance(predicate, CompoundPredicate):
3856
+ raise ValueError(f"Unknown predicate type: '{predicate}'")
3857
+ match predicate.operation:
3858
+ case CompoundOperator.And | CompoundOperator.Or:
3859
+ regular_predicates: list[AbstractPredicate] = []
3860
+ subquery_predicates: list[AbstractPredicate] = []
3861
+ for child_pred in predicate.iterchildren():
3862
+ if child_pred.accept_visitor(contains_subqueries):
3863
+ subquery_predicates.append(child_pred)
3864
+ else:
3865
+ regular_predicates.append(child_pred)
3866
+ if regular_predicates:
3867
+ simplified_composite = CompoundPredicate.create(
3868
+ predicate.operation, regular_predicates
3869
+ )
3870
+ final_fragment = self._ensure_predicate_applicability(
3871
+ simplified_composite, final_fragment
3872
+ )
3873
+ final_fragment = Selection(final_fragment, simplified_composite)
3874
+ for subquery_pred in subquery_predicates:
3875
+ if predicate.operation == CompoundOperator.And:
3876
+ final_fragment = self._convert_predicate(
3877
+ subquery_pred, input_node=final_fragment
3878
+ )
3879
+ continue
3880
+ subquery_branch = self._convert_predicate(
3881
+ subquery_pred, input_node=input_node
3882
+ )
3883
+ final_fragment = Union(final_fragment, subquery_branch)
3884
+ return final_fragment
3885
+
3886
+ case CompoundOperator.Not:
3887
+ if not predicate.children.accept_visitor(contains_subqueries):
3888
+ final_fragment = self._ensure_predicate_applicability(
3889
+ predicate, final_fragment
3890
+ )
3891
+ final_fragment = Selection(final_fragment, predicate)
3892
+ return final_fragment
3893
+ subquery_branch = self._convert_predicate(
3894
+ predicate.children, input_node=input_node
3895
+ )
3896
+ final_fragment = Difference(final_fragment, subquery_branch)
3897
+ return final_fragment
3898
+
3899
+ case _:
3900
+ raise ValueError(
3901
+ f"Unknown operation for composite predicate '{predicate}'"
3902
+ )
3903
+
3904
+ def _convert_join_predicate(self, predicate: AbstractPredicate) -> RelNode:
3905
+ """Generates the appropriate join nodes for a specific predicate.
3906
+
3907
+ Most of the implementation is structurally similar to `_convert_predicate`, so take a look at its documentation for
3908
+ details.
3909
+
3910
+ See Also
3911
+ --------
3912
+ _ImplicitRelalgParser._convert_predicate
3913
+ """
3914
+ contains_subqueries = _SubqueryDetector()
3915
+ nested_subqueries = predicate.accept_visitor(contains_subqueries)
3916
+ subquery_tables = util.set_union(
3917
+ subquery.bound_tables() for subquery in nested_subqueries.subqueries
3918
+ )
3919
+ table_fragments = {
3920
+ self._resolve(join_partner)
3921
+ for join_partner in predicate.tables() - subquery_tables
3922
+ }
3923
+ if len(table_fragments) == 1:
3924
+ input_node = util.simplify(table_fragments)
3925
+ provided_expressions = self._collect_provided_expressions(input_node)
3926
+ required_expressions = util.set_union(
3927
+ _collect_all_expressions(e) for e in predicate.iterexpressions()
3928
+ )
3929
+ missing_expressions = required_expressions - provided_expressions
3930
+ if missing_expressions:
3931
+ final_fragment = Map(
3932
+ input_node, _generate_expression_mapping_dict(missing_expressions)
3933
+ )
3934
+ else:
3935
+ final_fragment = input_node
3936
+ return Selection(final_fragment, predicate)
3937
+ if len(table_fragments) != 2:
3938
+ raise ValueError(
3939
+ "Expected exactly two base table fragments for join predicate "
3940
+ f"'{predicate}', but found {table_fragments}"
3941
+ )
3942
+
3943
+ required_expressions = util.set_union(
3944
+ _collect_all_expressions(e) for e in predicate.iterexpressions()
3945
+ )
3946
+ if isinstance(predicate, BinaryPredicate):
3947
+ first_input, second_input = table_fragments
3948
+ first_arg, second_arg = predicate.first_argument, predicate.second_argument
3949
+ if first_arg.tables() <= first_input.tables(
3950
+ ignore_subqueries=True
3951
+ ) and second_arg.tables() <= second_input.tables(ignore_subqueries=True):
3952
+ left_input, right_input = first_input, second_input
3953
+ elif first_arg.tables() <= second_input.tables(
3954
+ ignore_subqueries=True
3955
+ ) and second_arg.tables() <= first_input.tables(ignore_subqueries=True):
3956
+ left_input, right_input = second_input, first_input
3957
+ else:
3958
+ raise ValueError(f"Unsupported join predicate '{predicate}'")
3959
+
3960
+ left_input = self._add_expression(first_arg, input_node=left_input)
3961
+ right_input = self._add_expression(second_arg, input_node=right_input)
3962
+
3963
+ provided_expressions = self._collect_provided_expressions(
3964
+ left_input, right_input
3965
+ )
3966
+ missing_expressions = required_expressions - provided_expressions
3967
+ left_mappings: list[SqlExpression] = []
3968
+ right_mappings: list[SqlExpression] = []
3969
+ for missing_expr in missing_expressions:
3970
+ if missing_expr.tables() <= left_input.tables():
3971
+ left_mappings.append(missing_expr)
3972
+ elif missing_expr.tables() <= right_input.tables():
3973
+ right_mappings.append(missing_expr)
3974
+ else:
3975
+ raise ValueError(
3976
+ "Cannot calculate expression on left or right input: "
3977
+ f"'{missing_expr}' for predicate '{predicate}'"
3978
+ )
3979
+ if left_mappings:
3980
+ left_input = Map(
3981
+ left_input, _generate_expression_mapping_dict(left_mappings)
3982
+ )
3983
+ if right_mappings:
3984
+ right_input = Map(
3985
+ right_input, _generate_expression_mapping_dict(right_mappings)
3986
+ )
3987
+ return ThetaJoin(left_input, right_input, predicate)
3988
+
3989
+ if not isinstance(predicate, CompoundPredicate):
3990
+ raise ValueError(
3991
+ f"Unsupported join predicate '{predicate}'. Perhaps this should be a post-join filter?"
3992
+ )
3993
+
3994
+ match predicate.operation:
3995
+ case CompoundOperator.And | CompoundOperator.Or:
3996
+ regular_predicates: list[AbstractPredicate] = []
3997
+ subquery_predicates: list[AbstractPredicate] = []
3998
+ for child_pred in predicate.children:
3999
+ if predicate.accept_visitor(contains_subqueries):
4000
+ subquery_predicates.append(child_pred)
4001
+ else:
4002
+ regular_predicates.append(child_pred)
4003
+ if regular_predicates:
4004
+ simplified_composite = CompoundPredicate(
4005
+ predicate.operation, regular_predicates
4006
+ )
4007
+ final_fragment = self._convert_join_predicate(simplified_composite)
4008
+ else:
4009
+ first_input, second_input = table_fragments
4010
+ final_fragment = CrossProduct(first_input, second_input)
4011
+ for subquery_pred in subquery_predicates:
4012
+ final_fragment = self._convert_predicate(
4013
+ subquery_pred, input_node=final_fragment
4014
+ )
4015
+ return final_fragment
4016
+
4017
+ case CompoundOperator.Not:
4018
+ pass
4019
+
4020
+ case _:
4021
+ raise ValueError(
4022
+ f"Unknown operation for composite predicate '{predicate}'"
4023
+ )
4024
+
4025
+ def _add_expression(
4026
+ self,
4027
+ expression: SqlExpression,
4028
+ *,
4029
+ input_node: RelNode,
4030
+ subquery_target: typing.Literal[
4031
+ "semijoin", "antijoin", "scalar", "in"
4032
+ ] = "scalar",
4033
+ in_column: Optional[SqlExpression] = None,
4034
+ ) -> RelNode:
4035
+ """Generates the appropriate algebra fragment to execute a specific expression.
4036
+
4037
+ Depending on the specific expression, simple mappings or even join nodes might be included in the fragment. If the
4038
+ expression is already provided by the input fragment, the fragment will be returned unmodified.
4039
+
4040
+ Parameters
4041
+ ----------
4042
+ expression : SqlExpression
4043
+ The expression to include
4044
+ input_node : RelNode
4045
+ The operator that provides the input tuples for the expression
4046
+ subquery_target : typing.Literal["semijoin", "antijoin", "scalar", "in"], optional
4047
+ How the subquery results should be handled. This parameter is only used if the expression actually contains
4048
+ subqueries and depends on the context in which the expression is used, such as the owning predicate. *semijoin*
4049
+ and *antijoin* correspond to *MISSING* and *EXISTS* predicates. *in* corresponds to *IN* predicates that could
4050
+ either contain scalar subqueries that produce just a single value, or subqueries that produce an entire column of
4051
+ values. The appropriate handling is determined by this method automatically. Lastly, *scalar* indicates that the
4052
+ subquery is scalar and should produce just a single value, for usage e.g. in binary predicates or *SELECT* clauses.
4053
+ in_column : Optional[SqlExpression], optional
4054
+ For *IN* predicates that contain a subquery producing multiple rows, e.g. *R.a IN (SELECT S.b FROM S)*, this is the
4055
+ column that is compared to the subquery tuples (*R.a* in the example). For all other cases, this parameter is
4056
+ ignored.
4057
+
4058
+ Returns
4059
+ -------
4060
+ RelNode
4061
+ The expanded algebra fragment
4062
+ """
4063
+ if expression in input_node.provided_expressions():
4064
+ return input_node
4065
+
4066
+ match expression:
4067
+ case ColumnExpression() | StaticValueExpression():
4068
+ return input_node
4069
+ case SubqueryExpression():
4070
+ subquery_root = self._add_subquery(expression.query)
4071
+ match subquery_target:
4072
+ case "semijoin":
4073
+ return SemiJoin(input_node, subquery_root)
4074
+ case "antijoin":
4075
+ return AntiJoin(input_node, subquery_root)
4076
+ case "scalar":
4077
+ return CrossProduct(input_node, subquery_root)
4078
+ case "in" if expression.query.is_scalar():
4079
+ return CrossProduct(input_node, subquery_root)
4080
+ case "in" if not expression.query.is_scalar():
4081
+ unwrapped_scan = subquery_root.input_node
4082
+ assert (
4083
+ isinstance(unwrapped_scan, Projection)
4084
+ and len(unwrapped_scan.columns) == 1
4085
+ )
4086
+ in_predicate = BinaryPredicate.equal(
4087
+ in_column, unwrapped_scan.columns[0]
4088
+ )
4089
+ return SemiJoin(input_node, subquery_root, in_predicate)
4090
+ case CastExpression() | FunctionExpression() | MathExpression():
4091
+ return self._ensure_expression_applicability(expression, input_node)
4092
+ case WindowExpression() | CaseExpression():
4093
+ return self._ensure_expression_applicability(expression, input_node)
4094
+ case _:
4095
+ raise ValueError(f"Did not expect expression '{expression}'")
4096
+
4097
+ def _add_subquery(self, subquery: SqlQuery) -> SubqueryScan:
4098
+ """Generates the appropriate algebra fragment to include a subquery in the current algebra tree."""
4099
+ subquery_parser = _ImplicitRelalgParser(
4100
+ subquery, provided_base_tables=self._base_table_fragments
4101
+ )
4102
+ subquery_root = subquery_parser.generate_relnode()
4103
+ self._required_columns = util.dicts.merge(
4104
+ subquery_parser._required_columns, self._required_columns
4105
+ )
4106
+ # We do not include the subquery base tables in our _base_table_fragments since the subquery base tables are already
4107
+ # processed completely and this would contradict the interpretation of the _base_table_fragments in
4108
+ # _generate_initial_join_order()
4109
+ return SubqueryScan(subquery_root, subquery)
4110
+
4111
+ def _split_filter_predicate(
4112
+ self, pred: AbstractPredicate
4113
+ ) -> dict[TableReference, AbstractPredicate]:
4114
+ """Extracts applicable filter predicates for varios base tables.
4115
+
4116
+ This method splits conjunctive filters consisting of individual predicates for multiple base tables into an explicit
4117
+ mapping from base table to its filters.
4118
+
4119
+ For example, consider a predicate *R.a = 42 AND S.b < 101 AND S.c LIKE 'foo%'*. The split would provide a dictionary
4120
+ ``{R: R.a < 42, S: S.b < 101 AND S.c LIKE 'foo%'}``
4121
+
4122
+ Warnings
4123
+ --------
4124
+ The behavior of this method is undefined if the supplied predicate is not a filter predicate that can be evaluated
4125
+ during the base table evaluation phase.
4126
+ """
4127
+ if not pred.is_filter():
4128
+ raise ValueError(f"Not a filter predicate: '{pred}'")
4129
+
4130
+ if not isinstance(pred, CompoundPredicate):
4131
+ return {_BaseTableLookup()(pred): pred}
4132
+ if pred.operation != CompoundOperator.And:
4133
+ return {_BaseTableLookup()(pred): pred}
4134
+
4135
+ raw_predicate_components: dict[TableReference, set[AbstractPredicate]] = (
4136
+ collections.defaultdict(set)
4137
+ )
4138
+ for child_pred in pred.children:
4139
+ child_split = self._split_filter_predicate(child_pred)
4140
+ for tab, pred in child_split.items():
4141
+ raw_predicate_components[tab].add(pred)
4142
+ return {
4143
+ base_table: CompoundPredicate.create_and(predicates)
4144
+ for base_table, predicates in raw_predicate_components.items()
4145
+ }
4146
+
4147
+ def _split_join_predicate(
4148
+ self, predicate: AbstractPredicate
4149
+ ) -> set[AbstractPredicate]:
4150
+ """Provides all individual join predicates that have to be evaluated.
4151
+
4152
+ For conjunctive predicates, these are the actual components of the conjunction, all other predicates are returned
4153
+ as-is.
4154
+ """
4155
+ if not predicate.is_join():
4156
+ raise ValueError(f"Not a join predicate: '{predicate}'")
4157
+ if (
4158
+ isinstance(predicate, CompoundPredicate)
4159
+ and predicate.operation == CompoundOperator.And
4160
+ ):
4161
+ return set(predicate.children)
4162
+ return {predicate}
4163
+
4164
+ def _ensure_predicate_applicability(
4165
+ self, predicate: AbstractPredicate, input_node: RelNode
4166
+ ) -> RelNode:
4167
+ """Computes all required mappings that have to be execute before a predicate can be evaluated.
4168
+
4169
+ If such mappings exist, the input relation is expanded with a new mapping operation, otherwise the relation is provided
4170
+ as-is.
4171
+
4172
+ Parameters
4173
+ ----------
4174
+ predicate : AbstractPredicate
4175
+ The predicate to evaluate
4176
+ input_node : RelNode
4177
+ An operator providing the expressions that are already available.
4178
+
4179
+ Returns
4180
+ -------
4181
+ RelNode
4182
+ An algebra fragment
4183
+ """
4184
+ provided_expressions = self._collect_provided_expressions(input_node)
4185
+ required_expressions = util.set_union(
4186
+ _collect_all_expressions(expression)
4187
+ for expression in predicate.iterexpressions()
4188
+ )
4189
+ missing_expressions = required_expressions - provided_expressions
4190
+ if missing_expressions:
4191
+ return Map(
4192
+ input_node, _generate_expression_mapping_dict(missing_expressions)
4193
+ )
4194
+ return input_node
4195
+
4196
+ def _ensure_expression_applicability(
4197
+ self, expression: SqlExpression, input_node: RelNode
4198
+ ) -> RelNode:
4199
+ """Computes all required mappings that have to be execute before an expression can be evaluated.
4200
+
4201
+ This is pretty much the equivalent to `_ensure_predicate_applicability` but for expressions.
4202
+
4203
+ Parameters
4204
+ ----------
4205
+ expression : SqlExpression
4206
+ The expression to evaluate
4207
+ input_node : RelNode
4208
+ An operator providing the expressions that are already available.
4209
+
4210
+ Returns
4211
+ -------
4212
+ RelNode
4213
+ An algebra fragment
4214
+ """
4215
+ provided_expressions = self._collect_provided_expressions(input_node)
4216
+ required_expressions = util.set_union(
4217
+ _collect_all_expressions(child_expr)
4218
+ for child_expr in expression.iterchildren()
4219
+ )
4220
+ missing_expressions = required_expressions - provided_expressions
4221
+ if missing_expressions:
4222
+ return Map(
4223
+ input_node, _generate_expression_mapping_dict(missing_expressions)
4224
+ )
4225
+ return input_node
4226
+
4227
+ def _collect_provided_expressions(self, *nodes: RelNode) -> set[SqlExpression]:
4228
+ """Collects all expressions that are provided by a set of algebra nodes."""
4229
+ outer_table_expressions = util.set_union(
4230
+ base_table.provided_expressions()
4231
+ for base_table in self._provided_base_tables.values()
4232
+ )
4233
+ return (
4234
+ util.set_union(node.provided_expressions() for node in nodes)
4235
+ | outer_table_expressions
4236
+ )
4237
+
4238
+
4239
+ def parse_relalg(query: ImplicitSqlQuery) -> RelNode:
4240
+ """Converts an SQL query to a representation in relational algebra.
4241
+
4242
+ Parameters
4243
+ ----------
4244
+ query : util.ImplicitSqlQuery
4245
+ The query to convert
4246
+
4247
+ Returns
4248
+ -------
4249
+ RelNode
4250
+ The root node of the relational algebra tree. Notice that in some cases the algebraic expression might not be a tree
4251
+ but a directed, acyclic graph instead. However, in this case there still is a single root node.
4252
+ """
4253
+ raw_relnode = _ImplicitRelalgParser(query).generate_relnode()
4254
+
4255
+ # We perform a final mutation to ensure that sideways passes are generated correctly. This removes redundant subtrees that
4256
+ # could be left over from the initial parsing.
4257
+ return raw_relnode.mutate()