PostBOUND 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. postbound/__init__.py +211 -0
  2. postbound/_base.py +6 -0
  3. postbound/_bench.py +1012 -0
  4. postbound/_core.py +1153 -0
  5. postbound/_hints.py +1373 -0
  6. postbound/_jointree.py +1079 -0
  7. postbound/_pipelines.py +1121 -0
  8. postbound/_qep.py +1986 -0
  9. postbound/_stages.py +876 -0
  10. postbound/_validation.py +734 -0
  11. postbound/db/__init__.py +72 -0
  12. postbound/db/_db.py +2348 -0
  13. postbound/db/_duckdb.py +785 -0
  14. postbound/db/mysql.py +1195 -0
  15. postbound/db/postgres.py +4216 -0
  16. postbound/experiments/__init__.py +12 -0
  17. postbound/experiments/analysis.py +674 -0
  18. postbound/experiments/benchmarking.py +54 -0
  19. postbound/experiments/ceb.py +877 -0
  20. postbound/experiments/interactive.py +105 -0
  21. postbound/experiments/querygen.py +334 -0
  22. postbound/experiments/workloads.py +980 -0
  23. postbound/optimizer/__init__.py +92 -0
  24. postbound/optimizer/__init__.pyi +73 -0
  25. postbound/optimizer/_cardinalities.py +369 -0
  26. postbound/optimizer/_joingraph.py +1150 -0
  27. postbound/optimizer/dynprog.py +1825 -0
  28. postbound/optimizer/enumeration.py +432 -0
  29. postbound/optimizer/native.py +539 -0
  30. postbound/optimizer/noopt.py +54 -0
  31. postbound/optimizer/presets.py +147 -0
  32. postbound/optimizer/randomized.py +650 -0
  33. postbound/optimizer/tonic.py +1479 -0
  34. postbound/optimizer/ues.py +1607 -0
  35. postbound/qal/__init__.py +343 -0
  36. postbound/qal/_qal.py +9678 -0
  37. postbound/qal/formatter.py +1089 -0
  38. postbound/qal/parser.py +2344 -0
  39. postbound/qal/relalg.py +4257 -0
  40. postbound/qal/transform.py +2184 -0
  41. postbound/shortcuts.py +70 -0
  42. postbound/util/__init__.py +46 -0
  43. postbound/util/_errors.py +33 -0
  44. postbound/util/collections.py +490 -0
  45. postbound/util/dataframe.py +71 -0
  46. postbound/util/dicts.py +330 -0
  47. postbound/util/jsonize.py +68 -0
  48. postbound/util/logging.py +106 -0
  49. postbound/util/misc.py +168 -0
  50. postbound/util/networkx.py +401 -0
  51. postbound/util/numbers.py +438 -0
  52. postbound/util/proc.py +107 -0
  53. postbound/util/stats.py +37 -0
  54. postbound/util/system.py +48 -0
  55. postbound/util/typing.py +35 -0
  56. postbound/vis/__init__.py +5 -0
  57. postbound/vis/fdl.py +69 -0
  58. postbound/vis/graphs.py +48 -0
  59. postbound/vis/optimizer.py +538 -0
  60. postbound/vis/plots.py +84 -0
  61. postbound/vis/tonic.py +70 -0
  62. postbound/vis/trees.py +105 -0
  63. postbound-0.19.0.dist-info/METADATA +355 -0
  64. postbound-0.19.0.dist-info/RECORD +67 -0
  65. postbound-0.19.0.dist-info/WHEEL +5 -0
  66. postbound-0.19.0.dist-info/licenses/LICENSE.txt +202 -0
  67. postbound-0.19.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,785 @@
1
+ # We name this file _duckdb instead of duckdb to avoid conflicts with the official duckdb package. Do not change this!
2
+ # The module is available in the __init__ of the db package under the duckdb name. This solves our problems for now.
3
+ from __future__ import annotations
4
+
5
+ import concurrent.futures
6
+ import json
7
+ import math
8
+ import textwrap
9
+ import time
10
+ import warnings
11
+ from collections import UserString
12
+ from collections.abc import Iterable, Sequence
13
+ from dataclasses import dataclass
14
+ from pathlib import Path
15
+ from typing import Any, Optional
16
+
17
+ from .. import optimizer, qal
18
+ from .._core import (
19
+ Cardinality,
20
+ ColumnReference,
21
+ Cost,
22
+ JoinOperator,
23
+ PhysicalOperator,
24
+ ScanOperator,
25
+ TableReference,
26
+ )
27
+ from .._hints import (
28
+ HintType,
29
+ PhysicalOperatorAssignment,
30
+ PlanParameterization,
31
+ )
32
+ from .._jointree import JoinTree
33
+ from .._qep import QueryPlan
34
+ from ..qal import transform
35
+ from ..qal._qal import SqlQuery
36
+ from ..util import Version, jsondict
37
+ from ._db import (
38
+ Cursor,
39
+ Database,
40
+ DatabasePool,
41
+ DatabaseSchema,
42
+ DatabaseStatistics,
43
+ HintService,
44
+ OptimizerInterface,
45
+ ResultSet,
46
+ UnsupportedDatabaseFeatureError,
47
+ simplify_result_set,
48
+ )
49
+ from .postgres import PostgresLimitClause
50
+
51
+
52
+ class DuckDBInterface(Database):
53
+ def __init__(
54
+ self,
55
+ db: Path,
56
+ *,
57
+ system_name: str = "DuckDB",
58
+ cache_enabled: bool = False,
59
+ ) -> None:
60
+ import duckdb
61
+
62
+ super().__init__(system_name=system_name, cache_enabled=cache_enabled)
63
+
64
+ self._dbfile = db
65
+
66
+ self._cur = duckdb.connect(db)
67
+ self._last_query_runtime = math.nan
68
+
69
+ self._stats = DuckDBStatistics(self)
70
+ self._schema = DuckDBSchema(self)
71
+ self._optimizer = DuckDBOptimizer(self)
72
+ self._hinting = DuckDBHintService(self)
73
+
74
+ self._timeout_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
75
+
76
+ def schema(self) -> DuckDBSchema:
77
+ return self._schema
78
+
79
+ def statistics(self) -> DatabaseStatistics:
80
+ return self._stats
81
+
82
+ def hinting(self) -> HintService:
83
+ return self._hinting
84
+
85
+ def optimizer(self) -> OptimizerInterface:
86
+ return self._optimizer
87
+
88
+ def execute_query(
89
+ self,
90
+ query: SqlQuery | str,
91
+ *,
92
+ cache_enabled: Optional[bool] = None,
93
+ raw: bool = False,
94
+ timeout: Optional[float] = None,
95
+ ) -> Any:
96
+ if (
97
+ isinstance(query, SqlQuery)
98
+ and query.hints
99
+ and query.hints.preparatory_statements
100
+ ):
101
+ for preparatory_statement in query.hints.preparatory_statements:
102
+ self._cur.execute(preparatory_statement)
103
+ query = transform.drop_hints(query, preparatory_statements_only=True)
104
+ if isinstance(query, SqlQuery):
105
+ query = self._hinting.format_query(query)
106
+ elif isinstance(query, UserString):
107
+ query = str(query)
108
+
109
+ if cache_enabled:
110
+ cached_res = self._query_cache.get(query)
111
+ if cached_res is not None:
112
+ return cached_res if raw else simplify_result_set(cached_res)
113
+
114
+ if timeout is not None:
115
+ raw_result = self.execute_with_timeout(query, timeout=timeout)
116
+ if raw_result is None:
117
+ raise TimeoutError(query)
118
+ else:
119
+ start_time = time.perf_counter_ns()
120
+ self._cur.execute(query)
121
+ end_time = time.perf_counter_ns()
122
+
123
+ raw_result = self._cur.fetchall()
124
+ self._last_query_runtime = (
125
+ end_time - start_time
126
+ ) / 10**9 # convert to seconds
127
+
128
+ if cache_enabled:
129
+ self._query_cache[query] = raw_result
130
+
131
+ return raw_result if raw else simplify_result_set(raw_result)
132
+
133
+ def execute_with_timeout(
134
+ self, query: SqlQuery | str, *, timeout: float = 60.0
135
+ ) -> Optional[ResultSet]:
136
+ if isinstance(query, SqlQuery):
137
+ query = self._hinting.format_query(query)
138
+
139
+ promise = self._timeout_executor.submit(self._execute_worker, query)
140
+ try:
141
+ result_set = promise.result(timeout=timeout)
142
+ return result_set
143
+ except concurrent.futures.TimeoutError:
144
+ self._cur.interrupt()
145
+ promise.result()
146
+
147
+ # Make sure to update the last query runtime just now, the worker might terminate while we try to cancel and update
148
+ # the runtime itself. This way we overwrite any measurement that the worker might have produced.
149
+ self._last_query_runtime = math.inf
150
+ return None
151
+
152
+ def last_query_runtime(self) -> float:
153
+ return self._last_query_runtime
154
+
155
+ def time_query(self, query: SqlQuery, *, timeout: Optional[float] = None) -> float:
156
+ self.execute_query(query, cache_enabled=False, raw=True, timeout=timeout)
157
+ return self.last_query_runtime()
158
+
159
+ def database_name(self) -> str:
160
+ self._cur.execute("SELECT CURRENT_DATABASE();")
161
+ db_name = self._cur.fetchone()[0]
162
+ return db_name
163
+
164
+ def database_system_version(self) -> Version:
165
+ self._cur.execute("SELECT version();")
166
+ raw_ver: str = self._cur.fetchone()[0]
167
+ raw_ver = raw_ver.removeprefix("v")
168
+ raw_ver = raw_ver.split("-")[0] # remove the build information
169
+ return Version(raw_ver)
170
+
171
+ def cursor(self) -> Cursor:
172
+ return self._cur
173
+
174
+ def close(self) -> None:
175
+ self._cur.close()
176
+
177
+ def reconnect(self) -> None:
178
+ import duckdb
179
+
180
+ self._cur = duckdb.connect(self._dbfile)
181
+
182
+ def reset_connection(self) -> None:
183
+ import duckdb
184
+
185
+ try:
186
+ self.close()
187
+ except Exception:
188
+ pass
189
+
190
+ self._cur = duckdb.connect(self._dbfile)
191
+
192
+ def describe(self) -> jsondict:
193
+ base_info = {
194
+ "system_name": self.database_system_name(),
195
+ "system_version": self.database_system_version(),
196
+ "database": self.database_name(),
197
+ }
198
+
199
+ schema_info: list[jsondict] = []
200
+ for table in self._schema.tables():
201
+ column_info: list[jsondict] = []
202
+
203
+ for column in self._schema.columns(table):
204
+ column_info.append(
205
+ {
206
+ "column": str(column),
207
+ "indexed": self._schema.has_index(column),
208
+ "foreign_keys": self._schema.foreign_keys_on(column),
209
+ }
210
+ )
211
+
212
+ schema_info.append(
213
+ {
214
+ "table": str(table),
215
+ "n_rows": self.statistics().total_rows(table, emulated=True),
216
+ "columns": column_info,
217
+ "primary_key": self._schema.primary_key_column(table),
218
+ }
219
+ )
220
+
221
+ base_info["schema_info"] = schema_info
222
+ return base_info
223
+
224
+ def _execute_worker(self, query: str) -> ResultSet:
225
+ start_time = time.perf_counter_ns()
226
+ self._cur.execute(query)
227
+ end_time = time.perf_counter_ns()
228
+ result_set = self._cur.fetchall()
229
+ self._last_query_runtime = (end_time - start_time) / 10**9 # convert to seconds
230
+ return result_set
231
+
232
+
233
+ class DuckDBSchema(DatabaseSchema):
234
+ def __init__(self, db: DuckDBInterface) -> None:
235
+ super().__init__(db, prep_placeholder="?")
236
+
237
+ def has_secondary_index(self, column: ColumnReference) -> bool:
238
+ if not column.is_bound():
239
+ raise ValueError(
240
+ f"Cannot check index status for {column}: Column is not bound to a table"
241
+ )
242
+
243
+ schema_placeholder = "?" if column.table.schema else "current_schema()"
244
+
245
+ query_template = textwrap.dedent(f"""
246
+ SELECT ddbi.index_name
247
+ FROM duckdb_indexes() ddbi
248
+ WHERE ddbi.table_name = ?
249
+ AND ltrim(rtrim(ddbi.expressions, ']'), '[') = ?
250
+ AND ddbi.database_name = current_database()
251
+ AND ddbi.schema_name = {schema_placeholder}
252
+ """)
253
+
254
+ params = [column.table.full_name, column.name]
255
+ if column.table.schema:
256
+ params.append(column.table.schema)
257
+
258
+ cur = self._db.cursor()
259
+ cur.execute(query_template, parameters=params)
260
+ result_set = cur.fetchone()
261
+
262
+ return result_set is not None
263
+
264
+ def indexes_on(self, column: ColumnReference) -> set[str]:
265
+ if not column.is_bound():
266
+ raise ValueError(
267
+ f"Cannot retrieve indexes for {column}: Column is not bound to a table"
268
+ )
269
+
270
+ schema_placeholder = "?" if column.table.schema else "current_schema()"
271
+
272
+ # The query template is much more complicated here, due to the different semantics of the constraint_column_usage
273
+ # view. For UNIQUE constraints, the column is the column that is constrained. However, for foreign keys, the column
274
+ # is the column that is being referenced.
275
+ # Notice that this template is different from the vanilla template provided by the default implementation: in the
276
+ # first part, we query from duckdb_indexes() instead of information_schema.key_column_usage!
277
+
278
+ query_template = textwrap.dedent(f"""
279
+ SELECT ddbi.index_name
280
+ FROM duckdb_indexes() ddbi
281
+ WHERE ddbi.table_name = ?
282
+ AND ltrim(rtrim(ddbi.expressions, ']'), '[') = ?
283
+ AND ddbi.database_name = current_database()
284
+ AND ddbi.schema_name = {schema_placeholder}
285
+ UNION
286
+ SELECT tc.constraint_name
287
+ FROM information_schema.table_constraints tc
288
+ JOIN information_schema.constraint_column_usage ccu
289
+ ON tc.constraint_name = ccu.constraint_name
290
+ AND tc.table_name = ccu.table_name
291
+ AND tc.table_schema = ccu.table_schema
292
+ AND tc.table_catalog = ccu.table_catalog
293
+ WHERE tc.constraint_type IN ('PRIMARY KEY', 'UNIQUE')
294
+ AND ccu.table_name = ?
295
+ AND ccu.column_name = ?
296
+ AND ccu.table_catalog = current_database()
297
+ AND ccu.table_schema = {schema_placeholder}
298
+ """)
299
+
300
+ # Due to the UNION query, we need to repeat the placeholders. While the implementation is definitely not elegant,
301
+ # this solution is arguably better than relying on named parameters which might or might not be supported by the
302
+ # target database.
303
+ params = [column.table.full_name, column.name]
304
+ if column.table.schema:
305
+ params.append(column.table.schema)
306
+ params.extend([column.table.full_name, column.name])
307
+ if column.table.schema:
308
+ params.append(column.table.schema)
309
+
310
+ cur = self._db.cursor()
311
+ cur.execute(query_template, params)
312
+ result_set = cur.fetchall()
313
+
314
+ return {row[0] for row in result_set}
315
+
316
+
317
+ class DuckDBStatistics(DatabaseStatistics):
318
+ def __init__(
319
+ self,
320
+ db: DuckDBInterface,
321
+ *,
322
+ emulated: bool = False,
323
+ enable_emulation_fallback: bool = True,
324
+ cache_enabled: Optional[bool] = True,
325
+ ) -> None:
326
+ super().__init__(
327
+ db,
328
+ emulated=emulated,
329
+ enable_emulation_fallback=enable_emulation_fallback,
330
+ cache_enabled=cache_enabled,
331
+ )
332
+
333
+ def _retrieve_total_rows_from_stats(self, table: TableReference) -> Optional[int]:
334
+ schema_placeholder = "?" if table.schema else "current_schema()"
335
+
336
+ query_template = textwrap.dedent(f"""
337
+ SELECT estimated_size
338
+ FROM duckdb_tables()
339
+ WHERE table_name = ?
340
+ AND database_name = current_database()
341
+ AND schema_name = {schema_placeholder}
342
+ """)
343
+
344
+ params = [table.full_name]
345
+ if table.schema:
346
+ params.append(table.schema)
347
+
348
+ self._db.cursor().execute(query_template, parameters=params)
349
+ result_set = self._db.cursor().fetchone()
350
+
351
+ if not result_set:
352
+ return None
353
+
354
+ return result_set[0] if result_set[0] is not None else None
355
+
356
+ def _retrieve_distinct_values_from_stats(
357
+ self, column: ColumnReference
358
+ ) -> Optional[int]:
359
+ raise UnsupportedDatabaseFeatureError(
360
+ self._db, "distinct value count statistics."
361
+ )
362
+
363
+ def _retrieve_min_max_values_from_stats(
364
+ self, column: ColumnReference
365
+ ) -> Optional[tuple[Any, Any]]:
366
+ raise UnsupportedDatabaseFeatureError(self._db, "min/max value statistics.")
367
+
368
+ def _retrieve_most_common_values_from_stats(
369
+ self, column: ColumnReference, k: int
370
+ ) -> Sequence[tuple[Any, int]]:
371
+ raise UnsupportedDatabaseFeatureError(self._db, "most common value statistics.")
372
+
373
+
374
+ def parse_duckdb_plan(raw_plan: dict) -> QueryPlan:
375
+ node_type = raw_plan.get("name") or raw_plan.get("operator_name")
376
+ if not node_type:
377
+ assert len(raw_plan["children"]) == 1, (
378
+ "Expected a single child for the root operator"
379
+ )
380
+ return parse_duckdb_plan(raw_plan["children"][0])
381
+
382
+ if node_type == "EXPLAIN" or node_type == "EXPLAIN_ANALYZE":
383
+ assert len(raw_plan["children"]) == 1, (
384
+ "Expected a single child for EXPLAIN operator"
385
+ )
386
+ return parse_duckdb_plan(raw_plan["children"][0])
387
+
388
+ extras: dict = raw_plan.get("extra_info", {})
389
+ match node_type:
390
+ case "HASH_JOIN":
391
+ operator = JoinOperator.HashJoin
392
+ case (
393
+ "SEQ_SCAN" | "SEQ_SCAN " # DuckDB has a weird typo in the SEQ_SCAN label
394
+ ) if extras.get("Type", "") == "Sequential Scan":
395
+ operator = ScanOperator.SequentialScan
396
+ case (
397
+ "SEQ_SCAN" | "SEQ_SCAN " # DuckDB has a weird typo in the SEQ_SCAN label
398
+ ) if extras.get("Type", "") == "Index Scan":
399
+ operator = ScanOperator.IndexScan
400
+ case "PROJECTION" | "FILTER" | "UNGROUPED_AGGREGATE" | "PERFECT_HASH_GROUP_BY":
401
+ operator = None
402
+ case _:
403
+ warnings.warn(f"Unknown node type: {node_type}, ({extras})")
404
+ operator = None
405
+ if operator is not None:
406
+ node_type = operator
407
+
408
+ base_table = None
409
+ if operator and operator in ScanOperator:
410
+ tab = extras.get("Table", "")
411
+ if tab:
412
+ base_table = TableReference(tab)
413
+
414
+ card_est = float(
415
+ extras.get("Estimated Cardinality", math.nan)
416
+ ) # Estimated Cardinality is a string for some reason..
417
+ card_act = raw_plan.get("operator_cardinality", math.nan)
418
+
419
+ children = [parse_duckdb_plan(child) for child in raw_plan.get("children", [])]
420
+
421
+ own_runtime = extras.get("operator_timing", math.nan)
422
+ total_runtime = own_runtime + sum(child.execution_time for child in children)
423
+
424
+ return QueryPlan(
425
+ node_type,
426
+ operator=operator,
427
+ children=children,
428
+ base_table=base_table,
429
+ estimated_cardinality=card_est,
430
+ actual_cardinality=card_act,
431
+ execution_time=total_runtime,
432
+ )
433
+
434
+
435
+ class DuckDBOptimizer(OptimizerInterface):
436
+ def __init__(self, db: DuckDBInterface) -> None:
437
+ self._db = db
438
+
439
+ def query_plan(self, query: SqlQuery | str) -> QueryPlan:
440
+ if isinstance(query, SqlQuery):
441
+ query = qal.transform.as_explain(query)
442
+ query = self._db.hinting().format_query(query)
443
+ else:
444
+ normalized = query.strip().upper()
445
+ if not normalized.startswith("EXPLAIN"):
446
+ normalized = f"EXPLAIN (FORMAT JSON) {normalized}"
447
+ query = normalized
448
+
449
+ self._db.cursor().execute(query)
450
+ result_set = self._db.cursor().fetchone()
451
+ assert len(result_set) == 2
452
+
453
+ raw_explain = result_set[1]
454
+ parsed = json.loads(raw_explain)
455
+ return parse_duckdb_plan(parsed[0])
456
+
457
+ def analyze_plan(
458
+ self, query: SqlQuery, *, timeout: Optional[float] = None
459
+ ) -> Optional[QueryPlan]:
460
+ query = qal.transform.as_explain_analyze(query, qal.Explain)
461
+
462
+ try:
463
+ result_set = self._db.execute_query(
464
+ query, cache_enabled=False, raw=True, timeout=timeout
465
+ )[0]
466
+ except TimeoutError:
467
+ return None
468
+ assert len(result_set) == 2
469
+
470
+ raw_explain = result_set[1]
471
+ parsed = json.loads(raw_explain)
472
+ return parse_duckdb_plan(parsed[0])
473
+
474
+ def cardinality_estimate(self, query: SqlQuery | str) -> Cardinality:
475
+ plan = self.query_plan(query)
476
+ if "AGGREGATE" in plan.node_type:
477
+ warnings.warn(
478
+ "Plan could have an aggregate node as root. DuckDB does not estimate cardinalities for aggregations."
479
+ )
480
+ return plan.estimated_cardinality
481
+
482
+ def cost_estimate(self, query: SqlQuery | str) -> Cost:
483
+ raise UnsupportedDatabaseFeatureError(self._db, "cost estimates")
484
+
485
+
486
+ @dataclass
487
+ class HintParts:
488
+ """Models the different kinds of optimizer hints that are supported by Postgres.
489
+
490
+ HintParts are designed to conveniently collect all kinds of hints in order to prepare the generation of a proper
491
+ `Hint` clause.
492
+
493
+ See Also
494
+ --------
495
+ Hint
496
+ """
497
+
498
+ settings: list[str]
499
+ """Settings are global to the current database connection and influence the selection of operators for all queries.
500
+
501
+ Typical examples include ``SET enable_nestloop = 'off'``, which disables the usage of nested loop joins for all
502
+ queries.
503
+ """
504
+
505
+ hints: list[str]
506
+ """Hints are supplied by the *quack_lab* extension and influence optimizer decisions on a per-query basis.
507
+
508
+ Typical examples include the selection of a specific join order as well as the assignment of join operators to
509
+ individual joins.
510
+ """
511
+
512
+ @staticmethod
513
+ def empty() -> HintParts:
514
+ """Creates a new hint parts object without any contents.
515
+
516
+ Returns
517
+ -------
518
+ HintParts
519
+ A fresh plain hint parts object
520
+ """
521
+ return HintParts([], [])
522
+
523
+ def add(self, hint: str) -> None:
524
+ """Adds a new hint.
525
+
526
+ This modifies the current object.
527
+
528
+ Parameters
529
+ ----------
530
+ hint : str
531
+ The hint to add
532
+ """
533
+ self.hints.append(hint)
534
+
535
+ def merge_with(self, other: HintParts) -> HintParts:
536
+ """Combines the hints that are contained in this hint parts object with all hints in the other object.
537
+
538
+ This constructs new hint parts and leaves the current objects unmodified.
539
+
540
+ Parameters
541
+ ----------
542
+ other : HintParts
543
+ The additional hints to incorporate
544
+
545
+ Returns
546
+ -------
547
+ HintParts
548
+ A new hint parts object that contains the hints from both source objects
549
+ """
550
+ merged_settings = self.settings + [
551
+ setting for setting in other.settings if setting not in self.settings
552
+ ]
553
+ merged_hints = (
554
+ self.hints + [""] + [hint for hint in other.hints if hint not in self.hints]
555
+ )
556
+ return HintParts(merged_settings, merged_hints)
557
+
558
+ def __bool__(self) -> bool:
559
+ return bool(self.settings or self.hints)
560
+
561
+
562
+ class DuckDBHintService(HintService):
563
+ def __init__(self, db: DuckDBInterface) -> None:
564
+ self._db = db
565
+
566
+ def generate_hints(
567
+ self,
568
+ query: SqlQuery,
569
+ plan: Optional[QueryPlan] = None,
570
+ *,
571
+ join_order: Optional[JoinTree] = None,
572
+ physical_operators: Optional[PhysicalOperatorAssignment] = None,
573
+ plan_parameters: Optional[PlanParameterization] = None,
574
+ ) -> SqlQuery:
575
+ adapted_query = query
576
+ if adapted_query.limit_clause and not isinstance(
577
+ adapted_query.limit_clause, PostgresLimitClause
578
+ ):
579
+ adapted_query = qal.transform.replace_clause(
580
+ adapted_query, PostgresLimitClause(adapted_query.limit_clause)
581
+ )
582
+
583
+ has_partial_hints = any(
584
+ param is not None
585
+ for param in (join_order, physical_operators, plan_parameters)
586
+ )
587
+ if plan is not None and has_partial_hints:
588
+ raise ValueError(
589
+ "Can only hint an entire query plan, or individual parts, not both."
590
+ )
591
+
592
+ if plan is not None:
593
+ join_order = optimizer.jointree_from_plan(plan)
594
+ physical_operators = optimizer.operators_from_plan(plan)
595
+ plan_parameters = optimizer.parameters_from_plan(plan)
596
+
597
+ if join_order is not None:
598
+ hint_parts = self._generate_join_order_hint(join_order)
599
+ else:
600
+ hint_parts = HintParts.empty()
601
+
602
+ hint_parts = hint_parts if hint_parts else HintParts.empty()
603
+
604
+ if physical_operators:
605
+ operator_hints = self._generate_operator_hints(physical_operators)
606
+ hint_parts = hint_parts.merge_with(operator_hints)
607
+
608
+ if plan_parameters:
609
+ plan_hints = self._generate_parameter_hints(plan_parameters)
610
+ hint_parts = hint_parts.merge_with(plan_hints)
611
+
612
+ if hint_parts:
613
+ adapted_query = self._add_hint_block(adapted_query, hint_parts)
614
+
615
+ return adapted_query
616
+
617
+ def format_query(self, query: SqlQuery) -> str:
618
+ # DuckDB uses the Postgres SQL dialect, so this part is easy..
619
+ return qal.format_quick(query, flavor="postgres")
620
+
621
+ def supports_hint(self, hint: PhysicalOperator | HintType) -> bool:
622
+ return hint in {
623
+ ScanOperator.SequentialScan,
624
+ ScanOperator.IndexScan,
625
+ JoinOperator.NestedLoopJoin,
626
+ JoinOperator.HashJoin,
627
+ JoinOperator.SortMergeJoin,
628
+ HintType.LinearJoinOrder,
629
+ HintType.BushyJoinOrder,
630
+ HintType.Cardinality,
631
+ HintType.Operator,
632
+ }
633
+
634
+ def _generate_join_order_hint(self, join_tree: JoinTree) -> HintParts:
635
+ if len(join_tree) < 3:
636
+ # we can't force the join direction anyway, so there's no point in generating a hint if there is just a single join
637
+ return HintParts.empty()
638
+
639
+ def recurse(join_tree: JoinTree) -> str:
640
+ if join_tree.is_scan():
641
+ return join_tree.base_table.identifier()
642
+
643
+ lhs = recurse(join_tree.outer_child)
644
+ rhs = recurse(join_tree.inner_child)
645
+
646
+ return f"({lhs} {rhs})"
647
+
648
+ join_order = recurse(join_tree)
649
+ hint_parts = HintParts([], [f"JoinOrder({join_order})"])
650
+ return hint_parts
651
+
652
+ def _generate_operator_hints(self, ops: PhysicalOperatorAssignment) -> HintParts:
653
+ if not ops:
654
+ return HintParts.empty()
655
+
656
+ hints: list[str] = []
657
+ for tab, scan in ops.scan_operators.items():
658
+ match scan.operator:
659
+ case ScanOperator.SequentialScan:
660
+ op_txt = "SeqScan"
661
+ case ScanOperator.IndexScan:
662
+ op_txt = "IdxScan"
663
+ case _:
664
+ raise UnsupportedDatabaseFeatureError(self._db, scan.operator)
665
+ tab_txt = tab.identifier()
666
+ hints.append(f"{op_txt}({tab_txt})")
667
+
668
+ for intermediate, join in ops.join_operators.items():
669
+ match join.operator:
670
+ case JoinOperator.NestedLoopJoin:
671
+ op_txt = "NestLoop"
672
+ case JoinOperator.HashJoin:
673
+ op_txt = "HashJoin"
674
+ case JoinOperator.SortMergeJoin:
675
+ op_txt = "MergeJoin"
676
+ case _:
677
+ raise UnsupportedDatabaseFeatureError(self._db, join.operator)
678
+
679
+ intermediate_txt = self._intermediate_to_hint(intermediate)
680
+ hints.append(f"{op_txt}({intermediate_txt})")
681
+
682
+ if ops.intermediate_operators:
683
+ raise UnsupportedDatabaseFeatureError(
684
+ self._db,
685
+ "intermediate operators",
686
+ )
687
+
688
+ for param, val in ops.global_settings.items():
689
+ match param:
690
+ case ScanOperator.SequentialScan:
691
+ param_txt = "enable_seqscan"
692
+ case ScanOperator.IndexScan:
693
+ param_txt = "enable_indexscan"
694
+ case JoinOperator.NestedLoopJoin:
695
+ param_txt = "enable_nestloop"
696
+ case JoinOperator.HashJoin:
697
+ param_txt = "enable_hashjoin"
698
+ case JoinOperator.SortMergeJoin:
699
+ param_txt = "enable_mergejoin"
700
+ case _:
701
+ raise UnsupportedDatabaseFeatureError(self._db, param)
702
+
703
+ val_txt = "'on'" if val else "'off'"
704
+ hints.append(f"Set({param_txt} = {val_txt})")
705
+
706
+ return HintParts([], hints)
707
+
708
+ def _generate_parameter_hints(self, parameters: PlanParameterization) -> HintParts:
709
+ if not parameters:
710
+ return HintParts.empty()
711
+
712
+ hints: list[str] = []
713
+ for intermediate, card in parameters.cardinalities.items():
714
+ if not card.is_valid():
715
+ continue
716
+ intermediate_txt = self._intermediate_to_hint(intermediate)
717
+ hints.append(f"Card({intermediate_txt} #{card})")
718
+
719
+ if parameters.parallel_workers:
720
+ raise UnsupportedDatabaseFeatureError(self._db, "parallel worker hints")
721
+
722
+ global_settings: list[str] = []
723
+ for param, val in parameters.system_settings.items():
724
+ if isinstance(val, str):
725
+ val_txt = f"'{val}'"
726
+ else:
727
+ val_txt = str(val)
728
+
729
+ global_settings.append(f"{param} = {val_txt};")
730
+
731
+ return HintParts(global_settings, hints)
732
+
733
+ def _add_hint_block(self, query: SqlQuery, hint_parts: HintParts) -> SqlQuery:
734
+ if not hint_parts:
735
+ return query
736
+
737
+ local_hints = ["/*=quack_lab="]
738
+ for local_hint in hint_parts.hints:
739
+ local_hints.append(f" {local_hint}")
740
+ local_hints.append(" */")
741
+
742
+ hints = qal.Hint("\n".join(hint_parts.settings), "\n".join(local_hints))
743
+ return qal.transform.add_clause(query, hints)
744
+
745
+ def _intermediate_to_hint(self, intermediate: Iterable[TableReference]) -> str:
746
+ """Convert an iterable of TableReferences to a string representation."""
747
+ return " ".join(table.identifier() for table in intermediate)
748
+
749
+
750
+ def _reconnect(name: str, *, pool: DatabasePool) -> DuckDBInterface:
751
+ import duckdb
752
+
753
+ current_conn: DuckDBInterface = pool.retrieve_database(name)
754
+
755
+ try:
756
+ # check if the connection is still active
757
+ current_conn.execute_query("SELECT 1", cache_enabled=False, raw=True)
758
+ except duckdb.ConnectionException as e:
759
+ # otherwise re-establish the connection
760
+ if "Connection Error: Connection already closed!" in e.args:
761
+ current_conn.reconnect()
762
+ else:
763
+ raise e
764
+
765
+ return current_conn
766
+
767
+
768
+ def connect(
769
+ db: str | Path,
770
+ *,
771
+ name: str = "duckdb",
772
+ refresh: bool = False,
773
+ private: bool = False,
774
+ ) -> DuckDBInterface:
775
+ db_pool = DatabasePool.get_instance()
776
+ if name in db_pool and not refresh:
777
+ return _reconnect(name, pool=db_pool)
778
+
779
+ db = Path(db)
780
+ duckdb_instance = DuckDBInterface(db)
781
+
782
+ if not private:
783
+ db_pool.register_database(name, duckdb_instance)
784
+
785
+ return duckdb_instance