relationalai 0.13.4__py3-none-any.whl → 0.13.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/clients/exec_txn_poller.py +51 -20
- relationalai/clients/local.py +15 -7
- relationalai/clients/resources/snowflake/__init__.py +2 -2
- relationalai/clients/resources/snowflake/direct_access_resources.py +8 -4
- relationalai/clients/resources/snowflake/snowflake.py +16 -11
- relationalai/experimental/solvers.py +8 -0
- relationalai/semantics/lqp/executor.py +3 -3
- relationalai/semantics/lqp/model2lqp.py +34 -28
- relationalai/semantics/lqp/passes.py +6 -3
- relationalai/semantics/lqp/result_helpers.py +76 -12
- relationalai/semantics/lqp/rewrite/__init__.py +2 -0
- relationalai/semantics/lqp/rewrite/extract_common.py +3 -1
- relationalai/semantics/lqp/rewrite/extract_keys.py +85 -20
- relationalai/semantics/lqp/rewrite/flatten_script.py +301 -0
- relationalai/semantics/lqp/rewrite/functional_dependencies.py +12 -7
- relationalai/semantics/lqp/rewrite/quantify_vars.py +12 -3
- relationalai/semantics/lqp/rewrite/unify_definitions.py +9 -3
- relationalai/semantics/metamodel/dependency.py +9 -0
- relationalai/semantics/metamodel/executor.py +17 -10
- relationalai/semantics/metamodel/rewrite/__init__.py +2 -1
- relationalai/semantics/metamodel/rewrite/flatten.py +1 -2
- relationalai/semantics/metamodel/rewrite/format_outputs.py +131 -46
- relationalai/semantics/metamodel/rewrite/handle_aggregations_and_ranks.py +237 -0
- relationalai/semantics/metamodel/typer/typer.py +1 -1
- relationalai/semantics/reasoners/optimization/solvers_pb.py +101 -107
- relationalai/semantics/rel/compiler.py +7 -3
- relationalai/semantics/rel/executor.py +1 -1
- relationalai/tools/txn_progress.py +188 -0
- {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/METADATA +1 -1
- {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/RECORD +33 -30
- {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/WHEEL +0 -0
- {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/entry_points.txt +0 -0
- {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -21,30 +21,37 @@ class Executor():
|
|
|
21
21
|
def execute(self, model: Model, task:Task, format:Literal["pandas", "snowpark"]="pandas") -> Union[DataFrame, Any]:
|
|
22
22
|
raise NotImplementedError(f"execute: {self}")
|
|
23
23
|
|
|
24
|
-
def _compute_cols(self, task: ir.Task, final_model: ir.Model|None) -> Tuple[list[str], list[str]]:
|
|
25
|
-
cols = []
|
|
26
|
-
extra_cols = []
|
|
24
|
+
def _compute_cols(self, task: ir.Task, final_model: ir.Model|None) -> Tuple[list[str], list[str], list[int]]:
|
|
25
|
+
cols = [] # all cols in output
|
|
26
|
+
extra_cols = [] # all key cols not in output
|
|
27
|
+
key_locs = [] # locations of output key cols in all output cols
|
|
28
|
+
|
|
27
29
|
# we assume only queries have outputs
|
|
28
30
|
original_outputs = collect_by_type(ir.Output, task) if task else None
|
|
29
31
|
outputs = collect_by_type(ir.Output, final_model) if final_model else None
|
|
30
|
-
|
|
32
|
+
|
|
33
|
+
# there are some outputs, and some have keys
|
|
31
34
|
if original_outputs and outputs and not all(not out.keys for out in outputs):
|
|
32
35
|
assert len(original_outputs) == 1
|
|
33
36
|
original_output = original_outputs[0]
|
|
34
37
|
original_cols = []
|
|
35
38
|
original_cols_val = []
|
|
36
|
-
for alias, val in original_output.aliases:
|
|
37
|
-
if not alias:
|
|
38
|
-
continue
|
|
39
|
-
original_cols.append(alias)
|
|
40
|
-
original_cols_val.append(val)
|
|
41
39
|
|
|
42
40
|
keys = outputs[0].keys
|
|
43
41
|
assert keys
|
|
42
|
+
|
|
44
43
|
for out in outputs:
|
|
45
44
|
assert out.keys is not None
|
|
46
45
|
assert set(out.keys) == set(keys), "outputs with different key sets in the same query"
|
|
47
46
|
|
|
47
|
+
for (idx, (alias, val)) in enumerate(original_output.aliases):
|
|
48
|
+
if not alias:
|
|
49
|
+
continue
|
|
50
|
+
original_cols.append(alias)
|
|
51
|
+
original_cols_val.append(val)
|
|
52
|
+
if isinstance(val, ir.Var) and val in keys:
|
|
53
|
+
key_locs.append(idx)
|
|
54
|
+
|
|
48
55
|
extra_cols = []
|
|
49
56
|
name_cache = NameCache(start_from_one=True)
|
|
50
57
|
for key in keys:
|
|
@@ -54,7 +61,7 @@ class Executor():
|
|
|
54
61
|
elif outputs:
|
|
55
62
|
cols = [alias for alias, _ in outputs[-1].aliases if alias]
|
|
56
63
|
|
|
57
|
-
return cols, extra_cols
|
|
64
|
+
return cols, extra_cols, key_locs
|
|
58
65
|
|
|
59
66
|
def _postprocess_df(self, config: Config, df: DataFrame, extra_cols: list[str]) -> DataFrame:
|
|
60
67
|
if bool(config.get("compiler.debug_hidden_keys", False)):
|
|
@@ -3,5 +3,6 @@ from .dnf_union_splitter import DNFUnionSplitter
|
|
|
3
3
|
from .extract_nested_logicals import ExtractNestedLogicals
|
|
4
4
|
from .flatten import Flatten
|
|
5
5
|
from .format_outputs import FormatOutputs
|
|
6
|
+
from .handle_aggregations_and_ranks import HandleAggregationsAndRanks
|
|
6
7
|
|
|
7
|
-
__all__ = ["DischargeConstraints", "DNFUnionSplitter", "ExtractNestedLogicals", "Flatten", "FormatOutputs"]
|
|
8
|
+
__all__ = ["DischargeConstraints", "DNFUnionSplitter", "ExtractNestedLogicals", "Flatten", "FormatOutputs", "HandleAggregationsAndRanks"]
|
|
@@ -8,21 +8,22 @@ from relationalai.semantics.metamodel.util import FrozenOrderedSet
|
|
|
8
8
|
from relationalai.semantics.metamodel.typer.typer import is_primitive
|
|
9
9
|
|
|
10
10
|
class FormatOutputs(Pass):
|
|
11
|
-
def __init__(self,
|
|
11
|
+
def __init__(self, use_rel: bool=False):
|
|
12
12
|
super().__init__()
|
|
13
|
-
self.
|
|
13
|
+
self._use_rel = use_rel
|
|
14
14
|
|
|
15
15
|
#--------------------------------------------------
|
|
16
16
|
# Public API
|
|
17
17
|
#--------------------------------------------------
|
|
18
18
|
def rewrite(self, model: ir.Model, options:dict={}) -> ir.Model:
|
|
19
19
|
wide_outputs = options.get("wide_outputs", False)
|
|
20
|
-
return self.OutputRewriter(wide_outputs).walk(model)
|
|
20
|
+
return self.OutputRewriter(wide_outputs, self._use_rel).walk(model)
|
|
21
21
|
|
|
22
22
|
class OutputRewriter(visitor.Rewriter):
|
|
23
|
-
def __init__(self, wide_outputs: bool = False):
|
|
23
|
+
def __init__(self, wide_outputs: bool = False, use_rel: bool = False):
|
|
24
24
|
super().__init__()
|
|
25
25
|
self.wide_outputs = wide_outputs
|
|
26
|
+
self._use_rel = use_rel
|
|
26
27
|
|
|
27
28
|
def handle_logical(self, node: ir.Logical, parent: ir.Node):
|
|
28
29
|
# Rewrite children first
|
|
@@ -36,62 +37,146 @@ class FormatOutputs(Pass):
|
|
|
36
37
|
if not groups["outputs"]:
|
|
37
38
|
return node
|
|
38
39
|
|
|
39
|
-
|
|
40
|
+
if self.wide_outputs:
|
|
41
|
+
return adjust_wide_outputs(node, groups["outputs"])
|
|
42
|
+
|
|
43
|
+
return adjust_gnf_outputs(self._use_rel, node, groups["outputs"])
|
|
40
44
|
|
|
41
45
|
#--------------------------------------------------
|
|
42
46
|
# GNF vs wide output support
|
|
43
47
|
#--------------------------------------------------
|
|
44
|
-
def adjust_outputs(task: ir.Logical, outputs: OrderedSet[ir.Task], wide_outputs: bool = False):
|
|
45
48
|
|
|
49
|
+
# For wide outputs, only adjust the output task to include the keys.
|
|
50
|
+
# output looks like: (key0, key1, val0, val1, ...)
|
|
51
|
+
def adjust_wide_outputs(task: ir.Logical, outputs: OrderedSet[ir.Task]):
|
|
52
|
+
body = list(task.body)
|
|
53
|
+
for output in outputs:
|
|
54
|
+
assert(isinstance(output, ir.Output))
|
|
55
|
+
if output.keys:
|
|
56
|
+
body.remove(output)
|
|
57
|
+
body.append(rewrite_wide_output(output))
|
|
58
|
+
return ir.Logical(task.engine, task.hoisted, tuple(body), task.annotations)
|
|
59
|
+
|
|
60
|
+
# For GNF outputs we need to generate a rule for each "column" in the output
|
|
61
|
+
# and potentially one wide key column
|
|
62
|
+
def adjust_gnf_outputs(use_rel: bool, task: ir.Logical, outputs: OrderedSet[ir.Task]):
|
|
46
63
|
body = list(task.body)
|
|
64
|
+
for output in outputs:
|
|
65
|
+
assert(isinstance(output, ir.Output))
|
|
66
|
+
if output.keys:
|
|
67
|
+
# Remove the original output. This is replaced by per-column outputs below
|
|
68
|
+
body.remove(output)
|
|
69
|
+
|
|
70
|
+
is_export = helpers.is_export(output)
|
|
71
|
+
|
|
72
|
+
# Exports and Rel execution rely on all columns being in GNF format.
|
|
73
|
+
if is_export or use_rel:
|
|
74
|
+
_adjust_all_gnf_outputs(body, output, is_export)
|
|
75
|
+
else: # Otherwise, put all keys into one wide keys relation
|
|
76
|
+
_adjust_outputs_with_wide_keys(body, output)
|
|
77
|
+
|
|
78
|
+
return ir.Logical(task.engine, task.hoisted, tuple(body), task.annotations)
|
|
79
|
+
|
|
80
|
+
# Generate an output for each "column"
|
|
81
|
+
# output looks like: def output(:cols, :col000, key0, key1, value)
|
|
82
|
+
def _adjust_all_gnf_outputs(body, output: ir.Output, is_export: bool):
|
|
83
|
+
assert output.keys
|
|
84
|
+
|
|
85
|
+
original_cols = OrderedSet()
|
|
86
|
+
for idx, alias in enumerate(output.aliases):
|
|
87
|
+
# Skip None values which are used as a placeholder for missing values
|
|
88
|
+
if alias[1] is None:
|
|
89
|
+
continue
|
|
90
|
+
original_cols.add(alias[1])
|
|
91
|
+
body.extend(_generate_output_column_gnf(output, idx, alias, is_export))
|
|
92
|
+
|
|
93
|
+
idx = len(output.aliases)
|
|
94
|
+
for key in output.keys:
|
|
95
|
+
if key not in original_cols:
|
|
96
|
+
body.extend(_generate_output_column_gnf(output, idx, (key.name, key), is_export))
|
|
97
|
+
idx += 1
|
|
98
|
+
|
|
99
|
+
# Generate an output for each value "column" and one wide output for all the keys
|
|
100
|
+
# * value output looks like: def output(:cols, :col000, key0, key1, value)
|
|
101
|
+
# * key output looks like:
|
|
102
|
+
# def output(:keys, output_key_0, output_key_1, other_key_0, ...)
|
|
103
|
+
#
|
|
104
|
+
# Exceptions: keys for exports and compound keys are converted to GNF, same as the
|
|
105
|
+
# value columns.
|
|
106
|
+
def _adjust_outputs_with_wide_keys(body, output: ir.Output):
|
|
107
|
+
assert output.keys
|
|
108
|
+
|
|
109
|
+
original_cols = OrderedSet()
|
|
110
|
+
val_cols: list[Tuple[str, ir.Value] | None] = []
|
|
111
|
+
key_cols: OrderedSet[Tuple[str, ir.Value]] = OrderedSet()
|
|
112
|
+
key_cols.add(("keys", f.literal("keys", types.Symbol))) # name key col so we can identify it later
|
|
113
|
+
for alias in output.aliases:
|
|
114
|
+
# None values are used as a placeholder for missing values
|
|
115
|
+
# They are added to maintain the correct col count when enumerated below
|
|
116
|
+
if alias[1] is None:
|
|
117
|
+
val_cols.append(None)
|
|
118
|
+
continue
|
|
119
|
+
|
|
120
|
+
original_cols.add(alias[1])
|
|
121
|
+
|
|
122
|
+
if isinstance(alias[1], ir.Var) and alias[1] in output.keys: # note: skips compound keys
|
|
123
|
+
key_cols.add(alias)
|
|
124
|
+
else:
|
|
125
|
+
val_cols.append(alias)
|
|
126
|
+
|
|
127
|
+
# Add keys not in output to the end
|
|
128
|
+
for key in output.keys:
|
|
129
|
+
if key not in original_cols:
|
|
130
|
+
key_cols.add((key.name, key))
|
|
131
|
+
|
|
132
|
+
# Generate GNF val cols
|
|
133
|
+
for idx, alias in enumerate(val_cols):
|
|
134
|
+
if alias:
|
|
135
|
+
new_col = _generate_output_column(output, idx, alias, key_cols)
|
|
136
|
+
body.extend(new_col)
|
|
137
|
+
|
|
138
|
+
# Create a wide key column with all keys
|
|
139
|
+
if len(key_cols) > 1:
|
|
140
|
+
body.append(ir.Output(
|
|
141
|
+
output.engine,
|
|
142
|
+
key_cols.frozen(),
|
|
143
|
+
output.keys,
|
|
144
|
+
output.annotations
|
|
145
|
+
))
|
|
146
|
+
|
|
147
|
+
# Generate a relation representing a single col in GNF form
|
|
148
|
+
def _generate_output_column(output: ir.Output, idx: int, alias: tuple[str, ir.Value], key_cols):
|
|
149
|
+
if not output.keys:
|
|
150
|
+
return [output]
|
|
47
151
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
for output in outputs:
|
|
51
|
-
assert(isinstance(output, ir.Output))
|
|
52
|
-
if output.keys:
|
|
53
|
-
body.remove(output)
|
|
54
|
-
body.append(rewrite_wide_output(output))
|
|
55
|
-
return ir.Logical(task.engine, task.hoisted, tuple(body), task.annotations)
|
|
152
|
+
aliases = [("cols", f.literal("cols", types.Symbol))]
|
|
153
|
+
aliases.append(("col", f.literal(f"col{idx:03}", types.Symbol)))
|
|
56
154
|
|
|
57
|
-
#
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
body.extend(_generate_output_column(output, idx, alias, is_export))
|
|
77
|
-
|
|
78
|
-
idx = len(output.aliases)
|
|
79
|
-
for key in output.keys:
|
|
80
|
-
if key not in original_cols:
|
|
81
|
-
body.extend(_generate_output_column(output, idx, (key.name, key), is_export))
|
|
82
|
-
idx += 1
|
|
83
|
-
|
|
84
|
-
return ir.Logical(task.engine, task.hoisted, tuple(body), task.annotations)
|
|
85
|
-
|
|
86
|
-
# TODO: return non list?
|
|
87
|
-
def _generate_output_column(output: ir.Output, idx: int, alias: tuple[str, ir.Value], is_export: bool):
|
|
155
|
+
# Append all keys at the start
|
|
156
|
+
keys = iter(key_cols)
|
|
157
|
+
assert next(keys) == ("keys", f.literal("keys", types.Symbol)) # skip col name
|
|
158
|
+
for key in keys:
|
|
159
|
+
aliases.append((f"key_{key[0]}_{idx}", key[1]))
|
|
160
|
+
|
|
161
|
+
aliases.append(alias) # append val
|
|
162
|
+
|
|
163
|
+
return [
|
|
164
|
+
ir.Output(
|
|
165
|
+
output.engine,
|
|
166
|
+
FrozenOrderedSet.from_iterable(aliases),
|
|
167
|
+
output.keys,
|
|
168
|
+
output.annotations
|
|
169
|
+
)
|
|
170
|
+
]
|
|
171
|
+
|
|
172
|
+
# Generate a relation representing a single col in GNF form for export
|
|
173
|
+
def _generate_output_column_gnf(output: ir.Output, idx: int, alias: tuple[str, ir.Value], is_export: bool):
|
|
88
174
|
if not output.keys:
|
|
89
175
|
return [output]
|
|
90
176
|
|
|
91
177
|
aliases = [("cols", f.literal("cols", types.Symbol))] if not is_export else []
|
|
92
178
|
aliases.append(("col", f.literal(f"col{idx:03}", types.Symbol)))
|
|
93
179
|
|
|
94
|
-
# Append all keys at the start
|
|
95
180
|
for k in output.keys:
|
|
96
181
|
aliases.append((f"key_{k.name}_{idx}", k))
|
|
97
182
|
|
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from relationalai.semantics.metamodel import ir, helpers
|
|
4
|
+
from relationalai.semantics.metamodel.visitor import Rewriter
|
|
5
|
+
from relationalai.semantics.metamodel.compiler import Pass, group_tasks
|
|
6
|
+
from relationalai.semantics.metamodel.util import OrderedSet, ordered_set
|
|
7
|
+
from relationalai.semantics.metamodel import dependency
|
|
8
|
+
|
|
9
|
+
# This rewrite pass handles aggregations and ranks to ensure that their dependencies are
|
|
10
|
+
# self-contained in the required format for emitting Rel/LQP. The expected format is that
|
|
11
|
+
# aggregations and ranks are each contained in their own Logical (which hoists the output
|
|
12
|
+
# vars), with all dependencies pulled into the same Logical.
|
|
13
|
+
#
|
|
14
|
+
# For example,
|
|
15
|
+
#
|
|
16
|
+
# Logical ^[v::Int128]
|
|
17
|
+
# ... <dependencies> ...
|
|
18
|
+
# sum([foo::Foo], [], [a::Int128, v::Int128])
|
|
19
|
+
#
|
|
20
|
+
# Firstly, the pass ensures that all dependencies
|
|
21
|
+
# required by an aggregation/rank are contained locally in the same Logical as the
|
|
22
|
+
# aggregation/rank. For example, in the following Logical:
|
|
23
|
+
#
|
|
24
|
+
# Logical ^[v::Int128]
|
|
25
|
+
# Logical ^[a=None, foo=None]
|
|
26
|
+
# Foo(foo::Foo)
|
|
27
|
+
# a(foo::Foo, a::Int128)
|
|
28
|
+
# sum([foo::Foo], [], [a::Int128, v::Int128])
|
|
29
|
+
# common(foo::Foo, a::Int128)
|
|
30
|
+
#
|
|
31
|
+
# The aggregation `sum` depends on the relation `common`. So, the lookup for `common` needs
|
|
32
|
+
# to be pulled into the same Logical as the aggregation.
|
|
33
|
+
#
|
|
34
|
+
# Logical ^[v::Int128]
|
|
35
|
+
# Logical
|
|
36
|
+
# Logical ^[a=None, foo=None]
|
|
37
|
+
# Foo(foo::Foo)
|
|
38
|
+
# a(foo::Foo, a::Int128)
|
|
39
|
+
# common(foo::Foo, a::Int128)
|
|
40
|
+
# sum([foo::Foo], [], [a::Int128, v::Int128])
|
|
41
|
+
# common(foo::Foo, a::Int128)
|
|
42
|
+
#
|
|
43
|
+
# Secondly, the pass separates Logicals containing more than one aggregation/rank into
|
|
44
|
+
# separate Logicals, each containing a single aggregation/rank.
|
|
45
|
+
#
|
|
46
|
+
# Thirdly, the pass renames variables introduced inside aggregation/rank bodies to ensure
|
|
47
|
+
# they do not clash with variables outside.
|
|
48
|
+
|
|
49
|
+
class HandleAggregationsAndRanks(Pass):
|
|
50
|
+
def __init__(self):
|
|
51
|
+
super().__init__()
|
|
52
|
+
|
|
53
|
+
#--------------------------------------------------
|
|
54
|
+
# Public API
|
|
55
|
+
#--------------------------------------------------
|
|
56
|
+
def rewrite(self, model: ir.Model, options:dict={}) -> ir.Model:
|
|
57
|
+
dep_info = dependency.analyze(model.root)
|
|
58
|
+
|
|
59
|
+
r = AggregationsRanksRewriter(dep_info)
|
|
60
|
+
result = r.walk(model)
|
|
61
|
+
|
|
62
|
+
rn = AggregationsRanksVarRenameRewriter()
|
|
63
|
+
result = rn.walk(result)
|
|
64
|
+
|
|
65
|
+
return result
|
|
66
|
+
|
|
67
|
+
# The AggregationsRanksRewriter ensures that each aggregation and rank is contained
|
|
68
|
+
# in its own Logical, with all dependencies pulled into the same Logical.
|
|
69
|
+
class AggregationsRanksRewriter(Rewriter):
|
|
70
|
+
def __init__(self, dep_info):
|
|
71
|
+
super().__init__()
|
|
72
|
+
self.info = dep_info
|
|
73
|
+
self.rewritten: dict[int, ir.Node] = {}
|
|
74
|
+
|
|
75
|
+
def handle_logical(self, node: ir.Logical, parent: ir.Node):
|
|
76
|
+
groups = group_tasks(node.body, {
|
|
77
|
+
"aggregates_and_ranks": (ir.Aggregate, ir.Rank),
|
|
78
|
+
})
|
|
79
|
+
|
|
80
|
+
aggregates_and_ranks = groups["aggregates_and_ranks"]
|
|
81
|
+
|
|
82
|
+
# If there are no aggregates or ranks, then just recurse into the Logical body
|
|
83
|
+
if not aggregates_and_ranks:
|
|
84
|
+
return super().handle_logical(node, parent)
|
|
85
|
+
|
|
86
|
+
agg_rank_logicals = []
|
|
87
|
+
for agg_rank in aggregates_and_ranks:
|
|
88
|
+
# Gather all dependencies of the logical containing the aggregate/rank
|
|
89
|
+
agg_deps = self.info.task_dependencies(agg_rank)
|
|
90
|
+
|
|
91
|
+
# Reconstruct the body of the Logical containing the aggregate/rank, starting
|
|
92
|
+
# with the existing body
|
|
93
|
+
body = ordered_set()
|
|
94
|
+
|
|
95
|
+
# agg_body is the inner body containing the dependencies of the aggregate/rank
|
|
96
|
+
agg_body = ordered_set()
|
|
97
|
+
for t in node.body:
|
|
98
|
+
if isinstance(t, (ir.Output, ir.Update)):
|
|
99
|
+
# Outputs and Updates need to be kept in the outer body, rather than
|
|
100
|
+
# nested inside the aggregate/rank body.
|
|
101
|
+
body.add(t)
|
|
102
|
+
elif t not in aggregates_and_ranks:
|
|
103
|
+
agg_body.add(t)
|
|
104
|
+
|
|
105
|
+
# Add all other dependencies
|
|
106
|
+
for dep in agg_deps:
|
|
107
|
+
# HACK: there are bugs in the dependency analysis that can cause cycles.
|
|
108
|
+
# Avoid these cycles because otherwise they can cause infinite recursion.
|
|
109
|
+
if agg_rank in self.info.task_dependencies(dep) or node in self.info.task_dependencies(dep):
|
|
110
|
+
continue
|
|
111
|
+
agg_body.add(dep)
|
|
112
|
+
|
|
113
|
+
body.add(ir.Logical(node.engine, tuple(), tuple(agg_body)))
|
|
114
|
+
|
|
115
|
+
# Add the actual aggregate/rank
|
|
116
|
+
body.add(agg_rank.clone())
|
|
117
|
+
|
|
118
|
+
# Construct the final Logical holding the aggregate/rank contents.
|
|
119
|
+
# Output variables need to be hoisted
|
|
120
|
+
if isinstance(agg_rank, ir.Aggregate):
|
|
121
|
+
output_vars = [v for v in helpers.vars(agg_rank.args) if not helpers.is_aggregate_input(v, agg_rank)]
|
|
122
|
+
else:
|
|
123
|
+
assert isinstance(agg_rank, ir.Rank)
|
|
124
|
+
output_vars = [a for a in agg_rank.args]
|
|
125
|
+
output_vars.append(agg_rank.result)
|
|
126
|
+
|
|
127
|
+
agg_logical = ir.Logical(
|
|
128
|
+
engine=node.engine,
|
|
129
|
+
hoisted=tuple(output_vars),
|
|
130
|
+
body=tuple(body)
|
|
131
|
+
)
|
|
132
|
+
agg_rank_logicals.append(agg_logical)
|
|
133
|
+
|
|
134
|
+
if len(agg_rank_logicals) == 1:
|
|
135
|
+
# If there's only one, no need to create a parent Logical
|
|
136
|
+
result = agg_rank_logicals[0]
|
|
137
|
+
else:
|
|
138
|
+
# Otherwise, create a parent Logical, ensuring all body vars are hoisted
|
|
139
|
+
hoisted = OrderedSet()
|
|
140
|
+
for agg_rank_logical in agg_rank_logicals:
|
|
141
|
+
hoisted.update(agg_rank_logical.hoisted)
|
|
142
|
+
|
|
143
|
+
result = ir.Logical(
|
|
144
|
+
engine=node.engine,
|
|
145
|
+
hoisted=tuple(hoisted),
|
|
146
|
+
body=tuple(agg_rank_logicals)
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# Rewrite the children
|
|
150
|
+
result = super().handle_logical(result, parent)
|
|
151
|
+
|
|
152
|
+
# Make a deep copy so that each task has a unique id. This is important for later
|
|
153
|
+
# rewrite passes (namely QuantifyVars) which identify tasks by id.
|
|
154
|
+
result = DeepCopyRewriter().walk(result)
|
|
155
|
+
return result
|
|
156
|
+
|
|
157
|
+
# The AggregationsRanksVarRenameRewriter renames variables inside aggregation/rank
|
|
158
|
+
# bodies to ensure they do not clash with variables outside. It is careful to keep certain
|
|
159
|
+
# variables unrenamed because they need to interact with the outside, namely the group-by
|
|
160
|
+
# variables and output variables.
|
|
161
|
+
class AggregationsRanksVarRenameRewriter(Rewriter):
|
|
162
|
+
class RenameRewriter(Rewriter):
|
|
163
|
+
def __init__(self, to_keep: set[ir.Var], suffix: str):
|
|
164
|
+
super().__init__()
|
|
165
|
+
self.to_keep = to_keep
|
|
166
|
+
self.suffix = suffix
|
|
167
|
+
self.renamed_vars: dict[ir.Var, ir.Var] = {}
|
|
168
|
+
|
|
169
|
+
def handle_default(self, node: ir.Default, parent):
|
|
170
|
+
if node.var in self.to_keep:
|
|
171
|
+
return node
|
|
172
|
+
|
|
173
|
+
return ir.Default(
|
|
174
|
+
var=self.handle_var(node.var, node),
|
|
175
|
+
value=node.value
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
def handle_var(self, node: ir.Var, parent):
|
|
179
|
+
if node in self.to_keep:
|
|
180
|
+
return node
|
|
181
|
+
|
|
182
|
+
if node in self.renamed_vars:
|
|
183
|
+
return self.renamed_vars[node]
|
|
184
|
+
|
|
185
|
+
# Rename var
|
|
186
|
+
result = ir.Var(
|
|
187
|
+
type=node.type,
|
|
188
|
+
name=f"{node.name}_{self.suffix}"
|
|
189
|
+
)
|
|
190
|
+
self.renamed_vars[node] = result
|
|
191
|
+
return result
|
|
192
|
+
|
|
193
|
+
def __init__(self):
|
|
194
|
+
super().__init__()
|
|
195
|
+
self.renamed_vars: dict[ir.Var, ir.Var] = {}
|
|
196
|
+
|
|
197
|
+
def handle_logical(self, node: ir.Logical, parent: ir.Node):
|
|
198
|
+
groups = group_tasks(node.body, {
|
|
199
|
+
"aggregates_and_ranks": (ir.Aggregate, ir.Rank),
|
|
200
|
+
})
|
|
201
|
+
|
|
202
|
+
aggregates_and_ranks = groups["aggregates_and_ranks"]
|
|
203
|
+
if not aggregates_and_ranks:
|
|
204
|
+
return super().handle_logical(node, parent)
|
|
205
|
+
|
|
206
|
+
# There should only be one, because the AggregationsRanksRewriter should have
|
|
207
|
+
# separated them out.
|
|
208
|
+
assert len(aggregates_and_ranks) == 1, "Multiple aggregate/ranks still found after rewriting dependencies"
|
|
209
|
+
|
|
210
|
+
# Rename at this level
|
|
211
|
+
agg_rank = aggregates_and_ranks[0]
|
|
212
|
+
if isinstance(agg_rank, ir.Aggregate):
|
|
213
|
+
vars_to_keep = set(agg_rank.group)
|
|
214
|
+
output_vars = [v for v in helpers.vars(agg_rank.args) if not helpers.is_aggregate_input(v, agg_rank)]
|
|
215
|
+
vars_to_keep.update(output_vars)
|
|
216
|
+
|
|
217
|
+
result = self.RenameRewriter(vars_to_keep, 'agg').walk(node)
|
|
218
|
+
else:
|
|
219
|
+
assert isinstance(agg_rank, ir.Rank)
|
|
220
|
+
vars_to_keep = set(agg_rank.group)
|
|
221
|
+
output_vars = helpers.vars(agg_rank.args) + [agg_rank.result]
|
|
222
|
+
vars_to_keep.update(output_vars)
|
|
223
|
+
vars_to_keep.update(agg_rank.projection)
|
|
224
|
+
|
|
225
|
+
result = self.RenameRewriter(vars_to_keep, 'rank').walk(node)
|
|
226
|
+
|
|
227
|
+
# Process children
|
|
228
|
+
result = super().handle_logical(result, parent)
|
|
229
|
+
return result
|
|
230
|
+
|
|
231
|
+
class DeepCopyRewriter(Rewriter):
|
|
232
|
+
def walk(self, node, parent=None):
|
|
233
|
+
new_node = super().walk(node, parent)
|
|
234
|
+
if isinstance(new_node, ir.Task):
|
|
235
|
+
return new_node.clone()
|
|
236
|
+
else:
|
|
237
|
+
return new_node
|
|
@@ -1092,7 +1092,7 @@ class Replacer(visitor.Rewriter):
|
|
|
1092
1092
|
|
|
1093
1093
|
def handle_var(self, node: ir.Var, parent: ir.Node):
|
|
1094
1094
|
if node.id in self.net.resolved_types:
|
|
1095
|
-
return
|
|
1095
|
+
return node.reconstruct(name=node.name, type=self.net.resolved_types[node.id])
|
|
1096
1096
|
return node
|
|
1097
1097
|
|
|
1098
1098
|
def handle_literal(self, node: ir.Literal, parent: ir.Node):
|