PostBOUND 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. postbound/__init__.py +211 -0
  2. postbound/_base.py +6 -0
  3. postbound/_bench.py +1012 -0
  4. postbound/_core.py +1153 -0
  5. postbound/_hints.py +1373 -0
  6. postbound/_jointree.py +1079 -0
  7. postbound/_pipelines.py +1121 -0
  8. postbound/_qep.py +1986 -0
  9. postbound/_stages.py +876 -0
  10. postbound/_validation.py +734 -0
  11. postbound/db/__init__.py +72 -0
  12. postbound/db/_db.py +2348 -0
  13. postbound/db/_duckdb.py +785 -0
  14. postbound/db/mysql.py +1195 -0
  15. postbound/db/postgres.py +4216 -0
  16. postbound/experiments/__init__.py +12 -0
  17. postbound/experiments/analysis.py +674 -0
  18. postbound/experiments/benchmarking.py +54 -0
  19. postbound/experiments/ceb.py +877 -0
  20. postbound/experiments/interactive.py +105 -0
  21. postbound/experiments/querygen.py +334 -0
  22. postbound/experiments/workloads.py +980 -0
  23. postbound/optimizer/__init__.py +92 -0
  24. postbound/optimizer/__init__.pyi +73 -0
  25. postbound/optimizer/_cardinalities.py +369 -0
  26. postbound/optimizer/_joingraph.py +1150 -0
  27. postbound/optimizer/dynprog.py +1825 -0
  28. postbound/optimizer/enumeration.py +432 -0
  29. postbound/optimizer/native.py +539 -0
  30. postbound/optimizer/noopt.py +54 -0
  31. postbound/optimizer/presets.py +147 -0
  32. postbound/optimizer/randomized.py +650 -0
  33. postbound/optimizer/tonic.py +1479 -0
  34. postbound/optimizer/ues.py +1607 -0
  35. postbound/qal/__init__.py +343 -0
  36. postbound/qal/_qal.py +9678 -0
  37. postbound/qal/formatter.py +1089 -0
  38. postbound/qal/parser.py +2344 -0
  39. postbound/qal/relalg.py +4257 -0
  40. postbound/qal/transform.py +2184 -0
  41. postbound/shortcuts.py +70 -0
  42. postbound/util/__init__.py +46 -0
  43. postbound/util/_errors.py +33 -0
  44. postbound/util/collections.py +490 -0
  45. postbound/util/dataframe.py +71 -0
  46. postbound/util/dicts.py +330 -0
  47. postbound/util/jsonize.py +68 -0
  48. postbound/util/logging.py +106 -0
  49. postbound/util/misc.py +168 -0
  50. postbound/util/networkx.py +401 -0
  51. postbound/util/numbers.py +438 -0
  52. postbound/util/proc.py +107 -0
  53. postbound/util/stats.py +37 -0
  54. postbound/util/system.py +48 -0
  55. postbound/util/typing.py +35 -0
  56. postbound/vis/__init__.py +5 -0
  57. postbound/vis/fdl.py +69 -0
  58. postbound/vis/graphs.py +48 -0
  59. postbound/vis/optimizer.py +538 -0
  60. postbound/vis/plots.py +84 -0
  61. postbound/vis/tonic.py +70 -0
  62. postbound/vis/trees.py +105 -0
  63. postbound-0.19.0.dist-info/METADATA +355 -0
  64. postbound-0.19.0.dist-info/RECORD +67 -0
  65. postbound-0.19.0.dist-info/WHEEL +5 -0
  66. postbound-0.19.0.dist-info/licenses/LICENSE.txt +202 -0
  67. postbound-0.19.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,734 @@
1
+ """Pre-checks make sure that optimization strategies and input can be optimized as indicated.
2
+
3
+ These checks should prevent the optimization of queries that contain features that the optimization algorithm does not support,
4
+ as well as the usage of optimization algorithms that make decisions that the target database cannot enforce.
5
+
6
+ The `OptimizationPreCheck` defines the abstract interface that all checks should adhere to.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import abc
12
+ from collections.abc import Callable, Iterable
13
+ from dataclasses import dataclass
14
+ from typing import Optional
15
+
16
+ import networkx as nx
17
+
18
+ from . import util
19
+ from ._core import PhysicalOperator
20
+ from ._hints import HintType
21
+ from .db._db import Database
22
+ from .qal._qal import (
23
+ AbstractPredicate,
24
+ BasePredicate,
25
+ BinaryPredicate,
26
+ ColumnExpression,
27
+ CompoundOperator,
28
+ CompoundPredicate,
29
+ DirectTableSource,
30
+ ExplicitFromClause,
31
+ From,
32
+ ImplicitFromClause,
33
+ ImplicitSqlQuery,
34
+ JoinTableSource,
35
+ LogicalOperator,
36
+ SqlQuery,
37
+ SubqueryTableSource,
38
+ TableSource,
39
+ ValuesTableSource,
40
+ )
41
+
42
+ ImplicitFromClauseFailure = "NO_IMPLICIT_FROM_CLAUSE"
43
+ EquiJoinFailure = "NON_EQUI_JOIN"
44
+ InnerJoinFailure = "NON_INNER_JOIN"
45
+ ConjunctiveJoinFailure = "NON_CONJUNCTIVE_JOIN"
46
+ SubqueryFailure = "SUBQUERY"
47
+ DependentSubqueryFailure = "DEPENDENT_SUBQUERY"
48
+ CrossProductFailure = "CROSS_PRODUCT"
49
+ VirtualTablesFailure = "VIRTUAL_TABLES"
50
+ JoinPredicateFailure = "BAD_JOIN_PREDICATE"
51
+
52
+
53
+ @dataclass
54
+ class PreCheckResult:
55
+ """Wrapper for a validation result.
56
+
57
+ The result is used in two different ways: to model the check for supported database systems for optimization strategies and
58
+ to model the check for supported queries for optimization strategies.
59
+
60
+ The `ensure_all_passed` method can be used to quickly assert that no problems occurred.
61
+
62
+ Attributes
63
+ ----------
64
+ passed : bool
65
+ Indicates whether problems were detected
66
+ failure_reason : str | list[str], optional
67
+ Gives details about the problem(s) that were detected
68
+ """
69
+
70
+ passed: bool = True
71
+ failure_reason: str | list[str] = ""
72
+
73
+ @staticmethod
74
+ def with_all_passed() -> PreCheckResult:
75
+ """Generates a check result without any problems.
76
+
77
+ Returns
78
+ -------
79
+ PreCheckResult
80
+ The check result
81
+ """
82
+ return PreCheckResult()
83
+
84
+ @staticmethod
85
+ def merge(checks: Iterable[PreCheckResult]) -> PreCheckResult:
86
+ """Merges multiple check results into a single result.
87
+
88
+ The result is passed if all input checks are passed. If any of the checks failed, the failure reasons are merged into
89
+ a single list.
90
+
91
+ Parameters
92
+ ----------
93
+ checks : Iterable[PreCheckResult]
94
+ The check results to merge
95
+
96
+ Returns
97
+ -------
98
+ PreCheckResult
99
+ The merged check result
100
+ """
101
+ failures: list[str] = []
102
+ for check in checks:
103
+ if check.passed:
104
+ continue
105
+ failures.extend(util.enlist(check.failure_reason))
106
+ return (
107
+ PreCheckResult.with_all_passed()
108
+ if not failures
109
+ else PreCheckResult.with_failure(failures)
110
+ )
111
+
112
+ def with_failure(failure: str | list[str]) -> PreCheckResult:
113
+ """Generates a check result for a specific failure.
114
+
115
+ Parameters
116
+ ----------
117
+ failure : str | list[str]
118
+ The failure message(s)
119
+
120
+ Returns
121
+ -------
122
+ PreCheckResult
123
+ The check result
124
+ """
125
+ return PreCheckResult(False, failure)
126
+
127
+ def ensure_all_passed(self, context: SqlQuery | Database | None = None) -> None:
128
+ """Raises an error if the check contains any failures.
129
+
130
+ Depending on the context, a more specific error can be raised. The context is used to infer whether an optimization
131
+ strategy does not work on a database system, or whether an input query is not supported by an optimization strategy.
132
+
133
+ Parameters
134
+ ----------
135
+ context : SqlQuery | Database | None, optional
136
+ An indicator of the kind of check that was performed. This influences the kind of error that will be raised in case
137
+ of failure. Defaults to ``None`` if no further context is available.
138
+
139
+ Raises
140
+ ------
141
+ util.StateError
142
+ In case of failure if there is no additional context available
143
+ UnsupportedQueryError
144
+ In case of failure if the context is an SQL query
145
+ UnsupportedSystemError
146
+ In case of failure if the context is a database interface
147
+ """
148
+ if self.passed:
149
+ return
150
+ if context is None:
151
+ raise util.StateError(f"Pre check failed {self._generate_failure_str()}")
152
+ elif isinstance(context, SqlQuery):
153
+ raise UnsupportedQueryError(context, self.failure_reason)
154
+ elif isinstance(context, Database):
155
+ raise UnsupportedSystemError(context, self.failure_reason)
156
+
157
+ def _generate_failure_str(self) -> str:
158
+ """Creates a nice string of the failure messages from `failure_reason`s.
159
+
160
+ Returns
161
+ -------
162
+ str
163
+ The failure message
164
+ """
165
+ if not self.failure_reason:
166
+ return ""
167
+ elif isinstance(self.failure_reason, str):
168
+ inner_contents = self.failure_reason
169
+ elif isinstance(self.failure_reason, Iterable):
170
+ inner_contents = " | ".join(reason for reason in self.failure_reason)
171
+ else:
172
+ raise ValueError(
173
+ "Unexpected failure reason type: " + str(self.failure_reason)
174
+ )
175
+ return f"[{inner_contents}]"
176
+
177
+
178
+ class UnsupportedQueryError(RuntimeError):
179
+ """Error to indicate that a specific query cannot be optimized by a selected algorithms.
180
+
181
+ Parameters
182
+ ----------
183
+ query : SqlQuery
184
+ The unsupported query
185
+ features : str | list[str], optional
186
+ The features of the query that are unsupported. Defaults to an empty string
187
+ """
188
+
189
+ def __init__(self, query: SqlQuery, features: str | list[str] = "") -> None:
190
+ if isinstance(features, list):
191
+ features = ", ".join(features)
192
+ features_str = f" [{features}]" if features else ""
193
+
194
+ super().__init__(f"Query contains unsupported features{features_str}: {query}")
195
+ self.query = query
196
+ self.features = features
197
+
198
+
199
+ class UnsupportedSystemError(RuntimeError):
200
+ """Error to indicate that a selected query plan cannot be enforced on a target system.
201
+
202
+ Parameters
203
+ ----------
204
+ db_instance : Database
205
+ The database system without a required feature
206
+ reason : str, optional
207
+ The features that are not supported. Defaults to an empty string
208
+ """
209
+
210
+ def __init__(self, db_instance: Database, reason: str = "") -> None:
211
+ error_msg = f"Unsupported database system: {db_instance}"
212
+ if reason:
213
+ error_msg += f" ({reason})"
214
+ super().__init__(error_msg)
215
+ self.db_system = db_instance
216
+ self.reason = reason
217
+
218
+
219
+ class OptimizationPreCheck(abc.ABC):
220
+ """The pre-check interface.
221
+
222
+ This is the type that all concrete pre-checks must implement. It contains two check methods that correpond to the checks
223
+ on the database system and to the check on the input query. Both methods pass on all input data by default and must be
224
+ overwritten to execute the necessary checks.
225
+
226
+ Parameters
227
+ ----------
228
+ name : str
229
+ The name of the check. It should describe what features the check tests and will be used to represent the checks that
230
+ are present in an optimization pipeline.
231
+ """
232
+
233
+ def __init__(self, name: str) -> None:
234
+ self.name = name
235
+
236
+ def check_supported_query(self, query: SqlQuery) -> PreCheckResult:
237
+ """Validates that a specific query does not contain any features that cannot be handled by an optimization strategy.
238
+
239
+ Examples of such features can be non-equi join predicates, dependent subqueries or aggregations.
240
+
241
+ Parameters
242
+ ----------
243
+ query : SqlQuery
244
+ The query to check
245
+
246
+ Returns
247
+ -------
248
+ PreCheckResult
249
+ A description of whether the check passed and an indication of the failures.
250
+ """
251
+ return PreCheckResult.with_all_passed()
252
+
253
+ def check_supported_database_system(
254
+ self, database_instance: Database
255
+ ) -> PreCheckResult:
256
+ """Validates that a specific database system provides all features that are required by an optimization strategy.
257
+
258
+ Examples of such features can be support for cardinality hints or specific operators.
259
+
260
+ Parameters
261
+ ----------
262
+ database_instance : Database
263
+ The database to check
264
+
265
+ Returns
266
+ -------
267
+ PreCheckResult
268
+ A description of whether the check passed and an indication of the failures.
269
+ """
270
+ return PreCheckResult.with_all_passed()
271
+
272
+ @abc.abstractmethod
273
+ def describe(self) -> dict:
274
+ """Provides a JSON-serializable representation of the specific check, as well as important parameters.
275
+
276
+ Returns
277
+ -------
278
+ dict
279
+ The description
280
+
281
+ See Also
282
+ --------
283
+ postbound.postbound.OptimizationPipeline.describe
284
+ """
285
+ raise NotImplementedError
286
+
287
+ def __contains__(self, item: object) -> bool:
288
+ return item == self
289
+
290
+ def __hash__(self) -> int:
291
+ return hash(self.name)
292
+
293
+ def __eq__(self, other: object) -> bool:
294
+ return isinstance(other, type(self)) and self.name == other.name
295
+
296
+ def __repr__(self) -> str:
297
+ return f"OptimizationPreCheck [{self.name}]"
298
+
299
+ def __str__(self) -> str:
300
+ return self.name
301
+
302
+
303
+ class EmptyPreCheck(OptimizationPreCheck):
304
+ """Dummy check that does not actually validate anything."""
305
+
306
+ def __init__(self) -> None:
307
+ super().__init__("empty")
308
+
309
+ def check_supported_query(self, query: SqlQuery) -> PreCheckResult:
310
+ return PreCheckResult.with_all_passed()
311
+
312
+ def describe(self) -> dict:
313
+ return {"name": "no_check"}
314
+
315
+
316
+ class CompoundCheck(OptimizationPreCheck):
317
+ """A compound check combines an arbitrary number of base checks and asserts that all of them are satisfied.
318
+
319
+ If multiple checks fail, the `failure_reason` of the result contains all individual failure reasons.
320
+
321
+ Parameters
322
+ ----------
323
+ checks : Iterable[OptimizationPreCheck]
324
+ The checks that must all be passed.
325
+ """
326
+
327
+ def __init__(self, checks: Iterable[OptimizationPreCheck]) -> None:
328
+ super().__init__("compound-check")
329
+ checks = util.flatten(
330
+ [
331
+ check.checks if isinstance(check, CompoundCheck) else [check]
332
+ for check in checks
333
+ if not isinstance(check, EmptyPreCheck)
334
+ ]
335
+ )
336
+ self.checks = [
337
+ check for check in checks if not isinstance(check, EmptyPreCheck)
338
+ ]
339
+
340
+ def check_supported_query(self, query: SqlQuery) -> PreCheckResult:
341
+ check_results = [check.check_supported_query(query) for check in self.checks]
342
+ aggregated_passed = all(check_result.passed for check_result in check_results)
343
+ aggregated_failures = (
344
+ util.flatten(check_result.failure_reason for check_result in check_results)
345
+ if not aggregated_passed
346
+ else []
347
+ )
348
+ return PreCheckResult(aggregated_passed, aggregated_failures)
349
+
350
+ def describe(self) -> dict:
351
+ return {"multiple_checks": [check.describe() for check in self.checks]}
352
+
353
+ def __contains__(self, item: object) -> bool:
354
+ return super().__contains__(item) or any(
355
+ item in child_check for child_check in self.checks
356
+ )
357
+
358
+ def __hash__(self) -> int:
359
+ return hash(tuple(self.checks))
360
+
361
+ def __eq__(self, other: object) -> bool:
362
+ return isinstance(other, type(self)) and self.checks == other.checks
363
+
364
+ def __str__(self) -> str:
365
+ child_checks_str = "|".join(str(child_check) for child_check in self.checks)
366
+ return f"CompoundCheck [{child_checks_str}]"
367
+
368
+
369
+ def merge_checks(
370
+ checks: OptimizationPreCheck | Iterable[OptimizationPreCheck], *more_checks
371
+ ) -> OptimizationPreCheck:
372
+ """Combines all of the supplied checks into one compound check.
373
+
374
+ This method is smarter than creating a compound check directly. It eliminates duplicate checks as far as possible and
375
+ ignores empty checks.
376
+
377
+ If there is only a single (unique) check, this is returned directly
378
+
379
+ Parameters
380
+ ----------
381
+ checks : OptimizationPreCheck | Iterable[OptimizationPreCheck]
382
+ The checks to combine
383
+ *more_checks
384
+ Additional checks that should also be included
385
+
386
+ Returns
387
+ -------
388
+ OptimizationPreCheck
389
+ A check that combines all of the given checks.
390
+ """
391
+ if not checks:
392
+ return EmptyPreCheck()
393
+ all_checks = (
394
+ {checks} if isinstance(checks, OptimizationPreCheck) else set(checks)
395
+ ) | set(more_checks)
396
+ all_checks = {check for check in all_checks if check}
397
+ compound_checks = [
398
+ check for check in all_checks if isinstance(check, CompoundCheck)
399
+ ]
400
+ atomic_checks = {
401
+ check for check in all_checks if not isinstance(check, CompoundCheck)
402
+ }
403
+ compound_check_children = util.set_union(
404
+ set(check.checks) for check in compound_checks
405
+ )
406
+ merged_checks = atomic_checks | compound_check_children
407
+ merged_checks = {
408
+ check for check in merged_checks if not isinstance(check, EmptyPreCheck)
409
+ }
410
+ if not merged_checks:
411
+ return EmptyPreCheck()
412
+ return (
413
+ CompoundCheck(merged_checks)
414
+ if len(merged_checks) > 1
415
+ else util.simplify(merged_checks)
416
+ )
417
+
418
+
419
+ class ImplicitQueryPreCheck(OptimizationPreCheck):
420
+ """Check to assert that an input query is a `ImplicitSqlQuery`."""
421
+
422
+ def __init__(self) -> None:
423
+ super().__init__("implicit-query")
424
+
425
+ def check_supported_query(self, query: SqlQuery) -> PreCheckResult:
426
+ passed = isinstance(query, ImplicitSqlQuery)
427
+ failure_reason = "" if passed else ImplicitFromClauseFailure
428
+ return PreCheckResult(passed, failure_reason)
429
+
430
+ def describe(self) -> dict:
431
+ return {"name": "implicit_query"}
432
+
433
+
434
+ class CrossProductPreCheck(OptimizationPreCheck):
435
+ """Check to assert that a query does not contain any cross products."""
436
+
437
+ def __init__(self) -> None:
438
+ super().__init__("no-cross-products")
439
+
440
+ def check_supported_query(self, query: SqlQuery) -> PreCheckResult:
441
+ no_cross_products = nx.is_connected(query.predicates().join_graph())
442
+ failure_reason = "" if no_cross_products else CrossProductFailure
443
+ return PreCheckResult(no_cross_products, failure_reason)
444
+
445
+ def describe(self) -> dict:
446
+ return {"name": "no_cross_products"}
447
+
448
+
449
+ class VirtualTablesPreCheck(OptimizationPreCheck):
450
+ """Check to assert that a query does not contain any virtual tables."""
451
+
452
+ def __init__(self) -> None:
453
+ super().__init__("no-virtual-tables")
454
+
455
+ def check_supported_query(self, query: SqlQuery) -> PreCheckResult:
456
+ no_virtual_tables = all(not table.virtual for table in query.tables())
457
+ failure_reason = "" if no_virtual_tables else VirtualTablesFailure
458
+ return PreCheckResult(no_virtual_tables, failure_reason)
459
+
460
+ def describe(self) -> dict:
461
+ return {"name": "no_virtual_tables"}
462
+
463
+
464
+ class EquiJoinPreCheck(OptimizationPreCheck):
465
+ """Check to assert that a query only contains equi-joins.
466
+
467
+ This does not restrict the filters in any way. The determination of joins is based on `QueryPredicates.joins`.
468
+ """
469
+
470
+ def __init__(
471
+ self, *, allow_conjunctions: bool = False, allow_nesting: bool = False
472
+ ) -> None:
473
+ super().__init__("equi-joins-only")
474
+ self._allow_conjunctions = allow_conjunctions
475
+ self._allow_nesting = allow_nesting
476
+
477
+ def check_supported_query(self, query: SqlQuery) -> PreCheckResult:
478
+ join_predicates = query.predicates().joins()
479
+ all_passed = all(
480
+ self._perform_predicate_check(join_pred) for join_pred in join_predicates
481
+ )
482
+ failure_reason = "" if all_passed else EquiJoinFailure
483
+ return PreCheckResult(all_passed, failure_reason)
484
+
485
+ def describe(self) -> dict:
486
+ return {
487
+ "name": "equi_joins_only",
488
+ "allow_conjunctions": self._allow_conjunctions,
489
+ "allow_nesting": self._allow_nesting,
490
+ }
491
+
492
+ def _perform_predicate_check(self, predicate: AbstractPredicate) -> bool:
493
+ """Handler method to dispatch to the appropriate check utility depending on the predicate type.
494
+
495
+ Parameters
496
+ ----------
497
+ predicate : AbstractPredicate
498
+ The predicate to check
499
+
500
+ Returns
501
+ -------
502
+ bool
503
+ Whether the predicate passed the check
504
+ """
505
+ if isinstance(predicate, BasePredicate):
506
+ return self._perform_base_predicate_check(predicate)
507
+ elif isinstance(predicate, CompoundPredicate):
508
+ return self._perform_compound_predicate_check(predicate)
509
+ else:
510
+ return False
511
+
512
+ def _perform_base_predicate_check(self, predicate: BasePredicate) -> bool:
513
+ """Handler method to check a single base predicate.
514
+
515
+ Parameters
516
+ ----------
517
+ predicate : BasePredicate
518
+ The predicate to check
519
+
520
+ Returns
521
+ -------
522
+ bool
523
+ Whether the predicate passed the check
524
+ """
525
+ if not isinstance(predicate, BinaryPredicate) or len(predicate.columns()) != 2:
526
+ return False
527
+ if predicate.operation != LogicalOperator.Equal:
528
+ return False
529
+
530
+ if self._allow_nesting:
531
+ return True
532
+ first_is_col = isinstance(predicate.first_argument, ColumnExpression)
533
+ second_is_col = isinstance(predicate.second_argument, ColumnExpression)
534
+ return first_is_col and second_is_col
535
+
536
+ def _perform_compound_predicate_check(self, predicate: CompoundPredicate) -> bool:
537
+ """Handler method to check a compound predicate.
538
+
539
+ Parameters
540
+ ----------
541
+ predicate : CompoundPredicate
542
+ The predicate to check
543
+
544
+ Returns
545
+ -------
546
+ bool
547
+ Whether the predicate passed the check
548
+ """
549
+ if not self._allow_conjunctions:
550
+ return False
551
+ elif predicate.operation != CompoundOperator.And:
552
+ return False
553
+ return all(
554
+ self._perform_predicate_check(child_pred)
555
+ for child_pred in predicate.children
556
+ )
557
+
558
+ def __eq__(self, other: object) -> bool:
559
+ return (
560
+ isinstance(other, type(self))
561
+ and self._allow_conjunctions == other._allow_conjunctions
562
+ and self._allow_nesting == other._allow_nesting
563
+ )
564
+
565
+ def __hash__(self) -> int:
566
+ return hash((self.name, self._allow_conjunctions, self._allow_nesting))
567
+
568
+
569
+ class InnerJoinPreCheck(OptimizationPreCheck):
570
+ """Check to assert that a query only contains inner joins."""
571
+
572
+ def __init__(self) -> None:
573
+ super().__init__("inner-joins-only")
574
+
575
+ def check_supported_query(self, query: SqlQuery) -> PreCheckResult:
576
+ if not query.from_clause:
577
+ return PreCheckResult.with_all_passed()
578
+
579
+ match query.from_clause:
580
+ case ImplicitFromClause():
581
+ return PreCheckResult.with_all_passed()
582
+ case ExplicitFromClause(join):
583
+ return self._check_table_source(join)
584
+ case From(items):
585
+ checks = [self._check_table_source(entry) for entry in items]
586
+ return PreCheckResult.merge(checks)
587
+ case _:
588
+ raise ValueError(f"Unknown FROM clause type: {query.from_clause}")
589
+
590
+ def describe(self) -> dict:
591
+ return {"name": "inner_joins_only"}
592
+
593
+ def _check_table_source(self, source: TableSource) -> PreCheckResult:
594
+ """Handler method to check a single table source."""
595
+ match source:
596
+ case DirectTableSource() | ValuesTableSource():
597
+ return PreCheckResult.with_all_passed()
598
+ case SubqueryTableSource(subquery):
599
+ return self.check_supported_query(subquery)
600
+ case JoinTableSource(left, right, _, join_type):
601
+ checks = (
602
+ [PreCheckResult.with_failure(InnerJoinFailure)]
603
+ if join_type != "INNER"
604
+ else []
605
+ )
606
+ checks.extend(
607
+ [self._check_table_source(left), self._check_table_source(right)]
608
+ )
609
+ return PreCheckResult.merge(checks)
610
+ case _:
611
+ raise ValueError(f"Unknown table source type: {source}")
612
+
613
+
614
+ class SubqueryPreCheck(OptimizationPreCheck):
615
+ """Check to assert that a query does not contain any subqueries."""
616
+
617
+ def __init__(self) -> None:
618
+ super().__init__("no-subqueries")
619
+
620
+ def check_supported_query(self, query: SqlQuery) -> PreCheckResult:
621
+ return (
622
+ PreCheckResult.with_all_passed()
623
+ if not query.subqueries()
624
+ else PreCheckResult.with_failure(SubqueryFailure)
625
+ )
626
+
627
+ def describe(self) -> dict:
628
+ return {"name": "no_subqueries"}
629
+
630
+
631
+ class DependentSubqueryPreCheck(OptimizationPreCheck):
632
+ """Check to assert that a query does not contain any dependent subqueries."""
633
+
634
+ def __init__(self) -> None:
635
+ super().__init__("no-dependent-subquery")
636
+
637
+ def check_supported_query(self, query: SqlQuery) -> PreCheckResult:
638
+ passed = not any(subquery.is_dependent() for subquery in query.subqueries())
639
+ failure_reason = "" if passed else DependentSubqueryFailure
640
+ return PreCheckResult(passed, failure_reason)
641
+
642
+ def describe(self) -> dict:
643
+ return {"name": "no_dependent_subquery"}
644
+
645
+
646
+ class SetOperationsPreCheck(OptimizationPreCheck):
647
+ """Check to assert that a query does not contain any set operations (**UNION**, **EXCEPT**, etc.)."""
648
+
649
+ def __init__(self) -> None:
650
+ super().__init__("no-set-operations")
651
+
652
+ def check_supported_query(self, query: SqlQuery) -> PreCheckResult:
653
+ passed = not query.is_set_query()
654
+ failure_reason = "" if passed else "SET_OPERATION"
655
+ return PreCheckResult(passed, failure_reason)
656
+
657
+ def describe(self) -> dict:
658
+ return {"name": "no_set_operations"}
659
+
660
+
661
+ class SupportedHintCheck(OptimizationPreCheck):
662
+ """Check to assert that a number of operators are supported by a database system.
663
+
664
+ Parameters
665
+ ----------
666
+ hints : HintType | PhysicalOperator | Iterable[HintType | PhysicalOperator]
667
+ The operators and hints that have to be supported by the database system. Can be either a single hint, or an iterable
668
+ of hints.
669
+
670
+ See Also
671
+ --------
672
+ HintService.supports_hint
673
+ """
674
+
675
+ def __init__(
676
+ self, hints: HintType | PhysicalOperator | Iterable[HintType | PhysicalOperator]
677
+ ) -> None:
678
+ super().__init__("database-check")
679
+ self._features = util.enlist(hints)
680
+
681
+ def check_supported_database_system(
682
+ self, database_instance: Database
683
+ ) -> PreCheckResult:
684
+ failures = [
685
+ hint
686
+ for hint in self._features
687
+ if not database_instance.hinting().supports_hint(hint)
688
+ ]
689
+ passed = not failures
690
+ return PreCheckResult(passed, failures)
691
+
692
+ def describe(self) -> dict:
693
+ return {"name": "database_operator_support", "features": self._features}
694
+
695
+
696
+ class CustomCheck(OptimizationPreCheck):
697
+ """Check to quickly implement arbitrary one-off checks.
698
+
699
+ The custom check somewhat clashes with directly implementing the `OptimizationPreCheck` interface. The latter is generally
700
+ preferred since it is more readable and easier to understand. However, the custom check can be useful for checks that
701
+ will not be used in multiple places and are not worth the effort of creating a separate class.
702
+
703
+ Parameters
704
+ ----------
705
+ name : str, optional
706
+ The name of the check. It is heavily recommended to supply a descriptive name, even though a default value exists.
707
+ query_check : Optional[Callable[[SqlQuery], PreCheckResult]], optional
708
+ Check to apply to each query
709
+ db_check : Optional[Callable[[Database], PreCheckResult]], optional
710
+ Check to apply to the database
711
+ """
712
+
713
+ def __init__(
714
+ self,
715
+ name: str = "custom-check",
716
+ *,
717
+ query_check: Optional[Callable[[SqlQuery], PreCheckResult]] = None,
718
+ db_check: Optional[Callable[[Database], PreCheckResult]] = None,
719
+ ) -> None:
720
+ super().__init__(name)
721
+ self._query_check = query_check
722
+ self._db_check = db_check
723
+
724
+ def check_supported_query(self, query: SqlQuery) -> PreCheckResult:
725
+ if self._query_check is None:
726
+ return PreCheckResult.with_all_passed()
727
+ return self._query_check(query)
728
+
729
+ def check_supported_database_system(
730
+ self, database_instance: Database
731
+ ) -> PreCheckResult:
732
+ if self._db_check is None:
733
+ return PreCheckResult.with_all_passed()
734
+ return self._db_check(database_instance)