relationalai 0.12.9__py3-none-any.whl → 0.12.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/__init__.py +9 -0
- relationalai/clients/__init__.py +2 -2
- relationalai/clients/local.py +571 -0
- relationalai/debugging.py +5 -2
- relationalai/semantics/__init__.py +2 -2
- relationalai/semantics/internal/__init__.py +2 -2
- relationalai/semantics/internal/internal.py +24 -7
- relationalai/semantics/lqp/README.md +34 -0
- relationalai/semantics/lqp/constructors.py +2 -1
- relationalai/semantics/lqp/executor.py +13 -2
- relationalai/semantics/lqp/ir.py +4 -0
- relationalai/semantics/lqp/model2lqp.py +41 -2
- relationalai/semantics/lqp/passes.py +6 -4
- relationalai/semantics/lqp/rewrite/__init__.py +2 -0
- relationalai/semantics/lqp/rewrite/annotate_constraints.py +55 -0
- relationalai/semantics/lqp/rewrite/extract_keys.py +22 -3
- relationalai/semantics/lqp/rewrite/functional_dependencies.py +42 -10
- relationalai/semantics/lqp/rewrite/quantify_vars.py +14 -0
- relationalai/semantics/lqp/validators.py +3 -0
- relationalai/semantics/metamodel/builtins.py +5 -0
- relationalai/semantics/metamodel/rewrite/flatten.py +10 -4
- relationalai/semantics/metamodel/typer/typer.py +13 -0
- relationalai/semantics/metamodel/types.py +2 -1
- relationalai/semantics/reasoners/graph/core.py +44 -53
- relationalai/tools/debugger.py +4 -2
- relationalai/tools/qb_debugger.py +5 -3
- {relationalai-0.12.9.dist-info → relationalai-0.12.11.dist-info}/METADATA +2 -2
- {relationalai-0.12.9.dist-info → relationalai-0.12.11.dist-info}/RECORD +31 -28
- {relationalai-0.12.9.dist-info → relationalai-0.12.11.dist-info}/WHEEL +0 -0
- {relationalai-0.12.9.dist-info → relationalai-0.12.11.dist-info}/entry_points.txt +0 -0
- {relationalai-0.12.9.dist-info → relationalai-0.12.11.dist-info}/licenses/LICENSE +0 -0
|
@@ -621,7 +621,7 @@ class Producer:
|
|
|
621
621
|
|
|
622
622
|
if self._model and self._model._strict:
|
|
623
623
|
raise AttributeError(f"{self._name} has no relationship `{name}`")
|
|
624
|
-
if topmost_parent is not concept:
|
|
624
|
+
if topmost_parent is not concept and topmost_parent not in Concept.builtin_concepts:
|
|
625
625
|
topmost_parent._relationships[name] = topmost_parent._get_relationship(name)
|
|
626
626
|
rich.print(f"[red bold][Implicit Subtype Relationship][/red bold] [yellow]{concept}.{name}[/yellow] appended to topmost parent [yellow]{topmost_parent}[/yellow] instead")
|
|
627
627
|
|
|
@@ -1165,7 +1165,10 @@ Primitive = Concept.builtins["Primitive"] = Concept("Primitive")
|
|
|
1165
1165
|
Error = Concept.builtins["Error"] = ErrorConcept("Error")
|
|
1166
1166
|
|
|
1167
1167
|
def _register_builtin(name):
|
|
1168
|
-
|
|
1168
|
+
if name == "AnyEntity":
|
|
1169
|
+
c = Concept(name)
|
|
1170
|
+
else:
|
|
1171
|
+
c = Concept(name, extends=[Primitive])
|
|
1169
1172
|
Concept.builtin_concepts.add(c)
|
|
1170
1173
|
Concept.builtins[name] = c
|
|
1171
1174
|
|
|
@@ -1174,6 +1177,7 @@ for builtin in types.builtin_types:
|
|
|
1174
1177
|
if isinstance(builtin, ir.ScalarType):
|
|
1175
1178
|
_register_builtin(builtin.name)
|
|
1176
1179
|
|
|
1180
|
+
AnyEntity = Concept.builtins["AnyEntity"]
|
|
1177
1181
|
Float = Concept.builtins["Float"]
|
|
1178
1182
|
Number = Concept.builtins["Number"]
|
|
1179
1183
|
Int64 = Concept.builtins["Int64"]
|
|
@@ -2896,10 +2900,9 @@ class Compiler():
|
|
|
2896
2900
|
if concept not in self.types:
|
|
2897
2901
|
self.to_type(concept)
|
|
2898
2902
|
self.to_relation(concept)
|
|
2899
|
-
|
|
2900
|
-
|
|
2901
|
-
|
|
2902
|
-
rules.append(rule)
|
|
2903
|
+
rule = self.concept_inheritance_rule(concept)
|
|
2904
|
+
if rule:
|
|
2905
|
+
rules.append(rule)
|
|
2903
2906
|
unresolved = []
|
|
2904
2907
|
for relationship in model.relationships:
|
|
2905
2908
|
if relationship not in self.relations:
|
|
@@ -3204,8 +3207,11 @@ class Compiler():
|
|
|
3204
3207
|
# filter extends to get only non-primitive parents
|
|
3205
3208
|
parents = []
|
|
3206
3209
|
for parent in concept._extends:
|
|
3207
|
-
if not parent._is_primitive():
|
|
3210
|
+
if not parent._is_primitive() and parent is not AnyEntity:
|
|
3208
3211
|
parents.append(parent)
|
|
3212
|
+
# always extend AnyEntity for non-primitive types that are not built-in
|
|
3213
|
+
if not concept._is_primitive() and concept not in Concept.builtin_concepts:
|
|
3214
|
+
parents.append(AnyEntity)
|
|
3209
3215
|
# only extends primitive types, no need for inheritance rules
|
|
3210
3216
|
if not parents:
|
|
3211
3217
|
return None
|
|
@@ -3218,6 +3224,17 @@ class Compiler():
|
|
|
3218
3224
|
*[f.derive(self.to_relation(parent), [var]) for parent in parents]
|
|
3219
3225
|
])
|
|
3220
3226
|
|
|
3227
|
+
def concept_any_entity_rule(self, entities:list[Concept]):
|
|
3228
|
+
"""
|
|
3229
|
+
Generate an inheritance rule for all these entities to AnyEntity.
|
|
3230
|
+
"""
|
|
3231
|
+
any_entity_relation = self.to_relation(AnyEntity)
|
|
3232
|
+
var = f.var("v", types.Any)
|
|
3233
|
+
return f.logical([
|
|
3234
|
+
f.union([f.lookup(self.to_relation(e), [var]) for e in entities]),
|
|
3235
|
+
f.derive(any_entity_relation, [var])
|
|
3236
|
+
])
|
|
3237
|
+
|
|
3221
3238
|
def relation_dict(self, items:dict[Relationship|Concept, Producer], ctx:CompilerContext) -> dict[ir.Relation, list[ir.Var]]:
|
|
3222
3239
|
return {self.to_relation(k): unwrap_list(self.lookup(v, ctx)) for k, v in items.items()}
|
|
3223
3240
|
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# Logic Engine LQP Backend
|
|
2
|
+
|
|
3
|
+
The logic engine runs the *Logical Query Protocol* (short *LQP*). This module includes a
|
|
4
|
+
compiler from the semantic metamodel to LQP along with an executor.
|
|
5
|
+
|
|
6
|
+
## Running against a local logic engine
|
|
7
|
+
|
|
8
|
+
For development and testing, it is possible to run PyRel models against a local logic engine
|
|
9
|
+
server process.
|
|
10
|
+
|
|
11
|
+
To start your local server, please refer to the [logic engine
|
|
12
|
+
docs](https://github.com/RelationalAI/raicode/tree/master/src/Server#starting-the-server).
|
|
13
|
+
|
|
14
|
+
With the local server running, add this to your `raiconfig.toml`:
|
|
15
|
+
|
|
16
|
+
```toml
|
|
17
|
+
[profile.local]
|
|
18
|
+
platform = "local"
|
|
19
|
+
engine = "local"
|
|
20
|
+
host = "localhost"
|
|
21
|
+
port = 8010
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
Then set `active_profile = "local"` at the top of the file.
|
|
25
|
+
|
|
26
|
+
**Known limitations:**
|
|
27
|
+
|
|
28
|
+
Local execution does not support running against Snowflake source tables.
|
|
29
|
+
|
|
30
|
+
At the moment, locally created databases cannot be cleaned up by the client. Eventually you
|
|
31
|
+
will need to clear your local pager directory.
|
|
32
|
+
|
|
33
|
+
At the moment, local execution is only supported for fast-path transactions, i.e. those
|
|
34
|
+
which complete in less than 5 seconds. Polling support will be added soon.
|
|
@@ -63,5 +63,6 @@ def mk_attribute(name: str, args: list[lqp.Value]) -> lqp.Attribute:
|
|
|
63
63
|
def mk_transaction(
|
|
64
64
|
epochs: list[lqp.Epoch],
|
|
65
65
|
configure: lqp.Configure = lqp.construct_configure({}, None),
|
|
66
|
+
sync = None
|
|
66
67
|
) -> lqp.Transaction:
|
|
67
|
-
return lqp.Transaction(epochs=epochs, configure=configure, meta=None)
|
|
68
|
+
return lqp.Transaction(epochs=epochs, configure=configure, sync=sync, meta=None)
|
|
@@ -66,6 +66,8 @@ class LQPExecutor(e.Executor):
|
|
|
66
66
|
resource_class = rai.clients.snowflake.Resources
|
|
67
67
|
if self.config.get("use_direct_access", USE_DIRECT_ACCESS):
|
|
68
68
|
resource_class = rai.clients.snowflake.DirectAccessResources
|
|
69
|
+
if self.config.get("platform", "") == "local":
|
|
70
|
+
resource_class = rai.clients.local.LocalResources
|
|
69
71
|
# NOTE: language="lqp" is not strictly required for LQP execution, but it
|
|
70
72
|
# will significantly improve performance.
|
|
71
73
|
self._resources = resource_class(
|
|
@@ -311,6 +313,12 @@ class LQPExecutor(e.Executor):
|
|
|
311
313
|
config_dict['ivm.maintenance_level'] = lqp_ir.Value(value=ivm_flag, meta=None)
|
|
312
314
|
return construct_configure(config_dict, None)
|
|
313
315
|
|
|
316
|
+
def _should_sync(self, model) :
|
|
317
|
+
if self._last_model != model:
|
|
318
|
+
return lqp_ir.Sync(fragments=[], meta=None)
|
|
319
|
+
else :
|
|
320
|
+
return None
|
|
321
|
+
|
|
314
322
|
def _compile_intrinsics(self) -> lqp_ir.Epoch:
|
|
315
323
|
"""Construct an epoch that defines a number of built-in definitions used by the
|
|
316
324
|
emitter."""
|
|
@@ -334,6 +342,7 @@ class LQPExecutor(e.Executor):
|
|
|
334
342
|
meta=None,
|
|
335
343
|
)
|
|
336
344
|
|
|
345
|
+
# [RAI-40997] We eagerly undefine query fragments so they are not committed to storage
|
|
337
346
|
def _compile_undefine_query(self, query_epoch: lqp_ir.Epoch) -> lqp_ir.Epoch:
|
|
338
347
|
fragment_ids = []
|
|
339
348
|
|
|
@@ -363,7 +372,9 @@ class LQPExecutor(e.Executor):
|
|
|
363
372
|
epochs = []
|
|
364
373
|
epochs.append(self._compile_intrinsics())
|
|
365
374
|
|
|
366
|
-
|
|
375
|
+
sync = self._should_sync(model)
|
|
376
|
+
|
|
377
|
+
if sync is not None:
|
|
367
378
|
with debugging.span("compile", metamodel=model) as install_span:
|
|
368
379
|
install_span["compile_type"] = "model"
|
|
369
380
|
_, model_epoch = self.compiler.compile(model, {"fragment_id": b"model"})
|
|
@@ -383,7 +394,7 @@ class LQPExecutor(e.Executor):
|
|
|
383
394
|
epochs.append(self._compile_undefine_query(query_epoch))
|
|
384
395
|
|
|
385
396
|
txn_span["compile_type"] = "query"
|
|
386
|
-
txn = mk_transaction(epochs=epochs, configure=configure)
|
|
397
|
+
txn = mk_transaction(epochs=epochs, configure=configure, sync=sync)
|
|
387
398
|
txn_span["lqp"] = lqp_print.to_string(txn, {"print_names": True, "print_debug": False, "print_csv_filename": False})
|
|
388
399
|
|
|
389
400
|
validate_lqp(txn)
|
relationalai/semantics/lqp/ir.py
CHANGED
|
@@ -4,6 +4,7 @@ __all__ = [
|
|
|
4
4
|
"SourceInfo",
|
|
5
5
|
"LqpNode",
|
|
6
6
|
"Declaration",
|
|
7
|
+
"FunctionalDependency",
|
|
7
8
|
"Def",
|
|
8
9
|
"Loop",
|
|
9
10
|
"Abstraction",
|
|
@@ -45,6 +46,7 @@ __all__ = [
|
|
|
45
46
|
"Read",
|
|
46
47
|
"Epoch",
|
|
47
48
|
"Transaction",
|
|
49
|
+
"Sync",
|
|
48
50
|
"DebugInfo",
|
|
49
51
|
"Configure",
|
|
50
52
|
"IVMConfig",
|
|
@@ -59,6 +61,7 @@ from lqp.ir import (
|
|
|
59
61
|
SourceInfo,
|
|
60
62
|
LqpNode,
|
|
61
63
|
Declaration,
|
|
64
|
+
FunctionalDependency,
|
|
62
65
|
Def,
|
|
63
66
|
Loop,
|
|
64
67
|
Abstraction,
|
|
@@ -100,6 +103,7 @@ from lqp.ir import (
|
|
|
100
103
|
Read,
|
|
101
104
|
Epoch,
|
|
102
105
|
Transaction,
|
|
106
|
+
Sync,
|
|
103
107
|
DebugInfo,
|
|
104
108
|
Configure,
|
|
105
109
|
IVMConfig,
|
|
@@ -11,7 +11,9 @@ from relationalai.semantics.lqp.constructors import (
|
|
|
11
11
|
)
|
|
12
12
|
from relationalai.semantics.lqp.utils import TranslationCtx, gen_unique_var
|
|
13
13
|
from relationalai.semantics.lqp.validators import assert_valid_input
|
|
14
|
-
|
|
14
|
+
from relationalai.semantics.lqp.rewrite.functional_dependencies import (
|
|
15
|
+
normalized_fd, contains_only_declarable_constraints
|
|
16
|
+
)
|
|
15
17
|
from decimal import Decimal as PyDecimal
|
|
16
18
|
from datetime import datetime, date, timezone
|
|
17
19
|
from typing import Tuple, cast, Union, Optional
|
|
@@ -102,6 +104,43 @@ def _get_export_reads(export_ids: list[tuple[lqp.RelationId, int, lqp.Type]]) ->
|
|
|
102
104
|
return (export_filename, col_info, reads)
|
|
103
105
|
|
|
104
106
|
def _translate_to_decls(ctx: TranslationCtx, rule: ir.Logical) -> list[lqp.Declaration]:
|
|
107
|
+
if contains_only_declarable_constraints(rule):
|
|
108
|
+
return _translate_to_constraint_decls(ctx, rule)
|
|
109
|
+
else:
|
|
110
|
+
return _translate_to_standard_decl(ctx, rule)
|
|
111
|
+
|
|
112
|
+
def _translate_to_constraint_decls(ctx: TranslationCtx, rule: ir.Logical) -> list[lqp.Declaration]:
|
|
113
|
+
constraint_decls: list[lqp.Declaration] = []
|
|
114
|
+
for task in rule.body:
|
|
115
|
+
assert isinstance(task, ir.Require)
|
|
116
|
+
fd = normalized_fd(task)
|
|
117
|
+
assert fd is not None
|
|
118
|
+
|
|
119
|
+
# check for unresolved types
|
|
120
|
+
if any(types.is_any(var.type) for var in fd.keys + fd.values):
|
|
121
|
+
warn(f"Ignoring FD with unresolved type: {fd}")
|
|
122
|
+
continue
|
|
123
|
+
|
|
124
|
+
lqp_typed_keys = [_translate_term(ctx, key) for key in fd.keys]
|
|
125
|
+
lqp_typed_values = [_translate_term(ctx, value) for value in fd.values]
|
|
126
|
+
lqp_typed_vars:list[Tuple[lqp.Var, lqp.Type]] = lqp_typed_keys + lqp_typed_values # type: ignore
|
|
127
|
+
lqp_guard_atoms = [_translate_to_atom(ctx, atom) for atom in fd.guard]
|
|
128
|
+
lqp_guard = mk_abstraction(lqp_typed_vars, mk_and(lqp_guard_atoms))
|
|
129
|
+
lqp_keys:list[lqp.Var] = [var for (var, _) in lqp_typed_keys] # type: ignore
|
|
130
|
+
lqp_values:list[lqp.Var] = [var for (var, _) in lqp_typed_values] # type: ignore
|
|
131
|
+
|
|
132
|
+
fd_decl = lqp.FunctionalDependency(
|
|
133
|
+
guard=lqp_guard,
|
|
134
|
+
keys=lqp_keys,
|
|
135
|
+
values=lqp_values,
|
|
136
|
+
meta=None
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
constraint_decls.append(fd_decl)
|
|
140
|
+
|
|
141
|
+
return constraint_decls
|
|
142
|
+
|
|
143
|
+
def _translate_to_standard_decl(ctx: TranslationCtx, rule: ir.Logical) -> list[lqp.Declaration]:
|
|
105
144
|
effects = collect_by_type((ir.Output, ir.Update), rule)
|
|
106
145
|
aggregates = collect_by_type(ir.Aggregate, rule)
|
|
107
146
|
ranks = collect_by_type(ir.Rank, rule)
|
|
@@ -452,7 +491,7 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
|
|
|
452
491
|
|
|
453
492
|
return mk_exists(result_terms, conjunction)
|
|
454
493
|
|
|
455
|
-
# `input_args
|
|
494
|
+
# `input_args` hold the types of the input arguments, but they may have been modified
|
|
456
495
|
# if we're dealing with a count, so we use `abstr_args` to find the type.
|
|
457
496
|
(aggr_arg, aggr_arg_type) = abstr_args[-1]
|
|
458
497
|
|
|
@@ -6,9 +6,11 @@ from relationalai.semantics.metamodel.util import FrozenOrderedSet
|
|
|
6
6
|
|
|
7
7
|
from relationalai.semantics.metamodel.rewrite import Flatten
|
|
8
8
|
|
|
9
|
-
from ..metamodel.rewrite import
|
|
10
|
-
from .rewrite import
|
|
11
|
-
|
|
9
|
+
from ..metamodel.rewrite import DNFUnionSplitter, ExtractNestedLogicals, FormatOutputs
|
|
10
|
+
from .rewrite import (
|
|
11
|
+
AnnotateConstraints, CDC, ExtractCommon, ExtractKeys, FunctionAnnotations, QuantifyVars,
|
|
12
|
+
Splinter, SplitMultiCheckRequires
|
|
13
|
+
)
|
|
12
14
|
from relationalai.semantics.lqp.utils import output_names
|
|
13
15
|
|
|
14
16
|
from typing import cast, List, Sequence, Tuple, Union, Optional, Iterable
|
|
@@ -20,7 +22,7 @@ def lqp_passes() -> list[Pass]:
|
|
|
20
22
|
return [
|
|
21
23
|
SplitMultiCheckRequires(),
|
|
22
24
|
FunctionAnnotations(),
|
|
23
|
-
|
|
25
|
+
AnnotateConstraints(),
|
|
24
26
|
Checker(),
|
|
25
27
|
CDC(), # specialize to physical relations before extracting nested and typing
|
|
26
28
|
ExtractNestedLogicals(), # before InferTypes to avoid extracting casts
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from .annotate_constraints import AnnotateConstraints
|
|
1
2
|
from .cdc import CDC
|
|
2
3
|
from .extract_common import ExtractCommon
|
|
3
4
|
from .extract_keys import ExtractKeys
|
|
@@ -6,6 +7,7 @@ from .quantify_vars import QuantifyVars
|
|
|
6
7
|
from .splinter import Splinter
|
|
7
8
|
|
|
8
9
|
__all__ = [
|
|
10
|
+
"AnnotateConstraints",
|
|
9
11
|
"CDC",
|
|
10
12
|
"ExtractCommon",
|
|
11
13
|
"ExtractKeys",
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from relationalai.semantics.metamodel import builtins
|
|
4
|
+
from relationalai.semantics.metamodel.ir import Node, Model, Require
|
|
5
|
+
from relationalai.semantics.metamodel.compiler import Pass
|
|
6
|
+
from relationalai.semantics.metamodel.rewrite.discharge_constraints import (
|
|
7
|
+
DischargeConstraintsVisitor
|
|
8
|
+
)
|
|
9
|
+
from relationalai.semantics.lqp.rewrite.functional_dependencies import (
|
|
10
|
+
is_valid_unique_constraint, normalized_fd
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AnnotateConstraints(Pass):
|
|
16
|
+
"""
|
|
17
|
+
Extends `DischargeConstraints` pass by discharging only those Require nodes that cannot
|
|
18
|
+
be declared as constraints in LQP.
|
|
19
|
+
|
|
20
|
+
More precisely, the pass annotates Require nodes depending on how they should be
|
|
21
|
+
treated when generating code:
|
|
22
|
+
* `@declare_constraint` if the Require represents a constraint that can be declared in LQP.
|
|
23
|
+
* `@discharge` if the Require represents a constraint that should be dismissed during
|
|
24
|
+
code generation. Namely, when it cannot be declared in LQP and uses one of the
|
|
25
|
+
`unique`, `exclusive`, `anyof` builtins. These nodes are removed from the IR model
|
|
26
|
+
in the Flatten pass.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def rewrite(self, model: Model, options: dict = {}) -> Model:
|
|
30
|
+
return AnnotateConstraintsRewriter().walk(model)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class AnnotateConstraintsRewriter(DischargeConstraintsVisitor):
|
|
34
|
+
"""
|
|
35
|
+
Visitor marks all nodes which should be removed from IR model with `discharge` annotation.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def _should_be_declarable_constraint(self, node: Require) -> bool:
|
|
39
|
+
if not is_valid_unique_constraint(node):
|
|
40
|
+
return False
|
|
41
|
+
# Currently, we only declare non-structural functional dependencies.
|
|
42
|
+
fd = normalized_fd(node)
|
|
43
|
+
assert fd is not None # already checked by _is_valid_unique_constraint
|
|
44
|
+
return not fd.is_structural
|
|
45
|
+
|
|
46
|
+
def handle_require(self, node: Require, parent: Node):
|
|
47
|
+
if self._should_be_declarable_constraint(node):
|
|
48
|
+
return node.reconstruct(
|
|
49
|
+
node.engine,
|
|
50
|
+
node.domain,
|
|
51
|
+
node.checks,
|
|
52
|
+
node.annotations | [builtins.declare_constraint_annotation]
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
return super().handle_require(node, parent)
|
|
@@ -249,6 +249,24 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
249
249
|
|
|
250
250
|
return f.logical(tuple(outer_body), [])
|
|
251
251
|
|
|
252
|
+
def noop_logical(self, node: ir.Logical) -> bool:
|
|
253
|
+
# logicals that don't hoist variables are essentially filters like lookups
|
|
254
|
+
if not node.hoisted:
|
|
255
|
+
return True
|
|
256
|
+
if len(node.body) != 1:
|
|
257
|
+
return False
|
|
258
|
+
inner = node.body[0]
|
|
259
|
+
if not isinstance(inner, (ir.Match, ir.Union)):
|
|
260
|
+
return False
|
|
261
|
+
outer_vars = helpers.hoisted_vars(node.hoisted)
|
|
262
|
+
inner_vars = helpers.hoisted_vars(inner.hoisted)
|
|
263
|
+
for v in outer_vars:
|
|
264
|
+
if v not in inner_vars:
|
|
265
|
+
return False
|
|
266
|
+
# all vars hoisted by the outer logical, are also
|
|
267
|
+
# hoisted by the inner Match/Union
|
|
268
|
+
return True
|
|
269
|
+
|
|
252
270
|
# compute inital information that's needed for later steps. E.g., what's nullable or
|
|
253
271
|
# not, do some output columns have a default value, etc.
|
|
254
272
|
def preprocess_logical(self, node: ir.Logical, output_keys: Iterable[ir.Var]):
|
|
@@ -264,10 +282,11 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
264
282
|
non_nullable_vars.update(vars)
|
|
265
283
|
top_level_tasks.add(task)
|
|
266
284
|
elif isinstance(task, ir.Logical):
|
|
267
|
-
|
|
268
|
-
if not task.hoisted:
|
|
285
|
+
if self.noop_logical(task):
|
|
269
286
|
top_level_tasks.add(task)
|
|
270
|
-
|
|
287
|
+
non_nullable_vars.update(helpers.hoisted_vars(task.hoisted))
|
|
288
|
+
continue
|
|
289
|
+
|
|
271
290
|
for h in task.hoisted:
|
|
272
291
|
# Hoisted vars without a default are not nullable
|
|
273
292
|
if isinstance(h, ir.Var):
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
from typing import Optional, Sequence
|
|
3
3
|
from relationalai.semantics.internal import internal
|
|
4
4
|
from relationalai.semantics.metamodel.ir import (
|
|
5
|
-
Require, Logical, Var, Relation, Lookup, ScalarType
|
|
5
|
+
Node, Require, Logical, Var, Relation, Lookup, ScalarType
|
|
6
6
|
)
|
|
7
7
|
from relationalai.semantics.metamodel import builtins
|
|
8
8
|
|
|
@@ -130,14 +130,16 @@ def _split_unique_require_node(node: Require) -> Optional[tuple[list[Var], list[
|
|
|
130
130
|
return None
|
|
131
131
|
|
|
132
132
|
# collect variables
|
|
133
|
-
all_vars:
|
|
133
|
+
all_vars: list[Var] = []
|
|
134
134
|
for lookup in guard:
|
|
135
135
|
for arg in lookup.args:
|
|
136
136
|
if not isinstance(arg, Var):
|
|
137
137
|
return None
|
|
138
|
-
all_vars
|
|
138
|
+
if arg in all_vars:
|
|
139
|
+
continue
|
|
140
|
+
all_vars.append(arg)
|
|
139
141
|
|
|
140
|
-
unique_vars:
|
|
142
|
+
unique_vars: list[Var] = []
|
|
141
143
|
if len(unique_atom.args) != 1:
|
|
142
144
|
return None
|
|
143
145
|
if not isinstance(unique_atom.args[0], (internal.TupleArg, tuple)):
|
|
@@ -147,10 +149,12 @@ def _split_unique_require_node(node: Require) -> Optional[tuple[list[Var], list[
|
|
|
147
149
|
for arg in unique_atom.args[0]:
|
|
148
150
|
if not isinstance(arg, Var):
|
|
149
151
|
return None
|
|
150
|
-
unique_vars
|
|
152
|
+
if arg in unique_vars:
|
|
153
|
+
return None
|
|
154
|
+
unique_vars.append(arg)
|
|
151
155
|
|
|
152
156
|
# check that unique vars are a subset of other vars
|
|
153
|
-
if not unique_vars.issubset(all_vars):
|
|
157
|
+
if not set(unique_vars).issubset(set(all_vars)):
|
|
154
158
|
return None
|
|
155
159
|
|
|
156
160
|
return list(all_vars), list(unique_vars), guard
|
|
@@ -218,10 +222,10 @@ class FunctionalDependency:
|
|
|
218
222
|
- `X` and `Y` are disjoint and covering sets of variables used in `φ`
|
|
219
223
|
"""
|
|
220
224
|
def __init__(self, guard: Sequence[Lookup], keys: Sequence[Var], values: Sequence[Var]):
|
|
221
|
-
self.guard =
|
|
222
|
-
self.keys =
|
|
223
|
-
self.values =
|
|
224
|
-
assert self.keys.isdisjoint(self.values), "Keys and values must be disjoint"
|
|
225
|
+
self.guard = tuple(guard)
|
|
226
|
+
self.keys = tuple(keys)
|
|
227
|
+
self.values = tuple(values)
|
|
228
|
+
assert set(self.keys).isdisjoint(set(self.values)), "Keys and values must be disjoint"
|
|
225
229
|
|
|
226
230
|
# for structural fd check
|
|
227
231
|
self._is_structural:bool = False
|
|
@@ -280,3 +284,31 @@ class FunctionalDependency:
|
|
|
280
284
|
raise ValueError("Functional dependency is not structural")
|
|
281
285
|
assert self._structural_rank is not None
|
|
282
286
|
return self._structural_rank
|
|
287
|
+
|
|
288
|
+
def __str__(self) -> str:
|
|
289
|
+
guard_str = " ∧ ".join([str(atom) for atom in self.guard]).strip()
|
|
290
|
+
keys_str = ", ".join([str(var) for var in self.keys]).strip()
|
|
291
|
+
values_str = ", ".join([str(var) for var in self.values]).strip()
|
|
292
|
+
return f"{guard_str}: {{{keys_str}}} -> {{{values_str}}}"
|
|
293
|
+
|
|
294
|
+
def contains_only_declarable_constraints(node: Node) -> bool:
|
|
295
|
+
"""
|
|
296
|
+
Checks whether the input `Logical` node contains only `Require` nodes annotated with
|
|
297
|
+
`declare_constraint`.
|
|
298
|
+
"""
|
|
299
|
+
if not isinstance(node, Logical):
|
|
300
|
+
return False
|
|
301
|
+
if len(node.body) == 0:
|
|
302
|
+
return False
|
|
303
|
+
for task in node.body:
|
|
304
|
+
if not isinstance(task, Require):
|
|
305
|
+
return False
|
|
306
|
+
if not is_declarable_constraint(task):
|
|
307
|
+
return False
|
|
308
|
+
return True
|
|
309
|
+
|
|
310
|
+
def is_declarable_constraint(node: Require) -> bool:
|
|
311
|
+
"""
|
|
312
|
+
Checks whether the input `Require` node is annotated with `declare_constraint`.
|
|
313
|
+
"""
|
|
314
|
+
return builtins.declare_constraint_annotation in node.annotations
|
|
@@ -5,6 +5,7 @@ from relationalai.semantics.metamodel.compiler import Pass
|
|
|
5
5
|
from relationalai.semantics.metamodel.visitor import Visitor, Rewriter
|
|
6
6
|
from relationalai.semantics.metamodel.util import OrderedSet, ordered_set
|
|
7
7
|
from typing import Optional, Any, Tuple, Iterable
|
|
8
|
+
from .functional_dependencies import contains_only_declarable_constraints
|
|
8
9
|
|
|
9
10
|
class QuantifyVars(Pass):
|
|
10
11
|
"""
|
|
@@ -67,6 +68,7 @@ class VarScopeInfo(Visitor):
|
|
|
67
68
|
IGNORED_NODES = (ir.Type,
|
|
68
69
|
ir.Var, ir.Literal, ir.Relation, ir.Field,
|
|
69
70
|
ir.Default, ir.Output, ir.Update, ir.Aggregate,
|
|
71
|
+
ir.Check, ir.Require,
|
|
70
72
|
ir.Annotation, ir.Rank)
|
|
71
73
|
|
|
72
74
|
def __init__(self):
|
|
@@ -74,6 +76,9 @@ class VarScopeInfo(Visitor):
|
|
|
74
76
|
self._vars_in_scope = {}
|
|
75
77
|
|
|
76
78
|
def leave(self, node: ir.Node, parent: Optional[ir.Node]=None):
|
|
79
|
+
if contains_only_declarable_constraints(node):
|
|
80
|
+
return node
|
|
81
|
+
|
|
77
82
|
if isinstance(node, ir.Lookup):
|
|
78
83
|
self._record(node, helpers.vars(node.args))
|
|
79
84
|
|
|
@@ -189,6 +194,9 @@ class FindQuantificationNodes(Visitor):
|
|
|
189
194
|
self.node_quantifies_vars = {}
|
|
190
195
|
|
|
191
196
|
def enter(self, node: ir.Node, parent: Optional[ir.Node]=None) -> "Visitor":
|
|
197
|
+
if contains_only_declarable_constraints(node):
|
|
198
|
+
return self
|
|
199
|
+
|
|
192
200
|
if isinstance(node, (ir.Logical, ir.Not)):
|
|
193
201
|
ignored_vars = _ignored_vars(node)
|
|
194
202
|
self._handled_vars.update(ignored_vars)
|
|
@@ -202,6 +210,9 @@ class FindQuantificationNodes(Visitor):
|
|
|
202
210
|
return self
|
|
203
211
|
|
|
204
212
|
def leave(self, node: ir.Node, parent: Optional[ir.Node]=None) -> ir.Node:
|
|
213
|
+
if contains_only_declarable_constraints(node):
|
|
214
|
+
return node
|
|
215
|
+
|
|
205
216
|
if isinstance(node, (ir.Logical, ir.Not)):
|
|
206
217
|
ignored_vars = _ignored_vars(node)
|
|
207
218
|
self._handled_vars.difference_update(ignored_vars)
|
|
@@ -221,6 +232,9 @@ class QuantifyVarsRewriter(Rewriter):
|
|
|
221
232
|
self.node_quantifies_vars = quant.node_quantifies_vars
|
|
222
233
|
|
|
223
234
|
def handle_logical(self, node: ir.Logical, parent: ir.Node, ctx:Optional[Any]=None) -> ir.Logical:
|
|
235
|
+
if contains_only_declarable_constraints(node):
|
|
236
|
+
return node
|
|
237
|
+
|
|
224
238
|
new_body = self.walk_list(node.body, node)
|
|
225
239
|
|
|
226
240
|
if node.id in self.node_quantifies_vars:
|
|
@@ -524,6 +524,11 @@ recursion_config_annotation = f.annotation(recursion_config, [])
|
|
|
524
524
|
discharged = f.relation("discharged", [])
|
|
525
525
|
discharged_annotation = f.annotation(discharged, [])
|
|
526
526
|
|
|
527
|
+
# Require nodes with this annotation will be kept in the final metamodel to be emitted as
|
|
528
|
+
# constraint declarations (LQP)
|
|
529
|
+
declare_constraint = f.relation("declare_constraint", [])
|
|
530
|
+
declare_constraint_annotation = f.annotation(declare_constraint, [])
|
|
531
|
+
|
|
527
532
|
#
|
|
528
533
|
# Aggregations
|
|
529
534
|
#
|
|
@@ -5,7 +5,7 @@ from typing import Tuple
|
|
|
5
5
|
|
|
6
6
|
from relationalai.semantics.metamodel import builtins, ir, factory as f, helpers
|
|
7
7
|
from relationalai.semantics.metamodel.compiler import Pass, group_tasks
|
|
8
|
-
from relationalai.semantics.metamodel.util import OrderedSet, ordered_set
|
|
8
|
+
from relationalai.semantics.metamodel.util import NameCache, OrderedSet, ordered_set
|
|
9
9
|
from relationalai.semantics.metamodel import dependency
|
|
10
10
|
from relationalai.semantics.metamodel.typer.typer import to_type
|
|
11
11
|
|
|
@@ -419,9 +419,15 @@ class Flatten(Pass):
|
|
|
419
419
|
def handle_require(self, req: ir.Require, ctx: Context):
|
|
420
420
|
# only extract the domain if it is a somewhat complex Logical and there's more than
|
|
421
421
|
# one check, otherwise insert it straight into all checks
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
422
|
+
if builtins.discharged_annotation in req.annotations:
|
|
423
|
+
# remove discharged Requires
|
|
424
|
+
return Flatten.HandleResult(None)
|
|
425
|
+
elif builtins.declare_constraint_annotation in req.annotations:
|
|
426
|
+
# leave Requires that are declared constraints
|
|
427
|
+
return Flatten.HandleResult(req)
|
|
428
|
+
else:
|
|
429
|
+
# generate logic for remaining requires
|
|
430
|
+
domain = req.domain
|
|
425
431
|
if len(req.checks) > 1 and isinstance(domain, ir.Logical) and len(domain.body) > 1:
|
|
426
432
|
body = OrderedSet.from_iterable(domain.body)
|
|
427
433
|
vars = helpers.hoisted_vars(domain.hoisted)
|
|
@@ -156,6 +156,10 @@ def type_matches(actual:ir.Type, expected:ir.Type, allow_expected_parents=False)
|
|
|
156
156
|
if actual == types.Any or expected == types.Any:
|
|
157
157
|
return True
|
|
158
158
|
|
|
159
|
+
# any entity matches any entity (surprise surprise!)
|
|
160
|
+
if extends_any_entity(expected) and not is_primitive(actual):
|
|
161
|
+
return True
|
|
162
|
+
|
|
159
163
|
# all decimals match across each other
|
|
160
164
|
if types.is_decimal(actual) and types.is_decimal(expected):
|
|
161
165
|
return True
|
|
@@ -288,6 +292,15 @@ def is_base_primitive(type:ir.Type) -> bool:
|
|
|
288
292
|
def is_primitive(type:ir.Type) -> bool:
|
|
289
293
|
return to_base_primitive(type) is not None
|
|
290
294
|
|
|
295
|
+
def extends_any_entity(type:ir.Type) -> bool:
|
|
296
|
+
if type == types.AnyEntity:
|
|
297
|
+
return True
|
|
298
|
+
if isinstance(type, ir.ScalarType):
|
|
299
|
+
for parent in type.super_types:
|
|
300
|
+
if extends_any_entity(parent):
|
|
301
|
+
return True
|
|
302
|
+
return False
|
|
303
|
+
|
|
291
304
|
def invalid_type(type:ir.Type) -> bool:
|
|
292
305
|
if isinstance(type, ir.UnionType):
|
|
293
306
|
# if there are multiple primitives, or a primtive and a non-primitive
|
|
@@ -80,6 +80,7 @@ GenericDecimal = ir.ScalarType("GenericDecimal", util.frozen())
|
|
|
80
80
|
#
|
|
81
81
|
Null = ir.ScalarType("Null", util.frozen())
|
|
82
82
|
Any = ir.ScalarType("Any", util.frozen())
|
|
83
|
+
AnyEntity = ir.ScalarType("AnyEntity", util.frozen())
|
|
83
84
|
Hash = ir.ScalarType("Hash", util.frozen())
|
|
84
85
|
String = ir.ScalarType("String", util.frozen())
|
|
85
86
|
Int64 = ir.ScalarType("Int64")
|
|
@@ -144,7 +145,7 @@ def is_null(t: ir.Type) -> bool:
|
|
|
144
145
|
|
|
145
146
|
def is_abstract_type(t: ir.Type) -> bool:
|
|
146
147
|
if isinstance(t, ir.ScalarType):
|
|
147
|
-
return t in [Any, Number, GenericDecimal]
|
|
148
|
+
return t in [Any, AnyEntity, Number, GenericDecimal]
|
|
148
149
|
elif isinstance(t, ir.ListType):
|
|
149
150
|
return is_abstract_type(t.element_type)
|
|
150
151
|
elif isinstance(t, ir.TupleType):
|