relationalai 1.0.0a3__py3-none-any.whl → 1.0.0a5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/config/config.py +47 -21
- relationalai/config/connections/__init__.py +5 -2
- relationalai/config/connections/duckdb.py +2 -2
- relationalai/config/connections/local.py +31 -0
- relationalai/config/connections/snowflake.py +0 -1
- relationalai/config/external/raiconfig_converter.py +235 -0
- relationalai/config/external/raiconfig_models.py +202 -0
- relationalai/config/external/utils.py +31 -0
- relationalai/config/shims.py +1 -0
- relationalai/semantics/__init__.py +10 -8
- relationalai/semantics/backends/sql/sql_compiler.py +1 -4
- relationalai/semantics/experimental/__init__.py +0 -0
- relationalai/semantics/experimental/builder.py +295 -0
- relationalai/semantics/experimental/builtins.py +154 -0
- relationalai/semantics/frontend/base.py +67 -42
- relationalai/semantics/frontend/core.py +34 -6
- relationalai/semantics/frontend/front_compiler.py +209 -37
- relationalai/semantics/frontend/pprint.py +6 -2
- relationalai/semantics/metamodel/__init__.py +7 -0
- relationalai/semantics/metamodel/metamodel.py +2 -0
- relationalai/semantics/metamodel/metamodel_analyzer.py +58 -16
- relationalai/semantics/metamodel/pprint.py +6 -1
- relationalai/semantics/metamodel/rewriter.py +11 -7
- relationalai/semantics/metamodel/typer.py +116 -41
- relationalai/semantics/reasoners/__init__.py +11 -0
- relationalai/semantics/reasoners/graph/__init__.py +35 -0
- relationalai/semantics/reasoners/graph/core.py +9028 -0
- relationalai/semantics/std/__init__.py +30 -10
- relationalai/semantics/std/aggregates.py +641 -12
- relationalai/semantics/std/common.py +146 -13
- relationalai/semantics/std/constraints.py +71 -1
- relationalai/semantics/std/datetime.py +904 -21
- relationalai/semantics/std/decimals.py +143 -2
- relationalai/semantics/std/floats.py +57 -4
- relationalai/semantics/std/integers.py +98 -4
- relationalai/semantics/std/math.py +857 -35
- relationalai/semantics/std/numbers.py +216 -20
- relationalai/semantics/std/re.py +213 -5
- relationalai/semantics/std/strings.py +437 -44
- relationalai/shims/executor.py +60 -52
- relationalai/shims/fixtures.py +85 -0
- relationalai/shims/helpers.py +26 -2
- relationalai/shims/hoister.py +28 -9
- relationalai/shims/mm2v0.py +204 -173
- relationalai/tools/cli/cli.py +192 -10
- relationalai/tools/cli/components/progress_reader.py +1 -1
- relationalai/tools/cli/docs.py +394 -0
- relationalai/tools/debugger.py +11 -4
- relationalai/tools/qb_debugger.py +435 -0
- relationalai/tools/typer_debugger.py +1 -2
- relationalai/util/dataclasses.py +3 -5
- relationalai/util/docutils.py +1 -2
- relationalai/util/error.py +2 -5
- relationalai/util/python.py +23 -0
- relationalai/util/runtime.py +1 -2
- relationalai/util/schema.py +2 -4
- relationalai/util/structures.py +4 -2
- relationalai/util/tracing.py +8 -2
- {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/METADATA +8 -5
- {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/RECORD +118 -95
- {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/WHEEL +1 -1
- v0/relationalai/__init__.py +1 -1
- v0/relationalai/clients/client.py +52 -18
- v0/relationalai/clients/exec_txn_poller.py +122 -0
- v0/relationalai/clients/local.py +23 -8
- v0/relationalai/clients/resources/azure/azure.py +36 -11
- v0/relationalai/clients/resources/snowflake/__init__.py +4 -4
- v0/relationalai/clients/resources/snowflake/cli_resources.py +12 -1
- v0/relationalai/clients/resources/snowflake/direct_access_resources.py +124 -100
- v0/relationalai/clients/resources/snowflake/engine_service.py +381 -0
- v0/relationalai/clients/resources/snowflake/engine_state_handlers.py +35 -29
- v0/relationalai/clients/resources/snowflake/error_handlers.py +43 -2
- v0/relationalai/clients/resources/snowflake/snowflake.py +277 -179
- v0/relationalai/clients/resources/snowflake/use_index_poller.py +8 -0
- v0/relationalai/clients/types.py +5 -0
- v0/relationalai/errors.py +19 -1
- v0/relationalai/semantics/lqp/algorithms.py +173 -0
- v0/relationalai/semantics/lqp/builtins.py +199 -2
- v0/relationalai/semantics/lqp/executor.py +68 -37
- v0/relationalai/semantics/lqp/ir.py +28 -2
- v0/relationalai/semantics/lqp/model2lqp.py +215 -45
- v0/relationalai/semantics/lqp/passes.py +13 -658
- v0/relationalai/semantics/lqp/rewrite/__init__.py +12 -0
- v0/relationalai/semantics/lqp/rewrite/algorithm.py +385 -0
- v0/relationalai/semantics/lqp/rewrite/constants_to_vars.py +70 -0
- v0/relationalai/semantics/lqp/rewrite/deduplicate_vars.py +104 -0
- v0/relationalai/semantics/lqp/rewrite/eliminate_data.py +108 -0
- v0/relationalai/semantics/lqp/rewrite/extract_keys.py +25 -3
- v0/relationalai/semantics/lqp/rewrite/period_math.py +77 -0
- v0/relationalai/semantics/lqp/rewrite/quantify_vars.py +65 -31
- v0/relationalai/semantics/lqp/rewrite/unify_definitions.py +317 -0
- v0/relationalai/semantics/lqp/utils.py +11 -1
- v0/relationalai/semantics/lqp/validators.py +14 -1
- v0/relationalai/semantics/metamodel/builtins.py +2 -1
- v0/relationalai/semantics/metamodel/compiler.py +2 -1
- v0/relationalai/semantics/metamodel/dependency.py +12 -3
- v0/relationalai/semantics/metamodel/executor.py +11 -1
- v0/relationalai/semantics/metamodel/factory.py +2 -2
- v0/relationalai/semantics/metamodel/helpers.py +7 -0
- v0/relationalai/semantics/metamodel/ir.py +3 -2
- v0/relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +30 -20
- v0/relationalai/semantics/metamodel/rewrite/flatten.py +50 -13
- v0/relationalai/semantics/metamodel/rewrite/format_outputs.py +9 -3
- v0/relationalai/semantics/metamodel/typer/checker.py +6 -4
- v0/relationalai/semantics/metamodel/typer/typer.py +4 -3
- v0/relationalai/semantics/metamodel/visitor.py +4 -3
- v0/relationalai/semantics/reasoners/optimization/solvers_dev.py +1 -1
- v0/relationalai/semantics/reasoners/optimization/solvers_pb.py +336 -86
- v0/relationalai/semantics/rel/compiler.py +2 -1
- v0/relationalai/semantics/rel/executor.py +3 -2
- v0/relationalai/semantics/tests/lqp/__init__.py +0 -0
- v0/relationalai/semantics/tests/lqp/algorithms.py +345 -0
- v0/relationalai/tools/cli.py +339 -186
- v0/relationalai/tools/cli_controls.py +216 -67
- v0/relationalai/tools/cli_helpers.py +410 -6
- v0/relationalai/util/format.py +5 -2
- {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/entry_points.txt +0 -0
- {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/top_level.txt +0 -0
|
@@ -1,18 +1,30 @@
|
|
|
1
|
+
from .algorithm import AlgorithmPass
|
|
1
2
|
from .annotate_constraints import AnnotateConstraints
|
|
2
3
|
from .cdc import CDC
|
|
4
|
+
from .constants_to_vars import ConstantsToVars
|
|
5
|
+
from .deduplicate_vars import DeduplicateVars
|
|
6
|
+
from .eliminate_data import EliminateData
|
|
3
7
|
from .extract_common import ExtractCommon
|
|
4
8
|
from .extract_keys import ExtractKeys
|
|
5
9
|
from .function_annotations import FunctionAnnotations, SplitMultiCheckRequires
|
|
10
|
+
from .period_math import PeriodMath
|
|
6
11
|
from .quantify_vars import QuantifyVars
|
|
7
12
|
from .splinter import Splinter
|
|
13
|
+
from .unify_definitions import UnifyDefinitions
|
|
8
14
|
|
|
9
15
|
__all__ = [
|
|
16
|
+
"AlgorithmPass",
|
|
10
17
|
"AnnotateConstraints",
|
|
11
18
|
"CDC",
|
|
19
|
+
"ConstantsToVars",
|
|
20
|
+
"DeduplicateVars",
|
|
21
|
+
"EliminateData",
|
|
12
22
|
"ExtractCommon",
|
|
13
23
|
"ExtractKeys",
|
|
14
24
|
"FunctionAnnotations",
|
|
25
|
+
"PeriodMath",
|
|
15
26
|
"QuantifyVars",
|
|
16
27
|
"Splinter",
|
|
17
28
|
"SplitMultiCheckRequires",
|
|
29
|
+
"UnifyDefinitions",
|
|
18
30
|
]
|
|
@@ -0,0 +1,385 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import Optional, TypeGuard, Union, cast
|
|
3
|
+
from v0.relationalai.semantics.metamodel import ir, helpers, factory
|
|
4
|
+
from v0.relationalai.semantics.metamodel.compiler import Pass
|
|
5
|
+
from v0.relationalai.semantics.metamodel.visitor import Visitor, Rewriter, collect_by_type
|
|
6
|
+
from v0.relationalai.semantics.lqp.algorithms import (
|
|
7
|
+
is_script, is_algorithm_script,is_logical_instruction, is_update_instruction,
|
|
8
|
+
get_instruction_head_rels, get_instruction_body_rels, mk_assign, split_instruction
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
class AlgorithmPass(Pass):
|
|
12
|
+
"""
|
|
13
|
+
Transforms algorithm scripts by normalizing Loopy constructs (iterative algorithm).
|
|
14
|
+
|
|
15
|
+
This pass applies three main rewriting transformations to Metamodel IR that prepare
|
|
16
|
+
algorithm scripts for execution, in the order listed below:
|
|
17
|
+
|
|
18
|
+
1. *Intermediate Rescoping*: Moves nested logical intermediate relations from their
|
|
19
|
+
original logical scope into algorithm scripts, placing them immediately before each
|
|
20
|
+
instruction that uses them (which can include Break instructions). Removes
|
|
21
|
+
intermediates from the logical scope if they're only used within algorithms.
|
|
22
|
+
TODO: Monitor https://github.com/RelationalAI/relationalai-python/pull/3187
|
|
23
|
+
|
|
24
|
+
Example (Metamodel IR):
|
|
25
|
+
BEFORE:
|
|
26
|
+
Logical
|
|
27
|
+
Logical
|
|
28
|
+
R(x::Int128, y::Int128)
|
|
29
|
+
→ derive _nested_logical_1(x::Int128, y::Int128) @assign
|
|
30
|
+
Sequence @script @algorithm
|
|
31
|
+
Logical
|
|
32
|
+
_nested_logical_1(a::Int128, b::Int128)
|
|
33
|
+
→ derive S(a::Int128, b::Int128) @assign
|
|
34
|
+
|
|
35
|
+
AFTER:
|
|
36
|
+
Logical
|
|
37
|
+
Sequence @script @algorithm
|
|
38
|
+
Logical
|
|
39
|
+
R(x::Int128, y::Int128)
|
|
40
|
+
→ derive _nested_logical_1(x::Int128, y::Int128) @assign
|
|
41
|
+
Logical
|
|
42
|
+
_nested_logical_1(a::Int128, b::Int128)
|
|
43
|
+
→ derive S(a::Int128, b::Int128) @assign
|
|
44
|
+
|
|
45
|
+
2. **Update Normalization**: Transforms Loopy update operations (@upsert, @monoid, @monus)
|
|
46
|
+
to use a single body atom. Complex bodies with multiple lookups or additional
|
|
47
|
+
operations are normalized by introducing intermediate relations.
|
|
48
|
+
|
|
49
|
+
Example (Metamodel IR):
|
|
50
|
+
BEFORE:
|
|
51
|
+
Logical
|
|
52
|
+
R(x::Int128, y::Int128)
|
|
53
|
+
S(y::Int128, z::Int128)
|
|
54
|
+
→ derive T(x::Int128, z::Int128) @upsert
|
|
55
|
+
|
|
56
|
+
AFTER:
|
|
57
|
+
Logical
|
|
58
|
+
R(x::Int128, y::Int128)
|
|
59
|
+
S(y::Int128, z::Int128)
|
|
60
|
+
→ derive _loopy_update_intermediate_1(x::Int128, z::Int128) @assign
|
|
61
|
+
|
|
62
|
+
Logical
|
|
63
|
+
_loopy_update_intermediate_1(x::Int128, z::Int128)
|
|
64
|
+
→ derive T(x::Int128, z::Int128) @upsert
|
|
65
|
+
|
|
66
|
+
3. **Recursive Assignment Decoupling**: Decouples self-referential assignments where the
|
|
67
|
+
head relation appears in the body by introducing a copy relation. This transformation
|
|
68
|
+
is required for BackIR analysis compatibility.
|
|
69
|
+
|
|
70
|
+
Example (Metamodel IR):
|
|
71
|
+
BEFORE:
|
|
72
|
+
Logical
|
|
73
|
+
iter(i::Int128)
|
|
74
|
+
rel_primitive_int128_add(i::Int128, 1::Int128, i_plus_1::Int128)
|
|
75
|
+
→ derive iter(i_plus_1::Int128) @assign
|
|
76
|
+
|
|
77
|
+
AFTER:
|
|
78
|
+
Logical
|
|
79
|
+
iter(i::Int128)
|
|
80
|
+
→ derive _loopy_iter_copy_1(i::Int128) @assign
|
|
81
|
+
|
|
82
|
+
Logical
|
|
83
|
+
_loopy_iter_copy_1(i::Int128)
|
|
84
|
+
rel_primitive_int128_add(i::Int128, 1::Int128, i_plus_1::Int128)
|
|
85
|
+
→ derive iter(i_plus_1::Int128) @assign
|
|
86
|
+
"""
|
|
87
|
+
def rewrite(self, model: ir.Model, options:dict={}) -> ir.Model:
|
|
88
|
+
# Find all nested logical intermediates
|
|
89
|
+
intermediate_finder = FindIntermediates()
|
|
90
|
+
model.accept(intermediate_finder)
|
|
91
|
+
|
|
92
|
+
intermediate_analyzer = AnalyzeIntermediateUse(set(intermediate_finder.intermediates.keys()))
|
|
93
|
+
model.accept(intermediate_analyzer)
|
|
94
|
+
|
|
95
|
+
# Determine which intermediates to move and which to remove
|
|
96
|
+
uses_intermediates: dict[Union[ir.Logical, ir.Break], set[ir.Logical]] = defaultdict(set)
|
|
97
|
+
remove_declarations: set[ir.Logical] = set()
|
|
98
|
+
for rel, decl in intermediate_finder.intermediates.items():
|
|
99
|
+
if rel not in intermediate_analyzer.used_outside_algorithm:
|
|
100
|
+
remove_declarations.add(decl)
|
|
101
|
+
for instr in intermediate_analyzer.used_in_alg_instruction[rel]:
|
|
102
|
+
uses_intermediates[instr].add(decl)
|
|
103
|
+
|
|
104
|
+
# Rescope intermediates
|
|
105
|
+
rescoper = IntermediateRescoper(uses_intermediates, remove_declarations)
|
|
106
|
+
model = rescoper.walk(model)
|
|
107
|
+
|
|
108
|
+
# Normalize Loopy updates
|
|
109
|
+
normalizer = UpdateNormalizer()
|
|
110
|
+
model = normalizer.walk(model)
|
|
111
|
+
|
|
112
|
+
# Decompose recursive assignments
|
|
113
|
+
decomposer = RecursiveAssignmentDecoupling()
|
|
114
|
+
model = decomposer.walk(model)
|
|
115
|
+
|
|
116
|
+
return model
|
|
117
|
+
|
|
118
|
+
class FindIntermediates(Visitor):
|
|
119
|
+
"""
|
|
120
|
+
Gathers all `_nested_logical.*` intermediates defined in a Logical scope (where order
|
|
121
|
+
doesn't matter); in particular DOES NOT gather any intermediates declared in the scope
|
|
122
|
+
of a Sequence.
|
|
123
|
+
"""
|
|
124
|
+
def __init__(self):
|
|
125
|
+
self.intermediates: dict[ir.Relation, ir.Logical] = dict()
|
|
126
|
+
self._inside_algorithm: bool = False
|
|
127
|
+
|
|
128
|
+
def visit_logical(self, node: ir.Logical, parent: Optional[ir.Node]):
|
|
129
|
+
if is_logical_instruction(node):
|
|
130
|
+
heads = get_instruction_head_rels(node)
|
|
131
|
+
for rel in heads:
|
|
132
|
+
if rel.name.startswith("_nested_logical"):
|
|
133
|
+
self.intermediates[rel] = node
|
|
134
|
+
else:
|
|
135
|
+
super().visit_logical(node, parent)
|
|
136
|
+
|
|
137
|
+
def visit_sequence(self, node: ir.Sequence, parent: Optional[ir.Node]):
|
|
138
|
+
if is_algorithm_script(node):
|
|
139
|
+
self._inside_algorithm = True
|
|
140
|
+
super().visit_sequence(node, parent)
|
|
141
|
+
if is_algorithm_script(node):
|
|
142
|
+
self._inside_algorithm = False
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class AnalyzeIntermediateUse(Visitor):
|
|
146
|
+
"""
|
|
147
|
+
Identifies, for each nested logical intermediate, the algorithm instructions that
|
|
148
|
+
use it. Additionally, determines whether the intermediate is used anywhere
|
|
149
|
+
outside of an algorithm.
|
|
150
|
+
"""
|
|
151
|
+
def __init__(self, intermediate_relations: set[ir.Relation]):
|
|
152
|
+
self.intermediates = intermediate_relations
|
|
153
|
+
self.used_in_algorithm: set[ir.Relation] = set()
|
|
154
|
+
self.used_in_alg_instruction: dict[ir.Relation, set[Union[ir.Logical, ir.Break]]] = {rel: set() for rel in intermediate_relations}
|
|
155
|
+
self.used_outside_algorithm: set[ir.Relation] = set()
|
|
156
|
+
|
|
157
|
+
self._current_algorithm: Optional[ir.Sequence] = None
|
|
158
|
+
|
|
159
|
+
def register_use(self, instr: Union[ir.Logical, ir.Break], uses_intermediates: set[ir.Relation]):
|
|
160
|
+
# this instruction uses intermediates
|
|
161
|
+
if self._current_algorithm is not None:
|
|
162
|
+
# instruction is inside an algorithm
|
|
163
|
+
for rel in uses_intermediates:
|
|
164
|
+
self.used_in_algorithm.add(rel)
|
|
165
|
+
self.used_in_alg_instruction[rel].add(instr)
|
|
166
|
+
else:
|
|
167
|
+
self.used_outside_algorithm.update(uses_intermediates)
|
|
168
|
+
|
|
169
|
+
def visit_break(self, node: ir.Break, parent: Optional[ir.Node]):
|
|
170
|
+
lookups = collect_by_type(ir.Lookup, node)
|
|
171
|
+
lookup_rels = {lookup.relation for lookup in lookups}
|
|
172
|
+
uses_intermediates = lookup_rels.intersection(self.intermediates)
|
|
173
|
+
self.register_use(node, uses_intermediates)
|
|
174
|
+
super().visit_break(node, parent)
|
|
175
|
+
|
|
176
|
+
def visit_logical(self, node: ir.Logical, parent: Optional[ir.Node]):
|
|
177
|
+
if is_logical_instruction(node):
|
|
178
|
+
body = get_instruction_body_rels(node)
|
|
179
|
+
uses_intermediates = body.intersection(self.intermediates)
|
|
180
|
+
self.register_use(node, uses_intermediates)
|
|
181
|
+
else:
|
|
182
|
+
super().visit_logical(node, parent)
|
|
183
|
+
|
|
184
|
+
def visit_sequence(self, node: ir.Sequence, parent: Optional[ir.Node]):
|
|
185
|
+
if is_algorithm_script(node):
|
|
186
|
+
self._current_algorithm = node
|
|
187
|
+
super().visit_sequence(node, parent)
|
|
188
|
+
if is_algorithm_script(node):
|
|
189
|
+
self._current_algorithm = None
|
|
190
|
+
|
|
191
|
+
class IntermediateRescoper(Rewriter):
|
|
192
|
+
"""
|
|
193
|
+
Moves nested logical intermediates used in algorithm instructions from the logical scope
|
|
194
|
+
to any algorithm using the instruction before every instruction that uses them. Removes
|
|
195
|
+
an intermediate from the logical scope if it is not used anywhere else.
|
|
196
|
+
|
|
197
|
+
* `uses_intermediates`: a mapping from algorithm instructions to the set of nested logical
|
|
198
|
+
intermediates they use.
|
|
199
|
+
* `remove_declarations`: the set of nested logical intermediates to remove from the
|
|
200
|
+
logical scope because they are not used anywhere else.
|
|
201
|
+
"""
|
|
202
|
+
def __init__(self,
|
|
203
|
+
uses_intermediates: dict[Union[ir.Logical, ir.Break], set[ir.Logical]],
|
|
204
|
+
remove_declarations: set[ir.Logical]):
|
|
205
|
+
super().__init__()
|
|
206
|
+
self.uses_intermediates = uses_intermediates
|
|
207
|
+
self.remove_declarations = remove_declarations
|
|
208
|
+
|
|
209
|
+
def handle_logical(self, node: ir.Logical, parent: ir.Node) -> ir.Logical:
|
|
210
|
+
body = []
|
|
211
|
+
for child in node.body:
|
|
212
|
+
if child in self.remove_declarations:
|
|
213
|
+
continue
|
|
214
|
+
child = self.walk(child, node)
|
|
215
|
+
body.append(child)
|
|
216
|
+
return node.reconstruct(node.engine, node.hoisted, tuple(body), node.annotations)
|
|
217
|
+
|
|
218
|
+
def handle_sequence(self, node: ir.Sequence, parent: ir.Node) -> ir.Sequence:
|
|
219
|
+
tasks = []
|
|
220
|
+
for child in node.tasks:
|
|
221
|
+
if child in self.uses_intermediates:
|
|
222
|
+
assert isinstance(child, (ir.Logical, ir.Break))
|
|
223
|
+
for intermediate in self.uses_intermediates[child]:
|
|
224
|
+
tasks.append(mk_assign(intermediate))
|
|
225
|
+
child = self.walk(child, node)
|
|
226
|
+
tasks.append(child)
|
|
227
|
+
return node.reconstruct(node.engine, node.hoisted, tuple(tasks), node.annotations)
|
|
228
|
+
|
|
229
|
+
class UpdateNormalizer(Rewriter):
|
|
230
|
+
"""
|
|
231
|
+
This pass normalizes Loopy Update operations (upsert, monoid, and monus) to use a single
|
|
232
|
+
atom in their body. For any Update operation with more complex body, it introduces a new
|
|
233
|
+
intermediate relation to hold the body results.
|
|
234
|
+
"""
|
|
235
|
+
def __init__(self):
|
|
236
|
+
super().__init__()
|
|
237
|
+
self._inside_algorithm: bool = False
|
|
238
|
+
self._intermediate_counter: int = 0
|
|
239
|
+
|
|
240
|
+
# Tests if the given Update operation requires normalization
|
|
241
|
+
# * the body has more than one Lookup operation, or
|
|
242
|
+
# * the body has other tasks than Lookup and Update
|
|
243
|
+
def _requires_update_normalization(self, update: ir.Task) -> bool:
|
|
244
|
+
if not isinstance(update, ir.Logical):
|
|
245
|
+
return False
|
|
246
|
+
if not is_update_instruction(update):
|
|
247
|
+
return False
|
|
248
|
+
_, lookups, others = split_instruction(update)
|
|
249
|
+
return len(lookups) > 1 or len(others) > 0
|
|
250
|
+
|
|
251
|
+
def handle_sequence(self, node: ir.Sequence, parent: ir.Node) -> ir.Sequence:
|
|
252
|
+
if is_algorithm_script(node):
|
|
253
|
+
self._inside_algorithm = True
|
|
254
|
+
|
|
255
|
+
if self._inside_algorithm:
|
|
256
|
+
new_tasks = []
|
|
257
|
+
for task in node.tasks:
|
|
258
|
+
if self._requires_update_normalization(task):
|
|
259
|
+
assert isinstance(task, ir.Logical)
|
|
260
|
+
intermediate, normalized_update = self._normalize_update_instruction(task)
|
|
261
|
+
new_tasks.extend((intermediate, normalized_update))
|
|
262
|
+
else:
|
|
263
|
+
new_tasks.append(self.walk(task, node))
|
|
264
|
+
result = node.reconstruct(node.engine, node.hoisted, tuple(new_tasks), node.annotations)
|
|
265
|
+
else:
|
|
266
|
+
result = super().handle_sequence(node, parent)
|
|
267
|
+
|
|
268
|
+
if is_algorithm_script(node):
|
|
269
|
+
self._inside_algorithm = False
|
|
270
|
+
|
|
271
|
+
return result
|
|
272
|
+
|
|
273
|
+
def _normalize_update_instruction(self, update_instr: ir.Logical) -> tuple[ir.Logical, ir.Logical]:
|
|
274
|
+
update, lookups, others = split_instruction(update_instr)
|
|
275
|
+
normalized_update = []
|
|
276
|
+
|
|
277
|
+
var_list = helpers.vars(update.args)
|
|
278
|
+
|
|
279
|
+
intermediate_rel = factory.relation(
|
|
280
|
+
self._fresh_intermediate_name(), [
|
|
281
|
+
factory.field(f"arg_{i}", var.type) for i, var in enumerate(var_list)
|
|
282
|
+
]
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
intermediate_derive = factory.derive(intermediate_rel, var_list)
|
|
286
|
+
intermediate_logical = mk_assign(factory.logical(
|
|
287
|
+
engine=update_instr.engine,
|
|
288
|
+
hoisted=update_instr.hoisted,
|
|
289
|
+
body=(*lookups, *others, intermediate_derive),
|
|
290
|
+
annos=list(update_instr.annotations)
|
|
291
|
+
))
|
|
292
|
+
assert isinstance(intermediate_logical, ir.Logical)
|
|
293
|
+
|
|
294
|
+
intermediate_lookup = factory.lookup(
|
|
295
|
+
intermediate_rel,
|
|
296
|
+
var_list
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
normalized_update = factory.logical(
|
|
300
|
+
engine=update_instr.engine,
|
|
301
|
+
hoisted=update_instr.hoisted,
|
|
302
|
+
body=(intermediate_lookup, update),
|
|
303
|
+
annos=list(update_instr.annotations)
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
return (intermediate_logical, normalized_update)
|
|
307
|
+
|
|
308
|
+
def _fresh_intermediate_name(self) -> str:
|
|
309
|
+
self._intermediate_counter += 1
|
|
310
|
+
return f"_loopy_update_intermediate_{self._intermediate_counter}"
|
|
311
|
+
|
|
312
|
+
class RecursiveAssignmentDecoupling(Rewriter):
|
|
313
|
+
"""
|
|
314
|
+
Decouples assignments whose definition is "recursive", i.e., the body contain the head
|
|
315
|
+
e.g., `assign iter = iter + 1`. Currently, BackIR analysis cannot handle properly such
|
|
316
|
+
assignments. Such assignments are decoupled by introducing a new intermediate copy
|
|
317
|
+
relation; in the example above, `assign iter_copy = iter; assign iter = iter_copy + 1`.
|
|
318
|
+
The performance is not affected because the backend can identify the new assignment as a
|
|
319
|
+
copy operation and the execution will not lead to materialization of the intermediate
|
|
320
|
+
relation.
|
|
321
|
+
"""
|
|
322
|
+
def __init__(self):
|
|
323
|
+
super().__init__()
|
|
324
|
+
self._intermediate_copy_counter: int = 0
|
|
325
|
+
# control of head_rel -> copy_rel substitution in traversal
|
|
326
|
+
self._perform_substitution: bool = False
|
|
327
|
+
self._head_rel: Optional[ir.Relation] = None
|
|
328
|
+
self._copy_rel: Optional[ir.Relation] = None
|
|
329
|
+
|
|
330
|
+
def _fresh_copy_rel_name(self, rel_name:str) -> str:
|
|
331
|
+
self._intermediate_copy_counter += 1
|
|
332
|
+
return f"_loopy_{rel_name}_copy_{self._intermediate_copy_counter}"
|
|
333
|
+
|
|
334
|
+
def handle_sequence(self, node: ir.Sequence, parent: ir.Node) -> ir.Sequence:
|
|
335
|
+
if is_script(node):
|
|
336
|
+
new_tasks = []
|
|
337
|
+
for task in node.tasks:
|
|
338
|
+
if self._is_recursive_assignment(task):
|
|
339
|
+
assert isinstance(task, ir.Logical)
|
|
340
|
+
intermediate_copy, decomposed_assign = self._decouple_recursive_assignment(task, parent)
|
|
341
|
+
new_tasks.extend((intermediate_copy, decomposed_assign))
|
|
342
|
+
else:
|
|
343
|
+
new_tasks.append(self.walk(task, node))
|
|
344
|
+
return node.reconstruct(node.engine, node.hoisted, tuple(new_tasks), node.annotations)
|
|
345
|
+
else:
|
|
346
|
+
return super().handle_sequence(node, parent)
|
|
347
|
+
|
|
348
|
+
def _is_recursive_assignment(self, task: ir.Task) -> TypeGuard[ir.Logical]:
|
|
349
|
+
if is_logical_instruction(task):
|
|
350
|
+
heads = get_instruction_head_rels(task)
|
|
351
|
+
body = get_instruction_body_rels(task)
|
|
352
|
+
return len(body & heads) > 0
|
|
353
|
+
return False
|
|
354
|
+
|
|
355
|
+
def _decouple_recursive_assignment(self, rule: ir.Logical, parent: ir.Node) -> tuple[ir.Logical, ir.Logical]:
|
|
356
|
+
# we have `assign rel(x,...) = ..., rel(y,...), ...`
|
|
357
|
+
update, _, _ = split_instruction(rule)
|
|
358
|
+
self._head_rel = update.relation
|
|
359
|
+
|
|
360
|
+
copy_rel_name = self._fresh_copy_rel_name(self._head_rel.name)
|
|
361
|
+
|
|
362
|
+
self._copy_rel = factory.relation(copy_rel_name, list(self._head_rel.fields))
|
|
363
|
+
# build `assign copy_rel(x,...) = rel(x,...)`
|
|
364
|
+
copy_rule = cast(ir.Logical, mk_assign(
|
|
365
|
+
factory.logical([
|
|
366
|
+
factory.lookup(self._head_rel,update.args),
|
|
367
|
+
factory.update(self._copy_rel, update.args, update.effect)
|
|
368
|
+
])
|
|
369
|
+
))
|
|
370
|
+
|
|
371
|
+
# build `assign rel(x,...) = ..., copy_rel(y,...), ...``
|
|
372
|
+
self._perform_substitution = True
|
|
373
|
+
rewritten_rule = self.walk(rule, parent)
|
|
374
|
+
self._perform_substitution = False
|
|
375
|
+
|
|
376
|
+
self._head_rel = None
|
|
377
|
+
self._copy_rel = None
|
|
378
|
+
|
|
379
|
+
return (copy_rule, rewritten_rule)
|
|
380
|
+
|
|
381
|
+
def handle_lookup(self, node: ir.Lookup, parent: ir.Node) -> ir.Lookup:
|
|
382
|
+
if self._perform_substitution and node.relation == self._head_rel:
|
|
383
|
+
assert self._copy_rel is not None
|
|
384
|
+
return factory.lookup(self._copy_rel, node.args)
|
|
385
|
+
return super().handle_lookup(node, parent)
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from v0.relationalai.semantics.metamodel.compiler import Pass
|
|
2
|
+
from v0.relationalai.semantics.metamodel import ir, builtins as rel_builtins, factory as f, visitor
|
|
3
|
+
from v0.relationalai.semantics.metamodel.typer import typer
|
|
4
|
+
|
|
5
|
+
from typing import List, Sequence, Tuple, Union
|
|
6
|
+
|
|
7
|
+
# Rewrite constants to vars in Updates. This results in a more normalized format where
|
|
8
|
+
# updates contain only variables. This allows for easier rewrites in later passes.
|
|
9
|
+
class ConstantsToVars(Pass):
|
|
10
|
+
def rewrite(self, model: ir.Model, options:dict={}) -> ir.Model:
|
|
11
|
+
r = self.ConstantToVarRewriter()
|
|
12
|
+
return r.walk(model)
|
|
13
|
+
|
|
14
|
+
# Return 1) a new list of Values with no duplicates (at the object level) and
|
|
15
|
+
# 2) equalities between any original Value and a deduplicated Value.
|
|
16
|
+
@staticmethod
|
|
17
|
+
def replace_constants_with_vars(vals: Sequence[ir.Value]) -> Tuple[List[ir.Value], List[ir.Lookup]]:
|
|
18
|
+
new_vals = []
|
|
19
|
+
eqs = []
|
|
20
|
+
|
|
21
|
+
for i, val in enumerate(vals):
|
|
22
|
+
if isinstance(val, ir.PyValue) or isinstance(val, ir.Literal):
|
|
23
|
+
# Replace constant with a new Var.
|
|
24
|
+
typ = typer.to_type(val)
|
|
25
|
+
assert isinstance(typ, ir.ScalarType), "can only replace scalar constants with vars"
|
|
26
|
+
new_var = ir.Var(typ, f"{typ.name.lower()}_{i}")
|
|
27
|
+
new_vals.append(new_var)
|
|
28
|
+
eqs.append(f.lookup(rel_builtins.eq, [new_var, val]))
|
|
29
|
+
else:
|
|
30
|
+
new_vals.append(val)
|
|
31
|
+
|
|
32
|
+
return new_vals, eqs
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
def dedup_update(update: ir.Update) -> List[Union[ir.Update, ir.Lookup]]:
|
|
36
|
+
deduped_vals, req_lookups = ConstantsToVars.replace_constants_with_vars(update.args)
|
|
37
|
+
new_update = ir.Update(
|
|
38
|
+
update.engine,
|
|
39
|
+
update.relation,
|
|
40
|
+
tuple(deduped_vals),
|
|
41
|
+
update.effect,
|
|
42
|
+
update.annotations,
|
|
43
|
+
)
|
|
44
|
+
return req_lookups + [new_update]
|
|
45
|
+
|
|
46
|
+
# Does the actual work.
|
|
47
|
+
class ConstantToVarRewriter(visitor.Rewriter):
|
|
48
|
+
def __init__(self):
|
|
49
|
+
super().__init__()
|
|
50
|
+
|
|
51
|
+
# We implement handle_logical instead of handle_update because in
|
|
52
|
+
# addition to modifying said update we require new lookups (equality
|
|
53
|
+
# between original and deduplicated variables).
|
|
54
|
+
def handle_logical(self, node: ir.Logical, parent: ir.Node):
|
|
55
|
+
# In order to recurse over subtasks.
|
|
56
|
+
node = super().handle_logical(node, parent)
|
|
57
|
+
|
|
58
|
+
new_body = []
|
|
59
|
+
for subtask in node.body:
|
|
60
|
+
if isinstance(subtask, ir.Update):
|
|
61
|
+
new_body.extend(ConstantsToVars.dedup_update(subtask))
|
|
62
|
+
else:
|
|
63
|
+
new_body.append(subtask)
|
|
64
|
+
|
|
65
|
+
return ir.Logical(
|
|
66
|
+
node.engine,
|
|
67
|
+
node.hoisted,
|
|
68
|
+
tuple(new_body),
|
|
69
|
+
node.annotations
|
|
70
|
+
)
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
from v0.relationalai.semantics.metamodel.compiler import Pass
|
|
2
|
+
from v0.relationalai.semantics.metamodel import ir, builtins as rel_builtins, factory as f, visitor
|
|
3
|
+
from v0.relationalai.semantics.metamodel import helpers
|
|
4
|
+
from v0.relationalai.semantics.metamodel.util import FrozenOrderedSet
|
|
5
|
+
|
|
6
|
+
from v0.relationalai.semantics.lqp.utils import output_names
|
|
7
|
+
|
|
8
|
+
from typing import List, Sequence, Tuple, Union
|
|
9
|
+
|
|
10
|
+
# Deduplicate Vars in Updates and Outputs.
|
|
11
|
+
class DeduplicateVars(Pass):
|
|
12
|
+
def rewrite(self, model: ir.Model, options:dict={}) -> ir.Model:
|
|
13
|
+
r = self.VarDeduplicator()
|
|
14
|
+
return r.walk(model)
|
|
15
|
+
|
|
16
|
+
# Return 1) a new list of Values with no duplicates (at the object level) and
|
|
17
|
+
# 2) equalities between any original Value and a deduplicated Value.
|
|
18
|
+
@staticmethod
|
|
19
|
+
def dedup_values(vals: Sequence[ir.Value]) -> Tuple[List[ir.Value], List[ir.Lookup]]:
|
|
20
|
+
# If a var is seen more than once, it is a duplicate and we will create
|
|
21
|
+
# a new Var and equate it with the seen one.
|
|
22
|
+
seen_vars = set()
|
|
23
|
+
|
|
24
|
+
new_vals = []
|
|
25
|
+
eqs = []
|
|
26
|
+
|
|
27
|
+
for i, val in enumerate(vals):
|
|
28
|
+
# Duplicates can only occur within Vars.
|
|
29
|
+
# TODO: we don't know for sure if these are the only relevant cases.
|
|
30
|
+
if isinstance(val, ir.Default) or isinstance(val, ir.Var):
|
|
31
|
+
var = val if isinstance(val, ir.Var) else val.var
|
|
32
|
+
if var in seen_vars:
|
|
33
|
+
new_var = ir.Var(var.type, var.name + "_dup_" + str(i))
|
|
34
|
+
new_val = new_var if isinstance(val, ir.Var) else ir.Default(new_var, val.value)
|
|
35
|
+
new_vals.append(new_val)
|
|
36
|
+
eqs.append(f.lookup(rel_builtins.eq, [new_var, var]))
|
|
37
|
+
else:
|
|
38
|
+
seen_vars.add(var)
|
|
39
|
+
new_vals.append(val)
|
|
40
|
+
else:
|
|
41
|
+
# No possibility of problematic duplication.
|
|
42
|
+
new_vals.append(val)
|
|
43
|
+
|
|
44
|
+
return new_vals, eqs
|
|
45
|
+
|
|
46
|
+
# Returns a reconstructed output with no duplicate variable objects
|
|
47
|
+
# (dedup_values) and now necessary equalities between any two previously
|
|
48
|
+
# duplicate variables.
|
|
49
|
+
@staticmethod
|
|
50
|
+
def dedup_output(output: ir.Output) -> List[Union[ir.Output, ir.Lookup]]:
|
|
51
|
+
vals = helpers.output_values(output.aliases)
|
|
52
|
+
deduped_vals, req_lookups = DeduplicateVars.dedup_values(vals)
|
|
53
|
+
# Need the names so we can recombine.
|
|
54
|
+
alias_names = output_names(output.aliases)
|
|
55
|
+
new_output = ir.Output(
|
|
56
|
+
output.engine,
|
|
57
|
+
FrozenOrderedSet(list(zip(alias_names, deduped_vals))),
|
|
58
|
+
output.keys,
|
|
59
|
+
output.annotations,
|
|
60
|
+
)
|
|
61
|
+
return req_lookups + [new_output]
|
|
62
|
+
|
|
63
|
+
# Returns a replacement update with no duplicate variable objects
|
|
64
|
+
# (dedup_values) and now necessary equalities between any two previously
|
|
65
|
+
# duplicate variables.
|
|
66
|
+
@staticmethod
|
|
67
|
+
def dedup_update(update: ir.Update) -> List[Union[ir.Update, ir.Lookup]]:
|
|
68
|
+
deduped_vals, req_lookups = DeduplicateVars.dedup_values(update.args)
|
|
69
|
+
new_update = ir.Update(
|
|
70
|
+
update.engine,
|
|
71
|
+
update.relation,
|
|
72
|
+
tuple(deduped_vals),
|
|
73
|
+
update.effect,
|
|
74
|
+
update.annotations,
|
|
75
|
+
)
|
|
76
|
+
return req_lookups + [new_update]
|
|
77
|
+
|
|
78
|
+
# Does the actual work.
|
|
79
|
+
class VarDeduplicator(visitor.Rewriter):
|
|
80
|
+
def __init__(self):
|
|
81
|
+
super().__init__()
|
|
82
|
+
|
|
83
|
+
# We implement handle_logical instead of handle_update/handle_output
|
|
84
|
+
# because in addition to modifying said update/output we require new
|
|
85
|
+
# lookups (equality between original and deduplicated variables).
|
|
86
|
+
def handle_logical(self, node: ir.Logical, parent: ir.Node):
|
|
87
|
+
# In order to recurse over subtasks.
|
|
88
|
+
node = super().handle_logical(node, parent)
|
|
89
|
+
|
|
90
|
+
new_body = []
|
|
91
|
+
for subtask in node.body:
|
|
92
|
+
if isinstance(subtask, ir.Output):
|
|
93
|
+
new_body.extend(DeduplicateVars.dedup_output(subtask))
|
|
94
|
+
elif isinstance(subtask, ir.Update):
|
|
95
|
+
new_body.extend(DeduplicateVars.dedup_update(subtask))
|
|
96
|
+
else:
|
|
97
|
+
new_body.append(subtask)
|
|
98
|
+
|
|
99
|
+
return ir.Logical(
|
|
100
|
+
node.engine,
|
|
101
|
+
node.hoisted,
|
|
102
|
+
tuple(new_body),
|
|
103
|
+
node.annotations
|
|
104
|
+
)
|