pytrilogy 0.0.1.104__py3-none-any.whl → 0.0.1.106__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pytrilogy might be problematic. Click here for more details.

Files changed (32) hide show
  1. {pytrilogy-0.0.1.104.dist-info → pytrilogy-0.0.1.106.dist-info}/METADATA +1 -1
  2. {pytrilogy-0.0.1.104.dist-info → pytrilogy-0.0.1.106.dist-info}/RECORD +32 -31
  3. trilogy/__init__.py +3 -2
  4. trilogy/constants.py +1 -0
  5. trilogy/core/models.py +226 -49
  6. trilogy/core/optimization.py +141 -0
  7. trilogy/core/processing/concept_strategies_v3.py +1 -0
  8. trilogy/core/processing/node_generators/common.py +19 -7
  9. trilogy/core/processing/node_generators/filter_node.py +37 -10
  10. trilogy/core/processing/node_generators/merge_node.py +11 -1
  11. trilogy/core/processing/nodes/base_node.py +4 -2
  12. trilogy/core/processing/nodes/group_node.py +5 -2
  13. trilogy/core/processing/nodes/merge_node.py +13 -8
  14. trilogy/core/query_processor.py +5 -2
  15. trilogy/dialect/base.py +85 -54
  16. trilogy/dialect/bigquery.py +6 -4
  17. trilogy/dialect/common.py +8 -6
  18. trilogy/dialect/config.py +69 -1
  19. trilogy/dialect/duckdb.py +5 -4
  20. trilogy/dialect/enums.py +40 -19
  21. trilogy/dialect/postgres.py +4 -2
  22. trilogy/dialect/presto.py +6 -4
  23. trilogy/dialect/snowflake.py +6 -4
  24. trilogy/dialect/sql_server.py +4 -1
  25. trilogy/executor.py +18 -5
  26. trilogy/parsing/common.py +30 -0
  27. trilogy/parsing/parse_engine.py +43 -83
  28. trilogy/parsing/render.py +0 -122
  29. {pytrilogy-0.0.1.104.dist-info → pytrilogy-0.0.1.106.dist-info}/LICENSE.md +0 -0
  30. {pytrilogy-0.0.1.104.dist-info → pytrilogy-0.0.1.106.dist-info}/WHEEL +0 -0
  31. {pytrilogy-0.0.1.104.dist-info → pytrilogy-0.0.1.106.dist-info}/entry_points.txt +0 -0
  32. {pytrilogy-0.0.1.104.dist-info → pytrilogy-0.0.1.106.dist-info}/top_level.txt +0 -0
trilogy/dialect/base.py CHANGED
@@ -22,6 +22,7 @@ from trilogy.core.models import (
22
22
  CompiledCTE,
23
23
  Conditional,
24
24
  Comparison,
25
+ SubselectComparison,
25
26
  OrderItem,
26
27
  WindowItem,
27
28
  FilterItem,
@@ -169,6 +170,9 @@ GENERIC_SQL_TEMPLATE = Template(
169
170
  """{%- if ctes %}
170
171
  WITH {% for cte in ctes %}
171
172
  {{cte.name}} as ({{cte.statement}}){% if not loop.last %},{% endif %}{% endfor %}{% endif %}
173
+ {%- if full_select -%}
174
+ {{full_select}}
175
+ {%- else -%}
172
176
  SELECT
173
177
  {%- if limit is not none %}
174
178
  TOP {{ limit }}{% endif %}
@@ -183,8 +187,8 @@ TOP {{ limit }}{% endif %}
183
187
  \t{{group}}{% if not loop.last %},{% endif %}{% endfor %}{% endif %}
184
188
  {%- if order_by %}
185
189
  ORDER BY {% for order in order_by %}
186
- {{ order }}{% if not loop.last %},{% endif %}
187
- {% endfor %}{% endif %}
190
+ {{ order }}{% if not loop.last %},{% endif %}{% endfor %}
191
+ {% endif %}{% endif %}
188
192
  """
189
193
  )
190
194
 
@@ -217,15 +221,19 @@ def safe_quote(string: str, quote_char: str):
217
221
  return ".".join([f"{quote_char}{string}{quote_char}" for string in components])
218
222
 
219
223
 
220
- def safe_get_cte_value(coalesce, cte: CTE, address: str, rendered: str):
224
+ def safe_get_cte_value(coalesce, cte: CTE, c: Concept, quote_char: str):
225
+ address = c.address
221
226
  raw = cte.source_map.get(address, None)
227
+
222
228
  if not raw:
223
229
  return INVALID_REFERENCE_STRING("Missing source reference")
224
230
  if isinstance(raw, str):
225
- return f"{raw}.{rendered}"
231
+ rendered = cte.get_alias(c, raw)
232
+ return f"{raw}.{quote_char}{rendered}{quote_char}"
226
233
  if isinstance(raw, list) and len(raw) == 1:
227
- return f"{raw[0]}.{rendered}"
228
- return coalesce([f"{x}.{rendered}" for x in raw])
234
+ rendered = cte.get_alias(c, raw[0])
235
+ return f"{raw[0]}.{quote_char}{rendered}{quote_char}"
236
+ return coalesce([f"{x}.{quote_char}{cte.get_alias(c, x)}{quote_char}" for x in raw])
229
237
 
230
238
 
231
239
  class BaseDialect:
@@ -237,21 +245,13 @@ class BaseDialect:
237
245
  DATATYPE_MAP = DATATYPE_MAP
238
246
  UNNEST_MODE = UnnestMode.CROSS_APPLY
239
247
 
240
- def render_order_item(self, order_item: OrderItem, ctes: List[CTE]) -> str:
241
- matched_ctes = [
242
- cte
243
- for cte in ctes
244
- if order_item.expr.address in [a.address for a in cte.output_columns]
245
- ]
246
- if not matched_ctes:
247
- all_outputs = set()
248
- for cte in ctes:
249
- all_outputs.update([a.address for a in cte.output_columns])
250
- raise ValueError(
251
- f"No source found for concept {order_item.expr}, have {all_outputs}"
252
- )
253
- selected = matched_ctes[0]
254
- return f"{selected.name}.{self.QUOTE_CHARACTER}{order_item.expr.safe_address}{self.QUOTE_CHARACTER} {order_item.order.value}"
248
+ def render_order_item(
249
+ self, order_item: OrderItem, cte: CTE, final: bool = False
250
+ ) -> str:
251
+ if final:
252
+ return f"{cte.name}.{self.QUOTE_CHARACTER}{order_item.expr.safe_address}{self.QUOTE_CHARACTER} {order_item.order.value}"
253
+
254
+ return f"{self.render_concept_sql(order_item.expr, cte=cte, alias=False)} {order_item.order.value}"
255
255
 
256
256
  def render_concept_sql(self, c: Concept, cte: CTE, alias: bool = True) -> str:
257
257
  # only recurse while it's in sources of the current cte
@@ -273,14 +273,13 @@ class BaseDialect:
273
273
  ]
274
274
  rval = f"{self.WINDOW_FUNCTION_MAP[c.lineage.type](concept = self.render_concept_sql(c.lineage.content, cte=cte, alias=False), window=','.join(rendered_over_components), sort=','.join(rendered_order_components))}" # noqa: E501
275
275
  elif isinstance(c.lineage, FilterItem):
276
- rval = f"CASE WHEN {self.render_expr(c.lineage.where.conditional)} THEN {self.render_concept_sql(c.lineage.content, cte=cte, alias=False)} ELSE NULL END"
276
+ rval = f"CASE WHEN {self.render_expr(c.lineage.where.conditional, cte=cte)} THEN {self.render_concept_sql(c.lineage.content, cte=cte, alias=False)} ELSE NULL END"
277
277
  elif isinstance(c.lineage, RowsetItem):
278
278
  rval = f"{self.render_concept_sql(c.lineage.content, cte=cte, alias=False)}"
279
279
  elif isinstance(c.lineage, MultiSelectStatement):
280
280
  rval = f"{self.render_concept_sql(c.lineage.find_source(c, cte), cte=cte, alias=False)}"
281
281
  elif isinstance(c.lineage, MergeStatement):
282
282
  rval = f"{self.render_concept_sql(c.lineage.find_source(c, cte), cte=cte, alias=False)}"
283
- # rval = f"{self.FUNCTION_MAP[FunctionType.COALESCE](*[self.render_concept_sql(parent, cte=cte, alias=False) for parent in c.lineage.find_sources(c, cte)])}"
284
283
  elif isinstance(c.lineage, AggregateWrapper):
285
284
  args = [
286
285
  self.render_expr(v, cte) # , alias=False)
@@ -310,13 +309,14 @@ class BaseDialect:
310
309
  logger.debug(
311
310
  f"{LOGGER_PREFIX} [{c.address}] Rendering basic lookup from {cte.source_map.get(c.address, INVALID_REFERENCE_STRING('Missing source reference'))}"
312
311
  )
312
+
313
313
  raw_content = cte.get_alias(c)
314
314
  if isinstance(raw_content, RawColumnExpr):
315
315
  rval = raw_content.text
316
316
  elif isinstance(raw_content, Function):
317
317
  rval = self.render_expr(raw_content, cte=cte)
318
318
  else:
319
- rval = f"{safe_get_cte_value(self.FUNCTION_MAP[FunctionType.COALESCE], cte, c.address, rendered=safe_quote(raw_content, self.QUOTE_CHARACTER))}"
319
+ rval = f"{safe_get_cte_value(self.FUNCTION_MAP[FunctionType.COALESCE], cte, c, self.QUOTE_CHARACTER)}"
320
320
  if alias:
321
321
  return (
322
322
  f"{rval} as"
@@ -330,6 +330,7 @@ class BaseDialect:
330
330
  Function,
331
331
  Conditional,
332
332
  Comparison,
333
+ SubselectComparison,
333
334
  Concept,
334
335
  str,
335
336
  int,
@@ -358,7 +359,15 @@ class BaseDialect:
358
359
  # if isinstance(e, Concept):
359
360
  # cte = cte or cte_map.get(e.address, None)
360
361
 
361
- if isinstance(e, Comparison):
362
+ if isinstance(e, SubselectComparison):
363
+ assert cte, "Subselects must be rendered with a CTE in context"
364
+ if isinstance(e.right, Concept):
365
+ return f"{self.render_expr(e.left, cte=cte, cte_map=cte_map)} {e.operator.value} (select {self.render_expr(e.right, cte=cte, cte_map=cte_map)} from {cte.source_map[e.right.address][0]})"
366
+ else:
367
+ raise NotImplementedError(
368
+ f"Subselects must be a concept, got {e.right}"
369
+ )
370
+ elif isinstance(e, Comparison):
362
371
  return f"{self.render_expr(e.left, cte=cte, cte_map=cte_map)} {e.operator.value} {self.render_expr(e.right, cte=cte, cte_map=cte_map)}"
363
372
  elif isinstance(e, Conditional):
364
373
  # conditions need to be nested in parentheses
@@ -447,7 +456,7 @@ class BaseDialect:
447
456
  else None
448
457
  ),
449
458
  grain=cte.grain,
450
- limit=None,
459
+ limit=cte.limit,
451
460
  # some joins may not need to be rendered
452
461
  joins=[
453
462
  j
@@ -466,9 +475,11 @@ class BaseDialect:
466
475
  where=(
467
476
  self.render_expr(cte.condition, cte) if cte.condition else None
468
477
  ), # source_map=cte_output_map)
469
- # where=self.render_expr(where_assignment[cte.name], cte)
470
- # if cte.name in where_assignment
471
- # else None,
478
+ order_by=(
479
+ [self.render_order_item(i, cte) for i in cte.order_by.items]
480
+ if cte.order_by
481
+ else None
482
+ ),
472
483
  group_by=(
473
484
  list(
474
485
  set(
@@ -513,7 +524,8 @@ class BaseDialect:
513
524
  )
514
525
 
515
526
  def generate_ctes(
516
- self, query: ProcessedQuery, where_assignment: Dict[str, Conditional]
527
+ self,
528
+ query: ProcessedQuery,
517
529
  ):
518
530
  return [self.render_cte(cte) for cte in query.ctes]
519
531
 
@@ -640,35 +652,54 @@ class BaseDialect:
640
652
  " filtered concept instead."
641
653
  )
642
654
 
643
- compiled_ctes = self.generate_ctes(query, {})
655
+ compiled_ctes = self.generate_ctes(query)
644
656
 
645
657
  # restort selections by the order they were written in
646
658
  sorted_select: List[str] = []
647
659
  for output_c in output_addresses:
648
660
  sorted_select.append(select_columns[output_c])
649
- final = self.SQL_TEMPLATE.render(
650
- output=(
651
- query.output_to if isinstance(query, ProcessedQueryPersist) else None
652
- ),
653
- select_columns=sorted_select,
654
- base=query.base.name,
655
- joins=[
656
- render_join(join, self.QUOTE_CHARACTER, None) for join in query.joins
657
- ],
658
- ctes=compiled_ctes,
659
- limit=query.limit,
660
- # move up to CTEs
661
- where=(
662
- self.render_expr(query.where_clause.conditional, cte_map=cte_output_map)
663
- if query.where_clause and output_where
664
- else None
665
- ),
666
- order_by=(
667
- [self.render_order_item(i, [query.base]) for i in query.order_by.items]
668
- if query.order_by
669
- else None
670
- ),
671
- )
661
+ if not query.base.requires_nesting:
662
+ final = self.SQL_TEMPLATE.render(
663
+ output=(
664
+ query.output_to
665
+ if isinstance(query, ProcessedQueryPersist)
666
+ else None
667
+ ),
668
+ full_select=compiled_ctes[-1].statement,
669
+ ctes=compiled_ctes[:-1],
670
+ )
671
+ else:
672
+ final = self.SQL_TEMPLATE.render(
673
+ output=(
674
+ query.output_to
675
+ if isinstance(query, ProcessedQueryPersist)
676
+ else None
677
+ ),
678
+ select_columns=sorted_select,
679
+ base=query.base.name,
680
+ joins=[
681
+ render_join(join, self.QUOTE_CHARACTER, None)
682
+ for join in query.joins
683
+ ],
684
+ ctes=compiled_ctes,
685
+ limit=query.limit,
686
+ # move up to CTEs
687
+ where=(
688
+ self.render_expr(
689
+ query.where_clause.conditional, cte_map=cte_output_map
690
+ )
691
+ if query.where_clause and output_where
692
+ else None
693
+ ),
694
+ order_by=(
695
+ [
696
+ self.render_order_item(i, query.base, final=True)
697
+ for i in query.order_by.items
698
+ ]
699
+ if query.order_by
700
+ else None
701
+ ),
702
+ )
672
703
 
673
704
  if CONFIG.strict_mode and INVALID_REFERENCE_STRING(1) in final:
674
705
  raise ValueError(
@@ -43,8 +43,11 @@ CREATE OR REPLACE TABLE {{ output.address.location }} AS
43
43
  {% endif %}{%- if ctes %}
44
44
  WITH {% for cte in ctes %}
45
45
  {{cte.name}} as ({{cte.statement}}){% if not loop.last %},{% endif %}{% endfor %}{% endif %}
46
- SELECT
46
+ {%- if full_select -%}
47
+ {{full_select}}
48
+ {%- else -%}
47
49
 
50
+ SELECT
48
51
  {%- for select in select_columns %}
49
52
  {{ select }}{% if not loop.last %},{% endif %}{% endfor %}
50
53
  {% if base %}FROM
@@ -59,10 +62,9 @@ SELECT
59
62
  {{group}}{% if not loop.last %},{% endif %}{% endfor %}{% endif %}
60
63
  {%- if order_by %}
61
64
  ORDER BY {% for order in order_by %}
62
- {{ order }}{% if not loop.last %},{% endif %}
63
- {% endfor %}{% endif %}
65
+ {{ order }}{% if not loop.last %},{% endif %}{% endfor %}{% endif %}
64
66
  {%- if limit is not none %}
65
- LIMIT {{ limit }}{% endif %}
67
+ LIMIT {{ limit }}{% endif %}{% endif %}
66
68
  """
67
69
  )
68
70
  MAX_IDENTIFIER_LENGTH = 50
trilogy/dialect/common.py CHANGED
@@ -1,4 +1,4 @@
1
- from trilogy.core.models import Join, InstantiatedUnnestJoin, CTE, Concept
1
+ from trilogy.core.models import Join, InstantiatedUnnestJoin, CTE, Concept, Datasource
2
2
  from trilogy.core.enums import UnnestMode, Modifier
3
3
  from typing import Optional, Callable
4
4
 
@@ -21,18 +21,20 @@ def render_join(
21
21
  if unnest_mode == UnnestMode.DIRECT:
22
22
  return None
23
23
  if not render_func:
24
- raise ValueError("must provide a render func to build an unnest joins")
24
+ raise ValueError("must provide a render function to build an unnest joins")
25
25
  if not cte:
26
26
  raise ValueError("must provide a cte to build an unnest joins")
27
27
  if unnest_mode == UnnestMode.CROSS_JOIN:
28
28
  return f"CROSS JOIN {render_func(join.concept, cte, False)} as {quote_character}{join.concept.safe_address}{quote_character}"
29
29
 
30
30
  return f"FULL JOIN {render_func(join.concept, cte, False)} as unnest_wrapper({quote_character}{join.concept.safe_address}{quote_character})"
31
-
31
+ left_name = join.left_name
32
+ right_name = join.right_name
33
+ right_base = join.right_ref
32
34
  base_joinkeys = [
33
35
  null_wrapper(
34
- f"{join.left_cte.name}.{quote_character}{key.concept.safe_address}{quote_character}",
35
- f"{join.right_cte.name}.{quote_character}{key.concept.safe_address}{quote_character}",
36
+ f"{left_name}.{quote_character}{join.left_cte.get_alias(key.concept) if isinstance(join.left_cte, Datasource) else key.concept.safe_address}{quote_character}",
37
+ f"{right_name}.{quote_character}{join.right_cte.get_alias(key.concept) if isinstance(join.right_cte, Datasource) else key.concept.safe_address}{quote_character}",
36
38
  key.concept,
37
39
  )
38
40
  for key in join.joinkeys
@@ -40,4 +42,4 @@ def render_join(
40
42
  if not base_joinkeys:
41
43
  base_joinkeys = ["1=1"]
42
44
  joinkeys = " AND ".join(base_joinkeys)
43
- return f"{join.jointype.value.upper()} JOIN {join.right_cte.name} on {joinkeys}"
45
+ return f"{join.jointype.value.upper()} JOIN {right_base} on {joinkeys}"
trilogy/dialect/config.py CHANGED
@@ -1,5 +1,27 @@
1
1
  class DialectConfig:
2
- pass
2
+
3
+ def __init__(self):
4
+ pass
5
+
6
+ def connection_string(self) -> str:
7
+ raise NotImplementedError
8
+
9
+ @property
10
+ def connect_args(self) -> dict:
11
+ return {}
12
+
13
+
14
+ class BigQueryConfig(DialectConfig):
15
+ def __init__(self, project: str, client):
16
+ self.project = project
17
+ self.client = client
18
+
19
+ def connection_string(self) -> str:
20
+ return f"bigquery://{self.project}?user_supplied_client=True"
21
+
22
+ @property
23
+ def connect_args(self) -> dict:
24
+ return {"client": self.client}
3
25
 
4
26
 
5
27
  class DuckDBConfig(DialectConfig):
@@ -53,3 +75,49 @@ class SnowflakeConfig(DialectConfig):
53
75
 
54
76
  def connection_string(self) -> str:
55
77
  return f"snowflake://{self.username}:{self.password}@{self.account}"
78
+
79
+
80
+ class PrestoConfig(DialectConfig):
81
+ def __init__(
82
+ self,
83
+ host: str,
84
+ port: int,
85
+ username: str,
86
+ password: str,
87
+ catalog: str,
88
+ schema: str | None = None,
89
+ ):
90
+ self.host = host
91
+ self.port = port
92
+ self.username = username
93
+ self.password = password
94
+ self.catalog = catalog
95
+ self.schema = schema
96
+
97
+ def connection_string(self) -> str:
98
+ if self.schema:
99
+ return f"presto://{self.username}:{self.password}@{self.host}:{self.port}/{self.catalog}/{self.schema}"
100
+ return f"presto://{self.username}:{self.password}@{self.host}:{self.port}/{self.catalog}"
101
+
102
+
103
+ class TrinoConfig(DialectConfig):
104
+ def __init__(
105
+ self,
106
+ host: str,
107
+ port: int,
108
+ username: str,
109
+ password: str,
110
+ catalog: str,
111
+ schema: str | None = None,
112
+ ):
113
+ self.host = host
114
+ self.port = port
115
+ self.username = username
116
+ self.password = password
117
+ self.catalog = catalog
118
+ self.schema = schema
119
+
120
+ def connection_string(self) -> str:
121
+ if self.schema:
122
+ return f"trino://{self.username}:{self.password}@{self.host}:{self.port}/{self.catalog}/{self.schema}"
123
+ return f"trino://{self.username}:{self.password}@{self.host}:{self.port}/{self.catalog}"
trilogy/dialect/duckdb.py CHANGED
@@ -47,8 +47,10 @@ CREATE OR REPLACE TABLE {{ output.address.location }} AS
47
47
  {% endif %}{%- if ctes %}
48
48
  WITH {% for cte in ctes %}
49
49
  {{cte.name}} as ({{cte.statement}}){% if not loop.last %},{% endif %}{% endfor %}{% endif %}
50
- SELECT
50
+ {% if full_select -%}{{full_select}}
51
+ {% else -%}
51
52
 
53
+ SELECT
52
54
  {%- for select in select_columns %}
53
55
  {{ select }}{% if not loop.last %},{% endif %}{% endfor %}
54
56
  {% if base %}FROM
@@ -63,10 +65,9 @@ GROUP BY {% for group in group_by %}
63
65
  {{group}}{% if not loop.last %},{% endif %}{% endfor %}{% endif %}
64
66
  {%- if order_by %}
65
67
  ORDER BY {% for order in order_by %}
66
- {{ order }}{% if not loop.last %},{% endif %}
67
- {% endfor %}{% endif %}
68
+ {{ order }}{% if not loop.last %},{% endif %}{% endfor %}{% endif %}
68
69
  {%- if limit is not none %}
69
- LIMIT ({{ limit }}){% endif %}
70
+ LIMIT ({{ limit }}){% endif %}{% endif %}
70
71
  """
71
72
  )
72
73
 
trilogy/dialect/enums.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from enum import Enum
2
- from typing import List, TYPE_CHECKING, Optional
2
+ from typing import List, TYPE_CHECKING, Optional, Callable
3
3
 
4
4
  if TYPE_CHECKING:
5
5
  from trilogy.hooks.base_hook import BaseHook
@@ -9,6 +9,20 @@ from trilogy.dialect.config import DialectConfig
9
9
  from trilogy.constants import logger
10
10
 
11
11
 
12
+ def default_factory(conf: DialectConfig, config_type):
13
+ from sqlalchemy import create_engine
14
+
15
+ if not isinstance(conf, config_type):
16
+ raise TypeError(
17
+ f"Invalid dialect configuration for type {type(config_type).__name__}"
18
+ )
19
+ if conf.connect_args:
20
+ return create_engine(
21
+ conf.connection_string(), future=True, connect_args=conf.connect_args
22
+ )
23
+ return create_engine(conf.connection_string(), future=True)
24
+
25
+
12
26
  class Dialects(Enum):
13
27
  BIGQUERY = "bigquery"
14
28
  SQL_SERVER = "sql_server"
@@ -24,38 +38,32 @@ class Dialects(Enum):
24
38
  return cls.DUCK_DB
25
39
  return super()._missing_(value)
26
40
 
27
- def default_engine(self, conf=None):
41
+ def default_engine(self, conf=None, _engine_factory: Callable = default_factory):
28
42
  if self == Dialects.BIGQUERY:
29
- from sqlalchemy import create_engine
30
43
  from google.auth import default
31
44
  from google.cloud import bigquery
45
+ from trilogy.dialect.config import BigQueryConfig
32
46
 
33
47
  credentials, project = default()
34
48
  client = bigquery.Client(credentials=credentials, project=project)
35
- return create_engine(
36
- f"bigquery://{project}?user_supplied_client=True",
37
- connect_args={"client": client},
49
+ conf = conf or BigQueryConfig(project=project, client=client)
50
+ return _engine_factory(
51
+ conf,
52
+ BigQueryConfig,
38
53
  )
39
54
  elif self == Dialects.SQL_SERVER:
40
- from sqlalchemy import create_engine
41
55
 
42
56
  raise NotImplementedError()
43
57
  elif self == Dialects.DUCK_DB:
44
- from sqlalchemy import create_engine
45
58
  from trilogy.dialect.config import DuckDBConfig
46
59
 
47
60
  if not conf:
48
61
  conf = DuckDBConfig()
49
- if not isinstance(conf, DuckDBConfig):
50
- raise TypeError("Invalid dialect configuration for type duck_db")
51
- return create_engine(conf.connection_string(), future=True)
62
+ return _engine_factory(conf, DuckDBConfig)
52
63
  elif self == Dialects.SNOWFLAKE:
53
- from sqlalchemy import create_engine
54
64
  from trilogy.dialect.config import SnowflakeConfig
55
65
 
56
- if not isinstance(conf, SnowflakeConfig):
57
- raise TypeError("Invalid dialect configuration for type snowflake")
58
- return create_engine(conf.connection_string(), future=True)
66
+ return _engine_factory(conf, SnowflakeConfig)
59
67
  elif self == Dialects.POSTGRES:
60
68
  logger.warn(
61
69
  "WARN: Using experimental postgres dialect. Most functionality will not work."
@@ -67,13 +75,17 @@ class Dialects(Enum):
67
75
  raise ImportError(
68
76
  "postgres driver not installed. python -m pip install pypreql[postgres]"
69
77
  )
70
- from sqlalchemy import create_engine
71
78
  from trilogy.dialect.config import PostgresConfig
72
79
 
73
- if not isinstance(conf, PostgresConfig):
74
- raise TypeError("Invalid dialect configuration for type postgres")
80
+ return _engine_factory(conf, PostgresConfig)
81
+ elif self == Dialects.PRESTO:
82
+ from trilogy.dialect.config import PrestoConfig
83
+
84
+ return _engine_factory(conf, PrestoConfig)
85
+ elif self == Dialects.TRINO:
86
+ from trilogy.dialect.config import TrinoConfig
75
87
 
76
- return create_engine(conf.connection_string(), future=True)
88
+ return _engine_factory(conf, TrinoConfig)
77
89
  else:
78
90
  raise ValueError(
79
91
  f"Unsupported dialect {self} for default engine creation; create one explicitly."
@@ -84,9 +96,18 @@ class Dialects(Enum):
84
96
  environment: Optional["Environment"] = None,
85
97
  hooks: List["BaseHook"] | None = None,
86
98
  conf: DialectConfig | None = None,
99
+ _engine_factory: Callable | None = None,
87
100
  ) -> "Executor":
88
101
  from trilogy import Executor, Environment
89
102
 
103
+ if _engine_factory is not None:
104
+ return Executor(
105
+ engine=self.default_engine(conf=conf, _engine_factory=_engine_factory),
106
+ environment=environment or Environment(),
107
+ dialect=self,
108
+ hooks=hooks,
109
+ )
110
+
90
111
  return Executor(
91
112
  engine=self.default_engine(conf=conf),
92
113
  environment=environment or Environment(),
@@ -49,8 +49,10 @@ CREATE TABLE {{ output.address.location }} AS
49
49
  {% endif %}{%- if ctes %}
50
50
  WITH {% for cte in ctes %}
51
51
  {{cte.name}} as ({{cte.statement}}){% if not loop.last %},{% endif %}{% endfor %}{% endif %}
52
+ {%- if full_select -%}
53
+ {{full_select}}
54
+ {%- else -%}
52
55
  SELECT
53
-
54
56
  {%- for select in select_columns %}
55
57
  {{ select }}{% if not loop.last %},{% endif %}{% endfor %}
56
58
  {% if base %}FROM
@@ -68,7 +70,7 @@ ORDER BY {% for order in order_by %}
68
70
  {{ order }}{% if not loop.last %},{% endif %}
69
71
  {% endfor %}{% endif %}
70
72
  {%- if limit is not none %}
71
- LIMIT {{ limit }}{% endif %}
73
+ LIMIT {{ limit }}{% endif %}{% endif %}
72
74
  """
73
75
  )
74
76
 
trilogy/dialect/presto.py CHANGED
@@ -42,8 +42,11 @@ CREATE OR REPLACE TABLE {{ output.address }} AS
42
42
  {% endif %}{%- if ctes %}
43
43
  WITH {% for cte in ctes %}
44
44
  {{cte.name}} as ({{cte.statement}}){% if not loop.last %},{% endif %}{% endfor %}{% endif %}
45
- SELECT
46
45
 
46
+ SELECT
47
+ {%- if full_select -%}
48
+ {{full_select}}
49
+ {%- else -%}
47
50
  {%- for select in select_columns %}
48
51
  {{ select }}{% if not loop.last %},{% endif %}{% endfor %}
49
52
  {% if base %}FROM
@@ -58,10 +61,9 @@ SELECT
58
61
  {{group}}{% if not loop.last %},{% endif %}{% endfor %}{% endif %}
59
62
  {%- if order_by %}
60
63
  ORDER BY {% for order in order_by %}
61
- {{ order }}{% if not loop.last %},{% endif %}
62
- {% endfor %}{% endif %}
64
+ {{ order }}{% if not loop.last %},{% endif %}{% endfor %}{% endif %}
63
65
  {%- if limit is not none %}
64
- LIMIT {{ limit }}{% endif %}
66
+ LIMIT {{ limit }}{% endif %}{% endif %}
65
67
  """
66
68
  )
67
69
  MAX_IDENTIFIER_LENGTH = 50
@@ -45,8 +45,11 @@ CREATE OR REPLACE TABLE {{ output.address.location }} AS
45
45
  {% endif %}{%- if ctes %}
46
46
  WITH {% for cte in ctes %}
47
47
  {{cte.name}} as ({{cte.statement}}){% if not loop.last %},{% endif %}{% endfor %}{% endif %}
48
- SELECT
48
+ {%- if full_select -%}
49
+ {{full_select}}
50
+ {%- else -%}
49
51
 
52
+ SELECT
50
53
  {%- for select in select_columns %}
51
54
  {{ select }}{% if not loop.last %},{% endif %}{% endfor %}
52
55
  {% if base %}FROM
@@ -61,10 +64,9 @@ SELECT
61
64
  {{group}}{% if not loop.last %},{% endif %}{% endfor %}{% endif %}
62
65
  {%- if order_by %}
63
66
  ORDER BY {% for order in order_by %}
64
- {{ order }}{% if not loop.last %},{% endif %}
65
- {% endfor %}{% endif %}
67
+ {{ order }}{% if not loop.last %},{% endif %}{% endfor %}{% endif %}
66
68
  {%- if limit is not none %}
67
- LIMIT {{ limit }}{% endif %}
69
+ LIMIT {{ limit }}{% endif %}{% endif %}
68
70
  """
69
71
  )
70
72
  MAX_IDENTIFIER_LENGTH = 50
@@ -40,6 +40,9 @@ TSQL_TEMPLATE = Template(
40
40
  """{%- if ctes %}
41
41
  WITH {% for cte in ctes %}
42
42
  {{cte.name}} as ({{cte.statement}}){% if not loop.last %},{% endif %}{% endfor %}{% endif %}
43
+ {%- if full_select -%}
44
+ {{full_select}}
45
+ {%- else -%}
43
46
  SELECT
44
47
  {%- if limit is not none %}
45
48
  TOP {{ limit }}{% endif %}
@@ -60,7 +63,7 @@ GROUP BY {% for group in group_by %}
60
63
  {%- if order_by %}
61
64
  ORDER BY {% for order in order_by %}
62
65
  {{ order }}{% if not loop.last %},{% endif %}
63
- {% endfor %}{% endif %}
66
+ {% endfor %}{% endif %}{% endif %}
64
67
  """
65
68
  )
66
69
 
trilogy/executor.py CHANGED
@@ -99,7 +99,14 @@ class Executor(object):
99
99
  raise NotImplementedError("Cannot execute type {}".format(type(query)))
100
100
 
101
101
  @execute_query.register
102
- def _(self, query: SelectStatement | PersistStatement) -> CursorResult:
102
+ def _(self, query: SelectStatement) -> CursorResult:
103
+ sql = self.generator.generate_queries(
104
+ self.environment, [query], hooks=self.hooks
105
+ )
106
+ return self.execute_query(sql[0])
107
+
108
+ @execute_query.register
109
+ def _(self, query: PersistStatement) -> CursorResult:
103
110
  sql = self.generator.generate_queries(
104
111
  self.environment, [query], hooks=self.hooks
105
112
  )
@@ -117,16 +124,22 @@ class Executor(object):
117
124
  )
118
125
 
119
126
  @execute_query.register
120
- def _(self, query: ProcessedQuery | ProcessedQueryPersist) -> CursorResult:
127
+ def _(self, query: ProcessedQuery) -> CursorResult:
128
+ sql = self.generator.compile_statement(query)
129
+ # connection = self.engine.connect()
130
+ output = self.connection.execute(text(sql))
131
+ return output
132
+
133
+ @execute_query.register
134
+ def _(self, query: ProcessedQueryPersist) -> CursorResult:
121
135
  sql = self.generator.compile_statement(query)
122
136
  # connection = self.engine.connect()
123
137
  output = self.connection.execute(text(sql))
124
- if isinstance(query, ProcessedQueryPersist):
125
- self.environment.add_datasource(query.datasource)
138
+ self.environment.add_datasource(query.datasource)
126
139
  return output
127
140
 
128
141
  @singledispatchmethod
129
- def generate_sql(self, command: ProcessedQuery | str) -> list[str]:
142
+ def generate_sql(self, command) -> list[str]:
130
143
  raise NotImplementedError(
131
144
  "Cannot generate sql for type {}".format(type(command))
132
145
  )