relationalai 0.12.9__py3-none-any.whl → 0.12.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. relationalai/__init__.py +9 -0
  2. relationalai/clients/__init__.py +2 -2
  3. relationalai/clients/local.py +571 -0
  4. relationalai/debugging.py +5 -2
  5. relationalai/semantics/__init__.py +2 -2
  6. relationalai/semantics/internal/__init__.py +2 -2
  7. relationalai/semantics/internal/internal.py +24 -7
  8. relationalai/semantics/lqp/README.md +34 -0
  9. relationalai/semantics/lqp/constructors.py +2 -1
  10. relationalai/semantics/lqp/executor.py +13 -2
  11. relationalai/semantics/lqp/ir.py +4 -0
  12. relationalai/semantics/lqp/model2lqp.py +41 -2
  13. relationalai/semantics/lqp/passes.py +6 -4
  14. relationalai/semantics/lqp/rewrite/__init__.py +2 -0
  15. relationalai/semantics/lqp/rewrite/annotate_constraints.py +55 -0
  16. relationalai/semantics/lqp/rewrite/extract_keys.py +22 -3
  17. relationalai/semantics/lqp/rewrite/functional_dependencies.py +42 -10
  18. relationalai/semantics/lqp/rewrite/quantify_vars.py +14 -0
  19. relationalai/semantics/lqp/validators.py +3 -0
  20. relationalai/semantics/metamodel/builtins.py +5 -0
  21. relationalai/semantics/metamodel/rewrite/flatten.py +10 -4
  22. relationalai/semantics/metamodel/typer/typer.py +13 -0
  23. relationalai/semantics/metamodel/types.py +2 -1
  24. relationalai/semantics/reasoners/graph/core.py +44 -53
  25. relationalai/tools/debugger.py +4 -2
  26. relationalai/tools/qb_debugger.py +5 -3
  27. {relationalai-0.12.9.dist-info → relationalai-0.12.11.dist-info}/METADATA +2 -2
  28. {relationalai-0.12.9.dist-info → relationalai-0.12.11.dist-info}/RECORD +31 -28
  29. {relationalai-0.12.9.dist-info → relationalai-0.12.11.dist-info}/WHEEL +0 -0
  30. {relationalai-0.12.9.dist-info → relationalai-0.12.11.dist-info}/entry_points.txt +0 -0
  31. {relationalai-0.12.9.dist-info → relationalai-0.12.11.dist-info}/licenses/LICENSE +0 -0
@@ -621,7 +621,7 @@ class Producer:
621
621
 
622
622
  if self._model and self._model._strict:
623
623
  raise AttributeError(f"{self._name} has no relationship `{name}`")
624
- if topmost_parent is not concept:
624
+ if topmost_parent is not concept and topmost_parent not in Concept.builtin_concepts:
625
625
  topmost_parent._relationships[name] = topmost_parent._get_relationship(name)
626
626
  rich.print(f"[red bold][Implicit Subtype Relationship][/red bold] [yellow]{concept}.{name}[/yellow] appended to topmost parent [yellow]{topmost_parent}[/yellow] instead")
627
627
 
@@ -1165,7 +1165,10 @@ Primitive = Concept.builtins["Primitive"] = Concept("Primitive")
1165
1165
  Error = Concept.builtins["Error"] = ErrorConcept("Error")
1166
1166
 
1167
1167
  def _register_builtin(name):
1168
- c = Concept(name, extends=[Primitive])
1168
+ if name == "AnyEntity":
1169
+ c = Concept(name)
1170
+ else:
1171
+ c = Concept(name, extends=[Primitive])
1169
1172
  Concept.builtin_concepts.add(c)
1170
1173
  Concept.builtins[name] = c
1171
1174
 
@@ -1174,6 +1177,7 @@ for builtin in types.builtin_types:
1174
1177
  if isinstance(builtin, ir.ScalarType):
1175
1178
  _register_builtin(builtin.name)
1176
1179
 
1180
+ AnyEntity = Concept.builtins["AnyEntity"]
1177
1181
  Float = Concept.builtins["Float"]
1178
1182
  Number = Concept.builtins["Number"]
1179
1183
  Int64 = Concept.builtins["Int64"]
@@ -2896,10 +2900,9 @@ class Compiler():
2896
2900
  if concept not in self.types:
2897
2901
  self.to_type(concept)
2898
2902
  self.to_relation(concept)
2899
- if concept._extends:
2900
- rule = self.concept_inheritance_rule(concept)
2901
- if rule:
2902
- rules.append(rule)
2903
+ rule = self.concept_inheritance_rule(concept)
2904
+ if rule:
2905
+ rules.append(rule)
2903
2906
  unresolved = []
2904
2907
  for relationship in model.relationships:
2905
2908
  if relationship not in self.relations:
@@ -3204,8 +3207,11 @@ class Compiler():
3204
3207
  # filter extends to get only non-primitive parents
3205
3208
  parents = []
3206
3209
  for parent in concept._extends:
3207
- if not parent._is_primitive():
3210
+ if not parent._is_primitive() and parent is not AnyEntity:
3208
3211
  parents.append(parent)
3212
+ # always extend AnyEntity for non-primitive types that are not built-in
3213
+ if not concept._is_primitive() and concept not in Concept.builtin_concepts:
3214
+ parents.append(AnyEntity)
3209
3215
  # only extends primitive types, no need for inheritance rules
3210
3216
  if not parents:
3211
3217
  return None
@@ -3218,6 +3224,17 @@ class Compiler():
3218
3224
  *[f.derive(self.to_relation(parent), [var]) for parent in parents]
3219
3225
  ])
3220
3226
 
3227
+ def concept_any_entity_rule(self, entities:list[Concept]):
3228
+ """
3229
+ Generate an inheritance rule for all these entities to AnyEntity.
3230
+ """
3231
+ any_entity_relation = self.to_relation(AnyEntity)
3232
+ var = f.var("v", types.Any)
3233
+ return f.logical([
3234
+ f.union([f.lookup(self.to_relation(e), [var]) for e in entities]),
3235
+ f.derive(any_entity_relation, [var])
3236
+ ])
3237
+
3221
3238
  def relation_dict(self, items:dict[Relationship|Concept, Producer], ctx:CompilerContext) -> dict[ir.Relation, list[ir.Var]]:
3222
3239
  return {self.to_relation(k): unwrap_list(self.lookup(v, ctx)) for k, v in items.items()}
3223
3240
 
@@ -0,0 +1,34 @@
1
+ # Logic Engine LQP Backend
2
+
3
+ The logic engine runs the *Logical Query Protocol* (short *LQP*). This module includes a
4
+ compiler from the semantic metamodel to LQP along with an executor.
5
+
6
+ ## Running against a local logic engine
7
+
8
+ For development and testing, it is possible to run PyRel models against a local logic engine
9
+ server process.
10
+
11
+ To start your local server, please refer to the [logic engine
12
+ docs](https://github.com/RelationalAI/raicode/tree/master/src/Server#starting-the-server).
13
+
14
+ With the local server running, add this to your `raiconfig.toml`:
15
+
16
+ ```toml
17
+ [profile.local]
18
+ platform = "local"
19
+ engine = "local"
20
+ host = "localhost"
21
+ port = 8010
22
+ ```
23
+
24
+ Then set `active_profile = "local"` at the top of the file.
25
+
26
+ **Known limitations:**
27
+
28
+ Local execution does not support running against Snowflake source tables.
29
+
30
+ At the moment, locally created databases cannot be cleaned up by the client. Eventually you
31
+ will need to clear your local pager directory.
32
+
33
+ At the moment, local execution is only supported for fast-path transactions, i.e. those
34
+ which complete in less than 5 seconds. Polling support will be added soon.
@@ -63,5 +63,6 @@ def mk_attribute(name: str, args: list[lqp.Value]) -> lqp.Attribute:
63
63
  def mk_transaction(
64
64
  epochs: list[lqp.Epoch],
65
65
  configure: lqp.Configure = lqp.construct_configure({}, None),
66
+ sync = None
66
67
  ) -> lqp.Transaction:
67
- return lqp.Transaction(epochs=epochs, configure=configure, meta=None)
68
+ return lqp.Transaction(epochs=epochs, configure=configure, sync=sync, meta=None)
@@ -66,6 +66,8 @@ class LQPExecutor(e.Executor):
66
66
  resource_class = rai.clients.snowflake.Resources
67
67
  if self.config.get("use_direct_access", USE_DIRECT_ACCESS):
68
68
  resource_class = rai.clients.snowflake.DirectAccessResources
69
+ if self.config.get("platform", "") == "local":
70
+ resource_class = rai.clients.local.LocalResources
69
71
  # NOTE: language="lqp" is not strictly required for LQP execution, but it
70
72
  # will significantly improve performance.
71
73
  self._resources = resource_class(
@@ -311,6 +313,12 @@ class LQPExecutor(e.Executor):
311
313
  config_dict['ivm.maintenance_level'] = lqp_ir.Value(value=ivm_flag, meta=None)
312
314
  return construct_configure(config_dict, None)
313
315
 
316
+ def _should_sync(self, model) :
317
+ if self._last_model != model:
318
+ return lqp_ir.Sync(fragments=[], meta=None)
319
+ else :
320
+ return None
321
+
314
322
  def _compile_intrinsics(self) -> lqp_ir.Epoch:
315
323
  """Construct an epoch that defines a number of built-in definitions used by the
316
324
  emitter."""
@@ -334,6 +342,7 @@ class LQPExecutor(e.Executor):
334
342
  meta=None,
335
343
  )
336
344
 
345
+ # [RAI-40997] We eagerly undefine query fragments so they are not committed to storage
337
346
  def _compile_undefine_query(self, query_epoch: lqp_ir.Epoch) -> lqp_ir.Epoch:
338
347
  fragment_ids = []
339
348
 
@@ -363,7 +372,9 @@ class LQPExecutor(e.Executor):
363
372
  epochs = []
364
373
  epochs.append(self._compile_intrinsics())
365
374
 
366
- if self._last_model != model:
375
+ sync = self._should_sync(model)
376
+
377
+ if sync is not None:
367
378
  with debugging.span("compile", metamodel=model) as install_span:
368
379
  install_span["compile_type"] = "model"
369
380
  _, model_epoch = self.compiler.compile(model, {"fragment_id": b"model"})
@@ -383,7 +394,7 @@ class LQPExecutor(e.Executor):
383
394
  epochs.append(self._compile_undefine_query(query_epoch))
384
395
 
385
396
  txn_span["compile_type"] = "query"
386
- txn = mk_transaction(epochs=epochs, configure=configure)
397
+ txn = mk_transaction(epochs=epochs, configure=configure, sync=sync)
387
398
  txn_span["lqp"] = lqp_print.to_string(txn, {"print_names": True, "print_debug": False, "print_csv_filename": False})
388
399
 
389
400
  validate_lqp(txn)
@@ -4,6 +4,7 @@ __all__ = [
4
4
  "SourceInfo",
5
5
  "LqpNode",
6
6
  "Declaration",
7
+ "FunctionalDependency",
7
8
  "Def",
8
9
  "Loop",
9
10
  "Abstraction",
@@ -45,6 +46,7 @@ __all__ = [
45
46
  "Read",
46
47
  "Epoch",
47
48
  "Transaction",
49
+ "Sync",
48
50
  "DebugInfo",
49
51
  "Configure",
50
52
  "IVMConfig",
@@ -59,6 +61,7 @@ from lqp.ir import (
59
61
  SourceInfo,
60
62
  LqpNode,
61
63
  Declaration,
64
+ FunctionalDependency,
62
65
  Def,
63
66
  Loop,
64
67
  Abstraction,
@@ -100,6 +103,7 @@ from lqp.ir import (
100
103
  Read,
101
104
  Epoch,
102
105
  Transaction,
106
+ Sync,
103
107
  DebugInfo,
104
108
  Configure,
105
109
  IVMConfig,
@@ -11,7 +11,9 @@ from relationalai.semantics.lqp.constructors import (
11
11
  )
12
12
  from relationalai.semantics.lqp.utils import TranslationCtx, gen_unique_var
13
13
  from relationalai.semantics.lqp.validators import assert_valid_input
14
-
14
+ from relationalai.semantics.lqp.rewrite.functional_dependencies import (
15
+ normalized_fd, contains_only_declarable_constraints
16
+ )
15
17
  from decimal import Decimal as PyDecimal
16
18
  from datetime import datetime, date, timezone
17
19
  from typing import Tuple, cast, Union, Optional
@@ -102,6 +104,43 @@ def _get_export_reads(export_ids: list[tuple[lqp.RelationId, int, lqp.Type]]) ->
102
104
  return (export_filename, col_info, reads)
103
105
 
104
106
  def _translate_to_decls(ctx: TranslationCtx, rule: ir.Logical) -> list[lqp.Declaration]:
107
+ if contains_only_declarable_constraints(rule):
108
+ return _translate_to_constraint_decls(ctx, rule)
109
+ else:
110
+ return _translate_to_standard_decl(ctx, rule)
111
+
112
+ def _translate_to_constraint_decls(ctx: TranslationCtx, rule: ir.Logical) -> list[lqp.Declaration]:
113
+ constraint_decls: list[lqp.Declaration] = []
114
+ for task in rule.body:
115
+ assert isinstance(task, ir.Require)
116
+ fd = normalized_fd(task)
117
+ assert fd is not None
118
+
119
+ # check for unresolved types
120
+ if any(types.is_any(var.type) for var in fd.keys + fd.values):
121
+ warn(f"Ignoring FD with unresolved type: {fd}")
122
+ continue
123
+
124
+ lqp_typed_keys = [_translate_term(ctx, key) for key in fd.keys]
125
+ lqp_typed_values = [_translate_term(ctx, value) for value in fd.values]
126
+ lqp_typed_vars:list[Tuple[lqp.Var, lqp.Type]] = lqp_typed_keys + lqp_typed_values # type: ignore
127
+ lqp_guard_atoms = [_translate_to_atom(ctx, atom) for atom in fd.guard]
128
+ lqp_guard = mk_abstraction(lqp_typed_vars, mk_and(lqp_guard_atoms))
129
+ lqp_keys:list[lqp.Var] = [var for (var, _) in lqp_typed_keys] # type: ignore
130
+ lqp_values:list[lqp.Var] = [var for (var, _) in lqp_typed_values] # type: ignore
131
+
132
+ fd_decl = lqp.FunctionalDependency(
133
+ guard=lqp_guard,
134
+ keys=lqp_keys,
135
+ values=lqp_values,
136
+ meta=None
137
+ )
138
+
139
+ constraint_decls.append(fd_decl)
140
+
141
+ return constraint_decls
142
+
143
+ def _translate_to_standard_decl(ctx: TranslationCtx, rule: ir.Logical) -> list[lqp.Declaration]:
105
144
  effects = collect_by_type((ir.Output, ir.Update), rule)
106
145
  aggregates = collect_by_type(ir.Aggregate, rule)
107
146
  ranks = collect_by_type(ir.Rank, rule)
@@ -452,7 +491,7 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
452
491
 
453
492
  return mk_exists(result_terms, conjunction)
454
493
 
455
- # `input_args`` hold the types of the input arguments, but they may have been modified
494
+ # `input_args` hold the types of the input arguments, but they may have been modified
456
495
  # if we're dealing with a count, so we use `abstr_args` to find the type.
457
496
  (aggr_arg, aggr_arg_type) = abstr_args[-1]
458
497
 
@@ -6,9 +6,11 @@ from relationalai.semantics.metamodel.util import FrozenOrderedSet
6
6
 
7
7
  from relationalai.semantics.metamodel.rewrite import Flatten
8
8
 
9
- from ..metamodel.rewrite import DischargeConstraints, DNFUnionSplitter, ExtractNestedLogicals, FormatOutputs
10
- from .rewrite import CDC, ExtractCommon, ExtractKeys, FunctionAnnotations, QuantifyVars, Splinter, SplitMultiCheckRequires
11
-
9
+ from ..metamodel.rewrite import DNFUnionSplitter, ExtractNestedLogicals, FormatOutputs
10
+ from .rewrite import (
11
+ AnnotateConstraints, CDC, ExtractCommon, ExtractKeys, FunctionAnnotations, QuantifyVars,
12
+ Splinter, SplitMultiCheckRequires
13
+ )
12
14
  from relationalai.semantics.lqp.utils import output_names
13
15
 
14
16
  from typing import cast, List, Sequence, Tuple, Union, Optional, Iterable
@@ -20,7 +22,7 @@ def lqp_passes() -> list[Pass]:
20
22
  return [
21
23
  SplitMultiCheckRequires(),
22
24
  FunctionAnnotations(),
23
- DischargeConstraints(),
25
+ AnnotateConstraints(),
24
26
  Checker(),
25
27
  CDC(), # specialize to physical relations before extracting nested and typing
26
28
  ExtractNestedLogicals(), # before InferTypes to avoid extracting casts
@@ -1,3 +1,4 @@
1
+ from .annotate_constraints import AnnotateConstraints
1
2
  from .cdc import CDC
2
3
  from .extract_common import ExtractCommon
3
4
  from .extract_keys import ExtractKeys
@@ -6,6 +7,7 @@ from .quantify_vars import QuantifyVars
6
7
  from .splinter import Splinter
7
8
 
8
9
  __all__ = [
10
+ "AnnotateConstraints",
9
11
  "CDC",
10
12
  "ExtractCommon",
11
13
  "ExtractKeys",
@@ -0,0 +1,55 @@
1
+ from __future__ import annotations
2
+
3
+ from relationalai.semantics.metamodel import builtins
4
+ from relationalai.semantics.metamodel.ir import Node, Model, Require
5
+ from relationalai.semantics.metamodel.compiler import Pass
6
+ from relationalai.semantics.metamodel.rewrite.discharge_constraints import (
7
+ DischargeConstraintsVisitor
8
+ )
9
+ from relationalai.semantics.lqp.rewrite.functional_dependencies import (
10
+ is_valid_unique_constraint, normalized_fd
11
+ )
12
+
13
+
14
+
15
+ class AnnotateConstraints(Pass):
16
+ """
17
+ Extends `DischargeConstraints` pass by discharging only those Require nodes that cannot
18
+ be declared as constraints in LQP.
19
+
20
+ More precisely, the pass annotates Require nodes depending on how they should be
21
+ treated when generating code:
22
+ * `@declare_constraint` if the Require represents a constraint that can be declared in LQP.
23
+ * `@discharge` if the Require represents a constraint that should be dismissed during
24
+ code generation. Namely, when it cannot be declared in LQP and uses one of the
25
+ `unique`, `exclusive`, `anyof` builtins. These nodes are removed from the IR model
26
+ in the Flatten pass.
27
+ """
28
+
29
+ def rewrite(self, model: Model, options: dict = {}) -> Model:
30
+ return AnnotateConstraintsRewriter().walk(model)
31
+
32
+
33
+ class AnnotateConstraintsRewriter(DischargeConstraintsVisitor):
34
+ """
35
+ Visitor marks all nodes which should be removed from IR model with `discharge` annotation.
36
+ """
37
+
38
+ def _should_be_declarable_constraint(self, node: Require) -> bool:
39
+ if not is_valid_unique_constraint(node):
40
+ return False
41
+ # Currently, we only declare non-structural functional dependencies.
42
+ fd = normalized_fd(node)
43
+ assert fd is not None # already checked by _is_valid_unique_constraint
44
+ return not fd.is_structural
45
+
46
+ def handle_require(self, node: Require, parent: Node):
47
+ if self._should_be_declarable_constraint(node):
48
+ return node.reconstruct(
49
+ node.engine,
50
+ node.domain,
51
+ node.checks,
52
+ node.annotations | [builtins.declare_constraint_annotation]
53
+ )
54
+
55
+ return super().handle_require(node, parent)
@@ -249,6 +249,24 @@ class ExtractKeysRewriter(Rewriter):
249
249
 
250
250
  return f.logical(tuple(outer_body), [])
251
251
 
252
+ def noop_logical(self, node: ir.Logical) -> bool:
253
+ # logicals that don't hoist variables are essentially filters like lookups
254
+ if not node.hoisted:
255
+ return True
256
+ if len(node.body) != 1:
257
+ return False
258
+ inner = node.body[0]
259
+ if not isinstance(inner, (ir.Match, ir.Union)):
260
+ return False
261
+ outer_vars = helpers.hoisted_vars(node.hoisted)
262
+ inner_vars = helpers.hoisted_vars(inner.hoisted)
263
+ for v in outer_vars:
264
+ if v not in inner_vars:
265
+ return False
266
+ # all vars hoisted by the outer logical, are also
267
+ # hoisted by the inner Match/Union
268
+ return True
269
+
252
270
  # compute inital information that's needed for later steps. E.g., what's nullable or
253
271
  # not, do some output columns have a default value, etc.
254
272
  def preprocess_logical(self, node: ir.Logical, output_keys: Iterable[ir.Var]):
@@ -264,10 +282,11 @@ class ExtractKeysRewriter(Rewriter):
264
282
  non_nullable_vars.update(vars)
265
283
  top_level_tasks.add(task)
266
284
  elif isinstance(task, ir.Logical):
267
- # logicals that don't hoist variables are essentially filters like lookups
268
- if not task.hoisted:
285
+ if self.noop_logical(task):
269
286
  top_level_tasks.add(task)
270
- # TODO: should we do something about the inner variables?
287
+ non_nullable_vars.update(helpers.hoisted_vars(task.hoisted))
288
+ continue
289
+
271
290
  for h in task.hoisted:
272
291
  # Hoisted vars without a default are not nullable
273
292
  if isinstance(h, ir.Var):
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
  from typing import Optional, Sequence
3
3
  from relationalai.semantics.internal import internal
4
4
  from relationalai.semantics.metamodel.ir import (
5
- Require, Logical, Var, Relation, Lookup, ScalarType
5
+ Node, Require, Logical, Var, Relation, Lookup, ScalarType
6
6
  )
7
7
  from relationalai.semantics.metamodel import builtins
8
8
 
@@ -130,14 +130,16 @@ def _split_unique_require_node(node: Require) -> Optional[tuple[list[Var], list[
130
130
  return None
131
131
 
132
132
  # collect variables
133
- all_vars: set[Var] = set()
133
+ all_vars: list[Var] = []
134
134
  for lookup in guard:
135
135
  for arg in lookup.args:
136
136
  if not isinstance(arg, Var):
137
137
  return None
138
- all_vars.add(arg)
138
+ if arg in all_vars:
139
+ continue
140
+ all_vars.append(arg)
139
141
 
140
- unique_vars: set[Var] = set()
142
+ unique_vars: list[Var] = []
141
143
  if len(unique_atom.args) != 1:
142
144
  return None
143
145
  if not isinstance(unique_atom.args[0], (internal.TupleArg, tuple)):
@@ -147,10 +149,12 @@ def _split_unique_require_node(node: Require) -> Optional[tuple[list[Var], list[
147
149
  for arg in unique_atom.args[0]:
148
150
  if not isinstance(arg, Var):
149
151
  return None
150
- unique_vars.add(arg)
152
+ if arg in unique_vars:
153
+ return None
154
+ unique_vars.append(arg)
151
155
 
152
156
  # check that unique vars are a subset of other vars
153
- if not unique_vars.issubset(all_vars):
157
+ if not set(unique_vars).issubset(set(all_vars)):
154
158
  return None
155
159
 
156
160
  return list(all_vars), list(unique_vars), guard
@@ -218,10 +222,10 @@ class FunctionalDependency:
218
222
  - `X` and `Y` are disjoint and covering sets of variables used in `φ`
219
223
  """
220
224
  def __init__(self, guard: Sequence[Lookup], keys: Sequence[Var], values: Sequence[Var]):
221
- self.guard = frozenset(guard)
222
- self.keys = frozenset(keys)
223
- self.values = frozenset(values)
224
- assert self.keys.isdisjoint(self.values), "Keys and values must be disjoint"
225
+ self.guard = tuple(guard)
226
+ self.keys = tuple(keys)
227
+ self.values = tuple(values)
228
+ assert set(self.keys).isdisjoint(set(self.values)), "Keys and values must be disjoint"
225
229
 
226
230
  # for structural fd check
227
231
  self._is_structural:bool = False
@@ -280,3 +284,31 @@ class FunctionalDependency:
280
284
  raise ValueError("Functional dependency is not structural")
281
285
  assert self._structural_rank is not None
282
286
  return self._structural_rank
287
+
288
+ def __str__(self) -> str:
289
+ guard_str = " ∧ ".join([str(atom) for atom in self.guard]).strip()
290
+ keys_str = ", ".join([str(var) for var in self.keys]).strip()
291
+ values_str = ", ".join([str(var) for var in self.values]).strip()
292
+ return f"{guard_str}: {{{keys_str}}} -> {{{values_str}}}"
293
+
294
+ def contains_only_declarable_constraints(node: Node) -> bool:
295
+ """
296
+ Checks whether the input `Logical` node contains only `Require` nodes annotated with
297
+ `declare_constraint`.
298
+ """
299
+ if not isinstance(node, Logical):
300
+ return False
301
+ if len(node.body) == 0:
302
+ return False
303
+ for task in node.body:
304
+ if not isinstance(task, Require):
305
+ return False
306
+ if not is_declarable_constraint(task):
307
+ return False
308
+ return True
309
+
310
+ def is_declarable_constraint(node: Require) -> bool:
311
+ """
312
+ Checks whether the input `Require` node is annotated with `declare_constraint`.
313
+ """
314
+ return builtins.declare_constraint_annotation in node.annotations
@@ -5,6 +5,7 @@ from relationalai.semantics.metamodel.compiler import Pass
5
5
  from relationalai.semantics.metamodel.visitor import Visitor, Rewriter
6
6
  from relationalai.semantics.metamodel.util import OrderedSet, ordered_set
7
7
  from typing import Optional, Any, Tuple, Iterable
8
+ from .functional_dependencies import contains_only_declarable_constraints
8
9
 
9
10
  class QuantifyVars(Pass):
10
11
  """
@@ -67,6 +68,7 @@ class VarScopeInfo(Visitor):
67
68
  IGNORED_NODES = (ir.Type,
68
69
  ir.Var, ir.Literal, ir.Relation, ir.Field,
69
70
  ir.Default, ir.Output, ir.Update, ir.Aggregate,
71
+ ir.Check, ir.Require,
70
72
  ir.Annotation, ir.Rank)
71
73
 
72
74
  def __init__(self):
@@ -74,6 +76,9 @@ class VarScopeInfo(Visitor):
74
76
  self._vars_in_scope = {}
75
77
 
76
78
  def leave(self, node: ir.Node, parent: Optional[ir.Node]=None):
79
+ if contains_only_declarable_constraints(node):
80
+ return node
81
+
77
82
  if isinstance(node, ir.Lookup):
78
83
  self._record(node, helpers.vars(node.args))
79
84
 
@@ -189,6 +194,9 @@ class FindQuantificationNodes(Visitor):
189
194
  self.node_quantifies_vars = {}
190
195
 
191
196
  def enter(self, node: ir.Node, parent: Optional[ir.Node]=None) -> "Visitor":
197
+ if contains_only_declarable_constraints(node):
198
+ return self
199
+
192
200
  if isinstance(node, (ir.Logical, ir.Not)):
193
201
  ignored_vars = _ignored_vars(node)
194
202
  self._handled_vars.update(ignored_vars)
@@ -202,6 +210,9 @@ class FindQuantificationNodes(Visitor):
202
210
  return self
203
211
 
204
212
  def leave(self, node: ir.Node, parent: Optional[ir.Node]=None) -> ir.Node:
213
+ if contains_only_declarable_constraints(node):
214
+ return node
215
+
205
216
  if isinstance(node, (ir.Logical, ir.Not)):
206
217
  ignored_vars = _ignored_vars(node)
207
218
  self._handled_vars.difference_update(ignored_vars)
@@ -221,6 +232,9 @@ class QuantifyVarsRewriter(Rewriter):
221
232
  self.node_quantifies_vars = quant.node_quantifies_vars
222
233
 
223
234
  def handle_logical(self, node: ir.Logical, parent: ir.Node, ctx:Optional[Any]=None) -> ir.Logical:
235
+ if contains_only_declarable_constraints(node):
236
+ return node
237
+
224
238
  new_body = self.walk_list(node.body, node)
225
239
 
226
240
  if node.id in self.node_quantifies_vars:
@@ -21,6 +21,9 @@ CompilableType = Union[
21
21
  # Effects
22
22
  ir.Output,
23
23
  ir.Update,
24
+
25
+ # Constraints
26
+ ir.Require,
24
27
  ]
25
28
 
26
29
  # Preconditions
@@ -524,6 +524,11 @@ recursion_config_annotation = f.annotation(recursion_config, [])
524
524
  discharged = f.relation("discharged", [])
525
525
  discharged_annotation = f.annotation(discharged, [])
526
526
 
527
+ # Require nodes with this annotation will be kept in the final metamodel to be emitted as
528
+ # constraint declarations (LQP)
529
+ declare_constraint = f.relation("declare_constraint", [])
530
+ declare_constraint_annotation = f.annotation(declare_constraint, [])
531
+
527
532
  #
528
533
  # Aggregations
529
534
  #
@@ -5,7 +5,7 @@ from typing import Tuple
5
5
 
6
6
  from relationalai.semantics.metamodel import builtins, ir, factory as f, helpers
7
7
  from relationalai.semantics.metamodel.compiler import Pass, group_tasks
8
- from relationalai.semantics.metamodel.util import OrderedSet, ordered_set, NameCache
8
+ from relationalai.semantics.metamodel.util import NameCache, OrderedSet, ordered_set
9
9
  from relationalai.semantics.metamodel import dependency
10
10
  from relationalai.semantics.metamodel.typer.typer import to_type
11
11
 
@@ -419,9 +419,15 @@ class Flatten(Pass):
419
419
  def handle_require(self, req: ir.Require, ctx: Context):
420
420
  # only extract the domain if it is a somewhat complex Logical and there's more than
421
421
  # one check, otherwise insert it straight into all checks
422
- domain = req.domain
423
- # only generate logic for not discharged requires
424
- if builtins.discharged_annotation not in req.annotations:
422
+ if builtins.discharged_annotation in req.annotations:
423
+ # remove discharged Requires
424
+ return Flatten.HandleResult(None)
425
+ elif builtins.declare_constraint_annotation in req.annotations:
426
+ # leave Requires that are declared constraints
427
+ return Flatten.HandleResult(req)
428
+ else:
429
+ # generate logic for remaining requires
430
+ domain = req.domain
425
431
  if len(req.checks) > 1 and isinstance(domain, ir.Logical) and len(domain.body) > 1:
426
432
  body = OrderedSet.from_iterable(domain.body)
427
433
  vars = helpers.hoisted_vars(domain.hoisted)
@@ -156,6 +156,10 @@ def type_matches(actual:ir.Type, expected:ir.Type, allow_expected_parents=False)
156
156
  if actual == types.Any or expected == types.Any:
157
157
  return True
158
158
 
159
+ # any entity matches any entity (surprise surprise!)
160
+ if extends_any_entity(expected) and not is_primitive(actual):
161
+ return True
162
+
159
163
  # all decimals match across each other
160
164
  if types.is_decimal(actual) and types.is_decimal(expected):
161
165
  return True
@@ -288,6 +292,15 @@ def is_base_primitive(type:ir.Type) -> bool:
288
292
  def is_primitive(type:ir.Type) -> bool:
289
293
  return to_base_primitive(type) is not None
290
294
 
295
+ def extends_any_entity(type:ir.Type) -> bool:
296
+ if type == types.AnyEntity:
297
+ return True
298
+ if isinstance(type, ir.ScalarType):
299
+ for parent in type.super_types:
300
+ if extends_any_entity(parent):
301
+ return True
302
+ return False
303
+
291
304
  def invalid_type(type:ir.Type) -> bool:
292
305
  if isinstance(type, ir.UnionType):
293
306
  # if there are multiple primitives, or a primtive and a non-primitive
@@ -80,6 +80,7 @@ GenericDecimal = ir.ScalarType("GenericDecimal", util.frozen())
80
80
  #
81
81
  Null = ir.ScalarType("Null", util.frozen())
82
82
  Any = ir.ScalarType("Any", util.frozen())
83
+ AnyEntity = ir.ScalarType("AnyEntity", util.frozen())
83
84
  Hash = ir.ScalarType("Hash", util.frozen())
84
85
  String = ir.ScalarType("String", util.frozen())
85
86
  Int64 = ir.ScalarType("Int64")
@@ -144,7 +145,7 @@ def is_null(t: ir.Type) -> bool:
144
145
 
145
146
  def is_abstract_type(t: ir.Type) -> bool:
146
147
  if isinstance(t, ir.ScalarType):
147
- return t in [Any, Number, GenericDecimal]
148
+ return t in [Any, AnyEntity, Number, GenericDecimal]
148
149
  elif isinstance(t, ir.ListType):
149
150
  return is_abstract_type(t.element_type)
150
151
  elif isinstance(t, ir.TupleType):