relationalai 0.13.4__py3-none-any.whl → 0.13.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/clients/exec_txn_poller.py +51 -20
- relationalai/clients/local.py +15 -7
- relationalai/clients/resources/snowflake/__init__.py +2 -2
- relationalai/clients/resources/snowflake/direct_access_resources.py +8 -4
- relationalai/clients/resources/snowflake/snowflake.py +16 -11
- relationalai/experimental/solvers.py +8 -0
- relationalai/semantics/lqp/executor.py +3 -3
- relationalai/semantics/lqp/model2lqp.py +34 -28
- relationalai/semantics/lqp/passes.py +6 -3
- relationalai/semantics/lqp/result_helpers.py +76 -12
- relationalai/semantics/lqp/rewrite/__init__.py +2 -0
- relationalai/semantics/lqp/rewrite/extract_common.py +3 -1
- relationalai/semantics/lqp/rewrite/extract_keys.py +85 -20
- relationalai/semantics/lqp/rewrite/flatten_script.py +301 -0
- relationalai/semantics/lqp/rewrite/functional_dependencies.py +12 -7
- relationalai/semantics/lqp/rewrite/quantify_vars.py +12 -3
- relationalai/semantics/lqp/rewrite/unify_definitions.py +9 -3
- relationalai/semantics/metamodel/dependency.py +9 -0
- relationalai/semantics/metamodel/executor.py +17 -10
- relationalai/semantics/metamodel/rewrite/__init__.py +2 -1
- relationalai/semantics/metamodel/rewrite/flatten.py +1 -2
- relationalai/semantics/metamodel/rewrite/format_outputs.py +131 -46
- relationalai/semantics/metamodel/rewrite/handle_aggregations_and_ranks.py +237 -0
- relationalai/semantics/metamodel/typer/typer.py +1 -1
- relationalai/semantics/reasoners/optimization/solvers_pb.py +101 -107
- relationalai/semantics/rel/compiler.py +7 -3
- relationalai/semantics/rel/executor.py +1 -1
- relationalai/tools/txn_progress.py +188 -0
- {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/METADATA +1 -1
- {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/RECORD +33 -30
- {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/WHEEL +0 -0
- {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/entry_points.txt +0 -0
- {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -15,7 +15,7 @@ Given an Output with a group of keys (some of them potentially null),
|
|
|
15
15
|
* generate all the valid combinations of keys being present or not
|
|
16
16
|
* first all keys are present,
|
|
17
17
|
* then we remove one key at a time,
|
|
18
|
-
* then we remove two keys at a time,and so on.
|
|
18
|
+
* then we remove two keys at a time, and so on.
|
|
19
19
|
* the last combination is when all the *nullable* keys are missing.
|
|
20
20
|
* for each combination:
|
|
21
21
|
* create a compound (hash) key
|
|
@@ -103,10 +103,13 @@ Logical
|
|
|
103
103
|
construct(Hash, "Foo", foo, compound_key)
|
|
104
104
|
output[compound_key](v1, None, None)
|
|
105
105
|
"""
|
|
106
|
+
|
|
107
|
+
|
|
106
108
|
class ExtractKeys(Pass):
|
|
107
|
-
def rewrite(self, model: ir.Model, options:dict={}) -> ir.Model:
|
|
109
|
+
def rewrite(self, model: ir.Model, options: dict = {}) -> ir.Model:
|
|
108
110
|
return ExtractKeysRewriter().walk(model)
|
|
109
111
|
|
|
112
|
+
|
|
110
113
|
"""
|
|
111
114
|
* First, figure out all tasks that are common for all alternative logicals that will be
|
|
112
115
|
generated
|
|
@@ -117,6 +120,8 @@ class ExtractKeys(Pass):
|
|
|
117
120
|
missing (None will be filtered out in a later step -- we just need the column number to be
|
|
118
121
|
the same here).
|
|
119
122
|
"""
|
|
123
|
+
|
|
124
|
+
|
|
120
125
|
class ExtractKeysRewriter(Rewriter):
|
|
121
126
|
def __init__(self):
|
|
122
127
|
super().__init__()
|
|
@@ -129,7 +134,9 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
129
134
|
self.compound_keys[orig_keys] = compound_key
|
|
130
135
|
return compound_key
|
|
131
136
|
|
|
132
|
-
def handle_logical(
|
|
137
|
+
def handle_logical(
|
|
138
|
+
self, node: ir.Logical, parent: ir.Node, ctx: Optional[Any] = None
|
|
139
|
+
) -> ir.Logical:
|
|
133
140
|
outputs = [x for x in node.body if isinstance(x, ir.Output) and x.keys]
|
|
134
141
|
# We are not in a logical with an output at this level.
|
|
135
142
|
if not outputs:
|
|
@@ -170,7 +177,9 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
170
177
|
partitions, deps = self.partition_tasks(flat_body, all_vars)
|
|
171
178
|
|
|
172
179
|
# Compute all valid key combinations (keys that are not null)
|
|
173
|
-
combinations = self.key_combinations(
|
|
180
|
+
combinations = self.key_combinations(
|
|
181
|
+
nullable_keys, deps, 0, non_nullable_keys.get_list()
|
|
182
|
+
)
|
|
174
183
|
# there is no need to transform if there is only a single combination
|
|
175
184
|
if len(combinations) == 1:
|
|
176
185
|
return node
|
|
@@ -212,7 +221,9 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
212
221
|
values.append(ir.Literal(types.String, key.type.name))
|
|
213
222
|
if key in key_combination:
|
|
214
223
|
values.append(key)
|
|
215
|
-
body.add(
|
|
224
|
+
body.add(
|
|
225
|
+
ir.Construct(None, tuple(values), compound_key, OrderedSet().frozen())
|
|
226
|
+
)
|
|
216
227
|
|
|
217
228
|
# find variables used only inside the negated context
|
|
218
229
|
negative_vars = OrderedSet[ir.Var]()
|
|
@@ -233,7 +244,9 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
233
244
|
problematic_out_vars = OrderedSet[ir.Var]()
|
|
234
245
|
for out_var in out_vars:
|
|
235
246
|
out_deps = deps[out_var]
|
|
236
|
-
if out_var in
|
|
247
|
+
if out_var in negative_vars:
|
|
248
|
+
missing_out_vars.add(out_var)
|
|
249
|
+
elif out_var in missing_keys:
|
|
237
250
|
missing_out_vars.add(out_var)
|
|
238
251
|
elif any(x in missing_keys for x in out_deps):
|
|
239
252
|
missing_out_vars.add(out_var)
|
|
@@ -248,8 +261,17 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
248
261
|
exclude_vars = out_vars
|
|
249
262
|
has_problematic_var = False
|
|
250
263
|
|
|
251
|
-
self.negate_missing_keys(
|
|
252
|
-
|
|
264
|
+
self.negate_missing_keys(
|
|
265
|
+
body,
|
|
266
|
+
missing_keys,
|
|
267
|
+
var_to_default,
|
|
268
|
+
partitions,
|
|
269
|
+
deps,
|
|
270
|
+
out_vars,
|
|
271
|
+
exclude_vars,
|
|
272
|
+
negative_vars,
|
|
273
|
+
has_problematic_var,
|
|
274
|
+
)
|
|
253
275
|
|
|
254
276
|
new_output_aliases = []
|
|
255
277
|
for alias, out_value in output.aliases:
|
|
@@ -269,6 +291,23 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
269
291
|
# logicals that don't hoist variables are essentially filters like lookups
|
|
270
292
|
if not node.hoisted:
|
|
271
293
|
return True
|
|
294
|
+
|
|
295
|
+
# If the body contains an aggregate, and the Logical hoists only the aggregate
|
|
296
|
+
# output, then this node behaves as a lookup
|
|
297
|
+
if any(isinstance(t, ir.Aggregate) for t in node.body):
|
|
298
|
+
hoisted_vars = helpers.hoisted_vars(node.hoisted)
|
|
299
|
+
aggregate_outputs = []
|
|
300
|
+
for t in node.body:
|
|
301
|
+
if isinstance(t, ir.Aggregate):
|
|
302
|
+
aggregate_outputs.extend(
|
|
303
|
+
v
|
|
304
|
+
for v in helpers.vars(t.args)
|
|
305
|
+
if not helpers.is_aggregate_input(v, t)
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
if hoisted_vars == aggregate_outputs:
|
|
309
|
+
return True
|
|
310
|
+
|
|
272
311
|
if len(node.body) != 1:
|
|
273
312
|
return False
|
|
274
313
|
inner = node.body[0]
|
|
@@ -347,9 +386,9 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
347
386
|
|
|
348
387
|
# given a set of variables, compute the tasks that each variable is using and also
|
|
349
388
|
# other variables needed for this variable to bind correctly
|
|
350
|
-
def partition_tasks(self, tasks:Iterable[ir.Task], vars:Iterable[ir.Var]):
|
|
351
|
-
partitions:dict[ir.Var, OrderedSet[ir.Task]] = defaultdict(OrderedSet)
|
|
352
|
-
dependencies:dict[ir.Var, OrderedSet[ir.Var]] = defaultdict(OrderedSet)
|
|
389
|
+
def partition_tasks(self, tasks: Iterable[ir.Task], vars: Iterable[ir.Var]):
|
|
390
|
+
partitions: dict[ir.Var, OrderedSet[ir.Task]] = defaultdict(OrderedSet)
|
|
391
|
+
dependencies: dict[ir.Var, OrderedSet[ir.Var]] = defaultdict(OrderedSet)
|
|
353
392
|
|
|
354
393
|
def dfs_collect_deps(task, deps):
|
|
355
394
|
if isinstance(task, ir.Lookup):
|
|
@@ -360,7 +399,7 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
360
399
|
deps[v].add(args[j])
|
|
361
400
|
# for ternary+ lookups, a var also depends on the next vars
|
|
362
401
|
if i > 0 and len(args) >= 3:
|
|
363
|
-
for j in range(i+1, len(args)):
|
|
402
|
+
for j in range(i + 1, len(args)):
|
|
364
403
|
deps[v].add(args[j])
|
|
365
404
|
elif isinstance(task, ir.Construct):
|
|
366
405
|
vars = helpers.vars(task.values)
|
|
@@ -436,11 +475,21 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
436
475
|
return partitions, dependencies
|
|
437
476
|
|
|
438
477
|
# Generate all the valid combinations of non-nullable keys and nullable keys.
|
|
439
|
-
def key_combinations(
|
|
478
|
+
def key_combinations(
|
|
479
|
+
self,
|
|
480
|
+
nullable_keys: OrderedSet[ir.Var],
|
|
481
|
+
key_deps,
|
|
482
|
+
idx: int,
|
|
483
|
+
non_null_keys: list[ir.Var],
|
|
484
|
+
) -> OrderedSet[Tuple[ir.Var]]:
|
|
440
485
|
if idx < len(nullable_keys):
|
|
441
486
|
key = nullable_keys[idx]
|
|
442
|
-
set1 = self.key_combinations(
|
|
443
|
-
|
|
487
|
+
set1 = self.key_combinations(
|
|
488
|
+
nullable_keys, key_deps, idx + 1, non_null_keys + [key]
|
|
489
|
+
)
|
|
490
|
+
set2 = self.key_combinations(
|
|
491
|
+
nullable_keys, key_deps, idx + 1, non_null_keys
|
|
492
|
+
)
|
|
444
493
|
set1.update(set2)
|
|
445
494
|
return set1
|
|
446
495
|
else:
|
|
@@ -449,13 +498,25 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
449
498
|
# If a key depends on other keys, all of them should be present in this combination.
|
|
450
499
|
# If some dependency is not present, ignore the current key.
|
|
451
500
|
deps = key_deps.get(k)
|
|
452
|
-
if deps and any(
|
|
501
|
+
if deps and any(
|
|
502
|
+
dk in nullable_keys and dk not in non_null_keys for dk in deps
|
|
503
|
+
):
|
|
453
504
|
continue
|
|
454
505
|
final_keys.append(k)
|
|
455
506
|
return OrderedSet.from_iterable([tuple(final_keys)])
|
|
456
507
|
|
|
457
|
-
def negate_missing_keys(
|
|
458
|
-
|
|
508
|
+
def negate_missing_keys(
|
|
509
|
+
self,
|
|
510
|
+
body,
|
|
511
|
+
missing_keys,
|
|
512
|
+
var_to_default,
|
|
513
|
+
partitions,
|
|
514
|
+
deps,
|
|
515
|
+
out_vars,
|
|
516
|
+
exclude_vars,
|
|
517
|
+
negative_vars,
|
|
518
|
+
has_problematic_var: bool,
|
|
519
|
+
):
|
|
459
520
|
# for keys that are not present in the current combination
|
|
460
521
|
# we have to include their tasks negated
|
|
461
522
|
negated_tasks = OrderedSet[ir.Task]()
|
|
@@ -492,7 +553,9 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
492
553
|
out_deps = deps[out_var]
|
|
493
554
|
if has_problematic_var and any(x in missing_keys for x in out_deps):
|
|
494
555
|
continue
|
|
495
|
-
elif not has_problematic_var and any(
|
|
556
|
+
elif not has_problematic_var and any(
|
|
557
|
+
x in missing_keys or x in negative_vars for x in out_deps
|
|
558
|
+
):
|
|
496
559
|
continue
|
|
497
560
|
|
|
498
561
|
default = var_to_default.get(out_var)
|
|
@@ -509,4 +572,6 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
509
572
|
else:
|
|
510
573
|
property_body.update(partition)
|
|
511
574
|
if property_body:
|
|
512
|
-
body.add(
|
|
575
|
+
body.add(
|
|
576
|
+
f.logical(tuple(property_body), [default] if default else [out_var])
|
|
577
|
+
)
|
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from relationalai.semantics.metamodel import ir, factory as f, helpers
|
|
4
|
+
from relationalai.semantics.metamodel.util import OrderedSet, ordered_set
|
|
5
|
+
from relationalai.semantics.metamodel.rewrite.flatten import Flatten, negate, extend_body
|
|
6
|
+
from relationalai.semantics.lqp.algorithms import is_script, mk_assign
|
|
7
|
+
|
|
8
|
+
class FlattenScript(Flatten):
|
|
9
|
+
"""
|
|
10
|
+
Flattens Match nodes inside @script Sequence blocks. This pass extends Flatten
|
|
11
|
+
of standard Logicals, to reuse a number of utilities, but DOES NOT flatten Match
|
|
12
|
+
nodes outside of scripts (which are handled by the Flatten pass).
|
|
13
|
+
|
|
14
|
+
Unlike the regular Flatten pass which extracts to top-level, FlattenScript
|
|
15
|
+
maintains order by inserting intermediate relations right before they're used
|
|
16
|
+
within the Sequence. This is necessary because order matters in a script Sequence.
|
|
17
|
+
|
|
18
|
+
Additionally, if the original Logical is a Loopy instruction, the flattened Logicals are
|
|
19
|
+
made Loopy too (`@assign`).
|
|
20
|
+
|
|
21
|
+
Example:
|
|
22
|
+
|
|
23
|
+
=== BEFORE ===
|
|
24
|
+
Logical
|
|
25
|
+
Sequence @script @algorithm
|
|
26
|
+
Logical
|
|
27
|
+
dom(n)
|
|
28
|
+
Match ⇑[k]
|
|
29
|
+
Logical ⇑[k]
|
|
30
|
+
value(n, k)
|
|
31
|
+
Logical ⇑[k]
|
|
32
|
+
k = 0
|
|
33
|
+
→ derive result(n, k) @assign @global
|
|
34
|
+
filter(n)
|
|
35
|
+
|
|
36
|
+
=== AFTER ===
|
|
37
|
+
Logical
|
|
38
|
+
Sequence @script @algorithm
|
|
39
|
+
Logical
|
|
40
|
+
dom(n)
|
|
41
|
+
filter(n)
|
|
42
|
+
Logical ⇑[v]
|
|
43
|
+
value(n, v)
|
|
44
|
+
→ derive _match_1(n, v) @assign
|
|
45
|
+
Logical
|
|
46
|
+
dom(n)
|
|
47
|
+
filter(n)
|
|
48
|
+
Logical ⇑[v]
|
|
49
|
+
v = 0
|
|
50
|
+
Not
|
|
51
|
+
_match_1(n, _)
|
|
52
|
+
→ derive _match_2(n, v) @assign
|
|
53
|
+
Logical
|
|
54
|
+
dom(n)
|
|
55
|
+
filter(n)
|
|
56
|
+
Union ⇑[v]
|
|
57
|
+
_match_1(n, v)
|
|
58
|
+
_match_2(n, v)
|
|
59
|
+
→ derive result(n, v) @assign @global
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
class Context(Flatten.Context):
|
|
63
|
+
"""Extended context with script tracking."""
|
|
64
|
+
def __init__(self, model: ir.Model, options: dict):
|
|
65
|
+
super().__init__(model, options)
|
|
66
|
+
self.in_script: bool = False
|
|
67
|
+
|
|
68
|
+
def rewrite(self, model: ir.Model, options: dict = {}) -> ir.Model:
|
|
69
|
+
"""Traverse the model and flatten Match nodes inside script Sequences."""
|
|
70
|
+
ctx = FlattenScript.Context(model, options)
|
|
71
|
+
result = self.handle(model.root, ctx)
|
|
72
|
+
|
|
73
|
+
if result.replacement is None:
|
|
74
|
+
return model
|
|
75
|
+
|
|
76
|
+
# Convert relations list to FrozenOrderedSet (adding any new intermediate relations)
|
|
77
|
+
new_relations = OrderedSet.from_iterable(model.relations).update(ctx.rewrite_ctx.relations).frozen()
|
|
78
|
+
|
|
79
|
+
return ir.Model(
|
|
80
|
+
model.engines,
|
|
81
|
+
new_relations,
|
|
82
|
+
model.types,
|
|
83
|
+
result.replacement
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
def handle(self, task: ir.Task, ctx: Flatten.Context) -> Flatten.HandleResult:
|
|
87
|
+
"""Override handle to add Loop support."""
|
|
88
|
+
if isinstance(task, ir.Loop):
|
|
89
|
+
return self.handle_loop(task, ctx)
|
|
90
|
+
return super().handle(task, ctx)
|
|
91
|
+
|
|
92
|
+
def handle_loop(self, task: ir.Loop, ctx: Flatten.Context) -> Flatten.HandleResult:
|
|
93
|
+
"""Recursively handle the body of the loop."""
|
|
94
|
+
result = self.handle(task.body, ctx)
|
|
95
|
+
|
|
96
|
+
assert(result.replacement)
|
|
97
|
+
|
|
98
|
+
# If body unchanged, return original loop
|
|
99
|
+
if result.replacement is task.body:
|
|
100
|
+
return Flatten.HandleResult(task)
|
|
101
|
+
|
|
102
|
+
# Return new loop with handled body
|
|
103
|
+
return Flatten.HandleResult(ir.Loop(
|
|
104
|
+
task.engine,
|
|
105
|
+
task.hoisted,
|
|
106
|
+
task.iter,
|
|
107
|
+
result.replacement,
|
|
108
|
+
task.concurrency,
|
|
109
|
+
task.annotations
|
|
110
|
+
))
|
|
111
|
+
|
|
112
|
+
def handle_logical(self, task: ir.Logical, ctx: Context): # type: ignore[override]
|
|
113
|
+
"""
|
|
114
|
+
Handle Logical nodes.
|
|
115
|
+
|
|
116
|
+
Outside scripts: simple traversal to find nested scripts.
|
|
117
|
+
Inside scripts: prevent extraction to top-level and keep everything in sequence.
|
|
118
|
+
|
|
119
|
+
Note: Type checker complains about parameter type narrowing, but this is safe
|
|
120
|
+
because we only create FlattenScript.Context in our own rewrite() method.
|
|
121
|
+
"""
|
|
122
|
+
# Recursively process children
|
|
123
|
+
body: OrderedSet[ir.Task] = ordered_set()
|
|
124
|
+
for child in task.body:
|
|
125
|
+
result = self.handle(child, ctx)
|
|
126
|
+
if result.replacement is not None:
|
|
127
|
+
if ctx.in_script and isinstance(result.replacement, ir.Logical) and not result.replacement.hoisted:
|
|
128
|
+
# Inside script: inline simple logicals without hoisting
|
|
129
|
+
body.update(result.replacement.body)
|
|
130
|
+
else:
|
|
131
|
+
body.add(result.replacement)
|
|
132
|
+
|
|
133
|
+
if not body:
|
|
134
|
+
return Flatten.HandleResult(None)
|
|
135
|
+
|
|
136
|
+
return Flatten.HandleResult(
|
|
137
|
+
ir.Logical(task.engine, task.hoisted, tuple(body), task.annotations)
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
def flatten_match_in_logical(self, logical: ir.Logical, match: ir.Match, match_idx: int, ctx: Context) -> list[ir.Logical]:
|
|
141
|
+
"""
|
|
142
|
+
Flatten a Match inside a Logical within a script Sequence.
|
|
143
|
+
Returns a list of Logicals to be inserted in sequence.
|
|
144
|
+
"""
|
|
145
|
+
if not match.tasks:
|
|
146
|
+
return [logical]
|
|
147
|
+
|
|
148
|
+
# Split the logical into: tasks_before, match, tasks_after
|
|
149
|
+
tasks_before = list(logical.body[:match_idx])
|
|
150
|
+
tasks_after = list(logical.body[match_idx + 1:])
|
|
151
|
+
|
|
152
|
+
# Separate tasks_after into filters (non-Update tasks) and updates (Update tasks)
|
|
153
|
+
# Filters are constraints that should be included in all branches
|
|
154
|
+
filters = [task for task in tasks_after if not isinstance(task, ir.Update)]
|
|
155
|
+
updates = [task for task in tasks_after if isinstance(task, ir.Update)]
|
|
156
|
+
|
|
157
|
+
# Compute exposed variables
|
|
158
|
+
exposed_vars = self.compute_exposed_vars(match, match.tasks, ctx)
|
|
159
|
+
|
|
160
|
+
# Use dependency analysis for branch bodies (like flatten.py does)
|
|
161
|
+
branch_dependencies = ctx.info.task_dependencies(match)
|
|
162
|
+
|
|
163
|
+
# Collect all dependencies for the final Logical (tasks before + filters after)
|
|
164
|
+
final_dependencies = tasks_before + filters
|
|
165
|
+
|
|
166
|
+
# Negation length for wildcards
|
|
167
|
+
outputs = ctx.info.task_outputs(match)
|
|
168
|
+
negation_len = len(outputs) if outputs else 0
|
|
169
|
+
|
|
170
|
+
# Result: list of Logicals to insert
|
|
171
|
+
result_logicals = []
|
|
172
|
+
references = []
|
|
173
|
+
negated_reference = None
|
|
174
|
+
|
|
175
|
+
# Process each branch
|
|
176
|
+
for branch in match.tasks:
|
|
177
|
+
# Create connection relation for this branch
|
|
178
|
+
name = helpers.create_task_name(self.name_cache, branch, "_match")
|
|
179
|
+
relation = helpers.create_connection_relation(branch, exposed_vars, ctx.rewrite_ctx, name)
|
|
180
|
+
|
|
181
|
+
# Handle the branch (recursively process nested structures)
|
|
182
|
+
result = self.handle(branch, ctx)
|
|
183
|
+
branch_content = result.replacement if result.replacement else branch
|
|
184
|
+
|
|
185
|
+
# Update dependency tracking if branch was transformed
|
|
186
|
+
if result.replacement:
|
|
187
|
+
ctx.info.replaced(branch, result.replacement)
|
|
188
|
+
|
|
189
|
+
# Build logical for this branch
|
|
190
|
+
branch_body: OrderedSet[ir.Task] = ordered_set()
|
|
191
|
+
|
|
192
|
+
# Add dependencies (using dependency analysis, not all tasks)
|
|
193
|
+
branch_body.update(branch_dependencies)
|
|
194
|
+
|
|
195
|
+
# Add branch content using extend_body helper
|
|
196
|
+
extend_body(branch_body, branch_content)
|
|
197
|
+
|
|
198
|
+
# Add negation of previous branches (after branch content)
|
|
199
|
+
if negated_reference:
|
|
200
|
+
branch_body.add(negated_reference)
|
|
201
|
+
|
|
202
|
+
# Add derive to connection relation
|
|
203
|
+
branch_update = f.derive(relation, exposed_vars)
|
|
204
|
+
branch_body.add(branch_update)
|
|
205
|
+
|
|
206
|
+
# Create the Logical for this branch
|
|
207
|
+
branch_logical = mk_assign(ir.Logical(match.engine, tuple(), tuple(branch_body)))
|
|
208
|
+
|
|
209
|
+
result_logicals.append(branch_logical)
|
|
210
|
+
|
|
211
|
+
# Update references for final union
|
|
212
|
+
reference = f.lookup(relation, exposed_vars)
|
|
213
|
+
negated_reference = negate(reference, negation_len)
|
|
214
|
+
references.append(reference)
|
|
215
|
+
|
|
216
|
+
# Create final Logical with Union and remaining tasks
|
|
217
|
+
final_body: OrderedSet[ir.Task] = ordered_set()
|
|
218
|
+
final_body.update(final_dependencies)
|
|
219
|
+
|
|
220
|
+
# Add union of all branches
|
|
221
|
+
union = f.union(references, match.hoisted)
|
|
222
|
+
final_body.add(union)
|
|
223
|
+
|
|
224
|
+
# Add updates that came after the match (filters are already in dependencies)
|
|
225
|
+
final_body.update(updates)
|
|
226
|
+
|
|
227
|
+
# Create final logical preserving the original annotations
|
|
228
|
+
final_logical = ir.Logical(logical.engine, logical.hoisted, tuple(final_body), logical.annotations)
|
|
229
|
+
|
|
230
|
+
result_logicals.append(final_logical)
|
|
231
|
+
|
|
232
|
+
return result_logicals
|
|
233
|
+
|
|
234
|
+
def handle_sequence(self, task: ir.Sequence, ctx: Context): # type: ignore[override]
|
|
235
|
+
"""
|
|
236
|
+
Handle a Sequence.
|
|
237
|
+
|
|
238
|
+
If it's a script: set context flag and flatten Match nodes within.
|
|
239
|
+
If not a script: simple traversal (Flatten pass already processed it).
|
|
240
|
+
|
|
241
|
+
Note: Type checker complains about parameter type narrowing, but this is safe
|
|
242
|
+
because we only create FlattenScript.Context in our own rewrite() method.
|
|
243
|
+
"""
|
|
244
|
+
if not is_script(task):
|
|
245
|
+
# Not a script sequence - already processed by Flatten, just return as-is
|
|
246
|
+
return Flatten.HandleResult(task)
|
|
247
|
+
|
|
248
|
+
# This is a script - mark context and process with flattening
|
|
249
|
+
old_in_script = ctx.in_script
|
|
250
|
+
ctx.in_script = True
|
|
251
|
+
|
|
252
|
+
# Process the sequence tasks
|
|
253
|
+
new_tasks: list[ir.Task] = []
|
|
254
|
+
for child in task.tasks:
|
|
255
|
+
# Check if this child is a Logical with Match that needs flattening
|
|
256
|
+
if isinstance(child, ir.Logical):
|
|
257
|
+
new_tasks.extend(self.try_flatten_logical(child, ctx))
|
|
258
|
+
continue
|
|
259
|
+
|
|
260
|
+
# No flattening needed, process normally
|
|
261
|
+
result = self.handle(child, ctx)
|
|
262
|
+
if result.replacement is not None:
|
|
263
|
+
new_tasks.append(result.replacement)
|
|
264
|
+
|
|
265
|
+
# Restore context
|
|
266
|
+
ctx.in_script = old_in_script
|
|
267
|
+
|
|
268
|
+
return Flatten.HandleResult(
|
|
269
|
+
ir.Sequence(task.engine, task.hoisted, tuple(new_tasks), task.annotations)
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
def try_flatten_logical(self, logical: ir.Logical, ctx: Context) -> list[ir.Logical]:
|
|
273
|
+
"""
|
|
274
|
+
Flatten all Matches in a Logical.
|
|
275
|
+
Iteratively flattens until no more Matches remain in any of the resulting Logicals.
|
|
276
|
+
"""
|
|
277
|
+
worklist = [logical]
|
|
278
|
+
result = []
|
|
279
|
+
|
|
280
|
+
while worklist:
|
|
281
|
+
current = worklist.pop()
|
|
282
|
+
|
|
283
|
+
# Find first Match in current logical
|
|
284
|
+
match = None
|
|
285
|
+
match_idx = -1
|
|
286
|
+
for i, child in enumerate(current.body):
|
|
287
|
+
if isinstance(child, ir.Match):
|
|
288
|
+
match = child
|
|
289
|
+
match_idx = i
|
|
290
|
+
break
|
|
291
|
+
|
|
292
|
+
if match is None:
|
|
293
|
+
# No Match found - this logical is done
|
|
294
|
+
result.append(current)
|
|
295
|
+
else:
|
|
296
|
+
# Flatten and add results back to worklist for further processing.
|
|
297
|
+
# Reverse so that pop() returns them in the original order.
|
|
298
|
+
flattened = self.flatten_match_in_logical(current, match, match_idx, ctx)
|
|
299
|
+
worklist.extend(reversed(flattened))
|
|
300
|
+
|
|
301
|
+
return result
|
|
@@ -322,22 +322,27 @@ class FunctionalDependency:
|
|
|
322
322
|
|
|
323
323
|
def contains_only_declarable_constraints(node: Node) -> bool:
|
|
324
324
|
"""
|
|
325
|
-
Checks whether the input
|
|
326
|
-
`declare_constraint
|
|
325
|
+
Checks whether the input node contains only `Require` nodes annotated with
|
|
326
|
+
`declare_constraint` (or such a node itself).
|
|
327
327
|
"""
|
|
328
|
+
# Check if the node itself is a Require node with declarable constraint
|
|
329
|
+
if isinstance(node, Require):
|
|
330
|
+
return is_declarable_constraint(node)
|
|
331
|
+
|
|
332
|
+
# Otherwise, check if it is a Logical node containing only declarable constraints
|
|
328
333
|
if not isinstance(node, Logical):
|
|
329
334
|
return False
|
|
330
335
|
if len(node.body) == 0:
|
|
331
336
|
return False
|
|
332
337
|
for task in node.body:
|
|
333
|
-
if not
|
|
334
|
-
return False
|
|
335
|
-
if not is_declarable_constraint(task):
|
|
338
|
+
if not contains_only_declarable_constraints(task):
|
|
336
339
|
return False
|
|
337
340
|
return True
|
|
338
341
|
|
|
339
|
-
def is_declarable_constraint(node:
|
|
342
|
+
def is_declarable_constraint(node: Node) -> bool:
|
|
340
343
|
"""
|
|
341
|
-
Checks whether the input `Require` node
|
|
344
|
+
Checks whether the input node is a `Require` node annotated with `declare_constraint`.
|
|
342
345
|
"""
|
|
346
|
+
if not isinstance(node, Require):
|
|
347
|
+
return False
|
|
343
348
|
return builtins.declare_constraint_annotation in node.annotations
|
|
@@ -45,7 +45,7 @@ def _ignored_vars(node: ir.Logical|ir.Not):
|
|
|
45
45
|
|
|
46
46
|
elif isinstance(task, ir.Rank):
|
|
47
47
|
# Variables that are keys, and not in the group-by, don't need to be quantified.
|
|
48
|
-
for var in task.args
|
|
48
|
+
for var in task.args:
|
|
49
49
|
if var not in task.group:
|
|
50
50
|
vars_to_ignore.add(var)
|
|
51
51
|
|
|
@@ -149,13 +149,22 @@ class VarScopeInfo(Visitor):
|
|
|
149
149
|
if isinstance(task, ir.Output):
|
|
150
150
|
output_vars.update(helpers.output_vars(task.aliases))
|
|
151
151
|
|
|
152
|
-
if isinstance(task,
|
|
153
|
-
# Variables that are in the group-by, and not in the
|
|
152
|
+
if isinstance(task, ir.Aggregate):
|
|
153
|
+
# Variables that are in the group-by, and not in the args,
|
|
154
|
+
# can come into scope.
|
|
154
155
|
for var in task.group:
|
|
155
156
|
if var not in task.args:
|
|
156
157
|
scope_vars.add(var)
|
|
157
158
|
continue
|
|
158
159
|
|
|
160
|
+
if isinstance(task, ir.Rank):
|
|
161
|
+
# Variables that are in the group-by or projection, and not in the args,
|
|
162
|
+
# can come into scope.
|
|
163
|
+
for var in task.group + task.projection:
|
|
164
|
+
if var not in task.args:
|
|
165
|
+
scope_vars.add(var)
|
|
166
|
+
continue
|
|
167
|
+
|
|
159
168
|
# Hoisted variables from sub-tasks are brought again into scope.
|
|
160
169
|
if isinstance(task, (ir.Logical, ir.Union, ir.Match)):
|
|
161
170
|
scope_vars.update(helpers.hoisted_vars(task.hoisted))
|
|
@@ -54,9 +54,9 @@ class UnifyDefinitions(Pass):
|
|
|
54
54
|
return head.relation
|
|
55
55
|
else:
|
|
56
56
|
assert isinstance(head, ir.Output)
|
|
57
|
-
if len(head.aliases)
|
|
58
|
-
# For processing here, we need output to have at least the column
|
|
59
|
-
# `cols` and `col
|
|
57
|
+
if len(head.aliases) < 2:
|
|
58
|
+
# For processing here, we need output to have at least the column marker
|
|
59
|
+
# `keys` or both `cols` and `col` markers, and also a key
|
|
60
60
|
return None
|
|
61
61
|
|
|
62
62
|
output_alias_names = helpers.output_alias_names(head.aliases)
|
|
@@ -65,6 +65,8 @@ class UnifyDefinitions(Pass):
|
|
|
65
65
|
# For normal outputs, the pattern is output[keys](cols, "col000" as 'col', ...)
|
|
66
66
|
if output_alias_names[0] == "cols" and output_alias_names[1] == "col":
|
|
67
67
|
return output_vals[1]
|
|
68
|
+
if output_alias_names[0] == "keys": # handle wide keys relation
|
|
69
|
+
return output_vals[0]
|
|
68
70
|
|
|
69
71
|
# For exports, the pattern is output[keys]("col000" as 'col', ...)
|
|
70
72
|
if helpers.is_export(head):
|
|
@@ -122,6 +124,10 @@ class UnifyDefinitions(Pass):
|
|
|
122
124
|
# keys.
|
|
123
125
|
output_values = helpers.output_values(head.aliases)[2:]
|
|
124
126
|
|
|
127
|
+
elif output_alias_names[0] == "keys": # handle wide keys output
|
|
128
|
+
assert len(head.aliases) > 1
|
|
129
|
+
output_values = helpers.output_values(head.aliases)[1:]
|
|
130
|
+
|
|
125
131
|
else:
|
|
126
132
|
assert helpers.is_export(head) and output_alias_names[0] == "col"
|
|
127
133
|
assert len(head.aliases) > 1
|
|
@@ -500,6 +500,15 @@ class DependencyAnalysis(visitor.Visitor):
|
|
|
500
500
|
assert(isinstance(c2task, helpers.COMPOSITES))
|
|
501
501
|
if not c2task.hoisted:
|
|
502
502
|
return True
|
|
503
|
+
|
|
504
|
+
# c1 is a composite with hoisted variables; it depends on c2 if c2 is mergeable
|
|
505
|
+
# and they share variables, hence behaving like a filter
|
|
506
|
+
if c1.composite and c2.mergeable and c1.shares_variable(c2):
|
|
507
|
+
c1task = c1.content.some()
|
|
508
|
+
assert(isinstance(c1task, helpers.COMPOSITES))
|
|
509
|
+
if c1task.hoisted:
|
|
510
|
+
return True
|
|
511
|
+
|
|
503
512
|
return False
|
|
504
513
|
|
|
505
514
|
cs = list(clusters)
|