relationalai 1.0.0a3__py3-none-any.whl → 1.0.0a5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/config/config.py +47 -21
- relationalai/config/connections/__init__.py +5 -2
- relationalai/config/connections/duckdb.py +2 -2
- relationalai/config/connections/local.py +31 -0
- relationalai/config/connections/snowflake.py +0 -1
- relationalai/config/external/raiconfig_converter.py +235 -0
- relationalai/config/external/raiconfig_models.py +202 -0
- relationalai/config/external/utils.py +31 -0
- relationalai/config/shims.py +1 -0
- relationalai/semantics/__init__.py +10 -8
- relationalai/semantics/backends/sql/sql_compiler.py +1 -4
- relationalai/semantics/experimental/__init__.py +0 -0
- relationalai/semantics/experimental/builder.py +295 -0
- relationalai/semantics/experimental/builtins.py +154 -0
- relationalai/semantics/frontend/base.py +67 -42
- relationalai/semantics/frontend/core.py +34 -6
- relationalai/semantics/frontend/front_compiler.py +209 -37
- relationalai/semantics/frontend/pprint.py +6 -2
- relationalai/semantics/metamodel/__init__.py +7 -0
- relationalai/semantics/metamodel/metamodel.py +2 -0
- relationalai/semantics/metamodel/metamodel_analyzer.py +58 -16
- relationalai/semantics/metamodel/pprint.py +6 -1
- relationalai/semantics/metamodel/rewriter.py +11 -7
- relationalai/semantics/metamodel/typer.py +116 -41
- relationalai/semantics/reasoners/__init__.py +11 -0
- relationalai/semantics/reasoners/graph/__init__.py +35 -0
- relationalai/semantics/reasoners/graph/core.py +9028 -0
- relationalai/semantics/std/__init__.py +30 -10
- relationalai/semantics/std/aggregates.py +641 -12
- relationalai/semantics/std/common.py +146 -13
- relationalai/semantics/std/constraints.py +71 -1
- relationalai/semantics/std/datetime.py +904 -21
- relationalai/semantics/std/decimals.py +143 -2
- relationalai/semantics/std/floats.py +57 -4
- relationalai/semantics/std/integers.py +98 -4
- relationalai/semantics/std/math.py +857 -35
- relationalai/semantics/std/numbers.py +216 -20
- relationalai/semantics/std/re.py +213 -5
- relationalai/semantics/std/strings.py +437 -44
- relationalai/shims/executor.py +60 -52
- relationalai/shims/fixtures.py +85 -0
- relationalai/shims/helpers.py +26 -2
- relationalai/shims/hoister.py +28 -9
- relationalai/shims/mm2v0.py +204 -173
- relationalai/tools/cli/cli.py +192 -10
- relationalai/tools/cli/components/progress_reader.py +1 -1
- relationalai/tools/cli/docs.py +394 -0
- relationalai/tools/debugger.py +11 -4
- relationalai/tools/qb_debugger.py +435 -0
- relationalai/tools/typer_debugger.py +1 -2
- relationalai/util/dataclasses.py +3 -5
- relationalai/util/docutils.py +1 -2
- relationalai/util/error.py +2 -5
- relationalai/util/python.py +23 -0
- relationalai/util/runtime.py +1 -2
- relationalai/util/schema.py +2 -4
- relationalai/util/structures.py +4 -2
- relationalai/util/tracing.py +8 -2
- {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/METADATA +8 -5
- {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/RECORD +118 -95
- {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/WHEEL +1 -1
- v0/relationalai/__init__.py +1 -1
- v0/relationalai/clients/client.py +52 -18
- v0/relationalai/clients/exec_txn_poller.py +122 -0
- v0/relationalai/clients/local.py +23 -8
- v0/relationalai/clients/resources/azure/azure.py +36 -11
- v0/relationalai/clients/resources/snowflake/__init__.py +4 -4
- v0/relationalai/clients/resources/snowflake/cli_resources.py +12 -1
- v0/relationalai/clients/resources/snowflake/direct_access_resources.py +124 -100
- v0/relationalai/clients/resources/snowflake/engine_service.py +381 -0
- v0/relationalai/clients/resources/snowflake/engine_state_handlers.py +35 -29
- v0/relationalai/clients/resources/snowflake/error_handlers.py +43 -2
- v0/relationalai/clients/resources/snowflake/snowflake.py +277 -179
- v0/relationalai/clients/resources/snowflake/use_index_poller.py +8 -0
- v0/relationalai/clients/types.py +5 -0
- v0/relationalai/errors.py +19 -1
- v0/relationalai/semantics/lqp/algorithms.py +173 -0
- v0/relationalai/semantics/lqp/builtins.py +199 -2
- v0/relationalai/semantics/lqp/executor.py +68 -37
- v0/relationalai/semantics/lqp/ir.py +28 -2
- v0/relationalai/semantics/lqp/model2lqp.py +215 -45
- v0/relationalai/semantics/lqp/passes.py +13 -658
- v0/relationalai/semantics/lqp/rewrite/__init__.py +12 -0
- v0/relationalai/semantics/lqp/rewrite/algorithm.py +385 -0
- v0/relationalai/semantics/lqp/rewrite/constants_to_vars.py +70 -0
- v0/relationalai/semantics/lqp/rewrite/deduplicate_vars.py +104 -0
- v0/relationalai/semantics/lqp/rewrite/eliminate_data.py +108 -0
- v0/relationalai/semantics/lqp/rewrite/extract_keys.py +25 -3
- v0/relationalai/semantics/lqp/rewrite/period_math.py +77 -0
- v0/relationalai/semantics/lqp/rewrite/quantify_vars.py +65 -31
- v0/relationalai/semantics/lqp/rewrite/unify_definitions.py +317 -0
- v0/relationalai/semantics/lqp/utils.py +11 -1
- v0/relationalai/semantics/lqp/validators.py +14 -1
- v0/relationalai/semantics/metamodel/builtins.py +2 -1
- v0/relationalai/semantics/metamodel/compiler.py +2 -1
- v0/relationalai/semantics/metamodel/dependency.py +12 -3
- v0/relationalai/semantics/metamodel/executor.py +11 -1
- v0/relationalai/semantics/metamodel/factory.py +2 -2
- v0/relationalai/semantics/metamodel/helpers.py +7 -0
- v0/relationalai/semantics/metamodel/ir.py +3 -2
- v0/relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +30 -20
- v0/relationalai/semantics/metamodel/rewrite/flatten.py +50 -13
- v0/relationalai/semantics/metamodel/rewrite/format_outputs.py +9 -3
- v0/relationalai/semantics/metamodel/typer/checker.py +6 -4
- v0/relationalai/semantics/metamodel/typer/typer.py +4 -3
- v0/relationalai/semantics/metamodel/visitor.py +4 -3
- v0/relationalai/semantics/reasoners/optimization/solvers_dev.py +1 -1
- v0/relationalai/semantics/reasoners/optimization/solvers_pb.py +336 -86
- v0/relationalai/semantics/rel/compiler.py +2 -1
- v0/relationalai/semantics/rel/executor.py +3 -2
- v0/relationalai/semantics/tests/lqp/__init__.py +0 -0
- v0/relationalai/semantics/tests/lqp/algorithms.py +345 -0
- v0/relationalai/tools/cli.py +339 -186
- v0/relationalai/tools/cli_controls.py +216 -67
- v0/relationalai/tools/cli_helpers.py +410 -6
- v0/relationalai/util/format.py +5 -2
- {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/entry_points.txt +0 -0
- {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/top_level.txt +0 -0
|
@@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, Type
|
|
|
3
3
|
|
|
4
4
|
from relationalai.util.structures import OrderedSet
|
|
5
5
|
from .metamodel import (
|
|
6
|
-
Node, Model, NumberType, ScalarType,
|
|
6
|
+
Node, Model, Relation, UnresolvedRelation, NumberType, ScalarType, Logical, Sequence, Union, Match, Until, Wait, Loop, Break, Require,
|
|
7
7
|
Not, Exists, Lookup, Update, Aggregate, Construct, Var, Annotation, Field, Literal, UnionType, ListType, TupleType
|
|
8
8
|
)
|
|
9
9
|
import datetime as _dt
|
|
@@ -17,7 +17,11 @@ NO_WALK = object()
|
|
|
17
17
|
|
|
18
18
|
WALK_FIELDS: Dict[Type[Node], tuple[str, ...]] = {
|
|
19
19
|
# Top-level
|
|
20
|
-
Model: ("root",),
|
|
20
|
+
Model: ("root", "relations",),
|
|
21
|
+
|
|
22
|
+
# Relation
|
|
23
|
+
Relation: ("fields",),
|
|
24
|
+
UnresolvedRelation: ("fields",),
|
|
21
25
|
|
|
22
26
|
# Control flow
|
|
23
27
|
Logical: ("body",),
|
|
@@ -131,7 +135,7 @@ class Walker:
|
|
|
131
135
|
ns = {"NO_WALK": NO_WALK}
|
|
132
136
|
lines = ["def walker(self, node):"]
|
|
133
137
|
lines.append(f" no_walk = self.{enter_name}(node)")
|
|
134
|
-
lines.append(
|
|
138
|
+
lines.append(" if no_walk is not NO_WALK:")
|
|
135
139
|
lines.extend(f" self._walk(node.{fld})" for fld in WALK_FIELDS.get(node, ()))
|
|
136
140
|
lines.append(f" self.{exit_name}(node)")
|
|
137
141
|
exec("\n".join(lines), ns)
|
|
@@ -240,8 +244,8 @@ class Rewriter:
|
|
|
240
244
|
|
|
241
245
|
# pre-order hook: enter_<name>(node) -> optional replacement
|
|
242
246
|
lines.append(f" tmp = self.{enter_name}(node)")
|
|
243
|
-
lines.append(
|
|
244
|
-
lines.append(
|
|
247
|
+
lines.append(" if tmp is not None and tmp is not node:")
|
|
248
|
+
lines.append(" new = tmp")
|
|
245
249
|
# rewrite children in WALK_FIELDS, track changes
|
|
246
250
|
lines.extend(f" v_{fld} = self._rewrite(new.{fld})" for fld in fields)
|
|
247
251
|
|
|
@@ -253,8 +257,8 @@ class Rewriter:
|
|
|
253
257
|
|
|
254
258
|
# post-order hook: <name>(node) -> optional replacement
|
|
255
259
|
lines.append(f" tmp = self.{exit_name}(new)")
|
|
256
|
-
lines.append(
|
|
257
|
-
lines.append(
|
|
260
|
+
lines.append(" if tmp is not None and tmp is not new:")
|
|
261
|
+
lines.append(" new = tmp")
|
|
258
262
|
|
|
259
263
|
lines.append(" return new")
|
|
260
264
|
|
|
@@ -10,9 +10,36 @@ from ...util.naming import sanitize
|
|
|
10
10
|
from ...util.structures import OrderedSet
|
|
11
11
|
|
|
12
12
|
from . import metamodel as mm, builtins as bt
|
|
13
|
-
from .rewriter import Walker, Rewriter
|
|
13
|
+
from .rewriter import Walker, Rewriter
|
|
14
14
|
from .builtins import builtins as b
|
|
15
15
|
|
|
16
|
+
from contextlib import contextmanager
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
#--------------------------------------------------
|
|
20
|
+
# Typer Testing
|
|
21
|
+
#--------------------------------------------------
|
|
22
|
+
COLLECTED_ERRORS = None
|
|
23
|
+
@contextmanager
|
|
24
|
+
def errors(error_type):
|
|
25
|
+
"""
|
|
26
|
+
Check that this type error is raised within the context.
|
|
27
|
+
|
|
28
|
+
:param error_type: the type of the expected error.
|
|
29
|
+
"""
|
|
30
|
+
global COLLECTED_ERRORS
|
|
31
|
+
COLLECTED_ERRORS = []
|
|
32
|
+
try:
|
|
33
|
+
yield
|
|
34
|
+
except Exception:
|
|
35
|
+
# we expect the typer to raise an error
|
|
36
|
+
pass
|
|
37
|
+
finally:
|
|
38
|
+
if not any(isinstance(error, error_type) for error in COLLECTED_ERRORS):
|
|
39
|
+
exc("ExpectedWarning", f"Expected warning of type {error_type.__name__} but none was raised.")
|
|
40
|
+
COLLECTED_ERRORS = None
|
|
41
|
+
|
|
42
|
+
|
|
16
43
|
#--------------------------------------------------
|
|
17
44
|
# Typer
|
|
18
45
|
#--------------------------------------------------
|
|
@@ -59,7 +86,7 @@ class Typer:
|
|
|
59
86
|
Analyzer(net).analyze(node)
|
|
60
87
|
|
|
61
88
|
# propagate the types through the network
|
|
62
|
-
with tracing.span("typer.propagate") as span:
|
|
89
|
+
with tracing.span("typer.propagate"): # as span:
|
|
63
90
|
net.propagate()
|
|
64
91
|
# span["type_graph"] = net.to_mermaid()
|
|
65
92
|
|
|
@@ -68,12 +95,18 @@ class Typer:
|
|
|
68
95
|
replacer = Replacer(net)
|
|
69
96
|
final = replacer.rewrite(node)
|
|
70
97
|
|
|
71
|
-
# report any errors found during
|
|
98
|
+
# report any errors found during type inference
|
|
72
99
|
if net.errors:
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
100
|
+
if COLLECTED_ERRORS is not None:
|
|
101
|
+
# collecting errors for tests, just add them to the list
|
|
102
|
+
COLLECTED_ERRORS.extend(net.errors)
|
|
103
|
+
# raise to avoid further processing
|
|
104
|
+
raise Exception("Typer errors collected.")
|
|
105
|
+
else:
|
|
106
|
+
for error in net.errors:
|
|
107
|
+
error.report()
|
|
108
|
+
if self.enforce:
|
|
109
|
+
exc("TyperError", "Type errors detected during type inference.")
|
|
77
110
|
return final
|
|
78
111
|
|
|
79
112
|
|
|
@@ -146,10 +179,8 @@ class PropagationNetwork():
|
|
|
146
179
|
# TODO - consider renaming this to UnresolvedReference
|
|
147
180
|
self.errors.append(UnresolvedOverload(node, [self.resolve(a) for a in node.args]))
|
|
148
181
|
|
|
149
|
-
def unresolved_type(self, node:
|
|
150
|
-
|
|
151
|
-
# that we could not resolve
|
|
152
|
-
self.errors.append(UnresolvedType(node))
|
|
182
|
+
def unresolved_type(self, node: mm.Lookup|mm.Aggregate, arg: mm.Var):
|
|
183
|
+
self.errors.append(UnresolvedType(node, arg))
|
|
153
184
|
|
|
154
185
|
def has_errors(self, node: Node) -> bool:
|
|
155
186
|
for mismatch in self.errors:
|
|
@@ -259,12 +290,12 @@ class PropagationNetwork():
|
|
|
259
290
|
# start with the loaded roots + all literals + sources of edges without back edges
|
|
260
291
|
work_list.extend(self.loaded_roots)
|
|
261
292
|
for source in self.edges.keys():
|
|
262
|
-
if isinstance(source, (mm.Literal)) or not
|
|
293
|
+
if isinstance(source, (mm.Literal)) or source not in self.back_edges:
|
|
263
294
|
work_list.append(source)
|
|
264
295
|
|
|
265
296
|
# limit the number of iterations to avoid infinite loops
|
|
266
297
|
i = 0
|
|
267
|
-
max_iterations = 100 * len(self.edges)
|
|
298
|
+
max_iterations = 100 * (len(self.edges) + len(work_list))
|
|
268
299
|
|
|
269
300
|
# propagate types until we reach a fixed point
|
|
270
301
|
while work_list:
|
|
@@ -388,7 +419,7 @@ class PropagationNetwork():
|
|
|
388
419
|
# this is only attempted if all input types match the field types, i.e. no
|
|
389
420
|
# conversions are needed
|
|
390
421
|
resolved_fields = types
|
|
391
|
-
if bt.is_function(relation) and len(set(resolved_fields)) == 1 and not
|
|
422
|
+
if bt.is_function(relation) and len(set(resolved_fields)) == 1 and relation not in self.NON_TYPE_PRESERVERS and\
|
|
392
423
|
all(type_matches(arg_type, field_type) for arg_type, field_type, field in zip(resolved_args, types, relation.fields) if field.input):
|
|
393
424
|
|
|
394
425
|
input_types = set([arg_type for field, arg_type
|
|
@@ -409,6 +440,9 @@ class PropagationNetwork():
|
|
|
409
440
|
for field, field_type, arg in zip(relation.fields, resolved_fields, task.args):
|
|
410
441
|
if not field.input and isinstance(arg, mm.Var):
|
|
411
442
|
self.add_resolved_type(arg, field_type)
|
|
443
|
+
# this can be the way to learn the specialized number type for the field
|
|
444
|
+
if bt.is_abstract(field.type):
|
|
445
|
+
self.add_resolved_type(field, field_type)
|
|
412
446
|
|
|
413
447
|
elif b.core.TypeVar in types:
|
|
414
448
|
# this relation contains type vars, so we have to make sure that all args
|
|
@@ -446,13 +480,27 @@ class PropagationNetwork():
|
|
|
446
480
|
|
|
447
481
|
else:
|
|
448
482
|
# no typevar or number specialization shenanigans, just propagate field types to args
|
|
483
|
+
is_population_lookup = isinstance(relation, mm.TypeNode)
|
|
449
484
|
for field, field_type, arg, arg_type in zip(relation.fields, resolved_fields, task.args, resolved_args):
|
|
450
|
-
#
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
485
|
+
# if the arg does not match the declared field type, and cannot be converted into it, it's a mismatch
|
|
486
|
+
# note that we use the declared field type here, field.type, not the resolved field type, which may
|
|
487
|
+
# have been refined by updates
|
|
488
|
+
if not type_matches(arg_type, field.type, accept_expected_super_types=is_population_lookup) and not conversion_allowed(arg_type, field.type):
|
|
489
|
+
# if the field is an output and the arg is still abstract, we may be able to refine it later,
|
|
490
|
+
# as long as they are compatible, so skip reporting as it will come back later
|
|
491
|
+
# TODO - replace this with the analyzer reversing an edge for this case
|
|
492
|
+
if not field.input and bt.is_abstract(arg_type) and (type_matches(field.type, arg_type, accept_expected_super_types=is_population_lookup) or conversion_allowed(field.type, arg_type)):
|
|
493
|
+
continue
|
|
494
|
+
self.type_mismatch(task, field.type, arg_type)
|
|
495
|
+
elif isinstance(arg, mm.Var):
|
|
496
|
+
if bt.is_abstract(field_type) and bt.is_abstract(arg_type):
|
|
497
|
+
# if the resolved field type is still abstract, we cannot resolve the type of the arg
|
|
498
|
+
self.unresolved_type(task, arg)
|
|
499
|
+
elif bt.is_abstract(arg_type) or bt.extends(field_type, arg_type):
|
|
500
|
+
# we learned something about this arg, so store it
|
|
501
|
+
self.add_resolved_type(arg, field_type)
|
|
502
|
+
elif not type_matches(arg_type, field.type) and is_population_lookup:
|
|
503
|
+
self.add_resolved_type(arg, field_type)
|
|
456
504
|
|
|
457
505
|
# we try to preserve types for relations that are functions (i.e. potentially multiple
|
|
458
506
|
# input but a single output) and where all types are the same. However, there are some
|
|
@@ -526,7 +574,8 @@ class PropagationNetwork():
|
|
|
526
574
|
if isinstance(x, mm.NumberType):
|
|
527
575
|
if number is None or x.scale > number.scale or (x.scale == number.scale and x.precision > number.precision):
|
|
528
576
|
number = x
|
|
529
|
-
|
|
577
|
+
if number is None:
|
|
578
|
+
number = b.core.DefaultNumber
|
|
530
579
|
return number, [number if bt.is_number(t) else t for t in field_types]
|
|
531
580
|
|
|
532
581
|
|
|
@@ -537,15 +586,15 @@ class PropagationNetwork():
|
|
|
537
586
|
# draw the network as a mermaid graph for the debugger
|
|
538
587
|
def to_mermaid(self, max_edges=500) -> str:
|
|
539
588
|
# add links for edges while collecting nodes
|
|
540
|
-
nodes =
|
|
589
|
+
nodes = dict()
|
|
541
590
|
link_strs = []
|
|
542
591
|
# edges
|
|
543
592
|
for src, dsts in self.edges.items():
|
|
544
|
-
nodes.
|
|
593
|
+
nodes[src.id] = src
|
|
545
594
|
for dst in dsts:
|
|
546
595
|
if len(link_strs) > max_edges:
|
|
547
596
|
break
|
|
548
|
-
nodes.
|
|
597
|
+
nodes[dst.id] = dst
|
|
549
598
|
link_strs.append(f"n{src.id} --> n{dst.id}")
|
|
550
599
|
if len(link_strs) > max_edges:
|
|
551
600
|
break
|
|
@@ -568,7 +617,7 @@ class PropagationNetwork():
|
|
|
568
617
|
|
|
569
618
|
resolved = self.resolved_types
|
|
570
619
|
node_strs = []
|
|
571
|
-
for node in nodes:
|
|
620
|
+
for _, node in nodes.items():
|
|
572
621
|
klass = ""
|
|
573
622
|
if isinstance(node, mm.Var):
|
|
574
623
|
ir_type = resolved.get(node) or self.resolve(node)
|
|
@@ -642,10 +691,6 @@ class Analyzer(Walker):
|
|
|
642
691
|
def analyze(self, node: mm.Node):
|
|
643
692
|
self(node)
|
|
644
693
|
|
|
645
|
-
# TODO - ignoring requires for now because the typing of constraints seems incorrect
|
|
646
|
-
def enter_require(self, require: mm.Require):
|
|
647
|
-
return NO_WALK
|
|
648
|
-
|
|
649
694
|
def compute_potential_targets(self, relation: mm.Relation):
|
|
650
695
|
# register potential targets for placeholders
|
|
651
696
|
if bt.is_placeholder(relation):
|
|
@@ -681,7 +726,7 @@ class Analyzer(Walker):
|
|
|
681
726
|
is_placeholder = bt.is_placeholder(relation)
|
|
682
727
|
for field, arg in zip(relation.fields, task.args):
|
|
683
728
|
if isinstance(arg, (mm.Var, mm.Literal)):
|
|
684
|
-
if field.input:
|
|
729
|
+
if field.input or isinstance(arg, mm.Literal):
|
|
685
730
|
# we need to resolve all inputs before resolving the relation
|
|
686
731
|
self.net.add_edge(arg, task)
|
|
687
732
|
else:
|
|
@@ -708,9 +753,34 @@ class Replacer(Rewriter):
|
|
|
708
753
|
def __init__(self, net:PropagationNetwork):
|
|
709
754
|
super().__init__()
|
|
710
755
|
self.net = net
|
|
756
|
+
# map from relation id to rewritten relation, to make sure that lookups/updates
|
|
757
|
+
# point to the correct object in case we refined the relation field types
|
|
758
|
+
self.relations:dict[int, mm.Relation] = {}
|
|
759
|
+
# logicals created during rewriting (to inline only the ones we created)
|
|
760
|
+
self.logicals = set()
|
|
761
|
+
|
|
762
|
+
def rewrite(self, node: T) -> T:
|
|
763
|
+
try:
|
|
764
|
+
if isinstance(node, mm.Model):
|
|
765
|
+
# first rewrite all relations to update their field types
|
|
766
|
+
relations = self(node.relations) # type: ignore
|
|
767
|
+
assert isinstance(relations, tuple)
|
|
768
|
+
# index the rewritten relations
|
|
769
|
+
self.relations = {r.id: r for r in relations}
|
|
770
|
+
# then rewrite the root with the updated relations
|
|
771
|
+
root = self(node.root) # type: ignore
|
|
772
|
+
return node.mut(relations = relations, root = root)
|
|
773
|
+
else:
|
|
774
|
+
return self(node) # type: ignore
|
|
775
|
+
finally:
|
|
776
|
+
self.relations = {}
|
|
777
|
+
self.logicals = set()
|
|
711
778
|
|
|
712
|
-
def
|
|
713
|
-
|
|
779
|
+
def wrap(self, body: tuple) -> mm.Logical:
|
|
780
|
+
""" Wrap the given body in a logical and track it for inlining later. """
|
|
781
|
+
logical = mm.Logical(body)
|
|
782
|
+
self.logicals.add(logical.id)
|
|
783
|
+
return logical
|
|
714
784
|
|
|
715
785
|
def logical(self, logical: mm.Logical):
|
|
716
786
|
if len(logical.body) == 0:
|
|
@@ -718,7 +788,7 @@ class Replacer(Rewriter):
|
|
|
718
788
|
# inline logicals that are just there to group other nodes during rewrite
|
|
719
789
|
body = []
|
|
720
790
|
for child in logical.body:
|
|
721
|
-
if isinstance(child, mm.Logical) and not child.optional and not child.scope:
|
|
791
|
+
if isinstance(child, mm.Logical) and not child.optional and not child.scope and child.id in self.logicals:
|
|
722
792
|
body.extend(child.body)
|
|
723
793
|
else:
|
|
724
794
|
body.append(child)
|
|
@@ -729,8 +799,6 @@ class Replacer(Rewriter):
|
|
|
729
799
|
#--------------------------------------------------
|
|
730
800
|
|
|
731
801
|
def field(self, node: mm.Field):
|
|
732
|
-
# TODO - this is only modifying the relation in the model, but then we have a new
|
|
733
|
-
# relation there, which is different than the object referenced by tasks.
|
|
734
802
|
if node in self.net.resolved_types:
|
|
735
803
|
return mm.Field(node.name, self.net.resolved_types[node], node.input, _relation = node._relation)
|
|
736
804
|
return node
|
|
@@ -781,7 +849,7 @@ class Replacer(Rewriter):
|
|
|
781
849
|
args = get_lookup_args(node, target)
|
|
782
850
|
types = [f.type for f in get_relation_fields(resolved_relations[0], node.relation.name)]
|
|
783
851
|
# adding this logical to avoid issues in the old backend
|
|
784
|
-
branches.append(
|
|
852
|
+
branches.append(self.wrap((self.convert_arguments(node, target, args, types=types, force_copy=True),)))
|
|
785
853
|
return mm.Union(tuple(branches))
|
|
786
854
|
|
|
787
855
|
def convert_arguments(self, node: mm.Lookup|mm.Update, relation: mm.Relation, args: Iterable[mm.Value]|None=None, types: Iterable[mm.Type]|None=None, force_copy=False) -> mm.Logical|mm.Lookup|mm.Update:
|
|
@@ -789,6 +857,9 @@ class Replacer(Rewriter):
|
|
|
789
857
|
have these types. Convert any arguments as needed and return a new node with the
|
|
790
858
|
proper relation and converted args. If multiple conversions are needed, return a
|
|
791
859
|
logical that contains all the conversion tasks plus the final node. """
|
|
860
|
+
# ensure we use the rewritten relation if available
|
|
861
|
+
if relation.id in self.relations:
|
|
862
|
+
relation = self.relations[relation.id]
|
|
792
863
|
args = args or node.args
|
|
793
864
|
types = types or [self.net.resolve(f) for f in relation.fields]
|
|
794
865
|
number_type = self.net.resolved_number.get(node.id)
|
|
@@ -820,7 +891,7 @@ class Replacer(Rewriter):
|
|
|
820
891
|
# if we need conversion tasks, wrap in a logical
|
|
821
892
|
if len(tasks) == 1:
|
|
822
893
|
return tasks[0]
|
|
823
|
-
return
|
|
894
|
+
return self.wrap(tuple(tasks))
|
|
824
895
|
|
|
825
896
|
def visit_eq_lookup(self, node: mm.Lookup):
|
|
826
897
|
(left, right) = node.args
|
|
@@ -842,7 +913,7 @@ class Replacer(Rewriter):
|
|
|
842
913
|
return node
|
|
843
914
|
|
|
844
915
|
tasks.append(mm.Lookup(b.core.eq, tuple(final_args)))
|
|
845
|
-
return
|
|
916
|
+
return self.wrap(tuple(tasks))
|
|
846
917
|
|
|
847
918
|
#--------------------------------------------------
|
|
848
919
|
# Helpers
|
|
@@ -929,6 +1000,9 @@ def to_type(value: mm.Value|mm.Field|mm.Literal) -> mm.Type:
|
|
|
929
1000
|
if isinstance(value, mm.Field):
|
|
930
1001
|
return b.core.Field
|
|
931
1002
|
|
|
1003
|
+
if isinstance(value, mm.Relation):
|
|
1004
|
+
return b.core.Relation
|
|
1005
|
+
|
|
932
1006
|
if isinstance(value, tuple):
|
|
933
1007
|
return mm.TupleType(element_types=tuple(to_type(v) for v in value))
|
|
934
1008
|
|
|
@@ -1187,7 +1261,7 @@ class TypeMismatch(TyperError):
|
|
|
1187
1261
|
actual: mm.Type
|
|
1188
1262
|
|
|
1189
1263
|
def message(self) -> str:
|
|
1190
|
-
return f"Expected {get_name(self.expected)}, got {get_name(self.actual)}"
|
|
1264
|
+
return f"Expected '{get_name(self.expected)}', got '{get_name(self.actual)}'"
|
|
1191
1265
|
|
|
1192
1266
|
@dataclass
|
|
1193
1267
|
class InvalidType(TyperError):
|
|
@@ -1204,10 +1278,11 @@ class UnresolvedOverload(TyperError):
|
|
|
1204
1278
|
assert isinstance(self.node, (mm.Lookup, mm.Update, mm.Aggregate))
|
|
1205
1279
|
rel = get_relation(self.node)
|
|
1206
1280
|
types = ', '.join([get_name(t) for t in self.arg_types])
|
|
1207
|
-
return f"Unresolved overload: {rel.name}({types})"
|
|
1281
|
+
return f"Unresolved overload: '{rel.name}({types})'"
|
|
1208
1282
|
|
|
1209
1283
|
@dataclass
|
|
1210
1284
|
class UnresolvedType(TyperError):
|
|
1285
|
+
arg: mm.Var
|
|
1211
1286
|
|
|
1212
1287
|
def message(self) -> str:
|
|
1213
|
-
return "Unable to determine concrete type."
|
|
1288
|
+
return f"Unable to determine concrete type for argument '{self.arg.name}'."
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""
|
|
2
|
+
The RelationalAI Semantics Reasoners Module.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
# Mark this package's docstrings for inclusion
|
|
6
|
+
# in automatically generated web documentation,
|
|
7
|
+
# by default as early access.
|
|
8
|
+
# TODO: Remove dependency on v0, once this functionality is supported.
|
|
9
|
+
from v0.relationalai.docutils import ProductStage # type: ignore[import-not-found]
|
|
10
|
+
__include_in_docs__ = True
|
|
11
|
+
__rai_product_stage__ = ProductStage.EARLY_ACCESS
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RelationalAI Graph Library
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
__version__ = "0.0.0"
|
|
6
|
+
|
|
7
|
+
import warnings
|
|
8
|
+
from .core import Graph
|
|
9
|
+
|
|
10
|
+
# Mark this package's docstrings for inclusion
|
|
11
|
+
# in automatically generated web documentation.
|
|
12
|
+
__include_in_docs__ = True
|
|
13
|
+
|
|
14
|
+
# Warn on import that this package is at an early stage of development,
|
|
15
|
+
# intended for internal consumers only, and ask those internal consumers
|
|
16
|
+
# to contact the symbolic reasoning team such that we can track usage,
|
|
17
|
+
# get feedback, and help folks through breaking changes.
|
|
18
|
+
warnings.warn(
|
|
19
|
+
(
|
|
20
|
+
"\n\nThis library is still in early stages of development and is intended "
|
|
21
|
+
"for internal use only. Among other considerations, interfaces will change, "
|
|
22
|
+
"and performance is appropriate only for exploring small graphs. Please "
|
|
23
|
+
"see this package's README for additional information.\n\n"
|
|
24
|
+
"If you are an internal user seeing this, please also contact "
|
|
25
|
+
"the symbolic reasoning team such that we can track usage, get "
|
|
26
|
+
"feedback, and help you through breaking changes.\n"
|
|
27
|
+
),
|
|
28
|
+
FutureWarning,
|
|
29
|
+
stacklevel=2
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
# Finally make this package's core functionality publicly available.
|
|
33
|
+
__all__ = [
|
|
34
|
+
"Graph",
|
|
35
|
+
]
|