relationalai 0.12.7__py3-none-any.whl → 0.12.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. relationalai/clients/snowflake.py +37 -5
  2. relationalai/clients/use_index_poller.py +11 -1
  3. relationalai/semantics/internal/internal.py +29 -7
  4. relationalai/semantics/lqp/compiler.py +1 -1
  5. relationalai/semantics/lqp/constructors.py +6 -0
  6. relationalai/semantics/lqp/executor.py +23 -38
  7. relationalai/semantics/lqp/intrinsics.py +4 -3
  8. relationalai/semantics/lqp/model2lqp.py +6 -12
  9. relationalai/semantics/lqp/passes.py +4 -2
  10. relationalai/semantics/lqp/rewrite/__init__.py +2 -1
  11. relationalai/semantics/lqp/rewrite/function_annotations.py +91 -56
  12. relationalai/semantics/lqp/rewrite/functional_dependencies.py +282 -0
  13. relationalai/semantics/metamodel/builtins.py +6 -0
  14. relationalai/semantics/metamodel/rewrite/__init__.py +2 -1
  15. relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +1 -1
  16. relationalai/semantics/metamodel/rewrite/extract_nested_logicals.py +9 -9
  17. relationalai/semantics/metamodel/rewrite/flatten.py +18 -149
  18. relationalai/semantics/metamodel/rewrite/format_outputs.py +165 -0
  19. relationalai/semantics/reasoners/graph/core.py +98 -70
  20. relationalai/semantics/reasoners/optimization/__init__.py +55 -10
  21. relationalai/semantics/reasoners/optimization/common.py +63 -8
  22. relationalai/semantics/reasoners/optimization/solvers_dev.py +39 -33
  23. relationalai/semantics/reasoners/optimization/solvers_pb.py +1033 -385
  24. relationalai/semantics/rel/compiler.py +21 -2
  25. relationalai/semantics/tests/test_snapshot_abstract.py +3 -0
  26. relationalai/tools/cli.py +10 -0
  27. relationalai/tools/cli_controls.py +15 -0
  28. relationalai/util/otel_handler.py +10 -4
  29. {relationalai-0.12.7.dist-info → relationalai-0.12.9.dist-info}/METADATA +1 -1
  30. {relationalai-0.12.7.dist-info → relationalai-0.12.9.dist-info}/RECORD +33 -31
  31. {relationalai-0.12.7.dist-info → relationalai-0.12.9.dist-info}/WHEEL +0 -0
  32. {relationalai-0.12.7.dist-info → relationalai-0.12.9.dist-info}/entry_points.txt +0 -0
  33. {relationalai-0.12.7.dist-info → relationalai-0.12.9.dist-info}/licenses/LICENSE +0 -0
@@ -851,7 +851,17 @@ Otherwise, remove it from your '{profile}' configuration profile.
851
851
  self.generation
852
852
  )
853
853
  # If cache is valid (data freshness has not expired), skip polling
854
- if not poller.cache.is_valid():
854
+ if poller.cache.is_valid():
855
+ cached_sources = len(poller.cache.sources)
856
+ total_sources = len(sources_list)
857
+ cached_timestamp = poller.cache._metadata.get("cachedIndices", {}).get(poller.cache.key, {}).get("last_use_index_update_on", "")
858
+
859
+ message = f"Using cached data for {cached_sources}/{total_sources} data streams"
860
+ if cached_timestamp:
861
+ print(f"\n{message} (cached at {cached_timestamp})\n")
862
+ else:
863
+ print(f"\n{message}\n")
864
+ else:
855
865
  return poller.poll()
856
866
 
857
867
  #--------------------------------------------------
@@ -3284,12 +3294,24 @@ class DirectAccessResources(Resources):
3284
3294
  try:
3285
3295
  response = _send_request()
3286
3296
  if response.status_code != 200:
3297
+ # For 404 responses with skip_auto_create=True, return immediately to let caller handle it
3298
+ # (e.g., get_engine needs to check 404 and return None for auto_create_engine)
3299
+ # For skip_auto_create=False, continue to auto-creation logic below
3300
+ if response.status_code == 404 and skip_auto_create:
3301
+ return response
3302
+
3287
3303
  try:
3288
3304
  message = response.json().get("message", "")
3289
3305
  except requests.exceptions.JSONDecodeError:
3290
- raise ResponseStatusException(
3291
- f"Failed to parse error response from endpoint {endpoint}.", response
3292
- )
3306
+ # Can't parse JSON response. For skip_auto_create=True (e.g., get_engine),
3307
+ # this should have been caught by the 404 check above, so this is an error.
3308
+ # For skip_auto_create=False, we explicitly check status_code below,
3309
+ # so we don't need to parse the message.
3310
+ if skip_auto_create:
3311
+ raise ResponseStatusException(
3312
+ f"Failed to parse error response from endpoint {endpoint}.", response
3313
+ )
3314
+ message = "" # Not used when we check status_code directly
3293
3315
 
3294
3316
  # fix engine on engine error and retry
3295
3317
  # Skip auto-retry if skip_auto_create is True to avoid recursion
@@ -3482,7 +3504,17 @@ class DirectAccessResources(Resources):
3482
3504
  generation=self.generation,
3483
3505
  )
3484
3506
  # If cache is valid (data freshness has not expired), skip polling
3485
- if not poller.cache.is_valid():
3507
+ if poller.cache.is_valid():
3508
+ cached_sources = len(poller.cache.sources)
3509
+ total_sources = len(sources_list)
3510
+ cached_timestamp = poller.cache._metadata.get("cachedIndices", {}).get(poller.cache.key, {}).get("last_use_index_update_on", "")
3511
+
3512
+ message = f"Using cached data for {cached_sources}/{total_sources} data streams"
3513
+ if cached_timestamp:
3514
+ print(f"\n{message} (cached at {cached_timestamp})\n")
3515
+ else:
3516
+ print(f"\n{message}\n")
3517
+ else:
3486
3518
  return poller.poll()
3487
3519
 
3488
3520
  def _check_exec_async_status(self, txn_id: str, headers: Dict[str, str] | None = None) -> bool:
@@ -250,7 +250,17 @@ class UseIndexPoller:
250
250
  # Cache was used - show how many sources were cached
251
251
  total_sources = len(self.cache.sources)
252
252
  cached_sources = total_sources - len(self.sources)
253
- progress.add_sub_task(f"Using cached data for {cached_sources}/{total_sources} data streams", task_id="cache_usage", category=TASK_CATEGORY_CACHE)
253
+
254
+ # Get the timestamp when sources were cached
255
+ entry = self.cache._metadata.get("cachedIndices", {}).get(self.cache.key, {})
256
+ cached_timestamp = entry.get("last_use_index_update_on", "")
257
+
258
+ message = f"Using cached data for {cached_sources}/{total_sources} data streams"
259
+ # Format the message with timestamp
260
+ if cached_timestamp:
261
+ message += f" (cached at {cached_timestamp})"
262
+
263
+ progress.add_sub_task(message, task_id="cache_usage", category=TASK_CATEGORY_CACHE)
254
264
  # Complete the subtask immediately since it's just informational
255
265
  progress.complete_sub_task("cache_usage")
256
266
 
@@ -40,7 +40,7 @@ _global_id = peekable(itertools.count(0))
40
40
 
41
41
  # Single context variable with default values
42
42
  _overrides = ContextVar("overrides", default = {})
43
- def overrides(key: str, default: bool | str | dict):
43
+ def overrides(key: str, default: bool | str | dict | datetime | None):
44
44
  return _overrides.get().get(key, default)
45
45
 
46
46
  # Flag that users set in the config or directly on the model, but that can still be
@@ -60,6 +60,13 @@ def with_overrides(**kwargs):
60
60
  finally:
61
61
  _overrides.reset(token)
62
62
 
63
+ # Intrinsic values to override for stable snapshots.
64
+ def get_intrinsic_overrides() -> dict[str, Any]:
65
+ datetime_now = overrides('datetime_now', None)
66
+ if datetime_now is not None:
67
+ return {'datetime_now': datetime_now}
68
+ return {}
69
+
63
70
  #--------------------------------------------------
64
71
  # Root tracking
65
72
  #--------------------------------------------------
@@ -953,12 +960,25 @@ class Concept(Producer):
953
960
  self._validate_identifier_relationship(rel)
954
961
  self._add_ref_scheme(*args)
955
962
 
956
- def _add_ref_scheme(self, *args: Relationship|RelationshipReading):
957
- self._reference_schemes.append(args)
958
- # assumed that all Relationship|RelationshipReading are defined on the identified Concept
959
- fields = tuple([arg.__getitem__(0) for arg in args])
960
- uc = Unique(*fields, model=self._model)
961
- require(uc.to_expressions())
963
+ def _add_ref_scheme(self, *rels: Relationship|RelationshipReading):
964
+ # thanks to prior validation we we can safely assume that
965
+ # * the input types are correct due to prior validation
966
+ # * all relationships are binary and defined on this concept
967
+
968
+ self._reference_schemes.append(rels)
969
+
970
+ # for every concept x every field f has at most one value y.
971
+ # f(x,y): x -> y holds
972
+ concept_fields = tuple([rel.__getitem__(0) for rel in rels])
973
+ for field in concept_fields:
974
+ concept_uc = Unique(field, model=self._model)
975
+ require(concept_uc.to_expressions())
976
+
977
+ # for any combination of field values there is at most one concept x.
978
+ # f₁(x,y₁) ∧ … ∧ fₙ(x,yₙ): {y₁,…,yₙ} → {x}
979
+ key_fields = tuple([rel.__getitem__(1) for rel in rels])
980
+ key_uc = Unique(*key_fields, model=self._model)
981
+ require(key_uc.to_expressions())
962
982
 
963
983
  def _validate_identifier_relationship(self, rel:Relationship|RelationshipReading):
964
984
  if rel._arity() != 2:
@@ -2603,6 +2623,7 @@ class Model():
2603
2623
  config_overrides = overrides('config', {})
2604
2624
  for k, v in config_overrides.items():
2605
2625
  self._config.set(k, v)
2626
+ self._intrinsic_overrides = get_intrinsic_overrides()
2606
2627
  self._strict = cast(bool, overrides('strict', strict))
2607
2628
  self._use_lqp = overridable_flag('reasoner.rule.use_lqp', self._config, use_lqp, default=not self._use_sql)
2608
2629
  self._enable_otel_handler = overridable_flag('enable_otel_handler', self._config, enable_otel_handler, default=False)
@@ -2644,6 +2665,7 @@ class Model():
2644
2665
  wide_outputs=self._wide_outputs,
2645
2666
  connection=self._connection,
2646
2667
  config=self._config,
2668
+ intrinsic_overrides=self._intrinsic_overrides,
2647
2669
  )
2648
2670
  elif self._use_sql:
2649
2671
  self._executor = SnowflakeExecutor(
@@ -14,7 +14,7 @@ class Compiler(c.Compiler):
14
14
  super().__init__(lqp_passes())
15
15
  self.def_names = UniqueNames()
16
16
 
17
- def do_compile(self, model: ir.Model, options:dict={}) -> tuple[Optional[tuple], lqp.Transaction]:
17
+ def do_compile(self, model: ir.Model, options:dict={}) -> tuple[Optional[tuple], lqp.Epoch]:
18
18
  fragment_id: bytes = options.get("fragment_id", bytes(404))
19
19
  # Reset the var context for each compilation
20
20
  # TODO: Change to unique var names per lookup
@@ -59,3 +59,9 @@ def mk_pragma(name: str, terms: list[lqp.Var]) -> lqp.Pragma:
59
59
 
60
60
  def mk_attribute(name: str, args: list[lqp.Value]) -> lqp.Attribute:
61
61
  return lqp.Attribute(name=name, args=args, meta=None)
62
+
63
+ def mk_transaction(
64
+ epochs: list[lqp.Epoch],
65
+ configure: lqp.Configure = lqp.construct_configure({}, None),
66
+ ) -> lqp.Transaction:
67
+ return lqp.Transaction(epochs=epochs, configure=configure, meta=None)
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
  from collections import defaultdict
3
+ from datetime import datetime, timezone
3
4
  import atexit
4
5
  import re
5
6
 
@@ -13,6 +14,7 @@ from relationalai.semantics.lqp import result_helpers
13
14
  from relationalai.semantics.metamodel import ir, factory as f, executor as e
14
15
  from relationalai.semantics.lqp.compiler import Compiler
15
16
  from relationalai.semantics.lqp.intrinsics import mk_intrinsic_datetime_now
17
+ from relationalai.semantics.lqp.constructors import mk_transaction
16
18
  from relationalai.semantics.lqp.types import lqp_type_to_sql
17
19
  from lqp import print as lqp_print, ir as lqp_ir
18
20
  from lqp.parser import construct_configure
@@ -39,6 +41,9 @@ class LQPExecutor(e.Executor):
39
41
  wide_outputs: bool = False,
40
42
  connection: Session | None = None,
41
43
  config: Config | None = None,
44
+ # In order to facilitate snapshot testing, we allow overriding intrinsic definitions
45
+ # like the current time, which would otherwise change between runs.
46
+ intrinsic_overrides: dict = {},
42
47
  ) -> None:
43
48
  super().__init__()
44
49
  self.database = database
@@ -48,6 +53,7 @@ class LQPExecutor(e.Executor):
48
53
  self.compiler = Compiler()
49
54
  self.connection = connection
50
55
  self.config = config or Config()
56
+ self.intrinsic_overrides = intrinsic_overrides
51
57
  self._resources = None
52
58
  self._last_model = None
53
59
  self._last_sources_version = (-1, None)
@@ -311,19 +317,16 @@ class LQPExecutor(e.Executor):
311
317
  with debugging.span("compile_intrinsics") as span:
312
318
  span["compile_type"] = "intrinsics"
313
319
 
320
+ now = self.intrinsic_overrides.get('datetime_now', datetime.now(timezone.utc))
321
+
314
322
  debug_info = lqp_ir.DebugInfo(id_to_orig_name={}, meta=None)
315
323
  intrinsics_fragment = lqp_ir.Fragment(
316
324
  id = lqp_ir.FragmentId(id=b"__pyrel_lqp_intrinsics", meta=None),
317
- declarations = [
318
- mk_intrinsic_datetime_now(),
319
- ],
325
+ declarations = [mk_intrinsic_datetime_now(now)],
320
326
  debug_info = debug_info,
321
327
  meta = None,
322
328
  )
323
329
 
324
-
325
- span["lqp"] = lqp_print.to_string(intrinsics_fragment, {"print_names": True, "print_debug": False, "print_csv_filename": False})
326
-
327
330
  return lqp_ir.Epoch(
328
331
  writes=[
329
332
  lqp_ir.Write(write_type=lqp_ir.Define(fragment=intrinsics_fragment, meta=None), meta=None)
@@ -354,47 +357,38 @@ class LQPExecutor(e.Executor):
354
357
 
355
358
  def compile_lqp(self, model: ir.Model, task: ir.Task):
356
359
  configure = self._construct_configure()
360
+ # Merge the epochs into a single transaction. Long term the query bits should all
361
+ # go into a WhatIf action and the intrinsics could be fused with either of them. But
362
+ # for now we just use separate epochs.
363
+ epochs = []
364
+ epochs.append(self._compile_intrinsics())
357
365
 
358
- model_txn = None
359
366
  if self._last_model != model:
360
367
  with debugging.span("compile", metamodel=model) as install_span:
361
368
  install_span["compile_type"] = "model"
362
- _, model_txn = self.compiler.compile(model, {"fragment_id": b"model"})
363
- model_txn = txn_with_configure(model_txn, configure)
364
- install_span["lqp"] = lqp_print.to_string(model_txn, {"print_names": True, "print_debug": False, "print_csv_filename": False})
369
+ _, model_epoch = self.compiler.compile(model, {"fragment_id": b"model"})
370
+ epochs.append(model_epoch)
365
371
  self._last_model = model
366
372
 
367
- with debugging.span("compile", metamodel=task) as compile_span:
368
- compile_span["compile_type"] = "query"
373
+ with debugging.span("compile", metamodel=task) as txn_span:
369
374
  query = f.compute_model(f.logical([task]))
370
375
  options = {
371
376
  "wide_outputs": self.wide_outputs,
372
377
  "fragment_id": b"query",
373
378
  }
374
379
  result, final_model = self.compiler.compile_inner(query, options)
375
- export_info, query_txn = result
376
- query_txn = txn_with_configure(query_txn, configure)
377
- compile_span["lqp"] = lqp_print.to_string(query_txn, {"print_names": True, "print_debug": False, "print_csv_filename": False})
380
+ export_info, query_epoch = result
378
381
 
379
- # Merge the epochs into a single transactions. Long term the query bits should all
380
- # go into a WhatIf action and the intrinsics could be fused with either of them. But
381
- # for now we just use separate epochs.
382
- epochs = []
382
+ epochs.append(query_epoch)
383
+ epochs.append(self._compile_undefine_query(query_epoch))
383
384
 
384
- epochs.append(self._compile_intrinsics())
385
+ txn_span["compile_type"] = "query"
386
+ txn = mk_transaction(epochs=epochs, configure=configure)
387
+ txn_span["lqp"] = lqp_print.to_string(txn, {"print_names": True, "print_debug": False, "print_csv_filename": False})
385
388
 
386
- if model_txn is not None:
387
- epochs.append(model_txn.epochs[0])
388
-
389
- query_txn_epoch = query_txn.epochs[0]
390
- epochs.append(query_txn_epoch)
391
- epochs.append(self._compile_undefine_query(query_txn_epoch))
392
-
393
- txn = lqp_ir.Transaction(epochs=epochs, configure=configure, meta=None)
394
389
  validate_lqp(txn)
395
390
 
396
391
  txn_proto = convert_transaction(txn)
397
- # TODO (azreika): Should export_info be encoded as part of the txn_proto? [RAI-40312]
398
392
  return final_model, export_info, txn_proto
399
393
 
400
394
  # TODO (azreika): This should probably be split up into exporting and other processing. There are quite a lot of arguments here...
@@ -462,12 +456,3 @@ class LQPExecutor(e.Executor):
462
456
  # If processing the results failed, revert to the previous model.
463
457
  self._last_model = previous_model
464
458
  raise e
465
-
466
- def txn_with_configure(txn: lqp_ir.Transaction, configure: lqp_ir.Configure) -> lqp_ir.Transaction:
467
- """ Return a new transaction with the given configure. If the transaction already has
468
- a configure, it is replaced. """
469
- return lqp_ir.Transaction(
470
- epochs=txn.epochs,
471
- configure=configure,
472
- meta=txn.meta,
473
- )
@@ -1,15 +1,16 @@
1
- from datetime import datetime, timezone
1
+ from datetime import datetime
2
2
 
3
3
  from relationalai.semantics.lqp import ir as lqp
4
4
  from relationalai.semantics.lqp.constructors import mk_abstraction, mk_value, mk_type, mk_primitive
5
5
  from relationalai.semantics.lqp.utils import lqp_hash
6
6
 
7
- def mk_intrinsic_datetime_now() -> lqp.Def:
7
+ # Constructs a definition of the current datetime.
8
+ def mk_intrinsic_datetime_now(dt: datetime) -> lqp.Def:
8
9
  """Constructs a definition of the current datetime."""
9
10
  id = lqp_hash("__pyrel_lqp_intrinsic_datetime_now")
10
11
  out = lqp.Var(name="out", meta=None)
11
12
  out_type = mk_type(lqp.TypeName.DATETIME)
12
- now = mk_value(lqp.DateTimeValue(value=datetime.now(timezone.utc), meta=None))
13
+ now = mk_value(lqp.DateTimeValue(value=dt, meta=None))
13
14
  datetime_now = mk_abstraction(
14
15
  [(out, out_type)],
15
16
  mk_primitive("rel_primitive_eq", [out, now]),
@@ -19,8 +19,8 @@ from warnings import warn
19
19
  import re
20
20
  import uuid
21
21
 
22
- """ Main access point. Converts the model IR to an LQP transaction. """
23
- def to_lqp(model: ir.Model, fragment_name: bytes, ctx: TranslationCtx) -> tuple[Optional[tuple], lqp.Transaction]:
22
+ # Main access point for translating metamodel to lqp. Converts the model IR to an LQP epoch.
23
+ def to_lqp(model: ir.Model, fragment_name: bytes, ctx: TranslationCtx) -> tuple[Optional[tuple], lqp.Epoch]:
24
24
  assert_valid_input(model)
25
25
  decls: list[lqp.Declaration] = []
26
26
  reads: list[lqp.Read] = []
@@ -50,16 +50,10 @@ def to_lqp(model: ir.Model, fragment_name: bytes, ctx: TranslationCtx) -> tuple[
50
50
  fragment = lqp.Fragment(id=fragment_id, declarations=decls, meta=None, debug_info=debug_info)
51
51
  define_op = lqp.Define(fragment=fragment, meta=None)
52
52
 
53
- txn = lqp.Transaction(
54
- epochs=[
55
- lqp.Epoch(
56
- reads=reads,
57
- writes=[lqp.Write(write_type=define_op, meta=None)],
58
- meta=None
59
- )
60
- ],
61
- configure=lqp.construct_configure({}, None),
62
- meta=None,
53
+ txn = lqp.Epoch(
54
+ reads=reads,
55
+ writes=[lqp.Write(write_type=define_op, meta=None)],
56
+ meta=None
63
57
  )
64
58
 
65
59
  return (export_info, txn)
@@ -6,8 +6,8 @@ from relationalai.semantics.metamodel.util import FrozenOrderedSet
6
6
 
7
7
  from relationalai.semantics.metamodel.rewrite import Flatten
8
8
 
9
- from ..metamodel.rewrite import DischargeConstraints, DNFUnionSplitter, ExtractNestedLogicals
10
- from .rewrite import CDC, ExtractCommon, ExtractKeys, FunctionAnnotations, QuantifyVars, Splinter
9
+ from ..metamodel.rewrite import DischargeConstraints, DNFUnionSplitter, ExtractNestedLogicals, FormatOutputs
10
+ from .rewrite import CDC, ExtractCommon, ExtractKeys, FunctionAnnotations, QuantifyVars, Splinter, SplitMultiCheckRequires
11
11
 
12
12
  from relationalai.semantics.lqp.utils import output_names
13
13
 
@@ -18,6 +18,7 @@ import hashlib
18
18
 
19
19
  def lqp_passes() -> list[Pass]:
20
20
  return [
21
+ SplitMultiCheckRequires(),
21
22
  FunctionAnnotations(),
22
23
  DischargeConstraints(),
23
24
  Checker(),
@@ -27,6 +28,7 @@ def lqp_passes() -> list[Pass]:
27
28
  DNFUnionSplitter(),
28
29
  ExtractKeys(),
29
30
  ExtractCommon(),
31
+ FormatOutputs(),
30
32
  Flatten(),
31
33
  Splinter(), # Splits multi-headed rules into multiple rules
32
34
  QuantifyVars(), # Adds missing existentials
@@ -1,7 +1,7 @@
1
1
  from .cdc import CDC
2
2
  from .extract_common import ExtractCommon
3
3
  from .extract_keys import ExtractKeys
4
- from .function_annotations import FunctionAnnotations
4
+ from .function_annotations import FunctionAnnotations, SplitMultiCheckRequires
5
5
  from .quantify_vars import QuantifyVars
6
6
  from .splinter import Splinter
7
7
 
@@ -12,4 +12,5 @@ __all__ = [
12
12
  "FunctionAnnotations",
13
13
  "QuantifyVars",
14
14
  "Splinter",
15
+ "SplitMultiCheckRequires",
15
16
  ]
@@ -1,79 +1,114 @@
1
1
  from __future__ import annotations
2
2
 
3
- from dataclasses import dataclass
4
- from relationalai.semantics.metamodel import ir, compiler as c, visitor as v, builtins
5
- from relationalai.semantics.metamodel.util import OrderedSet, ordered_set
3
+ from typing import Optional
4
+ from relationalai.semantics.metamodel import builtins
5
+ from relationalai.semantics.metamodel.ir import (
6
+ Node, Model, Require, Logical, Relation, Annotation, Update
7
+ )
8
+ from relationalai.semantics.metamodel.compiler import Pass
9
+ from relationalai.semantics.metamodel.visitor import Rewriter, Visitor
10
+ from relationalai.semantics.lqp.rewrite.functional_dependencies import (
11
+ is_valid_unique_constraint, normalized_fd
12
+ )
6
13
 
14
+ # In the future iterations of PyRel metamodel, `Require` nodes will have a single `Check`
15
+ # (and no `errors`). Currently, however, the unique constraints may result in multiple
16
+ # `Check` nodes and for simplicity we split them in to separate `Require` nodes. This step
17
+ # will be removed in the future.
18
+ #
19
+ # Note that unique constraints always have an empty `domain` so apply the splitting only
20
+ # to such `Require` nodes.
21
+ class SplitMultiCheckRequires(Pass):
22
+ """
23
+ Pass splits unique Require nodes that have empty domain but multiple checks into multiple
24
+ Require nodes with single check each.
25
+ """
26
+
27
+ def rewrite(self, model: Model, options: dict = {}) -> Model:
28
+ return SplitMultiCheckRequiresRewriter().walk(model)
29
+
30
+
31
+ class SplitMultiCheckRequiresRewriter(Rewriter):
32
+ """
33
+ Splits unique Require nodes that have empty domain but multiple checks into multiple
34
+ Require nodes with single check each.
35
+ """
36
+ def handle_require(self, node: Require, parent: Node):
7
37
 
8
- class FunctionAnnotations(c.Pass):
38
+ if isinstance(node.domain, Logical) and not node.domain.body and len(node.checks) > 1:
39
+ require_nodes = []
40
+ for check in node.checks:
41
+ single_check = self.walk(check, node)
42
+ require_nodes.append(
43
+ node.reconstruct(node.engine, node.domain, (single_check,), node.annotations)
44
+ )
45
+ return require_nodes
46
+
47
+ return node
48
+
49
+
50
+ class FunctionAnnotations(Pass):
9
51
  """
10
- Pass marks all appropriate relations with `function` annotation.
11
- Criteria:
12
- - there is a Require node with `unique` builtin (appeared as a result of `require(unique(...))`)
13
- - `unique` declared for all the fields in a derived relation expect the last
52
+ Pass marks all appropriate relations with `function` annotation. Collects functional
53
+ dependencies from unique Require nodes and uses this information to identify functional
54
+ relations.
14
55
  """
15
56
 
16
- def rewrite(self, model: ir.Model, options: dict = {}) -> ir.Model:
17
- collect_fd = CollectFunctionalRelationsVisitor()
18
- new_model = collect_fd.walk(model)
19
- # mark relations collected by previous visitor with `@function` annotation
20
- return FunctionalAnnotationsVisitor(collect_fd.functional_relations).walk(new_model)
57
+ def rewrite(self, model: Model, options: dict = {}) -> Model:
58
+ collect_fds = CollectFDsVisitor()
59
+ collect_fds.visit_model(model, None)
60
+ annotated_model = FunctionalAnnotationsRewriter(collect_fds.functional_relations).walk(model)
61
+ return annotated_model
21
62
 
22
63
 
23
- @dataclass
24
- class CollectFunctionalRelationsVisitor(v.Rewriter):
64
+ class CollectFDsVisitor(Visitor):
25
65
  """
26
- Visitor collects all relations which should be marked with `functional` annotation.
66
+ Visitor collects all unique constraints.
27
67
  """
28
68
 
69
+ # Currently, only information about k-functional fd is collected.
29
70
  def __init__(self):
30
71
  super().__init__()
31
- self.functional_relations = ordered_set()
32
-
33
- def handle_check(self, node: ir.Check, parent: ir.Node):
34
- check = self.walk(node.check, node)
35
- assert isinstance(check, ir.Logical)
36
- unique_vars = []
37
- for item in check.body:
38
- # collect vars from `unique` builtin
39
- if isinstance(item, ir.Lookup) and item.relation.name == builtins.unique.name:
40
- var_set = set()
41
- for vargs in item.args:
42
- assert isinstance(vargs, tuple)
43
- var_set.update(vargs)
44
- unique_vars.append(var_set)
45
- functional_rel = []
46
- # mark relations as functional when at least 1 `unique` builtin
47
- if len(unique_vars) > 0:
48
- for item in check.body:
49
- if isinstance(item, ir.Lookup) and not item.relation.name == builtins.unique.name:
50
- for var_set in unique_vars:
51
- # when unique declared for all the fields except the last one in the relation mark it as functional
52
- if var_set == set(item.args[:-1]):
53
- functional_rel.append(item.relation)
54
-
55
- self.functional_relations.update(functional_rel)
56
- return node.reconstruct(check, node.error, node.annotations)
57
-
58
-
59
- @dataclass
60
- class FunctionalAnnotationsVisitor(v.Rewriter):
72
+ self.functional_relations:dict[Relation, int] = {}
73
+
74
+ def visit_require(self, node: Require, parent: Optional[Node]):
75
+ if is_valid_unique_constraint(node):
76
+ fd = normalized_fd(node)
77
+ assert fd is not None
78
+ if fd.is_structural:
79
+ relation = fd.structural_relation
80
+ k = fd.structural_rank
81
+ current_k = self.functional_relations.get(relation, 0)
82
+ self.functional_relations[relation] = max(current_k, k)
83
+
84
+
85
+ class FunctionalAnnotationsRewriter(Rewriter):
61
86
  """
62
- This visitor marks functional_relations with `function` annotation.
87
+ This visitor marks functional_relations with `@function(:checked [, k])` annotation.
63
88
  """
64
89
 
65
- def __init__(self, functional_relations: OrderedSet):
90
+ def __init__(self, functional_relations: dict[Relation, int]):
66
91
  super().__init__()
67
- self._functional_relations = functional_relations
92
+ self.functional_relations = functional_relations
93
+
94
+ def get_functional_annotation(self, rel: Relation) -> Optional[Annotation]:
95
+ k = self.functional_relations.get(rel, None)
96
+ if k is None:
97
+ return None
98
+ if k == 1:
99
+ return builtins.function_checked_annotation
100
+ return builtins.function_ranked_checked_annotation(k)
68
101
 
69
- def handle_relation(self, node: ir.Relation, parent: ir.Node):
70
- if node in self._functional_relations:
71
- return node.reconstruct(node.name, node.fields, node.requires, node.annotations | [builtins.function_checked_annotation],
72
- node.overloads)
102
+ def handle_relation(self, node: Relation, parent: Node):
103
+ function_annotation = self.get_functional_annotation(node)
104
+ if function_annotation:
105
+ return node.reconstruct(node.name, node.fields, node.requires,
106
+ node.annotations | [function_annotation], node.overloads)
73
107
  return node.reconstruct(node.name, node.fields, node.requires, node.annotations, node.overloads)
74
108
 
75
- def handle_update(self, node: ir.Update, parent: ir.Node):
76
- if node.relation in self._functional_relations:
109
+ def handle_update(self, node: Update, parent: Node):
110
+ function_annotation = self.get_functional_annotation(node.relation)
111
+ if function_annotation:
77
112
  return node.reconstruct(node.engine, node.relation, node.args, node.effect,
78
- node.annotations | [builtins.function_checked_annotation])
113
+ node.annotations | [function_annotation])
79
114
  return node.reconstruct(node.engine, node.relation, node.args, node.effect, node.annotations)