relationalai 1.0.0a3__py3-none-any.whl → 1.0.0a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. relationalai/config/config.py +47 -21
  2. relationalai/config/connections/__init__.py +5 -2
  3. relationalai/config/connections/duckdb.py +2 -2
  4. relationalai/config/connections/local.py +31 -0
  5. relationalai/config/connections/snowflake.py +0 -1
  6. relationalai/config/external/raiconfig_converter.py +235 -0
  7. relationalai/config/external/raiconfig_models.py +202 -0
  8. relationalai/config/external/utils.py +31 -0
  9. relationalai/config/shims.py +1 -0
  10. relationalai/semantics/__init__.py +10 -8
  11. relationalai/semantics/backends/sql/sql_compiler.py +1 -4
  12. relationalai/semantics/experimental/__init__.py +0 -0
  13. relationalai/semantics/experimental/builder.py +295 -0
  14. relationalai/semantics/experimental/builtins.py +154 -0
  15. relationalai/semantics/frontend/base.py +67 -42
  16. relationalai/semantics/frontend/core.py +34 -6
  17. relationalai/semantics/frontend/front_compiler.py +209 -37
  18. relationalai/semantics/frontend/pprint.py +6 -2
  19. relationalai/semantics/metamodel/__init__.py +7 -0
  20. relationalai/semantics/metamodel/metamodel.py +2 -0
  21. relationalai/semantics/metamodel/metamodel_analyzer.py +58 -16
  22. relationalai/semantics/metamodel/pprint.py +6 -1
  23. relationalai/semantics/metamodel/rewriter.py +11 -7
  24. relationalai/semantics/metamodel/typer.py +116 -41
  25. relationalai/semantics/reasoners/__init__.py +11 -0
  26. relationalai/semantics/reasoners/graph/__init__.py +35 -0
  27. relationalai/semantics/reasoners/graph/core.py +9028 -0
  28. relationalai/semantics/std/__init__.py +30 -10
  29. relationalai/semantics/std/aggregates.py +641 -12
  30. relationalai/semantics/std/common.py +146 -13
  31. relationalai/semantics/std/constraints.py +71 -1
  32. relationalai/semantics/std/datetime.py +904 -21
  33. relationalai/semantics/std/decimals.py +143 -2
  34. relationalai/semantics/std/floats.py +57 -4
  35. relationalai/semantics/std/integers.py +98 -4
  36. relationalai/semantics/std/math.py +857 -35
  37. relationalai/semantics/std/numbers.py +216 -20
  38. relationalai/semantics/std/re.py +213 -5
  39. relationalai/semantics/std/strings.py +437 -44
  40. relationalai/shims/executor.py +60 -52
  41. relationalai/shims/fixtures.py +85 -0
  42. relationalai/shims/helpers.py +26 -2
  43. relationalai/shims/hoister.py +28 -9
  44. relationalai/shims/mm2v0.py +204 -173
  45. relationalai/tools/cli/cli.py +192 -10
  46. relationalai/tools/cli/components/progress_reader.py +1 -1
  47. relationalai/tools/cli/docs.py +394 -0
  48. relationalai/tools/debugger.py +11 -4
  49. relationalai/tools/qb_debugger.py +435 -0
  50. relationalai/tools/typer_debugger.py +1 -2
  51. relationalai/util/dataclasses.py +3 -5
  52. relationalai/util/docutils.py +1 -2
  53. relationalai/util/error.py +2 -5
  54. relationalai/util/python.py +23 -0
  55. relationalai/util/runtime.py +1 -2
  56. relationalai/util/schema.py +2 -4
  57. relationalai/util/structures.py +4 -2
  58. relationalai/util/tracing.py +8 -2
  59. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/METADATA +8 -5
  60. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/RECORD +118 -95
  61. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/WHEEL +1 -1
  62. v0/relationalai/__init__.py +1 -1
  63. v0/relationalai/clients/client.py +52 -18
  64. v0/relationalai/clients/exec_txn_poller.py +122 -0
  65. v0/relationalai/clients/local.py +23 -8
  66. v0/relationalai/clients/resources/azure/azure.py +36 -11
  67. v0/relationalai/clients/resources/snowflake/__init__.py +4 -4
  68. v0/relationalai/clients/resources/snowflake/cli_resources.py +12 -1
  69. v0/relationalai/clients/resources/snowflake/direct_access_resources.py +124 -100
  70. v0/relationalai/clients/resources/snowflake/engine_service.py +381 -0
  71. v0/relationalai/clients/resources/snowflake/engine_state_handlers.py +35 -29
  72. v0/relationalai/clients/resources/snowflake/error_handlers.py +43 -2
  73. v0/relationalai/clients/resources/snowflake/snowflake.py +277 -179
  74. v0/relationalai/clients/resources/snowflake/use_index_poller.py +8 -0
  75. v0/relationalai/clients/types.py +5 -0
  76. v0/relationalai/errors.py +19 -1
  77. v0/relationalai/semantics/lqp/algorithms.py +173 -0
  78. v0/relationalai/semantics/lqp/builtins.py +199 -2
  79. v0/relationalai/semantics/lqp/executor.py +68 -37
  80. v0/relationalai/semantics/lqp/ir.py +28 -2
  81. v0/relationalai/semantics/lqp/model2lqp.py +215 -45
  82. v0/relationalai/semantics/lqp/passes.py +13 -658
  83. v0/relationalai/semantics/lqp/rewrite/__init__.py +12 -0
  84. v0/relationalai/semantics/lqp/rewrite/algorithm.py +385 -0
  85. v0/relationalai/semantics/lqp/rewrite/constants_to_vars.py +70 -0
  86. v0/relationalai/semantics/lqp/rewrite/deduplicate_vars.py +104 -0
  87. v0/relationalai/semantics/lqp/rewrite/eliminate_data.py +108 -0
  88. v0/relationalai/semantics/lqp/rewrite/extract_keys.py +25 -3
  89. v0/relationalai/semantics/lqp/rewrite/period_math.py +77 -0
  90. v0/relationalai/semantics/lqp/rewrite/quantify_vars.py +65 -31
  91. v0/relationalai/semantics/lqp/rewrite/unify_definitions.py +317 -0
  92. v0/relationalai/semantics/lqp/utils.py +11 -1
  93. v0/relationalai/semantics/lqp/validators.py +14 -1
  94. v0/relationalai/semantics/metamodel/builtins.py +2 -1
  95. v0/relationalai/semantics/metamodel/compiler.py +2 -1
  96. v0/relationalai/semantics/metamodel/dependency.py +12 -3
  97. v0/relationalai/semantics/metamodel/executor.py +11 -1
  98. v0/relationalai/semantics/metamodel/factory.py +2 -2
  99. v0/relationalai/semantics/metamodel/helpers.py +7 -0
  100. v0/relationalai/semantics/metamodel/ir.py +3 -2
  101. v0/relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +30 -20
  102. v0/relationalai/semantics/metamodel/rewrite/flatten.py +50 -13
  103. v0/relationalai/semantics/metamodel/rewrite/format_outputs.py +9 -3
  104. v0/relationalai/semantics/metamodel/typer/checker.py +6 -4
  105. v0/relationalai/semantics/metamodel/typer/typer.py +4 -3
  106. v0/relationalai/semantics/metamodel/visitor.py +4 -3
  107. v0/relationalai/semantics/reasoners/optimization/solvers_dev.py +1 -1
  108. v0/relationalai/semantics/reasoners/optimization/solvers_pb.py +336 -86
  109. v0/relationalai/semantics/rel/compiler.py +2 -1
  110. v0/relationalai/semantics/rel/executor.py +3 -2
  111. v0/relationalai/semantics/tests/lqp/__init__.py +0 -0
  112. v0/relationalai/semantics/tests/lqp/algorithms.py +345 -0
  113. v0/relationalai/tools/cli.py +339 -186
  114. v0/relationalai/tools/cli_controls.py +216 -67
  115. v0/relationalai/tools/cli_helpers.py +410 -6
  116. v0/relationalai/util/format.py +5 -2
  117. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/entry_points.txt +0 -0
  118. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,108 @@
1
+ from v0.relationalai.semantics.metamodel.compiler import Pass
2
+ from v0.relationalai.semantics.metamodel import ir, builtins as rel_builtins, factory as f, visitor
3
+
4
+ from typing import cast
5
+ import pandas as pd
6
+ import hashlib
7
+
8
+ # Creates intermediary relations for all Data nodes and replaces said Data nodes
9
+ # with a Lookup into these created relations. Reuse duplicate created relations.
10
+ class EliminateData(Pass):
11
+ def rewrite(self, model: ir.Model, options:dict={}) -> ir.Model:
12
+ r = self.DataRewriter()
13
+ return r.walk(model)
14
+
15
+ # Does the actual work.
16
+ class DataRewriter(visitor.Rewriter):
17
+ new_relations: list[ir.Relation]
18
+ new_updates: list[ir.Logical]
19
+ # Counter for naming new relations.
20
+ # It must be that new_count == len new_updates == len new_relations.
21
+ new_count: int
22
+ # Cache for Data nodes to avoid creating duplicate intermediary relations
23
+ data_cache: dict[str, ir.Relation]
24
+
25
+ def __init__(self):
26
+ self.new_relations = []
27
+ self.new_updates = []
28
+ self.new_count = 0
29
+ self.data_cache = {}
30
+ super().__init__()
31
+
32
+ # Create a cache key for a Data node based on its structure and content
33
+ def _data_cache_key(self, node: ir.Data) -> str:
34
+ values = pd.util.hash_pandas_object(node.data).values
35
+ return hashlib.sha256(bytes(values)).hexdigest()
36
+
37
+ def _intermediary_relation(self, node: ir.Data) -> ir.Relation:
38
+ cache_key = self._data_cache_key(node)
39
+ if cache_key in self.data_cache:
40
+ return self.data_cache[cache_key]
41
+ self.new_count += 1
42
+ intermediary_name = f"formerly_Data_{self.new_count}"
43
+
44
+ intermediary_relation = f.relation(
45
+ intermediary_name,
46
+ [f.field(v.name, v.type) for v in node.vars]
47
+ )
48
+ self.new_relations.append(intermediary_relation)
49
+
50
+ intermediary_update = f.logical([
51
+ # For each row (union), equate values and their variable (logical).
52
+ f.union(
53
+ [
54
+ f.logical(
55
+ [
56
+ f.lookup(rel_builtins.eq, [f.literal(val, var.type), var])
57
+ for (val, var) in zip(row, node.vars)
58
+ ],
59
+ )
60
+ for row in node
61
+ ],
62
+ hoisted = node.vars,
63
+ ),
64
+ # And pop it back into the relation.
65
+ f.update(intermediary_relation, node.vars, ir.Effect.derive),
66
+ ])
67
+ self.new_updates.append(intermediary_update)
68
+
69
+ # Cache the result for reuse
70
+ self.data_cache[cache_key] = intermediary_relation
71
+
72
+ return intermediary_relation
73
+
74
+ # Create a new intermediary relation representing the Data (and pop it in
75
+ # new_updates/new_relations) and replace this Data with a Lookup of said
76
+ # intermediary.
77
+ def handle_data(self, node: ir.Data, parent: ir.Node) -> ir.Lookup:
78
+ intermediary_relation = self._intermediary_relation(node)
79
+ replacement_lookup = f.lookup(intermediary_relation, node.vars)
80
+
81
+ return replacement_lookup
82
+
83
+ # Walks the model for the handle_data work then updates the model with
84
+ # the new state.
85
+ def handle_model(self, model: ir.Model, parent: None):
86
+ walked_model = super().handle_model(model, parent)
87
+ assert len(self.new_relations) == len(self.new_updates) and self.new_count == len(self.new_relations)
88
+
89
+ # This is okay because its LQP.
90
+ assert isinstance(walked_model.root, ir.Logical)
91
+ root_logical = cast(ir.Logical, walked_model.root)
92
+
93
+ # We may need to add the new intermediaries from handle_data to the model.
94
+ if self.new_count == 0:
95
+ return model
96
+ else:
97
+ return ir.Model(
98
+ walked_model.engines,
99
+ walked_model.relations | self.new_relations,
100
+ walked_model.types,
101
+ ir.Logical(
102
+ root_logical.engine,
103
+ root_logical.hoisted,
104
+ root_logical.body + tuple(self.new_updates),
105
+ root_logical.annotations,
106
+ ),
107
+ walked_model.annotations,
108
+ )
@@ -118,6 +118,17 @@ class ExtractKeys(Pass):
118
118
  the same here).
119
119
  """
120
120
  class ExtractKeysRewriter(Rewriter):
121
+ def __init__(self):
122
+ super().__init__()
123
+ self.compound_keys: dict[Any, ir.Var] = {}
124
+
125
+ def _get_compound_key(self, orig_keys: Iterable[ir.Var]) -> ir.Var:
126
+ if orig_keys in self.compound_keys:
127
+ return self.compound_keys[orig_keys]
128
+ compound_key = f.var("compound_key", types.Hash)
129
+ self.compound_keys[orig_keys] = compound_key
130
+ return compound_key
131
+
121
132
  def handle_logical(self, node: ir.Logical, parent: ir.Node, ctx:Optional[Any]=None) -> ir.Logical:
122
133
  outputs = [x for x in node.body if isinstance(x, ir.Output) and x.keys]
123
134
  # We are not in a logical with an output at this level.
@@ -170,7 +181,7 @@ class ExtractKeysRewriter(Rewriter):
170
181
  annos = list(output.annotations)
171
182
  annos.append(f.annotation(builtins.output_keys, tuple(output_keys)))
172
183
  # Create a compound key that will be used in place of the original keys.
173
- compound_key = f.var("compound_key", types.Hash)
184
+ compound_key = self._get_compound_key(output_keys)
174
185
 
175
186
  for key_combination in combinations:
176
187
  missing_keys = OrderedSet.from_iterable(output_keys)
@@ -192,8 +203,13 @@ class ExtractKeysRewriter(Rewriter):
192
203
  # handle the construct node in each clone
193
204
  values: list[ir.Value] = [compound_key.type]
194
205
  for key in output_keys:
195
- assert isinstance(key.type, ir.ScalarType)
196
- values.append(ir.Literal(types.String, key.type.name))
206
+ if isinstance(key.type, ir.UnionType):
207
+ # the typer can derive union types when multiple distinct entities flow
208
+ # into a relation's field, so use AnyEntity as the type marker
209
+ values.append(ir.Literal(types.String, "AnyEntity"))
210
+ else:
211
+ assert isinstance(key.type, ir.ScalarType)
212
+ values.append(ir.Literal(types.String, key.type.name))
197
213
  if key in key_combination:
198
214
  values.append(key)
199
215
  body.add(ir.Construct(None, tuple(values), compound_key, OrderedSet().frozen()))
@@ -408,6 +424,12 @@ class ExtractKeysRewriter(Rewriter):
408
424
  for arg in args[:-1]:
409
425
  extended_vars.add(arg)
410
426
  there_is_progress = True
427
+ elif isinstance(task, ir.Not):
428
+ if isinstance(task.task, ir.Logical):
429
+ hoisted = helpers.hoisted_vars(task.task.hoisted)
430
+ if var in hoisted:
431
+ partitions[var].add(task)
432
+ there_is_progress = True
411
433
  else:
412
434
  assert False, f"invalid node kind {type(task)}"
413
435
 
@@ -0,0 +1,77 @@
1
+ from v0.relationalai.semantics.metamodel.compiler import Pass
2
+ from v0.relationalai.semantics.metamodel import ir, builtins as rel_builtins, factory as f, visitor
3
+ from v0.relationalai.semantics.metamodel import types
4
+
5
+ # Generate date arithmetic expressions, such as
6
+ # `rel_primitive_date_add(:day, [date] delta, res_2)` by finding the period
7
+ # expression for the delta and adding the period type to the date arithmetic expression.
8
+ #
9
+ # date_add and it's kin are generated by a period expression, e.g.,
10
+ # `day(delta, res_1)`
11
+ # followed by the date arithmetic expression using the period
12
+ # `date_add([date] res_1 res_2)`
13
+ class PeriodMath(Pass):
14
+ def rewrite(self, model: ir.Model, options:dict={}) -> ir.Model:
15
+ period_rewriter = self.PeriodRewriter()
16
+ model = period_rewriter.walk(model)
17
+ period_math_rewriter = self.PeriodMathRewriter(period_rewriter.period_vars)
18
+ model = period_math_rewriter.walk(model)
19
+ return model
20
+
21
+ # Find all period builtins. We need to make them safe for the emitter (either by
22
+ # translating to a cast, or removing) and store the variable and period type for use
23
+ # in the date/datetime add/subtract expressions.
24
+ class PeriodRewriter(visitor.Rewriter):
25
+ def __init__(self):
26
+ super().__init__()
27
+ self.period_vars: dict[ir.Var, str] = {}
28
+
29
+ def handle_lookup(self, node: ir.Lookup, parent: ir.Node) -> ir.Lookup:
30
+ if not rel_builtins.is_builtin(node.relation):
31
+ return node
32
+
33
+ if node.relation.name not in {
34
+ "year", "month", "week", "day", "hour", "minute", "second", "millisecond", "microsecond", "nanosecond"
35
+ }:
36
+ return node
37
+
38
+ assert len(node.args) == 2, "Expect 2 arguments for period builtins"
39
+ assert isinstance(node.args[1], ir.Var), "Expect result to be a variable"
40
+ period = node.relation.name
41
+ result_var = node.args[1]
42
+ self.period_vars[result_var] = period
43
+
44
+ # Ideally we could now remove the unused and unhandled period type construction
45
+ # but we may also need to cast the original variable to an Int64 for use by the
46
+ # date/datetime add/subtract expressions.
47
+ # TODO: Remove the node entirely where possible and update uses of the result
48
+ return f.lookup(rel_builtins.cast, [types.Int64, node.args[0], result_var])
49
+
50
+ # Update date/datetime add/subtract expressions with period information.
51
+ class PeriodMathRewriter(visitor.Rewriter):
52
+ def __init__(self, period_vars: dict[ir.Var, str]):
53
+ super().__init__()
54
+ self.period_vars: dict[ir.Var, str] = period_vars
55
+
56
+ def handle_lookup(self, node: ir.Lookup, parent: ir.Node) -> ir.Lookup:
57
+ if not rel_builtins.is_builtin(node.relation):
58
+ return node
59
+
60
+ if node.relation.name not in {
61
+ "date_add", "date_subtract", "datetime_add", "datetime_subtract"
62
+ }:
63
+ return node
64
+
65
+ if len(node.args) == 4:
66
+ # We've already visited this lookup
67
+ return node
68
+
69
+ assert isinstance(node.args[1], ir.Var), "Expect period to be a variable"
70
+ period_var = node.args[1]
71
+ assert period_var in self.period_vars, "datemath found, but no vars to insert"
72
+
73
+ period = self.period_vars[period_var]
74
+
75
+ new_args = [f.literal(period, types.Symbol)] + [arg for arg in node.args]
76
+
77
+ return f.lookup(node.relation, new_args)
@@ -69,7 +69,7 @@ class VarScopeInfo(Visitor):
69
69
  ir.Var, ir.Literal, ir.Relation, ir.Field,
70
70
  ir.Default, ir.Output, ir.Update, ir.Aggregate,
71
71
  ir.Check, ir.Require,
72
- ir.Annotation, ir.Rank)
72
+ ir.Annotation, ir.Rank, ir.Break)
73
73
 
74
74
  def __init__(self):
75
75
  super().__init__()
@@ -103,16 +103,29 @@ class VarScopeInfo(Visitor):
103
103
  self._record(node, scope_vars)
104
104
 
105
105
  elif isinstance(node, (ir.Match, ir.Union)):
106
- # Match/Union inherits the vars in scope from its sub-tasks.
106
+ # Match/Union only inherits vars if they are in scope for all sub-tasks.
107
107
  scope_vars = ordered_set()
108
+ # Prime the search with the first sub-task's vars.
109
+ if node.tasks:
110
+ scope_vars.update(self._vars_in_scope.get(node.tasks[0].id, None))
111
+
108
112
  for task in node.tasks:
109
113
  sub_scope_vars = self._vars_in_scope.get(task.id, None)
110
- if sub_scope_vars:
111
- scope_vars.update(sub_scope_vars)
114
+ if not scope_vars or not sub_scope_vars:
115
+ scope_vars = ordered_set()
116
+ break
117
+ scope_vars = (scope_vars & sub_scope_vars)
118
+
112
119
  # Hoisted vars are not considered for quantification at this level.
113
120
  scope_vars.difference_update(helpers.hoisted_vars(node.hoisted))
114
121
  self._record(node, scope_vars)
115
122
 
123
+ elif isinstance(node, (ir.Loop, ir.Sequence)):
124
+ # Variables in Loops and Sequences are scoped exclusively within the body and
125
+ # not propagated outside. No need to record any variables, as they shouldn't be
126
+ # in scope for the node itself
127
+ pass
128
+
116
129
  elif isinstance(node, ir.Logical):
117
130
  self._do_logical(node)
118
131
 
@@ -128,6 +141,9 @@ class VarScopeInfo(Visitor):
128
141
  all_nested_vars = ordered_set()
129
142
  output_vars = ordered_set()
130
143
 
144
+ # Collect variables nested in child Logical and Not nodes
145
+ nested_vars_in_task: dict[ir.Var, int] = dict()
146
+
131
147
  # Collect all variables from logical sub-tasks
132
148
  for task in node.body:
133
149
  if isinstance(task, ir.Output):
@@ -140,19 +156,29 @@ class VarScopeInfo(Visitor):
140
156
  scope_vars.add(var)
141
157
  continue
142
158
 
143
- sub_scope_vars = self._vars_in_scope.get(task.id, None)
144
-
145
159
  # Hoisted variables from sub-tasks are brought again into scope.
146
160
  if isinstance(task, (ir.Logical, ir.Union, ir.Match)):
147
161
  scope_vars.update(helpers.hoisted_vars(task.hoisted))
148
162
 
149
- if sub_scope_vars:
150
- if isinstance(task, ir.Logical):
151
- all_nested_vars.update(sub_scope_vars)
152
- elif not isinstance(task, ir.Not):
153
- # For all other node kinds (except Not), just propagate the variables in scope.
154
- # Not nodes stop the propagation of variables coming from their sub-tasks.
155
- scope_vars.update(sub_scope_vars)
163
+ # Get variables in sub-task scope
164
+ sub_scope_vars = self._vars_in_scope.get(task.id, ordered_set())
165
+
166
+ if isinstance(task, ir.Logical):
167
+ # Logical child nodes should have their nested variables quantified
168
+ # only if they are needed in more than one child task.
169
+ for var in sub_scope_vars:
170
+ if var not in nested_vars_in_task:
171
+ nested_vars_in_task[var] = 0
172
+ nested_vars_in_task[var] += 1
173
+ elif not isinstance(task, ir.Not):
174
+ # Other nodes with nested variables need to be quantified at this level
175
+ scope_vars.update(sub_scope_vars)
176
+
177
+ for v, c in nested_vars_in_task.items():
178
+ # If the variable appears in more than one nested child, then it needs to be
179
+ # quantified here. Otherwise, it will be handled in the child node
180
+ if c > 1:
181
+ all_nested_vars.add(v)
156
182
 
157
183
  # Nested variables also need to be introduced, provided they are not output variables.
158
184
  for var in all_nested_vars:
@@ -190,37 +216,30 @@ class FindQuantificationNodes(Visitor):
190
216
  def __init__(self, var_info: VarScopeInfo):
191
217
  super().__init__()
192
218
  self._vars_in_scope = var_info._vars_in_scope
193
- self._handled_vars = ordered_set()
219
+ self.handled_vars: dict[int, OrderedSet[ir.Var]] = {}
194
220
  self.node_quantifies_vars = {}
195
221
 
196
222
  def enter(self, node: ir.Node, parent: Optional[ir.Node]=None) -> "Visitor":
197
223
  if contains_only_declarable_constraints(node):
198
224
  return self
199
225
 
226
+ handled_vars = self.handled_vars.get(parent.id, ordered_set()) if parent else ordered_set()
227
+ # Clone the set to avoid modifying parent's handled vars
228
+ handled_vars = OrderedSet.from_iterable(handled_vars)
229
+
200
230
  if isinstance(node, (ir.Logical, ir.Not)):
201
231
  ignored_vars = _ignored_vars(node)
202
- self._handled_vars.update(ignored_vars)
232
+ handled_vars.update(ignored_vars)
203
233
 
204
234
  scope_vars = self._vars_in_scope.get(node.id, None)
205
235
  if scope_vars:
206
- scope_vars.difference_update(self._handled_vars)
236
+ scope_vars.difference_update(handled_vars)
207
237
  if scope_vars:
208
- self._handled_vars.update(scope_vars)
238
+ handled_vars.update(scope_vars)
209
239
  self.node_quantifies_vars[node.id] = scope_vars
210
- return self
211
-
212
- def leave(self, node: ir.Node, parent: Optional[ir.Node]=None) -> ir.Node:
213
- if contains_only_declarable_constraints(node):
214
- return node
215
240
 
216
- if isinstance(node, (ir.Logical, ir.Not)):
217
- ignored_vars = _ignored_vars(node)
218
- self._handled_vars.difference_update(ignored_vars)
219
-
220
- scope_vars = self._vars_in_scope.get(node.id, None)
221
- if scope_vars:
222
- self._handled_vars.difference_update(scope_vars)
223
- return node
241
+ self.handled_vars[node.id] = handled_vars
242
+ return self
224
243
 
225
244
  class QuantifyVarsRewriter(Rewriter):
226
245
  """
@@ -254,7 +273,12 @@ class QuantifyVarsRewriter(Rewriter):
254
273
  # in IR directly may do so and the flatten pass doesn't split them yet.
255
274
  if len(agg_or_rank_tasks) > 0:
256
275
  print(f"Multiple aggregate/rank tasks found: {agg_or_rank_tasks} and {task}")
257
- agg_or_rank_tasks.append(task)
276
+ # If the agg/rank depends on any of the vars being quantified here,
277
+ # then it needs to be inside the quantification
278
+ if any(var in helpers.vars(task.projection) for var in vars):
279
+ inner_tasks.append(task)
280
+ else:
281
+ agg_or_rank_tasks.append(task)
258
282
 
259
283
  else:
260
284
  inner_tasks.append(task)
@@ -283,6 +307,16 @@ class QuantifyVarsRewriter(Rewriter):
283
307
 
284
308
  return node if node.task is new_task else f.not_(new_task)
285
309
 
310
+ def handle_union(self, node: ir.Union, parent: ir.Node, ctx:Optional[Any]=None) -> ir.Union:
311
+ if not node.tasks:
312
+ return node
313
+
314
+ new_tasks = self.walk_list(node.tasks, node)
315
+ return node if node.tasks is new_tasks else f.union(
316
+ tasks = new_tasks,
317
+ hoisted = node.hoisted,
318
+ )
319
+
286
320
  # To avoid unnecessary cloning of vars in the visitor.
287
321
  def handle_var(self, node: ir.Var, parent: ir.Node, ctx:Optional[Any]=None) -> ir.Var:
288
322
  return node