relationalai 1.0.0a1__py3-none-any.whl → 1.0.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/semantics/frontend/base.py +3 -0
- relationalai/semantics/frontend/front_compiler.py +5 -2
- relationalai/semantics/metamodel/builtins.py +2 -1
- relationalai/semantics/metamodel/metamodel.py +32 -4
- relationalai/semantics/metamodel/pprint.py +5 -3
- relationalai/semantics/metamodel/typer.py +324 -297
- relationalai/semantics/std/aggregates.py +0 -1
- relationalai/semantics/std/datetime.py +4 -1
- relationalai/shims/executor.py +26 -5
- relationalai/shims/mm2v0.py +119 -44
- {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a3.dist-info}/METADATA +1 -1
- {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a3.dist-info}/RECORD +57 -48
- v0/relationalai/__init__.py +69 -22
- v0/relationalai/clients/__init__.py +15 -2
- v0/relationalai/clients/client.py +4 -4
- v0/relationalai/clients/local.py +5 -5
- v0/relationalai/clients/resources/__init__.py +8 -0
- v0/relationalai/clients/{azure.py → resources/azure/azure.py} +12 -12
- v0/relationalai/clients/resources/snowflake/__init__.py +20 -0
- v0/relationalai/clients/resources/snowflake/cli_resources.py +87 -0
- v0/relationalai/clients/resources/snowflake/direct_access_resources.py +711 -0
- v0/relationalai/clients/resources/snowflake/engine_state_handlers.py +309 -0
- v0/relationalai/clients/resources/snowflake/error_handlers.py +199 -0
- v0/relationalai/clients/resources/snowflake/resources_factory.py +99 -0
- v0/relationalai/clients/{snowflake.py → resources/snowflake/snowflake.py} +606 -1392
- v0/relationalai/clients/{use_index_poller.py → resources/snowflake/use_index_poller.py} +43 -12
- v0/relationalai/clients/resources/snowflake/use_index_resources.py +188 -0
- v0/relationalai/clients/resources/snowflake/util.py +387 -0
- v0/relationalai/early_access/dsl/ir/executor.py +4 -4
- v0/relationalai/early_access/dsl/snow/api.py +2 -1
- v0/relationalai/errors.py +23 -0
- v0/relationalai/experimental/solvers.py +7 -7
- v0/relationalai/semantics/devtools/benchmark_lqp.py +4 -5
- v0/relationalai/semantics/devtools/extract_lqp.py +1 -1
- v0/relationalai/semantics/internal/internal.py +4 -4
- v0/relationalai/semantics/internal/snowflake.py +3 -2
- v0/relationalai/semantics/lqp/executor.py +20 -22
- v0/relationalai/semantics/lqp/model2lqp.py +42 -4
- v0/relationalai/semantics/lqp/passes.py +1 -1
- v0/relationalai/semantics/lqp/rewrite/cdc.py +1 -1
- v0/relationalai/semantics/lqp/rewrite/extract_keys.py +53 -12
- v0/relationalai/semantics/metamodel/builtins.py +8 -6
- v0/relationalai/semantics/metamodel/rewrite/flatten.py +9 -4
- v0/relationalai/semantics/metamodel/util.py +6 -5
- v0/relationalai/semantics/reasoners/graph/core.py +8 -9
- v0/relationalai/semantics/rel/executor.py +14 -11
- v0/relationalai/semantics/sql/compiler.py +2 -2
- v0/relationalai/semantics/sql/executor/snowflake.py +9 -5
- v0/relationalai/semantics/tests/test_snapshot_abstract.py +1 -1
- v0/relationalai/tools/cli.py +26 -30
- v0/relationalai/tools/cli_helpers.py +10 -2
- v0/relationalai/util/otel_configuration.py +2 -1
- v0/relationalai/util/otel_handler.py +1 -1
- {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a3.dist-info}/WHEEL +0 -0
- {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a3.dist-info}/entry_points.txt +0 -0
- {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a3.dist-info}/top_level.txt +0 -0
- /v0/relationalai/clients/{cache_store.py → resources/snowflake/cache_store.py} +0 -0
|
@@ -10,7 +10,6 @@ AggValue = Value | Distinct
|
|
|
10
10
|
# Aggregates
|
|
11
11
|
#------------------------------------------------------
|
|
12
12
|
|
|
13
|
-
# TODO - overloads
|
|
14
13
|
_sum = library.Relation("sum", fields=[Field.input("value", Numeric), Field("result", Numeric)],
|
|
15
14
|
overloads=[[Number, Number], [Float, Float]])
|
|
16
15
|
_count = library.Relation("count", fields=[Field("result", Integer)])
|
|
@@ -3,8 +3,9 @@ from __future__ import annotations
|
|
|
3
3
|
from relationalai.semantics.std import floats
|
|
4
4
|
|
|
5
5
|
from . import StringValue, IntegerValue, DateValue, DateTimeValue, math, common
|
|
6
|
-
from ..frontend.base import Aggregate, Library, Concept, NumberConcept, Expression, Field, Literal, Variable
|
|
6
|
+
from ..frontend.base import Aggregate, Library, Concept, MetaRef, NumberConcept, Expression, Field, Literal, Variable
|
|
7
7
|
from ..frontend.core import Float, Number, String, Integer, Date, DateTime
|
|
8
|
+
from ..frontend import core
|
|
8
9
|
from .. import select
|
|
9
10
|
|
|
10
11
|
from typing import Union, Literal
|
|
@@ -162,6 +163,7 @@ class date:
|
|
|
162
163
|
num_days = cls.period_days(start, end)
|
|
163
164
|
if freq in ["W", "M", "Y"]:
|
|
164
165
|
range_end = math.ceil(num_days * _days[freq])
|
|
166
|
+
range_end = core.cast(MetaRef(Integer), range_end)
|
|
165
167
|
else:
|
|
166
168
|
range_end = num_days
|
|
167
169
|
# date_range is inclusive. add 1 since std.range is exclusive
|
|
@@ -348,6 +350,7 @@ class datetime:
|
|
|
348
350
|
_end = num_ms
|
|
349
351
|
else:
|
|
350
352
|
_end = math.ceil(num_ms * Float(_milliseconds[freq]))
|
|
353
|
+
_end = core.cast(MetaRef(Integer), _end)
|
|
351
354
|
# datetime_range is inclusive. add 1 since common.range is exclusive
|
|
352
355
|
ix = common.range(0, _end + 1, 1)
|
|
353
356
|
else:
|
relationalai/shims/executor.py
CHANGED
|
@@ -10,7 +10,10 @@ from v0.relationalai.semantics.rel.executor import RelExecutor
|
|
|
10
10
|
from v0.relationalai.semantics.metamodel import ir as v0, factory as v0_factory
|
|
11
11
|
from v0.relationalai.semantics.metamodel.visitor import collect_by_type
|
|
12
12
|
from v0.relationalai.semantics.snowflake import Table as v0Table
|
|
13
|
-
|
|
13
|
+
try:
|
|
14
|
+
from v0.relationalai.clients.snowflake import Provider as v0Provider #type: ignore
|
|
15
|
+
except ImportError:
|
|
16
|
+
from v0.relationalai.clients.resources.snowflake import Provider as v0Provider
|
|
14
17
|
from v0.relationalai.clients.config import Config
|
|
15
18
|
|
|
16
19
|
# from ..config import Config
|
|
@@ -29,10 +32,28 @@ TYPER_DEBUGGER=False
|
|
|
29
32
|
# PRINT_RESULT=True
|
|
30
33
|
# TYPER_DEBUGGER=True
|
|
31
34
|
|
|
35
|
+
@lru_cache()
|
|
36
|
+
def get_config():
|
|
37
|
+
return Config()
|
|
38
|
+
|
|
39
|
+
def with_source(item: mm.Node):
|
|
40
|
+
if not hasattr(item, "source"):
|
|
41
|
+
raise ValueError(f"Item {item} has no source")
|
|
42
|
+
elif item.source is None:
|
|
43
|
+
return {}
|
|
44
|
+
elif debugging.DEBUG:
|
|
45
|
+
source = item.source.block
|
|
46
|
+
if source:
|
|
47
|
+
return { "file": source.file, "line": source.line, "source": source.source }
|
|
48
|
+
else:
|
|
49
|
+
return {"file":item.source.file, "line":item.source.line}
|
|
50
|
+
else:
|
|
51
|
+
return {"file":item.source.file, "line":item.source.line}
|
|
52
|
+
|
|
32
53
|
def execute(query: Fragment, model: Model|None = None, executor=None, export_to="", update=False):
|
|
33
54
|
if not executor:
|
|
34
55
|
# use_lqp = Config().reasoner.rule.use_lqp
|
|
35
|
-
use_lqp = bool(
|
|
56
|
+
use_lqp = bool(get_config().get("reasoner.rule.use_lqp", True))
|
|
36
57
|
executor = "lqp" if use_lqp else "rel"
|
|
37
58
|
mm_model = model.to_metamodel() if model else None
|
|
38
59
|
mm_query = query.to_metamodel()
|
|
@@ -41,7 +62,7 @@ def execute(query: Fragment, model: Model|None = None, executor=None, export_to=
|
|
|
41
62
|
|
|
42
63
|
def execute_mm(mm_query: mm.Task, mm_model: mm.Model|None = None, executor="lqp", export_to="", update=False, model: Model|None = None):
|
|
43
64
|
# perform type inference
|
|
44
|
-
typer = Typer()
|
|
65
|
+
typer = Typer(enforce=False)
|
|
45
66
|
# normalize the metamodel
|
|
46
67
|
normalizer = Normalize()
|
|
47
68
|
# translate the metamodel into a v0 query
|
|
@@ -117,7 +138,7 @@ def execute_mm(mm_query: mm.Task, mm_model: mm.Model|None = None, executor="lqp"
|
|
|
117
138
|
f.write(msg)
|
|
118
139
|
f.write('\n')
|
|
119
140
|
|
|
120
|
-
if DRY_RUN:
|
|
141
|
+
if DRY_RUN or get_config().get("compiler.dry_run", False):
|
|
121
142
|
results = []
|
|
122
143
|
else:
|
|
123
144
|
# create snowflake tables for all the tables that have been used
|
|
@@ -132,7 +153,7 @@ def execute_mm(mm_query: mm.Task, mm_model: mm.Model|None = None, executor="lqp"
|
|
|
132
153
|
|
|
133
154
|
# get an executor and execute
|
|
134
155
|
executor = _get_executor(executor, model.name if model else "")
|
|
135
|
-
with debugging.span("query", tag=None, export_to=export_to) as query_span:
|
|
156
|
+
with debugging.span("query", tag=None, export_to=export_to, dsl="", **with_source(mm_query)) as query_span:
|
|
136
157
|
if isinstance(executor, (LQPExecutor, RelExecutor)):
|
|
137
158
|
results = executor.execute(v0_model, v0_query, export_to=export_table, update=update)
|
|
138
159
|
else:
|
relationalai/shims/mm2v0.py
CHANGED
|
@@ -21,6 +21,8 @@ from ..semantics.backends.lqp import annotations as lqp_annotations
|
|
|
21
21
|
from v0.relationalai.semantics.metamodel import ir as v0, builtins as v0_builtins, types as v0_types, factory as f
|
|
22
22
|
from v0.relationalai.semantics.metamodel.util import FrozenOrderedSet, frozen, ordered_set, filter_by_type, OrderedSet
|
|
23
23
|
from v0.relationalai.semantics.internal.internal import literal_value_to_type
|
|
24
|
+
from v0.relationalai.semantics.metamodel.typer import typer as v0_typer
|
|
25
|
+
from v0.relationalai.clients.util import IdentityParser
|
|
24
26
|
|
|
25
27
|
from .hoister import Hoister
|
|
26
28
|
from .helpers import is_output_update, is_main_output
|
|
@@ -131,6 +133,16 @@ class Translator():
|
|
|
131
133
|
def translate_frozen(self, nodes: seq[mm.Node], parent, ctx) -> FrozenOrderedSet[v0.Node]:
|
|
132
134
|
return frozen(*self.translate_seq(nodes, parent, ctx))
|
|
133
135
|
|
|
136
|
+
#------------------------------------------------------
|
|
137
|
+
# Helper
|
|
138
|
+
#------------------------------------------------------
|
|
139
|
+
|
|
140
|
+
# NOTE: This has to match what is done in the v0 snowflake.Table class as that is what CDC
|
|
141
|
+
# produces. If there's a quote in the table name, then we take it verbatim, otherwise we lowercase it.
|
|
142
|
+
def translate_table_name(self, table: mm.Table) -> str:
|
|
143
|
+
name = IdentityParser(table.name).identifier
|
|
144
|
+
name = name.lower() if '"' not in table.name else table.name.replace('"', '_')
|
|
145
|
+
return sanitize(name)
|
|
134
146
|
|
|
135
147
|
#-----------------------------------------------------------------------------
|
|
136
148
|
# Capabilities, Reasoners
|
|
@@ -279,12 +291,22 @@ class Translator():
|
|
|
279
291
|
else:
|
|
280
292
|
overloads.update(x) # type: ignore
|
|
281
293
|
overloads.update(self.translate_frozen(r.overloads, r, ctx)) # type: ignore
|
|
294
|
+
annotations = self.translate_seq(r.annotations, r, ctx) # type: ignore
|
|
295
|
+
name = r.name
|
|
296
|
+
fields = self.translate_seq(r.fields, r, ctx) # type: ignore
|
|
297
|
+
# We need to turn column relations into what the CDC pass would otherwise produce
|
|
298
|
+
# by making the relation name be the table name and adding a symbol field at the front
|
|
299
|
+
# representing the column name
|
|
300
|
+
if r.fields and isinstance(r.fields[0].type, mm.Table) and r in r.fields[0].type.columns:
|
|
301
|
+
name = self.translate_table_name(r.fields[0].type)
|
|
302
|
+
fields = (v0.Field("symbol", v0_types.Symbol, False), *fields) # type: ignore
|
|
303
|
+
annotations = annotations + (v0_builtins.external_annotation,)
|
|
282
304
|
|
|
283
305
|
return v0.Relation(
|
|
284
|
-
name=
|
|
285
|
-
fields=
|
|
306
|
+
name=name,
|
|
307
|
+
fields=frozen(*fields), # type: ignore
|
|
286
308
|
requires=self.translate_frozen(r.requires, r, ctx), # type: ignore
|
|
287
|
-
annotations=
|
|
309
|
+
annotations=frozen(*annotations), # type: ignore
|
|
288
310
|
overloads=overloads.frozen(), # type: ignore
|
|
289
311
|
)
|
|
290
312
|
|
|
@@ -321,6 +343,7 @@ class Translator():
|
|
|
321
343
|
b.core.Date: v0_types.Date,
|
|
322
344
|
b.core.DateTime: v0_types.DateTime,
|
|
323
345
|
b.core.Float: v0_types.Float,
|
|
346
|
+
b.core.Hash: v0_types.Hash,
|
|
324
347
|
}
|
|
325
348
|
|
|
326
349
|
def translate_scalartype(self, t: mm.ScalarType, parent: mm.Node, ctx) -> v0.ScalarType|v0.Relation|None:
|
|
@@ -334,18 +357,14 @@ class Translator():
|
|
|
334
357
|
assert isinstance(actual_type, v0.ScalarType)
|
|
335
358
|
fields = [v0.Field(name="entity", type=actual_type, input=False)] # type: ignore
|
|
336
359
|
annotations = [v0_builtins.concept_relation_annotation]
|
|
360
|
+
name = t.name
|
|
337
361
|
if isinstance(t, mm.Table):
|
|
338
362
|
annotations.append(v0_builtins.external_annotation)
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
fields.append(v0.Field(
|
|
342
|
-
name=col.name,
|
|
343
|
-
type=self.translate_node(col.fields[-1].type, col, Context.MODEL), # type: ignore
|
|
344
|
-
input=False
|
|
345
|
-
))
|
|
363
|
+
name = self.translate_table_name(t)
|
|
364
|
+
fields.insert(0, v0.Field(name="symbol", type=v0_types.Symbol, input=False)) # type: ignore
|
|
346
365
|
|
|
347
366
|
type_relation = v0.Relation(
|
|
348
|
-
name=
|
|
367
|
+
name=name,
|
|
349
368
|
fields=tuple(fields), # type: ignore
|
|
350
369
|
requires=frozen(),
|
|
351
370
|
annotations=frozen(*annotations), # type: ignore
|
|
@@ -437,6 +456,17 @@ class Translator():
|
|
|
437
456
|
def translate_table(self, t: mm.Table, parent, ctx):
|
|
438
457
|
return self.translate_scalartype(t, parent, ctx)
|
|
439
458
|
|
|
459
|
+
def rewrite_cdc_args(self, l: mm.Lookup, args, parent, ctx):
|
|
460
|
+
# If this is a lookup for an external table column,
|
|
461
|
+
# we have to prepend the column symbol to the args
|
|
462
|
+
root_type = l.relation.fields[0].type
|
|
463
|
+
if isinstance(root_type, mm.Table):
|
|
464
|
+
self.used_tables.add(root_type)
|
|
465
|
+
is_table = l.relation == root_type
|
|
466
|
+
if is_table or l.relation in root_type.columns:
|
|
467
|
+
sym = "METADATA$KEY" if is_table else l.relation.name
|
|
468
|
+
args = (v0.Literal(type=v0_types.Symbol, value=sym), *args)
|
|
469
|
+
return args
|
|
440
470
|
|
|
441
471
|
# -----------------------------------------------------------------------------
|
|
442
472
|
# Values
|
|
@@ -541,7 +571,7 @@ class Translator():
|
|
|
541
571
|
# inline logicals if possible
|
|
542
572
|
new_children = []
|
|
543
573
|
for c in children:
|
|
544
|
-
if isinstance(c, v0.Logical) and not c.hoisted and len(c.body) == 1:
|
|
574
|
+
if isinstance(c, v0.Logical) and not c.hoisted and len(c.body) == 1 and not isinstance(c.body[0], (v0.Aggregate, v0.Rank)):
|
|
545
575
|
new_children.extend(c.body)
|
|
546
576
|
else:
|
|
547
577
|
new_children.append(c)
|
|
@@ -590,10 +620,49 @@ class Translator():
|
|
|
590
620
|
return outputs
|
|
591
621
|
|
|
592
622
|
# if this is an optional logical but we're not hoisting anything and not updating,
|
|
593
|
-
# then it's
|
|
623
|
+
# then it's possible we're filtering an outer variable, but only for this column. If
|
|
624
|
+
# so, we need to alias the output and hoist it.
|
|
594
625
|
# this is important because the LQP stack blows up if there's a logical with no effect
|
|
595
626
|
if l.optional and not hoisted and not any(isinstance(c, v0.Update) for c in children):
|
|
596
|
-
return outputs
|
|
627
|
+
# if there are no lookups, then this really is a no-op, just return outputs
|
|
628
|
+
# LQP blows up with e.g. a match-only logical
|
|
629
|
+
if not any(isinstance(c, v0.Lookup) for c in children):
|
|
630
|
+
return outputs
|
|
631
|
+
# otherwise, make sure we filter the outer variable through aliasing
|
|
632
|
+
new_children = [*children]
|
|
633
|
+
new_hoists = []
|
|
634
|
+
new_outputs = []
|
|
635
|
+
# add an eq to a new var for the output, hoist the new var, change the output to use the new var
|
|
636
|
+
for output in outputs:
|
|
637
|
+
# shim outputs always only have one alias
|
|
638
|
+
(name, orig_var) = output.aliases.data[0]
|
|
639
|
+
# if the original var is not a Var or Literal, or it's already in the keys, skip (we can't filter keys)
|
|
640
|
+
if not isinstance(orig_var, (v0.Var, v0.Literal)) or (output.keys and orig_var in output.keys): #type: ignore
|
|
641
|
+
new_outputs.append(output)
|
|
642
|
+
continue
|
|
643
|
+
new_var = v0.Var(
|
|
644
|
+
type=orig_var.type,
|
|
645
|
+
name=f"{orig_var.name if isinstance(orig_var, v0.Var) else 'literal'}_hoisted"
|
|
646
|
+
)
|
|
647
|
+
new_output = v0.Output(
|
|
648
|
+
engine=None,
|
|
649
|
+
aliases=frozen((name, new_var)),
|
|
650
|
+
keys=output.keys,
|
|
651
|
+
annotations=output.annotations
|
|
652
|
+
)
|
|
653
|
+
eq = v0.Lookup(
|
|
654
|
+
engine=None,
|
|
655
|
+
relation=v0_builtins.eq,
|
|
656
|
+
args=(new_var, orig_var),
|
|
657
|
+
annotations=frozen()
|
|
658
|
+
)
|
|
659
|
+
new_children.append(eq)
|
|
660
|
+
new_hoists.append(v0.Default(new_var, None))
|
|
661
|
+
new_outputs.append(new_output)
|
|
662
|
+
outputs = new_outputs
|
|
663
|
+
hoisted = tuple(new_hoists)
|
|
664
|
+
children = tuple(new_children)
|
|
665
|
+
# return outputs
|
|
597
666
|
|
|
598
667
|
# return outputs + a logical with the other children
|
|
599
668
|
outputs.append(
|
|
@@ -782,44 +851,22 @@ class Translator():
|
|
|
782
851
|
args=(var, col_var), # type: ignore
|
|
783
852
|
annotations=self.translate_frozen(l.annotations, l, ctx) # type: ignore
|
|
784
853
|
)
|
|
785
|
-
|
|
854
|
+
# if this is a data column, we just ignore it as the data node already binds the variables
|
|
855
|
+
elif isinstance(l.args[0].type, mm.Data):
|
|
786
856
|
return None
|
|
857
|
+
# Otherwise we keep the lookup because that's what the LQP stack expect (the lookup gets repeated for
|
|
858
|
+
# each column)
|
|
859
|
+
|
|
787
860
|
|
|
788
861
|
relation, args = self._resolve_reading(l, ctx)
|
|
789
862
|
if relation is None:
|
|
790
863
|
return None
|
|
791
864
|
|
|
792
|
-
# External Table Column lookups
|
|
793
|
-
# we have to take the 6nf column relations and pull them into a single wide lookup
|
|
794
|
-
# making sure that the variable get mapped correctly. To match the expectations of
|
|
795
|
-
# v0, we also have to make sure that if we're looking up the table row itself, that
|
|
796
|
-
# it is wrapped in its own logical
|
|
797
|
-
root_type = l.relation.fields[0].type
|
|
798
|
-
if isinstance(root_type, mm.Table):
|
|
799
|
-
self.used_tables.add(root_type)
|
|
800
|
-
assert isinstance(l.args[0], mm.Var)
|
|
801
|
-
is_col = l.relation in root_type.columns
|
|
802
|
-
is_table = l.relation == root_type
|
|
803
|
-
if is_col:
|
|
804
|
-
self.column_map[l.args[0]][l.relation] = args[-1]
|
|
805
|
-
# we always lookup the full table, so replace the relation and args
|
|
806
|
-
relation = self.translate_node(root_type, l, ctx)
|
|
807
|
-
|
|
808
|
-
# this is a lookup on the table itself or the columns, translate to the column vars
|
|
809
|
-
if is_col or is_table:
|
|
810
|
-
mapped = self.column_map.get(l.args[0], {})
|
|
811
|
-
col_args = []
|
|
812
|
-
for col in root_type.columns:
|
|
813
|
-
v = mapped.setdefault(col, v0.Var(
|
|
814
|
-
type=self.translate_node(col.fields[-1].type, col, Context.MODEL), # type: ignore
|
|
815
|
-
name=f"{col.name}"
|
|
816
|
-
))
|
|
817
|
-
col_args.append(v)
|
|
818
|
-
args = tuple([args[0], *col_args]) # type: ignore
|
|
819
865
|
|
|
820
866
|
# Specific rewrites
|
|
821
867
|
rewrite = self.rewrite_lookup(l, parent, ctx)
|
|
822
868
|
if rewrite is None:
|
|
869
|
+
args = self.rewrite_cdc_args(l, args, parent, ctx)
|
|
823
870
|
# General translation
|
|
824
871
|
rewrite = v0.Lookup(
|
|
825
872
|
engine=self.translate_reasoner(l.reasoner, l, Context.MODEL),
|
|
@@ -1056,10 +1103,12 @@ class Translator():
|
|
|
1056
1103
|
else:
|
|
1057
1104
|
replaced_args.append(arg)
|
|
1058
1105
|
# translate the lookup
|
|
1106
|
+
args = tuple(self.translate_value(arg, l, ctx) for arg in replaced_args)
|
|
1107
|
+
args = self.rewrite_cdc_args(l, args, l, ctx)
|
|
1059
1108
|
lookup = v0.Lookup(
|
|
1060
1109
|
engine=self.translate_reasoner(l.reasoner, l, Context.MODEL),
|
|
1061
1110
|
relation=self.translate_node(l.relation), # type: ignore
|
|
1062
|
-
args=
|
|
1111
|
+
args=args,
|
|
1063
1112
|
annotations=self.translate_frozen(l.annotations, l, ctx) # type: ignore
|
|
1064
1113
|
)
|
|
1065
1114
|
# subtract 1 from the index to convert from 1-based to 0-based
|
|
@@ -1109,6 +1158,18 @@ class Translator():
|
|
|
1109
1158
|
def decrement(self, l: mm.Lookup, index: int, ctx):
|
|
1110
1159
|
""" Rewrite the lookup such that the arg at `index` is decremented by 1 before the
|
|
1111
1160
|
lookup. """
|
|
1161
|
+
x = l.args[index]
|
|
1162
|
+
if isinstance(x, mm.Literal):
|
|
1163
|
+
# if the arg is a literal, just decrement the literal directly
|
|
1164
|
+
new_literal = mm.Literal(type=x.type, value=x.value - 1) # type: ignore
|
|
1165
|
+
return v0.Lookup(
|
|
1166
|
+
engine=None,
|
|
1167
|
+
relation=self.translate_node(l.relation), # type: ignore
|
|
1168
|
+
args=tuple(self.translate_value(arg, l, ctx) if i != index else self.translate_value(new_literal, l, ctx) for i, arg in enumerate(l.args)),
|
|
1169
|
+
annotations=self.translate_frozen(l.annotations, l, ctx) # type: ignore
|
|
1170
|
+
)
|
|
1171
|
+
|
|
1172
|
+
# arg is not a literal, so we need to create a tmp var to store the decremented value
|
|
1112
1173
|
tmp = self.translate_value(mm.Var(type=b.core.Number, name="tmp"), l, ctx)
|
|
1113
1174
|
# lookup(..., tmp, ...)
|
|
1114
1175
|
new = v0.Lookup(
|
|
@@ -1137,7 +1198,21 @@ class Translator():
|
|
|
1137
1198
|
inputs = []
|
|
1138
1199
|
outputs = []
|
|
1139
1200
|
args = []
|
|
1140
|
-
|
|
1201
|
+
fields = l.relation.fields
|
|
1202
|
+
# if there are overloads, we need to cast based on the most compatible overload
|
|
1203
|
+
if l.relation.overloads:
|
|
1204
|
+
inf = float("inf")
|
|
1205
|
+
min_cost = inf
|
|
1206
|
+
for overload in l.relation.overloads:
|
|
1207
|
+
total = 0
|
|
1208
|
+
for arg, field in zip(l.args, overload.fields):
|
|
1209
|
+
if v0_typer.to_type(arg) != field.type:
|
|
1210
|
+
total += 1
|
|
1211
|
+
if total < min_cost:
|
|
1212
|
+
min_cost = total
|
|
1213
|
+
fields = overload.fields
|
|
1214
|
+
|
|
1215
|
+
for arg, field in zip(l.args, fields):
|
|
1141
1216
|
target_type = field.type
|
|
1142
1217
|
if target_type is None or not isinstance(arg, (v0.Var, v0.Literal)) or arg.type == target_type:
|
|
1143
1218
|
args.append(arg)
|