relationalai 0.12.8__py3-none-any.whl → 0.12.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. relationalai/__init__.py +9 -0
  2. relationalai/clients/__init__.py +2 -2
  3. relationalai/clients/local.py +571 -0
  4. relationalai/clients/snowflake.py +106 -83
  5. relationalai/debugging.py +5 -2
  6. relationalai/semantics/__init__.py +2 -2
  7. relationalai/semantics/internal/__init__.py +2 -2
  8. relationalai/semantics/internal/internal.py +53 -14
  9. relationalai/semantics/lqp/README.md +34 -0
  10. relationalai/semantics/lqp/compiler.py +1 -1
  11. relationalai/semantics/lqp/constructors.py +7 -0
  12. relationalai/semantics/lqp/executor.py +35 -39
  13. relationalai/semantics/lqp/intrinsics.py +4 -3
  14. relationalai/semantics/lqp/ir.py +4 -0
  15. relationalai/semantics/lqp/model2lqp.py +47 -14
  16. relationalai/semantics/lqp/passes.py +7 -4
  17. relationalai/semantics/lqp/rewrite/__init__.py +4 -1
  18. relationalai/semantics/lqp/rewrite/annotate_constraints.py +55 -0
  19. relationalai/semantics/lqp/rewrite/extract_keys.py +22 -3
  20. relationalai/semantics/lqp/rewrite/function_annotations.py +91 -56
  21. relationalai/semantics/lqp/rewrite/functional_dependencies.py +314 -0
  22. relationalai/semantics/lqp/rewrite/quantify_vars.py +14 -0
  23. relationalai/semantics/lqp/validators.py +3 -0
  24. relationalai/semantics/metamodel/builtins.py +10 -0
  25. relationalai/semantics/metamodel/rewrite/extract_nested_logicals.py +5 -4
  26. relationalai/semantics/metamodel/rewrite/flatten.py +10 -4
  27. relationalai/semantics/metamodel/typer/typer.py +13 -0
  28. relationalai/semantics/metamodel/types.py +2 -1
  29. relationalai/semantics/reasoners/graph/core.py +44 -53
  30. relationalai/semantics/rel/compiler.py +19 -1
  31. relationalai/semantics/tests/test_snapshot_abstract.py +3 -0
  32. relationalai/tools/debugger.py +4 -2
  33. relationalai/tools/qb_debugger.py +5 -3
  34. relationalai/util/otel_handler.py +10 -4
  35. {relationalai-0.12.8.dist-info → relationalai-0.12.10.dist-info}/METADATA +2 -2
  36. {relationalai-0.12.8.dist-info → relationalai-0.12.10.dist-info}/RECORD +39 -35
  37. {relationalai-0.12.8.dist-info → relationalai-0.12.10.dist-info}/WHEEL +0 -0
  38. {relationalai-0.12.8.dist-info → relationalai-0.12.10.dist-info}/entry_points.txt +0 -0
  39. {relationalai-0.12.8.dist-info → relationalai-0.12.10.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
  from collections import defaultdict
3
+ from datetime import datetime, timezone
3
4
  import atexit
4
5
  import re
5
6
 
@@ -13,6 +14,7 @@ from relationalai.semantics.lqp import result_helpers
13
14
  from relationalai.semantics.metamodel import ir, factory as f, executor as e
14
15
  from relationalai.semantics.lqp.compiler import Compiler
15
16
  from relationalai.semantics.lqp.intrinsics import mk_intrinsic_datetime_now
17
+ from relationalai.semantics.lqp.constructors import mk_transaction
16
18
  from relationalai.semantics.lqp.types import lqp_type_to_sql
17
19
  from lqp import print as lqp_print, ir as lqp_ir
18
20
  from lqp.parser import construct_configure
@@ -39,6 +41,9 @@ class LQPExecutor(e.Executor):
39
41
  wide_outputs: bool = False,
40
42
  connection: Session | None = None,
41
43
  config: Config | None = None,
44
+ # In order to facilitate snapshot testing, we allow overriding intrinsic definitions
45
+ # like the current time, which would otherwise change between runs.
46
+ intrinsic_overrides: dict = {},
42
47
  ) -> None:
43
48
  super().__init__()
44
49
  self.database = database
@@ -48,6 +53,7 @@ class LQPExecutor(e.Executor):
48
53
  self.compiler = Compiler()
49
54
  self.connection = connection
50
55
  self.config = config or Config()
56
+ self.intrinsic_overrides = intrinsic_overrides
51
57
  self._resources = None
52
58
  self._last_model = None
53
59
  self._last_sources_version = (-1, None)
@@ -60,6 +66,8 @@ class LQPExecutor(e.Executor):
60
66
  resource_class = rai.clients.snowflake.Resources
61
67
  if self.config.get("use_direct_access", USE_DIRECT_ACCESS):
62
68
  resource_class = rai.clients.snowflake.DirectAccessResources
69
+ if self.config.get("platform", "") == "local":
70
+ resource_class = rai.clients.local.LocalResources
63
71
  # NOTE: language="lqp" is not strictly required for LQP execution, but it
64
72
  # will significantly improve performance.
65
73
  self._resources = resource_class(
@@ -305,25 +313,28 @@ class LQPExecutor(e.Executor):
305
313
  config_dict['ivm.maintenance_level'] = lqp_ir.Value(value=ivm_flag, meta=None)
306
314
  return construct_configure(config_dict, None)
307
315
 
316
+ def _should_sync(self, model) :
317
+ if self._last_model != model:
318
+ return lqp_ir.Sync(fragments=[], meta=None)
319
+ else :
320
+ return None
321
+
308
322
  def _compile_intrinsics(self) -> lqp_ir.Epoch:
309
323
  """Construct an epoch that defines a number of built-in definitions used by the
310
324
  emitter."""
311
325
  with debugging.span("compile_intrinsics") as span:
312
326
  span["compile_type"] = "intrinsics"
313
327
 
328
+ now = self.intrinsic_overrides.get('datetime_now', datetime.now(timezone.utc))
329
+
314
330
  debug_info = lqp_ir.DebugInfo(id_to_orig_name={}, meta=None)
315
331
  intrinsics_fragment = lqp_ir.Fragment(
316
332
  id = lqp_ir.FragmentId(id=b"__pyrel_lqp_intrinsics", meta=None),
317
- declarations = [
318
- mk_intrinsic_datetime_now(),
319
- ],
333
+ declarations = [mk_intrinsic_datetime_now(now)],
320
334
  debug_info = debug_info,
321
335
  meta = None,
322
336
  )
323
337
 
324
-
325
- span["lqp"] = lqp_print.to_string(intrinsics_fragment, {"print_names": True, "print_debug": False, "print_csv_filename": False})
326
-
327
338
  return lqp_ir.Epoch(
328
339
  writes=[
329
340
  lqp_ir.Write(write_type=lqp_ir.Define(fragment=intrinsics_fragment, meta=None), meta=None)
@@ -331,6 +342,7 @@ class LQPExecutor(e.Executor):
331
342
  meta=None,
332
343
  )
333
344
 
345
+ # [RAI-40997] We eagerly undefine query fragments so they are not committed to storage
334
346
  def _compile_undefine_query(self, query_epoch: lqp_ir.Epoch) -> lqp_ir.Epoch:
335
347
  fragment_ids = []
336
348
 
@@ -354,47 +366,40 @@ class LQPExecutor(e.Executor):
354
366
 
355
367
  def compile_lqp(self, model: ir.Model, task: ir.Task):
356
368
  configure = self._construct_configure()
369
+ # Merge the epochs into a single transaction. Long term the query bits should all
370
+ # go into a WhatIf action and the intrinsics could be fused with either of them. But
371
+ # for now we just use separate epochs.
372
+ epochs = []
373
+ epochs.append(self._compile_intrinsics())
357
374
 
358
- model_txn = None
359
- if self._last_model != model:
375
+ sync = self._should_sync(model)
376
+
377
+ if sync is not None:
360
378
  with debugging.span("compile", metamodel=model) as install_span:
361
379
  install_span["compile_type"] = "model"
362
- _, model_txn = self.compiler.compile(model, {"fragment_id": b"model"})
363
- model_txn = txn_with_configure(model_txn, configure)
364
- install_span["lqp"] = lqp_print.to_string(model_txn, {"print_names": True, "print_debug": False, "print_csv_filename": False})
380
+ _, model_epoch = self.compiler.compile(model, {"fragment_id": b"model"})
381
+ epochs.append(model_epoch)
365
382
  self._last_model = model
366
383
 
367
- with debugging.span("compile", metamodel=task) as compile_span:
368
- compile_span["compile_type"] = "query"
384
+ with debugging.span("compile", metamodel=task) as txn_span:
369
385
  query = f.compute_model(f.logical([task]))
370
386
  options = {
371
387
  "wide_outputs": self.wide_outputs,
372
388
  "fragment_id": b"query",
373
389
  }
374
390
  result, final_model = self.compiler.compile_inner(query, options)
375
- export_info, query_txn = result
376
- query_txn = txn_with_configure(query_txn, configure)
377
- compile_span["lqp"] = lqp_print.to_string(query_txn, {"print_names": True, "print_debug": False, "print_csv_filename": False})
391
+ export_info, query_epoch = result
378
392
 
379
- # Merge the epochs into a single transactions. Long term the query bits should all
380
- # go into a WhatIf action and the intrinsics could be fused with either of them. But
381
- # for now we just use separate epochs.
382
- epochs = []
393
+ epochs.append(query_epoch)
394
+ epochs.append(self._compile_undefine_query(query_epoch))
383
395
 
384
- epochs.append(self._compile_intrinsics())
385
-
386
- if model_txn is not None:
387
- epochs.append(model_txn.epochs[0])
396
+ txn_span["compile_type"] = "query"
397
+ txn = mk_transaction(epochs=epochs, configure=configure, sync=sync)
398
+ txn_span["lqp"] = lqp_print.to_string(txn, {"print_names": True, "print_debug": False, "print_csv_filename": False})
388
399
 
389
- query_txn_epoch = query_txn.epochs[0]
390
- epochs.append(query_txn_epoch)
391
- epochs.append(self._compile_undefine_query(query_txn_epoch))
392
-
393
- txn = lqp_ir.Transaction(epochs=epochs, configure=configure, meta=None)
394
400
  validate_lqp(txn)
395
401
 
396
402
  txn_proto = convert_transaction(txn)
397
- # TODO (azreika): Should export_info be encoded as part of the txn_proto? [RAI-40312]
398
403
  return final_model, export_info, txn_proto
399
404
 
400
405
  # TODO (azreika): This should probably be split up into exporting and other processing. There are quite a lot of arguments here...
@@ -462,12 +467,3 @@ class LQPExecutor(e.Executor):
462
467
  # If processing the results failed, revert to the previous model.
463
468
  self._last_model = previous_model
464
469
  raise e
465
-
466
- def txn_with_configure(txn: lqp_ir.Transaction, configure: lqp_ir.Configure) -> lqp_ir.Transaction:
467
- """ Return a new transaction with the given configure. If the transaction already has
468
- a configure, it is replaced. """
469
- return lqp_ir.Transaction(
470
- epochs=txn.epochs,
471
- configure=configure,
472
- meta=txn.meta,
473
- )
@@ -1,15 +1,16 @@
1
- from datetime import datetime, timezone
1
+ from datetime import datetime
2
2
 
3
3
  from relationalai.semantics.lqp import ir as lqp
4
4
  from relationalai.semantics.lqp.constructors import mk_abstraction, mk_value, mk_type, mk_primitive
5
5
  from relationalai.semantics.lqp.utils import lqp_hash
6
6
 
7
- def mk_intrinsic_datetime_now() -> lqp.Def:
7
+ # Constructs a definition of the current datetime.
8
+ def mk_intrinsic_datetime_now(dt: datetime) -> lqp.Def:
8
9
  """Constructs a definition of the current datetime."""
9
10
  id = lqp_hash("__pyrel_lqp_intrinsic_datetime_now")
10
11
  out = lqp.Var(name="out", meta=None)
11
12
  out_type = mk_type(lqp.TypeName.DATETIME)
12
- now = mk_value(lqp.DateTimeValue(value=datetime.now(timezone.utc), meta=None))
13
+ now = mk_value(lqp.DateTimeValue(value=dt, meta=None))
13
14
  datetime_now = mk_abstraction(
14
15
  [(out, out_type)],
15
16
  mk_primitive("rel_primitive_eq", [out, now]),
@@ -4,6 +4,7 @@ __all__ = [
4
4
  "SourceInfo",
5
5
  "LqpNode",
6
6
  "Declaration",
7
+ "FunctionalDependency",
7
8
  "Def",
8
9
  "Loop",
9
10
  "Abstraction",
@@ -45,6 +46,7 @@ __all__ = [
45
46
  "Read",
46
47
  "Epoch",
47
48
  "Transaction",
49
+ "Sync",
48
50
  "DebugInfo",
49
51
  "Configure",
50
52
  "IVMConfig",
@@ -59,6 +61,7 @@ from lqp.ir import (
59
61
  SourceInfo,
60
62
  LqpNode,
61
63
  Declaration,
64
+ FunctionalDependency,
62
65
  Def,
63
66
  Loop,
64
67
  Abstraction,
@@ -100,6 +103,7 @@ from lqp.ir import (
100
103
  Read,
101
104
  Epoch,
102
105
  Transaction,
106
+ Sync,
103
107
  DebugInfo,
104
108
  Configure,
105
109
  IVMConfig,
@@ -11,7 +11,9 @@ from relationalai.semantics.lqp.constructors import (
11
11
  )
12
12
  from relationalai.semantics.lqp.utils import TranslationCtx, gen_unique_var
13
13
  from relationalai.semantics.lqp.validators import assert_valid_input
14
-
14
+ from relationalai.semantics.lqp.rewrite.functional_dependencies import (
15
+ normalized_fd, contains_only_declarable_constraints
16
+ )
15
17
  from decimal import Decimal as PyDecimal
16
18
  from datetime import datetime, date, timezone
17
19
  from typing import Tuple, cast, Union, Optional
@@ -19,8 +21,8 @@ from warnings import warn
19
21
  import re
20
22
  import uuid
21
23
 
22
- """ Main access point. Converts the model IR to an LQP transaction. """
23
- def to_lqp(model: ir.Model, fragment_name: bytes, ctx: TranslationCtx) -> tuple[Optional[tuple], lqp.Transaction]:
24
+ # Main access point for translating metamodel to lqp. Converts the model IR to an LQP epoch.
25
+ def to_lqp(model: ir.Model, fragment_name: bytes, ctx: TranslationCtx) -> tuple[Optional[tuple], lqp.Epoch]:
24
26
  assert_valid_input(model)
25
27
  decls: list[lqp.Declaration] = []
26
28
  reads: list[lqp.Read] = []
@@ -50,16 +52,10 @@ def to_lqp(model: ir.Model, fragment_name: bytes, ctx: TranslationCtx) -> tuple[
50
52
  fragment = lqp.Fragment(id=fragment_id, declarations=decls, meta=None, debug_info=debug_info)
51
53
  define_op = lqp.Define(fragment=fragment, meta=None)
52
54
 
53
- txn = lqp.Transaction(
54
- epochs=[
55
- lqp.Epoch(
56
- reads=reads,
57
- writes=[lqp.Write(write_type=define_op, meta=None)],
58
- meta=None
59
- )
60
- ],
61
- configure=lqp.construct_configure({}, None),
62
- meta=None,
55
+ txn = lqp.Epoch(
56
+ reads=reads,
57
+ writes=[lqp.Write(write_type=define_op, meta=None)],
58
+ meta=None
63
59
  )
64
60
 
65
61
  return (export_info, txn)
@@ -108,6 +104,43 @@ def _get_export_reads(export_ids: list[tuple[lqp.RelationId, int, lqp.Type]]) ->
108
104
  return (export_filename, col_info, reads)
109
105
 
110
106
  def _translate_to_decls(ctx: TranslationCtx, rule: ir.Logical) -> list[lqp.Declaration]:
107
+ if contains_only_declarable_constraints(rule):
108
+ return _translate_to_constraint_decls(ctx, rule)
109
+ else:
110
+ return _translate_to_standard_decl(ctx, rule)
111
+
112
+ def _translate_to_constraint_decls(ctx: TranslationCtx, rule: ir.Logical) -> list[lqp.Declaration]:
113
+ constraint_decls: list[lqp.Declaration] = []
114
+ for task in rule.body:
115
+ assert isinstance(task, ir.Require)
116
+ fd = normalized_fd(task)
117
+ assert fd is not None
118
+
119
+ # check for unresolved types
120
+ if any(types.is_any(var.type) for var in fd.keys + fd.values):
121
+ warn(f"Ignoring FD with unresolved type: {fd}")
122
+ continue
123
+
124
+ lqp_typed_keys = [_translate_term(ctx, key) for key in fd.keys]
125
+ lqp_typed_values = [_translate_term(ctx, value) for value in fd.values]
126
+ lqp_typed_vars:list[Tuple[lqp.Var, lqp.Type]] = lqp_typed_keys + lqp_typed_values # type: ignore
127
+ lqp_guard_atoms = [_translate_to_atom(ctx, atom) for atom in fd.guard]
128
+ lqp_guard = mk_abstraction(lqp_typed_vars, mk_and(lqp_guard_atoms))
129
+ lqp_keys:list[lqp.Var] = [var for (var, _) in lqp_typed_keys] # type: ignore
130
+ lqp_values:list[lqp.Var] = [var for (var, _) in lqp_typed_values] # type: ignore
131
+
132
+ fd_decl = lqp.FunctionalDependency(
133
+ guard=lqp_guard,
134
+ keys=lqp_keys,
135
+ values=lqp_values,
136
+ meta=None
137
+ )
138
+
139
+ constraint_decls.append(fd_decl)
140
+
141
+ return constraint_decls
142
+
143
+ def _translate_to_standard_decl(ctx: TranslationCtx, rule: ir.Logical) -> list[lqp.Declaration]:
111
144
  effects = collect_by_type((ir.Output, ir.Update), rule)
112
145
  aggregates = collect_by_type(ir.Aggregate, rule)
113
146
  ranks = collect_by_type(ir.Rank, rule)
@@ -458,7 +491,7 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
458
491
 
459
492
  return mk_exists(result_terms, conjunction)
460
493
 
461
- # `input_args`` hold the types of the input arguments, but they may have been modified
494
+ # `input_args` hold the types of the input arguments, but they may have been modified
462
495
  # if we're dealing with a count, so we use `abstr_args` to find the type.
463
496
  (aggr_arg, aggr_arg_type) = abstr_args[-1]
464
497
 
@@ -6,9 +6,11 @@ from relationalai.semantics.metamodel.util import FrozenOrderedSet
6
6
 
7
7
  from relationalai.semantics.metamodel.rewrite import Flatten
8
8
 
9
- from ..metamodel.rewrite import DischargeConstraints, DNFUnionSplitter, ExtractNestedLogicals, FormatOutputs
10
- from .rewrite import CDC, ExtractCommon, ExtractKeys, FunctionAnnotations, QuantifyVars, Splinter
11
-
9
+ from ..metamodel.rewrite import DNFUnionSplitter, ExtractNestedLogicals, FormatOutputs
10
+ from .rewrite import (
11
+ AnnotateConstraints, CDC, ExtractCommon, ExtractKeys, FunctionAnnotations, QuantifyVars,
12
+ Splinter, SplitMultiCheckRequires
13
+ )
12
14
  from relationalai.semantics.lqp.utils import output_names
13
15
 
14
16
  from typing import cast, List, Sequence, Tuple, Union, Optional, Iterable
@@ -18,8 +20,9 @@ import hashlib
18
20
 
19
21
  def lqp_passes() -> list[Pass]:
20
22
  return [
23
+ SplitMultiCheckRequires(),
21
24
  FunctionAnnotations(),
22
- DischargeConstraints(),
25
+ AnnotateConstraints(),
23
26
  Checker(),
24
27
  CDC(), # specialize to physical relations before extracting nested and typing
25
28
  ExtractNestedLogicals(), # before InferTypes to avoid extracting casts
@@ -1,15 +1,18 @@
1
+ from .annotate_constraints import AnnotateConstraints
1
2
  from .cdc import CDC
2
3
  from .extract_common import ExtractCommon
3
4
  from .extract_keys import ExtractKeys
4
- from .function_annotations import FunctionAnnotations
5
+ from .function_annotations import FunctionAnnotations, SplitMultiCheckRequires
5
6
  from .quantify_vars import QuantifyVars
6
7
  from .splinter import Splinter
7
8
 
8
9
  __all__ = [
10
+ "AnnotateConstraints",
9
11
  "CDC",
10
12
  "ExtractCommon",
11
13
  "ExtractKeys",
12
14
  "FunctionAnnotations",
13
15
  "QuantifyVars",
14
16
  "Splinter",
17
+ "SplitMultiCheckRequires",
15
18
  ]
@@ -0,0 +1,55 @@
1
+ from __future__ import annotations
2
+
3
+ from relationalai.semantics.metamodel import builtins
4
+ from relationalai.semantics.metamodel.ir import Node, Model, Require
5
+ from relationalai.semantics.metamodel.compiler import Pass
6
+ from relationalai.semantics.metamodel.rewrite.discharge_constraints import (
7
+ DischargeConstraintsVisitor
8
+ )
9
+ from relationalai.semantics.lqp.rewrite.functional_dependencies import (
10
+ is_valid_unique_constraint, normalized_fd
11
+ )
12
+
13
+
14
+
15
+ class AnnotateConstraints(Pass):
16
+ """
17
+ Extends `DischargeConstraints` pass by discharging only those Require nodes that cannot
18
+ be declared as constraints in LQP.
19
+
20
+ More precisely, the pass annotates Require nodes depending on how they should be
21
+ treated when generating code:
22
+ * `@declare_constraint` if the Require represents a constraint that can be declared in LQP.
23
+ * `@discharge` if the Require represents a constraint that should be dismissed during
24
+ code generation. Namely, when it cannot be declared in LQP and uses one of the
25
+ `unique`, `exclusive`, `anyof` builtins. These nodes are removed from the IR model
26
+ in the Flatten pass.
27
+ """
28
+
29
+ def rewrite(self, model: Model, options: dict = {}) -> Model:
30
+ return AnnotateConstraintsRewriter().walk(model)
31
+
32
+
33
+ class AnnotateConstraintsRewriter(DischargeConstraintsVisitor):
34
+ """
35
+ Visitor marks all nodes which should be removed from IR model with `discharge` annotation.
36
+ """
37
+
38
+ def _should_be_declarable_constraint(self, node: Require) -> bool:
39
+ if not is_valid_unique_constraint(node):
40
+ return False
41
+ # Currently, we only declare non-structural functional dependencies.
42
+ fd = normalized_fd(node)
43
+ assert fd is not None # already checked by _is_valid_unique_constraint
44
+ return not fd.is_structural
45
+
46
+ def handle_require(self, node: Require, parent: Node):
47
+ if self._should_be_declarable_constraint(node):
48
+ return node.reconstruct(
49
+ node.engine,
50
+ node.domain,
51
+ node.checks,
52
+ node.annotations | [builtins.declare_constraint_annotation]
53
+ )
54
+
55
+ return super().handle_require(node, parent)
@@ -249,6 +249,24 @@ class ExtractKeysRewriter(Rewriter):
249
249
 
250
250
  return f.logical(tuple(outer_body), [])
251
251
 
252
+ def noop_logical(self, node: ir.Logical) -> bool:
253
+ # logicals that don't hoist variables are essentially filters like lookups
254
+ if not node.hoisted:
255
+ return True
256
+ if len(node.body) != 1:
257
+ return False
258
+ inner = node.body[0]
259
+ if not isinstance(inner, (ir.Match, ir.Union)):
260
+ return False
261
+ outer_vars = helpers.hoisted_vars(node.hoisted)
262
+ inner_vars = helpers.hoisted_vars(inner.hoisted)
263
+ for v in outer_vars:
264
+ if v not in inner_vars:
265
+ return False
266
+ # all vars hoisted by the outer logical, are also
267
+ # hoisted by the inner Match/Union
268
+ return True
269
+
252
270
  # compute inital information that's needed for later steps. E.g., what's nullable or
253
271
  # not, do some output columns have a default value, etc.
254
272
  def preprocess_logical(self, node: ir.Logical, output_keys: Iterable[ir.Var]):
@@ -264,10 +282,11 @@ class ExtractKeysRewriter(Rewriter):
264
282
  non_nullable_vars.update(vars)
265
283
  top_level_tasks.add(task)
266
284
  elif isinstance(task, ir.Logical):
267
- # logicals that don't hoist variables are essentially filters like lookups
268
- if not task.hoisted:
285
+ if self.noop_logical(task):
269
286
  top_level_tasks.add(task)
270
- # TODO: should we do something about the inner variables?
287
+ non_nullable_vars.update(helpers.hoisted_vars(task.hoisted))
288
+ continue
289
+
271
290
  for h in task.hoisted:
272
291
  # Hoisted vars without a default are not nullable
273
292
  if isinstance(h, ir.Var):
@@ -1,79 +1,114 @@
1
1
  from __future__ import annotations
2
2
 
3
- from dataclasses import dataclass
4
- from relationalai.semantics.metamodel import ir, compiler as c, visitor as v, builtins
5
- from relationalai.semantics.metamodel.util import OrderedSet, ordered_set
3
+ from typing import Optional
4
+ from relationalai.semantics.metamodel import builtins
5
+ from relationalai.semantics.metamodel.ir import (
6
+ Node, Model, Require, Logical, Relation, Annotation, Update
7
+ )
8
+ from relationalai.semantics.metamodel.compiler import Pass
9
+ from relationalai.semantics.metamodel.visitor import Rewriter, Visitor
10
+ from relationalai.semantics.lqp.rewrite.functional_dependencies import (
11
+ is_valid_unique_constraint, normalized_fd
12
+ )
6
13
 
14
+ # In the future iterations of PyRel metamodel, `Require` nodes will have a single `Check`
15
+ # (and no `errors`). Currently, however, the unique constraints may result in multiple
16
+ # `Check` nodes and for simplicity we split them in to separate `Require` nodes. This step
17
+ # will be removed in the future.
18
+ #
19
+ # Note that unique constraints always have an empty `domain` so apply the splitting only
20
+ # to such `Require` nodes.
21
+ class SplitMultiCheckRequires(Pass):
22
+ """
23
+ Pass splits unique Require nodes that have empty domain but multiple checks into multiple
24
+ Require nodes with single check each.
25
+ """
26
+
27
+ def rewrite(self, model: Model, options: dict = {}) -> Model:
28
+ return SplitMultiCheckRequiresRewriter().walk(model)
29
+
30
+
31
+ class SplitMultiCheckRequiresRewriter(Rewriter):
32
+ """
33
+ Splits unique Require nodes that have empty domain but multiple checks into multiple
34
+ Require nodes with single check each.
35
+ """
36
+ def handle_require(self, node: Require, parent: Node):
7
37
 
8
- class FunctionAnnotations(c.Pass):
38
+ if isinstance(node.domain, Logical) and not node.domain.body and len(node.checks) > 1:
39
+ require_nodes = []
40
+ for check in node.checks:
41
+ single_check = self.walk(check, node)
42
+ require_nodes.append(
43
+ node.reconstruct(node.engine, node.domain, (single_check,), node.annotations)
44
+ )
45
+ return require_nodes
46
+
47
+ return node
48
+
49
+
50
+ class FunctionAnnotations(Pass):
9
51
  """
10
- Pass marks all appropriate relations with `function` annotation.
11
- Criteria:
12
- - there is a Require node with `unique` builtin (appeared as a result of `require(unique(...))`)
13
- - `unique` declared for all the fields in a derived relation expect the last
52
+ Pass marks all appropriate relations with `function` annotation. Collects functional
53
+ dependencies from unique Require nodes and uses this information to identify functional
54
+ relations.
14
55
  """
15
56
 
16
- def rewrite(self, model: ir.Model, options: dict = {}) -> ir.Model:
17
- collect_fd = CollectFunctionalRelationsVisitor()
18
- new_model = collect_fd.walk(model)
19
- # mark relations collected by previous visitor with `@function` annotation
20
- return FunctionalAnnotationsVisitor(collect_fd.functional_relations).walk(new_model)
57
+ def rewrite(self, model: Model, options: dict = {}) -> Model:
58
+ collect_fds = CollectFDsVisitor()
59
+ collect_fds.visit_model(model, None)
60
+ annotated_model = FunctionalAnnotationsRewriter(collect_fds.functional_relations).walk(model)
61
+ return annotated_model
21
62
 
22
63
 
23
- @dataclass
24
- class CollectFunctionalRelationsVisitor(v.Rewriter):
64
+ class CollectFDsVisitor(Visitor):
25
65
  """
26
- Visitor collects all relations which should be marked with `functional` annotation.
66
+ Visitor collects all unique constraints.
27
67
  """
28
68
 
69
+ # Currently, only information about k-functional fd is collected.
29
70
  def __init__(self):
30
71
  super().__init__()
31
- self.functional_relations = ordered_set()
32
-
33
- def handle_check(self, node: ir.Check, parent: ir.Node):
34
- check = self.walk(node.check, node)
35
- assert isinstance(check, ir.Logical)
36
- unique_vars = []
37
- for item in check.body:
38
- # collect vars from `unique` builtin
39
- if isinstance(item, ir.Lookup) and item.relation.name == builtins.unique.name:
40
- var_set = set()
41
- for vargs in item.args:
42
- assert isinstance(vargs, tuple)
43
- var_set.update(vargs)
44
- unique_vars.append(var_set)
45
- functional_rel = []
46
- # mark relations as functional when at least 1 `unique` builtin
47
- if len(unique_vars) > 0:
48
- for item in check.body:
49
- if isinstance(item, ir.Lookup) and not item.relation.name == builtins.unique.name:
50
- for var_set in unique_vars:
51
- # when unique declared for all the fields except the last one in the relation mark it as functional
52
- if var_set == set(item.args[:-1]):
53
- functional_rel.append(item.relation)
54
-
55
- self.functional_relations.update(functional_rel)
56
- return node.reconstruct(check, node.error, node.annotations)
57
-
58
-
59
- @dataclass
60
- class FunctionalAnnotationsVisitor(v.Rewriter):
72
+ self.functional_relations:dict[Relation, int] = {}
73
+
74
+ def visit_require(self, node: Require, parent: Optional[Node]):
75
+ if is_valid_unique_constraint(node):
76
+ fd = normalized_fd(node)
77
+ assert fd is not None
78
+ if fd.is_structural:
79
+ relation = fd.structural_relation
80
+ k = fd.structural_rank
81
+ current_k = self.functional_relations.get(relation, 0)
82
+ self.functional_relations[relation] = max(current_k, k)
83
+
84
+
85
+ class FunctionalAnnotationsRewriter(Rewriter):
61
86
  """
62
- This visitor marks functional_relations with `function` annotation.
87
+ This visitor marks functional_relations with `@function(:checked [, k])` annotation.
63
88
  """
64
89
 
65
- def __init__(self, functional_relations: OrderedSet):
90
+ def __init__(self, functional_relations: dict[Relation, int]):
66
91
  super().__init__()
67
- self._functional_relations = functional_relations
92
+ self.functional_relations = functional_relations
93
+
94
+ def get_functional_annotation(self, rel: Relation) -> Optional[Annotation]:
95
+ k = self.functional_relations.get(rel, None)
96
+ if k is None:
97
+ return None
98
+ if k == 1:
99
+ return builtins.function_checked_annotation
100
+ return builtins.function_ranked_checked_annotation(k)
68
101
 
69
- def handle_relation(self, node: ir.Relation, parent: ir.Node):
70
- if node in self._functional_relations:
71
- return node.reconstruct(node.name, node.fields, node.requires, node.annotations | [builtins.function_checked_annotation],
72
- node.overloads)
102
+ def handle_relation(self, node: Relation, parent: Node):
103
+ function_annotation = self.get_functional_annotation(node)
104
+ if function_annotation:
105
+ return node.reconstruct(node.name, node.fields, node.requires,
106
+ node.annotations | [function_annotation], node.overloads)
73
107
  return node.reconstruct(node.name, node.fields, node.requires, node.annotations, node.overloads)
74
108
 
75
- def handle_update(self, node: ir.Update, parent: ir.Node):
76
- if node.relation in self._functional_relations:
109
+ def handle_update(self, node: Update, parent: Node):
110
+ function_annotation = self.get_functional_annotation(node.relation)
111
+ if function_annotation:
77
112
  return node.reconstruct(node.engine, node.relation, node.args, node.effect,
78
- node.annotations | [builtins.function_checked_annotation])
113
+ node.annotations | [function_annotation])
79
114
  return node.reconstruct(node.engine, node.relation, node.args, node.effect, node.annotations)