relationalai 0.13.2__py3-none-any.whl → 0.13.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. relationalai/clients/client.py +3 -4
  2. relationalai/clients/exec_txn_poller.py +62 -31
  3. relationalai/clients/resources/snowflake/direct_access_resources.py +6 -5
  4. relationalai/clients/resources/snowflake/snowflake.py +54 -51
  5. relationalai/clients/resources/snowflake/use_index_poller.py +1 -1
  6. relationalai/semantics/internal/snowflake.py +5 -1
  7. relationalai/semantics/lqp/algorithms.py +173 -0
  8. relationalai/semantics/lqp/builtins.py +199 -2
  9. relationalai/semantics/lqp/executor.py +90 -41
  10. relationalai/semantics/lqp/export_rewriter.py +40 -0
  11. relationalai/semantics/lqp/ir.py +28 -2
  12. relationalai/semantics/lqp/model2lqp.py +218 -45
  13. relationalai/semantics/lqp/passes.py +13 -658
  14. relationalai/semantics/lqp/rewrite/__init__.py +12 -0
  15. relationalai/semantics/lqp/rewrite/algorithm.py +385 -0
  16. relationalai/semantics/lqp/rewrite/annotate_constraints.py +22 -10
  17. relationalai/semantics/lqp/rewrite/constants_to_vars.py +70 -0
  18. relationalai/semantics/lqp/rewrite/deduplicate_vars.py +104 -0
  19. relationalai/semantics/lqp/rewrite/eliminate_data.py +108 -0
  20. relationalai/semantics/lqp/rewrite/functional_dependencies.py +31 -2
  21. relationalai/semantics/lqp/rewrite/period_math.py +77 -0
  22. relationalai/semantics/lqp/rewrite/quantify_vars.py +65 -31
  23. relationalai/semantics/lqp/rewrite/unify_definitions.py +317 -0
  24. relationalai/semantics/lqp/utils.py +11 -1
  25. relationalai/semantics/lqp/validators.py +14 -1
  26. relationalai/semantics/metamodel/builtins.py +2 -1
  27. relationalai/semantics/metamodel/compiler.py +2 -1
  28. relationalai/semantics/metamodel/dependency.py +12 -3
  29. relationalai/semantics/metamodel/executor.py +11 -1
  30. relationalai/semantics/metamodel/factory.py +2 -2
  31. relationalai/semantics/metamodel/helpers.py +7 -0
  32. relationalai/semantics/metamodel/ir.py +3 -2
  33. relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +30 -20
  34. relationalai/semantics/metamodel/rewrite/flatten.py +50 -13
  35. relationalai/semantics/metamodel/rewrite/format_outputs.py +9 -3
  36. relationalai/semantics/metamodel/typer/checker.py +6 -4
  37. relationalai/semantics/metamodel/typer/typer.py +2 -5
  38. relationalai/semantics/metamodel/visitor.py +4 -3
  39. relationalai/semantics/reasoners/optimization/solvers_dev.py +1 -1
  40. relationalai/semantics/reasoners/optimization/solvers_pb.py +3 -4
  41. relationalai/semantics/rel/compiler.py +2 -1
  42. relationalai/semantics/rel/executor.py +3 -2
  43. relationalai/semantics/tests/lqp/__init__.py +0 -0
  44. relationalai/semantics/tests/lqp/algorithms.py +345 -0
  45. relationalai/semantics/tests/test_snapshot_abstract.py +2 -1
  46. relationalai/tools/cli_controls.py +216 -67
  47. relationalai/util/format.py +5 -2
  48. {relationalai-0.13.2.dist-info → relationalai-0.13.4.dist-info}/METADATA +2 -2
  49. {relationalai-0.13.2.dist-info → relationalai-0.13.4.dist-info}/RECORD +52 -42
  50. {relationalai-0.13.2.dist-info → relationalai-0.13.4.dist-info}/WHEEL +0 -0
  51. {relationalai-0.13.2.dist-info → relationalai-0.13.4.dist-info}/entry_points.txt +0 -0
  52. {relationalai-0.13.2.dist-info → relationalai-0.13.4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,317 @@
1
+ from relationalai.semantics.metamodel.compiler import Pass
2
+ from relationalai.semantics.metamodel import ir, builtins as rel_builtins, factory as f, visitor
3
+ from relationalai.semantics.metamodel.typer import typer
4
+ from relationalai.semantics.metamodel import helpers
5
+ from relationalai.semantics.metamodel.util import FrozenOrderedSet, OrderedSet
6
+
7
+
8
+ from typing import cast, Union, Optional, Iterable
9
+ from collections import defaultdict
10
+
11
+ # LQP does not support multiple definitions for the same relation. This pass unifies all
12
+ # definitions for each relation into a single definition using a union.
13
+ class UnifyDefinitions(Pass):
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ def rewrite(self, model: ir.Model, options:dict={}) -> ir.Model:
18
+ # Maintain a cache of renamings for each relation. These need to be consistent
19
+ # across all definitions of the same relation.
20
+ self.renamed_relation_args: dict[Union[ir.Value, ir.Relation], list[ir.Var]] = {}
21
+
22
+ root = cast(ir.Logical, model.root)
23
+ new_tasks = self.get_combined_multidefs(root)
24
+ return ir.Model(
25
+ model.engines,
26
+ model.relations,
27
+ model.types,
28
+ f.logical(
29
+ tuple(new_tasks),
30
+ root.hoisted,
31
+ root.engine,
32
+ ),
33
+ model.annotations,
34
+ )
35
+
36
+ def _get_heads(self, logical: ir.Logical) -> list[Union[ir.Update, ir.Output]]:
37
+ derives = []
38
+ for task in logical.body:
39
+ if isinstance(task, ir.Update) and task.effect == ir.Effect.derive:
40
+ derives.append(task)
41
+ elif isinstance(task, ir.Output):
42
+ derives.append(task)
43
+ return derives
44
+
45
+ def _get_non_heads(self, logical: ir.Logical) -> list[ir.Task]:
46
+ non_derives = []
47
+ for task in logical.body:
48
+ if not(isinstance(task, ir.Update) and task.effect == ir.Effect.derive) and not isinstance(task, ir.Output):
49
+ non_derives.append(task)
50
+ return non_derives
51
+
52
+ def _get_head_identifier(self, head: Union[ir.Update, ir.Output]) -> Optional[ir.Value]:
53
+ if isinstance(head, ir.Update):
54
+ return head.relation
55
+ else:
56
+ assert isinstance(head, ir.Output)
57
+ if len(head.aliases) <= 2:
58
+ # For processing here, we need output to have at least the column markers
59
+ # `cols` and `col`, and also a key
60
+ return None
61
+
62
+ output_alias_names = helpers.output_alias_names(head.aliases)
63
+ output_vals = helpers.output_values(head.aliases)
64
+
65
+ # For normal outputs, the pattern is output[keys](cols, "col000" as 'col', ...)
66
+ if output_alias_names[0] == "cols" and output_alias_names[1] == "col":
67
+ return output_vals[1]
68
+
69
+ # For exports, the pattern is output[keys]("col000" as 'col', ...)
70
+ if helpers.is_export(head):
71
+ if output_alias_names[0] == "col":
72
+ return output_vals[0]
73
+
74
+ return None
75
+
76
+ def get_combined_multidefs(self, root: ir.Logical) -> list[ir.Logical]:
77
+ # Step 1: Group tasks by the relation they define.
78
+ relation_to_tasks: dict[Union[None, ir.Value, ir.Relation], list[ir.Logical]] = defaultdict(list)
79
+
80
+ for task in root.body:
81
+ task = cast(ir.Logical, task)
82
+ task_heads = self._get_heads(task)
83
+
84
+ # Some relations do not need to be grouped, e.g., if they don't contain a
85
+ # derive. Use `None` as a placeholder key for these cases.
86
+ if len(task_heads) != 1:
87
+ relation_to_tasks[None].append(task)
88
+ continue
89
+
90
+ head_id = self._get_head_identifier(task_heads[0])
91
+ relation_to_tasks[head_id].append(task)
92
+
93
+ # Step 2: For each relation, combine all of the body definitions into a union.
94
+ result_tasks = []
95
+ for relation, tasks in relation_to_tasks.items():
96
+ # If there's only one task for the relation, or if grouping is not needed, then
97
+ # just keep the original tasks.
98
+ if len(tasks) == 1 or relation is None:
99
+ result_tasks.extend(tasks)
100
+ continue
101
+
102
+ result_tasks.append(self._combine_tasks_into_union(tasks))
103
+ return result_tasks
104
+
105
+ def _get_variable_mapping(self, logical: ir.Logical) -> dict[ir.Value, ir.Var]:
106
+ heads = self._get_heads(logical)
107
+ assert len(heads) == 1, "should only have one head in a logical at this stage"
108
+ head = heads[0]
109
+
110
+ var_mapping = {}
111
+ head_id = self._get_head_identifier(head)
112
+
113
+ if isinstance(head, ir.Update):
114
+ args_for_renaming = head.args
115
+ else:
116
+ assert isinstance(head, ir.Output)
117
+ output_alias_names = helpers.output_alias_names(head.aliases)
118
+ if output_alias_names[0] == "cols" and output_alias_names[1] == "col":
119
+ assert len(head.aliases) > 2
120
+
121
+ # For outputs, we do not need to rename the `cols` and `col` markers or the
122
+ # keys.
123
+ output_values = helpers.output_values(head.aliases)[2:]
124
+
125
+ else:
126
+ assert helpers.is_export(head) and output_alias_names[0] == "col"
127
+ assert len(head.aliases) > 1
128
+
129
+ # For exports, we do not need to rename the `col` marker or the keys.
130
+ output_values = helpers.output_values(head.aliases)[1:]
131
+
132
+ args_for_renaming = []
133
+ for v in output_values:
134
+ if head.keys and isinstance(v, ir.Var) and v in head.keys:
135
+ continue
136
+ args_for_renaming.append(v)
137
+
138
+ if head_id not in self.renamed_relation_args:
139
+ renamed_vars = []
140
+ for (i, arg) in enumerate(args_for_renaming):
141
+ typ = typer.to_type(arg)
142
+ assert arg not in var_mapping, "args of update should be unique"
143
+ if isinstance(arg, ir.Var):
144
+ var_mapping[arg] = ir.Var(typ, arg.name)
145
+ else:
146
+ var_mapping[arg] = ir.Var(typ, f"arg_{i}")
147
+
148
+ renamed_vars.append(var_mapping[arg])
149
+ self.renamed_relation_args[head_id] = renamed_vars
150
+ else:
151
+ for (arg, var) in zip(args_for_renaming, self.renamed_relation_args[head_id]):
152
+ var_mapping[arg] = var
153
+
154
+ return var_mapping
155
+
156
+ def _rename_variables(self, logical: ir.Logical) -> ir.Logical:
157
+ class RenameVisitor(visitor.Rewriter):
158
+ def __init__(self, var_mapping: dict[ir.Value, ir.Var]):
159
+ super().__init__()
160
+ self.var_mapping = var_mapping
161
+
162
+ def _get_mapped_value(self, val: ir.Value) -> ir.Value:
163
+ if isinstance(val, tuple):
164
+ return tuple(self._get_mapped_value(t) for t in val)
165
+ return self.var_mapping.get(val, val)
166
+
167
+ def _get_mapped_values(self, vals: Iterable[ir.Value]) -> list[ir.Value]:
168
+ return [self._get_mapped_value(v) for v in vals]
169
+
170
+ def handle_var(self, node: ir.Var, parent: ir.Node) -> ir.Var:
171
+ return self.var_mapping.get(node, node)
172
+
173
+ # TODO: ideally, extend the rewriter class to allow rewriting PyValue to Var so
174
+ # we don't need to separately handle all cases containing them.
175
+ def handle_update(self, node: ir.Update, parent: ir.Node) -> ir.Update:
176
+ return ir.Update(
177
+ node.engine,
178
+ node.relation,
179
+ tuple(self._get_mapped_values(node.args)),
180
+ node.effect,
181
+ node.annotations,
182
+ )
183
+
184
+ def handle_lookup(self, node: ir.Lookup, parent: ir.Node) -> ir.Lookup:
185
+ return ir.Lookup(
186
+ node.engine,
187
+ node.relation,
188
+ tuple(self._get_mapped_values(node.args)),
189
+ node.annotations,
190
+ )
191
+
192
+ def handle_output(self, node: ir.Output, parent: ir.Node) -> ir.Output:
193
+ new_aliases = FrozenOrderedSet(
194
+ [(name, self._get_mapped_value(value)) for name, value in node.aliases]
195
+ )
196
+ if node.keys:
197
+ new_keys = FrozenOrderedSet(
198
+ [self.var_mapping.get(key, key) for key in node.keys]
199
+ )
200
+ else:
201
+ new_keys = node.keys
202
+
203
+ return ir.Output(
204
+ node.engine,
205
+ new_aliases,
206
+ new_keys,
207
+ node.annotations,
208
+ )
209
+
210
+ def handle_construct(self, node: ir.Construct, parent: ir.Node) -> ir.Construct:
211
+ new_values = tuple(self._get_mapped_values(node.values))
212
+ new_id_var = self.var_mapping.get(node.id_var, node.id_var)
213
+ return ir.Construct(
214
+ node.engine,
215
+ new_values,
216
+ new_id_var,
217
+ node.annotations,
218
+ )
219
+
220
+ def handle_aggregate(self, node: ir.Aggregate, parent: ir.Node) -> ir.Aggregate:
221
+ new_projection = tuple(self.var_mapping.get(arg, arg) for arg in node.projection)
222
+ new_group = tuple(self.var_mapping.get(arg, arg) for arg in node.group)
223
+ new_args = tuple(self._get_mapped_values(node.args))
224
+ return ir.Aggregate(
225
+ node.engine,
226
+ node.aggregation,
227
+ new_projection,
228
+ new_group,
229
+ new_args,
230
+ node.annotations,
231
+ )
232
+
233
+ def handle_rank(self, node: ir.Rank, parent: ir.Node) -> ir.Rank:
234
+ new_projection = tuple(self.var_mapping.get(arg, arg) for arg in node.projection)
235
+ new_group = tuple(self.var_mapping.get(arg, arg) for arg in node.group)
236
+ new_args = tuple(self.var_mapping.get(arg, arg) for arg in node.args)
237
+ new_result = self.var_mapping.get(node.result, node.result)
238
+
239
+ return ir.Rank(
240
+ node.engine,
241
+ new_projection,
242
+ new_group,
243
+ new_args,
244
+ node.arg_is_ascending,
245
+ new_result,
246
+ node.limit,
247
+ node.annotations,
248
+ )
249
+
250
+ var_mapping = self._get_variable_mapping(logical)
251
+
252
+ renamer = RenameVisitor(var_mapping)
253
+ result = renamer.walk(logical)
254
+
255
+ # Also need to append the equality for each renamed constant. E.g., if the mapping
256
+ # contains (50.0::FLOAT -> arg_2::FLOAT), we need to add
257
+ # `eq(arg_2::FLOAT, 50.0::FLOAT)` to the result.
258
+ value_eqs = []
259
+ for (old_var, new_var) in var_mapping.items():
260
+ if not isinstance(old_var, ir.Var):
261
+ value_eqs.append(f.lookup(rel_builtins.eq, [new_var, old_var]))
262
+
263
+ return ir.Logical(
264
+ result.engine,
265
+ result.hoisted,
266
+ tuple(value_eqs) + tuple(result.body),
267
+ result.annotations,
268
+ )
269
+
270
+ # This function is the main workhorse for this rewrite pass. It takes a list of tasks
271
+ # that define the same relation, and combines them into a single task that defines
272
+ # the relation using a union of all of the bodies.
273
+ def _combine_tasks_into_union(self, tasks: list[ir.Logical]) -> ir.Logical:
274
+ # Step 1: Rename the variables in all tasks so that they will match the final derive
275
+ # after reconstructing into a union
276
+ renamed_tasks = [self._rename_variables(task) for task in tasks]
277
+
278
+ # Step 2: Get the final derive
279
+ derives = self._get_heads(renamed_tasks[0])
280
+ assert len(derives) == 1, "should only have one derive in a logical at this stage"
281
+ # Also make sure that all the derives are the same. This should be the case because
282
+ # we renamed all the variables to be the same in step 1.
283
+ for task in renamed_tasks[1:]:
284
+ assert self._get_heads(task) == derives, "all derives should be the same"
285
+
286
+ derive = derives[0]
287
+
288
+ # Step 3: Remove the final `derive` from each task
289
+ renamed_task_bodies = [
290
+ f.logical(
291
+ tuple(self._get_non_heads(t)), # Only keep non-head tasks
292
+ t.hoisted,
293
+ t.engine,
294
+ )
295
+ for t in renamed_tasks
296
+ ]
297
+
298
+ # Deduplicate bodies
299
+ renamed_task_bodies = OrderedSet.from_iterable(renamed_task_bodies).get_list()
300
+
301
+ # Step 4: Construct a union of all the task bodies
302
+ if len(renamed_task_bodies) == 1:
303
+ # If there's only one body after deduplication, no need to create a union
304
+ new_body = renamed_task_bodies[0]
305
+ else:
306
+ new_body = f.union(
307
+ tuple(renamed_task_bodies),
308
+ [],
309
+ renamed_tasks[0].engine,
310
+ )
311
+
312
+ # Step 5: Add the final derive back
313
+ return f.logical(
314
+ (new_body, derive),
315
+ [],
316
+ renamed_tasks[0].engine,
317
+ )
@@ -3,6 +3,7 @@ from relationalai.semantics.metamodel import ir
3
3
  from relationalai.semantics.metamodel.helpers import sanitize
4
4
  from relationalai.semantics.metamodel.util import FrozenOrderedSet
5
5
 
6
+ from dataclasses import dataclass
6
7
  from hashlib import sha256
7
8
  from typing import Tuple
8
9
 
@@ -43,16 +44,25 @@ class UniqueNames:
43
44
  self.id_to_name[id] = name
44
45
  return name
45
46
 
47
+ @dataclass(frozen=True)
48
+ class ExportDescriptor:
49
+ relation_id: lqp.RelationId
50
+ column_name: str
51
+ column_number: int
52
+ column_type: lqp.Type
53
+
46
54
  class TranslationCtx:
47
55
  def __init__(self, def_names: UniqueNames = UniqueNames()):
48
56
  # TODO: comment these fields
49
57
  self.def_names = def_names
50
58
  self.var_names = UniqueNames()
51
59
  self.output_names = UniqueNames()
60
+ # A counter for break rules generated during translation of while loops
61
+ self.break_rule_counter = 0
52
62
  # Map relation IDs to their original names for debugging and pretty printing.
53
63
  self.rel_id_to_orig_name = {}
54
64
  self.output_ids: list[tuple[lqp.RelationId, str]] = []
55
- self.export_ids: list[tuple[lqp.RelationId, int, lqp.Type]] = []
65
+ self.export_descriptors: list[ExportDescriptor] = []
56
66
 
57
67
  def gen_rel_id(ctx: TranslationCtx, orig_name: str, suffix: str = "") -> lqp.RelationId:
58
68
  relation_id = lqp.RelationId(id=lqp_hash(orig_name + suffix), meta=None)
@@ -6,6 +6,11 @@ CompilableType = Union[
6
6
  ir.Logical,
7
7
  ir.Union,
8
8
 
9
+ # Loops
10
+ ir.Loop,
11
+ ir.Sequence,
12
+ ir.Break,
13
+
9
14
  # Formulas
10
15
  ir.Lookup,
11
16
  ir.Exists,
@@ -36,7 +41,7 @@ def assert_valid_input(model: ir.Model) -> None:
36
41
 
37
42
  def _assert_valid_subtask(task: ir.Task) -> None:
38
43
  # TODO: assert what subtasks should look like
39
- assert isinstance(task, ir.Logical), f"expected logical task, got {type(task)}"
44
+ assert isinstance(task, (ir.Logical, ir.Sequence)), f"expected logical task, got {type(task)}"
40
45
  _assert_task_compilable(task)
41
46
 
42
47
  def _assert_task_compilable(task: ir.Task) -> None:
@@ -51,6 +56,14 @@ def _assert_task_compilable(task: ir.Task) -> None:
51
56
  assert_valid_update(task)
52
57
  effect = task.effect
53
58
  assert effect == ir.Effect.derive, "only derive supported at the moment"
59
+ elif isinstance(task, ir.Sequence):
60
+ assert any(anno.relation.name == "script" for anno in task.annotations), "only @script sequences supported at the moment"
61
+ for subtask in task.tasks:
62
+ _assert_task_compilable(subtask)
63
+ elif isinstance(task, ir.Loop):
64
+ assert isinstance(task.body, ir.Sequence), f"expected loop body to be a sequence, got {type(task.body)}"
65
+ for subtask in task.body.tasks:
66
+ _assert_task_compilable(subtask)
54
67
 
55
68
  def assert_valid_update(update: ir.Update) -> None:
56
69
  effect = update.effect
@@ -474,7 +474,8 @@ external = f.relation("external", [])
474
474
  external_annotation = f.annotation(external, [])
475
475
 
476
476
  # indicates an output is meant to be exported
477
- export = f.relation("export", [])
477
+ export = f.relation("export", [f.input_field("fqn", types.String)])
478
+ # convenience for when there are no arguments (this is deprecated as fqn should always be used)
478
479
  export_annotation = f.annotation(export, [])
479
480
 
480
481
  # indicates this relation is a concept population
@@ -21,7 +21,8 @@ class Compiler():
21
21
  for p in self.passes:
22
22
  with debugging.span(p.name) as span:
23
23
  model = p.rewrite(model, options)
24
- span["metamodel"] = str(model.root)
24
+ if debugging.DEBUG:
25
+ span["metamodel"] = str(model.root)
25
26
  p.reset()
26
27
  return model
27
28
 
@@ -31,6 +31,8 @@ class DependencyInfo():
31
31
  parent: dict[int, ir.Task] = field(default_factory=dict)
32
32
  # keep track of replacements that were made during a rewrite
33
33
  replacements: dict[int, ir.Task] = field(default_factory=dict)
34
+ # keep track of which logicals are effectful
35
+ effectful: set[int] = field(default_factory=set)
34
36
 
35
37
  def task_inputs(self, node: ir.Task) -> Optional[OrderedSet[ir.Var]]:
36
38
  """ The input variables for this task, if any. """
@@ -165,7 +167,7 @@ class Cluster():
165
167
  # this is a binders cluster, which is a candidate to being merged
166
168
  self.mergeable = not self.required and isinstance(task, helpers.BINDERS)
167
169
  # this is a cluster that will only hold an effect
168
- self.effectful = isinstance(task, helpers.EFFECTS)
170
+ self.effectful = isinstance(task, helpers.EFFECTS) or task.id in info.effectful
169
171
  # this is a cluster that will only hold a composite
170
172
  self.composite = isinstance(task, helpers.COMPOSITES)
171
173
  # content is either a single task or a set of tasks
@@ -374,7 +376,6 @@ class DependencyAnalysis(visitor.Visitor):
374
376
  def __init__(self, info: DependencyInfo):
375
377
  self.info = info
376
378
 
377
-
378
379
  def enter(self, node: ir.Node, parent: Optional[ir.Node]=None):
379
380
  # keep track of parents of all nodes
380
381
  if parent and isinstance(parent, ir.Task):
@@ -456,7 +457,7 @@ class DependencyAnalysis(visitor.Visitor):
456
457
  # if c1 has an effect and c2 is a composite without hoisted variables or with a
457
458
  # hoisted variable that does not have a default (it is a plain var), then c2
458
459
  # behaves like a filter and c1 depends on it.
459
- if c1.effectful and c2.composite:
460
+ if c1.effectful and c2.composite and not c2.effectful:
460
461
  task = c2.content.some()
461
462
  assert(isinstance(task, helpers.COMPOSITES))
462
463
  if not task.hoisted:
@@ -608,6 +609,10 @@ class BindingAnalysis(visitor.Visitor):
608
609
  else:
609
610
  map[key.id].add(val)
610
611
 
612
+ def leave(self, node: ir.Node, parent: Optional[ir.Node]=None):
613
+ if parent and node.id in self.info.effectful:
614
+ self.info.effectful.add(parent.id)
615
+ return super().leave(node, parent)
611
616
 
612
617
  #
613
618
  # Composite tasks
@@ -768,6 +773,8 @@ class BindingAnalysis(visitor.Visitor):
768
773
 
769
774
 
770
775
  def visit_update(self, node: ir.Update, parent: Optional[ir.Node]):
776
+ assert parent is not None
777
+ self.info.effectful.add(parent.id)
771
778
  # register variables being used as arguments to the update, it's always considered an input
772
779
  for v in helpers.vars(node.args):
773
780
  self.input(node, v)
@@ -816,6 +823,8 @@ class BindingAnalysis(visitor.Visitor):
816
823
 
817
824
 
818
825
  def visit_output(self, node: ir.Output, parent: Optional[ir.Node]):
826
+ assert parent is not None
827
+ self.info.effectful.add(parent.id)
819
828
  # register variables being output, they always considered an input to the task
820
829
  for v in helpers.output_vars(node.aliases):
821
830
  self.input(node, v)
@@ -1,15 +1,21 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from pandas import DataFrame
4
- from typing import Any, Union, Tuple, Literal
4
+ from typing import Any, Union, Tuple, Literal, TYPE_CHECKING
5
5
 
6
6
  from relationalai.clients.config import Config
7
7
  from relationalai.semantics.metamodel import Model, Task, ir
8
8
  from relationalai.semantics.metamodel.visitor import collect_by_type
9
+ if TYPE_CHECKING:
10
+ from relationalai.semantics.internal.internal import Model as InternalModel
11
+
9
12
  from .util import NameCache
10
13
 
11
14
  import rich
12
15
 
16
+ # global flag to suppress type errors from being printed
17
+ SUPPRESS_TYPE_ERRORS = False
18
+
13
19
  class Executor():
14
20
  """ Interface for an object that can execute the program specified by a model. """
15
21
  def execute(self, model: Model, task:Task, format:Literal["pandas", "snowpark"]="pandas") -> Union[DataFrame, Any]:
@@ -59,3 +65,7 @@ class Executor():
59
65
  if col in df.columns:
60
66
  df = df.drop(col, axis=1)
61
67
  return df
68
+
69
+ def export_to_csv(self, model: "InternalModel", query) -> str:
70
+ ### Only implemented in the LQP executor for now.
71
+ raise NotImplementedError(f"export_to_csv is not supported by {type(self).__name__}")
@@ -277,8 +277,8 @@ def for_all(vars: PySequence[ir.Var], task: ir.Task, engine: Optional[ir.Engine]
277
277
  #
278
278
 
279
279
  # loops body until a break condition is met
280
- def loop(iter: ir.Var, body: ir.Task, hoisted: PySequence[ir.VarOrDefault]=[], engine: Optional[ir.Engine]=None, annos: PySequence[ir.Annotation]=[]):
281
- return ir.Loop(engine, tuple(hoisted), iter, body, FrozenOrderedSet(annos))
280
+ def loop(body: ir.Task, iter: PySequence[ir.Var]=[], hoisted: PySequence[ir.VarOrDefault]=[], concurrency:int=1, engine: Optional[ir.Engine]=None, annos: PySequence[ir.Annotation]=[]):
281
+ return ir.Loop(engine, tuple(hoisted), tuple(iter), body, concurrency, FrozenOrderedSet(annos))
282
282
 
283
283
  def break_(check: ir.Task, engine: Optional[ir.Engine]=None, annos: PySequence[ir.Annotation]=[]):
284
284
  return Break(check, engine, annos)
@@ -24,6 +24,13 @@ def sanitize(name:str) -> str:
24
24
  # Checks
25
25
  #--------------------------------------------------
26
26
 
27
+ def is_export(node: ir.Node):
28
+ """ Whether this node is an export output. """
29
+ return isinstance(node, ir.Output) and (
30
+ builtins.export_annotation in node.annotations or
31
+ any(annotation.relation == builtins.export for annotation in node.annotations)
32
+ )
33
+
27
34
  def is_concept_lookup(node: ir.Lookup|ir.Relation):
28
35
  """ Whether this task is a concept lookup. """
29
36
  if isinstance(node, ir.Lookup) and is_concept_lookup(node.relation):
@@ -427,8 +427,9 @@ class Loop(Task):
427
427
  """Execute the body in a loop, incrementing the iter variable, until a break sub-task in
428
428
  the body succeeds."""
429
429
  hoisted: Tuple[VarOrDefault, ...]
430
- iter: Var
430
+ iter: Tuple[Var, ...]
431
431
  body: Task
432
+ concurrency: int = 1
432
433
  annotations:FrozenOrderedSet[Annotation] = annotations_field()
433
434
 
434
435
  @acceptor
@@ -856,7 +857,7 @@ class Printer(BasePrinter):
856
857
 
857
858
  # Iteration (Loops)
858
859
  elif isinstance(node, Loop):
859
- self.print_hoisted(depth, f"Loop ⇓[{self.value_to_string(node.iter)}]{annos_str}", node.hoisted)
860
+ self.print_hoisted(depth, f"Loop ⇓[{', '.join([self.value_to_string(v) for v in node.iter])}] concurrency={node.concurrency} {annos_str}", node.hoisted)
860
861
  self.pprint(node.body, depth + 1, print_ids=print_ids)
861
862
 
862
863
  elif isinstance(node, Break):
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  from relationalai.semantics.metamodel import ir
4
4
  from relationalai.semantics.metamodel.compiler import Pass
5
5
  from relationalai.semantics.metamodel.visitor import Visitor, Rewriter
6
- from relationalai.semantics.metamodel.util import OrderedSet
6
+ from relationalai.semantics.metamodel.util import OrderedSet, ordered_set
7
7
  from relationalai.semantics.metamodel import helpers, factory as f
8
8
  from typing import Optional, Any
9
9
 
@@ -106,6 +106,7 @@ class DNFExtractor(Visitor):
106
106
  # The logical that contains the output.
107
107
  # The assumption for the IR at this point is that there is only one output.
108
108
  self.output_logical: Optional[ir.Logical] = None
109
+ self.output_keys: OrderedSet[ir.Var] = ordered_set()
109
110
  self.active_negations: list[ir.Not] = []
110
111
  # Nodes that have to split into multiple similar nodes, depending on the changes
111
112
  # of sub-nodes.
@@ -120,6 +121,8 @@ class DNFExtractor(Visitor):
120
121
  if any(isinstance(x, ir.Output) for x in node.body):
121
122
  assert not self.output_logical, "multiple outputs"
122
123
  self.output_logical = node
124
+ output_node = next(x for x in node.body if isinstance(x, ir.Output))
125
+ self.output_keys = helpers.collect_vars(output_node)
123
126
 
124
127
  elif isinstance(node, ir.Not):
125
128
  self.active_negations.append(node)
@@ -168,29 +171,36 @@ class DNFExtractor(Visitor):
168
171
  self.output_logical and
169
172
  len(self.active_negations) % 2 == 0 and
170
173
  len(node.tasks) > 1):
171
- # We split the union when there is a branch with vars "X,Y" and another with "X,Z"
172
- # If some branches have vars "X, Y, Z" and others have "X, Y" or "Y, Z" we don't split
174
+ # We split the union when there are vars not present in all branches that are
175
+ # present in the output keys. If vars are not in output keys then they act as
176
+ # filters only and do not require splitting.
173
177
  should_split = False
174
- all_vars = helpers.collect_vars(node.tasks[0])
175
- for t in node.tasks[1:]:
176
- vars = helpers.collect_vars(t)
177
- curr_intersection = vars.get_set().intersection(all_vars.get_set())
178
- should_split |= not (curr_intersection == vars.get_set() or curr_intersection == all_vars.get_set())
179
- if should_split:
180
- replacements:list[ir.Task] = []
181
- for t in node.tasks:
182
- # If some branch should already be replaced, we flatten all the replacements here.
183
- if t in self.replaced_by:
184
- replacements.extend(self.replaced_by[t])
185
- else:
186
- replacements.append(t)
187
- self.replaced_by[node] = replacements
188
- self.should_split.add(parent)
189
- break
190
- all_vars.update(vars)
178
+ all_vars = helpers.collect_vars(node).get_set()
179
+ all_vars &= self.output_keys.get_set()
180
+
181
+ if all_vars:
182
+ for t in node.tasks:
183
+ vars = helpers.collect_vars(t).get_set()
184
+ curr_intersection = vars.intersection(all_vars)
185
+ if curr_intersection != all_vars:
186
+ should_split = True
187
+ break
188
+
189
+ if should_split:
190
+ replacements:list[ir.Task] = []
191
+ for t in node.tasks:
192
+ # If some branch should already be replaced, we flatten all
193
+ # the replacements here.
194
+ if t in self.replaced_by:
195
+ replacements.extend(self.replaced_by[t])
196
+ else:
197
+ replacements.append(t)
198
+ self.replaced_by[node] = replacements
199
+ self.should_split.add(parent)
191
200
 
192
201
  if isinstance(node, ir.Logical) and node == self.output_logical:
193
202
  self.output_logical = None
203
+ self.output_keys = ordered_set()
194
204
 
195
205
  return node
196
206