PostBOUND 0.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- postbound/__init__.py +211 -0
- postbound/_base.py +6 -0
- postbound/_bench.py +1012 -0
- postbound/_core.py +1153 -0
- postbound/_hints.py +1373 -0
- postbound/_jointree.py +1079 -0
- postbound/_pipelines.py +1121 -0
- postbound/_qep.py +1986 -0
- postbound/_stages.py +876 -0
- postbound/_validation.py +734 -0
- postbound/db/__init__.py +72 -0
- postbound/db/_db.py +2348 -0
- postbound/db/_duckdb.py +785 -0
- postbound/db/mysql.py +1195 -0
- postbound/db/postgres.py +4216 -0
- postbound/experiments/__init__.py +12 -0
- postbound/experiments/analysis.py +674 -0
- postbound/experiments/benchmarking.py +54 -0
- postbound/experiments/ceb.py +877 -0
- postbound/experiments/interactive.py +105 -0
- postbound/experiments/querygen.py +334 -0
- postbound/experiments/workloads.py +980 -0
- postbound/optimizer/__init__.py +92 -0
- postbound/optimizer/__init__.pyi +73 -0
- postbound/optimizer/_cardinalities.py +369 -0
- postbound/optimizer/_joingraph.py +1150 -0
- postbound/optimizer/dynprog.py +1825 -0
- postbound/optimizer/enumeration.py +432 -0
- postbound/optimizer/native.py +539 -0
- postbound/optimizer/noopt.py +54 -0
- postbound/optimizer/presets.py +147 -0
- postbound/optimizer/randomized.py +650 -0
- postbound/optimizer/tonic.py +1479 -0
- postbound/optimizer/ues.py +1607 -0
- postbound/qal/__init__.py +343 -0
- postbound/qal/_qal.py +9678 -0
- postbound/qal/formatter.py +1089 -0
- postbound/qal/parser.py +2344 -0
- postbound/qal/relalg.py +4257 -0
- postbound/qal/transform.py +2184 -0
- postbound/shortcuts.py +70 -0
- postbound/util/__init__.py +46 -0
- postbound/util/_errors.py +33 -0
- postbound/util/collections.py +490 -0
- postbound/util/dataframe.py +71 -0
- postbound/util/dicts.py +330 -0
- postbound/util/jsonize.py +68 -0
- postbound/util/logging.py +106 -0
- postbound/util/misc.py +168 -0
- postbound/util/networkx.py +401 -0
- postbound/util/numbers.py +438 -0
- postbound/util/proc.py +107 -0
- postbound/util/stats.py +37 -0
- postbound/util/system.py +48 -0
- postbound/util/typing.py +35 -0
- postbound/vis/__init__.py +5 -0
- postbound/vis/fdl.py +69 -0
- postbound/vis/graphs.py +48 -0
- postbound/vis/optimizer.py +538 -0
- postbound/vis/plots.py +84 -0
- postbound/vis/tonic.py +70 -0
- postbound/vis/trees.py +105 -0
- postbound-0.19.0.dist-info/METADATA +355 -0
- postbound-0.19.0.dist-info/RECORD +67 -0
- postbound-0.19.0.dist-info/WHEEL +5 -0
- postbound-0.19.0.dist-info/licenses/LICENSE.txt +202 -0
- postbound-0.19.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from .._core import TableReference
|
|
7
|
+
from .._jointree import JoinTree, LogicalJoinTree
|
|
8
|
+
from ..db._db import Database, DatabasePool
|
|
9
|
+
from ..optimizer._joingraph import JoinGraph
|
|
10
|
+
from ..qal import transform
|
|
11
|
+
from ..qal._qal import ImplicitSqlQuery, SqlQuery
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclasses.dataclass(frozen=True)
|
|
15
|
+
class ManualJoinOrderSelection:
|
|
16
|
+
query: SqlQuery
|
|
17
|
+
join_order: tuple[TableReference]
|
|
18
|
+
database: Database
|
|
19
|
+
|
|
20
|
+
def join_tree(self) -> LogicalJoinTree:
|
|
21
|
+
base_table, *joined_tables = self.join_order
|
|
22
|
+
join_tree: JoinTree[None] = JoinTree.scan(base_table)
|
|
23
|
+
for joined_table in joined_tables:
|
|
24
|
+
join_tree = join_tree.join_with(joined_table)
|
|
25
|
+
return join_tree
|
|
26
|
+
|
|
27
|
+
def final_query(self) -> SqlQuery:
|
|
28
|
+
return self.database.hinting().generate_hints(self.query, self.join_tree())
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class InteractiveJoinOrderOptimizer:
|
|
32
|
+
def __init__(
|
|
33
|
+
self, query: ImplicitSqlQuery, *, database: Optional[Database] = None
|
|
34
|
+
) -> None:
|
|
35
|
+
self._query = query
|
|
36
|
+
self._db = (
|
|
37
|
+
database
|
|
38
|
+
if database is not None
|
|
39
|
+
else DatabasePool.get_instance().current_database()
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def start(
|
|
43
|
+
self, *, use_predicate_equivalence_classes: bool = False
|
|
44
|
+
) -> ManualJoinOrderSelection:
|
|
45
|
+
join_graph = JoinGraph(
|
|
46
|
+
self._query,
|
|
47
|
+
self._db.schema(),
|
|
48
|
+
include_predicate_equivalence_classes=use_predicate_equivalence_classes,
|
|
49
|
+
)
|
|
50
|
+
join_graph_stack: list[JoinGraph] = []
|
|
51
|
+
join_order: list[TableReference] = []
|
|
52
|
+
n_tables = len(self._query.tables())
|
|
53
|
+
|
|
54
|
+
while n_tables > len(join_order):
|
|
55
|
+
intermediate_str = (
|
|
56
|
+
" ⋈ ".join(str(tab.identifier()) for tab in join_graph.joined_tables())
|
|
57
|
+
if join_order
|
|
58
|
+
else "∅"
|
|
59
|
+
)
|
|
60
|
+
if join_order:
|
|
61
|
+
query_fragment = transform.extract_query_fragment(
|
|
62
|
+
self._query, join_order
|
|
63
|
+
)
|
|
64
|
+
query_fragment = transform.as_count_star_query(query_fragment)
|
|
65
|
+
current_card = self._db.execute_query(query_fragment)
|
|
66
|
+
intermediate_str += f" (card = {current_card})"
|
|
67
|
+
print("> Current intermediate:", intermediate_str)
|
|
68
|
+
|
|
69
|
+
print("> Available actions:")
|
|
70
|
+
available_joins = dict(enumerate(join_graph.available_join_paths()))
|
|
71
|
+
for join_idx, join in available_joins.items():
|
|
72
|
+
print(
|
|
73
|
+
f"[{join_idx}]\tJoin {join.start_table.identifier()} ⋈ {join.target_table.identifier()}"
|
|
74
|
+
)
|
|
75
|
+
print(" b\tbacktrack to last graph")
|
|
76
|
+
action = input("> Select next join:")
|
|
77
|
+
|
|
78
|
+
if action == "b":
|
|
79
|
+
if len(join_graph_stack) == 0:
|
|
80
|
+
print("Already at initial join graph")
|
|
81
|
+
join_graph = join_graph_stack.pop()
|
|
82
|
+
join_order.pop()
|
|
83
|
+
print()
|
|
84
|
+
continue
|
|
85
|
+
elif not action.isdigit():
|
|
86
|
+
print(f"Unknown action: '{action}'\n")
|
|
87
|
+
continue
|
|
88
|
+
|
|
89
|
+
next_join_idx = int(action)
|
|
90
|
+
next_join = available_joins[next_join_idx]
|
|
91
|
+
|
|
92
|
+
join_graph_stack.append(join_graph.clone())
|
|
93
|
+
if join_graph.initial():
|
|
94
|
+
join_graph.mark_joined(next_join.start_table)
|
|
95
|
+
join_order.append(next_join.start_table)
|
|
96
|
+
join_graph.mark_joined(next_join.target_table)
|
|
97
|
+
join_order.append(next_join.target_table)
|
|
98
|
+
print()
|
|
99
|
+
|
|
100
|
+
final_card = self._db.execute_query(transform.as_count_star_query(self._query))
|
|
101
|
+
print(f"> Done. (final card = {final_card})")
|
|
102
|
+
print("> Final join order: ", [tab.identifier() for tab in join_order])
|
|
103
|
+
return ManualJoinOrderSelection(
|
|
104
|
+
query=self._query, join_order=tuple(join_order), database=self._db
|
|
105
|
+
)
|
|
@@ -0,0 +1,334 @@
|
|
|
1
|
+
"""Simple randomized query generator."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import random
|
|
6
|
+
from collections.abc import Generator
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
import networkx as nx
|
|
10
|
+
|
|
11
|
+
from .. import util
|
|
12
|
+
from .._core import ColumnReference, TableReference
|
|
13
|
+
from ..db import postgres
|
|
14
|
+
from ..db._db import Database, DatabasePool, ForeignKeyRef
|
|
15
|
+
from ..qal._qal import (
|
|
16
|
+
AbstractPredicate,
|
|
17
|
+
CompoundOperator,
|
|
18
|
+
CompoundPredicate,
|
|
19
|
+
ImplicitFromClause,
|
|
20
|
+
LogicalOperator,
|
|
21
|
+
Select,
|
|
22
|
+
SqlQuery,
|
|
23
|
+
Where,
|
|
24
|
+
as_predicate,
|
|
25
|
+
build_query,
|
|
26
|
+
)
|
|
27
|
+
from ..util import networkx as nx_utils
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _generate_join_predicates(
|
|
31
|
+
tables: list[TableReference], *, schema: nx.DiGraph
|
|
32
|
+
) -> AbstractPredicate:
|
|
33
|
+
"""Generates equi-join predicates for specific tables.
|
|
34
|
+
|
|
35
|
+
Between each pair of tables, a join predicate is generated if there is a foreign key relationship between the two tables.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
tables : list[TableReference]
|
|
40
|
+
The tables to join.
|
|
41
|
+
schema : nx.DiGraph
|
|
42
|
+
Graph for the entire schema. It must contain node for all `tables`, but can contain more nodes.
|
|
43
|
+
|
|
44
|
+
Returns
|
|
45
|
+
-------
|
|
46
|
+
AbstractPredicate
|
|
47
|
+
A compound predicate that represents all foreign key joins between the tables. Notice that if there are partitions in
|
|
48
|
+
the schema, the predicate will not span all partitions, leading to cross products between some tables.
|
|
49
|
+
|
|
50
|
+
Raises
|
|
51
|
+
------
|
|
52
|
+
ValueError
|
|
53
|
+
If no foreign key edges are found between the tables.
|
|
54
|
+
"""
|
|
55
|
+
predicates: list[AbstractPredicate] = []
|
|
56
|
+
|
|
57
|
+
for i, outer_tab in enumerate(tables):
|
|
58
|
+
# Make sure to check each pair of tables only once
|
|
59
|
+
for inner_tab in tables[i + 1 :]:
|
|
60
|
+
# Since we are working with a directed graph, we need to check for both FK reference "directions"
|
|
61
|
+
fk_edge = schema.get_edge_data(outer_tab, inner_tab)
|
|
62
|
+
if not fk_edge:
|
|
63
|
+
fk_edge = schema.get_edge_data(inner_tab, outer_tab)
|
|
64
|
+
if not fk_edge:
|
|
65
|
+
continue
|
|
66
|
+
|
|
67
|
+
candidate_keys: list[ForeignKeyRef] = fk_edge["foreign_keys"]
|
|
68
|
+
selected_key = random.choice(candidate_keys)
|
|
69
|
+
source_col, target_col = selected_key.referenced_col, selected_key.fk_col
|
|
70
|
+
join_predicate = as_predicate(source_col, LogicalOperator.Equal, target_col)
|
|
71
|
+
predicates.append(join_predicate)
|
|
72
|
+
|
|
73
|
+
if not predicates:
|
|
74
|
+
raise ValueError(
|
|
75
|
+
f"Found no suitable edges in the schema graph for tables {tables}."
|
|
76
|
+
)
|
|
77
|
+
return CompoundPredicate.create_and(predicates)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _generate_filter(
|
|
81
|
+
column: ColumnReference, *, target_db: Database
|
|
82
|
+
) -> Optional[AbstractPredicate]:
|
|
83
|
+
"""Generates a random filter predicate for a specific column.
|
|
84
|
+
|
|
85
|
+
The predicate will be a simple binary predicate using one of the following operators: equality, inequality, greater than,
|
|
86
|
+
or less than. The comparison value is randomly selected from the set of distinct values for the column.
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
column : ColumnReference
|
|
91
|
+
The column to filter.
|
|
92
|
+
target_db : Database
|
|
93
|
+
The database that contains all allowed values for the column.
|
|
94
|
+
|
|
95
|
+
Returns
|
|
96
|
+
-------
|
|
97
|
+
Optional[AbstractPredicate]
|
|
98
|
+
A random filter predicate on the column. If no candidate values are found in the database, *None* is returned.
|
|
99
|
+
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
# TODO: for text columns we should also generate LIKE predicates
|
|
103
|
+
candidate_operators = [
|
|
104
|
+
LogicalOperator.Equal,
|
|
105
|
+
LogicalOperator.NotEqual,
|
|
106
|
+
LogicalOperator.Greater,
|
|
107
|
+
LogicalOperator.Less,
|
|
108
|
+
]
|
|
109
|
+
|
|
110
|
+
# We need to compute the unique values for the column
|
|
111
|
+
# For Postgres, we can use the TABLESAMPLE clause to reduce the number of rows to materialize by a large number.
|
|
112
|
+
# For all other databases, we need to compute the entire result set.
|
|
113
|
+
# In either case, we should never cache the results of this query, even if this might be way more efficient for continuous
|
|
114
|
+
# sampling of queries. The reason is that the number of distinct values can be huge and we don't want to overload the cache
|
|
115
|
+
estimated_n_rows = target_db.statistics().total_rows(column.table, emulated=False)
|
|
116
|
+
if isinstance(target_db, postgres.PostgresInterface) and estimated_n_rows > 1000:
|
|
117
|
+
distinct_template = (
|
|
118
|
+
"""SELECT DISTINCT {col} FROM {tab} TABLESAMPLE BERNOULLI(1)"""
|
|
119
|
+
)
|
|
120
|
+
else:
|
|
121
|
+
distinct_template = """SELECT DISTINCT {col} FROM {tab}"""
|
|
122
|
+
|
|
123
|
+
candidate_values = target_db.execute_query(
|
|
124
|
+
distinct_template.format(col=column.name, tab=column.table.full_name),
|
|
125
|
+
cache_enabled=False,
|
|
126
|
+
)
|
|
127
|
+
if not candidate_values:
|
|
128
|
+
return None
|
|
129
|
+
candidate_values = util.enlist(candidate_values)
|
|
130
|
+
|
|
131
|
+
selected_value = random.choice(candidate_values)
|
|
132
|
+
selected_operator = random.choice(candidate_operators)
|
|
133
|
+
filter_predicate = as_predicate(column, selected_operator, selected_value)
|
|
134
|
+
return filter_predicate
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _is_numeric(data_type: str) -> bool:
|
|
138
|
+
"""Checks, whether a data type is numeric."""
|
|
139
|
+
return data_type in {
|
|
140
|
+
"integer",
|
|
141
|
+
"smallint",
|
|
142
|
+
"bigint",
|
|
143
|
+
"date",
|
|
144
|
+
"double precision",
|
|
145
|
+
"real",
|
|
146
|
+
"numeric",
|
|
147
|
+
"decimal",
|
|
148
|
+
} or data_type.startswith("time")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def generate_query(
|
|
152
|
+
target_db: Optional[Database],
|
|
153
|
+
*,
|
|
154
|
+
count_star: bool = False,
|
|
155
|
+
ignore_tables: Optional[set[TableReference]] = None,
|
|
156
|
+
min_tables: Optional[int] = None,
|
|
157
|
+
max_tables: Optional[int] = None,
|
|
158
|
+
min_filters: Optional[int] = None,
|
|
159
|
+
max_filters: Optional[int] = None,
|
|
160
|
+
filter_key_columns: bool = True,
|
|
161
|
+
numeric_filters: bool = False,
|
|
162
|
+
) -> Generator[SqlQuery, None, None]:
|
|
163
|
+
"""A simple randomized query generator.
|
|
164
|
+
|
|
165
|
+
The generator selects a random subset of (connected) tables from the schema graph of the target database and builds a
|
|
166
|
+
random number of random filter predicates on the columns of the selected tables.
|
|
167
|
+
|
|
168
|
+
The generator will yield new queries until the user stops requesting them, there is no termination condition.
|
|
169
|
+
|
|
170
|
+
Parameters
|
|
171
|
+
----------
|
|
172
|
+
target_db : Optional[Database]
|
|
173
|
+
The database from which queries should be generated. The database is important for two main use-cases:
|
|
174
|
+
|
|
175
|
+
1. The schema graph is used to select a (connected) subset of tables
|
|
176
|
+
2. The column values are used to generate filter predicates
|
|
177
|
+
|
|
178
|
+
If no database is provided, the current database from the database pool is used.
|
|
179
|
+
|
|
180
|
+
count_star : bool, optional
|
|
181
|
+
Whether the resulting queries should contain a *COUNT(\\*)* instead of a plain * *SELECT* clause
|
|
182
|
+
ignore_tables : Optional[set[TableReference]], optional
|
|
183
|
+
An optional set of tables that should never be contained in the generated queries. For Postgres databases, internal
|
|
184
|
+
*pg_XXX* tables are ignored automatically.
|
|
185
|
+
min_tables : Optional[int], optional
|
|
186
|
+
The minimum number of tables that should be contained in each query. Default is 1.
|
|
187
|
+
max_tables : Optional[int], optional
|
|
188
|
+
The maximum number of tables that should be contained in each query. Default is the number of tables in the schema
|
|
189
|
+
graph (minus the ignored tables).
|
|
190
|
+
min_filters : Optional[int], optional
|
|
191
|
+
The minimum number of filter predicates that should be contained in each query. Default is 0.
|
|
192
|
+
max_filters : Optional[int], optional
|
|
193
|
+
The maximum number of filter predicates that should be contained in each query. By default, each column from the
|
|
194
|
+
selected tables can be filtered.
|
|
195
|
+
filter_key_columns : bool, optional
|
|
196
|
+
Whether primary key/foreign key columns should be considered for filtering. This is enabled by default.
|
|
197
|
+
numeric_filters : bool, optional
|
|
198
|
+
Whether only numeric columns should be considered for filtering (i.e. integer, float or time columns). This is disabled
|
|
199
|
+
by default.
|
|
200
|
+
|
|
201
|
+
Yields
|
|
202
|
+
------
|
|
203
|
+
Generator[SqlQuery, None, None]
|
|
204
|
+
A random SQL query
|
|
205
|
+
|
|
206
|
+
Examples
|
|
207
|
+
--------
|
|
208
|
+
>>> qgen = generate_query(some_database)
|
|
209
|
+
>>> next(qgen)
|
|
210
|
+
"""
|
|
211
|
+
|
|
212
|
+
target_db = target_db or DatabasePool.get_instance().current_database()
|
|
213
|
+
db_schema = target_db.schema()
|
|
214
|
+
|
|
215
|
+
#
|
|
216
|
+
# Our sampling algorithm is acutally pretty straightforward:
|
|
217
|
+
#
|
|
218
|
+
# 1. We select a random number of tables
|
|
219
|
+
# 2. We select a random number of columns to filter from the selected tables
|
|
220
|
+
# 3. We generate the join predicates between the tables
|
|
221
|
+
# 4. We generate random filter predicates for the selected columns
|
|
222
|
+
# 5. We build the query
|
|
223
|
+
#
|
|
224
|
+
# The hardest part (and the part that takes up the most LOCs), is making sure that we always select from the correct
|
|
225
|
+
# subset.
|
|
226
|
+
#
|
|
227
|
+
|
|
228
|
+
schema_graph = db_schema.as_graph()
|
|
229
|
+
if ignore_tables:
|
|
230
|
+
nodes_to_remove = [node for node in schema_graph.nodes if node in ignore_tables]
|
|
231
|
+
schema_graph.remove_nodes_from(nodes_to_remove)
|
|
232
|
+
if isinstance(target_db, postgres.PostgresInterface):
|
|
233
|
+
nodes_to_remove = [
|
|
234
|
+
node for node in schema_graph.nodes if node.full_name.startswith("pg_")
|
|
235
|
+
]
|
|
236
|
+
schema_graph.remove_nodes_from(nodes_to_remove)
|
|
237
|
+
|
|
238
|
+
min_tables = min_tables or 1
|
|
239
|
+
if not min_tables:
|
|
240
|
+
raise ValueError("min_tables must be at least 1")
|
|
241
|
+
max_tables = max_tables or len(schema_graph.nodes)
|
|
242
|
+
max_tables = min(max_tables, len(schema_graph.nodes))
|
|
243
|
+
if max_tables < min_tables:
|
|
244
|
+
raise ValueError(
|
|
245
|
+
f"max_tables must be at least as large as min_tables. Got {max_tables} (max) and {min_tables} (min)."
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
filter_columns = util.flatten(
|
|
249
|
+
cols for __, cols in schema_graph.nodes(data="columns")
|
|
250
|
+
)
|
|
251
|
+
min_filters = min_filters or 0
|
|
252
|
+
max_filters = max_filters or len(filter_columns)
|
|
253
|
+
|
|
254
|
+
select_clause = Select.count_star() if count_star else Select.star()
|
|
255
|
+
|
|
256
|
+
# We generate new queries until the user asks us to stop.
|
|
257
|
+
while True:
|
|
258
|
+
n_tables = random.randint(min_tables, max_tables)
|
|
259
|
+
|
|
260
|
+
# We ensure that we always generate a connected join graph by performing a random walk through the schema graph.
|
|
261
|
+
# This way, we can terminate the walk at any point if we have visited enough tables.
|
|
262
|
+
table_walk = nx_utils.nx_random_walk(schema_graph.to_undirected())
|
|
263
|
+
joined_tables: list[TableReference] = [
|
|
264
|
+
next(table_walk) for _ in range(n_tables)
|
|
265
|
+
]
|
|
266
|
+
|
|
267
|
+
available_columns: list[ColumnReference] = util.flatten(
|
|
268
|
+
[
|
|
269
|
+
cols
|
|
270
|
+
for tab, cols in schema_graph.nodes(data="columns")
|
|
271
|
+
if tab in set(joined_tables)
|
|
272
|
+
]
|
|
273
|
+
)
|
|
274
|
+
if not filter_key_columns:
|
|
275
|
+
available_columns = [
|
|
276
|
+
col
|
|
277
|
+
for col in available_columns
|
|
278
|
+
if not db_schema.is_primary_key(col)
|
|
279
|
+
and not db_schema.foreign_keys_on(col)
|
|
280
|
+
]
|
|
281
|
+
if numeric_filters:
|
|
282
|
+
available_columns = [
|
|
283
|
+
col
|
|
284
|
+
for col in available_columns
|
|
285
|
+
if _is_numeric(schema_graph.nodes[col.table]["data_type"][col])
|
|
286
|
+
]
|
|
287
|
+
|
|
288
|
+
from_clause = ImplicitFromClause.create_for(joined_tables)
|
|
289
|
+
join_predicates = (
|
|
290
|
+
_generate_join_predicates(joined_tables, schema=schema_graph)
|
|
291
|
+
if n_tables > 1
|
|
292
|
+
else None
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
current_max_filters = min(max_filters, len(available_columns))
|
|
296
|
+
if current_max_filters <= min_filters:
|
|
297
|
+
# too few columns available, let's just try again
|
|
298
|
+
continue
|
|
299
|
+
else:
|
|
300
|
+
n_filters = random.randint(min_filters, current_max_filters)
|
|
301
|
+
|
|
302
|
+
if n_filters > 0 and available_columns:
|
|
303
|
+
cols_to_filter = random.sample(available_columns, n_filters)
|
|
304
|
+
individual_filters = [
|
|
305
|
+
_generate_filter(col, target_db=target_db) for col in cols_to_filter
|
|
306
|
+
]
|
|
307
|
+
individual_filters = [
|
|
308
|
+
pred for pred in individual_filters if pred is not None
|
|
309
|
+
]
|
|
310
|
+
filter_predicates = (
|
|
311
|
+
CompoundPredicate.create_and(individual_filters)
|
|
312
|
+
if individual_filters
|
|
313
|
+
else None
|
|
314
|
+
)
|
|
315
|
+
else:
|
|
316
|
+
filter_predicates = None
|
|
317
|
+
|
|
318
|
+
# This is just a bit of optimization to avoid useless nesting inside the WHERE clause
|
|
319
|
+
where_parts: list[AbstractPredicate] = []
|
|
320
|
+
for predicates in (join_predicates, filter_predicates):
|
|
321
|
+
match predicates:
|
|
322
|
+
case None:
|
|
323
|
+
pass
|
|
324
|
+
case CompoundPredicate(op, children) if op == CompoundOperator.And:
|
|
325
|
+
where_parts.extend(children)
|
|
326
|
+
case _:
|
|
327
|
+
where_parts.append(join_predicates)
|
|
328
|
+
|
|
329
|
+
where_clause = (
|
|
330
|
+
Where(CompoundPredicate.create_and(where_parts)) if where_parts else None
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
query = build_query([select_clause, from_clause, where_clause])
|
|
334
|
+
yield query
|