relationalai 0.13.4__py3-none-any.whl → 0.13.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. relationalai/clients/exec_txn_poller.py +51 -20
  2. relationalai/clients/local.py +15 -7
  3. relationalai/clients/resources/snowflake/__init__.py +2 -2
  4. relationalai/clients/resources/snowflake/direct_access_resources.py +8 -4
  5. relationalai/clients/resources/snowflake/snowflake.py +16 -11
  6. relationalai/experimental/solvers.py +8 -0
  7. relationalai/semantics/lqp/executor.py +3 -3
  8. relationalai/semantics/lqp/model2lqp.py +34 -28
  9. relationalai/semantics/lqp/passes.py +6 -3
  10. relationalai/semantics/lqp/result_helpers.py +76 -12
  11. relationalai/semantics/lqp/rewrite/__init__.py +2 -0
  12. relationalai/semantics/lqp/rewrite/extract_common.py +3 -1
  13. relationalai/semantics/lqp/rewrite/extract_keys.py +85 -20
  14. relationalai/semantics/lqp/rewrite/flatten_script.py +301 -0
  15. relationalai/semantics/lqp/rewrite/functional_dependencies.py +12 -7
  16. relationalai/semantics/lqp/rewrite/quantify_vars.py +12 -3
  17. relationalai/semantics/lqp/rewrite/unify_definitions.py +9 -3
  18. relationalai/semantics/metamodel/dependency.py +9 -0
  19. relationalai/semantics/metamodel/executor.py +17 -10
  20. relationalai/semantics/metamodel/rewrite/__init__.py +2 -1
  21. relationalai/semantics/metamodel/rewrite/flatten.py +1 -2
  22. relationalai/semantics/metamodel/rewrite/format_outputs.py +131 -46
  23. relationalai/semantics/metamodel/rewrite/handle_aggregations_and_ranks.py +237 -0
  24. relationalai/semantics/metamodel/typer/typer.py +1 -1
  25. relationalai/semantics/reasoners/optimization/solvers_pb.py +101 -107
  26. relationalai/semantics/rel/compiler.py +7 -3
  27. relationalai/semantics/rel/executor.py +1 -1
  28. relationalai/tools/txn_progress.py +188 -0
  29. {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/METADATA +1 -1
  30. {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/RECORD +33 -30
  31. {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/WHEEL +0 -0
  32. {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/entry_points.txt +0 -0
  33. {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/licenses/LICENSE +0 -0
@@ -15,7 +15,7 @@ Given an Output with a group of keys (some of them potentially null),
15
15
  * generate all the valid combinations of keys being present or not
16
16
  * first all keys are present,
17
17
  * then we remove one key at a time,
18
- * then we remove two keys at a time,and so on.
18
+ * then we remove two keys at a time, and so on.
19
19
  * the last combination is when all the *nullable* keys are missing.
20
20
  * for each combination:
21
21
  * create a compound (hash) key
@@ -103,10 +103,13 @@ Logical
103
103
  construct(Hash, "Foo", foo, compound_key)
104
104
  output[compound_key](v1, None, None)
105
105
  """
106
+
107
+
106
108
  class ExtractKeys(Pass):
107
- def rewrite(self, model: ir.Model, options:dict={}) -> ir.Model:
109
+ def rewrite(self, model: ir.Model, options: dict = {}) -> ir.Model:
108
110
  return ExtractKeysRewriter().walk(model)
109
111
 
112
+
110
113
  """
111
114
  * First, figure out all tasks that are common for all alternative logicals that will be
112
115
  generated
@@ -117,6 +120,8 @@ class ExtractKeys(Pass):
117
120
  missing (None will be filtered out in a later step -- we just need the column number to be
118
121
  the same here).
119
122
  """
123
+
124
+
120
125
  class ExtractKeysRewriter(Rewriter):
121
126
  def __init__(self):
122
127
  super().__init__()
@@ -129,7 +134,9 @@ class ExtractKeysRewriter(Rewriter):
129
134
  self.compound_keys[orig_keys] = compound_key
130
135
  return compound_key
131
136
 
132
- def handle_logical(self, node: ir.Logical, parent: ir.Node, ctx:Optional[Any]=None) -> ir.Logical:
137
+ def handle_logical(
138
+ self, node: ir.Logical, parent: ir.Node, ctx: Optional[Any] = None
139
+ ) -> ir.Logical:
133
140
  outputs = [x for x in node.body if isinstance(x, ir.Output) and x.keys]
134
141
  # We are not in a logical with an output at this level.
135
142
  if not outputs:
@@ -170,7 +177,9 @@ class ExtractKeysRewriter(Rewriter):
170
177
  partitions, deps = self.partition_tasks(flat_body, all_vars)
171
178
 
172
179
  # Compute all valid key combinations (keys that are not null)
173
- combinations = self.key_combinations(nullable_keys, deps, 0, non_nullable_keys.get_list())
180
+ combinations = self.key_combinations(
181
+ nullable_keys, deps, 0, non_nullable_keys.get_list()
182
+ )
174
183
  # there is no need to transform if there is only a single combination
175
184
  if len(combinations) == 1:
176
185
  return node
@@ -212,7 +221,9 @@ class ExtractKeysRewriter(Rewriter):
212
221
  values.append(ir.Literal(types.String, key.type.name))
213
222
  if key in key_combination:
214
223
  values.append(key)
215
- body.add(ir.Construct(None, tuple(values), compound_key, OrderedSet().frozen()))
224
+ body.add(
225
+ ir.Construct(None, tuple(values), compound_key, OrderedSet().frozen())
226
+ )
216
227
 
217
228
  # find variables used only inside the negated context
218
229
  negative_vars = OrderedSet[ir.Var]()
@@ -233,7 +244,9 @@ class ExtractKeysRewriter(Rewriter):
233
244
  problematic_out_vars = OrderedSet[ir.Var]()
234
245
  for out_var in out_vars:
235
246
  out_deps = deps[out_var]
236
- if out_var in missing_keys:
247
+ if out_var in negative_vars:
248
+ missing_out_vars.add(out_var)
249
+ elif out_var in missing_keys:
237
250
  missing_out_vars.add(out_var)
238
251
  elif any(x in missing_keys for x in out_deps):
239
252
  missing_out_vars.add(out_var)
@@ -248,8 +261,17 @@ class ExtractKeysRewriter(Rewriter):
248
261
  exclude_vars = out_vars
249
262
  has_problematic_var = False
250
263
 
251
- self.negate_missing_keys(body, missing_keys, var_to_default, partitions, deps,
252
- out_vars, exclude_vars, negative_vars, has_problematic_var)
264
+ self.negate_missing_keys(
265
+ body,
266
+ missing_keys,
267
+ var_to_default,
268
+ partitions,
269
+ deps,
270
+ out_vars,
271
+ exclude_vars,
272
+ negative_vars,
273
+ has_problematic_var,
274
+ )
253
275
 
254
276
  new_output_aliases = []
255
277
  for alias, out_value in output.aliases:
@@ -269,6 +291,23 @@ class ExtractKeysRewriter(Rewriter):
269
291
  # logicals that don't hoist variables are essentially filters like lookups
270
292
  if not node.hoisted:
271
293
  return True
294
+
295
+ # If the body contains an aggregate, and the Logical hoists only the aggregate
296
+ # output, then this node behaves as a lookup
297
+ if any(isinstance(t, ir.Aggregate) for t in node.body):
298
+ hoisted_vars = helpers.hoisted_vars(node.hoisted)
299
+ aggregate_outputs = []
300
+ for t in node.body:
301
+ if isinstance(t, ir.Aggregate):
302
+ aggregate_outputs.extend(
303
+ v
304
+ for v in helpers.vars(t.args)
305
+ if not helpers.is_aggregate_input(v, t)
306
+ )
307
+
308
+ if hoisted_vars == aggregate_outputs:
309
+ return True
310
+
272
311
  if len(node.body) != 1:
273
312
  return False
274
313
  inner = node.body[0]
@@ -347,9 +386,9 @@ class ExtractKeysRewriter(Rewriter):
347
386
 
348
387
  # given a set of variables, compute the tasks that each variable is using and also
349
388
  # other variables needed for this variable to bind correctly
350
- def partition_tasks(self, tasks:Iterable[ir.Task], vars:Iterable[ir.Var]):
351
- partitions:dict[ir.Var, OrderedSet[ir.Task]] = defaultdict(OrderedSet)
352
- dependencies:dict[ir.Var, OrderedSet[ir.Var]] = defaultdict(OrderedSet)
389
+ def partition_tasks(self, tasks: Iterable[ir.Task], vars: Iterable[ir.Var]):
390
+ partitions: dict[ir.Var, OrderedSet[ir.Task]] = defaultdict(OrderedSet)
391
+ dependencies: dict[ir.Var, OrderedSet[ir.Var]] = defaultdict(OrderedSet)
353
392
 
354
393
  def dfs_collect_deps(task, deps):
355
394
  if isinstance(task, ir.Lookup):
@@ -360,7 +399,7 @@ class ExtractKeysRewriter(Rewriter):
360
399
  deps[v].add(args[j])
361
400
  # for ternary+ lookups, a var also depends on the next vars
362
401
  if i > 0 and len(args) >= 3:
363
- for j in range(i+1, len(args)):
402
+ for j in range(i + 1, len(args)):
364
403
  deps[v].add(args[j])
365
404
  elif isinstance(task, ir.Construct):
366
405
  vars = helpers.vars(task.values)
@@ -436,11 +475,21 @@ class ExtractKeysRewriter(Rewriter):
436
475
  return partitions, dependencies
437
476
 
438
477
  # Generate all the valid combinations of non-nullable keys and nullable keys.
439
- def key_combinations(self, nullable_keys: OrderedSet[ir.Var], key_deps, idx: int, non_null_keys: list[ir.Var]) -> OrderedSet[Tuple[ir.Var]]:
478
+ def key_combinations(
479
+ self,
480
+ nullable_keys: OrderedSet[ir.Var],
481
+ key_deps,
482
+ idx: int,
483
+ non_null_keys: list[ir.Var],
484
+ ) -> OrderedSet[Tuple[ir.Var]]:
440
485
  if idx < len(nullable_keys):
441
486
  key = nullable_keys[idx]
442
- set1 = self.key_combinations(nullable_keys, key_deps, idx + 1, non_null_keys + [key])
443
- set2 = self.key_combinations(nullable_keys, key_deps, idx + 1, non_null_keys)
487
+ set1 = self.key_combinations(
488
+ nullable_keys, key_deps, idx + 1, non_null_keys + [key]
489
+ )
490
+ set2 = self.key_combinations(
491
+ nullable_keys, key_deps, idx + 1, non_null_keys
492
+ )
444
493
  set1.update(set2)
445
494
  return set1
446
495
  else:
@@ -449,13 +498,25 @@ class ExtractKeysRewriter(Rewriter):
449
498
  # If a key depends on other keys, all of them should be present in this combination.
450
499
  # If some dependency is not present, ignore the current key.
451
500
  deps = key_deps.get(k)
452
- if deps and any(dk in nullable_keys and dk not in non_null_keys for dk in deps):
501
+ if deps and any(
502
+ dk in nullable_keys and dk not in non_null_keys for dk in deps
503
+ ):
453
504
  continue
454
505
  final_keys.append(k)
455
506
  return OrderedSet.from_iterable([tuple(final_keys)])
456
507
 
457
- def negate_missing_keys(self, body, missing_keys, var_to_default, partitions, deps,
458
- out_vars, exclude_vars, negative_vars, has_problematic_var:bool):
508
+ def negate_missing_keys(
509
+ self,
510
+ body,
511
+ missing_keys,
512
+ var_to_default,
513
+ partitions,
514
+ deps,
515
+ out_vars,
516
+ exclude_vars,
517
+ negative_vars,
518
+ has_problematic_var: bool,
519
+ ):
459
520
  # for keys that are not present in the current combination
460
521
  # we have to include their tasks negated
461
522
  negated_tasks = OrderedSet[ir.Task]()
@@ -492,7 +553,9 @@ class ExtractKeysRewriter(Rewriter):
492
553
  out_deps = deps[out_var]
493
554
  if has_problematic_var and any(x in missing_keys for x in out_deps):
494
555
  continue
495
- elif not has_problematic_var and any(x in missing_keys or x in negative_vars for x in out_deps):
556
+ elif not has_problematic_var and any(
557
+ x in missing_keys or x in negative_vars for x in out_deps
558
+ ):
496
559
  continue
497
560
 
498
561
  default = var_to_default.get(out_var)
@@ -509,4 +572,6 @@ class ExtractKeysRewriter(Rewriter):
509
572
  else:
510
573
  property_body.update(partition)
511
574
  if property_body:
512
- body.add(f.logical(tuple(property_body), [default] if default else []))
575
+ body.add(
576
+ f.logical(tuple(property_body), [default] if default else [out_var])
577
+ )
@@ -0,0 +1,301 @@
1
+ from __future__ import annotations
2
+
3
+ from relationalai.semantics.metamodel import ir, factory as f, helpers
4
+ from relationalai.semantics.metamodel.util import OrderedSet, ordered_set
5
+ from relationalai.semantics.metamodel.rewrite.flatten import Flatten, negate, extend_body
6
+ from relationalai.semantics.lqp.algorithms import is_script, mk_assign
7
+
8
+ class FlattenScript(Flatten):
9
+ """
10
+ Flattens Match nodes inside @script Sequence blocks. This pass extends Flatten
11
+ of standard Logicals, to reuse a number of utilities, but DOES NOT flatten Match
12
+ nodes outside of scripts (which are handled by the Flatten pass).
13
+
14
+ Unlike the regular Flatten pass which extracts to top-level, FlattenScript
15
+ maintains order by inserting intermediate relations right before they're used
16
+ within the Sequence. This is necessary because order matters in a script Sequence.
17
+
18
+ Additionally, if the original Logical is a Loopy instruction, the flattened Logicals are
19
+ made Loopy too (`@assign`).
20
+
21
+ Example:
22
+
23
+ === BEFORE ===
24
+ Logical
25
+ Sequence @script @algorithm
26
+ Logical
27
+ dom(n)
28
+ Match ⇑[k]
29
+ Logical ⇑[k]
30
+ value(n, k)
31
+ Logical ⇑[k]
32
+ k = 0
33
+ → derive result(n, k) @assign @global
34
+ filter(n)
35
+
36
+ === AFTER ===
37
+ Logical
38
+ Sequence @script @algorithm
39
+ Logical
40
+ dom(n)
41
+ filter(n)
42
+ Logical ⇑[v]
43
+ value(n, v)
44
+ → derive _match_1(n, v) @assign
45
+ Logical
46
+ dom(n)
47
+ filter(n)
48
+ Logical ⇑[v]
49
+ v = 0
50
+ Not
51
+ _match_1(n, _)
52
+ → derive _match_2(n, v) @assign
53
+ Logical
54
+ dom(n)
55
+ filter(n)
56
+ Union ⇑[v]
57
+ _match_1(n, v)
58
+ _match_2(n, v)
59
+ → derive result(n, v) @assign @global
60
+ """
61
+
62
+ class Context(Flatten.Context):
63
+ """Extended context with script tracking."""
64
+ def __init__(self, model: ir.Model, options: dict):
65
+ super().__init__(model, options)
66
+ self.in_script: bool = False
67
+
68
+ def rewrite(self, model: ir.Model, options: dict = {}) -> ir.Model:
69
+ """Traverse the model and flatten Match nodes inside script Sequences."""
70
+ ctx = FlattenScript.Context(model, options)
71
+ result = self.handle(model.root, ctx)
72
+
73
+ if result.replacement is None:
74
+ return model
75
+
76
+ # Convert relations list to FrozenOrderedSet (adding any new intermediate relations)
77
+ new_relations = OrderedSet.from_iterable(model.relations).update(ctx.rewrite_ctx.relations).frozen()
78
+
79
+ return ir.Model(
80
+ model.engines,
81
+ new_relations,
82
+ model.types,
83
+ result.replacement
84
+ )
85
+
86
+ def handle(self, task: ir.Task, ctx: Flatten.Context) -> Flatten.HandleResult:
87
+ """Override handle to add Loop support."""
88
+ if isinstance(task, ir.Loop):
89
+ return self.handle_loop(task, ctx)
90
+ return super().handle(task, ctx)
91
+
92
+ def handle_loop(self, task: ir.Loop, ctx: Flatten.Context) -> Flatten.HandleResult:
93
+ """Recursively handle the body of the loop."""
94
+ result = self.handle(task.body, ctx)
95
+
96
+ assert(result.replacement)
97
+
98
+ # If body unchanged, return original loop
99
+ if result.replacement is task.body:
100
+ return Flatten.HandleResult(task)
101
+
102
+ # Return new loop with handled body
103
+ return Flatten.HandleResult(ir.Loop(
104
+ task.engine,
105
+ task.hoisted,
106
+ task.iter,
107
+ result.replacement,
108
+ task.concurrency,
109
+ task.annotations
110
+ ))
111
+
112
+ def handle_logical(self, task: ir.Logical, ctx: Context): # type: ignore[override]
113
+ """
114
+ Handle Logical nodes.
115
+
116
+ Outside scripts: simple traversal to find nested scripts.
117
+ Inside scripts: prevent extraction to top-level and keep everything in sequence.
118
+
119
+ Note: Type checker complains about parameter type narrowing, but this is safe
120
+ because we only create FlattenScript.Context in our own rewrite() method.
121
+ """
122
+ # Recursively process children
123
+ body: OrderedSet[ir.Task] = ordered_set()
124
+ for child in task.body:
125
+ result = self.handle(child, ctx)
126
+ if result.replacement is not None:
127
+ if ctx.in_script and isinstance(result.replacement, ir.Logical) and not result.replacement.hoisted:
128
+ # Inside script: inline simple logicals without hoisting
129
+ body.update(result.replacement.body)
130
+ else:
131
+ body.add(result.replacement)
132
+
133
+ if not body:
134
+ return Flatten.HandleResult(None)
135
+
136
+ return Flatten.HandleResult(
137
+ ir.Logical(task.engine, task.hoisted, tuple(body), task.annotations)
138
+ )
139
+
140
+ def flatten_match_in_logical(self, logical: ir.Logical, match: ir.Match, match_idx: int, ctx: Context) -> list[ir.Logical]:
141
+ """
142
+ Flatten a Match inside a Logical within a script Sequence.
143
+ Returns a list of Logicals to be inserted in sequence.
144
+ """
145
+ if not match.tasks:
146
+ return [logical]
147
+
148
+ # Split the logical into: tasks_before, match, tasks_after
149
+ tasks_before = list(logical.body[:match_idx])
150
+ tasks_after = list(logical.body[match_idx + 1:])
151
+
152
+ # Separate tasks_after into filters (non-Update tasks) and updates (Update tasks)
153
+ # Filters are constraints that should be included in all branches
154
+ filters = [task for task in tasks_after if not isinstance(task, ir.Update)]
155
+ updates = [task for task in tasks_after if isinstance(task, ir.Update)]
156
+
157
+ # Compute exposed variables
158
+ exposed_vars = self.compute_exposed_vars(match, match.tasks, ctx)
159
+
160
+ # Use dependency analysis for branch bodies (like flatten.py does)
161
+ branch_dependencies = ctx.info.task_dependencies(match)
162
+
163
+ # Collect all dependencies for the final Logical (tasks before + filters after)
164
+ final_dependencies = tasks_before + filters
165
+
166
+ # Negation length for wildcards
167
+ outputs = ctx.info.task_outputs(match)
168
+ negation_len = len(outputs) if outputs else 0
169
+
170
+ # Result: list of Logicals to insert
171
+ result_logicals = []
172
+ references = []
173
+ negated_reference = None
174
+
175
+ # Process each branch
176
+ for branch in match.tasks:
177
+ # Create connection relation for this branch
178
+ name = helpers.create_task_name(self.name_cache, branch, "_match")
179
+ relation = helpers.create_connection_relation(branch, exposed_vars, ctx.rewrite_ctx, name)
180
+
181
+ # Handle the branch (recursively process nested structures)
182
+ result = self.handle(branch, ctx)
183
+ branch_content = result.replacement if result.replacement else branch
184
+
185
+ # Update dependency tracking if branch was transformed
186
+ if result.replacement:
187
+ ctx.info.replaced(branch, result.replacement)
188
+
189
+ # Build logical for this branch
190
+ branch_body: OrderedSet[ir.Task] = ordered_set()
191
+
192
+ # Add dependencies (using dependency analysis, not all tasks)
193
+ branch_body.update(branch_dependencies)
194
+
195
+ # Add branch content using extend_body helper
196
+ extend_body(branch_body, branch_content)
197
+
198
+ # Add negation of previous branches (after branch content)
199
+ if negated_reference:
200
+ branch_body.add(negated_reference)
201
+
202
+ # Add derive to connection relation
203
+ branch_update = f.derive(relation, exposed_vars)
204
+ branch_body.add(branch_update)
205
+
206
+ # Create the Logical for this branch
207
+ branch_logical = mk_assign(ir.Logical(match.engine, tuple(), tuple(branch_body)))
208
+
209
+ result_logicals.append(branch_logical)
210
+
211
+ # Update references for final union
212
+ reference = f.lookup(relation, exposed_vars)
213
+ negated_reference = negate(reference, negation_len)
214
+ references.append(reference)
215
+
216
+ # Create final Logical with Union and remaining tasks
217
+ final_body: OrderedSet[ir.Task] = ordered_set()
218
+ final_body.update(final_dependencies)
219
+
220
+ # Add union of all branches
221
+ union = f.union(references, match.hoisted)
222
+ final_body.add(union)
223
+
224
+ # Add updates that came after the match (filters are already in dependencies)
225
+ final_body.update(updates)
226
+
227
+ # Create final logical preserving the original annotations
228
+ final_logical = ir.Logical(logical.engine, logical.hoisted, tuple(final_body), logical.annotations)
229
+
230
+ result_logicals.append(final_logical)
231
+
232
+ return result_logicals
233
+
234
+ def handle_sequence(self, task: ir.Sequence, ctx: Context): # type: ignore[override]
235
+ """
236
+ Handle a Sequence.
237
+
238
+ If it's a script: set context flag and flatten Match nodes within.
239
+ If not a script: simple traversal (Flatten pass already processed it).
240
+
241
+ Note: Type checker complains about parameter type narrowing, but this is safe
242
+ because we only create FlattenScript.Context in our own rewrite() method.
243
+ """
244
+ if not is_script(task):
245
+ # Not a script sequence - already processed by Flatten, just return as-is
246
+ return Flatten.HandleResult(task)
247
+
248
+ # This is a script - mark context and process with flattening
249
+ old_in_script = ctx.in_script
250
+ ctx.in_script = True
251
+
252
+ # Process the sequence tasks
253
+ new_tasks: list[ir.Task] = []
254
+ for child in task.tasks:
255
+ # Check if this child is a Logical with Match that needs flattening
256
+ if isinstance(child, ir.Logical):
257
+ new_tasks.extend(self.try_flatten_logical(child, ctx))
258
+ continue
259
+
260
+ # No flattening needed, process normally
261
+ result = self.handle(child, ctx)
262
+ if result.replacement is not None:
263
+ new_tasks.append(result.replacement)
264
+
265
+ # Restore context
266
+ ctx.in_script = old_in_script
267
+
268
+ return Flatten.HandleResult(
269
+ ir.Sequence(task.engine, task.hoisted, tuple(new_tasks), task.annotations)
270
+ )
271
+
272
+ def try_flatten_logical(self, logical: ir.Logical, ctx: Context) -> list[ir.Logical]:
273
+ """
274
+ Flatten all Matches in a Logical.
275
+ Iteratively flattens until no more Matches remain in any of the resulting Logicals.
276
+ """
277
+ worklist = [logical]
278
+ result = []
279
+
280
+ while worklist:
281
+ current = worklist.pop()
282
+
283
+ # Find first Match in current logical
284
+ match = None
285
+ match_idx = -1
286
+ for i, child in enumerate(current.body):
287
+ if isinstance(child, ir.Match):
288
+ match = child
289
+ match_idx = i
290
+ break
291
+
292
+ if match is None:
293
+ # No Match found - this logical is done
294
+ result.append(current)
295
+ else:
296
+ # Flatten and add results back to worklist for further processing.
297
+ # Reverse so that pop() returns them in the original order.
298
+ flattened = self.flatten_match_in_logical(current, match, match_idx, ctx)
299
+ worklist.extend(reversed(flattened))
300
+
301
+ return result
@@ -322,22 +322,27 @@ class FunctionalDependency:
322
322
 
323
323
  def contains_only_declarable_constraints(node: Node) -> bool:
324
324
  """
325
- Checks whether the input `Logical` node contains only `Require` nodes annotated with
326
- `declare_constraint`.
325
+ Checks whether the input node contains only `Require` nodes annotated with
326
+ `declare_constraint` (or such a node itself).
327
327
  """
328
+ # Check if the node itself is a Require node with declarable constraint
329
+ if isinstance(node, Require):
330
+ return is_declarable_constraint(node)
331
+
332
+ # Otherwise, check if it is a Logical node containing only declarable constraints
328
333
  if not isinstance(node, Logical):
329
334
  return False
330
335
  if len(node.body) == 0:
331
336
  return False
332
337
  for task in node.body:
333
- if not isinstance(task, Require):
334
- return False
335
- if not is_declarable_constraint(task):
338
+ if not contains_only_declarable_constraints(task):
336
339
  return False
337
340
  return True
338
341
 
339
- def is_declarable_constraint(node: Require) -> bool:
342
+ def is_declarable_constraint(node: Node) -> bool:
340
343
  """
341
- Checks whether the input `Require` node is annotated with `declare_constraint`.
344
+ Checks whether the input node is a `Require` node annotated with `declare_constraint`.
342
345
  """
346
+ if not isinstance(node, Require):
347
+ return False
343
348
  return builtins.declare_constraint_annotation in node.annotations
@@ -45,7 +45,7 @@ def _ignored_vars(node: ir.Logical|ir.Not):
45
45
 
46
46
  elif isinstance(task, ir.Rank):
47
47
  # Variables that are keys, and not in the group-by, don't need to be quantified.
48
- for var in task.args + task.projection:
48
+ for var in task.args:
49
49
  if var not in task.group:
50
50
  vars_to_ignore.add(var)
51
51
 
@@ -149,13 +149,22 @@ class VarScopeInfo(Visitor):
149
149
  if isinstance(task, ir.Output):
150
150
  output_vars.update(helpers.output_vars(task.aliases))
151
151
 
152
- if isinstance(task, (ir.Aggregate, ir.Rank)):
153
- # Variables that are in the group-by, and not in the projections, can come into scope.
152
+ if isinstance(task, ir.Aggregate):
153
+ # Variables that are in the group-by, and not in the args,
154
+ # can come into scope.
154
155
  for var in task.group:
155
156
  if var not in task.args:
156
157
  scope_vars.add(var)
157
158
  continue
158
159
 
160
+ if isinstance(task, ir.Rank):
161
+ # Variables that are in the group-by or projection, and not in the args,
162
+ # can come into scope.
163
+ for var in task.group + task.projection:
164
+ if var not in task.args:
165
+ scope_vars.add(var)
166
+ continue
167
+
159
168
  # Hoisted variables from sub-tasks are brought again into scope.
160
169
  if isinstance(task, (ir.Logical, ir.Union, ir.Match)):
161
170
  scope_vars.update(helpers.hoisted_vars(task.hoisted))
@@ -54,9 +54,9 @@ class UnifyDefinitions(Pass):
54
54
  return head.relation
55
55
  else:
56
56
  assert isinstance(head, ir.Output)
57
- if len(head.aliases) <= 2:
58
- # For processing here, we need output to have at least the column markers
59
- # `cols` and `col`, and also a key
57
+ if len(head.aliases) < 2:
58
+ # For processing here, we need output to have at least the column marker
59
+ # `keys` or both `cols` and `col` markers, and also a key
60
60
  return None
61
61
 
62
62
  output_alias_names = helpers.output_alias_names(head.aliases)
@@ -65,6 +65,8 @@ class UnifyDefinitions(Pass):
65
65
  # For normal outputs, the pattern is output[keys](cols, "col000" as 'col', ...)
66
66
  if output_alias_names[0] == "cols" and output_alias_names[1] == "col":
67
67
  return output_vals[1]
68
+ if output_alias_names[0] == "keys": # handle wide keys relation
69
+ return output_vals[0]
68
70
 
69
71
  # For exports, the pattern is output[keys]("col000" as 'col', ...)
70
72
  if helpers.is_export(head):
@@ -122,6 +124,10 @@ class UnifyDefinitions(Pass):
122
124
  # keys.
123
125
  output_values = helpers.output_values(head.aliases)[2:]
124
126
 
127
+ elif output_alias_names[0] == "keys": # handle wide keys output
128
+ assert len(head.aliases) > 1
129
+ output_values = helpers.output_values(head.aliases)[1:]
130
+
125
131
  else:
126
132
  assert helpers.is_export(head) and output_alias_names[0] == "col"
127
133
  assert len(head.aliases) > 1
@@ -500,6 +500,15 @@ class DependencyAnalysis(visitor.Visitor):
500
500
  assert(isinstance(c2task, helpers.COMPOSITES))
501
501
  if not c2task.hoisted:
502
502
  return True
503
+
504
+ # c1 is a composite with hoisted variables; it depends on c2 if c2 is mergeable
505
+ # and they share variables, hence behaving like a filter
506
+ if c1.composite and c2.mergeable and c1.shares_variable(c2):
507
+ c1task = c1.content.some()
508
+ assert(isinstance(c1task, helpers.COMPOSITES))
509
+ if c1task.hoisted:
510
+ return True
511
+
503
512
  return False
504
513
 
505
514
  cs = list(clusters)