relationalai 1.0.0a3__py3-none-any.whl → 1.0.0a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. relationalai/config/config.py +47 -21
  2. relationalai/config/connections/__init__.py +5 -2
  3. relationalai/config/connections/duckdb.py +2 -2
  4. relationalai/config/connections/local.py +31 -0
  5. relationalai/config/connections/snowflake.py +0 -1
  6. relationalai/config/external/raiconfig_converter.py +235 -0
  7. relationalai/config/external/raiconfig_models.py +202 -0
  8. relationalai/config/external/utils.py +31 -0
  9. relationalai/config/shims.py +1 -0
  10. relationalai/semantics/__init__.py +10 -8
  11. relationalai/semantics/backends/sql/sql_compiler.py +1 -4
  12. relationalai/semantics/experimental/__init__.py +0 -0
  13. relationalai/semantics/experimental/builder.py +295 -0
  14. relationalai/semantics/experimental/builtins.py +154 -0
  15. relationalai/semantics/frontend/base.py +67 -42
  16. relationalai/semantics/frontend/core.py +34 -6
  17. relationalai/semantics/frontend/front_compiler.py +209 -37
  18. relationalai/semantics/frontend/pprint.py +6 -2
  19. relationalai/semantics/metamodel/__init__.py +7 -0
  20. relationalai/semantics/metamodel/metamodel.py +2 -0
  21. relationalai/semantics/metamodel/metamodel_analyzer.py +58 -16
  22. relationalai/semantics/metamodel/pprint.py +6 -1
  23. relationalai/semantics/metamodel/rewriter.py +11 -7
  24. relationalai/semantics/metamodel/typer.py +116 -41
  25. relationalai/semantics/reasoners/__init__.py +11 -0
  26. relationalai/semantics/reasoners/graph/__init__.py +35 -0
  27. relationalai/semantics/reasoners/graph/core.py +9028 -0
  28. relationalai/semantics/std/__init__.py +30 -10
  29. relationalai/semantics/std/aggregates.py +641 -12
  30. relationalai/semantics/std/common.py +146 -13
  31. relationalai/semantics/std/constraints.py +71 -1
  32. relationalai/semantics/std/datetime.py +904 -21
  33. relationalai/semantics/std/decimals.py +143 -2
  34. relationalai/semantics/std/floats.py +57 -4
  35. relationalai/semantics/std/integers.py +98 -4
  36. relationalai/semantics/std/math.py +857 -35
  37. relationalai/semantics/std/numbers.py +216 -20
  38. relationalai/semantics/std/re.py +213 -5
  39. relationalai/semantics/std/strings.py +437 -44
  40. relationalai/shims/executor.py +60 -52
  41. relationalai/shims/fixtures.py +85 -0
  42. relationalai/shims/helpers.py +26 -2
  43. relationalai/shims/hoister.py +28 -9
  44. relationalai/shims/mm2v0.py +204 -173
  45. relationalai/tools/cli/cli.py +192 -10
  46. relationalai/tools/cli/components/progress_reader.py +1 -1
  47. relationalai/tools/cli/docs.py +394 -0
  48. relationalai/tools/debugger.py +11 -4
  49. relationalai/tools/qb_debugger.py +435 -0
  50. relationalai/tools/typer_debugger.py +1 -2
  51. relationalai/util/dataclasses.py +3 -5
  52. relationalai/util/docutils.py +1 -2
  53. relationalai/util/error.py +2 -5
  54. relationalai/util/python.py +23 -0
  55. relationalai/util/runtime.py +1 -2
  56. relationalai/util/schema.py +2 -4
  57. relationalai/util/structures.py +4 -2
  58. relationalai/util/tracing.py +8 -2
  59. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/METADATA +8 -5
  60. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/RECORD +118 -95
  61. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/WHEEL +1 -1
  62. v0/relationalai/__init__.py +1 -1
  63. v0/relationalai/clients/client.py +52 -18
  64. v0/relationalai/clients/exec_txn_poller.py +122 -0
  65. v0/relationalai/clients/local.py +23 -8
  66. v0/relationalai/clients/resources/azure/azure.py +36 -11
  67. v0/relationalai/clients/resources/snowflake/__init__.py +4 -4
  68. v0/relationalai/clients/resources/snowflake/cli_resources.py +12 -1
  69. v0/relationalai/clients/resources/snowflake/direct_access_resources.py +124 -100
  70. v0/relationalai/clients/resources/snowflake/engine_service.py +381 -0
  71. v0/relationalai/clients/resources/snowflake/engine_state_handlers.py +35 -29
  72. v0/relationalai/clients/resources/snowflake/error_handlers.py +43 -2
  73. v0/relationalai/clients/resources/snowflake/snowflake.py +277 -179
  74. v0/relationalai/clients/resources/snowflake/use_index_poller.py +8 -0
  75. v0/relationalai/clients/types.py +5 -0
  76. v0/relationalai/errors.py +19 -1
  77. v0/relationalai/semantics/lqp/algorithms.py +173 -0
  78. v0/relationalai/semantics/lqp/builtins.py +199 -2
  79. v0/relationalai/semantics/lqp/executor.py +68 -37
  80. v0/relationalai/semantics/lqp/ir.py +28 -2
  81. v0/relationalai/semantics/lqp/model2lqp.py +215 -45
  82. v0/relationalai/semantics/lqp/passes.py +13 -658
  83. v0/relationalai/semantics/lqp/rewrite/__init__.py +12 -0
  84. v0/relationalai/semantics/lqp/rewrite/algorithm.py +385 -0
  85. v0/relationalai/semantics/lqp/rewrite/constants_to_vars.py +70 -0
  86. v0/relationalai/semantics/lqp/rewrite/deduplicate_vars.py +104 -0
  87. v0/relationalai/semantics/lqp/rewrite/eliminate_data.py +108 -0
  88. v0/relationalai/semantics/lqp/rewrite/extract_keys.py +25 -3
  89. v0/relationalai/semantics/lqp/rewrite/period_math.py +77 -0
  90. v0/relationalai/semantics/lqp/rewrite/quantify_vars.py +65 -31
  91. v0/relationalai/semantics/lqp/rewrite/unify_definitions.py +317 -0
  92. v0/relationalai/semantics/lqp/utils.py +11 -1
  93. v0/relationalai/semantics/lqp/validators.py +14 -1
  94. v0/relationalai/semantics/metamodel/builtins.py +2 -1
  95. v0/relationalai/semantics/metamodel/compiler.py +2 -1
  96. v0/relationalai/semantics/metamodel/dependency.py +12 -3
  97. v0/relationalai/semantics/metamodel/executor.py +11 -1
  98. v0/relationalai/semantics/metamodel/factory.py +2 -2
  99. v0/relationalai/semantics/metamodel/helpers.py +7 -0
  100. v0/relationalai/semantics/metamodel/ir.py +3 -2
  101. v0/relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +30 -20
  102. v0/relationalai/semantics/metamodel/rewrite/flatten.py +50 -13
  103. v0/relationalai/semantics/metamodel/rewrite/format_outputs.py +9 -3
  104. v0/relationalai/semantics/metamodel/typer/checker.py +6 -4
  105. v0/relationalai/semantics/metamodel/typer/typer.py +4 -3
  106. v0/relationalai/semantics/metamodel/visitor.py +4 -3
  107. v0/relationalai/semantics/reasoners/optimization/solvers_dev.py +1 -1
  108. v0/relationalai/semantics/reasoners/optimization/solvers_pb.py +336 -86
  109. v0/relationalai/semantics/rel/compiler.py +2 -1
  110. v0/relationalai/semantics/rel/executor.py +3 -2
  111. v0/relationalai/semantics/tests/lqp/__init__.py +0 -0
  112. v0/relationalai/semantics/tests/lqp/algorithms.py +345 -0
  113. v0/relationalai/tools/cli.py +339 -186
  114. v0/relationalai/tools/cli_controls.py +216 -67
  115. v0/relationalai/tools/cli_helpers.py +410 -6
  116. v0/relationalai/util/format.py +5 -2
  117. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/entry_points.txt +0 -0
  118. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,30 @@
1
+ from .algorithm import AlgorithmPass
1
2
  from .annotate_constraints import AnnotateConstraints
2
3
  from .cdc import CDC
4
+ from .constants_to_vars import ConstantsToVars
5
+ from .deduplicate_vars import DeduplicateVars
6
+ from .eliminate_data import EliminateData
3
7
  from .extract_common import ExtractCommon
4
8
  from .extract_keys import ExtractKeys
5
9
  from .function_annotations import FunctionAnnotations, SplitMultiCheckRequires
10
+ from .period_math import PeriodMath
6
11
  from .quantify_vars import QuantifyVars
7
12
  from .splinter import Splinter
13
+ from .unify_definitions import UnifyDefinitions
8
14
 
9
15
  __all__ = [
16
+ "AlgorithmPass",
10
17
  "AnnotateConstraints",
11
18
  "CDC",
19
+ "ConstantsToVars",
20
+ "DeduplicateVars",
21
+ "EliminateData",
12
22
  "ExtractCommon",
13
23
  "ExtractKeys",
14
24
  "FunctionAnnotations",
25
+ "PeriodMath",
15
26
  "QuantifyVars",
16
27
  "Splinter",
17
28
  "SplitMultiCheckRequires",
29
+ "UnifyDefinitions",
18
30
  ]
@@ -0,0 +1,385 @@
1
+ from collections import defaultdict
2
+ from typing import Optional, TypeGuard, Union, cast
3
+ from v0.relationalai.semantics.metamodel import ir, helpers, factory
4
+ from v0.relationalai.semantics.metamodel.compiler import Pass
5
+ from v0.relationalai.semantics.metamodel.visitor import Visitor, Rewriter, collect_by_type
6
+ from v0.relationalai.semantics.lqp.algorithms import (
7
+ is_script, is_algorithm_script,is_logical_instruction, is_update_instruction,
8
+ get_instruction_head_rels, get_instruction_body_rels, mk_assign, split_instruction
9
+ )
10
+
11
+ class AlgorithmPass(Pass):
12
+ """
13
+ Transforms algorithm scripts by normalizing Loopy constructs (iterative algorithm).
14
+
15
+ This pass applies three main rewriting transformations to Metamodel IR that prepare
16
+ algorithm scripts for execution, in the order listed below:
17
+
18
+ 1. *Intermediate Rescoping*: Moves nested logical intermediate relations from their
19
+ original logical scope into algorithm scripts, placing them immediately before each
20
+ instruction that uses them (which can include Break instructions). Removes
21
+ intermediates from the logical scope if they're only used within algorithms.
22
+ TODO: Monitor https://github.com/RelationalAI/relationalai-python/pull/3187
23
+
24
+ Example (Metamodel IR):
25
+ BEFORE:
26
+ Logical
27
+ Logical
28
+ R(x::Int128, y::Int128)
29
+ → derive _nested_logical_1(x::Int128, y::Int128) @assign
30
+ Sequence @script @algorithm
31
+ Logical
32
+ _nested_logical_1(a::Int128, b::Int128)
33
+ → derive S(a::Int128, b::Int128) @assign
34
+
35
+ AFTER:
36
+ Logical
37
+ Sequence @script @algorithm
38
+ Logical
39
+ R(x::Int128, y::Int128)
40
+ → derive _nested_logical_1(x::Int128, y::Int128) @assign
41
+ Logical
42
+ _nested_logical_1(a::Int128, b::Int128)
43
+ → derive S(a::Int128, b::Int128) @assign
44
+
45
+ 2. **Update Normalization**: Transforms Loopy update operations (@upsert, @monoid, @monus)
46
+ to use a single body atom. Complex bodies with multiple lookups or additional
47
+ operations are normalized by introducing intermediate relations.
48
+
49
+ Example (Metamodel IR):
50
+ BEFORE:
51
+ Logical
52
+ R(x::Int128, y::Int128)
53
+ S(y::Int128, z::Int128)
54
+ → derive T(x::Int128, z::Int128) @upsert
55
+
56
+ AFTER:
57
+ Logical
58
+ R(x::Int128, y::Int128)
59
+ S(y::Int128, z::Int128)
60
+ → derive _loopy_update_intermediate_1(x::Int128, z::Int128) @assign
61
+
62
+ Logical
63
+ _loopy_update_intermediate_1(x::Int128, z::Int128)
64
+ → derive T(x::Int128, z::Int128) @upsert
65
+
66
+ 3. **Recursive Assignment Decoupling**: Decouples self-referential assignments where the
67
+ head relation appears in the body by introducing a copy relation. This transformation
68
+ is required for BackIR analysis compatibility.
69
+
70
+ Example (Metamodel IR):
71
+ BEFORE:
72
+ Logical
73
+ iter(i::Int128)
74
+ rel_primitive_int128_add(i::Int128, 1::Int128, i_plus_1::Int128)
75
+ → derive iter(i_plus_1::Int128) @assign
76
+
77
+ AFTER:
78
+ Logical
79
+ iter(i::Int128)
80
+ → derive _loopy_iter_copy_1(i::Int128) @assign
81
+
82
+ Logical
83
+ _loopy_iter_copy_1(i::Int128)
84
+ rel_primitive_int128_add(i::Int128, 1::Int128, i_plus_1::Int128)
85
+ → derive iter(i_plus_1::Int128) @assign
86
+ """
87
+ def rewrite(self, model: ir.Model, options:dict={}) -> ir.Model:
88
+ # Find all nested logical intermediates
89
+ intermediate_finder = FindIntermediates()
90
+ model.accept(intermediate_finder)
91
+
92
+ intermediate_analyzer = AnalyzeIntermediateUse(set(intermediate_finder.intermediates.keys()))
93
+ model.accept(intermediate_analyzer)
94
+
95
+ # Determine which intermediates to move and which to remove
96
+ uses_intermediates: dict[Union[ir.Logical, ir.Break], set[ir.Logical]] = defaultdict(set)
97
+ remove_declarations: set[ir.Logical] = set()
98
+ for rel, decl in intermediate_finder.intermediates.items():
99
+ if rel not in intermediate_analyzer.used_outside_algorithm:
100
+ remove_declarations.add(decl)
101
+ for instr in intermediate_analyzer.used_in_alg_instruction[rel]:
102
+ uses_intermediates[instr].add(decl)
103
+
104
+ # Rescope intermediates
105
+ rescoper = IntermediateRescoper(uses_intermediates, remove_declarations)
106
+ model = rescoper.walk(model)
107
+
108
+ # Normalize Loopy updates
109
+ normalizer = UpdateNormalizer()
110
+ model = normalizer.walk(model)
111
+
112
+ # Decompose recursive assignments
113
+ decomposer = RecursiveAssignmentDecoupling()
114
+ model = decomposer.walk(model)
115
+
116
+ return model
117
+
118
+ class FindIntermediates(Visitor):
119
+ """
120
+ Gathers all `_nested_logical.*` intermediates defined in a Logical scope (where order
121
+ doesn't matter); in particular DOES NOT gather any intermediates declared in the scope
122
+ of a Sequence.
123
+ """
124
+ def __init__(self):
125
+ self.intermediates: dict[ir.Relation, ir.Logical] = dict()
126
+ self._inside_algorithm: bool = False
127
+
128
+ def visit_logical(self, node: ir.Logical, parent: Optional[ir.Node]):
129
+ if is_logical_instruction(node):
130
+ heads = get_instruction_head_rels(node)
131
+ for rel in heads:
132
+ if rel.name.startswith("_nested_logical"):
133
+ self.intermediates[rel] = node
134
+ else:
135
+ super().visit_logical(node, parent)
136
+
137
+ def visit_sequence(self, node: ir.Sequence, parent: Optional[ir.Node]):
138
+ if is_algorithm_script(node):
139
+ self._inside_algorithm = True
140
+ super().visit_sequence(node, parent)
141
+ if is_algorithm_script(node):
142
+ self._inside_algorithm = False
143
+
144
+
145
+ class AnalyzeIntermediateUse(Visitor):
146
+ """
147
+ Identifies, for each nested logical intermediate, the algorithm instructions that
148
+ use it. Additionally, determines whether the intermediate is used anywhere
149
+ outside of an algorithm.
150
+ """
151
+ def __init__(self, intermediate_relations: set[ir.Relation]):
152
+ self.intermediates = intermediate_relations
153
+ self.used_in_algorithm: set[ir.Relation] = set()
154
+ self.used_in_alg_instruction: dict[ir.Relation, set[Union[ir.Logical, ir.Break]]] = {rel: set() for rel in intermediate_relations}
155
+ self.used_outside_algorithm: set[ir.Relation] = set()
156
+
157
+ self._current_algorithm: Optional[ir.Sequence] = None
158
+
159
+ def register_use(self, instr: Union[ir.Logical, ir.Break], uses_intermediates: set[ir.Relation]):
160
+ # this instruction uses intermediates
161
+ if self._current_algorithm is not None:
162
+ # instruction is inside an algorithm
163
+ for rel in uses_intermediates:
164
+ self.used_in_algorithm.add(rel)
165
+ self.used_in_alg_instruction[rel].add(instr)
166
+ else:
167
+ self.used_outside_algorithm.update(uses_intermediates)
168
+
169
+ def visit_break(self, node: ir.Break, parent: Optional[ir.Node]):
170
+ lookups = collect_by_type(ir.Lookup, node)
171
+ lookup_rels = {lookup.relation for lookup in lookups}
172
+ uses_intermediates = lookup_rels.intersection(self.intermediates)
173
+ self.register_use(node, uses_intermediates)
174
+ super().visit_break(node, parent)
175
+
176
+ def visit_logical(self, node: ir.Logical, parent: Optional[ir.Node]):
177
+ if is_logical_instruction(node):
178
+ body = get_instruction_body_rels(node)
179
+ uses_intermediates = body.intersection(self.intermediates)
180
+ self.register_use(node, uses_intermediates)
181
+ else:
182
+ super().visit_logical(node, parent)
183
+
184
+ def visit_sequence(self, node: ir.Sequence, parent: Optional[ir.Node]):
185
+ if is_algorithm_script(node):
186
+ self._current_algorithm = node
187
+ super().visit_sequence(node, parent)
188
+ if is_algorithm_script(node):
189
+ self._current_algorithm = None
190
+
191
+ class IntermediateRescoper(Rewriter):
192
+ """
193
+ Moves nested logical intermediates used in algorithm instructions from the logical scope
194
+ to any algorithm using the instruction before every instruction that uses them. Removes
195
+ an intermediate from the logical scope if it is not used anywhere else.
196
+
197
+ * `uses_intermediates`: a mapping from algorithm instructions to the set of nested logical
198
+ intermediates they use.
199
+ * `remove_declarations`: the set of nested logical intermediates to remove from the
200
+ logical scope because they are not used anywhere else.
201
+ """
202
+ def __init__(self,
203
+ uses_intermediates: dict[Union[ir.Logical, ir.Break], set[ir.Logical]],
204
+ remove_declarations: set[ir.Logical]):
205
+ super().__init__()
206
+ self.uses_intermediates = uses_intermediates
207
+ self.remove_declarations = remove_declarations
208
+
209
+ def handle_logical(self, node: ir.Logical, parent: ir.Node) -> ir.Logical:
210
+ body = []
211
+ for child in node.body:
212
+ if child in self.remove_declarations:
213
+ continue
214
+ child = self.walk(child, node)
215
+ body.append(child)
216
+ return node.reconstruct(node.engine, node.hoisted, tuple(body), node.annotations)
217
+
218
+ def handle_sequence(self, node: ir.Sequence, parent: ir.Node) -> ir.Sequence:
219
+ tasks = []
220
+ for child in node.tasks:
221
+ if child in self.uses_intermediates:
222
+ assert isinstance(child, (ir.Logical, ir.Break))
223
+ for intermediate in self.uses_intermediates[child]:
224
+ tasks.append(mk_assign(intermediate))
225
+ child = self.walk(child, node)
226
+ tasks.append(child)
227
+ return node.reconstruct(node.engine, node.hoisted, tuple(tasks), node.annotations)
228
+
229
+ class UpdateNormalizer(Rewriter):
230
+ """
231
+ This pass normalizes Loopy Update operations (upsert, monoid, and monus) to use a single
232
+ atom in their body. For any Update operation with more complex body, it introduces a new
233
+ intermediate relation to hold the body results.
234
+ """
235
+ def __init__(self):
236
+ super().__init__()
237
+ self._inside_algorithm: bool = False
238
+ self._intermediate_counter: int = 0
239
+
240
+ # Tests if the given Update operation requires normalization
241
+ # * the body has more than one Lookup operation, or
242
+ # * the body has other tasks than Lookup and Update
243
+ def _requires_update_normalization(self, update: ir.Task) -> bool:
244
+ if not isinstance(update, ir.Logical):
245
+ return False
246
+ if not is_update_instruction(update):
247
+ return False
248
+ _, lookups, others = split_instruction(update)
249
+ return len(lookups) > 1 or len(others) > 0
250
+
251
+ def handle_sequence(self, node: ir.Sequence, parent: ir.Node) -> ir.Sequence:
252
+ if is_algorithm_script(node):
253
+ self._inside_algorithm = True
254
+
255
+ if self._inside_algorithm:
256
+ new_tasks = []
257
+ for task in node.tasks:
258
+ if self._requires_update_normalization(task):
259
+ assert isinstance(task, ir.Logical)
260
+ intermediate, normalized_update = self._normalize_update_instruction(task)
261
+ new_tasks.extend((intermediate, normalized_update))
262
+ else:
263
+ new_tasks.append(self.walk(task, node))
264
+ result = node.reconstruct(node.engine, node.hoisted, tuple(new_tasks), node.annotations)
265
+ else:
266
+ result = super().handle_sequence(node, parent)
267
+
268
+ if is_algorithm_script(node):
269
+ self._inside_algorithm = False
270
+
271
+ return result
272
+
273
+ def _normalize_update_instruction(self, update_instr: ir.Logical) -> tuple[ir.Logical, ir.Logical]:
274
+ update, lookups, others = split_instruction(update_instr)
275
+ normalized_update = []
276
+
277
+ var_list = helpers.vars(update.args)
278
+
279
+ intermediate_rel = factory.relation(
280
+ self._fresh_intermediate_name(), [
281
+ factory.field(f"arg_{i}", var.type) for i, var in enumerate(var_list)
282
+ ]
283
+ )
284
+
285
+ intermediate_derive = factory.derive(intermediate_rel, var_list)
286
+ intermediate_logical = mk_assign(factory.logical(
287
+ engine=update_instr.engine,
288
+ hoisted=update_instr.hoisted,
289
+ body=(*lookups, *others, intermediate_derive),
290
+ annos=list(update_instr.annotations)
291
+ ))
292
+ assert isinstance(intermediate_logical, ir.Logical)
293
+
294
+ intermediate_lookup = factory.lookup(
295
+ intermediate_rel,
296
+ var_list
297
+ )
298
+
299
+ normalized_update = factory.logical(
300
+ engine=update_instr.engine,
301
+ hoisted=update_instr.hoisted,
302
+ body=(intermediate_lookup, update),
303
+ annos=list(update_instr.annotations)
304
+ )
305
+
306
+ return (intermediate_logical, normalized_update)
307
+
308
+ def _fresh_intermediate_name(self) -> str:
309
+ self._intermediate_counter += 1
310
+ return f"_loopy_update_intermediate_{self._intermediate_counter}"
311
+
312
+ class RecursiveAssignmentDecoupling(Rewriter):
313
+ """
314
+ Decouples assignments whose definition is "recursive", i.e., the body contain the head
315
+ e.g., `assign iter = iter + 1`. Currently, BackIR analysis cannot handle properly such
316
+ assignments. Such assignments are decoupled by introducing a new intermediate copy
317
+ relation; in the example above, `assign iter_copy = iter; assign iter = iter_copy + 1`.
318
+ The performance is not affected because the backend can identify the new assignment as a
319
+ copy operation and the execution will not lead to materialization of the intermediate
320
+ relation.
321
+ """
322
+ def __init__(self):
323
+ super().__init__()
324
+ self._intermediate_copy_counter: int = 0
325
+ # control of head_rel -> copy_rel substitution in traversal
326
+ self._perform_substitution: bool = False
327
+ self._head_rel: Optional[ir.Relation] = None
328
+ self._copy_rel: Optional[ir.Relation] = None
329
+
330
+ def _fresh_copy_rel_name(self, rel_name:str) -> str:
331
+ self._intermediate_copy_counter += 1
332
+ return f"_loopy_{rel_name}_copy_{self._intermediate_copy_counter}"
333
+
334
+ def handle_sequence(self, node: ir.Sequence, parent: ir.Node) -> ir.Sequence:
335
+ if is_script(node):
336
+ new_tasks = []
337
+ for task in node.tasks:
338
+ if self._is_recursive_assignment(task):
339
+ assert isinstance(task, ir.Logical)
340
+ intermediate_copy, decomposed_assign = self._decouple_recursive_assignment(task, parent)
341
+ new_tasks.extend((intermediate_copy, decomposed_assign))
342
+ else:
343
+ new_tasks.append(self.walk(task, node))
344
+ return node.reconstruct(node.engine, node.hoisted, tuple(new_tasks), node.annotations)
345
+ else:
346
+ return super().handle_sequence(node, parent)
347
+
348
+ def _is_recursive_assignment(self, task: ir.Task) -> TypeGuard[ir.Logical]:
349
+ if is_logical_instruction(task):
350
+ heads = get_instruction_head_rels(task)
351
+ body = get_instruction_body_rels(task)
352
+ return len(body & heads) > 0
353
+ return False
354
+
355
+ def _decouple_recursive_assignment(self, rule: ir.Logical, parent: ir.Node) -> tuple[ir.Logical, ir.Logical]:
356
+ # we have `assign rel(x,...) = ..., rel(y,...), ...`
357
+ update, _, _ = split_instruction(rule)
358
+ self._head_rel = update.relation
359
+
360
+ copy_rel_name = self._fresh_copy_rel_name(self._head_rel.name)
361
+
362
+ self._copy_rel = factory.relation(copy_rel_name, list(self._head_rel.fields))
363
+ # build `assign copy_rel(x,...) = rel(x,...)`
364
+ copy_rule = cast(ir.Logical, mk_assign(
365
+ factory.logical([
366
+ factory.lookup(self._head_rel,update.args),
367
+ factory.update(self._copy_rel, update.args, update.effect)
368
+ ])
369
+ ))
370
+
371
+ # build `assign rel(x,...) = ..., copy_rel(y,...), ...``
372
+ self._perform_substitution = True
373
+ rewritten_rule = self.walk(rule, parent)
374
+ self._perform_substitution = False
375
+
376
+ self._head_rel = None
377
+ self._copy_rel = None
378
+
379
+ return (copy_rule, rewritten_rule)
380
+
381
+ def handle_lookup(self, node: ir.Lookup, parent: ir.Node) -> ir.Lookup:
382
+ if self._perform_substitution and node.relation == self._head_rel:
383
+ assert self._copy_rel is not None
384
+ return factory.lookup(self._copy_rel, node.args)
385
+ return super().handle_lookup(node, parent)
@@ -0,0 +1,70 @@
1
+ from v0.relationalai.semantics.metamodel.compiler import Pass
2
+ from v0.relationalai.semantics.metamodel import ir, builtins as rel_builtins, factory as f, visitor
3
+ from v0.relationalai.semantics.metamodel.typer import typer
4
+
5
+ from typing import List, Sequence, Tuple, Union
6
+
7
+ # Rewrite constants to vars in Updates. This results in a more normalized format where
8
+ # updates contain only variables. This allows for easier rewrites in later passes.
9
+ class ConstantsToVars(Pass):
10
+ def rewrite(self, model: ir.Model, options:dict={}) -> ir.Model:
11
+ r = self.ConstantToVarRewriter()
12
+ return r.walk(model)
13
+
14
+ # Return 1) a new list of Values with no duplicates (at the object level) and
15
+ # 2) equalities between any original Value and a deduplicated Value.
16
+ @staticmethod
17
+ def replace_constants_with_vars(vals: Sequence[ir.Value]) -> Tuple[List[ir.Value], List[ir.Lookup]]:
18
+ new_vals = []
19
+ eqs = []
20
+
21
+ for i, val in enumerate(vals):
22
+ if isinstance(val, ir.PyValue) or isinstance(val, ir.Literal):
23
+ # Replace constant with a new Var.
24
+ typ = typer.to_type(val)
25
+ assert isinstance(typ, ir.ScalarType), "can only replace scalar constants with vars"
26
+ new_var = ir.Var(typ, f"{typ.name.lower()}_{i}")
27
+ new_vals.append(new_var)
28
+ eqs.append(f.lookup(rel_builtins.eq, [new_var, val]))
29
+ else:
30
+ new_vals.append(val)
31
+
32
+ return new_vals, eqs
33
+
34
+ @staticmethod
35
+ def dedup_update(update: ir.Update) -> List[Union[ir.Update, ir.Lookup]]:
36
+ deduped_vals, req_lookups = ConstantsToVars.replace_constants_with_vars(update.args)
37
+ new_update = ir.Update(
38
+ update.engine,
39
+ update.relation,
40
+ tuple(deduped_vals),
41
+ update.effect,
42
+ update.annotations,
43
+ )
44
+ return req_lookups + [new_update]
45
+
46
+ # Does the actual work.
47
+ class ConstantToVarRewriter(visitor.Rewriter):
48
+ def __init__(self):
49
+ super().__init__()
50
+
51
+ # We implement handle_logical instead of handle_update because in
52
+ # addition to modifying said update we require new lookups (equality
53
+ # between original and deduplicated variables).
54
+ def handle_logical(self, node: ir.Logical, parent: ir.Node):
55
+ # In order to recurse over subtasks.
56
+ node = super().handle_logical(node, parent)
57
+
58
+ new_body = []
59
+ for subtask in node.body:
60
+ if isinstance(subtask, ir.Update):
61
+ new_body.extend(ConstantsToVars.dedup_update(subtask))
62
+ else:
63
+ new_body.append(subtask)
64
+
65
+ return ir.Logical(
66
+ node.engine,
67
+ node.hoisted,
68
+ tuple(new_body),
69
+ node.annotations
70
+ )
@@ -0,0 +1,104 @@
1
+ from v0.relationalai.semantics.metamodel.compiler import Pass
2
+ from v0.relationalai.semantics.metamodel import ir, builtins as rel_builtins, factory as f, visitor
3
+ from v0.relationalai.semantics.metamodel import helpers
4
+ from v0.relationalai.semantics.metamodel.util import FrozenOrderedSet
5
+
6
+ from v0.relationalai.semantics.lqp.utils import output_names
7
+
8
+ from typing import List, Sequence, Tuple, Union
9
+
10
+ # Deduplicate Vars in Updates and Outputs.
11
+ class DeduplicateVars(Pass):
12
+ def rewrite(self, model: ir.Model, options:dict={}) -> ir.Model:
13
+ r = self.VarDeduplicator()
14
+ return r.walk(model)
15
+
16
+ # Return 1) a new list of Values with no duplicates (at the object level) and
17
+ # 2) equalities between any original Value and a deduplicated Value.
18
+ @staticmethod
19
+ def dedup_values(vals: Sequence[ir.Value]) -> Tuple[List[ir.Value], List[ir.Lookup]]:
20
+ # If a var is seen more than once, it is a duplicate and we will create
21
+ # a new Var and equate it with the seen one.
22
+ seen_vars = set()
23
+
24
+ new_vals = []
25
+ eqs = []
26
+
27
+ for i, val in enumerate(vals):
28
+ # Duplicates can only occur within Vars.
29
+ # TODO: we don't know for sure if these are the only relevant cases.
30
+ if isinstance(val, ir.Default) or isinstance(val, ir.Var):
31
+ var = val if isinstance(val, ir.Var) else val.var
32
+ if var in seen_vars:
33
+ new_var = ir.Var(var.type, var.name + "_dup_" + str(i))
34
+ new_val = new_var if isinstance(val, ir.Var) else ir.Default(new_var, val.value)
35
+ new_vals.append(new_val)
36
+ eqs.append(f.lookup(rel_builtins.eq, [new_var, var]))
37
+ else:
38
+ seen_vars.add(var)
39
+ new_vals.append(val)
40
+ else:
41
+ # No possibility of problematic duplication.
42
+ new_vals.append(val)
43
+
44
+ return new_vals, eqs
45
+
46
+ # Returns a reconstructed output with no duplicate variable objects
47
+ # (dedup_values) and now necessary equalities between any two previously
48
+ # duplicate variables.
49
+ @staticmethod
50
+ def dedup_output(output: ir.Output) -> List[Union[ir.Output, ir.Lookup]]:
51
+ vals = helpers.output_values(output.aliases)
52
+ deduped_vals, req_lookups = DeduplicateVars.dedup_values(vals)
53
+ # Need the names so we can recombine.
54
+ alias_names = output_names(output.aliases)
55
+ new_output = ir.Output(
56
+ output.engine,
57
+ FrozenOrderedSet(list(zip(alias_names, deduped_vals))),
58
+ output.keys,
59
+ output.annotations,
60
+ )
61
+ return req_lookups + [new_output]
62
+
63
+ # Returns a replacement update with no duplicate variable objects
64
+ # (dedup_values) and now necessary equalities between any two previously
65
+ # duplicate variables.
66
+ @staticmethod
67
+ def dedup_update(update: ir.Update) -> List[Union[ir.Update, ir.Lookup]]:
68
+ deduped_vals, req_lookups = DeduplicateVars.dedup_values(update.args)
69
+ new_update = ir.Update(
70
+ update.engine,
71
+ update.relation,
72
+ tuple(deduped_vals),
73
+ update.effect,
74
+ update.annotations,
75
+ )
76
+ return req_lookups + [new_update]
77
+
78
+ # Does the actual work.
79
+ class VarDeduplicator(visitor.Rewriter):
80
+ def __init__(self):
81
+ super().__init__()
82
+
83
+ # We implement handle_logical instead of handle_update/handle_output
84
+ # because in addition to modifying said update/output we require new
85
+ # lookups (equality between original and deduplicated variables).
86
+ def handle_logical(self, node: ir.Logical, parent: ir.Node):
87
+ # In order to recurse over subtasks.
88
+ node = super().handle_logical(node, parent)
89
+
90
+ new_body = []
91
+ for subtask in node.body:
92
+ if isinstance(subtask, ir.Output):
93
+ new_body.extend(DeduplicateVars.dedup_output(subtask))
94
+ elif isinstance(subtask, ir.Update):
95
+ new_body.extend(DeduplicateVars.dedup_update(subtask))
96
+ else:
97
+ new_body.append(subtask)
98
+
99
+ return ir.Logical(
100
+ node.engine,
101
+ node.hoisted,
102
+ tuple(new_body),
103
+ node.annotations
104
+ )