relationalai 0.13.2__py3-none-any.whl → 0.13.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. relationalai/clients/client.py +3 -4
  2. relationalai/clients/exec_txn_poller.py +62 -31
  3. relationalai/clients/resources/snowflake/direct_access_resources.py +6 -5
  4. relationalai/clients/resources/snowflake/snowflake.py +54 -51
  5. relationalai/clients/resources/snowflake/use_index_poller.py +1 -1
  6. relationalai/semantics/internal/snowflake.py +5 -1
  7. relationalai/semantics/lqp/algorithms.py +173 -0
  8. relationalai/semantics/lqp/builtins.py +199 -2
  9. relationalai/semantics/lqp/executor.py +90 -41
  10. relationalai/semantics/lqp/export_rewriter.py +40 -0
  11. relationalai/semantics/lqp/ir.py +28 -2
  12. relationalai/semantics/lqp/model2lqp.py +218 -45
  13. relationalai/semantics/lqp/passes.py +13 -658
  14. relationalai/semantics/lqp/rewrite/__init__.py +12 -0
  15. relationalai/semantics/lqp/rewrite/algorithm.py +385 -0
  16. relationalai/semantics/lqp/rewrite/annotate_constraints.py +22 -10
  17. relationalai/semantics/lqp/rewrite/constants_to_vars.py +70 -0
  18. relationalai/semantics/lqp/rewrite/deduplicate_vars.py +104 -0
  19. relationalai/semantics/lqp/rewrite/eliminate_data.py +108 -0
  20. relationalai/semantics/lqp/rewrite/functional_dependencies.py +31 -2
  21. relationalai/semantics/lqp/rewrite/period_math.py +77 -0
  22. relationalai/semantics/lqp/rewrite/quantify_vars.py +65 -31
  23. relationalai/semantics/lqp/rewrite/unify_definitions.py +317 -0
  24. relationalai/semantics/lqp/utils.py +11 -1
  25. relationalai/semantics/lqp/validators.py +14 -1
  26. relationalai/semantics/metamodel/builtins.py +2 -1
  27. relationalai/semantics/metamodel/compiler.py +2 -1
  28. relationalai/semantics/metamodel/dependency.py +12 -3
  29. relationalai/semantics/metamodel/executor.py +11 -1
  30. relationalai/semantics/metamodel/factory.py +2 -2
  31. relationalai/semantics/metamodel/helpers.py +7 -0
  32. relationalai/semantics/metamodel/ir.py +3 -2
  33. relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +30 -20
  34. relationalai/semantics/metamodel/rewrite/flatten.py +50 -13
  35. relationalai/semantics/metamodel/rewrite/format_outputs.py +9 -3
  36. relationalai/semantics/metamodel/typer/checker.py +6 -4
  37. relationalai/semantics/metamodel/typer/typer.py +2 -5
  38. relationalai/semantics/metamodel/visitor.py +4 -3
  39. relationalai/semantics/reasoners/optimization/solvers_dev.py +1 -1
  40. relationalai/semantics/reasoners/optimization/solvers_pb.py +3 -4
  41. relationalai/semantics/rel/compiler.py +2 -1
  42. relationalai/semantics/rel/executor.py +3 -2
  43. relationalai/semantics/tests/lqp/__init__.py +0 -0
  44. relationalai/semantics/tests/lqp/algorithms.py +345 -0
  45. relationalai/semantics/tests/test_snapshot_abstract.py +2 -1
  46. relationalai/tools/cli_controls.py +216 -67
  47. relationalai/util/format.py +5 -2
  48. {relationalai-0.13.2.dist-info → relationalai-0.13.4.dist-info}/METADATA +2 -2
  49. {relationalai-0.13.2.dist-info → relationalai-0.13.4.dist-info}/RECORD +52 -42
  50. {relationalai-0.13.2.dist-info → relationalai-0.13.4.dist-info}/WHEEL +0 -0
  51. {relationalai-0.13.2.dist-info → relationalai-0.13.4.dist-info}/entry_points.txt +0 -0
  52. {relationalai-0.13.2.dist-info → relationalai-0.13.4.dist-info}/licenses/LICENSE +0 -0
@@ -13,7 +13,8 @@ class Flatten(Pass):
13
13
  """
14
14
  Traverses the model's root to flatten it as much as possible. The result of this pass is
15
15
  a Logical root where all nested tasks that represent a rule in Rel are extracted to the
16
- top level.
16
+ top level. Additionally, any Sequence is promoted to the top level Logical (but
17
+ encapsulated by a Logical).
17
18
 
18
19
  - nested logical with updates becomes a top-level logical (a rule)
19
20
 
@@ -122,6 +123,35 @@ class Flatten(Pass):
122
123
  Logical
123
124
  lookup tmp2
124
125
  output
126
+
127
+ - a Sequence is promoted to the top level Logical, encapsulated by a Logical:
128
+ From:
129
+ Logical
130
+ Logical
131
+ lookup
132
+ derive foo
133
+ Sequence
134
+ Logical
135
+ ...
136
+ Loop
137
+ Sequence
138
+ ...
139
+ Logical
140
+ ...
141
+ To:
142
+ Logical
143
+ Logical
144
+ lookup
145
+ derive foo
146
+ Logical
147
+ Sequence
148
+ Logical
149
+ ...
150
+ Loop
151
+ Sequence
152
+ ...
153
+ Logical
154
+ ...
125
155
  """
126
156
 
127
157
  def __init__(self, use_sql: bool=False):
@@ -181,11 +211,8 @@ class Flatten(Pass):
181
211
  def handle(self, task: ir.Task, ctx: Context) -> Flatten.HandleResult:
182
212
  if isinstance(task, ir.Logical):
183
213
  return self.handle_logical(task, ctx)
184
- elif isinstance(task, ir.Union) and (task.hoisted or self._use_sql):
185
- # Only flatten Unions which hoist variables. If there are no hoisted variables,
186
- # then the Union acts as a filter, and it can be inefficient to flatten it.
187
- #
188
- # However, for the SQL backend, we always need to flatten Unions for correct SQL
214
+ elif isinstance(task, ir.Union) and self._use_sql:
215
+ # The SQL backend needs to flatten Unions for correct SQL
189
216
  # generation.
190
217
  return self.handle_union(task, ctx)
191
218
  elif isinstance(task, ir.Match):
@@ -194,6 +221,8 @@ class Flatten(Pass):
194
221
  return self.handle_require(task, ctx)
195
222
  elif isinstance(task, ir.Not):
196
223
  return self.handle_not(task, ctx)
224
+ elif isinstance(task, ir.Sequence):
225
+ return self.handle_sequence(task, ctx)
197
226
  else:
198
227
  return Flatten.HandleResult(task)
199
228
 
@@ -253,9 +282,9 @@ class Flatten(Pass):
253
282
 
254
283
  for output in groups["outputs"]:
255
284
  assert(isinstance(output, ir.Output))
256
- new_body = info.task_dependencies(output)
257
- new_body.update(ctx.extra_tasks)
258
- new_body.add(output)
285
+ new_body = OrderedSet.from_iterable(t.clone() for t in info.task_dependencies(output))
286
+ new_body.update(t.clone() for t in ctx.extra_tasks)
287
+ new_body.add(output.clone())
259
288
  ctx.rewrite_ctx.top_level.append(ir.Logical(task.engine, task.hoisted, tuple(new_body), task.annotations))
260
289
 
261
290
  return Flatten.HandleResult(None)
@@ -263,9 +292,9 @@ class Flatten(Pass):
263
292
  # if there are updates, extract as a new top level rule
264
293
  if groups["updates"]:
265
294
  # add task dependencies to the body
266
- body.prefix(ctx.info.task_dependencies(task))
295
+ body.prefix(t.clone() for t in ctx.info.task_dependencies(task))
267
296
  # potentially add context extra tasks
268
- body.update(ctx.extra_tasks)
297
+ body.update(t.clone() for t in ctx.extra_tasks)
269
298
  ctx.rewrite_ctx.top_level.append(ir.Logical(task.engine, task.hoisted, tuple(body), task.annotations))
270
299
  return Flatten.HandleResult(None)
271
300
 
@@ -278,7 +307,7 @@ class Flatten(Pass):
278
307
  agg = cast(ir.Aggregate, groups["aggregates"].some())
279
308
 
280
309
  # add agg dependencies to the body
281
- body.prefix(ctx.info.task_dependencies(agg))
310
+ body.prefix(t.clone() for t in ctx.info.task_dependencies(agg))
282
311
 
283
312
  # extract a new logical for the aggregate, exposing aggregate group-by and results
284
313
  exposed_vars = OrderedSet.from_iterable(list(agg.group) + helpers.aggregate_outputs(agg))
@@ -298,7 +327,7 @@ class Flatten(Pass):
298
327
  rank = cast(ir.Rank, groups["ranks"].some())
299
328
 
300
329
  # add rank dependencies to the body
301
- body.prefix(ctx.info.task_dependencies(rank))
330
+ body.prefix(t.clone() for t in ctx.info.task_dependencies(rank))
302
331
  # for rank, we sort by the args, but the result includes the keys to preserve bag semantics.
303
332
  exposed_vars_raw = list(rank.projection) + list(rank.group) + list(rank.args) +[rank.result]
304
333
  # deduplicate vars
@@ -487,6 +516,14 @@ class Flatten(Pass):
487
516
  task.annotations
488
517
  ))
489
518
 
519
+ def handle_sequence(self, task: ir.Sequence, ctx: Context):
520
+ new_logical = f.logical(
521
+ body = [task],
522
+ engine = task.engine
523
+ )
524
+ ctx.rewrite_ctx.top_level.append(new_logical)
525
+ return Flatten.HandleResult(None)
526
+
490
527
  #--------------------------------------------------
491
528
  # Helpers
492
529
  #--------------------------------------------------
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
  from typing import Tuple
3
3
 
4
- from relationalai.semantics.metamodel import builtins, ir, factory as f, types, visitor
4
+ from relationalai.semantics.metamodel import builtins, ir, factory as f, types, visitor, helpers
5
5
  from relationalai.semantics.metamodel.compiler import Pass, group_tasks
6
6
  from relationalai.semantics.metamodel.util import OrderedSet
7
7
  from relationalai.semantics.metamodel.util import FrozenOrderedSet
@@ -63,7 +63,7 @@ def adjust_outputs(task: ir.Logical, outputs: OrderedSet[ir.Task], wide_outputs:
63
63
  # Remove the original output. This is replaced by per-column outputs below
64
64
  body.remove(output)
65
65
 
66
- is_export = builtins.export_annotation in output.annotations
66
+ is_export = helpers.is_export(output)
67
67
 
68
68
  # Generate an output for each "column"
69
69
  # output looks like def output(:cols, :col000, key0, key1, value):
@@ -100,7 +100,13 @@ def _generate_output_column(output: ir.Output, idx: int, alias: tuple[str, ir.Va
100
100
  (not is_primitive(alias[1].type) or alias[1].type == types.Hash)):
101
101
 
102
102
  uuid = f.var(f"{alias[0]}_{idx}_uuid", types.String)
103
- aliases.append((uuid.name, uuid))
103
+
104
+ if not is_primitive(alias[1].type):
105
+ # For non-primitive types, we keep the original alias
106
+ aliases.append((alias[0], uuid))
107
+ else:
108
+ # For Hash types, we use the uuid name as alias
109
+ aliases.append((uuid.name, uuid))
104
110
 
105
111
  return [
106
112
  ir.Lookup(None, builtins.uuid_to_string, (alias[1], uuid)),
@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
5
5
  from typing import Optional, List, Union as PyUnion, Tuple, cast
6
6
 
7
7
  from relationalai.semantics.metamodel.util import OrderedSet, ordered_set
8
- from relationalai.semantics.metamodel import ir, types, visitor, compiler
8
+ from relationalai.semantics.metamodel import ir, types, visitor, compiler, executor
9
9
  import rich
10
10
 
11
11
 
@@ -39,7 +39,8 @@ class CheckEnv:
39
39
 
40
40
  def _complain(self, node: ir.Node, msg: str):
41
41
  """Report an error."""
42
- self.diags.append(CheckError(msg, node))
42
+ if not executor.SUPPRESS_TYPE_ERRORS:
43
+ self.diags.append(CheckError(msg, node))
43
44
 
44
45
 
45
46
  @dataclass
@@ -306,8 +307,9 @@ class CheckModel(visitor.DAGVisitor):
306
307
  for x in node.hoisted:
307
308
  if not CheckModel._variable_occurs_in(x, node.body):
308
309
  self.env._complain(node, f"Variable {ir.node_to_string(x).strip()} is hoisted but not used in the body of {ir.node_to_string(node).strip()}.")
309
- if not CheckModel._variable_occurs_in(node.iter, node.body):
310
- self.env._complain(node, f"Variable {node.iter} is the loop iterator but is not used in the body of {ir.node_to_string(node).strip()}.")
310
+ for iter_var in node.iter:
311
+ if not CheckModel._variable_occurs_in(iter_var, node.body):
312
+ self.env._complain(node, f"Variable {iter_var} is the loop iterator but is not used in the body of {ir.node_to_string(node).strip()}.")
311
313
  return super().visit_loop(node, parent)
312
314
 
313
315
  def visit_update(self, node: ir.Update, parent: Optional[ir.Node]=None):
@@ -6,7 +6,7 @@ import datetime
6
6
  from decimal import Decimal as PyDecimal
7
7
  from typing import Optional, Union, Tuple
8
8
  from relationalai import debugging
9
- from relationalai.semantics.metamodel import builtins, helpers, ir, types, visitor, compiler, factory as f
9
+ from relationalai.semantics.metamodel import builtins, helpers, ir, types, visitor, compiler, factory as f, executor
10
10
  from relationalai.semantics.metamodel.util import OrderedSet, ordered_set
11
11
  import rich
12
12
  import sys
@@ -1361,9 +1361,6 @@ class Replacer(visitor.Rewriter):
1361
1361
  # Typer pass
1362
1362
  #--------------------------------------------------
1363
1363
 
1364
- # global flag to suppress type errors from being printed
1365
- SUPPRESS_TYPE_ERRORS = False
1366
-
1367
1364
  class InferTypes(compiler.Pass):
1368
1365
  def __init__(self):
1369
1366
  super().__init__()
@@ -1392,7 +1389,7 @@ class InferTypes(compiler.Pass):
1392
1389
  with debugging.span("type.replace"):
1393
1390
  final = Replacer(w.net).walk(model)
1394
1391
 
1395
- if not SUPPRESS_TYPE_ERRORS:
1392
+ if not executor.SUPPRESS_TYPE_ERRORS:
1396
1393
  for err in w.net.errors:
1397
1394
  rich.print(str(err), file=sys.stderr)
1398
1395
 
@@ -466,7 +466,8 @@ class Visitor(GenericVisitor[None]):
466
466
  self._walk_engine(node.engine, node)
467
467
  for h in node.hoisted:
468
468
  self._walk_var_or_default(h, node)
469
- self._walk_var(node.iter, node)
469
+ for iter in node.iter:
470
+ self._walk_var(iter, node)
470
471
  self._walk_node(node.body, node)
471
472
  for a in node.annotations:
472
473
  self._walk_node(a, node)
@@ -935,9 +936,9 @@ class Rewriter():
935
936
  #
936
937
  def handle_loop(self, node: ir.Loop, parent: ir.Node):
937
938
  hoisted = rewrite_list(ir.VarOrDefault, lambda n: self.walk(n, node), node.hoisted)
938
- iter_val = self.walk(node.iter, node)
939
+ iter = rewrite_list(ir.Var, lambda n: self.walk(n, node), node.iter)
939
940
  body = self.walk(node.body, node)
940
- return node.reconstruct(node.engine, hoisted, iter_val, body, node.annotations)
941
+ return node.reconstruct(node.engine, hoisted, iter, body, node.concurrency, node.annotations)
941
942
 
942
943
  def handle_break(self, node: ir.Break, parent: ir.Node):
943
944
  check = self.walk(node.check, node)
@@ -333,7 +333,7 @@ class SolverModelDev:
333
333
  executor.execute_raw(textwrap.dedent(f"""
334
334
  def delete[:{self.point._name}]: {self.point._name}
335
335
  def insert(:{self.point._name}, var, val): {self.points._name}(int128[{i}], var, val)
336
- """), readonly=False)
336
+ """))
337
337
  return None
338
338
 
339
339
  # print summary of the solver result
@@ -599,7 +599,7 @@ class SolverModelPB:
599
599
  }}]}}
600
600
  """)
601
601
 
602
- executor.execute_raw(export_rel, readonly=False, query_timeout_mins=query_timeout_mins)
602
+ executor.execute_raw(export_rel, query_timeout_mins=query_timeout_mins)
603
603
 
604
604
  def _import_solver_results_from_csv(
605
605
  self,
@@ -695,7 +695,7 @@ class SolverModelPB:
695
695
  }}
696
696
  """)
697
697
 
698
- executor.execute_raw(load_and_extract_rel, readonly=False, query_timeout_mins=query_timeout_mins)
698
+ executor.execute_raw(load_and_extract_rel, query_timeout_mins=query_timeout_mins)
699
699
 
700
700
  def _export_model_to_protobuf(
701
701
  self,
@@ -791,7 +791,6 @@ class SolverModelPB:
791
791
 
792
792
  executor.execute_raw(
793
793
  textwrap.dedent(extract_rel) + textwrap.dedent(insert_points_relation),
794
- readonly=False,
795
794
  query_timeout_mins=query_timeout_mins
796
795
  )
797
796
 
@@ -929,7 +928,7 @@ class SolverModelPB:
929
928
  def delete[:{self.point._name}]: {self.point._name}
930
929
  def insert(:{self.point._name}, variable, value): {self.points._name}(int128[{point_index}], variable, value)
931
930
  """
932
- executor.execute_raw(textwrap.dedent(load_point_relation), readonly=False)
931
+ executor.execute_raw(textwrap.dedent(load_point_relation))
933
932
 
934
933
  def summarize_result(self) -> Any:
935
934
  """Print solver result summary.
@@ -1,4 +1,5 @@
1
1
  from __future__ import annotations
2
+
2
3
  from typing import Any, Iterable, Sequence as PySequence, cast, Tuple, Union
3
4
  from dataclasses import dataclass, field
4
5
  from decimal import Decimal as PyDecimal
@@ -843,7 +844,7 @@ class ModelToRel:
843
844
 
844
845
  def _effect_name(self, n: ir.Task):
845
846
  """ Return the name to be used for the effect (e.g. the relation name, output, etc). """
846
- if isinstance(n, ir.Output) and bt.export_annotation in n.annotations:
847
+ if helpers.is_export(n):
847
848
  return "Export_Relation"
848
849
  elif isinstance(n, ir.Output):
849
850
  return "output"
@@ -355,8 +355,9 @@ class RelExecutor(e.Executor):
355
355
 
356
356
  # NOTE(coey): this is added temporarily to support executing Rel for the solvers library in EA.
357
357
  # It can be removed once this is no longer needed by the solvers library.
358
- def execute_raw(self, raw_rel:str, readonly:bool=True, query_timeout_mins:int|None=None) -> DataFrame:
359
- raw_results = self.resources.exec_raw(self.database, self.engine, raw_rel, readonly, nowait_durable=True, query_timeout_mins=query_timeout_mins)
358
+ def execute_raw(self, raw_rel:str, query_timeout_mins:int|None=None) -> DataFrame:
359
+ # NOTE intentionally hard-coding to read-only=False, because read-only Rel queries are deprecated.
360
+ raw_results = self.resources.exec_raw(self.database, self.engine, raw_rel, False, nowait_durable=True, query_timeout_mins=query_timeout_mins)
360
361
  df, errs = result_helpers.format_results(raw_results, None, generation=Generation.QB) # Pass None for task parameter
361
362
  self.report_errors(errs)
362
363
  return df
File without changes
@@ -0,0 +1,345 @@
1
+ """
2
+ Constructing Metamodel IR with Algorithms
3
+
4
+ We introduce a set of programmatic constructs that provide a convenient syntax for
5
+ constructing PyRel's metamodel IR representations for Loopy algorithms. Importantly, these
6
+ macros construct a new model using PyRel declarations constructed with a _base model_. The
7
+ base model needs to be also used to declare all concepts and relationships.
8
+
9
+ Below we illustrate the use of these macros by constructing a simple reachability
10
+ algorithm, whose Rel-like pseudo-code is as follows:
11
+
12
+ ```
13
+ algorithm
14
+ setup
15
+ def edge = { (1,2); (2,3); (3,4) }
16
+ def source = { 1 }
17
+ end setup
18
+ @global empty reachable = {}
19
+ loop
20
+ def frontier = source
21
+ def reachable = frontier
22
+ while (true)
23
+ def next_frontier = frontier . edge
24
+ def frontier = next_frontier
25
+ monus frontier = reachable # frontier = frontier - reachable
26
+ upsert reachable = frontier # reachable = reachable ∪ frontier
27
+ break break_reachable = empty(frontier)
28
+ end while
29
+ end loop
30
+ end algorithm
31
+ ```
32
+
33
+ The PyRel's metamodel IR for the above algorithm is constructed with the utilities as
34
+ follows.
35
+
36
+ ```
37
+ base_model = Model("algorithm_builder", dry_run=True)
38
+
39
+ # Input (context) data
40
+
41
+ edge = base_model.Relationship("Edge from {source:int} to {target:int}")
42
+ source = base_model.Relationship("Source node {node:int}")
43
+
44
+ with algorithm(base_model):
45
+ setup(
46
+ define(edge(1,2), edge(2,3), edge(3,4), edge(4,1))),
47
+ define(source(1))
48
+ )
49
+
50
+ # "local" variables and relations
51
+ n = Integer.ref()
52
+ m = Integer.ref()
53
+ reachable = base_model.Relationship("Reachable node {node:int}")
54
+ frontier = base_model.Relationship("Frontier node {node:int}")
55
+ next_frontier = base_model.Relationship("Next frontier node {node:int}")
56
+
57
+ global_(empty(define(reachable(n))))
58
+ assign(define(frontier(n)).where(source(n)))
59
+ assign(define(reachable(n)).where(frontier(n)))
60
+ with while_():
61
+ assign(define(next_frontier(m)).where(frontier(n), edge(n, m)))
62
+ assign(define(frontier(m)).where(next_frontier(m)))
63
+ monus(define(frontier(n)).where(reachable(n)))
64
+ upsert(0)(define(reachable(n)).where(frontier(n)))
65
+ break_(where(not_(frontier(n))))
66
+
67
+ # Prints the PyRel Metamodel (IR)
68
+ print(get_metamodel())
69
+
70
+ # Prints the LQP transaction
71
+ print(get_lqp_str())
72
+ ```
73
+ """
74
+ from relationalai.semantics import Model
75
+ from relationalai.semantics.metamodel import factory, ir, types
76
+ from relationalai.semantics.internal.internal import Fragment
77
+ from relationalai.semantics.lqp.algorithms import (
78
+ mk_empty, mk_assign, mk_upsert, mk_global, mk_monus
79
+ )
80
+ from relationalai.semantics.lqp.constructors import mk_transaction
81
+ from relationalai.semantics.lqp.compiler import Compiler
82
+ from relationalai.semantics.lqp import ir as lqp, builtins
83
+ from typing import cast, TypeGuard, Optional, Sequence
84
+ from lqp import print as lqp_print
85
+ import threading
86
+ from contextlib import contextmanager
87
+
88
+
89
+ # While the constructors are very light-weight they enforce
90
+ # the following grammar for algorithms:
91
+ #
92
+ # <Algorithm> := with algorithm(base_model): <Script>
93
+ # <Script> := <Instruction>*
94
+ # <Instruction> := <BaseInstruction> | <Loop>
95
+ # <BaseInstruction> := [global_(] empty(Fragment) [)]
96
+ # | [global_(] assign(<Fragment>) [)]
97
+ # | break(<Fragment>)
98
+ # | upsert(<Int>)(<Fragment>)
99
+ # | monus(<Fragment>)
100
+ # <Loop> := with while_(): <Script>
101
+ #
102
+ # Note: global_ annotation can only be used on top-level empty and assign instructions at the
103
+ # top-level of the algorithm script.
104
+
105
+ _storage = threading.local()
106
+
107
+ def get_builder() -> 'AlgorithmBuilder':
108
+ """ Retrieves the thread-local AlgorithmBuilder instance."""
109
+ global _storage
110
+ if not(hasattr(_storage, "algorithm_builder")):
111
+ _storage.algorithm_builder = AlgorithmBuilder()
112
+ return _storage.algorithm_builder
113
+
114
+ def get_metamodel() -> ir.Model:
115
+ """ Retrieves the compiled metamodel IR for the previous algorithm. Can only be used
116
+ after an algorithm has been defined."""
117
+ return get_builder().get_metamodel()
118
+
119
+ def get_lqp_str() -> str:
120
+ """ Retrieves the LQP string representation for the previous algorithm. Can only be used
121
+ after an algorithm has been defined."""
122
+ return get_builder().get_lqp_str()
123
+
124
+ @contextmanager
125
+ def algorithm(model:Model):
126
+ """ Context manager for defining an algorithm on the given base model."""
127
+ get_builder().begin_algorithm(model)
128
+ yield
129
+ get_builder().end_algorithm()
130
+
131
+ @contextmanager
132
+ def while_():
133
+ """ Context manager for defining a while loop within an algorithm."""
134
+ get_builder().begin_while_loop()
135
+ yield
136
+ get_builder().end_while_loop()
137
+
138
+ def setup(*stmts:Fragment):
139
+ """ Defines the setup section of an algorithm: a collection of PyRel statement that
140
+ prepare input data for the algorithm."""
141
+ builder = get_builder()
142
+ assert len(builder.script_stacks) == 1, "setup can only be called at the top-level of an algorithm"
143
+ assert builder.setup_fragments is None, "setup can only be called once per algorithm"
144
+ builder.set_setup_fragments(stmts)
145
+
146
+ def global_(pos:int):
147
+ """ Marks a top-level `empty` or `assign` instruction as defining a global relation."""
148
+ assert type(pos) is int, "global_ can only be applied to empty and assign"
149
+ builder = get_builder()
150
+ assert len(builder.script_stacks) == 1, "global_ can only be applied to top-level instructions"
151
+ assert len(builder.script_stacks[0].instructions) == pos + 1
152
+ task = cast(ir.Task, mk_global(builder.script_stacks[0].instructions[pos]))
153
+ builder.script_stacks[0].instructions[pos] = task
154
+ builder.add_global_relation(task)
155
+
156
+ def empty(stmt) -> int:
157
+ """ Marks a PyRel statement as an assignment of empty relation. The statement must not
158
+ have a body (no where clause)."""
159
+ assert has_empty_body(stmt), "Empty instruction must have an empty body"
160
+ task = get_builder().compile_statement(stmt)
161
+ task = cast(ir.Task, mk_empty(task))
162
+ return get_builder().append_task(task)
163
+
164
+ def assign(stmt) -> int:
165
+ """ Marks a PyRel statement as an assignment instruction."""
166
+ task = get_builder().compile_statement(stmt)
167
+ task = cast(ir.Task, mk_assign(task))
168
+ return get_builder().append_task(task)
169
+
170
+ def upsert_with_arity(arity:int, stmt:Fragment):
171
+ task = get_builder().compile_statement(stmt)
172
+ task = cast(ir.Task, mk_upsert(task, arity))
173
+ get_builder().append_task(task)
174
+
175
+ def upsert(arity:int):
176
+ """ Marks a PyRel statement as an upsert instruction with the given arity."""
177
+ assert type(arity) is int and arity >= 0, "arity must be a non-negative integer"
178
+ return lambda stmt: upsert_with_arity(arity, stmt)
179
+
180
+ def monus(stmt: Fragment) -> int:
181
+ """ Marks a PyRel statement as a Boolean monus (set difference) instruction."""
182
+ task = get_builder().compile_statement(stmt)
183
+ task = cast(ir.Task, mk_monus(task, types.Bool, "or", 0))
184
+ return get_builder().append_task(task)
185
+
186
+ def break_(stmt):
187
+ """ Marks a PyRel statement as a break instruction. The statement must be headless (no define clause)."""
188
+ assert has_no_head(stmt), "Break instruction must have a headless fragment"
189
+ task = get_builder().compile_statement(stmt)
190
+ assert isinstance(task, ir.Logical)
191
+ break_condition = [cond for cond in task.body if not isinstance(cond, ir.Update)]
192
+ break_node = factory.break_(factory.logical(break_condition))
193
+ get_builder().append_task(break_node)
194
+
195
+ def has_empty_body(stmt) -> TypeGuard[Fragment]:
196
+ if not isinstance(stmt, Fragment):
197
+ return False
198
+ return len(stmt._where) == 0
199
+
200
+ def has_no_head(frag):
201
+ return len(frag._define) == 0
202
+
203
+
204
+ class ScriptBuilder:
205
+ """
206
+ Builder for Loopy scripts.
207
+ """
208
+ def __init__(self):
209
+ self.instructions:list[ir.Task] = []
210
+
211
+ def add_task(self, instr:ir.Task) -> int:
212
+ self.instructions.append(instr)
213
+ return len(self.instructions) - 1
214
+
215
+ def build_script(self, annos:list[ir.Annotation]) -> ir.Sequence:
216
+ return factory.sequence(
217
+ tasks=self.instructions,
218
+ annos=[builtins.script_annotation()] + annos
219
+ )
220
+
221
+
222
+ class AlgorithmBuilder:
223
+ """
224
+ Builder for Loopy algorithms.
225
+ """
226
+ def __init__(self):
227
+ self.script_stacks:list[ScriptBuilder] = []
228
+ self.compiled_model:Optional[ir.Model] = None
229
+ self.global_relations:list[str] = []
230
+ self.base_model:Optional[Model] = None
231
+ self.setup_fragments:Optional[list[Fragment]] = None
232
+
233
+ def begin_algorithm(self, base_model:Model):
234
+ self.base_model = base_model
235
+ self.script_stacks = [ScriptBuilder()]
236
+ self.compiled_model = None
237
+ self.global_relations:list[str] = []
238
+ self.setup_fragments:Optional[list[Fragment]] = None
239
+
240
+ def add_global_relation(self, task:ir.Task):
241
+ assert isinstance(task, ir.Logical)
242
+ for t in task.body:
243
+ if isinstance(t, ir.Update):
244
+ if t.relation.name not in self.global_relations:
245
+ self.global_relations.append(t.relation.name)
246
+
247
+ def set_setup_fragments(self, fragments:Sequence[Fragment]):
248
+ self.setup_fragments = list(fragments)
249
+
250
+ def compile_statement(self, stmt:Fragment) -> ir.Task:
251
+ assert self.base_model is not None
252
+ task = self.base_model._compiler.compile_task(stmt)
253
+ return task
254
+
255
+ def append_task(self, task:ir.Task) -> int:
256
+ assert len(self.script_stacks) > 0
257
+ return self.script_stacks[-1].add_task(task)
258
+
259
+ def begin_while_loop(self):
260
+ script_builder = ScriptBuilder()
261
+ self.script_stacks.append(script_builder)
262
+
263
+ def end_while_loop(self):
264
+ script_builder = self.script_stacks.pop()
265
+ while_script = script_builder.build_script([builtins.while_annotation()])
266
+ loop = factory.loop(while_script, annos=[builtins.while_annotation()])
267
+ self.append_task(loop)
268
+
269
+ def end_algorithm(self):
270
+ assert len(self.script_stacks) == 1
271
+ script_builder = self.script_stacks.pop()
272
+ algorithm_script = script_builder.build_script([builtins.algorithm_annotation()])
273
+ setup = self.compile_setup()
274
+ algorithm_logical = factory.logical(setup + [algorithm_script])
275
+ self.compiled_model = factory.compute_model(algorithm_logical)
276
+
277
+ def compile_setup(self) -> list[ir.Logical]:
278
+ if self.setup_fragments is None:
279
+ return []
280
+ assert self.setup_fragments is not None
281
+ assert self.base_model is not None
282
+ setup_tasks = []
283
+ for stmt in self.setup_fragments:
284
+ task = self.base_model._compiler.compile_task(stmt)
285
+ setup_tasks.append(task)
286
+ return setup_tasks
287
+
288
+ def get_metamodel(self) -> ir.Model:
289
+ """ Retrieves the compiled metamodel IR for the previous algorithm. """
290
+ metamodel = self.compiled_model
291
+ assert metamodel is not None, "No metamodel available. You must first define algorithm."
292
+ return metamodel
293
+
294
+ def get_lqp_str(self) -> str:
295
+ lqp = self.get_lqp()
296
+ options = lqp_print.ugly_config.copy()
297
+ options[str(lqp_print.PrettyOptions.PRINT_NAMES)] = True
298
+ options[str(lqp_print.PrettyOptions.PRINT_DEBUG)] = False
299
+ lqp_str = lqp_print.to_string(lqp, options)
300
+ return lqp_str
301
+
302
+ def get_lqp(self):
303
+ model = self.get_metamodel()
304
+
305
+ compiler = Compiler()
306
+ rewritten_model = compiler.rewrite(model)
307
+ write_epoch = compiler.do_compile(rewritten_model, {'fragment_id': b"f1"})[1]
308
+
309
+ define = cast(lqp.Define, write_epoch.writes[0].write_type)
310
+ debug_info = define.fragment.debug_info
311
+
312
+ read_epoch = self._build_read_epoch(debug_info)
313
+
314
+ transaction = mk_transaction([write_epoch, read_epoch])
315
+
316
+ return transaction
317
+
318
+ def _build_read_epoch(self, debug_info:lqp.DebugInfo) -> lqp.Epoch:
319
+ reads = []
320
+
321
+ relation_id:dict[str,lqp.RelationId] = dict()
322
+ for rel_id, rel_name in debug_info.id_to_orig_name.items():
323
+ if rel_name in self.global_relations:
324
+ relation_id[rel_name] = rel_id
325
+
326
+ global_relation_names = [rel for rel in self.global_relations if rel in relation_id]
327
+
328
+ for (i, rel_name) in enumerate(global_relation_names):
329
+ read = lqp.Read(
330
+ meta = None,
331
+ read_type = lqp.Output(
332
+ meta=None,
333
+ name=f"{rel_name}",
334
+ relation_id=relation_id[rel_name],
335
+ )
336
+ )
337
+ reads.append(read)
338
+
339
+ read_epoch = lqp.Epoch(
340
+ meta = None,
341
+ writes = [],
342
+ reads = reads,
343
+ )
344
+
345
+ return read_epoch
@@ -20,7 +20,7 @@ class AbstractSnapshotTest(ABC):
20
20
  provider:Provider = cast(SFProvider, Provider()) # type: ignore
21
21
 
22
22
  def run_snapshot_test(self, snapshot, script_path, db_schema=None, use_sql=False, use_lqp=True, use_rel=False,
23
- use_direct_access=False, e2e=False, use_csv=True, e2e_only=False):
23
+ use_direct_access=False, e2e=False, use_csv=True, e2e_only=False, emit_constraints=False):
24
24
  # Resolve use_lqp
25
25
  use_lqp = use_lqp and (not use_rel) # use_rel overrides because use_lqp is default.
26
26
 
@@ -47,6 +47,7 @@ class AbstractSnapshotTest(ABC):
47
47
  'model_suffix': "" if not e2e else f"_{unique_name}",
48
48
  'use_sql': use_sql,
49
49
  'reasoner.rule.use_lqp': use_lqp,
50
+ 'reasoner.rule.emit_constraints': emit_constraints,
50
51
  'keep_model': False,
51
52
  # fix the current time to keep snapshots stable
52
53
  'datetime_now': datetime.datetime.fromisoformat("2025-12-01T12:00:00+00:00"),