relationalai 0.13.4__py3-none-any.whl → 0.13.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. relationalai/clients/exec_txn_poller.py +51 -20
  2. relationalai/clients/local.py +15 -7
  3. relationalai/clients/resources/snowflake/__init__.py +2 -2
  4. relationalai/clients/resources/snowflake/direct_access_resources.py +8 -4
  5. relationalai/clients/resources/snowflake/snowflake.py +16 -11
  6. relationalai/experimental/solvers.py +8 -0
  7. relationalai/semantics/lqp/executor.py +3 -3
  8. relationalai/semantics/lqp/model2lqp.py +34 -28
  9. relationalai/semantics/lqp/passes.py +6 -3
  10. relationalai/semantics/lqp/result_helpers.py +76 -12
  11. relationalai/semantics/lqp/rewrite/__init__.py +2 -0
  12. relationalai/semantics/lqp/rewrite/extract_common.py +3 -1
  13. relationalai/semantics/lqp/rewrite/extract_keys.py +85 -20
  14. relationalai/semantics/lqp/rewrite/flatten_script.py +301 -0
  15. relationalai/semantics/lqp/rewrite/functional_dependencies.py +12 -7
  16. relationalai/semantics/lqp/rewrite/quantify_vars.py +12 -3
  17. relationalai/semantics/lqp/rewrite/unify_definitions.py +9 -3
  18. relationalai/semantics/metamodel/dependency.py +9 -0
  19. relationalai/semantics/metamodel/executor.py +17 -10
  20. relationalai/semantics/metamodel/rewrite/__init__.py +2 -1
  21. relationalai/semantics/metamodel/rewrite/flatten.py +1 -2
  22. relationalai/semantics/metamodel/rewrite/format_outputs.py +131 -46
  23. relationalai/semantics/metamodel/rewrite/handle_aggregations_and_ranks.py +237 -0
  24. relationalai/semantics/metamodel/typer/typer.py +1 -1
  25. relationalai/semantics/reasoners/optimization/solvers_pb.py +101 -107
  26. relationalai/semantics/rel/compiler.py +7 -3
  27. relationalai/semantics/rel/executor.py +1 -1
  28. relationalai/tools/txn_progress.py +188 -0
  29. {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/METADATA +1 -1
  30. {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/RECORD +33 -30
  31. {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/WHEEL +0 -0
  32. {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/entry_points.txt +0 -0
  33. {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/licenses/LICENSE +0 -0
@@ -21,30 +21,37 @@ class Executor():
21
21
  def execute(self, model: Model, task:Task, format:Literal["pandas", "snowpark"]="pandas") -> Union[DataFrame, Any]:
22
22
  raise NotImplementedError(f"execute: {self}")
23
23
 
24
- def _compute_cols(self, task: ir.Task, final_model: ir.Model|None) -> Tuple[list[str], list[str]]:
25
- cols = []
26
- extra_cols = []
24
+ def _compute_cols(self, task: ir.Task, final_model: ir.Model|None) -> Tuple[list[str], list[str], list[int]]:
25
+ cols = [] # all cols in output
26
+ extra_cols = [] # all key cols not in output
27
+ key_locs = [] # locations of output key cols in all output cols
28
+
27
29
  # we assume only queries have outputs
28
30
  original_outputs = collect_by_type(ir.Output, task) if task else None
29
31
  outputs = collect_by_type(ir.Output, final_model) if final_model else None
30
- # there are some outputs, and they all have keys
32
+
33
+ # there are some outputs, and some have keys
31
34
  if original_outputs and outputs and not all(not out.keys for out in outputs):
32
35
  assert len(original_outputs) == 1
33
36
  original_output = original_outputs[0]
34
37
  original_cols = []
35
38
  original_cols_val = []
36
- for alias, val in original_output.aliases:
37
- if not alias:
38
- continue
39
- original_cols.append(alias)
40
- original_cols_val.append(val)
41
39
 
42
40
  keys = outputs[0].keys
43
41
  assert keys
42
+
44
43
  for out in outputs:
45
44
  assert out.keys is not None
46
45
  assert set(out.keys) == set(keys), "outputs with different key sets in the same query"
47
46
 
47
+ for (idx, (alias, val)) in enumerate(original_output.aliases):
48
+ if not alias:
49
+ continue
50
+ original_cols.append(alias)
51
+ original_cols_val.append(val)
52
+ if isinstance(val, ir.Var) and val in keys:
53
+ key_locs.append(idx)
54
+
48
55
  extra_cols = []
49
56
  name_cache = NameCache(start_from_one=True)
50
57
  for key in keys:
@@ -54,7 +61,7 @@ class Executor():
54
61
  elif outputs:
55
62
  cols = [alias for alias, _ in outputs[-1].aliases if alias]
56
63
 
57
- return cols, extra_cols
64
+ return cols, extra_cols, key_locs
58
65
 
59
66
  def _postprocess_df(self, config: Config, df: DataFrame, extra_cols: list[str]) -> DataFrame:
60
67
  if bool(config.get("compiler.debug_hidden_keys", False)):
@@ -3,5 +3,6 @@ from .dnf_union_splitter import DNFUnionSplitter
3
3
  from .extract_nested_logicals import ExtractNestedLogicals
4
4
  from .flatten import Flatten
5
5
  from .format_outputs import FormatOutputs
6
+ from .handle_aggregations_and_ranks import HandleAggregationsAndRanks
6
7
 
7
- __all__ = ["DischargeConstraints", "DNFUnionSplitter", "ExtractNestedLogicals", "Flatten", "FormatOutputs"]
8
+ __all__ = ["DischargeConstraints", "DNFUnionSplitter", "ExtractNestedLogicals", "Flatten", "FormatOutputs", "HandleAggregationsAndRanks"]
@@ -585,7 +585,6 @@ def extend_body(body: OrderedSet[ir.Task], extra: ir.Task):
585
585
  tuple(logical_body)
586
586
  ))
587
587
  else:
588
- # no hoists, just inline
589
- body.update(extra.body)
588
+ body.add(extra)
590
589
  else:
591
590
  body.add(extra)
@@ -8,21 +8,22 @@ from relationalai.semantics.metamodel.util import FrozenOrderedSet
8
8
  from relationalai.semantics.metamodel.typer.typer import is_primitive
9
9
 
10
10
  class FormatOutputs(Pass):
11
- def __init__(self, handle_outputs: bool=True):
11
+ def __init__(self, use_rel: bool=False):
12
12
  super().__init__()
13
- self._handle_outputs = handle_outputs
13
+ self._use_rel = use_rel
14
14
 
15
15
  #--------------------------------------------------
16
16
  # Public API
17
17
  #--------------------------------------------------
18
18
  def rewrite(self, model: ir.Model, options:dict={}) -> ir.Model:
19
19
  wide_outputs = options.get("wide_outputs", False)
20
- return self.OutputRewriter(wide_outputs).walk(model)
20
+ return self.OutputRewriter(wide_outputs, self._use_rel).walk(model)
21
21
 
22
22
  class OutputRewriter(visitor.Rewriter):
23
- def __init__(self, wide_outputs: bool = False):
23
+ def __init__(self, wide_outputs: bool = False, use_rel: bool = False):
24
24
  super().__init__()
25
25
  self.wide_outputs = wide_outputs
26
+ self._use_rel = use_rel
26
27
 
27
28
  def handle_logical(self, node: ir.Logical, parent: ir.Node):
28
29
  # Rewrite children first
@@ -36,62 +37,146 @@ class FormatOutputs(Pass):
36
37
  if not groups["outputs"]:
37
38
  return node
38
39
 
39
- return adjust_outputs(node, groups["outputs"], self.wide_outputs)
40
+ if self.wide_outputs:
41
+ return adjust_wide_outputs(node, groups["outputs"])
42
+
43
+ return adjust_gnf_outputs(self._use_rel, node, groups["outputs"])
40
44
 
41
45
  #--------------------------------------------------
42
46
  # GNF vs wide output support
43
47
  #--------------------------------------------------
44
- def adjust_outputs(task: ir.Logical, outputs: OrderedSet[ir.Task], wide_outputs: bool = False):
45
48
 
49
+ # For wide outputs, only adjust the output task to include the keys.
50
+ # output looks like: (key0, key1, val0, val1, ...)
51
+ def adjust_wide_outputs(task: ir.Logical, outputs: OrderedSet[ir.Task]):
52
+ body = list(task.body)
53
+ for output in outputs:
54
+ assert(isinstance(output, ir.Output))
55
+ if output.keys:
56
+ body.remove(output)
57
+ body.append(rewrite_wide_output(output))
58
+ return ir.Logical(task.engine, task.hoisted, tuple(body), task.annotations)
59
+
60
+ # For GNF outputs we need to generate a rule for each "column" in the output
61
+ # and potentially one wide key column
62
+ def adjust_gnf_outputs(use_rel: bool, task: ir.Logical, outputs: OrderedSet[ir.Task]):
46
63
  body = list(task.body)
64
+ for output in outputs:
65
+ assert(isinstance(output, ir.Output))
66
+ if output.keys:
67
+ # Remove the original output. This is replaced by per-column outputs below
68
+ body.remove(output)
69
+
70
+ is_export = helpers.is_export(output)
71
+
72
+ # Exports and Rel execution rely on all columns being in GNF format.
73
+ if is_export or use_rel:
74
+ _adjust_all_gnf_outputs(body, output, is_export)
75
+ else: # Otherwise, put all keys into one wide keys relation
76
+ _adjust_outputs_with_wide_keys(body, output)
77
+
78
+ return ir.Logical(task.engine, task.hoisted, tuple(body), task.annotations)
79
+
80
+ # Generate an output for each "column"
81
+ # output looks like: def output(:cols, :col000, key0, key1, value)
82
+ def _adjust_all_gnf_outputs(body, output: ir.Output, is_export: bool):
83
+ assert output.keys
84
+
85
+ original_cols = OrderedSet()
86
+ for idx, alias in enumerate(output.aliases):
87
+ # Skip None values which are used as a placeholder for missing values
88
+ if alias[1] is None:
89
+ continue
90
+ original_cols.add(alias[1])
91
+ body.extend(_generate_output_column_gnf(output, idx, alias, is_export))
92
+
93
+ idx = len(output.aliases)
94
+ for key in output.keys:
95
+ if key not in original_cols:
96
+ body.extend(_generate_output_column_gnf(output, idx, (key.name, key), is_export))
97
+ idx += 1
98
+
99
+ # Generate an output for each value "column" and one wide output for all the keys
100
+ # * value output looks like: def output(:cols, :col000, key0, key1, value)
101
+ # * key output looks like:
102
+ # def output(:keys, output_key_0, output_key_1, other_key_0, ...)
103
+ #
104
+ # Exceptions: keys for exports and compound keys are converted to GNF, same as the
105
+ # value columns.
106
+ def _adjust_outputs_with_wide_keys(body, output: ir.Output):
107
+ assert output.keys
108
+
109
+ original_cols = OrderedSet()
110
+ val_cols: list[Tuple[str, ir.Value] | None] = []
111
+ key_cols: OrderedSet[Tuple[str, ir.Value]] = OrderedSet()
112
+ key_cols.add(("keys", f.literal("keys", types.Symbol))) # name key col so we can identify it later
113
+ for alias in output.aliases:
114
+ # None values are used as a placeholder for missing values
115
+ # They are added to maintain the correct col count when enumerated below
116
+ if alias[1] is None:
117
+ val_cols.append(None)
118
+ continue
119
+
120
+ original_cols.add(alias[1])
121
+
122
+ if isinstance(alias[1], ir.Var) and alias[1] in output.keys: # note: skips compound keys
123
+ key_cols.add(alias)
124
+ else:
125
+ val_cols.append(alias)
126
+
127
+ # Add keys not in output to the end
128
+ for key in output.keys:
129
+ if key not in original_cols:
130
+ key_cols.add((key.name, key))
131
+
132
+ # Generate GNF val cols
133
+ for idx, alias in enumerate(val_cols):
134
+ if alias:
135
+ new_col = _generate_output_column(output, idx, alias, key_cols)
136
+ body.extend(new_col)
137
+
138
+ # Create a wide key column with all keys
139
+ if len(key_cols) > 1:
140
+ body.append(ir.Output(
141
+ output.engine,
142
+ key_cols.frozen(),
143
+ output.keys,
144
+ output.annotations
145
+ ))
146
+
147
+ # Generate a relation representing a single col in GNF form
148
+ def _generate_output_column(output: ir.Output, idx: int, alias: tuple[str, ir.Value], key_cols):
149
+ if not output.keys:
150
+ return [output]
47
151
 
48
- # For wide outputs, only adjust the output task to include the keys.
49
- if wide_outputs:
50
- for output in outputs:
51
- assert(isinstance(output, ir.Output))
52
- if output.keys:
53
- body.remove(output)
54
- body.append(rewrite_wide_output(output))
55
- return ir.Logical(task.engine, task.hoisted, tuple(body), task.annotations)
152
+ aliases = [("cols", f.literal("cols", types.Symbol))]
153
+ aliases.append(("col", f.literal(f"col{idx:03}", types.Symbol)))
56
154
 
57
- # For GNF outputs we need to generate a rule for each "column" in the output
58
- else:
59
- # First split outputs in potentially multiple outputs, one for each "column"
60
- for output in outputs:
61
- assert(isinstance(output, ir.Output))
62
- if output.keys:
63
- # Remove the original output. This is replaced by per-column outputs below
64
- body.remove(output)
65
-
66
- is_export = helpers.is_export(output)
67
-
68
- # Generate an output for each "column"
69
- # output looks like def output(:cols, :col000, key0, key1, value):
70
- original_cols = OrderedSet()
71
- for idx, alias in enumerate(output.aliases):
72
- # Skip None values which are used as a placeholder for missing values
73
- if alias[1] is None:
74
- continue
75
- original_cols.add(alias[1])
76
- body.extend(_generate_output_column(output, idx, alias, is_export))
77
-
78
- idx = len(output.aliases)
79
- for key in output.keys:
80
- if key not in original_cols:
81
- body.extend(_generate_output_column(output, idx, (key.name, key), is_export))
82
- idx += 1
83
-
84
- return ir.Logical(task.engine, task.hoisted, tuple(body), task.annotations)
85
-
86
- # TODO: return non list?
87
- def _generate_output_column(output: ir.Output, idx: int, alias: tuple[str, ir.Value], is_export: bool):
155
+ # Append all keys at the start
156
+ keys = iter(key_cols)
157
+ assert next(keys) == ("keys", f.literal("keys", types.Symbol)) # skip col name
158
+ for key in keys:
159
+ aliases.append((f"key_{key[0]}_{idx}", key[1]))
160
+
161
+ aliases.append(alias) # append val
162
+
163
+ return [
164
+ ir.Output(
165
+ output.engine,
166
+ FrozenOrderedSet.from_iterable(aliases),
167
+ output.keys,
168
+ output.annotations
169
+ )
170
+ ]
171
+
172
+ # Generate a relation representing a single col in GNF form for export
173
+ def _generate_output_column_gnf(output: ir.Output, idx: int, alias: tuple[str, ir.Value], is_export: bool):
88
174
  if not output.keys:
89
175
  return [output]
90
176
 
91
177
  aliases = [("cols", f.literal("cols", types.Symbol))] if not is_export else []
92
178
  aliases.append(("col", f.literal(f"col{idx:03}", types.Symbol)))
93
179
 
94
- # Append all keys at the start
95
180
  for k in output.keys:
96
181
  aliases.append((f"key_{k.name}_{idx}", k))
97
182
 
@@ -0,0 +1,237 @@
1
+ from __future__ import annotations
2
+
3
+ from relationalai.semantics.metamodel import ir, helpers
4
+ from relationalai.semantics.metamodel.visitor import Rewriter
5
+ from relationalai.semantics.metamodel.compiler import Pass, group_tasks
6
+ from relationalai.semantics.metamodel.util import OrderedSet, ordered_set
7
+ from relationalai.semantics.metamodel import dependency
8
+
9
+ # This rewrite pass handles aggregations and ranks to ensure that their dependencies are
10
+ # self-contained in the required format for emitting Rel/LQP. The expected format is that
11
+ # aggregations and ranks are each contained in their own Logical (which hoists the output
12
+ # vars), with all dependencies pulled into the same Logical.
13
+ #
14
+ # For example,
15
+ #
16
+ # Logical ^[v::Int128]
17
+ # ... <dependencies> ...
18
+ # sum([foo::Foo], [], [a::Int128, v::Int128])
19
+ #
20
+ # Firstly, the pass ensures that all dependencies
21
+ # required by an aggregation/rank are contained locally in the same Logical as the
22
+ # aggregation/rank. For example, in the following Logical:
23
+ #
24
+ # Logical ^[v::Int128]
25
+ # Logical ^[a=None, foo=None]
26
+ # Foo(foo::Foo)
27
+ # a(foo::Foo, a::Int128)
28
+ # sum([foo::Foo], [], [a::Int128, v::Int128])
29
+ # common(foo::Foo, a::Int128)
30
+ #
31
+ # The aggregation `sum` depends on the relation `common`. So, the lookup for `common` needs
32
+ # to be pulled into the same Logical as the aggregation.
33
+ #
34
+ # Logical ^[v::Int128]
35
+ # Logical
36
+ # Logical ^[a=None, foo=None]
37
+ # Foo(foo::Foo)
38
+ # a(foo::Foo, a::Int128)
39
+ # common(foo::Foo, a::Int128)
40
+ # sum([foo::Foo], [], [a::Int128, v::Int128])
41
+ # common(foo::Foo, a::Int128)
42
+ #
43
+ # Secondly, the pass separates Logicals containing more than one aggregation/rank into
44
+ # separate Logicals, each containing a single aggregation/rank.
45
+ #
46
+ # Thirdly, the pass renames variables introduced inside aggregation/rank bodies to ensure
47
+ # they do not clash with variables outside.
48
+
49
+ class HandleAggregationsAndRanks(Pass):
50
+ def __init__(self):
51
+ super().__init__()
52
+
53
+ #--------------------------------------------------
54
+ # Public API
55
+ #--------------------------------------------------
56
+ def rewrite(self, model: ir.Model, options:dict={}) -> ir.Model:
57
+ dep_info = dependency.analyze(model.root)
58
+
59
+ r = AggregationsRanksRewriter(dep_info)
60
+ result = r.walk(model)
61
+
62
+ rn = AggregationsRanksVarRenameRewriter()
63
+ result = rn.walk(result)
64
+
65
+ return result
66
+
67
+ # The AggregationsRanksRewriter ensures that each aggregation and rank is contained
68
+ # in its own Logical, with all dependencies pulled into the same Logical.
69
+ class AggregationsRanksRewriter(Rewriter):
70
+ def __init__(self, dep_info):
71
+ super().__init__()
72
+ self.info = dep_info
73
+ self.rewritten: dict[int, ir.Node] = {}
74
+
75
+ def handle_logical(self, node: ir.Logical, parent: ir.Node):
76
+ groups = group_tasks(node.body, {
77
+ "aggregates_and_ranks": (ir.Aggregate, ir.Rank),
78
+ })
79
+
80
+ aggregates_and_ranks = groups["aggregates_and_ranks"]
81
+
82
+ # If there are no aggregates or ranks, then just recurse into the Logical body
83
+ if not aggregates_and_ranks:
84
+ return super().handle_logical(node, parent)
85
+
86
+ agg_rank_logicals = []
87
+ for agg_rank in aggregates_and_ranks:
88
+ # Gather all dependencies of the logical containing the aggregate/rank
89
+ agg_deps = self.info.task_dependencies(agg_rank)
90
+
91
+ # Reconstruct the body of the Logical containing the aggregate/rank, starting
92
+ # with the existing body
93
+ body = ordered_set()
94
+
95
+ # agg_body is the inner body containing the dependencies of the aggregate/rank
96
+ agg_body = ordered_set()
97
+ for t in node.body:
98
+ if isinstance(t, (ir.Output, ir.Update)):
99
+ # Outputs and Updates need to be kept in the outer body, rather than
100
+ # nested inside the aggregate/rank body.
101
+ body.add(t)
102
+ elif t not in aggregates_and_ranks:
103
+ agg_body.add(t)
104
+
105
+ # Add all other dependencies
106
+ for dep in agg_deps:
107
+ # HACK: there are bugs in the dependency analysis that can cause cycles.
108
+ # Avoid these cycles because otherwise they can cause infinite recursion.
109
+ if agg_rank in self.info.task_dependencies(dep) or node in self.info.task_dependencies(dep):
110
+ continue
111
+ agg_body.add(dep)
112
+
113
+ body.add(ir.Logical(node.engine, tuple(), tuple(agg_body)))
114
+
115
+ # Add the actual aggregate/rank
116
+ body.add(agg_rank.clone())
117
+
118
+ # Construct the final Logical holding the aggregate/rank contents.
119
+ # Output variables need to be hoisted
120
+ if isinstance(agg_rank, ir.Aggregate):
121
+ output_vars = [v for v in helpers.vars(agg_rank.args) if not helpers.is_aggregate_input(v, agg_rank)]
122
+ else:
123
+ assert isinstance(agg_rank, ir.Rank)
124
+ output_vars = [a for a in agg_rank.args]
125
+ output_vars.append(agg_rank.result)
126
+
127
+ agg_logical = ir.Logical(
128
+ engine=node.engine,
129
+ hoisted=tuple(output_vars),
130
+ body=tuple(body)
131
+ )
132
+ agg_rank_logicals.append(agg_logical)
133
+
134
+ if len(agg_rank_logicals) == 1:
135
+ # If there's only one, no need to create a parent Logical
136
+ result = agg_rank_logicals[0]
137
+ else:
138
+ # Otherwise, create a parent Logical, ensuring all body vars are hoisted
139
+ hoisted = OrderedSet()
140
+ for agg_rank_logical in agg_rank_logicals:
141
+ hoisted.update(agg_rank_logical.hoisted)
142
+
143
+ result = ir.Logical(
144
+ engine=node.engine,
145
+ hoisted=tuple(hoisted),
146
+ body=tuple(agg_rank_logicals)
147
+ )
148
+
149
+ # Rewrite the children
150
+ result = super().handle_logical(result, parent)
151
+
152
+ # Make a deep copy so that each task has a unique id. This is important for later
153
+ # rewrite passes (namely QuantifyVars) which identify tasks by id.
154
+ result = DeepCopyRewriter().walk(result)
155
+ return result
156
+
157
+ # The AggregationsRanksVarRenameRewriter renames variables inside aggregation/rank
158
+ # bodies to ensure they do not clash with variables outside. It is careful to keep certain
159
+ # variables unrenamed because they need to interact with the outside, namely the group-by
160
+ # variables and output variables.
161
+ class AggregationsRanksVarRenameRewriter(Rewriter):
162
+ class RenameRewriter(Rewriter):
163
+ def __init__(self, to_keep: set[ir.Var], suffix: str):
164
+ super().__init__()
165
+ self.to_keep = to_keep
166
+ self.suffix = suffix
167
+ self.renamed_vars: dict[ir.Var, ir.Var] = {}
168
+
169
+ def handle_default(self, node: ir.Default, parent):
170
+ if node.var in self.to_keep:
171
+ return node
172
+
173
+ return ir.Default(
174
+ var=self.handle_var(node.var, node),
175
+ value=node.value
176
+ )
177
+
178
+ def handle_var(self, node: ir.Var, parent):
179
+ if node in self.to_keep:
180
+ return node
181
+
182
+ if node in self.renamed_vars:
183
+ return self.renamed_vars[node]
184
+
185
+ # Rename var
186
+ result = ir.Var(
187
+ type=node.type,
188
+ name=f"{node.name}_{self.suffix}"
189
+ )
190
+ self.renamed_vars[node] = result
191
+ return result
192
+
193
+ def __init__(self):
194
+ super().__init__()
195
+ self.renamed_vars: dict[ir.Var, ir.Var] = {}
196
+
197
+ def handle_logical(self, node: ir.Logical, parent: ir.Node):
198
+ groups = group_tasks(node.body, {
199
+ "aggregates_and_ranks": (ir.Aggregate, ir.Rank),
200
+ })
201
+
202
+ aggregates_and_ranks = groups["aggregates_and_ranks"]
203
+ if not aggregates_and_ranks:
204
+ return super().handle_logical(node, parent)
205
+
206
+ # There should only be one, because the AggregationsRanksRewriter should have
207
+ # separated them out.
208
+ assert len(aggregates_and_ranks) == 1, "Multiple aggregate/ranks still found after rewriting dependencies"
209
+
210
+ # Rename at this level
211
+ agg_rank = aggregates_and_ranks[0]
212
+ if isinstance(agg_rank, ir.Aggregate):
213
+ vars_to_keep = set(agg_rank.group)
214
+ output_vars = [v for v in helpers.vars(agg_rank.args) if not helpers.is_aggregate_input(v, agg_rank)]
215
+ vars_to_keep.update(output_vars)
216
+
217
+ result = self.RenameRewriter(vars_to_keep, 'agg').walk(node)
218
+ else:
219
+ assert isinstance(agg_rank, ir.Rank)
220
+ vars_to_keep = set(agg_rank.group)
221
+ output_vars = helpers.vars(agg_rank.args) + [agg_rank.result]
222
+ vars_to_keep.update(output_vars)
223
+ vars_to_keep.update(agg_rank.projection)
224
+
225
+ result = self.RenameRewriter(vars_to_keep, 'rank').walk(node)
226
+
227
+ # Process children
228
+ result = super().handle_logical(result, parent)
229
+ return result
230
+
231
+ class DeepCopyRewriter(Rewriter):
232
+ def walk(self, node, parent=None):
233
+ new_node = super().walk(node, parent)
234
+ if isinstance(new_node, ir.Task):
235
+ return new_node.clone()
236
+ else:
237
+ return new_node
@@ -1092,7 +1092,7 @@ class Replacer(visitor.Rewriter):
1092
1092
 
1093
1093
  def handle_var(self, node: ir.Var, parent: ir.Node):
1094
1094
  if node.id in self.net.resolved_types:
1095
- return f.var(node.name, self.net.resolved_types[node.id])
1095
+ return node.reconstruct(name=node.name, type=self.net.resolved_types[node.id])
1096
1096
  return node
1097
1097
 
1098
1098
  def handle_literal(self, node: ir.Literal, parent: ir.Node):