PostBOUND 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. postbound/__init__.py +211 -0
  2. postbound/_base.py +6 -0
  3. postbound/_bench.py +1012 -0
  4. postbound/_core.py +1153 -0
  5. postbound/_hints.py +1373 -0
  6. postbound/_jointree.py +1079 -0
  7. postbound/_pipelines.py +1121 -0
  8. postbound/_qep.py +1986 -0
  9. postbound/_stages.py +876 -0
  10. postbound/_validation.py +734 -0
  11. postbound/db/__init__.py +72 -0
  12. postbound/db/_db.py +2348 -0
  13. postbound/db/_duckdb.py +785 -0
  14. postbound/db/mysql.py +1195 -0
  15. postbound/db/postgres.py +4216 -0
  16. postbound/experiments/__init__.py +12 -0
  17. postbound/experiments/analysis.py +674 -0
  18. postbound/experiments/benchmarking.py +54 -0
  19. postbound/experiments/ceb.py +877 -0
  20. postbound/experiments/interactive.py +105 -0
  21. postbound/experiments/querygen.py +334 -0
  22. postbound/experiments/workloads.py +980 -0
  23. postbound/optimizer/__init__.py +92 -0
  24. postbound/optimizer/__init__.pyi +73 -0
  25. postbound/optimizer/_cardinalities.py +369 -0
  26. postbound/optimizer/_joingraph.py +1150 -0
  27. postbound/optimizer/dynprog.py +1825 -0
  28. postbound/optimizer/enumeration.py +432 -0
  29. postbound/optimizer/native.py +539 -0
  30. postbound/optimizer/noopt.py +54 -0
  31. postbound/optimizer/presets.py +147 -0
  32. postbound/optimizer/randomized.py +650 -0
  33. postbound/optimizer/tonic.py +1479 -0
  34. postbound/optimizer/ues.py +1607 -0
  35. postbound/qal/__init__.py +343 -0
  36. postbound/qal/_qal.py +9678 -0
  37. postbound/qal/formatter.py +1089 -0
  38. postbound/qal/parser.py +2344 -0
  39. postbound/qal/relalg.py +4257 -0
  40. postbound/qal/transform.py +2184 -0
  41. postbound/shortcuts.py +70 -0
  42. postbound/util/__init__.py +46 -0
  43. postbound/util/_errors.py +33 -0
  44. postbound/util/collections.py +490 -0
  45. postbound/util/dataframe.py +71 -0
  46. postbound/util/dicts.py +330 -0
  47. postbound/util/jsonize.py +68 -0
  48. postbound/util/logging.py +106 -0
  49. postbound/util/misc.py +168 -0
  50. postbound/util/networkx.py +401 -0
  51. postbound/util/numbers.py +438 -0
  52. postbound/util/proc.py +107 -0
  53. postbound/util/stats.py +37 -0
  54. postbound/util/system.py +48 -0
  55. postbound/util/typing.py +35 -0
  56. postbound/vis/__init__.py +5 -0
  57. postbound/vis/fdl.py +69 -0
  58. postbound/vis/graphs.py +48 -0
  59. postbound/vis/optimizer.py +538 -0
  60. postbound/vis/plots.py +84 -0
  61. postbound/vis/tonic.py +70 -0
  62. postbound/vis/trees.py +105 -0
  63. postbound-0.19.0.dist-info/METADATA +355 -0
  64. postbound-0.19.0.dist-info/RECORD +67 -0
  65. postbound-0.19.0.dist-info/WHEEL +5 -0
  66. postbound-0.19.0.dist-info/licenses/LICENSE.txt +202 -0
  67. postbound-0.19.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,877 @@
1
+ """
2
+ Implementation of the Cardinality Estimation Benchmark (CEB) algorithm to automatically generate workload queries based on
3
+ different templates.
4
+
5
+ While the original CEB was introduced in [1]_, the allowed template settings and their interaction was barely documented.
6
+ Therefore, we provide our own implementation that is (hopefully) better documented and understandable.
7
+
8
+ Generally speaking you do not need to interact with this module directly. Instead, there exists a high-level CLI tool
9
+ called ``ceb-generator.py`` in the ``tools`` directory. It exposes the most relevant options and can be used instead.
10
+
11
+ References
12
+ ----------
13
+ [1] Parimarjan Negi et al.: "Flow-Loss: Learning Cardinality Estimates That Matter" (PVLDB 2021)
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import collections
19
+ import pathlib
20
+ import random
21
+ import tomllib
22
+ from collections.abc import Iterable
23
+ from typing import Any, Literal, NewType, Optional
24
+
25
+ import numpy as np
26
+
27
+ from ..db import postgres
28
+ from ..db._db import Database
29
+ from ..qal import formatter, parser
30
+ from ..qal._qal import ColumnReference, SqlQuery, TableReference
31
+ from ..util._errors import StateError
32
+ from ..util.misc import DependencyGraph
33
+ from .workloads import Workload
34
+
35
+ # we introduce a bunch of type aliases to prevent types like dict[str, str]
36
+ TemplatedQuery = NewType("TemplatedQuery", str)
37
+ """The raw query text with placeholders for the predicate values. Due to the placeholders, this won't be a valid SQL query."""
38
+
39
+ ColumnName = NewType("ColumnName", str)
40
+ """Reference to a column with a placeholder predicate."""
41
+
42
+ PredicateName = NewType("PredicateName", str)
43
+ """A predicate that produces a subset of the placeholder values."""
44
+
45
+ PredicateType = Literal["=", "<", ">", "<=", ">=", "<>", "LIKE", "ILIKE", "IN"]
46
+ """The operand of a placeholder predicate."""
47
+
48
+ PlaceholderName = NewType("PlaceholderName", str)
49
+ """The actual placeholder key that will be replace in a `TemplatedQuery`."""
50
+
51
+ PlaceHolderValue = Any
52
+ """The selected value to bind and replace a placeholder."""
53
+
54
+
55
+ def _tuplelize_value(val: Any) -> Any:
56
+ """Leaves scalar values as they are, but converts lists to tuples. This ensures that we can hash them."""
57
+ if isinstance(val, list):
58
+ return tuple(val)
59
+ return val
60
+
61
+
62
+ def _make_options_list(
63
+ options: list[PlaceHolderValue],
64
+ ) -> list[tuple[PlaceHolderValue]]:
65
+ """Transforms explitic options list into our standardized format: a list of tuples of individual option values."""
66
+ if not options:
67
+ raise ValueError("Must provide at least one option")
68
+ initial_opt = options[0]
69
+ if isinstance(initial_opt, tuple):
70
+ return options
71
+ return [(opt,) for opt in options]
72
+
73
+
74
+ def _remove_weight_col(
75
+ val: tuple[PlaceHolderValue], col_idx: int
76
+ ) -> tuple[PlaceHolderValue]:
77
+ """Drops the weight column from a pre-weighted tuple."""
78
+ return tuple([elem for i, elem in enumerate(val) if i != col_idx])
79
+
80
+
81
+ class PredicateGenerator:
82
+ """The predicate generator handles the selection of substittion values for a subset of all placeholders in a template.
83
+
84
+ Parameters
85
+ ----------
86
+ name : PredicateName
87
+ Alias of the predicate. This might be referenced as a dependency in other predicates.
88
+ provided_keys : list[PlaceholderName]
89
+ The placeholders for which this predicate calculates substitution values.
90
+ template_type : Literal["sql", "list"]
91
+ The inference strategy for the predicate. If "sql", the values are fetched from the `db_connection` using an actual SQL
92
+ query. If "list", the values must be provided as a static list.
93
+ sampling_method : Literal["uniform", "weighted"]
94
+ How to select the final value from a list of candidates. If "uniform", all values have the same probability. If
95
+ "weighted", values that occur multiple times have a higher higher chance of being selected.
96
+ target_columns : list[ColumnName]
97
+ Columns corresponding to the placeholders. This has to be indexed in the same order as `provided_keys`. Columns must
98
+ be listed explicitly to ensure that values are escaped properly in the final query. For example, text values require
99
+ surrounding quotes, etc.
100
+ pred_type : list[PredicateType]
101
+ Operands corresponding to the columns and placeholders. This has to be indexed in the same order as `provided_keys`.
102
+ Operands must be listed explicitly in order to ensure that values are formatted and prepared correctly for the final
103
+ query. For example, *LIKE* predicates require insertion of wildcard characters, *IN* predicates require surrounding
104
+ parens, etc.
105
+
106
+ Notice that *IN* predicates require that the predicate only computes a single placeholder value. Otherwise it is
107
+ unclear how scalar values from one predicate should correlate to values for the *IN* predicate.
108
+ sql_query : Optional[str], optional
109
+ The actual SQL query to compute the selected values. This query must compute the values in the same order as the
110
+ `provided_keys`. It can contain placeholders that are computed by the predicates listed in `dependencies`. This
111
+ parameter is required for ``template_type="sql"`` and ignored otherwise.
112
+ list_allowed_values : Optional[list[PlaceHolderValue]], optional
113
+ The options to choose from to select a placeholder value. This parameter is required for ``template_type="list"`` and
114
+ ignored otherwise.
115
+ in_pred_min_samples : Optional[int], optional
116
+ For *IN* predicates, this designates the minimum number of values that must be included in the final *IN* predicate.
117
+ Set to 1 by default.
118
+ in_pred_max_samples : Optional[int], optional
119
+ For *IN* predicates this designates the maximum number of values that might be included in the final *IN* predicate.
120
+ This defaults to the total number of candidate values. Since this might be quite a lot, it is recommended to set this
121
+ value to something reasonable.
122
+ count_column_idx : Optional[int], optional
123
+ For weighted sampling this denotes the column that contains pre-calculated weights for each candidate value. If this
124
+ is omitted, the candidate values are assumed to contain duplicates and weights are inferred based on the number of
125
+ occurences of each value. The index is relative to the *SELECT* clause of the SQL query or the list of allowed values.
126
+ dependencies : Optional[list[PredicateName]], optional
127
+ Predicates that compute values referenced in this predicate's SQL query.
128
+ max_tries : Optional[int], optional
129
+ The maximum number of attempts to select a valid value for the predicate. It is necessary to re-try if the selected
130
+ value fails some constraints (e.g. the number of allowed values in an *IN* predicate). In such a case, all dependent
131
+ values are re-drawn as well.
132
+ db_connection : Optional[Database], optional
133
+ The database containing the values to sample from.
134
+ """
135
+
136
+ def __init__(
137
+ self,
138
+ name: PredicateName,
139
+ *,
140
+ provided_keys: list[PlaceholderName],
141
+ template_type: Literal["sql", "list"],
142
+ sampling_method: Literal["uniform", "weighted"],
143
+ target_columns: list[ColumnName],
144
+ pred_type: list[PredicateType],
145
+ sql_query: Optional[str] = None,
146
+ list_allowed_values: Optional[list[PlaceHolderValue]] = None,
147
+ in_pred_min_samples: int = 1,
148
+ in_pred_max_samples: Optional[int] = None,
149
+ count_column_idx: Optional[int] = None,
150
+ dependencies: Optional[list[PredicateName]] = None,
151
+ max_tries: Optional[int] = None,
152
+ db_connection: Optional[Database] = None,
153
+ ) -> None:
154
+ self.name = name
155
+
156
+ if "IN" in pred_type and len(provided_keys) > 1:
157
+ raise ValueError(
158
+ "IN predicates must only compute a single placeholder value"
159
+ )
160
+ self.pred_type = pred_type
161
+
162
+ self.dependencies = dependencies
163
+
164
+ if not (len(provided_keys) == len(target_columns) == len(pred_type)):
165
+ raise ValueError(
166
+ "The number of provided keys, target columns, and predicate types must match"
167
+ )
168
+ self._key_lookup = dict(((k, i) for i, k in enumerate(provided_keys)))
169
+
170
+ self._parent_generator: QueryTemplate | None = None
171
+ self._template_type = template_type
172
+ self._sampling_method = sampling_method
173
+ self._target_columns = target_columns
174
+
175
+ if template_type == "sql" and not sql_query:
176
+ raise ValueError(
177
+ f"SQL query must be provided for sql-typed predicate '{name}'"
178
+ )
179
+ if template_type == "list" and not list_allowed_values:
180
+ raise ValueError(
181
+ f"Option values must be provided for list-typed predicate '{name}'"
182
+ )
183
+ self._sql_query = sql_query
184
+ self._list_allowed_values = (
185
+ _make_options_list(list_allowed_values) if list_allowed_values else None
186
+ )
187
+
188
+ self._count_col_idx = (
189
+ count_column_idx - 1 if count_column_idx is not None else None
190
+ )
191
+
192
+ self._in_pred_min_samples = in_pred_min_samples
193
+ self._in_pred_max_samples = in_pred_max_samples
194
+
195
+ self._max_tries = max_tries
196
+ self._db_connection = db_connection
197
+
198
+ self._selected_values: list = []
199
+
200
+ @property
201
+ def placeholders(self) -> Iterable[PlaceholderName]:
202
+ """Provides all placeholder keys that are computed by this predicate."""
203
+ return self._key_lookup.keys()
204
+
205
+ def choose_predicate_values(self) -> None:
206
+ """Draws a valid value for the predicate based on the provided strategy.
207
+
208
+ Raises
209
+ ------
210
+ SamplingError
211
+ If no valid value passing all constraints could be found within the maximum number of tries.
212
+ """
213
+ if self._max_tries is None:
214
+ self._max_tries = self._parent_generator.max_tries
215
+
216
+ selected_value: list[PlaceHolderValue] | None = None
217
+ current_try = 0
218
+ redraw_dependent_values = False
219
+ while current_try < self._max_tries:
220
+ current_try += 1
221
+ try:
222
+ selected_value = self._next_predicate_value(redraw_dependent_values)
223
+ except SamplingError:
224
+ # if we did not find any value, retry but make sure to also refresh all dependent values.
225
+ redraw_dependent_values = True
226
+ selected_value = None
227
+ continue
228
+
229
+ if not all(
230
+ self._value_passes_constraints(key, selected_value)
231
+ for key in self._key_lookup.keys()
232
+ ):
233
+ redraw_dependent_values = True
234
+ selected_value = None
235
+ continue
236
+ else:
237
+ break
238
+
239
+ if not selected_value:
240
+ raise SamplingError(
241
+ f"Did not find a valid value for predicate '{self.name}'"
242
+ )
243
+ self._selected_values = selected_value
244
+
245
+ def fetch_value(self, key: PlaceholderName) -> PlaceHolderValue:
246
+ """Provides the selected value for a specific placeholder key.
247
+
248
+ Raises
249
+ ------
250
+ StateError
251
+ If no value has been selected yet via `choose_predicate_values()`.
252
+ """
253
+ self._assert_values_available()
254
+ self._assert_valid_key(key)
255
+ value_idx = self._key_lookup[key]
256
+ return self._selected_values[value_idx]
257
+
258
+ def selected_values(self) -> dict[PlaceholderName, PlaceHolderValue]:
259
+ """Provides all selected values for the placeholders computed by this predicate.
260
+
261
+ Raises
262
+ ------
263
+ StateError
264
+ If no value has been selected yet via `choose_predicate_values()`.
265
+ """
266
+ self._assert_values_available()
267
+ return {k: self._selected_values[i] for k, i in self._key_lookup.items()}
268
+
269
+ def column_for(self, key: PlaceholderName) -> ColumnName:
270
+ """Provides the column that is computed by a specific placeholder key."""
271
+ self._assert_valid_key(key)
272
+ value_idx = self._key_lookup[key]
273
+ return self._target_columns[value_idx]
274
+
275
+ def predicate_for(self, key: PlaceholderName) -> PredicateType:
276
+ """Provides the operand of the predicate computed by a specifc placeholder key."""
277
+ self._assert_valid_key(key)
278
+ value_idx = self._key_lookup[key]
279
+ return self.pred_type[value_idx]
280
+
281
+ def _next_predicate_value(
282
+ self, redraw_dependent_values: bool
283
+ ) -> tuple[PlaceHolderValue]:
284
+ """Calculates the next (tuple of) placeholder values based on the specified selection strategy.
285
+
286
+ This is the main workhorse method which delegates to all further more specialized methods, e.g. to actually draw
287
+ a value from SQL, or to select a value with weighted propabilities, etc.
288
+
289
+ Parameters
290
+ ----------
291
+ redraw_dependent_values : bool
292
+ Whether all predicates that supply dependent values to this predicates should select new values. This is important
293
+ if we already tried to select a new value but it did not pass all of our constraints.
294
+ """
295
+ if self._template_type == "list" and self._list_allowed_values is not None:
296
+ candidate_values = self._list_allowed_values
297
+ elif self._template_type == "sql":
298
+ candidate_values = self._collect_candidate_values_from_sql(
299
+ redraw_dependent_values
300
+ )
301
+ else:
302
+ raise ValueError(f"Unknown template type: '{self._template_type}'")
303
+
304
+ if self.pred_type == ["IN"]:
305
+ selected_value = self._draw_multi_values(candidate_values)
306
+ else:
307
+ selected_value = self._draw_scalar_value(candidate_values)
308
+
309
+ if not isinstance(selected_value, list) and not isinstance(
310
+ selected_value, tuple
311
+ ):
312
+ selected_value = [selected_value]
313
+ return selected_value
314
+
315
+ def _collect_candidate_values_from_sql(
316
+ self, redraw_dependent_values: bool
317
+ ) -> list[tuple[PlaceHolderValue]]:
318
+ """Provides all possible candidate values based on an SQL query.
319
+
320
+ This method is also responsible for generating an adequate SQL query by subsituting all dependent values.
321
+
322
+ Parameters
323
+ ----------
324
+ redraw_dependent_values : bool
325
+ Whether all predicates that supply dependent values to this predicates should select new values. This is important
326
+ if we already tried to select a new value but it did not pass all of our constraints.
327
+
328
+ Raises
329
+ ------
330
+ SamplingError
331
+ If the query did not provide any results. This might happen if we have an unlucky selection of dependent values
332
+ and is no big deal since we can simply try again (as controlled by the higher-up methods).
333
+ """
334
+ sql_query = self._sql_query
335
+
336
+ for dep in self.dependencies:
337
+ dependent_values = self._parent_generator.selected_values(
338
+ dep, refresh=redraw_dependent_values
339
+ )
340
+ sql_query = self._parent_generator.substitute_placeholders(
341
+ sql_query, dependent_values
342
+ )
343
+
344
+ candidate_values = self._db_connection.execute_query(sql_query, raw=True)
345
+ if not candidate_values:
346
+ raise SamplingError(f"No values found for predicate '{self.name}'")
347
+ return [tuple(candidate) for candidate in candidate_values]
348
+
349
+ def _draw_scalar_value(
350
+ self, candidate_values: list[tuple[PlaceHolderValue]]
351
+ ) -> tuple[PlaceHolderValue]:
352
+ """Selects a single value from the candidates according to the specified sampling strategy."""
353
+ if self._sampling_method == "uniform":
354
+ # For uniform selection duplicate occurences of the same value should not increase their chance of selection.
355
+ # Therefore we need to make sure that no value is present more than once.
356
+ unique_values = list(set(candidate_values))
357
+ selected_val = random.choice(unique_values)
358
+ return selected_val
359
+
360
+ elif self._sampling_method == "weighted":
361
+ # For weighted mode, we can either receive the desired weights along with the candidate values, or we might need
362
+ # to calculate them ourselves.
363
+ # In the first case, each candidate value tuple also contains a weight entry that is designated by the
364
+ # `count_col_idx` attribute.
365
+ # In the latter case, each occurence of the same candidate value counts as a weight increase, hence we can just
366
+ # select one of the values at uniform probability without eliminating duplicates.
367
+ weights: list[int] | None = (
368
+ [val[self._count_col_idx] for val in candidate_values]
369
+ if self._count_col_idx is not None
370
+ else None
371
+ )
372
+ selected_val = random.choices(candidate_values, weights=weights, k=1)[
373
+ 0
374
+ ] # choices always returns a list!
375
+
376
+ if self._count_col_idx is not None:
377
+ # for pre-weighted lists our selected value does not only contain the actual data, but also the weight column
378
+ # since the weight column can be located at an arbitrary position and should not become part of the actual
379
+ # "payload", we need to filter the selected value before returning it
380
+ selected_val = _remove_weight_col(selected_val, self._count_col_idx)
381
+ return selected_val
382
+
383
+ else:
384
+ raise ValueError(f"Unknown sampling method: '{self._sampling_method}'")
385
+
386
+ def _draw_multi_values(
387
+ self, candidate_values: list[tuple[PlaceHolderValue]]
388
+ ) -> tuple[PlaceHolderValue]:
389
+ """Selects placeholder values for *IN* predicates according to the specified sampling strategy."""
390
+ if len(candidate_values[0]) != 1 or (
391
+ self._count_col_idx and len(candidate_values[0]) != 2
392
+ ):
393
+ raise ValueError(
394
+ "IN predicates must only compute a single placeholder value"
395
+ )
396
+
397
+ min_values = self._in_pred_min_samples
398
+
399
+ if self._sampling_method == "uniform":
400
+ # Uniform sampling is easy: we just need to determine which unique values are available and then choose a
401
+ # correctly-sized subset from them
402
+ candidate_values = list(set(candidate_values))
403
+
404
+ max_values = (
405
+ len(candidate_values)
406
+ if self._in_pred_max_samples is None
407
+ else min(self._in_pred_max_samples, len(candidate_values))
408
+ )
409
+ n_values = random.randint(min_values, max_values)
410
+
411
+ selected_val = random.sample(candidate_values, k=n_values)
412
+ return [tuple(selected_val)]
413
+
414
+ if self._count_col_idx is not None:
415
+ # If weights are already supplied, we just need to extract them
416
+ val_idx = 0 if self._count_col_idx == 1 else 1
417
+ population, weights = zip(
418
+ *[(val[val_idx], val[self._count_col_idx]) for val in candidate_values]
419
+ )
420
+ else:
421
+ # Otherwise we calculate our own weights based on the number of occurences of each value
422
+ counter = collections.Counter([val[0] for val in candidate_values])
423
+ population, weights = zip(*counter.items())
424
+
425
+ max_values = (
426
+ len(population)
427
+ if self._in_pred_max_samples is None
428
+ else min(self._in_pred_max_samples, len(population))
429
+ )
430
+ n_values = random.randint(min_values, max_values)
431
+
432
+ # We use numpy's random module here because it supports sampling from a population with custom weights as well as
433
+ # without replacement.
434
+ # But numpy expects propabilities instead of weights, so we need calculate them first
435
+ weights = np.array(weights)
436
+ weights = weights / weights.sum()
437
+
438
+ rng = np.random.default_rng()
439
+ selected_val: list[PlaceHolderValue] = rng.choice(
440
+ population, size=n_values, p=weights, replace=False
441
+ )
442
+ return [tuple(selected_val)]
443
+
444
+ def _value_passes_constraints(
445
+ self, key: PlaceholderName, value: list[PlaceHolderValue]
446
+ ) -> bool:
447
+ """Checks, whether a specific value passes all constraints attached to its placeholder."""
448
+ if self.predicate_for(key) != "IN":
449
+ return True
450
+
451
+ max_allowed_values = (
452
+ self._in_pred_max_samples
453
+ if self._in_pred_max_samples is not None
454
+ else len(value)
455
+ )
456
+ return self._in_pred_min_samples <= len(value) <= max_allowed_values
457
+
458
+ def _assert_values_available(self) -> None:
459
+ """Raises an error if no values have been selected yet."""
460
+ if not self._selected_values:
461
+ raise StateError(
462
+ "Must first call choose_predicate_values() to select values"
463
+ )
464
+
465
+ def _assert_valid_key(self, key: PlaceholderName) -> None:
466
+ """Raises an error if a placeholder is not computed by the current predicate."""
467
+ if key not in self._key_lookup:
468
+ raise KeyError(f"Key '{key}' is not provided by template '{self.name}'")
469
+
470
+ def __hash__(self) -> int:
471
+ return hash(self.name)
472
+
473
+ def __eq__(self, value: object) -> bool:
474
+ return isinstance(value, type(self)) and value.name == self.name
475
+
476
+ def __repr__(self) -> str:
477
+ return str(self)
478
+
479
+ def __str__(self) -> str:
480
+ if self.dependencies:
481
+ deps_str = ", ".join(self.dependencies)
482
+ return f"{self.name}({deps_str})"
483
+ return self.name
484
+
485
+
486
+ class QueryTemplate:
487
+ """A query template handles generation of all placeholder values and the replacement process in the final query.
488
+
489
+ Parameters
490
+ ----------
491
+ base_query : TemplatedQuery
492
+ The actual query to format. Its placeholders will be replaced by the selected values obtained from the predicate
493
+ generators.
494
+ label : str
495
+ The label of the current query template. Mostly used for debugging purposes.
496
+ table_aliases : dict[str, str]
497
+ A map from alias to physical table name. This is necessary to determine to which table the columns belong without
498
+ parsing the query (which is impossible since the query is not yet valid SQL due to the placeholders). If a table does
499
+ not have an alias, the table name itself can be used as a key. The value has to be empty in this case.
500
+ db_connection : Database
501
+ The database providing the actual candidate values for the placeholders
502
+ """
503
+
504
+ def __init__(
505
+ self,
506
+ base_query: TemplatedQuery,
507
+ *,
508
+ label: str,
509
+ table_aliases: dict[str, str],
510
+ db_connection: Database,
511
+ ) -> None:
512
+ self.label = label
513
+ self.base_query = base_query
514
+
515
+ # The table aliases should map alias => fully-qualified table name. However, if a table has no alias, its
516
+ # fully-qualified name becomes the key instead. This makes the weird calculation of the target value necessary.
517
+ # The root cause is that a physical table can be referenced with multiple aliases in the same query and other tables
518
+ # might be referenced without any alias, still within the same query. SQL is weird, man!
519
+ self._table_aliases = {
520
+ alias: (tab if tab else alias) for alias, tab in table_aliases.items()
521
+ }
522
+
523
+ self._predicate_generators: dict[PredicateName, PredicateGenerator] = {}
524
+ self._generator_lookup: dict[PlaceholderName, PredicateGenerator] = {}
525
+
526
+ self._db_conn = db_connection
527
+
528
+ @property
529
+ def max_tries(self) -> int:
530
+ """How often a predicate generator may attempt to obtain a valid placeholder value."""
531
+ return 10
532
+
533
+ def register_generator(self, generator: PredicateGenerator) -> None:
534
+ """Adds a new predicate generator to the template.
535
+
536
+ Raises
537
+ ------
538
+ KeyError
539
+ If another generator has already been registered for one of the provided placeholder keys or the same generator
540
+ name.
541
+ """
542
+ if generator.name in self._predicate_generators:
543
+ raise KeyError(
544
+ f"Predicate '{generator.name}' already registered in template '{self.label}'"
545
+ )
546
+
547
+ self._predicate_generators[generator.name] = generator
548
+
549
+ for key in generator.placeholders:
550
+ if key in self._generator_lookup:
551
+ raise KeyError(
552
+ f"Key '{key}' already registered in template '{self.label}'"
553
+ )
554
+ self._generator_lookup[key] = generator
555
+
556
+ generator._parent_generator = self
557
+
558
+ def selected_values(
559
+ self, predicate: PredicateName, *, refresh: bool = False
560
+ ) -> dict[PlaceholderName, PlaceHolderValue]:
561
+ """Provides the values selected by a specific predicate generator.
562
+
563
+ Parameters
564
+ ----------
565
+ predicate : PredicateName
566
+ The generator name.
567
+ refresh : bool, optional
568
+ Whether the generator should re-draw all values. This is useful if some dependent predicate cannot satisfy its
569
+ constraints with the current values.
570
+ """
571
+ if predicate not in self._predicate_generators:
572
+ raise KeyError(
573
+ f"Predicate '{predicate}' not found in template '{self.label}'"
574
+ )
575
+
576
+ generator = self._predicate_generators[predicate]
577
+ if refresh:
578
+ generator.choose_predicate_values()
579
+ return generator.selected_values()
580
+
581
+ def substitute_placeholders(
582
+ self,
583
+ query: TemplatedQuery,
584
+ selected_values: dict[PlaceholderName, PlaceHolderValue],
585
+ ) -> TemplatedQuery:
586
+ """Replaces all placeholders with their selected values in a specific query.
587
+
588
+ The query must not be the `base_query`. For example, it can also be a dependent SQL query of a predicate generator.
589
+ """
590
+ for key, value in selected_values.items():
591
+ generator = self._generator_lookup[key]
592
+
593
+ target_column = self._lookup_column(generator.column_for(key))
594
+ column_dtype = self._db_conn.schema().datatype(target_column)
595
+ pred_type = generator.predicate_for(key)
596
+
597
+ escaped_placeholder = self._escape_col_value(value, pred_type, column_dtype)
598
+ query = query.replace(f"<<{key}>>", escaped_placeholder)
599
+
600
+ return query
601
+
602
+ def generate_raw_query(self) -> str:
603
+ """Creates a new SQL query by replacing all placeholders in the base query with appropriate values."""
604
+ dep_graph: DependencyGraph[PredicateGenerator] = DependencyGraph()
605
+ for generator in self._predicate_generators.values():
606
+ dependencies = (
607
+ [self._predicate_generators[dep] for dep in generator.dependencies]
608
+ if generator.dependencies
609
+ else []
610
+ )
611
+ dep_graph.add_task(generator, depends_on=dependencies)
612
+
613
+ for generator in dep_graph:
614
+ generator.choose_predicate_values()
615
+
616
+ final_query = self.base_query
617
+ for generator in self._predicate_generators.values():
618
+ selected_values = generator.selected_values()
619
+ final_query = self.substitute_placeholders(final_query, selected_values)
620
+
621
+ return str(final_query)
622
+
623
+ def generate_query(self) -> SqlQuery:
624
+ """Creates a new SQL query by replacing all placeholders in the base query with appropriate values."""
625
+ final_query = self.generate_raw_query()
626
+ return parser.parse_query(final_query)
627
+
628
+ def _lookup_column(self, colname: ColumnName) -> ColumnReference:
629
+ """Generates an actual column reference for a specific column name."""
630
+ if "." not in colname:
631
+ tables_without_alias = [
632
+ TableReference(tab)
633
+ for tab, alias in self._table_aliases
634
+ if tab == alias
635
+ ]
636
+ target_table = self._db_conn.schema().lookup_column(
637
+ colname, tables_without_alias
638
+ )
639
+ return ColumnReference(colname, target_table)
640
+
641
+ table, column = colname.split(".")
642
+ target_table = self._table_aliases[table]
643
+ table_ref = TableReference(target_table, table)
644
+ return ColumnReference(column, table_ref)
645
+
646
+ def _escape_col_value(self, value: PlaceHolderValue, pred_type, dtype: str) -> str:
647
+ """Creates an appropriately escaped string for a placeholder.
648
+
649
+ Depending on the predicate type the value might be processed further, e.g. by adding wildcard operators for *LIKE*
650
+ predicates. Likewise, the values might be wrapped by parens for *IN* predicates.
651
+ """
652
+ if isinstance(value, tuple):
653
+ assert pred_type == "IN"
654
+ escaped_values = [self._escape_col_value(v, "=", dtype) for v in value]
655
+ value_text = ", ".join(escaped_values)
656
+ return f"({value_text})"
657
+
658
+ if dtype in {"text", "varchar", "char", "character varying"}:
659
+ if pred_type == "LIKE" or pred_type == "ILIKE":
660
+ value = f"'%{value}%'"
661
+ return f"'{value}'"
662
+
663
+ return str(value)
664
+
665
+ def __hash__(self) -> int:
666
+ return hash(self.label)
667
+
668
+ def __eq__(self, value: object) -> bool:
669
+ return isinstance(value, type(self)) and value.label == self.label
670
+
671
+ def __repr__(self) -> str:
672
+ return str(self)
673
+
674
+ def __str__(self) -> str:
675
+ predicates_str = ", ".join(self._predicate_generators.keys())
676
+ return f"{self.label}({predicates_str})"
677
+
678
+
679
+ def _parse_template_toml(
680
+ path: str | pathlib.Path, db_connection: Database
681
+ ) -> QueryTemplate:
682
+ """Generates a full query template instance based on its TOML description."""
683
+ contents = {}
684
+ with open(path, "rb") as toml_file:
685
+ contents = tomllib.load(toml_file)
686
+
687
+ query_template = QueryTemplate(
688
+ TemplatedQuery(contents["base_sql"]["sql"]),
689
+ label=contents["title"],
690
+ table_aliases=contents["base_sql"]["table_aliases"],
691
+ db_connection=db_connection,
692
+ )
693
+
694
+ for raw_predicate in contents["predicates"]:
695
+ parsed_predicate = PredicateGenerator(
696
+ PredicateName(raw_predicate["name"]),
697
+ provided_keys=[
698
+ PlaceholderName(k.removeprefix("<<").removesuffix(">>"))
699
+ for k in raw_predicate["keys"]
700
+ ],
701
+ template_type=raw_predicate["type"],
702
+ sampling_method=raw_predicate["sampling_method"],
703
+ pred_type=raw_predicate["pred_type"],
704
+ target_columns=[ColumnName(c) for c in raw_predicate["columns"]],
705
+ sql_query=raw_predicate.get("sql"),
706
+ list_allowed_values=[
707
+ _tuplelize_value(option) for option in raw_predicate.get("options", [])
708
+ ],
709
+ in_pred_min_samples=raw_predicate.get("min_samples", 1),
710
+ in_pred_max_samples=raw_predicate.get("max_samples"),
711
+ dependencies=[
712
+ PredicateName(d) for d in raw_predicate.get("dependencies", [])
713
+ ],
714
+ db_connection=db_connection,
715
+ )
716
+ query_template.register_generator(parsed_predicate)
717
+
718
+ return query_template
719
+
720
+
721
+ def generate_raw_workload(
722
+ path: str | pathlib.Path,
723
+ *,
724
+ queries_per_template: int,
725
+ template_pattern: str = "*.toml",
726
+ db_connection: Optional[Database] = None,
727
+ ) -> dict[str, str]:
728
+ """Produces an unoptimized workload based on a number of CEB templates.
729
+
730
+ In contrast to `generate_workload`, generated queries are not parsed into actual query objects. Instead, the raw query
731
+ text is provided. This function is intended for situations where the parser or the query abstraction cannot yet be used
732
+ to represent the desired structure of the templates.
733
+
734
+ Parameters
735
+ ----------
736
+ path : str | pathlib.Path
737
+ The directory containing the template files.
738
+ queries_per_template : int
739
+ The number of queries that should be generated for each template. Queries will be distinguished by increasing label
740
+ numbers.
741
+ template_pattern : str, optional
742
+ A GLOB pattern that all template files must match to be recognized as such. Defaults to *\\*.toml*.
743
+ db_connection : Optional[Database], optional
744
+ The database to use for fetching appropriate candidate values for the placeholders. If omitted, a default Postgres
745
+ connection will be opened.
746
+
747
+ Returns
748
+ -------
749
+ dict[str, str]
750
+ The generated workload. Maps query labels to the raw query text.
751
+
752
+ Raises
753
+ ------
754
+ SamplingError
755
+ If the sampling algorithm could not satisfy all constraints of its predicates.
756
+
757
+ See Also
758
+ --------
759
+ generate_workload
760
+ """
761
+ db_connection = postgres.connect() if db_connection is None else db_connection
762
+ template_dir = path if isinstance(path, pathlib.Path) else pathlib.Path(path)
763
+ if not template_dir.is_dir():
764
+ raise FileNotFoundError(f"Directory '{template_dir}' does not exist")
765
+
766
+ templates: list[QueryTemplate] = []
767
+ for template_file in template_dir.glob(template_pattern):
768
+ templates.append(_parse_template_toml(template_file, db_connection))
769
+
770
+ max_tries = (
771
+ len(templates) * queries_per_template * 10
772
+ ) # TODO: the user should be able to control this parameter?!
773
+ generated_queries: set[str] = set()
774
+ workload_queries: dict[str, str] = {}
775
+ for template in templates:
776
+ generated_count, num_tries = 0, 0
777
+ while generated_count < queries_per_template and num_tries <= max_tries:
778
+ num_tries += 1
779
+ query = template.generate_raw_query()
780
+ if query in generated_queries:
781
+ if num_tries == max_tries:
782
+ raise SamplingError(
783
+ "Could not generate enough unique queries for template {template.label}"
784
+ )
785
+ continue
786
+ else:
787
+ generated_queries.add(query)
788
+ generated_count += 1
789
+
790
+ template_idx = str(
791
+ generated_count
792
+ ) # this works b/c we already incremented the generated_count just above!
793
+ query_label = f"{template.label}-{template_idx}"
794
+ workload_queries[query_label] = query
795
+
796
+ return workload_queries
797
+
798
+
799
+ def generate_workload(
800
+ path: str | pathlib.Path,
801
+ *,
802
+ queries_per_template: int,
803
+ name: Optional[str] = None,
804
+ template_pattern: str = "*.toml",
805
+ db_connection: Optional[Database] = None,
806
+ ) -> Workload[str]:
807
+ """Produces a full workload based on a number of CEB templates.
808
+
809
+ Parameters
810
+ ----------
811
+ path : str | pathlib.Path
812
+ The directory containing the template files.
813
+ queries_per_template : int
814
+ The number of queries that should be generated for each template. Queries will be distinguished by increasing label
815
+ numbers.
816
+ name : Optional[str], optional
817
+ The name of the resulting workload.
818
+ template_pattern : str, optional
819
+ A GLOB pattern that all template files must match to be recognized as such. Defaults to *\\*.toml*.
820
+ db_connection : Optional[Database], optional
821
+ The database to use for fetching appropriate candidate values for the placeholders. If omitted, a default Postgres
822
+ connection will be opened.
823
+
824
+ Returns
825
+ -------
826
+ Workload[str]
827
+ The generated workload. Queries are differentiated by labels based on the template names.
828
+
829
+ Raises
830
+ ------
831
+ SamplingError
832
+ If the sampling algorithm could not satisfy all constraints of its predicates.
833
+ """
834
+ template_dir = path if isinstance(path, pathlib.Path) else pathlib.Path(path)
835
+ raw_workload = generate_raw_workload(
836
+ template_dir,
837
+ queries_per_template=queries_per_template,
838
+ template_pattern=template_pattern,
839
+ db_connection=db_connection,
840
+ )
841
+ workload_queries = {
842
+ label: parser.parse_query(query) for label, query in raw_workload.items()
843
+ }
844
+
845
+ return Workload(workload_queries, name=(name if name else ""), root=template_dir)
846
+
847
+
848
+ def persist_workload(
849
+ path: str | pathlib.Path, workload: Workload[str] | dict[str, str]
850
+ ) -> None:
851
+ """Stores all queries of a workload with one query per file in a specific directory.
852
+
853
+ Files are named according to the query lables.
854
+ """
855
+ path = pathlib.Path(path) if isinstance(path, str) else path
856
+ query_iter = (
857
+ workload.entries() if isinstance(workload, Workload) else workload.items()
858
+ )
859
+ query_formatter = (
860
+ formatter.format_quick if isinstance(workload, Workload) else lambda x: x
861
+ )
862
+ for label, query in query_iter:
863
+ query_file = path / f"{label}.sql"
864
+ with open(query_file, "w") as query_file:
865
+ query_file.write(query_formatter(query) + "\n")
866
+
867
+
868
+ class SamplingError(RuntimeError):
869
+ """Error to indicate that something went wrong during the sampling process.
870
+
871
+ This error can either be exposed to the user to indicate that something might be wrong with the templates (e.g. constraints
872
+ that are too restrictive or sampling that is too random), or within the sampling process. In the latter case this denotes
873
+ situations that will be resolved automatically within the generation process and without user intervention.
874
+ """
875
+
876
+ def __init__(self, message) -> None:
877
+ super().__init__(message)