pytrilogy 0.0.1.118__py3-none-any.whl → 0.0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pytrilogy might be problematic. Click here for more details.

Files changed (45) hide show
  1. {pytrilogy-0.0.1.118.dist-info → pytrilogy-0.0.2.2.dist-info}/METADATA +1 -1
  2. pytrilogy-0.0.2.2.dist-info/RECORD +82 -0
  3. {pytrilogy-0.0.1.118.dist-info → pytrilogy-0.0.2.2.dist-info}/WHEEL +1 -1
  4. trilogy/__init__.py +1 -1
  5. trilogy/constants.py +6 -0
  6. trilogy/core/enums.py +7 -2
  7. trilogy/core/env_processor.py +43 -19
  8. trilogy/core/functions.py +1 -0
  9. trilogy/core/models.py +674 -146
  10. trilogy/core/optimization.py +31 -28
  11. trilogy/core/optimizations/inline_constant.py +4 -1
  12. trilogy/core/optimizations/inline_datasource.py +25 -4
  13. trilogy/core/optimizations/predicate_pushdown.py +94 -54
  14. trilogy/core/processing/concept_strategies_v3.py +69 -39
  15. trilogy/core/processing/graph_utils.py +3 -3
  16. trilogy/core/processing/node_generators/__init__.py +0 -2
  17. trilogy/core/processing/node_generators/basic_node.py +30 -17
  18. trilogy/core/processing/node_generators/filter_node.py +3 -1
  19. trilogy/core/processing/node_generators/node_merge_node.py +345 -96
  20. trilogy/core/processing/node_generators/rowset_node.py +18 -16
  21. trilogy/core/processing/node_generators/select_node.py +45 -85
  22. trilogy/core/processing/nodes/__init__.py +2 -0
  23. trilogy/core/processing/nodes/base_node.py +22 -5
  24. trilogy/core/processing/nodes/filter_node.py +3 -0
  25. trilogy/core/processing/nodes/group_node.py +20 -2
  26. trilogy/core/processing/nodes/merge_node.py +32 -18
  27. trilogy/core/processing/nodes/select_node_v2.py +17 -3
  28. trilogy/core/processing/utility.py +100 -8
  29. trilogy/core/query_processor.py +77 -24
  30. trilogy/dialect/base.py +11 -46
  31. trilogy/dialect/bigquery.py +1 -1
  32. trilogy/dialect/common.py +11 -0
  33. trilogy/dialect/duckdb.py +1 -1
  34. trilogy/dialect/presto.py +1 -0
  35. trilogy/hooks/graph_hook.py +50 -5
  36. trilogy/hooks/query_debugger.py +1 -0
  37. trilogy/parsing/common.py +8 -5
  38. trilogy/parsing/parse_engine.py +52 -27
  39. trilogy/parsing/render.py +20 -9
  40. trilogy/parsing/trilogy.lark +13 -8
  41. pytrilogy-0.0.1.118.dist-info/RECORD +0 -83
  42. trilogy/core/processing/node_generators/concept_merge_node.py +0 -214
  43. {pytrilogy-0.0.1.118.dist-info → pytrilogy-0.0.2.2.dist-info}/LICENSE.md +0 -0
  44. {pytrilogy-0.0.1.118.dist-info → pytrilogy-0.0.2.2.dist-info}/entry_points.txt +0 -0
  45. {pytrilogy-0.0.1.118.dist-info → pytrilogy-0.0.2.2.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,7 @@ from trilogy.core.models import (
7
7
  Concept,
8
8
  QueryDatasource,
9
9
  LooseConceptList,
10
+ Environment,
10
11
  )
11
12
 
12
13
  from trilogy.core.enums import Purpose, Granularity
@@ -123,9 +124,23 @@ def resolve_join_order(joins: List[BaseJoin]) -> List[BaseJoin]:
123
124
  return final_joins
124
125
 
125
126
 
127
+ def add_node_join_concept(graph, concept, datasource, concepts):
128
+ # we don't need to join on a concept if all of the keys exist in the grain
129
+ # if concept.keys and all([x in grain for x in concept.keys]):
130
+ # continue
131
+ concepts.append(concept)
132
+
133
+ graph.add_node(concept.address, type=NodeType.CONCEPT)
134
+ graph.add_edge(datasource.identifier, concept.address)
135
+ for k, v in concept.pseudonyms.items():
136
+ if v.address != concept.address:
137
+ add_node_join_concept(graph, v, datasource, concepts)
138
+
139
+
126
140
  def get_node_joins(
127
141
  datasources: List[QueryDatasource],
128
142
  grain: List[Concept],
143
+ environment: Environment,
129
144
  # concepts:List[Concept],
130
145
  ) -> List[BaseJoin]:
131
146
  graph = nx.Graph()
@@ -133,12 +148,14 @@ def get_node_joins(
133
148
  for datasource in datasources:
134
149
  graph.add_node(datasource.identifier, type=NodeType.NODE)
135
150
  for concept in datasource.output_concepts:
151
+ add_node_join_concept(graph, concept, datasource, concepts)
136
152
  # we don't need to join on a concept if all of the keys exist in the grain
137
153
  # if concept.keys and all([x in grain for x in concept.keys]):
138
154
  # continue
139
- concepts.append(concept)
140
- graph.add_node(concept.address, type=NodeType.CONCEPT)
141
- graph.add_edge(datasource.identifier, concept.address)
155
+ # concepts.append(concept)
156
+
157
+ # graph.add_node(concept.address, type=NodeType.CONCEPT)
158
+ # graph.add_edge(datasource.identifier, concept.address)
142
159
 
143
160
  # add edges for every constant to every datasource
144
161
  for datasource in datasources:
@@ -149,15 +166,55 @@ def get_node_joins(
149
166
  graph.add_edge(node, concept.address)
150
167
 
151
168
  joins: defaultdict[str, set] = defaultdict(set)
152
- identifier_map = {x.identifier: x for x in datasources}
169
+ identifier_map: dict[str, Datasource | QueryDatasource] = {
170
+ x.identifier: x for x in datasources
171
+ }
172
+
173
+ grain_pseudonyms: set[str] = set()
174
+ for g in grain:
175
+ env_lookup = environment.concepts[g.address]
176
+ # if we're looking up a pseudonym, we would have gotten the remapped value
177
+ # so double check we got what we were looking for
178
+ if env_lookup.address == g.address:
179
+ grain_pseudonyms.update(env_lookup.pseudonyms.keys())
153
180
 
154
181
  node_list = sorted(
155
182
  [x for x in graph.nodes if graph.nodes[x]["type"] == NodeType.NODE],
156
183
  # sort so that anything with a partial match on the target is later
157
184
  key=lambda x: len(
158
- [x for x in identifier_map[x].partial_concepts if x in grain]
185
+ [
186
+ partial
187
+ for partial in identifier_map[x].partial_concepts
188
+ if partial in grain
189
+ ]
190
+ + [
191
+ output
192
+ for output in identifier_map[x].output_concepts
193
+ if output.address in grain_pseudonyms
194
+ ]
159
195
  ),
160
196
  )
197
+
198
+ node_map = {
199
+ x[0:20]: len(
200
+ [
201
+ partial
202
+ for partial in identifier_map[x].partial_concepts
203
+ if partial in grain
204
+ ]
205
+ + [
206
+ output
207
+ for output in identifier_map[x].output_concepts
208
+ if output.address in grain_pseudonyms
209
+ ]
210
+ )
211
+ for x in node_list
212
+ }
213
+ print("NODE MAP")
214
+ print(node_map)
215
+ print([x.address for x in grain])
216
+ print(grain_pseudonyms)
217
+
161
218
  for left in node_list:
162
219
  # the constant dataset is a special case
163
220
  # and can never be on the left of a join
@@ -203,14 +260,49 @@ def get_node_joins(
203
260
  c for c in local_concepts if c.granularity != Granularity.SINGLE_ROW
204
261
  ]
205
262
 
206
- # if concept.keys and all([x in grain for x in concept.keys]):
207
- # continue
263
+ relevant = concept_to_relevant_joins(local_concepts)
264
+ left_datasource = identifier_map[left]
265
+ right_datasource = identifier_map[right]
266
+ join_tuples = []
267
+ for joinc in relevant:
268
+ left_arg = joinc
269
+ right_arg = joinc
270
+ if joinc.address not in [
271
+ c.address for c in left_datasource.output_concepts
272
+ ]:
273
+ try:
274
+ left_arg = [
275
+ x
276
+ for x in left_datasource.output_concepts
277
+ if x.address in joinc.pseudonyms
278
+ or joinc.address in x.pseudonyms
279
+ ].pop()
280
+ except IndexError:
281
+ raise SyntaxError(
282
+ f"Could not find {joinc.address} in {left_datasource.identifier} output {[c.address for c in left_datasource.output_concepts]}"
283
+ )
284
+ if joinc.address not in [
285
+ c.address for c in right_datasource.output_concepts
286
+ ]:
287
+ try:
288
+ right_arg = [
289
+ x
290
+ for x in right_datasource.output_concepts
291
+ if x.address in joinc.pseudonyms
292
+ or joinc.address in x.pseudonyms
293
+ ].pop()
294
+ except IndexError:
295
+ raise SyntaxError(
296
+ f"Could not find {joinc.address} in {right_datasource.identifier} output {[c.address for c in right_datasource.output_concepts]}"
297
+ )
298
+ join_tuples.append((left_arg, right_arg))
208
299
  final_joins_pre.append(
209
300
  BaseJoin(
210
301
  left_datasource=identifier_map[left],
211
302
  right_datasource=identifier_map[right],
212
303
  join_type=join_type,
213
- concepts=concept_to_relevant_joins(local_concepts),
304
+ concepts=[],
305
+ concept_pairs=join_tuples,
214
306
  )
215
307
  )
216
308
  final_joins = resolve_join_order(final_joins_pre)
@@ -4,8 +4,11 @@ from trilogy.core.env_processor import generate_graph
4
4
  from trilogy.core.graph_models import ReferenceGraph
5
5
  from trilogy.core.constants import CONSTANT_DATASET
6
6
  from trilogy.core.processing.concept_strategies_v3 import source_query_concepts
7
+ from trilogy.core.enums import SelectFiltering
7
8
  from trilogy.constants import CONFIG, DEFAULT_NAMESPACE
9
+ from trilogy.core.processing.nodes import GroupNode, SelectNode, StrategyNode
8
10
  from trilogy.core.models import (
11
+ Concept,
9
12
  Environment,
10
13
  PersistStatement,
11
14
  SelectStatement,
@@ -24,13 +27,14 @@ from trilogy.core.models import (
24
27
  )
25
28
 
26
29
  from trilogy.utility import unique
27
- from collections import defaultdict
30
+
28
31
  from trilogy.hooks.base_hook import BaseHook
29
32
  from trilogy.constants import logger
30
- from random import shuffle
31
33
  from trilogy.core.ergonomics import CTE_NAMES
32
34
  from trilogy.core.optimization import optimize_ctes
33
35
  from math import ceil
36
+ from collections import defaultdict
37
+ from random import shuffle
34
38
 
35
39
  LOGGER_PREFIX = "[QUERY BUILD]"
36
40
 
@@ -79,6 +83,7 @@ def base_join_to_join(
79
83
  right_cte=right_cte,
80
84
  joinkeys=[JoinKey(concept=concept) for concept in base_join.concepts],
81
85
  jointype=base_join.join_type,
86
+ joinkey_pairs=base_join.concept_pairs if base_join.concept_pairs else None,
82
87
  )
83
88
 
84
89
 
@@ -107,7 +112,9 @@ def generate_source_map(
107
112
  matches = [cte for cte in all_new_ctes if cte.source.name in names]
108
113
 
109
114
  if not matches and names:
110
- raise SyntaxError(query_datasource.source_map)
115
+ raise SyntaxError(
116
+ f"Missing parent CTEs for source map; expecting {names}, have {[cte.source.name for cte in all_new_ctes]}"
117
+ )
111
118
  for cte in matches:
112
119
  output_address = [
113
120
  x.address
@@ -260,7 +267,7 @@ def datasource_to_ctes(
260
267
 
261
268
  human_id = generate_cte_name(query_datasource.full_name, name_map)
262
269
  logger.info(
263
- f"Finished building source map for {human_id} with {len(parents)} parents, have {source_map}, parent had non-empty keys {[k for k, v in query_datasource.source_map.items() if v]} "
270
+ f"Finished building source map for {human_id} with {len(parents)} parents, have {source_map}, query_datasource had non-empty keys {[k for k, v in query_datasource.source_map.items() if v]} "
264
271
  )
265
272
  final_joins = [
266
273
  x
@@ -306,6 +313,28 @@ def datasource_to_ctes(
306
313
  return output
307
314
 
308
315
 
316
+ def append_existence_check(
317
+ node: StrategyNode, environment: Environment, graph: ReferenceGraph
318
+ ):
319
+ # we if we have a where clause doing an existence check
320
+ # treat that as separate subquery
321
+ if (where := node.conditions) and where.existence_arguments:
322
+ for subselect in where.existence_arguments:
323
+ if not subselect:
324
+ continue
325
+ logger.info(
326
+ f"{LOGGER_PREFIX} fetching existance clause inputs {[str(c) for c in subselect]}"
327
+ )
328
+ eds = source_query_concepts([*subselect], environment=environment, g=graph)
329
+
330
+ final_eds = eds.resolve()
331
+ first_parent = node.resolve()
332
+ first_parent.datasources.append(final_eds)
333
+ for x in final_eds.output_concepts:
334
+ if x.address not in first_parent.existence_source_map:
335
+ first_parent.existence_source_map[x.address] = {final_eds}
336
+
337
+
309
338
  def get_query_datasources(
310
339
  environment: Environment,
311
340
  statement: SelectStatement | MultiSelectStatement,
@@ -318,33 +347,57 @@ def get_query_datasources(
318
347
  )
319
348
  if not statement.output_components:
320
349
  raise ValueError(f"Statement has no output components {statement}")
321
- ds = source_query_concepts(
322
- statement.output_components,
350
+
351
+ search_concepts: list[Concept] = statement.output_components
352
+ nest_where = statement.where_clause_category == SelectFiltering.IMPLICIT
353
+ if nest_where and statement.where_clause:
354
+ search_concepts = unique(
355
+ statement.where_clause.row_arguments + search_concepts, "address"
356
+ )
357
+ nest_where = True
358
+
359
+ ods = source_query_concepts(
360
+ search_concepts,
323
361
  environment=environment,
324
362
  g=graph,
325
363
  )
364
+ ds: GroupNode | SelectNode
365
+ if nest_where and statement.where_clause:
366
+ ods.conditions = statement.where_clause.conditional
367
+ ods.output_concepts = search_concepts
368
+ # ods.hidden_concepts = where_delta
369
+ ods.rebuild_cache()
370
+ append_existence_check(ods, environment, graph)
371
+ ds = GroupNode(
372
+ output_concepts=statement.output_components,
373
+ input_concepts=search_concepts,
374
+ parents=[ods],
375
+ environment=ods.environment,
376
+ g=ods.g,
377
+ partial_concepts=ods.partial_concepts,
378
+ )
379
+ # we can still check existence here.
380
+
381
+ elif statement.where_clause:
382
+ ds = SelectNode(
383
+ output_concepts=statement.output_components,
384
+ input_concepts=ods.input_concepts,
385
+ parents=[ods],
386
+ environment=ods.environment,
387
+ g=ods.g,
388
+ partial_concepts=ods.partial_concepts,
389
+ conditions=statement.where_clause.conditional,
390
+ )
391
+ append_existence_check(ds, environment, graph)
392
+
393
+ else:
394
+ ds = ods
395
+
396
+ final_qds = ds.resolve()
326
397
  if hooks:
327
398
  for hook in hooks:
328
399
  hook.process_root_strategy_node(ds)
329
- final_qds = ds.resolve()
330
400
 
331
- # we if we have a where clause doing an existence check
332
- # treat that as separate subquery
333
- if (where := statement.where_clause) and where.existence_arguments:
334
- for subselect in where.existence_arguments:
335
- if not subselect:
336
- continue
337
- logger.info(
338
- f"{LOGGER_PREFIX} fetching existance clause inputs {[str(c) for c in subselect]}"
339
- )
340
- eds = source_query_concepts([*subselect], environment=environment, g=graph)
341
-
342
- final_eds = eds.resolve()
343
- first_parent = final_qds
344
- first_parent.datasources.append(final_eds)
345
- for x in final_eds.output_concepts:
346
- if x.address not in first_parent.existence_source_map:
347
- first_parent.existence_source_map[x.address] = {final_eds}
348
401
  return final_qds
349
402
 
350
403
 
trilogy/dialect/base.py CHANGED
@@ -5,7 +5,6 @@ from jinja2 import Template
5
5
  from trilogy.constants import CONFIG, logger, MagicConstants
6
6
  from trilogy.core.internal import DEFAULT_CONCEPTS
7
7
  from trilogy.core.enums import (
8
- Purpose,
9
8
  FunctionType,
10
9
  WindowType,
11
10
  DatePart,
@@ -40,18 +39,17 @@ from trilogy.core.models import (
40
39
  ShowStatement,
41
40
  RowsetItem,
42
41
  MultiSelectStatement,
43
- MergeStatement,
44
42
  RowsetDerivationStatement,
45
43
  ConceptDeclarationStatement,
46
44
  ImportStatement,
47
45
  RawSQLStatement,
48
46
  ProcessedRawSQLStatement,
49
47
  NumericType,
48
+ MergeStatementV2,
50
49
  )
51
50
  from trilogy.core.query_processor import process_query, process_persist
52
51
  from trilogy.dialect.common import render_join
53
52
  from trilogy.hooks.base_hook import BaseHook
54
- from trilogy.utility import unique
55
53
  from trilogy.core.enums import UnnestMode
56
54
 
57
55
  LOGGER_PREFIX = "[RENDERING]"
@@ -264,7 +262,10 @@ class BaseDialect:
264
262
  rval = f"{self.WINDOW_FUNCTION_MAP[c.lineage.type](concept = self.render_concept_sql(c.lineage.content, cte=cte, alias=False), window=','.join(rendered_over_components), sort=','.join(rendered_order_components))}" # noqa: E501
265
263
  elif isinstance(c.lineage, FilterItem):
266
264
  # for cases when we've optimized this
267
- if len(cte.output_columns) == 1:
265
+ if (
266
+ len(cte.output_columns) == 1
267
+ and cte.condition == c.lineage.where.conditional
268
+ ):
268
269
  rval = self.render_expr(c.lineage.content, cte=cte)
269
270
  else:
270
271
  rval = f"CASE WHEN {self.render_expr(c.lineage.where.conditional, cte=cte)} THEN {self.render_concept_sql(c.lineage.content, cte=cte, alias=False)} ELSE NULL END"
@@ -272,8 +273,6 @@ class BaseDialect:
272
273
  rval = f"{self.render_concept_sql(c.lineage.content, cte=cte, alias=False)}"
273
274
  elif isinstance(c.lineage, MultiSelectStatement):
274
275
  rval = f"{self.render_concept_sql(c.lineage.find_source(c, cte), cte=cte, alias=False)}"
275
- elif isinstance(c.lineage, MergeStatement):
276
- rval = f"{self.render_concept_sql(c.lineage.find_source(c, cte), cte=cte, alias=False)}"
277
276
  elif isinstance(c.lineage, AggregateWrapper):
278
277
  args = [
279
278
  self.render_expr(v, cte) # , alias=False)
@@ -509,35 +508,7 @@ class BaseDialect:
509
508
  set(
510
509
  [
511
510
  self.render_concept_sql(c, cte, alias=False)
512
- for c in unique(
513
- cte.grain.components
514
- + [
515
- c
516
- for c in cte.output_columns
517
- if c.purpose in (Purpose.PROPERTY, Purpose.KEY)
518
- and c.address
519
- not in [x.address for x in cte.grain.components]
520
- ]
521
- + [
522
- c
523
- for c in cte.output_columns
524
- if c.purpose == Purpose.METRIC
525
- and any(
526
- [
527
- c.with_grain(cte.grain)
528
- in cte.output_columns
529
- for cte in cte.parent_ctes
530
- ]
531
- )
532
- ]
533
- + [
534
- c
535
- for c in cte.output_columns
536
- if c.purpose == Purpose.CONSTANT
537
- and cte.source_map[c.address] != []
538
- ],
539
- "address",
540
- )
511
+ for c in cte.group_concepts
541
512
  ]
542
513
  )
543
514
  )
@@ -563,9 +534,9 @@ class BaseDialect:
563
534
  | ShowStatement
564
535
  | ConceptDeclarationStatement
565
536
  | RowsetDerivationStatement
566
- | MergeStatement
567
537
  | ImportStatement
568
538
  | RawSQLStatement
539
+ | MergeStatementV2
569
540
  ],
570
541
  hooks: Optional[List[BaseHook]] = None,
571
542
  ) -> List[
@@ -626,7 +597,7 @@ class BaseDialect:
626
597
  statement,
627
598
  (
628
599
  ConceptDeclarationStatement,
629
- MergeStatement,
600
+ MergeStatementV2,
630
601
  ImportStatement,
631
602
  RowsetDerivationStatement,
632
603
  ),
@@ -675,7 +646,7 @@ class BaseDialect:
675
646
  # where assignment
676
647
  output_where = False
677
648
  if query.where_clause:
678
- found = False
649
+ # found = False
679
650
  filter = set(
680
651
  [
681
652
  str(x.address)
@@ -684,16 +655,10 @@ class BaseDialect:
684
655
  ]
685
656
  )
686
657
  query_output = set([str(z.address) for z in query.output_columns])
658
+ # if it wasn't an output
659
+ # we would have forced it up earlier and we don't need to render at this point
687
660
  if filter.issubset(query_output):
688
661
  output_where = True
689
- found = True
690
-
691
- if not found:
692
- raise NotImplementedError(
693
- f"Cannot generate query with filtering on row arguments {filter} that is"
694
- f" not a subset of the query output grain {query_output}. Try a"
695
- " filtered concept instead, or include it in the select clause"
696
- )
697
662
  for ex_set in query.where_clause.existence_arguments:
698
663
  for c in ex_set:
699
664
  if c.address not in cte_output_map:
@@ -43,7 +43,7 @@ CREATE OR REPLACE TABLE {{ output.address.location }} AS
43
43
  {% endif %}{%- if ctes %}
44
44
  WITH {% for cte in ctes %}
45
45
  {{cte.name}} as ({{cte.statement}}){% if not loop.last %},{% endif %}{% endfor %}{% endif %}
46
- {%- if full_select -%}
46
+ {%- if full_select %}
47
47
  {{full_select}}
48
48
  {% else -%}
49
49
 
trilogy/dialect/common.py CHANGED
@@ -39,6 +39,17 @@ def render_join(
39
39
  )
40
40
  for key in join.joinkeys
41
41
  ]
42
+ if join.joinkey_pairs:
43
+ base_joinkeys.extend(
44
+ [
45
+ null_wrapper(
46
+ f"{left_name}.{quote_character}{join.left_cte.get_alias(left_concept) if isinstance(join.left_cte, Datasource) else left_concept.safe_address}{quote_character}",
47
+ f"{right_name}.{quote_character}{join.right_cte.get_alias(right_concept) if isinstance(join.right_cte, Datasource) else right_concept.safe_address}{quote_character}",
48
+ left_concept,
49
+ )
50
+ for left_concept, right_concept in join.joinkey_pairs
51
+ ]
52
+ )
42
53
  if not base_joinkeys:
43
54
  base_joinkeys = ["1=1"]
44
55
  joinkeys = " AND ".join(base_joinkeys)
trilogy/dialect/duckdb.py CHANGED
@@ -59,7 +59,7 @@ SELECT
59
59
  {{ join }}{% endfor %}{% endif %}
60
60
  {% if where %}WHERE
61
61
  {{ where }}
62
- {% endif %}
62
+ {% endif -%}
63
63
  {%- if group_by %}
64
64
  GROUP BY {% for group in group_by %}
65
65
  {{group}}{% if not loop.last %},{% endif %}{% endfor %}{% endif %}
trilogy/dialect/presto.py CHANGED
@@ -14,6 +14,7 @@ FUNCTION_MAP = {
14
14
  FunctionType.SUM: lambda x: f"sum({x[0]})",
15
15
  FunctionType.LENGTH: lambda x: f"length({x[0]})",
16
16
  FunctionType.AVG: lambda x: f"avg({x[0]})",
17
+ FunctionType.INDEX_ACCESS: lambda x: f"element_at({x[0]},{x[1]})",
17
18
  FunctionType.LIKE: lambda x: (
18
19
  f" CASE WHEN {x[0]} like {x[1]} THEN True ELSE False END"
19
20
  ),
@@ -1,5 +1,5 @@
1
1
  from trilogy.hooks.base_hook import BaseHook
2
- from networkx import DiGraph
2
+ import networkx as nx
3
3
 
4
4
 
5
5
  class GraphHook(BaseHook):
@@ -10,8 +10,12 @@ class GraphHook(BaseHook):
10
10
  except ImportError:
11
11
  raise ImportError("GraphHook requires matplotlib and scipy to be installed")
12
12
 
13
- def query_graph_built(self, graph: DiGraph):
14
- from networkx import draw_kamada_kawai
13
+ def query_graph_built(
14
+ self,
15
+ graph: nx.DiGraph,
16
+ target: str | None = None,
17
+ highlight_nodes: list[str] | None = None,
18
+ ):
15
19
  from matplotlib import pyplot as plt
16
20
 
17
21
  graph = graph.copy()
@@ -19,6 +23,47 @@ class GraphHook(BaseHook):
19
23
  for node in nodes:
20
24
  if "__preql_internal" in node:
21
25
  graph.remove_node(node)
22
- draw_kamada_kawai(graph, with_labels=True, connectionstyle="arc3, rad = 0.1")
23
- # draw_spring(graph, with_labels=True, connectionstyle='arc3, rad = 0.1')
26
+ graph.remove_nodes_from(list(nx.isolates(graph)))
27
+ color_map = []
28
+ highlight_nodes = highlight_nodes or []
29
+ for node in graph:
30
+ if node in highlight_nodes:
31
+ color_map.append("orange")
32
+ elif str(node).startswith("ds"):
33
+ color_map.append("blue")
34
+ else:
35
+ color_map.append("green")
36
+ # pos = nx.kamada_kawai_layout(graph, scale=2)
37
+ pos = nx.spring_layout(graph)
38
+ kwargs = {}
39
+ if target:
40
+ edge_colors = []
41
+ descendents = nx.descendants(graph, target)
42
+ for edge in graph.edges():
43
+ if edge[0] == target:
44
+ edge_colors.append("blue")
45
+ elif edge[1] == target:
46
+ edge_colors.append("blue")
47
+ elif edge[1] in descendents:
48
+ edge_colors.append("green")
49
+ else:
50
+ edge_colors.append("black")
51
+ kwargs["edge_color"] = edge_colors
52
+ nx.draw(
53
+ graph,
54
+ pos=pos,
55
+ node_color=color_map,
56
+ connectionstyle="arc3, rad = 0.1",
57
+ **kwargs
58
+ ) # Draw the original graph
59
+ # Please note, the code below uses the original idea of re-calculating a dictionary of adjusted label positions per node.
60
+ pos_labels = {}
61
+ # For each node in the Graph
62
+ for aNode in graph.nodes():
63
+ # Get the node's position from the layout
64
+ x, y = pos[aNode]
65
+ # pos_labels[aNode] = (x+slopeX*label_ratio, y+slopeY*label_ratio)
66
+ pos_labels[aNode] = (x, y)
67
+ # Finally, redraw the labels at their new position.
68
+ nx.draw_networkx_labels(graph, pos=pos_labels, font_size=10)
24
69
  plt.show()
@@ -130,4 +130,5 @@ class DebuggingHook(BaseHook):
130
130
  if self.process_nodes != PrintMode.OFF:
131
131
  printed = print_recursive_nodes(node, mode=self.process_nodes)
132
132
  for row in printed:
133
+ logger.info("".join([str(v) for v in row]))
133
134
  print("".join([str(v) for v in row]))
trilogy/parsing/common.py CHANGED
@@ -53,6 +53,7 @@ def constant_to_concept(
53
53
  output_purpose=Purpose.CONSTANT,
54
54
  arguments=[parent],
55
55
  )
56
+ fmetadata = metadata or Metadata()
56
57
  return Concept(
57
58
  name=name,
58
59
  datatype=const_function.output_datatype,
@@ -60,7 +61,7 @@ def constant_to_concept(
60
61
  lineage=const_function,
61
62
  grain=const_function.output_grain,
62
63
  namespace=namespace,
63
- metadata=metadata,
64
+ metadata=fmetadata,
64
65
  )
65
66
 
66
67
 
@@ -105,13 +106,13 @@ def filter_item_to_concept(
105
106
  purpose: Purpose | None = None,
106
107
  metadata: Metadata | None = None,
107
108
  ) -> Concept:
108
-
109
+ fmetadata = metadata or Metadata()
109
110
  return Concept(
110
111
  name=name,
111
112
  datatype=parent.content.datatype,
112
113
  purpose=parent.content.purpose,
113
114
  lineage=parent,
114
- metadata=metadata,
115
+ metadata=fmetadata,
115
116
  namespace=namespace,
116
117
  # filtered copies cannot inherit keys
117
118
  keys=None,
@@ -130,6 +131,7 @@ def window_item_to_concept(
130
131
  purpose: Purpose | None = None,
131
132
  metadata: Metadata | None = None,
132
133
  ) -> Concept:
134
+ fmetadata = metadata or Metadata()
133
135
  local_purpose, keys = get_purpose_and_keys(purpose, (parent.content,))
134
136
  if parent.order_by:
135
137
  grain = parent.over + [parent.content.output]
@@ -142,7 +144,7 @@ def window_item_to_concept(
142
144
  datatype=parent.content.datatype,
143
145
  purpose=local_purpose,
144
146
  lineage=parent,
145
- metadata=metadata,
147
+ metadata=fmetadata,
146
148
  # filters are implicitly at the grain of the base item
147
149
  grain=Grain(components=grain),
148
150
  namespace=namespace,
@@ -162,12 +164,13 @@ def agg_wrapper_to_concept(
162
164
  )
163
165
  # anything grouped to a grain should be a property
164
166
  # at that grain
167
+ fmetadata = metadata or Metadata()
165
168
  aggfunction = parent.function
166
169
  out = Concept(
167
170
  name=name,
168
171
  datatype=aggfunction.output_datatype,
169
172
  purpose=Purpose.METRIC,
170
- metadata=metadata,
173
+ metadata=fmetadata,
171
174
  lineage=parent,
172
175
  grain=Grain(components=parent.by) if parent.by else Grain(),
173
176
  namespace=namespace,