relationalai 0.13.2__py3-none-any.whl → 0.13.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/clients/client.py +3 -4
- relationalai/clients/exec_txn_poller.py +62 -31
- relationalai/clients/resources/snowflake/direct_access_resources.py +6 -5
- relationalai/clients/resources/snowflake/snowflake.py +47 -51
- relationalai/semantics/lqp/algorithms.py +173 -0
- relationalai/semantics/lqp/builtins.py +199 -2
- relationalai/semantics/lqp/executor.py +65 -36
- relationalai/semantics/lqp/ir.py +28 -2
- relationalai/semantics/lqp/model2lqp.py +215 -45
- relationalai/semantics/lqp/passes.py +13 -658
- relationalai/semantics/lqp/rewrite/__init__.py +12 -0
- relationalai/semantics/lqp/rewrite/algorithm.py +385 -0
- relationalai/semantics/lqp/rewrite/constants_to_vars.py +70 -0
- relationalai/semantics/lqp/rewrite/deduplicate_vars.py +104 -0
- relationalai/semantics/lqp/rewrite/eliminate_data.py +108 -0
- relationalai/semantics/lqp/rewrite/period_math.py +77 -0
- relationalai/semantics/lqp/rewrite/quantify_vars.py +65 -31
- relationalai/semantics/lqp/rewrite/unify_definitions.py +317 -0
- relationalai/semantics/lqp/utils.py +11 -1
- relationalai/semantics/lqp/validators.py +14 -1
- relationalai/semantics/metamodel/builtins.py +2 -1
- relationalai/semantics/metamodel/compiler.py +2 -1
- relationalai/semantics/metamodel/dependency.py +12 -3
- relationalai/semantics/metamodel/executor.py +11 -1
- relationalai/semantics/metamodel/factory.py +2 -2
- relationalai/semantics/metamodel/helpers.py +7 -0
- relationalai/semantics/metamodel/ir.py +3 -2
- relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +30 -20
- relationalai/semantics/metamodel/rewrite/flatten.py +50 -13
- relationalai/semantics/metamodel/rewrite/format_outputs.py +9 -3
- relationalai/semantics/metamodel/typer/checker.py +6 -4
- relationalai/semantics/metamodel/typer/typer.py +2 -5
- relationalai/semantics/metamodel/visitor.py +4 -3
- relationalai/semantics/reasoners/optimization/solvers_dev.py +1 -1
- relationalai/semantics/reasoners/optimization/solvers_pb.py +3 -4
- relationalai/semantics/rel/compiler.py +2 -1
- relationalai/semantics/rel/executor.py +3 -2
- relationalai/semantics/tests/lqp/__init__.py +0 -0
- relationalai/semantics/tests/lqp/algorithms.py +345 -0
- relationalai/tools/cli_controls.py +216 -67
- relationalai/util/format.py +5 -2
- {relationalai-0.13.2.dist-info → relationalai-0.13.3.dist-info}/METADATA +1 -1
- {relationalai-0.13.2.dist-info → relationalai-0.13.3.dist-info}/RECORD +46 -37
- {relationalai-0.13.2.dist-info → relationalai-0.13.3.dist-info}/WHEEL +0 -0
- {relationalai-0.13.2.dist-info → relationalai-0.13.3.dist-info}/entry_points.txt +0 -0
- {relationalai-0.13.2.dist-info → relationalai-0.13.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -3,6 +3,7 @@ from relationalai.semantics.metamodel import ir
|
|
|
3
3
|
from relationalai.semantics.metamodel.helpers import sanitize
|
|
4
4
|
from relationalai.semantics.metamodel.util import FrozenOrderedSet
|
|
5
5
|
|
|
6
|
+
from dataclasses import dataclass
|
|
6
7
|
from hashlib import sha256
|
|
7
8
|
from typing import Tuple
|
|
8
9
|
|
|
@@ -43,16 +44,25 @@ class UniqueNames:
|
|
|
43
44
|
self.id_to_name[id] = name
|
|
44
45
|
return name
|
|
45
46
|
|
|
47
|
+
@dataclass(frozen=True)
|
|
48
|
+
class ExportDescriptor:
|
|
49
|
+
relation_id: lqp.RelationId
|
|
50
|
+
column_name: str
|
|
51
|
+
column_number: int
|
|
52
|
+
column_type: lqp.Type
|
|
53
|
+
|
|
46
54
|
class TranslationCtx:
|
|
47
55
|
def __init__(self, def_names: UniqueNames = UniqueNames()):
|
|
48
56
|
# TODO: comment these fields
|
|
49
57
|
self.def_names = def_names
|
|
50
58
|
self.var_names = UniqueNames()
|
|
51
59
|
self.output_names = UniqueNames()
|
|
60
|
+
# A counter for break rules generated during translation of while loops
|
|
61
|
+
self.break_rule_counter = 0
|
|
52
62
|
# Map relation IDs to their original names for debugging and pretty printing.
|
|
53
63
|
self.rel_id_to_orig_name = {}
|
|
54
64
|
self.output_ids: list[tuple[lqp.RelationId, str]] = []
|
|
55
|
-
self.
|
|
65
|
+
self.export_descriptors: list[ExportDescriptor] = []
|
|
56
66
|
|
|
57
67
|
def gen_rel_id(ctx: TranslationCtx, orig_name: str, suffix: str = "") -> lqp.RelationId:
|
|
58
68
|
relation_id = lqp.RelationId(id=lqp_hash(orig_name + suffix), meta=None)
|
|
@@ -6,6 +6,11 @@ CompilableType = Union[
|
|
|
6
6
|
ir.Logical,
|
|
7
7
|
ir.Union,
|
|
8
8
|
|
|
9
|
+
# Loops
|
|
10
|
+
ir.Loop,
|
|
11
|
+
ir.Sequence,
|
|
12
|
+
ir.Break,
|
|
13
|
+
|
|
9
14
|
# Formulas
|
|
10
15
|
ir.Lookup,
|
|
11
16
|
ir.Exists,
|
|
@@ -36,7 +41,7 @@ def assert_valid_input(model: ir.Model) -> None:
|
|
|
36
41
|
|
|
37
42
|
def _assert_valid_subtask(task: ir.Task) -> None:
|
|
38
43
|
# TODO: assert what subtasks should look like
|
|
39
|
-
assert isinstance(task, ir.Logical), f"expected logical task, got {type(task)}"
|
|
44
|
+
assert isinstance(task, (ir.Logical, ir.Sequence)), f"expected logical task, got {type(task)}"
|
|
40
45
|
_assert_task_compilable(task)
|
|
41
46
|
|
|
42
47
|
def _assert_task_compilable(task: ir.Task) -> None:
|
|
@@ -51,6 +56,14 @@ def _assert_task_compilable(task: ir.Task) -> None:
|
|
|
51
56
|
assert_valid_update(task)
|
|
52
57
|
effect = task.effect
|
|
53
58
|
assert effect == ir.Effect.derive, "only derive supported at the moment"
|
|
59
|
+
elif isinstance(task, ir.Sequence):
|
|
60
|
+
assert any(anno.relation.name == "script" for anno in task.annotations), "only @script sequences supported at the moment"
|
|
61
|
+
for subtask in task.tasks:
|
|
62
|
+
_assert_task_compilable(subtask)
|
|
63
|
+
elif isinstance(task, ir.Loop):
|
|
64
|
+
assert isinstance(task.body, ir.Sequence), f"expected loop body to be a sequence, got {type(task.body)}"
|
|
65
|
+
for subtask in task.body.tasks:
|
|
66
|
+
_assert_task_compilable(subtask)
|
|
54
67
|
|
|
55
68
|
def assert_valid_update(update: ir.Update) -> None:
|
|
56
69
|
effect = update.effect
|
|
@@ -474,7 +474,8 @@ external = f.relation("external", [])
|
|
|
474
474
|
external_annotation = f.annotation(external, [])
|
|
475
475
|
|
|
476
476
|
# indicates an output is meant to be exported
|
|
477
|
-
export = f.relation("export", [])
|
|
477
|
+
export = f.relation("export", [f.input_field("fqn", types.String)])
|
|
478
|
+
# convenience for when there are no arguments (this is deprecated as fqn should always be used)
|
|
478
479
|
export_annotation = f.annotation(export, [])
|
|
479
480
|
|
|
480
481
|
# indicates this relation is a concept population
|
|
@@ -31,6 +31,8 @@ class DependencyInfo():
|
|
|
31
31
|
parent: dict[int, ir.Task] = field(default_factory=dict)
|
|
32
32
|
# keep track of replacements that were made during a rewrite
|
|
33
33
|
replacements: dict[int, ir.Task] = field(default_factory=dict)
|
|
34
|
+
# keep track of which logicals are effectful
|
|
35
|
+
effectful: set[int] = field(default_factory=set)
|
|
34
36
|
|
|
35
37
|
def task_inputs(self, node: ir.Task) -> Optional[OrderedSet[ir.Var]]:
|
|
36
38
|
""" The input variables for this task, if any. """
|
|
@@ -165,7 +167,7 @@ class Cluster():
|
|
|
165
167
|
# this is a binders cluster, which is a candidate to being merged
|
|
166
168
|
self.mergeable = not self.required and isinstance(task, helpers.BINDERS)
|
|
167
169
|
# this is a cluster that will only hold an effect
|
|
168
|
-
self.effectful = isinstance(task, helpers.EFFECTS)
|
|
170
|
+
self.effectful = isinstance(task, helpers.EFFECTS) or task.id in info.effectful
|
|
169
171
|
# this is a cluster that will only hold a composite
|
|
170
172
|
self.composite = isinstance(task, helpers.COMPOSITES)
|
|
171
173
|
# content is either a single task or a set of tasks
|
|
@@ -374,7 +376,6 @@ class DependencyAnalysis(visitor.Visitor):
|
|
|
374
376
|
def __init__(self, info: DependencyInfo):
|
|
375
377
|
self.info = info
|
|
376
378
|
|
|
377
|
-
|
|
378
379
|
def enter(self, node: ir.Node, parent: Optional[ir.Node]=None):
|
|
379
380
|
# keep track of parents of all nodes
|
|
380
381
|
if parent and isinstance(parent, ir.Task):
|
|
@@ -456,7 +457,7 @@ class DependencyAnalysis(visitor.Visitor):
|
|
|
456
457
|
# if c1 has an effect and c2 is a composite without hoisted variables or with a
|
|
457
458
|
# hoisted variable that does not have a default (it is a plain var), then c2
|
|
458
459
|
# behaves like a filter and c1 depends on it.
|
|
459
|
-
if c1.effectful and c2.composite:
|
|
460
|
+
if c1.effectful and c2.composite and not c2.effectful:
|
|
460
461
|
task = c2.content.some()
|
|
461
462
|
assert(isinstance(task, helpers.COMPOSITES))
|
|
462
463
|
if not task.hoisted:
|
|
@@ -608,6 +609,10 @@ class BindingAnalysis(visitor.Visitor):
|
|
|
608
609
|
else:
|
|
609
610
|
map[key.id].add(val)
|
|
610
611
|
|
|
612
|
+
def leave(self, node: ir.Node, parent: Optional[ir.Node]=None):
|
|
613
|
+
if parent and node.id in self.info.effectful:
|
|
614
|
+
self.info.effectful.add(parent.id)
|
|
615
|
+
return super().leave(node, parent)
|
|
611
616
|
|
|
612
617
|
#
|
|
613
618
|
# Composite tasks
|
|
@@ -768,6 +773,8 @@ class BindingAnalysis(visitor.Visitor):
|
|
|
768
773
|
|
|
769
774
|
|
|
770
775
|
def visit_update(self, node: ir.Update, parent: Optional[ir.Node]):
|
|
776
|
+
assert parent is not None
|
|
777
|
+
self.info.effectful.add(parent.id)
|
|
771
778
|
# register variables being used as arguments to the update, it's always considered an input
|
|
772
779
|
for v in helpers.vars(node.args):
|
|
773
780
|
self.input(node, v)
|
|
@@ -816,6 +823,8 @@ class BindingAnalysis(visitor.Visitor):
|
|
|
816
823
|
|
|
817
824
|
|
|
818
825
|
def visit_output(self, node: ir.Output, parent: Optional[ir.Node]):
|
|
826
|
+
assert parent is not None
|
|
827
|
+
self.info.effectful.add(parent.id)
|
|
819
828
|
# register variables being output, they always considered an input to the task
|
|
820
829
|
for v in helpers.output_vars(node.aliases):
|
|
821
830
|
self.input(node, v)
|
|
@@ -1,15 +1,21 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from pandas import DataFrame
|
|
4
|
-
from typing import Any, Union, Tuple, Literal
|
|
4
|
+
from typing import Any, Union, Tuple, Literal, TYPE_CHECKING
|
|
5
5
|
|
|
6
6
|
from relationalai.clients.config import Config
|
|
7
7
|
from relationalai.semantics.metamodel import Model, Task, ir
|
|
8
8
|
from relationalai.semantics.metamodel.visitor import collect_by_type
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from relationalai.semantics.internal.internal import Model as InternalModel
|
|
11
|
+
|
|
9
12
|
from .util import NameCache
|
|
10
13
|
|
|
11
14
|
import rich
|
|
12
15
|
|
|
16
|
+
# global flag to suppress type errors from being printed
|
|
17
|
+
SUPPRESS_TYPE_ERRORS = False
|
|
18
|
+
|
|
13
19
|
class Executor():
|
|
14
20
|
""" Interface for an object that can execute the program specified by a model. """
|
|
15
21
|
def execute(self, model: Model, task:Task, format:Literal["pandas", "snowpark"]="pandas") -> Union[DataFrame, Any]:
|
|
@@ -59,3 +65,7 @@ class Executor():
|
|
|
59
65
|
if col in df.columns:
|
|
60
66
|
df = df.drop(col, axis=1)
|
|
61
67
|
return df
|
|
68
|
+
|
|
69
|
+
def export_to_csv(self, model: "InternalModel", query) -> str:
|
|
70
|
+
### Only implemented in the LQP executor for now.
|
|
71
|
+
raise NotImplementedError(f"export_to_csv is not supported by {type(self).__name__}")
|
|
@@ -277,8 +277,8 @@ def for_all(vars: PySequence[ir.Var], task: ir.Task, engine: Optional[ir.Engine]
|
|
|
277
277
|
#
|
|
278
278
|
|
|
279
279
|
# loops body until a break condition is met
|
|
280
|
-
def loop(
|
|
281
|
-
return ir.Loop(engine, tuple(hoisted), iter, body, FrozenOrderedSet(annos))
|
|
280
|
+
def loop(body: ir.Task, iter: PySequence[ir.Var]=[], hoisted: PySequence[ir.VarOrDefault]=[], concurrency:int=1, engine: Optional[ir.Engine]=None, annos: PySequence[ir.Annotation]=[]):
|
|
281
|
+
return ir.Loop(engine, tuple(hoisted), tuple(iter), body, concurrency, FrozenOrderedSet(annos))
|
|
282
282
|
|
|
283
283
|
def break_(check: ir.Task, engine: Optional[ir.Engine]=None, annos: PySequence[ir.Annotation]=[]):
|
|
284
284
|
return Break(check, engine, annos)
|
|
@@ -24,6 +24,13 @@ def sanitize(name:str) -> str:
|
|
|
24
24
|
# Checks
|
|
25
25
|
#--------------------------------------------------
|
|
26
26
|
|
|
27
|
+
def is_export(node: ir.Node):
|
|
28
|
+
""" Whether this node is an export output. """
|
|
29
|
+
return isinstance(node, ir.Output) and (
|
|
30
|
+
builtins.export_annotation in node.annotations or
|
|
31
|
+
any(annotation.relation == builtins.export for annotation in node.annotations)
|
|
32
|
+
)
|
|
33
|
+
|
|
27
34
|
def is_concept_lookup(node: ir.Lookup|ir.Relation):
|
|
28
35
|
""" Whether this task is a concept lookup. """
|
|
29
36
|
if isinstance(node, ir.Lookup) and is_concept_lookup(node.relation):
|
|
@@ -427,8 +427,9 @@ class Loop(Task):
|
|
|
427
427
|
"""Execute the body in a loop, incrementing the iter variable, until a break sub-task in
|
|
428
428
|
the body succeeds."""
|
|
429
429
|
hoisted: Tuple[VarOrDefault, ...]
|
|
430
|
-
iter: Var
|
|
430
|
+
iter: Tuple[Var, ...]
|
|
431
431
|
body: Task
|
|
432
|
+
concurrency: int = 1
|
|
432
433
|
annotations:FrozenOrderedSet[Annotation] = annotations_field()
|
|
433
434
|
|
|
434
435
|
@acceptor
|
|
@@ -856,7 +857,7 @@ class Printer(BasePrinter):
|
|
|
856
857
|
|
|
857
858
|
# Iteration (Loops)
|
|
858
859
|
elif isinstance(node, Loop):
|
|
859
|
-
self.print_hoisted(depth, f"Loop ⇓[{self.value_to_string(node.iter)}]{annos_str}", node.hoisted)
|
|
860
|
+
self.print_hoisted(depth, f"Loop ⇓[{', '.join([self.value_to_string(v) for v in node.iter])}] concurrency={node.concurrency} {annos_str}", node.hoisted)
|
|
860
861
|
self.pprint(node.body, depth + 1, print_ids=print_ids)
|
|
861
862
|
|
|
862
863
|
elif isinstance(node, Break):
|
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
from relationalai.semantics.metamodel import ir
|
|
4
4
|
from relationalai.semantics.metamodel.compiler import Pass
|
|
5
5
|
from relationalai.semantics.metamodel.visitor import Visitor, Rewriter
|
|
6
|
-
from relationalai.semantics.metamodel.util import OrderedSet
|
|
6
|
+
from relationalai.semantics.metamodel.util import OrderedSet, ordered_set
|
|
7
7
|
from relationalai.semantics.metamodel import helpers, factory as f
|
|
8
8
|
from typing import Optional, Any
|
|
9
9
|
|
|
@@ -106,6 +106,7 @@ class DNFExtractor(Visitor):
|
|
|
106
106
|
# The logical that contains the output.
|
|
107
107
|
# The assumption for the IR at this point is that there is only one output.
|
|
108
108
|
self.output_logical: Optional[ir.Logical] = None
|
|
109
|
+
self.output_keys: OrderedSet[ir.Var] = ordered_set()
|
|
109
110
|
self.active_negations: list[ir.Not] = []
|
|
110
111
|
# Nodes that have to split into multiple similar nodes, depending on the changes
|
|
111
112
|
# of sub-nodes.
|
|
@@ -120,6 +121,8 @@ class DNFExtractor(Visitor):
|
|
|
120
121
|
if any(isinstance(x, ir.Output) for x in node.body):
|
|
121
122
|
assert not self.output_logical, "multiple outputs"
|
|
122
123
|
self.output_logical = node
|
|
124
|
+
output_node = next(x for x in node.body if isinstance(x, ir.Output))
|
|
125
|
+
self.output_keys = helpers.collect_vars(output_node)
|
|
123
126
|
|
|
124
127
|
elif isinstance(node, ir.Not):
|
|
125
128
|
self.active_negations.append(node)
|
|
@@ -168,29 +171,36 @@ class DNFExtractor(Visitor):
|
|
|
168
171
|
self.output_logical and
|
|
169
172
|
len(self.active_negations) % 2 == 0 and
|
|
170
173
|
len(node.tasks) > 1):
|
|
171
|
-
# We split the union when there
|
|
172
|
-
#
|
|
174
|
+
# We split the union when there are vars not present in all branches that are
|
|
175
|
+
# present in the output keys. If vars are not in output keys then they act as
|
|
176
|
+
# filters only and do not require splitting.
|
|
173
177
|
should_split = False
|
|
174
|
-
all_vars = helpers.collect_vars(node.
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
178
|
+
all_vars = helpers.collect_vars(node).get_set()
|
|
179
|
+
all_vars &= self.output_keys.get_set()
|
|
180
|
+
|
|
181
|
+
if all_vars:
|
|
182
|
+
for t in node.tasks:
|
|
183
|
+
vars = helpers.collect_vars(t).get_set()
|
|
184
|
+
curr_intersection = vars.intersection(all_vars)
|
|
185
|
+
if curr_intersection != all_vars:
|
|
186
|
+
should_split = True
|
|
187
|
+
break
|
|
188
|
+
|
|
189
|
+
if should_split:
|
|
190
|
+
replacements:list[ir.Task] = []
|
|
191
|
+
for t in node.tasks:
|
|
192
|
+
# If some branch should already be replaced, we flatten all
|
|
193
|
+
# the replacements here.
|
|
194
|
+
if t in self.replaced_by:
|
|
195
|
+
replacements.extend(self.replaced_by[t])
|
|
196
|
+
else:
|
|
197
|
+
replacements.append(t)
|
|
198
|
+
self.replaced_by[node] = replacements
|
|
199
|
+
self.should_split.add(parent)
|
|
191
200
|
|
|
192
201
|
if isinstance(node, ir.Logical) and node == self.output_logical:
|
|
193
202
|
self.output_logical = None
|
|
203
|
+
self.output_keys = ordered_set()
|
|
194
204
|
|
|
195
205
|
return node
|
|
196
206
|
|
|
@@ -13,7 +13,8 @@ class Flatten(Pass):
|
|
|
13
13
|
"""
|
|
14
14
|
Traverses the model's root to flatten it as much as possible. The result of this pass is
|
|
15
15
|
a Logical root where all nested tasks that represent a rule in Rel are extracted to the
|
|
16
|
-
top level.
|
|
16
|
+
top level. Additionally, any Sequence is promoted to the top level Logical (but
|
|
17
|
+
encapsulated by a Logical).
|
|
17
18
|
|
|
18
19
|
- nested logical with updates becomes a top-level logical (a rule)
|
|
19
20
|
|
|
@@ -122,6 +123,35 @@ class Flatten(Pass):
|
|
|
122
123
|
Logical
|
|
123
124
|
lookup tmp2
|
|
124
125
|
output
|
|
126
|
+
|
|
127
|
+
- a Sequence is promoted to the top level Logical, encapsulated by a Logical:
|
|
128
|
+
From:
|
|
129
|
+
Logical
|
|
130
|
+
Logical
|
|
131
|
+
lookup
|
|
132
|
+
derive foo
|
|
133
|
+
Sequence
|
|
134
|
+
Logical
|
|
135
|
+
...
|
|
136
|
+
Loop
|
|
137
|
+
Sequence
|
|
138
|
+
...
|
|
139
|
+
Logical
|
|
140
|
+
...
|
|
141
|
+
To:
|
|
142
|
+
Logical
|
|
143
|
+
Logical
|
|
144
|
+
lookup
|
|
145
|
+
derive foo
|
|
146
|
+
Logical
|
|
147
|
+
Sequence
|
|
148
|
+
Logical
|
|
149
|
+
...
|
|
150
|
+
Loop
|
|
151
|
+
Sequence
|
|
152
|
+
...
|
|
153
|
+
Logical
|
|
154
|
+
...
|
|
125
155
|
"""
|
|
126
156
|
|
|
127
157
|
def __init__(self, use_sql: bool=False):
|
|
@@ -181,11 +211,8 @@ class Flatten(Pass):
|
|
|
181
211
|
def handle(self, task: ir.Task, ctx: Context) -> Flatten.HandleResult:
|
|
182
212
|
if isinstance(task, ir.Logical):
|
|
183
213
|
return self.handle_logical(task, ctx)
|
|
184
|
-
elif isinstance(task, ir.Union) and
|
|
185
|
-
#
|
|
186
|
-
# then the Union acts as a filter, and it can be inefficient to flatten it.
|
|
187
|
-
#
|
|
188
|
-
# However, for the SQL backend, we always need to flatten Unions for correct SQL
|
|
214
|
+
elif isinstance(task, ir.Union) and self._use_sql:
|
|
215
|
+
# The SQL backend needs to flatten Unions for correct SQL
|
|
189
216
|
# generation.
|
|
190
217
|
return self.handle_union(task, ctx)
|
|
191
218
|
elif isinstance(task, ir.Match):
|
|
@@ -194,6 +221,8 @@ class Flatten(Pass):
|
|
|
194
221
|
return self.handle_require(task, ctx)
|
|
195
222
|
elif isinstance(task, ir.Not):
|
|
196
223
|
return self.handle_not(task, ctx)
|
|
224
|
+
elif isinstance(task, ir.Sequence):
|
|
225
|
+
return self.handle_sequence(task, ctx)
|
|
197
226
|
else:
|
|
198
227
|
return Flatten.HandleResult(task)
|
|
199
228
|
|
|
@@ -253,9 +282,9 @@ class Flatten(Pass):
|
|
|
253
282
|
|
|
254
283
|
for output in groups["outputs"]:
|
|
255
284
|
assert(isinstance(output, ir.Output))
|
|
256
|
-
new_body = info.task_dependencies(output)
|
|
257
|
-
new_body.update(ctx.extra_tasks)
|
|
258
|
-
new_body.add(output)
|
|
285
|
+
new_body = OrderedSet.from_iterable(t.clone() for t in info.task_dependencies(output))
|
|
286
|
+
new_body.update(t.clone() for t in ctx.extra_tasks)
|
|
287
|
+
new_body.add(output.clone())
|
|
259
288
|
ctx.rewrite_ctx.top_level.append(ir.Logical(task.engine, task.hoisted, tuple(new_body), task.annotations))
|
|
260
289
|
|
|
261
290
|
return Flatten.HandleResult(None)
|
|
@@ -263,9 +292,9 @@ class Flatten(Pass):
|
|
|
263
292
|
# if there are updates, extract as a new top level rule
|
|
264
293
|
if groups["updates"]:
|
|
265
294
|
# add task dependencies to the body
|
|
266
|
-
body.prefix(ctx.info.task_dependencies(task))
|
|
295
|
+
body.prefix(t.clone() for t in ctx.info.task_dependencies(task))
|
|
267
296
|
# potentially add context extra tasks
|
|
268
|
-
body.update(ctx.extra_tasks)
|
|
297
|
+
body.update(t.clone() for t in ctx.extra_tasks)
|
|
269
298
|
ctx.rewrite_ctx.top_level.append(ir.Logical(task.engine, task.hoisted, tuple(body), task.annotations))
|
|
270
299
|
return Flatten.HandleResult(None)
|
|
271
300
|
|
|
@@ -278,7 +307,7 @@ class Flatten(Pass):
|
|
|
278
307
|
agg = cast(ir.Aggregate, groups["aggregates"].some())
|
|
279
308
|
|
|
280
309
|
# add agg dependencies to the body
|
|
281
|
-
body.prefix(ctx.info.task_dependencies(agg))
|
|
310
|
+
body.prefix(t.clone() for t in ctx.info.task_dependencies(agg))
|
|
282
311
|
|
|
283
312
|
# extract a new logical for the aggregate, exposing aggregate group-by and results
|
|
284
313
|
exposed_vars = OrderedSet.from_iterable(list(agg.group) + helpers.aggregate_outputs(agg))
|
|
@@ -298,7 +327,7 @@ class Flatten(Pass):
|
|
|
298
327
|
rank = cast(ir.Rank, groups["ranks"].some())
|
|
299
328
|
|
|
300
329
|
# add rank dependencies to the body
|
|
301
|
-
body.prefix(ctx.info.task_dependencies(rank))
|
|
330
|
+
body.prefix(t.clone() for t in ctx.info.task_dependencies(rank))
|
|
302
331
|
# for rank, we sort by the args, but the result includes the keys to preserve bag semantics.
|
|
303
332
|
exposed_vars_raw = list(rank.projection) + list(rank.group) + list(rank.args) +[rank.result]
|
|
304
333
|
# deduplicate vars
|
|
@@ -487,6 +516,14 @@ class Flatten(Pass):
|
|
|
487
516
|
task.annotations
|
|
488
517
|
))
|
|
489
518
|
|
|
519
|
+
def handle_sequence(self, task: ir.Sequence, ctx: Context):
|
|
520
|
+
new_logical = f.logical(
|
|
521
|
+
body = [task],
|
|
522
|
+
engine = task.engine
|
|
523
|
+
)
|
|
524
|
+
ctx.rewrite_ctx.top_level.append(new_logical)
|
|
525
|
+
return Flatten.HandleResult(None)
|
|
526
|
+
|
|
490
527
|
#--------------------------------------------------
|
|
491
528
|
# Helpers
|
|
492
529
|
#--------------------------------------------------
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
from typing import Tuple
|
|
3
3
|
|
|
4
|
-
from relationalai.semantics.metamodel import builtins, ir, factory as f, types, visitor
|
|
4
|
+
from relationalai.semantics.metamodel import builtins, ir, factory as f, types, visitor, helpers
|
|
5
5
|
from relationalai.semantics.metamodel.compiler import Pass, group_tasks
|
|
6
6
|
from relationalai.semantics.metamodel.util import OrderedSet
|
|
7
7
|
from relationalai.semantics.metamodel.util import FrozenOrderedSet
|
|
@@ -63,7 +63,7 @@ def adjust_outputs(task: ir.Logical, outputs: OrderedSet[ir.Task], wide_outputs:
|
|
|
63
63
|
# Remove the original output. This is replaced by per-column outputs below
|
|
64
64
|
body.remove(output)
|
|
65
65
|
|
|
66
|
-
is_export =
|
|
66
|
+
is_export = helpers.is_export(output)
|
|
67
67
|
|
|
68
68
|
# Generate an output for each "column"
|
|
69
69
|
# output looks like def output(:cols, :col000, key0, key1, value):
|
|
@@ -100,7 +100,13 @@ def _generate_output_column(output: ir.Output, idx: int, alias: tuple[str, ir.Va
|
|
|
100
100
|
(not is_primitive(alias[1].type) or alias[1].type == types.Hash)):
|
|
101
101
|
|
|
102
102
|
uuid = f.var(f"{alias[0]}_{idx}_uuid", types.String)
|
|
103
|
-
|
|
103
|
+
|
|
104
|
+
if not is_primitive(alias[1].type):
|
|
105
|
+
# For non-primitive types, we keep the original alias
|
|
106
|
+
aliases.append((alias[0], uuid))
|
|
107
|
+
else:
|
|
108
|
+
# For Hash types, we use the uuid name as alias
|
|
109
|
+
aliases.append((uuid.name, uuid))
|
|
104
110
|
|
|
105
111
|
return [
|
|
106
112
|
ir.Lookup(None, builtins.uuid_to_string, (alias[1], uuid)),
|
|
@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
|
|
5
5
|
from typing import Optional, List, Union as PyUnion, Tuple, cast
|
|
6
6
|
|
|
7
7
|
from relationalai.semantics.metamodel.util import OrderedSet, ordered_set
|
|
8
|
-
from relationalai.semantics.metamodel import ir, types, visitor, compiler
|
|
8
|
+
from relationalai.semantics.metamodel import ir, types, visitor, compiler, executor
|
|
9
9
|
import rich
|
|
10
10
|
|
|
11
11
|
|
|
@@ -39,7 +39,8 @@ class CheckEnv:
|
|
|
39
39
|
|
|
40
40
|
def _complain(self, node: ir.Node, msg: str):
|
|
41
41
|
"""Report an error."""
|
|
42
|
-
|
|
42
|
+
if not executor.SUPPRESS_TYPE_ERRORS:
|
|
43
|
+
self.diags.append(CheckError(msg, node))
|
|
43
44
|
|
|
44
45
|
|
|
45
46
|
@dataclass
|
|
@@ -306,8 +307,9 @@ class CheckModel(visitor.DAGVisitor):
|
|
|
306
307
|
for x in node.hoisted:
|
|
307
308
|
if not CheckModel._variable_occurs_in(x, node.body):
|
|
308
309
|
self.env._complain(node, f"Variable {ir.node_to_string(x).strip()} is hoisted but not used in the body of {ir.node_to_string(node).strip()}.")
|
|
309
|
-
|
|
310
|
-
|
|
310
|
+
for iter_var in node.iter:
|
|
311
|
+
if not CheckModel._variable_occurs_in(iter_var, node.body):
|
|
312
|
+
self.env._complain(node, f"Variable {iter_var} is the loop iterator but is not used in the body of {ir.node_to_string(node).strip()}.")
|
|
311
313
|
return super().visit_loop(node, parent)
|
|
312
314
|
|
|
313
315
|
def visit_update(self, node: ir.Update, parent: Optional[ir.Node]=None):
|
|
@@ -6,7 +6,7 @@ import datetime
|
|
|
6
6
|
from decimal import Decimal as PyDecimal
|
|
7
7
|
from typing import Optional, Union, Tuple
|
|
8
8
|
from relationalai import debugging
|
|
9
|
-
from relationalai.semantics.metamodel import builtins, helpers, ir, types, visitor, compiler, factory as f
|
|
9
|
+
from relationalai.semantics.metamodel import builtins, helpers, ir, types, visitor, compiler, factory as f, executor
|
|
10
10
|
from relationalai.semantics.metamodel.util import OrderedSet, ordered_set
|
|
11
11
|
import rich
|
|
12
12
|
import sys
|
|
@@ -1361,9 +1361,6 @@ class Replacer(visitor.Rewriter):
|
|
|
1361
1361
|
# Typer pass
|
|
1362
1362
|
#--------------------------------------------------
|
|
1363
1363
|
|
|
1364
|
-
# global flag to suppress type errors from being printed
|
|
1365
|
-
SUPPRESS_TYPE_ERRORS = False
|
|
1366
|
-
|
|
1367
1364
|
class InferTypes(compiler.Pass):
|
|
1368
1365
|
def __init__(self):
|
|
1369
1366
|
super().__init__()
|
|
@@ -1392,7 +1389,7 @@ class InferTypes(compiler.Pass):
|
|
|
1392
1389
|
with debugging.span("type.replace"):
|
|
1393
1390
|
final = Replacer(w.net).walk(model)
|
|
1394
1391
|
|
|
1395
|
-
if not SUPPRESS_TYPE_ERRORS:
|
|
1392
|
+
if not executor.SUPPRESS_TYPE_ERRORS:
|
|
1396
1393
|
for err in w.net.errors:
|
|
1397
1394
|
rich.print(str(err), file=sys.stderr)
|
|
1398
1395
|
|
|
@@ -466,7 +466,8 @@ class Visitor(GenericVisitor[None]):
|
|
|
466
466
|
self._walk_engine(node.engine, node)
|
|
467
467
|
for h in node.hoisted:
|
|
468
468
|
self._walk_var_or_default(h, node)
|
|
469
|
-
|
|
469
|
+
for iter in node.iter:
|
|
470
|
+
self._walk_var(iter, node)
|
|
470
471
|
self._walk_node(node.body, node)
|
|
471
472
|
for a in node.annotations:
|
|
472
473
|
self._walk_node(a, node)
|
|
@@ -935,9 +936,9 @@ class Rewriter():
|
|
|
935
936
|
#
|
|
936
937
|
def handle_loop(self, node: ir.Loop, parent: ir.Node):
|
|
937
938
|
hoisted = rewrite_list(ir.VarOrDefault, lambda n: self.walk(n, node), node.hoisted)
|
|
938
|
-
|
|
939
|
+
iter = rewrite_list(ir.Var, lambda n: self.walk(n, node), node.iter)
|
|
939
940
|
body = self.walk(node.body, node)
|
|
940
|
-
return node.reconstruct(node.engine, hoisted,
|
|
941
|
+
return node.reconstruct(node.engine, hoisted, iter, body, node.concurrency, node.annotations)
|
|
941
942
|
|
|
942
943
|
def handle_break(self, node: ir.Break, parent: ir.Node):
|
|
943
944
|
check = self.walk(node.check, node)
|
|
@@ -333,7 +333,7 @@ class SolverModelDev:
|
|
|
333
333
|
executor.execute_raw(textwrap.dedent(f"""
|
|
334
334
|
def delete[:{self.point._name}]: {self.point._name}
|
|
335
335
|
def insert(:{self.point._name}, var, val): {self.points._name}(int128[{i}], var, val)
|
|
336
|
-
""")
|
|
336
|
+
"""))
|
|
337
337
|
return None
|
|
338
338
|
|
|
339
339
|
# print summary of the solver result
|
|
@@ -599,7 +599,7 @@ class SolverModelPB:
|
|
|
599
599
|
}}]}}
|
|
600
600
|
""")
|
|
601
601
|
|
|
602
|
-
executor.execute_raw(export_rel,
|
|
602
|
+
executor.execute_raw(export_rel, query_timeout_mins=query_timeout_mins)
|
|
603
603
|
|
|
604
604
|
def _import_solver_results_from_csv(
|
|
605
605
|
self,
|
|
@@ -695,7 +695,7 @@ class SolverModelPB:
|
|
|
695
695
|
}}
|
|
696
696
|
""")
|
|
697
697
|
|
|
698
|
-
executor.execute_raw(load_and_extract_rel,
|
|
698
|
+
executor.execute_raw(load_and_extract_rel, query_timeout_mins=query_timeout_mins)
|
|
699
699
|
|
|
700
700
|
def _export_model_to_protobuf(
|
|
701
701
|
self,
|
|
@@ -791,7 +791,6 @@ class SolverModelPB:
|
|
|
791
791
|
|
|
792
792
|
executor.execute_raw(
|
|
793
793
|
textwrap.dedent(extract_rel) + textwrap.dedent(insert_points_relation),
|
|
794
|
-
readonly=False,
|
|
795
794
|
query_timeout_mins=query_timeout_mins
|
|
796
795
|
)
|
|
797
796
|
|
|
@@ -929,7 +928,7 @@ class SolverModelPB:
|
|
|
929
928
|
def delete[:{self.point._name}]: {self.point._name}
|
|
930
929
|
def insert(:{self.point._name}, variable, value): {self.points._name}(int128[{point_index}], variable, value)
|
|
931
930
|
"""
|
|
932
|
-
executor.execute_raw(textwrap.dedent(load_point_relation)
|
|
931
|
+
executor.execute_raw(textwrap.dedent(load_point_relation))
|
|
933
932
|
|
|
934
933
|
def summarize_result(self) -> Any:
|
|
935
934
|
"""Print solver result summary.
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
from typing import Any, Iterable, Sequence as PySequence, cast, Tuple, Union
|
|
3
4
|
from dataclasses import dataclass, field
|
|
4
5
|
from decimal import Decimal as PyDecimal
|
|
@@ -843,7 +844,7 @@ class ModelToRel:
|
|
|
843
844
|
|
|
844
845
|
def _effect_name(self, n: ir.Task):
|
|
845
846
|
""" Return the name to be used for the effect (e.g. the relation name, output, etc). """
|
|
846
|
-
if
|
|
847
|
+
if helpers.is_export(n):
|
|
847
848
|
return "Export_Relation"
|
|
848
849
|
elif isinstance(n, ir.Output):
|
|
849
850
|
return "output"
|
|
@@ -355,8 +355,9 @@ class RelExecutor(e.Executor):
|
|
|
355
355
|
|
|
356
356
|
# NOTE(coey): this is added temporarily to support executing Rel for the solvers library in EA.
|
|
357
357
|
# It can be removed once this is no longer needed by the solvers library.
|
|
358
|
-
def execute_raw(self, raw_rel:str,
|
|
359
|
-
|
|
358
|
+
def execute_raw(self, raw_rel:str, query_timeout_mins:int|None=None) -> DataFrame:
|
|
359
|
+
# NOTE intentionally hard-coding to read-only=False, because read-only Rel queries are deprecated.
|
|
360
|
+
raw_results = self.resources.exec_raw(self.database, self.engine, raw_rel, False, nowait_durable=True, query_timeout_mins=query_timeout_mins)
|
|
360
361
|
df, errs = result_helpers.format_results(raw_results, None, generation=Generation.QB) # Pass None for task parameter
|
|
361
362
|
self.report_errors(errs)
|
|
362
363
|
return df
|
|
File without changes
|