PostBOUND 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. postbound/__init__.py +211 -0
  2. postbound/_base.py +6 -0
  3. postbound/_bench.py +1012 -0
  4. postbound/_core.py +1153 -0
  5. postbound/_hints.py +1373 -0
  6. postbound/_jointree.py +1079 -0
  7. postbound/_pipelines.py +1121 -0
  8. postbound/_qep.py +1986 -0
  9. postbound/_stages.py +876 -0
  10. postbound/_validation.py +734 -0
  11. postbound/db/__init__.py +72 -0
  12. postbound/db/_db.py +2348 -0
  13. postbound/db/_duckdb.py +785 -0
  14. postbound/db/mysql.py +1195 -0
  15. postbound/db/postgres.py +4216 -0
  16. postbound/experiments/__init__.py +12 -0
  17. postbound/experiments/analysis.py +674 -0
  18. postbound/experiments/benchmarking.py +54 -0
  19. postbound/experiments/ceb.py +877 -0
  20. postbound/experiments/interactive.py +105 -0
  21. postbound/experiments/querygen.py +334 -0
  22. postbound/experiments/workloads.py +980 -0
  23. postbound/optimizer/__init__.py +92 -0
  24. postbound/optimizer/__init__.pyi +73 -0
  25. postbound/optimizer/_cardinalities.py +369 -0
  26. postbound/optimizer/_joingraph.py +1150 -0
  27. postbound/optimizer/dynprog.py +1825 -0
  28. postbound/optimizer/enumeration.py +432 -0
  29. postbound/optimizer/native.py +539 -0
  30. postbound/optimizer/noopt.py +54 -0
  31. postbound/optimizer/presets.py +147 -0
  32. postbound/optimizer/randomized.py +650 -0
  33. postbound/optimizer/tonic.py +1479 -0
  34. postbound/optimizer/ues.py +1607 -0
  35. postbound/qal/__init__.py +343 -0
  36. postbound/qal/_qal.py +9678 -0
  37. postbound/qal/formatter.py +1089 -0
  38. postbound/qal/parser.py +2344 -0
  39. postbound/qal/relalg.py +4257 -0
  40. postbound/qal/transform.py +2184 -0
  41. postbound/shortcuts.py +70 -0
  42. postbound/util/__init__.py +46 -0
  43. postbound/util/_errors.py +33 -0
  44. postbound/util/collections.py +490 -0
  45. postbound/util/dataframe.py +71 -0
  46. postbound/util/dicts.py +330 -0
  47. postbound/util/jsonize.py +68 -0
  48. postbound/util/logging.py +106 -0
  49. postbound/util/misc.py +168 -0
  50. postbound/util/networkx.py +401 -0
  51. postbound/util/numbers.py +438 -0
  52. postbound/util/proc.py +107 -0
  53. postbound/util/stats.py +37 -0
  54. postbound/util/system.py +48 -0
  55. postbound/util/typing.py +35 -0
  56. postbound/vis/__init__.py +5 -0
  57. postbound/vis/fdl.py +69 -0
  58. postbound/vis/graphs.py +48 -0
  59. postbound/vis/optimizer.py +538 -0
  60. postbound/vis/plots.py +84 -0
  61. postbound/vis/tonic.py +70 -0
  62. postbound/vis/trees.py +105 -0
  63. postbound-0.19.0.dist-info/METADATA +355 -0
  64. postbound-0.19.0.dist-info/RECORD +67 -0
  65. postbound-0.19.0.dist-info/WHEEL +5 -0
  66. postbound-0.19.0.dist-info/licenses/LICENSE.txt +202 -0
  67. postbound-0.19.0.dist-info/top_level.txt +1 -0
postbound/vis/fdl.py ADDED
@@ -0,0 +1,69 @@
1
+ """Force-directed layout algorithms"""
2
+
3
+ from __future__ import annotations
4
+
5
+ import collections
6
+ import random
7
+ import typing
8
+ from collections.abc import Callable, Hashable, Iterable
9
+
10
+ import networkx as nx
11
+ import numpy as np
12
+
13
+ T = typing.TypeVar("T", bound=Hashable)
14
+ Debug = True
15
+
16
+ if Debug:
17
+ random.seed = 321
18
+ np.random.seed(321)
19
+
20
+
21
+ def force_directed_layout(
22
+ elements: Iterable[T], difference_score: Callable[[T, T], float]
23
+ ) -> dict[T, np.ndarray]:
24
+ """Lays out the supplied elements in a 2D-space according to the difference score.
25
+
26
+ Pairs of points with a large difference score are positioned further apart than points with a low difference score.
27
+
28
+ The returned dictionary maps each of the input element to the pair of (x, y) coordinates.
29
+ """
30
+ return DefaultLayoutEngine(elements, difference_score)
31
+
32
+
33
+ def kamada_kawai_layout(
34
+ elements: Iterable[T], difference_score: Callable[[T, T], float]
35
+ ) -> dict[T, np.ndarray]:
36
+ elements = list(elements)
37
+ layout_graph = nx.complete_graph(elements)
38
+
39
+ distance_map = collections.defaultdict(dict)
40
+ for a_idx, a in enumerate(elements):
41
+ for b in elements[a_idx:]:
42
+ current_score = difference_score(a, b)
43
+ distance_map[a][b] = current_score
44
+ distance_map[b][a] = current_score
45
+
46
+ elem_pos_spread = len(elements)
47
+ initial_pos = {
48
+ elem: (random.random() * elem_pos_spread, random.random() * elem_pos_spread)
49
+ for elem in elements
50
+ }
51
+ return nx.kamada_kawai_layout(layout_graph, dist=distance_map, pos=initial_pos)
52
+
53
+
54
+ def fruchterman_reingold_layout(
55
+ elements: Iterable[T],
56
+ similarity_score: Callable[[T, T], float],
57
+ *,
58
+ n_iter: int = 100,
59
+ ) -> dict[T, np.ndarray]:
60
+ elements = list(elements)
61
+ layout_graph = nx.Graph()
62
+ layout_graph.add_nodes_from(elements)
63
+ for a_idx, a in enumerate(elements):
64
+ for b in elements[a_idx:]:
65
+ layout_graph.add_edge(a, b, attraction=similarity_score(a, b))
66
+ return nx.spring_layout(layout_graph, weight="attraction", iterations=n_iter)
67
+
68
+
69
+ DefaultLayoutEngine = kamada_kawai_layout
@@ -0,0 +1,48 @@
1
+ from __future__ import annotations
2
+
3
+ import typing
4
+ from typing import Optional
5
+
6
+ import graphviz as gv
7
+ import matplotlib as mpl
8
+ import networkx as nx
9
+
10
+
11
+ @typing.overload
12
+ def draw_graph(graph: nx.Graph) -> gv.Graph:
13
+ pass
14
+
15
+
16
+ @typing.overload
17
+ def draw_graph(graph: nx.DiGraph) -> gv.Digraph:
18
+ pass
19
+
20
+
21
+ def draw_graph(
22
+ graph: nx.Graph | nx.DiGraph, *, directed: Optional[bool] = None, color: str = ""
23
+ ) -> gv.Graph | gv.Digraph:
24
+ if directed is None:
25
+ gv_graph = gv.Digraph() if isinstance(graph, nx.DiGraph) else gv.Graph()
26
+ else:
27
+ gv_graph = gv.Digraph() if directed else gv.Graph()
28
+
29
+ unique_color_labels = set()
30
+ color_mapping: dict = {}
31
+ if color:
32
+ for n, d in graph.nodes.data():
33
+ unique_color_labels.add(d[color])
34
+ viridis = mpl.cm.viridis
35
+ normalized_colors = mpl.colors.Normalize(
36
+ vmin=0, vmax=len(unique_color_labels) - 1
37
+ )
38
+ for i, color_label in enumerate(unique_color_labels):
39
+ color_mapping[color_label] = mpl.colors.rgb2hex(
40
+ viridis(normalized_colors(i))
41
+ )
42
+ for n, d in graph.nodes.data():
43
+ atts = {"color": color_mapping[d[color]]} if color_mapping else {}
44
+ gv_graph.node(str(n), label=gv.escape(str(n)), style="bold", **atts)
45
+ for s, t in graph.edges:
46
+ gv_graph.edge(str(s), str(t))
47
+
48
+ return gv_graph
@@ -0,0 +1,538 @@
1
+ """Utilities to visualize different aspects of query optimization, namely join trees and join graphs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+ from collections.abc import Callable, Sequence
7
+ from typing import Literal, Optional, overload
8
+
9
+ import graphviz as gv
10
+ import networkx as nx
11
+
12
+ from .. import util
13
+ from .._core import TableReference
14
+ from .._jointree import JoinTree, LogicalJoinTree
15
+ from .._qep import QueryPlan
16
+ from ..db._db import Database, DatabasePool
17
+ from ..optimizer._joingraph import JoinGraph
18
+ from ..qal import relalg, transform
19
+ from ..qal._qal import SqlQuery
20
+ from . import trees
21
+
22
+
23
+ def _join_tree_labels(node: JoinTree) -> tuple[str, dict]:
24
+ if node.is_join():
25
+ base_text = "⋈"
26
+ base_style = {"style": "bold"}
27
+ else:
28
+ assert node.is_scan()
29
+ base_text = str(node.base_table)
30
+ base_style = {"color": "grey"}
31
+
32
+ if isinstance(node, LogicalJoinTree):
33
+ base_text += f"\n Card = {node.cardinality}"
34
+
35
+ return base_text, base_style
36
+
37
+
38
+ def _join_tree_traversal(node: JoinTree) -> Sequence[JoinTree]:
39
+ return node.children
40
+
41
+
42
+ def plot_join_tree(join_tree: JoinTree) -> gv.Graph:
43
+ """Creates a Graphviz visualization of a join tree."""
44
+ if not join_tree:
45
+ return gv.Graph()
46
+ return trees.plot_tree(join_tree, _join_tree_labels, _join_tree_traversal)
47
+
48
+
49
+ def _fallback_default_join_edge(
50
+ graph: gv.Digraph, join_table: TableReference, partner_table: TableReference
51
+ ) -> None:
52
+ graph.edge(str(join_table), str(partner_table), dir="none")
53
+
54
+
55
+ def _render_pk_fk_join_edge(
56
+ graph: gv.Digraph,
57
+ query: SqlQuery,
58
+ join_table: TableReference,
59
+ partner_table: TableReference,
60
+ ) -> None:
61
+ db_schema = DatabasePool.get_instance().current_database().schema()
62
+ join_predicate = query.predicates().joins_between(join_table, partner_table)
63
+ if not join_predicate:
64
+ return _fallback_default_join_edge(graph, join_table, partner_table)
65
+
66
+ join_columns = join_predicate.join_partners()
67
+ if len(join_columns) != 1:
68
+ return _fallback_default_join_edge(graph, join_table, partner_table)
69
+
70
+ join_col, partner_col = list(join_columns)[0]
71
+ if db_schema.is_primary_key(join_col) and db_schema.has_secondary_index(
72
+ partner_col
73
+ ):
74
+ graph.edge(str(partner_col.table), str(join_col.table))
75
+ elif db_schema.is_primary_key(partner_col) and db_schema.has_secondary_index(
76
+ join_col
77
+ ):
78
+ graph.edge(str(join_col.table), str(partner_col.table))
79
+ else:
80
+ _fallback_default_join_edge(graph, join_table, partner_table)
81
+
82
+
83
+ def _plot_join_graph_from_query(
84
+ query: SqlQuery,
85
+ table_annotations: Optional[Callable[[TableReference], str]] = None,
86
+ include_pk_fk_joins: bool = False,
87
+ ) -> gv.Graph:
88
+ if not query.predicates():
89
+ return gv.Graph()
90
+ join_graph: nx.Graph = query.predicates().join_graph()
91
+ gv_graph = gv.Digraph() if include_pk_fk_joins else gv.Graph
92
+ for table in join_graph.nodes:
93
+ node_label = str(table)
94
+ node_label += (
95
+ ("\n" + table_annotations(table)) if table_annotations is not None else ""
96
+ )
97
+ gv_graph.node(str(table), label=node_label)
98
+ for start, target in join_graph.edges:
99
+ if include_pk_fk_joins:
100
+ _render_pk_fk_join_edge(gv_graph, query, start, target)
101
+ else:
102
+ gv_graph.edge(str(start), str(target))
103
+ return gv_graph
104
+
105
+
106
+ def _plot_join_graph_directly(
107
+ join_graph: JoinGraph,
108
+ table_annotations: Optional[Callable[[TableReference], str]] = None,
109
+ ) -> gv.Digraph:
110
+ gv_graph = gv.Digraph()
111
+ for table in join_graph:
112
+ node_color = "black" if join_graph.is_free_table(table) else "blue"
113
+ node_label = str(table)
114
+ node_label += (
115
+ ("\n" + table_annotations(table)) if table_annotations is not None else ""
116
+ )
117
+ gv_graph.node(str(table), label=node_label, color=node_color)
118
+ for start, target in join_graph.all_joins():
119
+ if join_graph.is_pk_fk_join(start, target): # start is FK, target is PK
120
+ gv_graph.edge(
121
+ str(start), str(target)
122
+ ) # edge arrow goes from start to target (i.e. FK to PK)
123
+ elif join_graph.is_pk_fk_join(target, start): # target is FK, start is PK
124
+ gv_graph.edge(
125
+ str(target), str(start)
126
+ ) # edge arrow goes form target to start (i.e. FK to PK)
127
+ else:
128
+ gv_graph.edge(str(start), str(target), dir="none")
129
+ return gv_graph
130
+
131
+
132
+ def plot_join_graph(
133
+ query_or_join_graph: SqlQuery | JoinGraph,
134
+ table_annotations: Optional[Callable[[TableReference], str]] = None,
135
+ *,
136
+ include_pk_fk_joins: bool = False,
137
+ out_path: str = "",
138
+ out_format: str = "svg",
139
+ ) -> gv.Graph | gv.Digraph:
140
+ """Creates a Graphviz visualization of a join graph.
141
+
142
+ The join graph can be either supplied directly (in which case it will be visualized as a directed graph), or implicitly
143
+ through its SQL query. In this case, the join graph is inferred based on the join conditions. Such a graph can be further
144
+ customized to also highlight primary-key/foreign-key relationships as a directed graph.
145
+
146
+ The directed graph variants will point from the foreign key table to the primary key table.
147
+
148
+ To customize the information shown on each table node, a custom `table_annotations` function can be provided. Several such
149
+ functions for common annotations are already provided in this module. Annotation functions have a very simple signature:
150
+ they take the table currently being rendered as input and return a string containing the metadata to be shown on the node.
151
+ To add additional context to these methods, it is advisable to use `functools.partial` to bind additional parameters.
152
+
153
+ See Also
154
+ --------
155
+ estimated_cards
156
+ annotate_filter_cards
157
+ annotate_cards
158
+ merged_annotation
159
+ """
160
+ if isinstance(query_or_join_graph, SqlQuery):
161
+ graph = _plot_join_graph_from_query(
162
+ query_or_join_graph, table_annotations, include_pk_fk_joins
163
+ )
164
+ elif isinstance(query_or_join_graph, JoinGraph):
165
+ graph = _plot_join_graph_directly(query_or_join_graph, table_annotations)
166
+ else:
167
+ raise TypeError(
168
+ "Argument must be either SqlQuery or JoinGraph, not"
169
+ + str(type(query_or_join_graph))
170
+ )
171
+
172
+ if out_path:
173
+ graph.render(out_path, format=out_format, cleanup=True)
174
+ return graph
175
+
176
+
177
+ def estimated_cards(
178
+ table: TableReference, *, query: SqlQuery, database: Optional[Database] = None
179
+ ) -> str:
180
+ """Annotates the nodes of a join graph with estimated cardinalities.
181
+
182
+ Estimated cardinalities are obtained by asking the actual query optimizer from the `database`. Usually, they are calculated
183
+ after all matching filter predicates have been applied.
184
+
185
+ Parameters
186
+ ----------
187
+ table : TableReference
188
+ The table to estimate
189
+ query : SqlQuery
190
+ The SQL query being optimized. This is required to infer all filter predicates
191
+ database : Optional[Database], optional
192
+ The database whose optimizer is used to estimate the cardinalities. If `None`, the current database from the
193
+ `DatabasePool` is used.
194
+
195
+ See Also
196
+ --------
197
+ plot_join_graph
198
+ """
199
+ database = (
200
+ database
201
+ if database is not None
202
+ else DatabasePool.get_instance().current_database()
203
+ )
204
+ filter_query = transform.extract_query_fragment(query, [table])
205
+ filter_query = transform.as_star_query(filter_query)
206
+ card_est = database.optimizer().cardinality_estimate(filter_query)
207
+ return f"[{card_est} rows estimated]"
208
+
209
+
210
+ def annotate_filter_cards(
211
+ table: TableReference, *, query: SqlQuery, database: Optional[Database] = None
212
+ ) -> str:
213
+ """Annotates the nodes of a join graph with true cardinalities *after* filters.
214
+
215
+ Cardinalities are calculated by issuing actual *count(\\*)* queries to the `database`. All applicable filter
216
+ predicates are included in the query.
217
+
218
+ Parameters
219
+ ----------
220
+ table : TableReference
221
+ The table to estimate
222
+ query : SqlQuery
223
+ The SQL query being optimized. This is required to infer all filter predicates
224
+ database : Optional[Database], optional
225
+ The database to calculate the cardinalities on. If `None`, the current database from the `DatabasePool` is used.
226
+
227
+ See Also
228
+ --------
229
+ plot_join_graph
230
+ """
231
+ database = (
232
+ database
233
+ if database is not None
234
+ else DatabasePool.get_instance().current_database()
235
+ )
236
+ filter_query = transform.extract_query_fragment(query, [table])
237
+ count_query = transform.as_count_star_query(filter_query)
238
+ card = database.execute_query(count_query, cache_enabled=True)
239
+ return f"[{card} rows]"
240
+
241
+
242
+ def annotate_cards(
243
+ table: TableReference, *, query: SqlQuery, database: Optional[Database] = None
244
+ ) -> str:
245
+ """Annotates the nodes of a join graph with true cardinalities before and after filters.
246
+
247
+ Cardinalities are calculated by issuing actual *count(\\*)* queries to the database. Two values are reported: the total
248
+ cardinality of the table (before filters) and the cardinality after applying all applicable filter predicates from the
249
+ query.
250
+
251
+
252
+ Parameters
253
+ ----------
254
+ table : TableReference
255
+ The table to estimate
256
+ query : SqlQuery
257
+ The SQL query being optimized. This is required to infer all filter predicates
258
+ database : Optional[Database], optional
259
+ The database to calculate the cardinalities on. If `None`, the current database from the `DatabasePool` is used.
260
+
261
+ See Also
262
+ --------
263
+ plot_join_graph
264
+ """
265
+ database = (
266
+ database
267
+ if database is not None
268
+ else DatabasePool.get_instance().current_database()
269
+ )
270
+ filter_query = transform.extract_query_fragment(query, [table])
271
+ count_query = transform.as_count_star_query(filter_query)
272
+ filter_card = database.execute_query(count_query, cache_enabled=True)
273
+ total_card = database.statistics().total_rows(
274
+ table, emulated=True, cache_enabled=True
275
+ )
276
+ return f"|R| = {total_card} |σ(R)| = {filter_card}"
277
+
278
+
279
+ @overload
280
+ def merged_annotation(
281
+ *annotations: Callable[[TableReference], str],
282
+ ) -> Callable[[TableReference], str]:
283
+ """Combines multiple annotation functions for join graphs into a single one."""
284
+ ...
285
+
286
+
287
+ @overload
288
+ def merged_annotation(
289
+ *annotations: Callable[[QueryPlan], str],
290
+ ) -> Callable[[QueryPlan], str]:
291
+ """Combines multiple annotation functions for query plans into a single one."""
292
+ ...
293
+
294
+
295
+ def merged_annotation(
296
+ *annotations: Callable[[TableReference], str] | Callable[[QueryPlan], str],
297
+ ) -> Callable[[TableReference], str] | Callable[[QueryPlan], str]:
298
+ def _merger(node: TableReference | QueryPlan) -> str:
299
+ return "\n".join(annotator(node) for annotator in annotations)
300
+
301
+ return _merger
302
+
303
+
304
+ def setup_annotations(
305
+ *annotations: Literal["estimated-cards", "filter-cards", "true-cards"],
306
+ query: SqlQuery,
307
+ database: Optional[Database] = None,
308
+ ) -> Callable[[TableReference], str]:
309
+ """Annotates the nodes of a join graph with different cardinality estimates."""
310
+ annotation_fns = {
311
+ "estimated-cards": estimated_cards,
312
+ "filter-cards": annotate_filter_cards,
313
+ "true-cards": annotate_cards,
314
+ }
315
+
316
+ annotators: list[Callable[[TableReference], str]] = []
317
+ for annotator in annotations:
318
+ fn = annotation_fns.get(annotator)
319
+ if not fn:
320
+ raise ValueError(f"Unknown annotation: {annotator}")
321
+ annotators.append(functools.partial(fn, query=query, database=database))
322
+
323
+ if not annotators:
324
+ raise ValueError("No annotator given")
325
+ return merged_annotation(*annotators) if len(annotators) > 1 else annotators[0]
326
+
327
+
328
+ def _query_plan_labels(
329
+ node: QueryPlan,
330
+ *,
331
+ annotation_generator: Optional[Callable[[QueryPlan], str]],
332
+ subplan_target: str = "",
333
+ ) -> tuple[str, dict]:
334
+ if node.subplan:
335
+ label, params = _query_plan_labels(
336
+ node.subplan.root,
337
+ annotation_generator=annotation_generator,
338
+ subplan_target=node.subplan.target_name,
339
+ )
340
+ label = f"<<SubPlan>> {subplan_target}\n{label}"
341
+ params["style"] = "dashed"
342
+ elif node.is_join():
343
+ label, params = node.node_type, {"style": "bold"}
344
+ elif node.is_scan():
345
+ label, params = f"<<{node.node_type}>>\n{node.base_table}", {"color": "grey"}
346
+ else:
347
+ label, params = node.node_type, {"style": "dashed", "color": "grey"}
348
+
349
+ if node.subplan:
350
+ label = f"{label}\nSubplan: {node.subplan.target_name}"
351
+
352
+ annotation = annotation_generator(node) if annotation_generator else ""
353
+ label = f"{label}\n{annotation}" if annotation else label
354
+ return label, params
355
+
356
+
357
+ def _query_plan_traversal(
358
+ node: QueryPlan, *, skip_intermediates: bool = False
359
+ ) -> list[QueryPlan]:
360
+ children = list(node.children)
361
+ if node.subplan:
362
+ children.append(node.subplan.root)
363
+
364
+ if skip_intermediates:
365
+ skipped = [
366
+ _query_plan_traversal(child, skip_intermediates=True)
367
+ if child.is_auxiliary()
368
+ else [child]
369
+ for child in children
370
+ ]
371
+ children = util.flatten(skipped)
372
+ return children
373
+
374
+
375
+ def annotate_estimates(node: QueryPlan) -> str:
376
+ """Annotates the nodes of a query plan with estimated cost and cardinality.
377
+
378
+ See Also
379
+ --------
380
+ plot_query_plan
381
+ """
382
+ return f"cost={node.estimated_cost} cardinality={node.estimated_cardinality}"
383
+
384
+
385
+ def plot_query_plan(
386
+ plan: QueryPlan,
387
+ annotation_generator: Optional[Callable[[QueryPlan], str]] = None,
388
+ *,
389
+ skip_intermediates: bool = False,
390
+ **kwargs,
391
+ ) -> gv.Graph:
392
+ """Creates a Graphviz visualization of a query plan.
393
+
394
+ By default, each node is just annotated with its operator type and base table (for scans). To add additional information
395
+ (e.g., estimated or true cardinality), a custom `annotation_generator` function can be provided. Such a function takes the
396
+ current `QueryPlan` node as input and returns a string containing the metadata to be shown on the node. Since this
397
+ signature is quite simple, it is advisable to use `functools.partial` to bind additional parameters to the function.
398
+ This module already provides the most common annotation function: `annotate_estimates`, which adds estimated cost and
399
+ cardinality.
400
+
401
+ For EXPLAIN ANALYZE plans (i.e. plans containing runtime information), the `plot_analyze_plan` function has as meaningful
402
+ default annotation generator.
403
+
404
+ See Also
405
+ --------
406
+ annotate_estimates
407
+ plot_analyze_plan
408
+ """
409
+ if not plan:
410
+ return gv.Graph()
411
+ return trees.plot_tree(
412
+ plan,
413
+ functools.partial(
414
+ _query_plan_labels, annotation_generator=annotation_generator
415
+ ),
416
+ functools.partial(_query_plan_traversal, skip_intermediates=skip_intermediates),
417
+ **kwargs,
418
+ )
419
+
420
+
421
+ def _explain_analyze_annotations(node: QueryPlan) -> str:
422
+ card_row = (
423
+ f"[Rows expected={node.estimated_cardinality} actual={node.actual_cardinality}]"
424
+ )
425
+ exec_time = round(node.execution_time, 4)
426
+ runtime_row = f"[Exec time={exec_time}s]"
427
+ return card_row + "\n" + runtime_row
428
+
429
+
430
+ def plot_analyze_plan(
431
+ plan: QueryPlan, *, skip_intermediates: bool = False, **kwargs
432
+ ) -> gv.Graph:
433
+ """Creates a Graphviz visualization of an EXPLAIN ANALYZE query plan.
434
+
435
+ This is a convenience wrapper around `plot_query_plan` that uses a default annotation generator suitable for showing
436
+ runtime information contained in EXPLAIN ANALYZE plans.
437
+ """
438
+ if not plan:
439
+ return gv.Graph()
440
+ return trees.plot_tree(
441
+ plan,
442
+ functools.partial(
443
+ _query_plan_labels, annotation_generator=_explain_analyze_annotations
444
+ ),
445
+ functools.partial(_query_plan_traversal, skip_intermediates=skip_intermediates),
446
+ **kwargs,
447
+ )
448
+
449
+
450
+ def _escape_label(text: str) -> str:
451
+ return text.replace("<", "&lt;").replace(">", "&gt;")
452
+
453
+
454
+ def _make_sub(text: str) -> str:
455
+ return f"<sub><font point-size='10.0'>{_escape_label(text)}</font></sub>"
456
+
457
+
458
+ def _make_label(text: str) -> str:
459
+ return f"<b>{text}</b>"
460
+
461
+
462
+ def _relalg_node_labels(node: relalg.RelNode) -> tuple[str, dict]:
463
+ node_params = {}
464
+ match node:
465
+ case relalg.Projection():
466
+ projection_targets = ", ".join(str(t) for t in node.columns)
467
+ node_str = f"{_make_label('π')} {_make_sub(projection_targets)}"
468
+ case relalg.Selection():
469
+ predicate = str(node.predicate)
470
+ node_str = f"{_make_label('σ')} {_make_sub(predicate)}"
471
+ case relalg.ThetaJoin():
472
+ predicate = str(node.predicate)
473
+ node_str = f"{_make_label('⋈')} {_make_sub(predicate)}"
474
+ case relalg.SemiJoin():
475
+ predicate = str(node.predicate)
476
+ node_str = f"{_make_label('⋉')} {_make_sub(predicate)}"
477
+ case relalg.AntiJoin():
478
+ predicate = str(node.predicate)
479
+ node_str = f"{_make_label('▷')} {_make_sub(predicate)}"
480
+ case relalg.Grouping():
481
+ columns_str = ", ".join(str(c) for c in node.group_columns)
482
+ aggregates: list[str] = []
483
+ for group_columns, agg_func in node.aggregates.items():
484
+ if len(group_columns) == 1:
485
+ group_str = str(util.simplify(group_columns))
486
+ else:
487
+ group_str = "(" + ", ".join(str(c) for c in group_columns) + ")"
488
+
489
+ if len(agg_func) == 1:
490
+ func_str = str(util.simplify(agg_func))
491
+ else:
492
+ func_str = "(" + ", ".join(str(agg) for agg in agg_func) + ")"
493
+ aggregates.append(f"{group_str}: {func_str}")
494
+ agg_str = ", ".join(agg for agg in aggregates)
495
+ prefix = f"{_make_sub(columns_str)} " if columns_str else ""
496
+ suffix = f" {_make_sub(agg_str)}" if agg_str else ""
497
+ node_str = "".join([prefix, _make_label("γ"), suffix])
498
+ case relalg.Map():
499
+ pretty_mapping: dict[str, str] = {}
500
+ for target_col, expression in node.mapping.items():
501
+ if len(target_col) == 1:
502
+ target_col = util.simplify(target_col)
503
+ target_str = str(target_col)
504
+ else:
505
+ target_str = "(" + ", ".join(str(t) for t in target_col) + ")"
506
+ if len(expression) == 1:
507
+ expression = util.simplify(expression)
508
+ expr_str = str(expression)
509
+ else:
510
+ expr_str = "(" + ", ".join(str(e) for e in expression) + ")"
511
+ pretty_mapping[target_str] = expr_str
512
+ mapping_str = ", ".join(
513
+ f"{target_col}: {expr}" for target_col, expr in pretty_mapping.items()
514
+ )
515
+ node_str = f"{_make_label('χ')} {_make_sub(mapping_str)}"
516
+ case _:
517
+ node_str = _escape_label(str(node))
518
+ return f"<{node_str}>", node_params
519
+
520
+
521
+ def _relalg_child_traversal(node: relalg.RelNode) -> Sequence[relalg.RelNode]:
522
+ return node.children()
523
+
524
+
525
+ def plot_relalg(relnode: relalg.RelNode, **kwargs) -> gv.Graph:
526
+ """Creates a Graphviz visualization of a relational algebra expression tree.
527
+
528
+ Additional keyword arguments are passed to `plot_tree`.
529
+ """
530
+ return trees.plot_tree(
531
+ relnode,
532
+ _relalg_node_labels,
533
+ _relalg_child_traversal,
534
+ escape_labels=False,
535
+ node_id_generator=id,
536
+ strict=True, # gv.Graph arguments
537
+ **kwargs,
538
+ )