relationalai 1.0.0a3__py3-none-any.whl → 1.0.0a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. relationalai/config/config.py +47 -21
  2. relationalai/config/connections/__init__.py +5 -2
  3. relationalai/config/connections/duckdb.py +2 -2
  4. relationalai/config/connections/local.py +31 -0
  5. relationalai/config/connections/snowflake.py +0 -1
  6. relationalai/config/external/raiconfig_converter.py +235 -0
  7. relationalai/config/external/raiconfig_models.py +202 -0
  8. relationalai/config/external/utils.py +31 -0
  9. relationalai/config/shims.py +1 -0
  10. relationalai/semantics/__init__.py +10 -8
  11. relationalai/semantics/backends/sql/sql_compiler.py +1 -4
  12. relationalai/semantics/experimental/__init__.py +0 -0
  13. relationalai/semantics/experimental/builder.py +295 -0
  14. relationalai/semantics/experimental/builtins.py +154 -0
  15. relationalai/semantics/frontend/base.py +67 -42
  16. relationalai/semantics/frontend/core.py +34 -6
  17. relationalai/semantics/frontend/front_compiler.py +209 -37
  18. relationalai/semantics/frontend/pprint.py +6 -2
  19. relationalai/semantics/metamodel/__init__.py +7 -0
  20. relationalai/semantics/metamodel/metamodel.py +2 -0
  21. relationalai/semantics/metamodel/metamodel_analyzer.py +58 -16
  22. relationalai/semantics/metamodel/pprint.py +6 -1
  23. relationalai/semantics/metamodel/rewriter.py +11 -7
  24. relationalai/semantics/metamodel/typer.py +116 -41
  25. relationalai/semantics/reasoners/__init__.py +11 -0
  26. relationalai/semantics/reasoners/graph/__init__.py +35 -0
  27. relationalai/semantics/reasoners/graph/core.py +9028 -0
  28. relationalai/semantics/std/__init__.py +30 -10
  29. relationalai/semantics/std/aggregates.py +641 -12
  30. relationalai/semantics/std/common.py +146 -13
  31. relationalai/semantics/std/constraints.py +71 -1
  32. relationalai/semantics/std/datetime.py +904 -21
  33. relationalai/semantics/std/decimals.py +143 -2
  34. relationalai/semantics/std/floats.py +57 -4
  35. relationalai/semantics/std/integers.py +98 -4
  36. relationalai/semantics/std/math.py +857 -35
  37. relationalai/semantics/std/numbers.py +216 -20
  38. relationalai/semantics/std/re.py +213 -5
  39. relationalai/semantics/std/strings.py +437 -44
  40. relationalai/shims/executor.py +60 -52
  41. relationalai/shims/fixtures.py +85 -0
  42. relationalai/shims/helpers.py +26 -2
  43. relationalai/shims/hoister.py +28 -9
  44. relationalai/shims/mm2v0.py +204 -173
  45. relationalai/tools/cli/cli.py +192 -10
  46. relationalai/tools/cli/components/progress_reader.py +1 -1
  47. relationalai/tools/cli/docs.py +394 -0
  48. relationalai/tools/debugger.py +11 -4
  49. relationalai/tools/qb_debugger.py +435 -0
  50. relationalai/tools/typer_debugger.py +1 -2
  51. relationalai/util/dataclasses.py +3 -5
  52. relationalai/util/docutils.py +1 -2
  53. relationalai/util/error.py +2 -5
  54. relationalai/util/python.py +23 -0
  55. relationalai/util/runtime.py +1 -2
  56. relationalai/util/schema.py +2 -4
  57. relationalai/util/structures.py +4 -2
  58. relationalai/util/tracing.py +8 -2
  59. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/METADATA +8 -5
  60. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/RECORD +118 -95
  61. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/WHEEL +1 -1
  62. v0/relationalai/__init__.py +1 -1
  63. v0/relationalai/clients/client.py +52 -18
  64. v0/relationalai/clients/exec_txn_poller.py +122 -0
  65. v0/relationalai/clients/local.py +23 -8
  66. v0/relationalai/clients/resources/azure/azure.py +36 -11
  67. v0/relationalai/clients/resources/snowflake/__init__.py +4 -4
  68. v0/relationalai/clients/resources/snowflake/cli_resources.py +12 -1
  69. v0/relationalai/clients/resources/snowflake/direct_access_resources.py +124 -100
  70. v0/relationalai/clients/resources/snowflake/engine_service.py +381 -0
  71. v0/relationalai/clients/resources/snowflake/engine_state_handlers.py +35 -29
  72. v0/relationalai/clients/resources/snowflake/error_handlers.py +43 -2
  73. v0/relationalai/clients/resources/snowflake/snowflake.py +277 -179
  74. v0/relationalai/clients/resources/snowflake/use_index_poller.py +8 -0
  75. v0/relationalai/clients/types.py +5 -0
  76. v0/relationalai/errors.py +19 -1
  77. v0/relationalai/semantics/lqp/algorithms.py +173 -0
  78. v0/relationalai/semantics/lqp/builtins.py +199 -2
  79. v0/relationalai/semantics/lqp/executor.py +68 -37
  80. v0/relationalai/semantics/lqp/ir.py +28 -2
  81. v0/relationalai/semantics/lqp/model2lqp.py +215 -45
  82. v0/relationalai/semantics/lqp/passes.py +13 -658
  83. v0/relationalai/semantics/lqp/rewrite/__init__.py +12 -0
  84. v0/relationalai/semantics/lqp/rewrite/algorithm.py +385 -0
  85. v0/relationalai/semantics/lqp/rewrite/constants_to_vars.py +70 -0
  86. v0/relationalai/semantics/lqp/rewrite/deduplicate_vars.py +104 -0
  87. v0/relationalai/semantics/lqp/rewrite/eliminate_data.py +108 -0
  88. v0/relationalai/semantics/lqp/rewrite/extract_keys.py +25 -3
  89. v0/relationalai/semantics/lqp/rewrite/period_math.py +77 -0
  90. v0/relationalai/semantics/lqp/rewrite/quantify_vars.py +65 -31
  91. v0/relationalai/semantics/lqp/rewrite/unify_definitions.py +317 -0
  92. v0/relationalai/semantics/lqp/utils.py +11 -1
  93. v0/relationalai/semantics/lqp/validators.py +14 -1
  94. v0/relationalai/semantics/metamodel/builtins.py +2 -1
  95. v0/relationalai/semantics/metamodel/compiler.py +2 -1
  96. v0/relationalai/semantics/metamodel/dependency.py +12 -3
  97. v0/relationalai/semantics/metamodel/executor.py +11 -1
  98. v0/relationalai/semantics/metamodel/factory.py +2 -2
  99. v0/relationalai/semantics/metamodel/helpers.py +7 -0
  100. v0/relationalai/semantics/metamodel/ir.py +3 -2
  101. v0/relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +30 -20
  102. v0/relationalai/semantics/metamodel/rewrite/flatten.py +50 -13
  103. v0/relationalai/semantics/metamodel/rewrite/format_outputs.py +9 -3
  104. v0/relationalai/semantics/metamodel/typer/checker.py +6 -4
  105. v0/relationalai/semantics/metamodel/typer/typer.py +4 -3
  106. v0/relationalai/semantics/metamodel/visitor.py +4 -3
  107. v0/relationalai/semantics/reasoners/optimization/solvers_dev.py +1 -1
  108. v0/relationalai/semantics/reasoners/optimization/solvers_pb.py +336 -86
  109. v0/relationalai/semantics/rel/compiler.py +2 -1
  110. v0/relationalai/semantics/rel/executor.py +3 -2
  111. v0/relationalai/semantics/tests/lqp/__init__.py +0 -0
  112. v0/relationalai/semantics/tests/lqp/algorithms.py +345 -0
  113. v0/relationalai/tools/cli.py +339 -186
  114. v0/relationalai/tools/cli_controls.py +216 -67
  115. v0/relationalai/tools/cli_helpers.py +410 -6
  116. v0/relationalai/util/format.py +5 -2
  117. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/entry_points.txt +0 -0
  118. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, Type
3
3
 
4
4
  from relationalai.util.structures import OrderedSet
5
5
  from .metamodel import (
6
- Node, Model, NumberType, ScalarType, Task, Logical, Sequence, Union, Match, Until, Wait, Loop, Break, Require,
6
+ Node, Model, Relation, UnresolvedRelation, NumberType, ScalarType, Logical, Sequence, Union, Match, Until, Wait, Loop, Break, Require,
7
7
  Not, Exists, Lookup, Update, Aggregate, Construct, Var, Annotation, Field, Literal, UnionType, ListType, TupleType
8
8
  )
9
9
  import datetime as _dt
@@ -17,7 +17,11 @@ NO_WALK = object()
17
17
 
18
18
  WALK_FIELDS: Dict[Type[Node], tuple[str, ...]] = {
19
19
  # Top-level
20
- Model: ("root",),
20
+ Model: ("root", "relations",),
21
+
22
+ # Relation
23
+ Relation: ("fields",),
24
+ UnresolvedRelation: ("fields",),
21
25
 
22
26
  # Control flow
23
27
  Logical: ("body",),
@@ -131,7 +135,7 @@ class Walker:
131
135
  ns = {"NO_WALK": NO_WALK}
132
136
  lines = ["def walker(self, node):"]
133
137
  lines.append(f" no_walk = self.{enter_name}(node)")
134
- lines.append(f" if no_walk is not NO_WALK:")
138
+ lines.append(" if no_walk is not NO_WALK:")
135
139
  lines.extend(f" self._walk(node.{fld})" for fld in WALK_FIELDS.get(node, ()))
136
140
  lines.append(f" self.{exit_name}(node)")
137
141
  exec("\n".join(lines), ns)
@@ -240,8 +244,8 @@ class Rewriter:
240
244
 
241
245
  # pre-order hook: enter_<name>(node) -> optional replacement
242
246
  lines.append(f" tmp = self.{enter_name}(node)")
243
- lines.append(f" if tmp is not None and tmp is not node:")
244
- lines.append(f" new = tmp")
247
+ lines.append(" if tmp is not None and tmp is not node:")
248
+ lines.append(" new = tmp")
245
249
  # rewrite children in WALK_FIELDS, track changes
246
250
  lines.extend(f" v_{fld} = self._rewrite(new.{fld})" for fld in fields)
247
251
 
@@ -253,8 +257,8 @@ class Rewriter:
253
257
 
254
258
  # post-order hook: <name>(node) -> optional replacement
255
259
  lines.append(f" tmp = self.{exit_name}(new)")
256
- lines.append(f" if tmp is not None and tmp is not new:")
257
- lines.append(f" new = tmp")
260
+ lines.append(" if tmp is not None and tmp is not new:")
261
+ lines.append(" new = tmp")
258
262
 
259
263
  lines.append(" return new")
260
264
 
@@ -10,9 +10,36 @@ from ...util.naming import sanitize
10
10
  from ...util.structures import OrderedSet
11
11
 
12
12
  from . import metamodel as mm, builtins as bt
13
- from .rewriter import Walker, Rewriter, NO_WALK
13
+ from .rewriter import Walker, Rewriter
14
14
  from .builtins import builtins as b
15
15
 
16
+ from contextlib import contextmanager
17
+
18
+
19
+ #--------------------------------------------------
20
+ # Typer Testing
21
+ #--------------------------------------------------
22
+ COLLECTED_ERRORS = None
23
+ @contextmanager
24
+ def errors(error_type):
25
+ """
26
+ Check that this type error is raised within the context.
27
+
28
+ :param error_type: the type of the expected error.
29
+ """
30
+ global COLLECTED_ERRORS
31
+ COLLECTED_ERRORS = []
32
+ try:
33
+ yield
34
+ except Exception:
35
+ # we expect the typer to raise an error
36
+ pass
37
+ finally:
38
+ if not any(isinstance(error, error_type) for error in COLLECTED_ERRORS):
39
+ exc("ExpectedWarning", f"Expected warning of type {error_type.__name__} but none was raised.")
40
+ COLLECTED_ERRORS = None
41
+
42
+
16
43
  #--------------------------------------------------
17
44
  # Typer
18
45
  #--------------------------------------------------
@@ -59,7 +86,7 @@ class Typer:
59
86
  Analyzer(net).analyze(node)
60
87
 
61
88
  # propagate the types through the network
62
- with tracing.span("typer.propagate") as span:
89
+ with tracing.span("typer.propagate"): # as span:
63
90
  net.propagate()
64
91
  # span["type_graph"] = net.to_mermaid()
65
92
 
@@ -68,12 +95,18 @@ class Typer:
68
95
  replacer = Replacer(net)
69
96
  final = replacer.rewrite(node)
70
97
 
71
- # report any errors found during typing
98
+ # report any errors found during type inference
72
99
  if net.errors:
73
- for error in net.errors:
74
- error.report()
75
- if self.enforce:
76
- exc("TyperError", "Type errors detected during type inference.")
100
+ if COLLECTED_ERRORS is not None:
101
+ # collecting errors for tests, just add them to the list
102
+ COLLECTED_ERRORS.extend(net.errors)
103
+ # raise to avoid further processing
104
+ raise Exception("Typer errors collected.")
105
+ else:
106
+ for error in net.errors:
107
+ error.report()
108
+ if self.enforce:
109
+ exc("TyperError", "Type errors detected during type inference.")
77
110
  return final
78
111
 
79
112
 
@@ -146,10 +179,8 @@ class PropagationNetwork():
146
179
  # TODO - consider renaming this to UnresolvedReference
147
180
  self.errors.append(UnresolvedOverload(node, [self.resolve(a) for a in node.args]))
148
181
 
149
- def unresolved_type(self, node: Node):
150
- # TODO - this is not being used yet, we need a pass at the end to check for any node
151
- # that we could not resolve
152
- self.errors.append(UnresolvedType(node))
182
+ def unresolved_type(self, node: mm.Lookup|mm.Aggregate, arg: mm.Var):
183
+ self.errors.append(UnresolvedType(node, arg))
153
184
 
154
185
  def has_errors(self, node: Node) -> bool:
155
186
  for mismatch in self.errors:
@@ -259,12 +290,12 @@ class PropagationNetwork():
259
290
  # start with the loaded roots + all literals + sources of edges without back edges
260
291
  work_list.extend(self.loaded_roots)
261
292
  for source in self.edges.keys():
262
- if isinstance(source, (mm.Literal)) or not source in self.back_edges:
293
+ if isinstance(source, (mm.Literal)) or source not in self.back_edges:
263
294
  work_list.append(source)
264
295
 
265
296
  # limit the number of iterations to avoid infinite loops
266
297
  i = 0
267
- max_iterations = 100 * len(self.edges)
298
+ max_iterations = 100 * (len(self.edges) + len(work_list))
268
299
 
269
300
  # propagate types until we reach a fixed point
270
301
  while work_list:
@@ -388,7 +419,7 @@ class PropagationNetwork():
388
419
  # this is only attempted if all input types match the field types, i.e. no
389
420
  # conversions are needed
390
421
  resolved_fields = types
391
- if bt.is_function(relation) and len(set(resolved_fields)) == 1 and not relation in self.NON_TYPE_PRESERVERS and\
422
+ if bt.is_function(relation) and len(set(resolved_fields)) == 1 and relation not in self.NON_TYPE_PRESERVERS and\
392
423
  all(type_matches(arg_type, field_type) for arg_type, field_type, field in zip(resolved_args, types, relation.fields) if field.input):
393
424
 
394
425
  input_types = set([arg_type for field, arg_type
@@ -409,6 +440,9 @@ class PropagationNetwork():
409
440
  for field, field_type, arg in zip(relation.fields, resolved_fields, task.args):
410
441
  if not field.input and isinstance(arg, mm.Var):
411
442
  self.add_resolved_type(arg, field_type)
443
+ # this can be the way to learn the specialized number type for the field
444
+ if bt.is_abstract(field.type):
445
+ self.add_resolved_type(field, field_type)
412
446
 
413
447
  elif b.core.TypeVar in types:
414
448
  # this relation contains type vars, so we have to make sure that all args
@@ -446,13 +480,27 @@ class PropagationNetwork():
446
480
 
447
481
  else:
448
482
  # no typevar or number specialization shenanigans, just propagate field types to args
483
+ is_population_lookup = isinstance(relation, mm.TypeNode)
449
484
  for field, field_type, arg, arg_type in zip(relation.fields, resolved_fields, task.args, resolved_args):
450
- # add a resolved type if we learned more about the arg's type
451
- if isinstance(arg, mm.Var) and (bt.is_abstract(arg_type) or bt.extends(field_type, arg_type)):
452
- self.add_resolved_type(arg, field_type)
453
- elif not type_matches(arg_type, field_type) and not conversion_allowed(arg_type, field_type):
454
- self.type_mismatch(task, field_type, arg_type)
455
-
485
+ # if the arg does not match the declared field type, and cannot be converted into it, it's a mismatch
486
+ # note that we use the declared field type here, field.type, not the resolved field type, which may
487
+ # have been refined by updates
488
+ if not type_matches(arg_type, field.type, accept_expected_super_types=is_population_lookup) and not conversion_allowed(arg_type, field.type):
489
+ # if the field is an output and the arg is still abstract, we may be able to refine it later,
490
+ # as long as they are compatible, so skip reporting as it will come back later
491
+ # TODO - replace this with the analyzer reversing an edge for this case
492
+ if not field.input and bt.is_abstract(arg_type) and (type_matches(field.type, arg_type, accept_expected_super_types=is_population_lookup) or conversion_allowed(field.type, arg_type)):
493
+ continue
494
+ self.type_mismatch(task, field.type, arg_type)
495
+ elif isinstance(arg, mm.Var):
496
+ if bt.is_abstract(field_type) and bt.is_abstract(arg_type):
497
+ # if the resolved field type is still abstract, we cannot resolve the type of the arg
498
+ self.unresolved_type(task, arg)
499
+ elif bt.is_abstract(arg_type) or bt.extends(field_type, arg_type):
500
+ # we learned something about this arg, so store it
501
+ self.add_resolved_type(arg, field_type)
502
+ elif not type_matches(arg_type, field.type) and is_population_lookup:
503
+ self.add_resolved_type(arg, field_type)
456
504
 
457
505
  # we try to preserve types for relations that are functions (i.e. potentially multiple
458
506
  # input but a single output) and where all types are the same. However, there are some
@@ -526,7 +574,8 @@ class PropagationNetwork():
526
574
  if isinstance(x, mm.NumberType):
527
575
  if number is None or x.scale > number.scale or (x.scale == number.scale and x.precision > number.precision):
528
576
  number = x
529
- assert(number is not None)
577
+ if number is None:
578
+ number = b.core.DefaultNumber
530
579
  return number, [number if bt.is_number(t) else t for t in field_types]
531
580
 
532
581
 
@@ -537,15 +586,15 @@ class PropagationNetwork():
537
586
  # draw the network as a mermaid graph for the debugger
538
587
  def to_mermaid(self, max_edges=500) -> str:
539
588
  # add links for edges while collecting nodes
540
- nodes = OrderedSet()
589
+ nodes = dict()
541
590
  link_strs = []
542
591
  # edges
543
592
  for src, dsts in self.edges.items():
544
- nodes.add(src)
593
+ nodes[src.id] = src
545
594
  for dst in dsts:
546
595
  if len(link_strs) > max_edges:
547
596
  break
548
- nodes.add(dst)
597
+ nodes[dst.id] = dst
549
598
  link_strs.append(f"n{src.id} --> n{dst.id}")
550
599
  if len(link_strs) > max_edges:
551
600
  break
@@ -568,7 +617,7 @@ class PropagationNetwork():
568
617
 
569
618
  resolved = self.resolved_types
570
619
  node_strs = []
571
- for node in nodes:
620
+ for _, node in nodes.items():
572
621
  klass = ""
573
622
  if isinstance(node, mm.Var):
574
623
  ir_type = resolved.get(node) or self.resolve(node)
@@ -642,10 +691,6 @@ class Analyzer(Walker):
642
691
  def analyze(self, node: mm.Node):
643
692
  self(node)
644
693
 
645
- # TODO - ignoring requires for now because the typing of constraints seems incorrect
646
- def enter_require(self, require: mm.Require):
647
- return NO_WALK
648
-
649
694
  def compute_potential_targets(self, relation: mm.Relation):
650
695
  # register potential targets for placeholders
651
696
  if bt.is_placeholder(relation):
@@ -681,7 +726,7 @@ class Analyzer(Walker):
681
726
  is_placeholder = bt.is_placeholder(relation)
682
727
  for field, arg in zip(relation.fields, task.args):
683
728
  if isinstance(arg, (mm.Var, mm.Literal)):
684
- if field.input:
729
+ if field.input or isinstance(arg, mm.Literal):
685
730
  # we need to resolve all inputs before resolving the relation
686
731
  self.net.add_edge(arg, task)
687
732
  else:
@@ -708,9 +753,34 @@ class Replacer(Rewriter):
708
753
  def __init__(self, net:PropagationNetwork):
709
754
  super().__init__()
710
755
  self.net = net
756
+ # map from relation id to rewritten relation, to make sure that lookups/updates
757
+ # point to the correct object in case we refined the relation field types
758
+ self.relations:dict[int, mm.Relation] = {}
759
+ # logicals created during rewriting (to inline only the ones we created)
760
+ self.logicals = set()
761
+
762
+ def rewrite(self, node: T) -> T:
763
+ try:
764
+ if isinstance(node, mm.Model):
765
+ # first rewrite all relations to update their field types
766
+ relations = self(node.relations) # type: ignore
767
+ assert isinstance(relations, tuple)
768
+ # index the rewritten relations
769
+ self.relations = {r.id: r for r in relations}
770
+ # then rewrite the root with the updated relations
771
+ root = self(node.root) # type: ignore
772
+ return node.mut(relations = relations, root = root)
773
+ else:
774
+ return self(node) # type: ignore
775
+ finally:
776
+ self.relations = {}
777
+ self.logicals = set()
711
778
 
712
- def rewrite(self, model: T) -> T:
713
- return self(model) # type: ignore
779
+ def wrap(self, body: tuple) -> mm.Logical:
780
+ """ Wrap the given body in a logical and track it for inlining later. """
781
+ logical = mm.Logical(body)
782
+ self.logicals.add(logical.id)
783
+ return logical
714
784
 
715
785
  def logical(self, logical: mm.Logical):
716
786
  if len(logical.body) == 0:
@@ -718,7 +788,7 @@ class Replacer(Rewriter):
718
788
  # inline logicals that are just there to group other nodes during rewrite
719
789
  body = []
720
790
  for child in logical.body:
721
- if isinstance(child, mm.Logical) and not child.optional and not child.scope:
791
+ if isinstance(child, mm.Logical) and not child.optional and not child.scope and child.id in self.logicals:
722
792
  body.extend(child.body)
723
793
  else:
724
794
  body.append(child)
@@ -729,8 +799,6 @@ class Replacer(Rewriter):
729
799
  #--------------------------------------------------
730
800
 
731
801
  def field(self, node: mm.Field):
732
- # TODO - this is only modifying the relation in the model, but then we have a new
733
- # relation there, which is different than the object referenced by tasks.
734
802
  if node in self.net.resolved_types:
735
803
  return mm.Field(node.name, self.net.resolved_types[node], node.input, _relation = node._relation)
736
804
  return node
@@ -781,7 +849,7 @@ class Replacer(Rewriter):
781
849
  args = get_lookup_args(node, target)
782
850
  types = [f.type for f in get_relation_fields(resolved_relations[0], node.relation.name)]
783
851
  # adding this logical to avoid issues in the old backend
784
- branches.append(mm.Logical((self.convert_arguments(node, target, args, types=types, force_copy=True),)))
852
+ branches.append(self.wrap((self.convert_arguments(node, target, args, types=types, force_copy=True),)))
785
853
  return mm.Union(tuple(branches))
786
854
 
787
855
  def convert_arguments(self, node: mm.Lookup|mm.Update, relation: mm.Relation, args: Iterable[mm.Value]|None=None, types: Iterable[mm.Type]|None=None, force_copy=False) -> mm.Logical|mm.Lookup|mm.Update:
@@ -789,6 +857,9 @@ class Replacer(Rewriter):
789
857
  have these types. Convert any arguments as needed and return a new node with the
790
858
  proper relation and converted args. If multiple conversions are needed, return a
791
859
  logical that contains all the conversion tasks plus the final node. """
860
+ # ensure we use the rewritten relation if available
861
+ if relation.id in self.relations:
862
+ relation = self.relations[relation.id]
792
863
  args = args or node.args
793
864
  types = types or [self.net.resolve(f) for f in relation.fields]
794
865
  number_type = self.net.resolved_number.get(node.id)
@@ -820,7 +891,7 @@ class Replacer(Rewriter):
820
891
  # if we need conversion tasks, wrap in a logical
821
892
  if len(tasks) == 1:
822
893
  return tasks[0]
823
- return mm.Logical(tuple(tasks))
894
+ return self.wrap(tuple(tasks))
824
895
 
825
896
  def visit_eq_lookup(self, node: mm.Lookup):
826
897
  (left, right) = node.args
@@ -842,7 +913,7 @@ class Replacer(Rewriter):
842
913
  return node
843
914
 
844
915
  tasks.append(mm.Lookup(b.core.eq, tuple(final_args)))
845
- return mm.Logical(tuple(tasks))
916
+ return self.wrap(tuple(tasks))
846
917
 
847
918
  #--------------------------------------------------
848
919
  # Helpers
@@ -929,6 +1000,9 @@ def to_type(value: mm.Value|mm.Field|mm.Literal) -> mm.Type:
929
1000
  if isinstance(value, mm.Field):
930
1001
  return b.core.Field
931
1002
 
1003
+ if isinstance(value, mm.Relation):
1004
+ return b.core.Relation
1005
+
932
1006
  if isinstance(value, tuple):
933
1007
  return mm.TupleType(element_types=tuple(to_type(v) for v in value))
934
1008
 
@@ -1187,7 +1261,7 @@ class TypeMismatch(TyperError):
1187
1261
  actual: mm.Type
1188
1262
 
1189
1263
  def message(self) -> str:
1190
- return f"Expected {get_name(self.expected)}, got {get_name(self.actual)}"
1264
+ return f"Expected '{get_name(self.expected)}', got '{get_name(self.actual)}'"
1191
1265
 
1192
1266
  @dataclass
1193
1267
  class InvalidType(TyperError):
@@ -1204,10 +1278,11 @@ class UnresolvedOverload(TyperError):
1204
1278
  assert isinstance(self.node, (mm.Lookup, mm.Update, mm.Aggregate))
1205
1279
  rel = get_relation(self.node)
1206
1280
  types = ', '.join([get_name(t) for t in self.arg_types])
1207
- return f"Unresolved overload: {rel.name}({types})"
1281
+ return f"Unresolved overload: '{rel.name}({types})'"
1208
1282
 
1209
1283
  @dataclass
1210
1284
  class UnresolvedType(TyperError):
1285
+ arg: mm.Var
1211
1286
 
1212
1287
  def message(self) -> str:
1213
- return "Unable to determine concrete type."
1288
+ return f"Unable to determine concrete type for argument '{self.arg.name}'."
@@ -0,0 +1,11 @@
1
+ """
2
+ The RelationalAI Semantics Reasoners Module.
3
+ """
4
+
5
+ # Mark this package's docstrings for inclusion
6
+ # in automatically generated web documentation,
7
+ # by default as early access.
8
+ # TODO: Remove dependency on v0, once this functionality is supported.
9
+ from v0.relationalai.docutils import ProductStage # type: ignore[import-not-found]
10
+ __include_in_docs__ = True
11
+ __rai_product_stage__ = ProductStage.EARLY_ACCESS
@@ -0,0 +1,35 @@
1
+ """
2
+ RelationalAI Graph Library
3
+ """
4
+
5
+ __version__ = "0.0.0"
6
+
7
+ import warnings
8
+ from .core import Graph
9
+
10
+ # Mark this package's docstrings for inclusion
11
+ # in automatically generated web documentation.
12
+ __include_in_docs__ = True
13
+
14
+ # Warn on import that this package is at an early stage of development,
15
+ # intended for internal consumers only, and ask those internal consumers
16
+ # to contact the symbolic reasoning team such that we can track usage,
17
+ # get feedback, and help folks through breaking changes.
18
+ warnings.warn(
19
+ (
20
+ "\n\nThis library is still in early stages of development and is intended "
21
+ "for internal use only. Among other considerations, interfaces will change, "
22
+ "and performance is appropriate only for exploring small graphs. Please "
23
+ "see this package's README for additional information.\n\n"
24
+ "If you are an internal user seeing this, please also contact "
25
+ "the symbolic reasoning team such that we can track usage, get "
26
+ "feedback, and help you through breaking changes.\n"
27
+ ),
28
+ FutureWarning,
29
+ stacklevel=2
30
+ )
31
+
32
+ # Finally make this package's core functionality publicly available.
33
+ __all__ = [
34
+ "Graph",
35
+ ]