relationalai 1.0.0a3__py3-none-any.whl → 1.0.0a5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/config/config.py +47 -21
- relationalai/config/connections/__init__.py +5 -2
- relationalai/config/connections/duckdb.py +2 -2
- relationalai/config/connections/local.py +31 -0
- relationalai/config/connections/snowflake.py +0 -1
- relationalai/config/external/raiconfig_converter.py +235 -0
- relationalai/config/external/raiconfig_models.py +202 -0
- relationalai/config/external/utils.py +31 -0
- relationalai/config/shims.py +1 -0
- relationalai/semantics/__init__.py +10 -8
- relationalai/semantics/backends/sql/sql_compiler.py +1 -4
- relationalai/semantics/experimental/__init__.py +0 -0
- relationalai/semantics/experimental/builder.py +295 -0
- relationalai/semantics/experimental/builtins.py +154 -0
- relationalai/semantics/frontend/base.py +67 -42
- relationalai/semantics/frontend/core.py +34 -6
- relationalai/semantics/frontend/front_compiler.py +209 -37
- relationalai/semantics/frontend/pprint.py +6 -2
- relationalai/semantics/metamodel/__init__.py +7 -0
- relationalai/semantics/metamodel/metamodel.py +2 -0
- relationalai/semantics/metamodel/metamodel_analyzer.py +58 -16
- relationalai/semantics/metamodel/pprint.py +6 -1
- relationalai/semantics/metamodel/rewriter.py +11 -7
- relationalai/semantics/metamodel/typer.py +116 -41
- relationalai/semantics/reasoners/__init__.py +11 -0
- relationalai/semantics/reasoners/graph/__init__.py +35 -0
- relationalai/semantics/reasoners/graph/core.py +9028 -0
- relationalai/semantics/std/__init__.py +30 -10
- relationalai/semantics/std/aggregates.py +641 -12
- relationalai/semantics/std/common.py +146 -13
- relationalai/semantics/std/constraints.py +71 -1
- relationalai/semantics/std/datetime.py +904 -21
- relationalai/semantics/std/decimals.py +143 -2
- relationalai/semantics/std/floats.py +57 -4
- relationalai/semantics/std/integers.py +98 -4
- relationalai/semantics/std/math.py +857 -35
- relationalai/semantics/std/numbers.py +216 -20
- relationalai/semantics/std/re.py +213 -5
- relationalai/semantics/std/strings.py +437 -44
- relationalai/shims/executor.py +60 -52
- relationalai/shims/fixtures.py +85 -0
- relationalai/shims/helpers.py +26 -2
- relationalai/shims/hoister.py +28 -9
- relationalai/shims/mm2v0.py +204 -173
- relationalai/tools/cli/cli.py +192 -10
- relationalai/tools/cli/components/progress_reader.py +1 -1
- relationalai/tools/cli/docs.py +394 -0
- relationalai/tools/debugger.py +11 -4
- relationalai/tools/qb_debugger.py +435 -0
- relationalai/tools/typer_debugger.py +1 -2
- relationalai/util/dataclasses.py +3 -5
- relationalai/util/docutils.py +1 -2
- relationalai/util/error.py +2 -5
- relationalai/util/python.py +23 -0
- relationalai/util/runtime.py +1 -2
- relationalai/util/schema.py +2 -4
- relationalai/util/structures.py +4 -2
- relationalai/util/tracing.py +8 -2
- {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/METADATA +8 -5
- {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/RECORD +118 -95
- {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/WHEEL +1 -1
- v0/relationalai/__init__.py +1 -1
- v0/relationalai/clients/client.py +52 -18
- v0/relationalai/clients/exec_txn_poller.py +122 -0
- v0/relationalai/clients/local.py +23 -8
- v0/relationalai/clients/resources/azure/azure.py +36 -11
- v0/relationalai/clients/resources/snowflake/__init__.py +4 -4
- v0/relationalai/clients/resources/snowflake/cli_resources.py +12 -1
- v0/relationalai/clients/resources/snowflake/direct_access_resources.py +124 -100
- v0/relationalai/clients/resources/snowflake/engine_service.py +381 -0
- v0/relationalai/clients/resources/snowflake/engine_state_handlers.py +35 -29
- v0/relationalai/clients/resources/snowflake/error_handlers.py +43 -2
- v0/relationalai/clients/resources/snowflake/snowflake.py +277 -179
- v0/relationalai/clients/resources/snowflake/use_index_poller.py +8 -0
- v0/relationalai/clients/types.py +5 -0
- v0/relationalai/errors.py +19 -1
- v0/relationalai/semantics/lqp/algorithms.py +173 -0
- v0/relationalai/semantics/lqp/builtins.py +199 -2
- v0/relationalai/semantics/lqp/executor.py +68 -37
- v0/relationalai/semantics/lqp/ir.py +28 -2
- v0/relationalai/semantics/lqp/model2lqp.py +215 -45
- v0/relationalai/semantics/lqp/passes.py +13 -658
- v0/relationalai/semantics/lqp/rewrite/__init__.py +12 -0
- v0/relationalai/semantics/lqp/rewrite/algorithm.py +385 -0
- v0/relationalai/semantics/lqp/rewrite/constants_to_vars.py +70 -0
- v0/relationalai/semantics/lqp/rewrite/deduplicate_vars.py +104 -0
- v0/relationalai/semantics/lqp/rewrite/eliminate_data.py +108 -0
- v0/relationalai/semantics/lqp/rewrite/extract_keys.py +25 -3
- v0/relationalai/semantics/lqp/rewrite/period_math.py +77 -0
- v0/relationalai/semantics/lqp/rewrite/quantify_vars.py +65 -31
- v0/relationalai/semantics/lqp/rewrite/unify_definitions.py +317 -0
- v0/relationalai/semantics/lqp/utils.py +11 -1
- v0/relationalai/semantics/lqp/validators.py +14 -1
- v0/relationalai/semantics/metamodel/builtins.py +2 -1
- v0/relationalai/semantics/metamodel/compiler.py +2 -1
- v0/relationalai/semantics/metamodel/dependency.py +12 -3
- v0/relationalai/semantics/metamodel/executor.py +11 -1
- v0/relationalai/semantics/metamodel/factory.py +2 -2
- v0/relationalai/semantics/metamodel/helpers.py +7 -0
- v0/relationalai/semantics/metamodel/ir.py +3 -2
- v0/relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +30 -20
- v0/relationalai/semantics/metamodel/rewrite/flatten.py +50 -13
- v0/relationalai/semantics/metamodel/rewrite/format_outputs.py +9 -3
- v0/relationalai/semantics/metamodel/typer/checker.py +6 -4
- v0/relationalai/semantics/metamodel/typer/typer.py +4 -3
- v0/relationalai/semantics/metamodel/visitor.py +4 -3
- v0/relationalai/semantics/reasoners/optimization/solvers_dev.py +1 -1
- v0/relationalai/semantics/reasoners/optimization/solvers_pb.py +336 -86
- v0/relationalai/semantics/rel/compiler.py +2 -1
- v0/relationalai/semantics/rel/executor.py +3 -2
- v0/relationalai/semantics/tests/lqp/__init__.py +0 -0
- v0/relationalai/semantics/tests/lqp/algorithms.py +345 -0
- v0/relationalai/tools/cli.py +339 -186
- v0/relationalai/tools/cli_controls.py +216 -67
- v0/relationalai/tools/cli_helpers.py +410 -6
- v0/relationalai/util/format.py +5 -2
- {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/entry_points.txt +0 -0
- {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from v0.relationalai.semantics.metamodel.compiler import Pass
|
|
2
|
+
from v0.relationalai.semantics.metamodel import ir, builtins as rel_builtins, factory as f, visitor
|
|
3
|
+
|
|
4
|
+
from typing import cast
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import hashlib
|
|
7
|
+
|
|
8
|
+
# Creates intermediary relations for all Data nodes and replaces said Data nodes
|
|
9
|
+
# with a Lookup into these created relations. Reuse duplicate created relations.
|
|
10
|
+
class EliminateData(Pass):
|
|
11
|
+
def rewrite(self, model: ir.Model, options:dict={}) -> ir.Model:
|
|
12
|
+
r = self.DataRewriter()
|
|
13
|
+
return r.walk(model)
|
|
14
|
+
|
|
15
|
+
# Does the actual work.
|
|
16
|
+
class DataRewriter(visitor.Rewriter):
|
|
17
|
+
new_relations: list[ir.Relation]
|
|
18
|
+
new_updates: list[ir.Logical]
|
|
19
|
+
# Counter for naming new relations.
|
|
20
|
+
# It must be that new_count == len new_updates == len new_relations.
|
|
21
|
+
new_count: int
|
|
22
|
+
# Cache for Data nodes to avoid creating duplicate intermediary relations
|
|
23
|
+
data_cache: dict[str, ir.Relation]
|
|
24
|
+
|
|
25
|
+
def __init__(self):
|
|
26
|
+
self.new_relations = []
|
|
27
|
+
self.new_updates = []
|
|
28
|
+
self.new_count = 0
|
|
29
|
+
self.data_cache = {}
|
|
30
|
+
super().__init__()
|
|
31
|
+
|
|
32
|
+
# Create a cache key for a Data node based on its structure and content
|
|
33
|
+
def _data_cache_key(self, node: ir.Data) -> str:
|
|
34
|
+
values = pd.util.hash_pandas_object(node.data).values
|
|
35
|
+
return hashlib.sha256(bytes(values)).hexdigest()
|
|
36
|
+
|
|
37
|
+
def _intermediary_relation(self, node: ir.Data) -> ir.Relation:
|
|
38
|
+
cache_key = self._data_cache_key(node)
|
|
39
|
+
if cache_key in self.data_cache:
|
|
40
|
+
return self.data_cache[cache_key]
|
|
41
|
+
self.new_count += 1
|
|
42
|
+
intermediary_name = f"formerly_Data_{self.new_count}"
|
|
43
|
+
|
|
44
|
+
intermediary_relation = f.relation(
|
|
45
|
+
intermediary_name,
|
|
46
|
+
[f.field(v.name, v.type) for v in node.vars]
|
|
47
|
+
)
|
|
48
|
+
self.new_relations.append(intermediary_relation)
|
|
49
|
+
|
|
50
|
+
intermediary_update = f.logical([
|
|
51
|
+
# For each row (union), equate values and their variable (logical).
|
|
52
|
+
f.union(
|
|
53
|
+
[
|
|
54
|
+
f.logical(
|
|
55
|
+
[
|
|
56
|
+
f.lookup(rel_builtins.eq, [f.literal(val, var.type), var])
|
|
57
|
+
for (val, var) in zip(row, node.vars)
|
|
58
|
+
],
|
|
59
|
+
)
|
|
60
|
+
for row in node
|
|
61
|
+
],
|
|
62
|
+
hoisted = node.vars,
|
|
63
|
+
),
|
|
64
|
+
# And pop it back into the relation.
|
|
65
|
+
f.update(intermediary_relation, node.vars, ir.Effect.derive),
|
|
66
|
+
])
|
|
67
|
+
self.new_updates.append(intermediary_update)
|
|
68
|
+
|
|
69
|
+
# Cache the result for reuse
|
|
70
|
+
self.data_cache[cache_key] = intermediary_relation
|
|
71
|
+
|
|
72
|
+
return intermediary_relation
|
|
73
|
+
|
|
74
|
+
# Create a new intermediary relation representing the Data (and pop it in
|
|
75
|
+
# new_updates/new_relations) and replace this Data with a Lookup of said
|
|
76
|
+
# intermediary.
|
|
77
|
+
def handle_data(self, node: ir.Data, parent: ir.Node) -> ir.Lookup:
|
|
78
|
+
intermediary_relation = self._intermediary_relation(node)
|
|
79
|
+
replacement_lookup = f.lookup(intermediary_relation, node.vars)
|
|
80
|
+
|
|
81
|
+
return replacement_lookup
|
|
82
|
+
|
|
83
|
+
# Walks the model for the handle_data work then updates the model with
|
|
84
|
+
# the new state.
|
|
85
|
+
def handle_model(self, model: ir.Model, parent: None):
|
|
86
|
+
walked_model = super().handle_model(model, parent)
|
|
87
|
+
assert len(self.new_relations) == len(self.new_updates) and self.new_count == len(self.new_relations)
|
|
88
|
+
|
|
89
|
+
# This is okay because its LQP.
|
|
90
|
+
assert isinstance(walked_model.root, ir.Logical)
|
|
91
|
+
root_logical = cast(ir.Logical, walked_model.root)
|
|
92
|
+
|
|
93
|
+
# We may need to add the new intermediaries from handle_data to the model.
|
|
94
|
+
if self.new_count == 0:
|
|
95
|
+
return model
|
|
96
|
+
else:
|
|
97
|
+
return ir.Model(
|
|
98
|
+
walked_model.engines,
|
|
99
|
+
walked_model.relations | self.new_relations,
|
|
100
|
+
walked_model.types,
|
|
101
|
+
ir.Logical(
|
|
102
|
+
root_logical.engine,
|
|
103
|
+
root_logical.hoisted,
|
|
104
|
+
root_logical.body + tuple(self.new_updates),
|
|
105
|
+
root_logical.annotations,
|
|
106
|
+
),
|
|
107
|
+
walked_model.annotations,
|
|
108
|
+
)
|
|
@@ -118,6 +118,17 @@ class ExtractKeys(Pass):
|
|
|
118
118
|
the same here).
|
|
119
119
|
"""
|
|
120
120
|
class ExtractKeysRewriter(Rewriter):
|
|
121
|
+
def __init__(self):
|
|
122
|
+
super().__init__()
|
|
123
|
+
self.compound_keys: dict[Any, ir.Var] = {}
|
|
124
|
+
|
|
125
|
+
def _get_compound_key(self, orig_keys: Iterable[ir.Var]) -> ir.Var:
|
|
126
|
+
if orig_keys in self.compound_keys:
|
|
127
|
+
return self.compound_keys[orig_keys]
|
|
128
|
+
compound_key = f.var("compound_key", types.Hash)
|
|
129
|
+
self.compound_keys[orig_keys] = compound_key
|
|
130
|
+
return compound_key
|
|
131
|
+
|
|
121
132
|
def handle_logical(self, node: ir.Logical, parent: ir.Node, ctx:Optional[Any]=None) -> ir.Logical:
|
|
122
133
|
outputs = [x for x in node.body if isinstance(x, ir.Output) and x.keys]
|
|
123
134
|
# We are not in a logical with an output at this level.
|
|
@@ -170,7 +181,7 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
170
181
|
annos = list(output.annotations)
|
|
171
182
|
annos.append(f.annotation(builtins.output_keys, tuple(output_keys)))
|
|
172
183
|
# Create a compound key that will be used in place of the original keys.
|
|
173
|
-
compound_key =
|
|
184
|
+
compound_key = self._get_compound_key(output_keys)
|
|
174
185
|
|
|
175
186
|
for key_combination in combinations:
|
|
176
187
|
missing_keys = OrderedSet.from_iterable(output_keys)
|
|
@@ -192,8 +203,13 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
192
203
|
# handle the construct node in each clone
|
|
193
204
|
values: list[ir.Value] = [compound_key.type]
|
|
194
205
|
for key in output_keys:
|
|
195
|
-
|
|
196
|
-
|
|
206
|
+
if isinstance(key.type, ir.UnionType):
|
|
207
|
+
# the typer can derive union types when multiple distinct entities flow
|
|
208
|
+
# into a relation's field, so use AnyEntity as the type marker
|
|
209
|
+
values.append(ir.Literal(types.String, "AnyEntity"))
|
|
210
|
+
else:
|
|
211
|
+
assert isinstance(key.type, ir.ScalarType)
|
|
212
|
+
values.append(ir.Literal(types.String, key.type.name))
|
|
197
213
|
if key in key_combination:
|
|
198
214
|
values.append(key)
|
|
199
215
|
body.add(ir.Construct(None, tuple(values), compound_key, OrderedSet().frozen()))
|
|
@@ -408,6 +424,12 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
408
424
|
for arg in args[:-1]:
|
|
409
425
|
extended_vars.add(arg)
|
|
410
426
|
there_is_progress = True
|
|
427
|
+
elif isinstance(task, ir.Not):
|
|
428
|
+
if isinstance(task.task, ir.Logical):
|
|
429
|
+
hoisted = helpers.hoisted_vars(task.task.hoisted)
|
|
430
|
+
if var in hoisted:
|
|
431
|
+
partitions[var].add(task)
|
|
432
|
+
there_is_progress = True
|
|
411
433
|
else:
|
|
412
434
|
assert False, f"invalid node kind {type(task)}"
|
|
413
435
|
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from v0.relationalai.semantics.metamodel.compiler import Pass
|
|
2
|
+
from v0.relationalai.semantics.metamodel import ir, builtins as rel_builtins, factory as f, visitor
|
|
3
|
+
from v0.relationalai.semantics.metamodel import types
|
|
4
|
+
|
|
5
|
+
# Generate date arithmetic expressions, such as
|
|
6
|
+
# `rel_primitive_date_add(:day, [date] delta, res_2)` by finding the period
|
|
7
|
+
# expression for the delta and adding the period type to the date arithmetic expression.
|
|
8
|
+
#
|
|
9
|
+
# date_add and it's kin are generated by a period expression, e.g.,
|
|
10
|
+
# `day(delta, res_1)`
|
|
11
|
+
# followed by the date arithmetic expression using the period
|
|
12
|
+
# `date_add([date] res_1 res_2)`
|
|
13
|
+
class PeriodMath(Pass):
|
|
14
|
+
def rewrite(self, model: ir.Model, options:dict={}) -> ir.Model:
|
|
15
|
+
period_rewriter = self.PeriodRewriter()
|
|
16
|
+
model = period_rewriter.walk(model)
|
|
17
|
+
period_math_rewriter = self.PeriodMathRewriter(period_rewriter.period_vars)
|
|
18
|
+
model = period_math_rewriter.walk(model)
|
|
19
|
+
return model
|
|
20
|
+
|
|
21
|
+
# Find all period builtins. We need to make them safe for the emitter (either by
|
|
22
|
+
# translating to a cast, or removing) and store the variable and period type for use
|
|
23
|
+
# in the date/datetime add/subtract expressions.
|
|
24
|
+
class PeriodRewriter(visitor.Rewriter):
|
|
25
|
+
def __init__(self):
|
|
26
|
+
super().__init__()
|
|
27
|
+
self.period_vars: dict[ir.Var, str] = {}
|
|
28
|
+
|
|
29
|
+
def handle_lookup(self, node: ir.Lookup, parent: ir.Node) -> ir.Lookup:
|
|
30
|
+
if not rel_builtins.is_builtin(node.relation):
|
|
31
|
+
return node
|
|
32
|
+
|
|
33
|
+
if node.relation.name not in {
|
|
34
|
+
"year", "month", "week", "day", "hour", "minute", "second", "millisecond", "microsecond", "nanosecond"
|
|
35
|
+
}:
|
|
36
|
+
return node
|
|
37
|
+
|
|
38
|
+
assert len(node.args) == 2, "Expect 2 arguments for period builtins"
|
|
39
|
+
assert isinstance(node.args[1], ir.Var), "Expect result to be a variable"
|
|
40
|
+
period = node.relation.name
|
|
41
|
+
result_var = node.args[1]
|
|
42
|
+
self.period_vars[result_var] = period
|
|
43
|
+
|
|
44
|
+
# Ideally we could now remove the unused and unhandled period type construction
|
|
45
|
+
# but we may also need to cast the original variable to an Int64 for use by the
|
|
46
|
+
# date/datetime add/subtract expressions.
|
|
47
|
+
# TODO: Remove the node entirely where possible and update uses of the result
|
|
48
|
+
return f.lookup(rel_builtins.cast, [types.Int64, node.args[0], result_var])
|
|
49
|
+
|
|
50
|
+
# Update date/datetime add/subtract expressions with period information.
|
|
51
|
+
class PeriodMathRewriter(visitor.Rewriter):
|
|
52
|
+
def __init__(self, period_vars: dict[ir.Var, str]):
|
|
53
|
+
super().__init__()
|
|
54
|
+
self.period_vars: dict[ir.Var, str] = period_vars
|
|
55
|
+
|
|
56
|
+
def handle_lookup(self, node: ir.Lookup, parent: ir.Node) -> ir.Lookup:
|
|
57
|
+
if not rel_builtins.is_builtin(node.relation):
|
|
58
|
+
return node
|
|
59
|
+
|
|
60
|
+
if node.relation.name not in {
|
|
61
|
+
"date_add", "date_subtract", "datetime_add", "datetime_subtract"
|
|
62
|
+
}:
|
|
63
|
+
return node
|
|
64
|
+
|
|
65
|
+
if len(node.args) == 4:
|
|
66
|
+
# We've already visited this lookup
|
|
67
|
+
return node
|
|
68
|
+
|
|
69
|
+
assert isinstance(node.args[1], ir.Var), "Expect period to be a variable"
|
|
70
|
+
period_var = node.args[1]
|
|
71
|
+
assert period_var in self.period_vars, "datemath found, but no vars to insert"
|
|
72
|
+
|
|
73
|
+
period = self.period_vars[period_var]
|
|
74
|
+
|
|
75
|
+
new_args = [f.literal(period, types.Symbol)] + [arg for arg in node.args]
|
|
76
|
+
|
|
77
|
+
return f.lookup(node.relation, new_args)
|
|
@@ -69,7 +69,7 @@ class VarScopeInfo(Visitor):
|
|
|
69
69
|
ir.Var, ir.Literal, ir.Relation, ir.Field,
|
|
70
70
|
ir.Default, ir.Output, ir.Update, ir.Aggregate,
|
|
71
71
|
ir.Check, ir.Require,
|
|
72
|
-
ir.Annotation, ir.Rank)
|
|
72
|
+
ir.Annotation, ir.Rank, ir.Break)
|
|
73
73
|
|
|
74
74
|
def __init__(self):
|
|
75
75
|
super().__init__()
|
|
@@ -103,16 +103,29 @@ class VarScopeInfo(Visitor):
|
|
|
103
103
|
self._record(node, scope_vars)
|
|
104
104
|
|
|
105
105
|
elif isinstance(node, (ir.Match, ir.Union)):
|
|
106
|
-
# Match/Union inherits
|
|
106
|
+
# Match/Union only inherits vars if they are in scope for all sub-tasks.
|
|
107
107
|
scope_vars = ordered_set()
|
|
108
|
+
# Prime the search with the first sub-task's vars.
|
|
109
|
+
if node.tasks:
|
|
110
|
+
scope_vars.update(self._vars_in_scope.get(node.tasks[0].id, None))
|
|
111
|
+
|
|
108
112
|
for task in node.tasks:
|
|
109
113
|
sub_scope_vars = self._vars_in_scope.get(task.id, None)
|
|
110
|
-
if sub_scope_vars:
|
|
111
|
-
scope_vars
|
|
114
|
+
if not scope_vars or not sub_scope_vars:
|
|
115
|
+
scope_vars = ordered_set()
|
|
116
|
+
break
|
|
117
|
+
scope_vars = (scope_vars & sub_scope_vars)
|
|
118
|
+
|
|
112
119
|
# Hoisted vars are not considered for quantification at this level.
|
|
113
120
|
scope_vars.difference_update(helpers.hoisted_vars(node.hoisted))
|
|
114
121
|
self._record(node, scope_vars)
|
|
115
122
|
|
|
123
|
+
elif isinstance(node, (ir.Loop, ir.Sequence)):
|
|
124
|
+
# Variables in Loops and Sequences are scoped exclusively within the body and
|
|
125
|
+
# not propagated outside. No need to record any variables, as they shouldn't be
|
|
126
|
+
# in scope for the node itself
|
|
127
|
+
pass
|
|
128
|
+
|
|
116
129
|
elif isinstance(node, ir.Logical):
|
|
117
130
|
self._do_logical(node)
|
|
118
131
|
|
|
@@ -128,6 +141,9 @@ class VarScopeInfo(Visitor):
|
|
|
128
141
|
all_nested_vars = ordered_set()
|
|
129
142
|
output_vars = ordered_set()
|
|
130
143
|
|
|
144
|
+
# Collect variables nested in child Logical and Not nodes
|
|
145
|
+
nested_vars_in_task: dict[ir.Var, int] = dict()
|
|
146
|
+
|
|
131
147
|
# Collect all variables from logical sub-tasks
|
|
132
148
|
for task in node.body:
|
|
133
149
|
if isinstance(task, ir.Output):
|
|
@@ -140,19 +156,29 @@ class VarScopeInfo(Visitor):
|
|
|
140
156
|
scope_vars.add(var)
|
|
141
157
|
continue
|
|
142
158
|
|
|
143
|
-
sub_scope_vars = self._vars_in_scope.get(task.id, None)
|
|
144
|
-
|
|
145
159
|
# Hoisted variables from sub-tasks are brought again into scope.
|
|
146
160
|
if isinstance(task, (ir.Logical, ir.Union, ir.Match)):
|
|
147
161
|
scope_vars.update(helpers.hoisted_vars(task.hoisted))
|
|
148
162
|
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
163
|
+
# Get variables in sub-task scope
|
|
164
|
+
sub_scope_vars = self._vars_in_scope.get(task.id, ordered_set())
|
|
165
|
+
|
|
166
|
+
if isinstance(task, ir.Logical):
|
|
167
|
+
# Logical child nodes should have their nested variables quantified
|
|
168
|
+
# only if they are needed in more than one child task.
|
|
169
|
+
for var in sub_scope_vars:
|
|
170
|
+
if var not in nested_vars_in_task:
|
|
171
|
+
nested_vars_in_task[var] = 0
|
|
172
|
+
nested_vars_in_task[var] += 1
|
|
173
|
+
elif not isinstance(task, ir.Not):
|
|
174
|
+
# Other nodes with nested variables need to be quantified at this level
|
|
175
|
+
scope_vars.update(sub_scope_vars)
|
|
176
|
+
|
|
177
|
+
for v, c in nested_vars_in_task.items():
|
|
178
|
+
# If the variable appears in more than one nested child, then it needs to be
|
|
179
|
+
# quantified here. Otherwise, it will be handled in the child node
|
|
180
|
+
if c > 1:
|
|
181
|
+
all_nested_vars.add(v)
|
|
156
182
|
|
|
157
183
|
# Nested variables also need to be introduced, provided they are not output variables.
|
|
158
184
|
for var in all_nested_vars:
|
|
@@ -190,37 +216,30 @@ class FindQuantificationNodes(Visitor):
|
|
|
190
216
|
def __init__(self, var_info: VarScopeInfo):
|
|
191
217
|
super().__init__()
|
|
192
218
|
self._vars_in_scope = var_info._vars_in_scope
|
|
193
|
-
self.
|
|
219
|
+
self.handled_vars: dict[int, OrderedSet[ir.Var]] = {}
|
|
194
220
|
self.node_quantifies_vars = {}
|
|
195
221
|
|
|
196
222
|
def enter(self, node: ir.Node, parent: Optional[ir.Node]=None) -> "Visitor":
|
|
197
223
|
if contains_only_declarable_constraints(node):
|
|
198
224
|
return self
|
|
199
225
|
|
|
226
|
+
handled_vars = self.handled_vars.get(parent.id, ordered_set()) if parent else ordered_set()
|
|
227
|
+
# Clone the set to avoid modifying parent's handled vars
|
|
228
|
+
handled_vars = OrderedSet.from_iterable(handled_vars)
|
|
229
|
+
|
|
200
230
|
if isinstance(node, (ir.Logical, ir.Not)):
|
|
201
231
|
ignored_vars = _ignored_vars(node)
|
|
202
|
-
|
|
232
|
+
handled_vars.update(ignored_vars)
|
|
203
233
|
|
|
204
234
|
scope_vars = self._vars_in_scope.get(node.id, None)
|
|
205
235
|
if scope_vars:
|
|
206
|
-
scope_vars.difference_update(
|
|
236
|
+
scope_vars.difference_update(handled_vars)
|
|
207
237
|
if scope_vars:
|
|
208
|
-
|
|
238
|
+
handled_vars.update(scope_vars)
|
|
209
239
|
self.node_quantifies_vars[node.id] = scope_vars
|
|
210
|
-
return self
|
|
211
|
-
|
|
212
|
-
def leave(self, node: ir.Node, parent: Optional[ir.Node]=None) -> ir.Node:
|
|
213
|
-
if contains_only_declarable_constraints(node):
|
|
214
|
-
return node
|
|
215
240
|
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
self._handled_vars.difference_update(ignored_vars)
|
|
219
|
-
|
|
220
|
-
scope_vars = self._vars_in_scope.get(node.id, None)
|
|
221
|
-
if scope_vars:
|
|
222
|
-
self._handled_vars.difference_update(scope_vars)
|
|
223
|
-
return node
|
|
241
|
+
self.handled_vars[node.id] = handled_vars
|
|
242
|
+
return self
|
|
224
243
|
|
|
225
244
|
class QuantifyVarsRewriter(Rewriter):
|
|
226
245
|
"""
|
|
@@ -254,7 +273,12 @@ class QuantifyVarsRewriter(Rewriter):
|
|
|
254
273
|
# in IR directly may do so and the flatten pass doesn't split them yet.
|
|
255
274
|
if len(agg_or_rank_tasks) > 0:
|
|
256
275
|
print(f"Multiple aggregate/rank tasks found: {agg_or_rank_tasks} and {task}")
|
|
257
|
-
|
|
276
|
+
# If the agg/rank depends on any of the vars being quantified here,
|
|
277
|
+
# then it needs to be inside the quantification
|
|
278
|
+
if any(var in helpers.vars(task.projection) for var in vars):
|
|
279
|
+
inner_tasks.append(task)
|
|
280
|
+
else:
|
|
281
|
+
agg_or_rank_tasks.append(task)
|
|
258
282
|
|
|
259
283
|
else:
|
|
260
284
|
inner_tasks.append(task)
|
|
@@ -283,6 +307,16 @@ class QuantifyVarsRewriter(Rewriter):
|
|
|
283
307
|
|
|
284
308
|
return node if node.task is new_task else f.not_(new_task)
|
|
285
309
|
|
|
310
|
+
def handle_union(self, node: ir.Union, parent: ir.Node, ctx:Optional[Any]=None) -> ir.Union:
|
|
311
|
+
if not node.tasks:
|
|
312
|
+
return node
|
|
313
|
+
|
|
314
|
+
new_tasks = self.walk_list(node.tasks, node)
|
|
315
|
+
return node if node.tasks is new_tasks else f.union(
|
|
316
|
+
tasks = new_tasks,
|
|
317
|
+
hoisted = node.hoisted,
|
|
318
|
+
)
|
|
319
|
+
|
|
286
320
|
# To avoid unnecessary cloning of vars in the visitor.
|
|
287
321
|
def handle_var(self, node: ir.Var, parent: ir.Node, ctx:Optional[Any]=None) -> ir.Var:
|
|
288
322
|
return node
|