relationalai 0.13.4__py3-none-any.whl → 0.13.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. relationalai/clients/exec_txn_poller.py +51 -20
  2. relationalai/clients/local.py +15 -7
  3. relationalai/clients/resources/snowflake/__init__.py +2 -2
  4. relationalai/clients/resources/snowflake/direct_access_resources.py +8 -4
  5. relationalai/clients/resources/snowflake/snowflake.py +16 -11
  6. relationalai/experimental/solvers.py +8 -0
  7. relationalai/semantics/lqp/executor.py +3 -3
  8. relationalai/semantics/lqp/model2lqp.py +34 -28
  9. relationalai/semantics/lqp/passes.py +6 -3
  10. relationalai/semantics/lqp/result_helpers.py +76 -12
  11. relationalai/semantics/lqp/rewrite/__init__.py +2 -0
  12. relationalai/semantics/lqp/rewrite/extract_common.py +3 -1
  13. relationalai/semantics/lqp/rewrite/extract_keys.py +85 -20
  14. relationalai/semantics/lqp/rewrite/flatten_script.py +301 -0
  15. relationalai/semantics/lqp/rewrite/functional_dependencies.py +12 -7
  16. relationalai/semantics/lqp/rewrite/quantify_vars.py +12 -3
  17. relationalai/semantics/lqp/rewrite/unify_definitions.py +9 -3
  18. relationalai/semantics/metamodel/dependency.py +9 -0
  19. relationalai/semantics/metamodel/executor.py +17 -10
  20. relationalai/semantics/metamodel/rewrite/__init__.py +2 -1
  21. relationalai/semantics/metamodel/rewrite/flatten.py +1 -2
  22. relationalai/semantics/metamodel/rewrite/format_outputs.py +131 -46
  23. relationalai/semantics/metamodel/rewrite/handle_aggregations_and_ranks.py +237 -0
  24. relationalai/semantics/metamodel/typer/typer.py +1 -1
  25. relationalai/semantics/reasoners/optimization/solvers_pb.py +101 -107
  26. relationalai/semantics/rel/compiler.py +7 -3
  27. relationalai/semantics/rel/executor.py +1 -1
  28. relationalai/tools/txn_progress.py +188 -0
  29. {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/METADATA +1 -1
  30. {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/RECORD +33 -30
  31. {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/WHEEL +0 -0
  32. {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/entry_points.txt +0 -0
  33. {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/licenses/LICENSE +0 -0
@@ -14,9 +14,11 @@ import textwrap
14
14
  import time
15
15
  import uuid
16
16
  from typing import Any, Optional
17
+ from concurrent.futures import ThreadPoolExecutor
17
18
 
18
19
  from relationalai.experimental.solvers import Solver
19
20
  from relationalai.semantics.internal import internal as b
21
+ from relationalai.semantics import std
20
22
  from relationalai.semantics.rel.executor import RelExecutor
21
23
  from relationalai.tools.constants import DEFAULT_QUERY_TIMEOUT_MINS
22
24
  from relationalai.util.timeout import calc_remaining_timeout_minutes
@@ -126,6 +128,9 @@ class SolverModelPB:
126
128
  self._model = model
127
129
  self._num_type = num_type
128
130
  self._id = next(b._global_id)
131
+ self._model_id = str(uuid.uuid4()).upper().replace('-', '_')
132
+ self._load_point_counter = 0
133
+
129
134
  # Maps relationships to their corresponding variable concepts
130
135
  self._variable_relationships: dict[b.Relationship, b.Concept] = {}
131
136
  prefix_uppercase = f"SolverModel_{self._id}_"
@@ -160,6 +165,17 @@ class SolverModelPB:
160
165
  f"{{var}} has {{value:{result_type}}}",
161
166
  short_name=(prefix_lowercase + "point"),
162
167
  )
168
+ # TODO (dba) Remove in favor of Loopy when possible.
169
+
170
+ # This is a helper relation that allows us to model
171
+ # `.load_point()` behavior. Results are loaded from SF views
172
+ # and we cannot swap which solution `.point` currently points
173
+ # to. So we keep a trace of updates. Each time a user calls
174
+ # `.load_point()` we derive a new tuple into this relation and
175
+ # use `max` to select the latest update.
176
+ self._point_updates = model.Property(
177
+ "at {update:int} we load {point:int}"
178
+ )
163
179
  self.points = model.Property(
164
180
  f"point {{i:int}} for {{var}} has {{value:{result_type}}}",
165
181
  short_name=(prefix_lowercase + "points"),
@@ -190,7 +206,6 @@ class SolverModelPB:
190
206
  """
191
207
  b.define(b.RawSource("rel", textwrap.dedent(install_rel)))
192
208
 
193
-
194
209
  # -------------------------------------------------------------------------
195
210
  # Variable Handling
196
211
  # -------------------------------------------------------------------------
@@ -601,101 +616,69 @@ class SolverModelPB:
601
616
 
602
617
  executor.execute_raw(export_rel, query_timeout_mins=query_timeout_mins)
603
618
 
604
- def _import_solver_results_from_csv(
605
- self,
606
- model_id: str,
607
- executor: RelExecutor,
608
- prefix_lowercase: str,
609
- query_timeout_mins: Optional[int] = None
610
- ) -> None:
611
- """Import solver results from CSV files in Snowflake stage.
612
-
613
- Loads and extracts CSV files in a single transaction to minimize overhead.
614
-
615
- Args:
616
- model_id: Unique model identifier for stage paths.
617
- executor: RelExecutor instance.
618
- prefix_lowercase: Prefix for relation names.
619
- query_timeout_mins: Query timeout in minutes.
620
- """
621
- result_stage_base = f"snowflake://APP_STATE.RAI_INTERNAL_STAGE/SOLVERS/job_{model_id}/results"
622
- value_parse_fn = "parse_int" if self._num_type == "int" else "parse_float"
623
-
624
- # Single transaction: Load CSV files and extract/map results
625
- # Use inline definitions to avoid needing declared relations
626
- load_and_extract_rel = textwrap.dedent(f"""
627
- // Define CSV loading inline (no declare needed)
628
- // Load ancillary.csv - contains solver metadata (NAME, VALUE columns)
629
- def ancillary_config[:path]: "{result_stage_base}/ancillary.csv.gz"
630
- def ancillary_config[:syntax, :header_row]: 1
631
- def ancillary_config[:schema, :NAME]: "string"
632
- def ancillary_config[:schema, :VALUE]: "string"
633
- def {prefix_lowercase}solver_ancillary_raw {{load_csv[ancillary_config]}}
634
-
635
- // Load objective_values.csv - contains objective values (SOL_INDEX, VALUE columns)
636
- def objective_values_config[:path]: "{result_stage_base}/objective_values.csv.gz"
637
- def objective_values_config[:syntax, :header_row]: 1
638
- def objective_values_config[:schema, :SOL_INDEX]: "string"
639
- def objective_values_config[:schema, :VALUE]: "string"
640
- def {prefix_lowercase}solver_objective_values_raw {{load_csv[objective_values_config]}}
641
-
642
- // Load points.csv.gz - contains solution points (SOL_INDEX, VAR_HASH, VALUE columns)
643
- def points_config[:path]: "{result_stage_base}/points.csv.gz"
644
- def points_config[:syntax, :header_row]: 1
645
- def points_config[:schema, :SOL_INDEX]: "string"
646
- def points_config[:schema, :VAR_HASH]: "string"
647
- def points_config[:schema, :VALUE]: "string"
648
- def {prefix_lowercase}solver_points_raw {{load_csv[points_config]}}
649
-
650
- // Clear existing result data
651
- def delete[:{self.result_info._name}]: {self.result_info._name}
652
- def delete[:{self.point._name}]: {self.point._name}
653
- def delete[:{self.points._name}]: {self.points._name}
654
-
655
- // Extract ancillary data (result info) - NAME and VALUE columns
656
- def insert(:{self.result_info._name}, key, val): {{
657
- exists((row) |
658
- {prefix_lowercase}solver_ancillary_raw(:NAME, row, key) and
659
- {prefix_lowercase}solver_ancillary_raw(:VALUE, row, val))
660
- }}
661
-
662
- // Extract objective value from objective_values CSV (first solution)
663
- def insert(:{self.result_info._name}, "objective_value", val): {{
664
- exists((row) |
665
- {prefix_lowercase}solver_objective_values_raw(:SOL_INDEX, row, "1") and
666
- {prefix_lowercase}solver_objective_values_raw(:VALUE, row, val))
667
- }}
668
-
669
- // Extract solution points from points.csv.gz into points property
670
- // This file has SOL_INDEX, VAR_HASH, VALUE columns
671
- // Convert CSV string index to Int128 for points property signature
672
- // Convert value to Int128 (for int) or Float64 (for float)
673
- def insert(:{self.points._name}, sol_idx_int128, var, val_converted): {{
674
- exists((row, sol_idx_str, var_hash_str, val_str, sol_idx_int, val) |
675
- {prefix_lowercase}solver_points_raw(:SOL_INDEX, row, sol_idx_str) and
676
- {prefix_lowercase}solver_points_raw(:VAR_HASH, row, var_hash_str) and
677
- {prefix_lowercase}solver_points_raw(:VALUE, row, val_str) and
678
- parse_int(sol_idx_str, sol_idx_int) and
679
- parse_uuid(var_hash_str, var) and
680
- {value_parse_fn}(val_str, val) and
681
- ::std::mirror::convert(std::mirror::typeof[Int128], sol_idx_int, sol_idx_int128) and
682
- {'::std::mirror::convert(std::mirror::typeof[Int128], val, val_converted)' if self._num_type == 'int' else '::std::mirror::convert(std::mirror::typeof[Float64], val, val_converted)'})
683
- }}
684
-
685
- // Extract first solution into point property (default solution)
686
- // Filter to SOL_INDEX = 1
687
- def insert(:{self.point._name}, var, val_converted): {{
688
- exists((row, var_hash_str, val_str, val) |
689
- {prefix_lowercase}solver_points_raw(:SOL_INDEX, row, "1") and
690
- {prefix_lowercase}solver_points_raw(:VAR_HASH, row, var_hash_str) and
691
- {prefix_lowercase}solver_points_raw(:VALUE, row, val_str) and
692
- parse_uuid(var_hash_str, var) and
693
- {value_parse_fn}(val_str, val) and
694
- {'::std::mirror::convert(std::mirror::typeof[Int128], val, val_converted)' if self._num_type == 'int' else '::std::mirror::convert(std::mirror::typeof[Float64], val, val_converted)'})
695
- }}
696
- """)
697
-
698
- executor.execute_raw(load_and_extract_rel, query_timeout_mins=query_timeout_mins)
619
+ def _import_solver_results_from_views(self) -> None:
620
+ from relationalai.semantics.internal.snowflake import Table
621
+
622
+ resources = self._model._to_executor().resources
623
+ app_name = resources.get_app_name()
624
+ logic_engine_name = resources.get_default_engine_name()
625
+ model_id = self._model_id
626
+
627
+ query = f"select {resources.get_app_name()}.api.get_index_db_name(?)"
628
+
629
+ # To prevent sql injection, pass the variable as a separate argument to the executor
630
+ index_db_name = resources._exec(query, [self._model.name])[0][0]
631
+
632
+ # Result tables.
633
+ ancillary_tbl = Table(f"{resources.get_app_name()}.RESULTS.SOLVERS_{model_id}_ANCILLARY")
634
+ obj_vals_tbl = Table(f"{resources.get_app_name()}.RESULTS.SOLVERS_{model_id}_OBJECTIVE_VALUES")
635
+ points_tbl = Table(f"{resources.get_app_name()}.RESULTS.SOLVERS_{model_id}_POINTS")
636
+
637
+ # Skip cdc as we already make sure the tables are loaded.
638
+ ancillary_tbl._skip_cdc = True
639
+ obj_vals_tbl._skip_cdc = True
640
+ points_tbl._skip_cdc = True
641
+
642
+ # Load all tables into the GI database in parallel.
643
+ load_statements = []
644
+ for tbl in [ancillary_tbl, obj_vals_tbl, points_tbl]:
645
+ rel_name = tbl._fqn.replace(".", "_").lower()
646
+ load_statements.append(f"call {app_name}.api.load_data('{index_db_name}', '{logic_engine_name}', '{rel_name}', '{tbl._fqn}', '')")
647
+
648
+ with ThreadPoolExecutor(max_workers=3) as executor:
649
+ futures = [executor.submit(resources._exec, stmt) for stmt in load_statements]
650
+ for future in futures:
651
+ future.result()
652
+
653
+ NumericType = b.Integer if self._num_type == "int" else b.Float
654
+
655
+ # Metadata.
656
+ key, val = b.String.ref(), b.String.ref()
657
+ b.where(ancillary_tbl.name == key, ancillary_tbl.value == val).define(
658
+ self.result_info(key, val)
659
+ )
660
+
661
+ b.where(obj_vals_tbl.sol_index == 1, std.strings.string(obj_vals_tbl.value) == val).define(
662
+ self.result_info("objective_value", val)
663
+ )
664
+
665
+ # Points.
666
+ val_ref = NumericType.ref()
667
+ b.where(
668
+ points_tbl.sol_index == b.Integer,
669
+ points_tbl.var_hash == b.String,
670
+ points_tbl.value == val_ref,
671
+ ).define(self.points(b.Integer, std.parse_uuid(b.String), val_ref))
672
+
673
+ # Default to loading solution `1`.
674
+ b.define(self._point_updates(self._load_point_counter, 1))
675
+
676
+ # Derive into `point` the latest user selection from
677
+ # `_point_updates`.
678
+ point_index = b.Integer.ref()
679
+ b.where(self._point_updates(b.Integer, point_index), b.Integer == b.max(b.Integer),
680
+ self.points(point_index, b.Hash, val_ref))\
681
+ .define(self.point(b.Hash, val_ref))
699
682
 
700
683
  def _export_model_to_protobuf(
701
684
  self,
@@ -805,9 +788,11 @@ class SolverModelPB:
805
788
  **kwargs: Solver options and parameters.
806
789
  """
807
790
 
791
+ self._solver = solver
792
+
808
793
  use_csv_store = solver.engine_settings.get("store", {})\
809
794
  .get("csv", {})\
810
- .get("enabled", False)
795
+ .get("enabled", True)
811
796
 
812
797
  print(f"Using {'csv' if use_csv_store else 'protobuf'} store...")
813
798
 
@@ -853,15 +838,13 @@ class SolverModelPB:
853
838
  payload: dict[str, Any] = {"solver": solver.solver_name.lower(), "options": options}
854
839
 
855
840
  if use_csv_store:
856
- # CSV format: model and results are exchanged via CSV files
857
- model_id = str(uuid.uuid4()).upper().replace('-', '_')
858
- payload["model_uri"] = f"snowflake://SOLVERS/job_{model_id}/model"
841
+ payload["model_uri"] = f"snowflake://SOLVERS/job_{self._model_id}/model"
859
842
 
860
843
  print("Exporting model to CSV...")
861
844
  remaining_timeout_minutes = calc_remaining_timeout_minutes(
862
845
  start_time, query_timeout_mins, config_file_path=config_file_path
863
846
  )
864
- self._export_model_to_csv(model_id, executor, prefix_lowercase, remaining_timeout_minutes)
847
+ self._export_model_to_csv(self._model_id, executor, prefix_lowercase, remaining_timeout_minutes)
865
848
  print("Model CSV export completed")
866
849
 
867
850
  print("Execute solver job")
@@ -874,7 +857,7 @@ class SolverModelPB:
874
857
  remaining_timeout_minutes = calc_remaining_timeout_minutes(
875
858
  start_time, query_timeout_mins, config_file_path=config_file_path
876
859
  )
877
- self._import_solver_results_from_csv(model_id, executor, prefix_lowercase, remaining_timeout_minutes)
860
+ self._import_solver_results_from_views()
878
861
 
879
862
  else: # protobuf format
880
863
  # Protobuf format: model and results are exchanged via binary protobuf
@@ -924,11 +907,22 @@ class SolverModelPB:
924
907
  raise ValueError(
925
908
  f"Expected RelExecutor, but got {type(executor).__name__}."
926
909
  )
927
- load_point_relation = f"""
928
- def delete[:{self.point._name}]: {self.point._name}
929
- def insert(:{self.point._name}, variable, value): {self.points._name}(int128[{point_index}], variable, value)
930
- """
931
- executor.execute_raw(textwrap.dedent(load_point_relation))
910
+
911
+ use_csv_store = True
912
+ if self._solver is not None:
913
+ use_csv_store = self._solver.engine_settings.get("store", {})\
914
+ .get("csv", {})\
915
+ .get("enabled", True)
916
+
917
+ if use_csv_store:
918
+ self._load_point_counter += 1
919
+ b.define(self._point_updates(self._load_point_counter, point_index))
920
+ else:
921
+ load_point_relation = f"""
922
+ def delete[:{self.point._name}]: {self.point._name}
923
+ def insert(:{self.point._name}, variable, value): {self.points._name}(int128[{point_index}], variable, value)
924
+ """
925
+ executor.execute_raw(textwrap.dedent(load_point_relation))
932
926
 
933
927
  def summarize_result(self) -> Any:
934
928
  """Print solver result summary.
@@ -12,7 +12,10 @@ from relationalai.semantics.metamodel.util import OrderedSet, group_by, NameCach
12
12
 
13
13
  from relationalai.semantics.rel import rel, rel_utils as u, builtins as rel_bt
14
14
 
15
- from ..metamodel.rewrite import (Flatten, ExtractNestedLogicals, DNFUnionSplitter, DischargeConstraints, FormatOutputs)
15
+ from ..metamodel.rewrite import (
16
+ Flatten, DNFUnionSplitter, DischargeConstraints, FormatOutputs, ExtractNestedLogicals,
17
+ # HandleAggregationsAndRanks
18
+ )
16
19
  from ..lqp.rewrite import CDC, ExtractCommon, ExtractKeys, FunctionAnnotations, QuantifyVars, Splinter, SplitMultiCheckRequires
17
20
 
18
21
  import math
@@ -30,13 +33,14 @@ class Compiler(c.Compiler):
30
33
  DischargeConstraints(),
31
34
  Checker(),
32
35
  CDC(), # specialize to physical relations before extracting nested and typing
33
- ExtractNestedLogicals(), # before InferTypes to avoid extracting casts
36
+ ExtractNestedLogicals(),
34
37
  InferTypes(),
35
38
  DNFUnionSplitter(),
36
39
  ExtractKeys(),
37
- FormatOutputs(),
40
+ FormatOutputs(use_rel=True),
38
41
  ExtractCommon(),
39
42
  Flatten(),
43
+ # HandleAggregationsAndRanks(),
40
44
  Splinter(),
41
45
  QuantifyVars(),
42
46
  ])
@@ -318,7 +318,7 @@ class RelExecutor(e.Executor):
318
318
  if self.dry_run:
319
319
  return DataFrame()
320
320
 
321
- cols, extra_cols = self._compute_cols(task, task_model)
321
+ cols, extra_cols, _ = self._compute_cols(task, task_model)
322
322
 
323
323
  if not export_to:
324
324
  if format == "pandas":
@@ -0,0 +1,188 @@
1
+ from datetime import datetime
2
+ from collections import defaultdict
3
+ from rich.text import Text
4
+ from rich.console import Console
5
+ from relationalai.util.format import format_duration
6
+
7
+ # Create a console for rendering Rich text to ANSI codes
8
+ # Force terminal mode and no pager to ensure ANSI codes are always generated
9
+ _console = Console(force_terminal=True, legacy_windows=False)
10
+
11
+ def render_line(line_data):
12
+ """Convert structured line data to rendered string with optional ANSI color codes.
13
+
14
+ Args:
15
+ line_data: dict with 'text' (str) and 'color' (str|None) fields.
16
+ color can be 'green', or None
17
+
18
+ Returns:
19
+ str: Plain text or ANSI-colored text
20
+ """
21
+ color = line_data.get("color")
22
+ if color:
23
+ text = Text(line_data["text"], style=color)
24
+ with _console.capture() as capture:
25
+ _console.print(text, end="")
26
+ return capture.get()
27
+ return line_data["text"]
28
+
29
+ def parse_ts(ts):
30
+ ts = ts.replace("Z", "")
31
+
32
+ # Handle fractional seconds - normalize to 6 digits (microseconds)
33
+ if "." in ts:
34
+ parts = ts.split(".")
35
+ if len(parts) == 2:
36
+ # Pad or truncate fractional seconds to 6 digits
37
+ fractional = parts[1].ljust(6, '0')[:6]
38
+ ts = f"{parts[0]}.{fractional}"
39
+
40
+ return datetime.fromisoformat(ts)
41
+
42
+ def duration_ms(task):
43
+ if "start_time" not in task or "end_time" not in task:
44
+ return 0
45
+ return int((parse_ts(task["end_time"]) - parse_ts(task["start_time"])).total_seconds() * 1000)
46
+
47
+ def build_graph(tasks):
48
+ children = defaultdict(list)
49
+ parents = defaultdict(list)
50
+
51
+ for tid, t in tasks.items():
52
+ if "origin_id" in t:
53
+ parent = t["origin_id"]
54
+ children[parent].append(tid)
55
+ parents[tid].append(parent)
56
+
57
+ return parents, children
58
+
59
+ def find_roots(tasks, parents):
60
+ # roots = nodes with no parents (or parents not in tasks)
61
+ return [tid for tid in tasks if tid not in parents or not any(p in tasks for p in parents[tid])]
62
+
63
+ def node_label(tid, task):
64
+ """Create structured label data for a task node.
65
+
66
+ Returns:
67
+ list of dicts with 'text' (str) and 'color' (str|None) fields.
68
+ First item is the task line, subsequent items are warning details.
69
+ """
70
+ name = task.get("task_name", tid)
71
+ is_finished = "start_time" in task and "end_time" in task
72
+ has_warnings = "warnings" in task and task["warnings"]
73
+ warning_count = 0
74
+ warning_text = ""
75
+ text_color = 'green' if is_finished else None
76
+
77
+ lines = []
78
+
79
+ if has_warnings:
80
+ warning_count = len(task["warnings"])
81
+ warning_text = f" ⚠️ {warning_count} warning" + ("s" if warning_count > 1 else "")
82
+
83
+ if is_finished:
84
+ dur_seconds = duration_ms(task) / 1000.0
85
+ duration_str = format_duration(dur_seconds)
86
+
87
+ lines.append({"text": f"{name} ({duration_str}){warning_text}", "color": text_color})
88
+ else:
89
+ lines.append({"text": f"{name} (Evaluating...){warning_text}", "color": text_color})
90
+
91
+ # Add warning details as separate lines (for both finished and evaluating tasks)
92
+ if has_warnings:
93
+ for warning_code, warning_data in task["warnings"].items():
94
+ message = warning_data.get("message", "")
95
+ lines.append({"text": f" ⚠️ {warning_code}: {message}", "color": text_color})
96
+
97
+ return lines
98
+
99
+ def build_tree(tid, tasks, children, prefix="", is_last=True, seen=None):
100
+ """Build tree structure and return as list of structured line data dicts.
101
+
102
+ Returns:
103
+ list of dicts with 'text' (str) and 'color' (str|None) fields
104
+ """
105
+ if seen is None:
106
+ seen = set()
107
+
108
+ if tid not in tasks:
109
+ return []
110
+
111
+ if tid in seen:
112
+ return []
113
+ seen.add(tid)
114
+
115
+ task = tasks[tid]
116
+ connector = "└─ " if is_last else "├─ "
117
+ label_lines = node_label(tid, task)
118
+
119
+ # Build lines with proper prefix
120
+ lines = []
121
+ for i, label_data in enumerate(label_lines):
122
+ if i == 0:
123
+ # First line gets the tree connector
124
+ full_text = prefix + connector + label_data["text"]
125
+ else:
126
+ # Warning detail lines get continuation prefix
127
+ continuation_prefix = prefix + (" " if is_last else "│ ")
128
+ full_text = continuation_prefix + label_data["text"]
129
+ lines.append({"text": full_text, "color": label_data["color"]})
130
+
131
+ new_prefix = prefix + (" " if is_last else "│ ")
132
+
133
+ if tid not in children:
134
+ return lines
135
+
136
+ child_list = [c for c in children[tid] if c in tasks and should_include_task(c, tasks[c])]
137
+ for i, child in enumerate(child_list):
138
+ last = i == len(child_list) - 1
139
+ lines.extend(build_tree(child, tasks, children, new_prefix, last, seen))
140
+
141
+ return lines
142
+
143
+ def should_include_task(task_id, task):
144
+ """Filter out error tasks that aren't useful for execution flow analysis"""
145
+ task_name = task.get("task_name", task_id)
146
+ return not task_name.startswith(":error_pyrel_error_attrs")
147
+
148
+ def format_execution_tree(progress):
149
+ """Format execution tree and return as a string."""
150
+ if not progress or "tasks" not in progress:
151
+ return ""
152
+
153
+ tasks = progress["tasks"]
154
+ parents, children = build_graph(tasks)
155
+ roots = find_roots(tasks, parents)
156
+
157
+ # Filter out error tasks from roots
158
+ filtered_roots = [root for root in roots if should_include_task(root, tasks[root])]
159
+
160
+ # Sort roots: in-progress tasks first, then finished tasks, alphabetically within each group
161
+ def sort_key(task_id):
162
+ task = tasks[task_id]
163
+ is_finished = "start_time" in task and "end_time" in task
164
+ task_name = task.get("task_name", task_id)
165
+ return (is_finished, task_name)
166
+
167
+ sorted_roots = sorted(filtered_roots, key=sort_key)
168
+
169
+ # Build structured data first
170
+ all_line_data = []
171
+ for root in sorted_roots:
172
+ # Add root node (may include warning lines)
173
+ label_lines = node_label(root, tasks[root])
174
+ all_line_data.extend(label_lines)
175
+
176
+ # Add children
177
+ if root in children:
178
+ child_list = [c for c in children[root] if c in tasks and should_include_task(c, tasks[c])]
179
+ for i, child in enumerate(child_list):
180
+ last = i == len(child_list) - 1
181
+ all_line_data.extend(build_tree(child, tasks, children, "", last, set()))
182
+
183
+ # Empty line after each root tree
184
+ all_line_data.append({"text": "", "color": None})
185
+
186
+ # Render structured data to strings
187
+ rendered_lines = [render_line(line_data) for line_data in all_line_data]
188
+ return "\n\n" + "\n".join(rendered_lines)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: relationalai
3
- Version: 0.13.4
3
+ Version: 0.13.5
4
4
  Summary: RelationalAI Library and CLI
5
5
  Author-email: RelationalAI <support@relational.ai>
6
6
  License-File: LICENSE