relationalai 1.0.0a1__py3-none-any.whl → 1.0.0a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/semantics/frontend/base.py +3 -0
- relationalai/semantics/frontend/front_compiler.py +5 -2
- relationalai/semantics/metamodel/builtins.py +2 -1
- relationalai/semantics/metamodel/metamodel.py +32 -4
- relationalai/semantics/metamodel/pprint.py +5 -3
- relationalai/semantics/metamodel/typer.py +324 -297
- relationalai/semantics/std/aggregates.py +0 -1
- relationalai/semantics/std/datetime.py +4 -1
- relationalai/shims/executor.py +22 -4
- relationalai/shims/mm2v0.py +108 -38
- {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a2.dist-info}/METADATA +1 -1
- {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a2.dist-info}/RECORD +27 -27
- v0/relationalai/errors.py +23 -0
- v0/relationalai/semantics/internal/internal.py +4 -4
- v0/relationalai/semantics/internal/snowflake.py +2 -1
- v0/relationalai/semantics/lqp/executor.py +16 -11
- v0/relationalai/semantics/lqp/model2lqp.py +42 -4
- v0/relationalai/semantics/lqp/passes.py +1 -1
- v0/relationalai/semantics/lqp/rewrite/cdc.py +1 -1
- v0/relationalai/semantics/lqp/rewrite/extract_keys.py +53 -12
- v0/relationalai/semantics/metamodel/builtins.py +8 -6
- v0/relationalai/semantics/metamodel/rewrite/flatten.py +9 -4
- v0/relationalai/semantics/reasoners/graph/core.py +8 -9
- v0/relationalai/semantics/sql/compiler.py +2 -2
- {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a2.dist-info}/WHEEL +0 -0
- {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a2.dist-info}/entry_points.txt +0 -0
- {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a2.dist-info}/top_level.txt +0 -0
|
@@ -396,6 +396,41 @@ def _translate_ascending_rank(ctx: TranslationCtx, limit: int, result_var: lqp.V
|
|
|
396
396
|
terms=terms,
|
|
397
397
|
)
|
|
398
398
|
|
|
399
|
+
def _rename_shadowed_abstraction_vars(
|
|
400
|
+
ctx: TranslationCtx,
|
|
401
|
+
aggr: ir.Aggregate,
|
|
402
|
+
abstr_args: list[Tuple[lqp.Var, lqp.Type]],
|
|
403
|
+
body_conjs: list[lqp.Formula]
|
|
404
|
+
) -> list[Tuple[lqp.Var, lqp.Type]]:
|
|
405
|
+
"""
|
|
406
|
+
Rename abstraction variables that shadow group-by variables.
|
|
407
|
+
|
|
408
|
+
This can happen when the same variable appears in both aggr.group and as an input
|
|
409
|
+
to the aggregation, e.g., min(Person.age).per(Person.age). The group-by variables
|
|
410
|
+
are in the outer scope, while the abstraction parameters are in the inner scope,
|
|
411
|
+
so we need different names to avoid shadowing.
|
|
412
|
+
"""
|
|
413
|
+
# Get the LQP names of group-by variables
|
|
414
|
+
group_var_names = set()
|
|
415
|
+
for group_var in aggr.group:
|
|
416
|
+
lqp_var = _translate_var(ctx, group_var)
|
|
417
|
+
group_var_names.add(lqp_var.name)
|
|
418
|
+
|
|
419
|
+
# Rename any abstraction parameters that conflict with group-by variables
|
|
420
|
+
renamed_abstr_args = []
|
|
421
|
+
for var, typ in abstr_args:
|
|
422
|
+
if var.name in group_var_names:
|
|
423
|
+
# This variable shadows a group-by variable, so rename it
|
|
424
|
+
fresh_var = gen_unique_var(ctx, var.name)
|
|
425
|
+
# Add an equality constraint: fresh_var == var
|
|
426
|
+
# var is a free variable referring to the outer scope group-by variable
|
|
427
|
+
body_conjs.append(mk_primitive("rel_primitive_eq", [fresh_var, var]))
|
|
428
|
+
renamed_abstr_args.append((fresh_var, typ))
|
|
429
|
+
else:
|
|
430
|
+
renamed_abstr_args.append((var, typ))
|
|
431
|
+
|
|
432
|
+
return renamed_abstr_args
|
|
433
|
+
|
|
399
434
|
def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Formula) -> Union[lqp.Reduce, lqp.Formula]:
|
|
400
435
|
# TODO: handle this properly
|
|
401
436
|
aggr_name = aggr.aggregation.name
|
|
@@ -432,6 +467,9 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
|
|
|
432
467
|
body_conjs.extend(projected_eqs)
|
|
433
468
|
abstr_args: list[Tuple[lqp.Var, lqp.Type]] = projected_args + input_args
|
|
434
469
|
|
|
470
|
+
# Rename abstraction variables that shadow group-by variables
|
|
471
|
+
abstr_args = _rename_shadowed_abstraction_vars(ctx, aggr, abstr_args, body_conjs)
|
|
472
|
+
|
|
435
473
|
if aggr_name == "count":
|
|
436
474
|
assert len(output_terms) == 1, "Count and avg expect a single output variable"
|
|
437
475
|
assert isinstance(meta_output_terms[0], ir.Var)
|
|
@@ -441,7 +479,6 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
|
|
|
441
479
|
one_var, eq = constant_to_var(ctx, to_lqp_value(1, meta_output_terms[0].type), "one")
|
|
442
480
|
body_conjs.append(eq)
|
|
443
481
|
abstr_args.append((one_var, typ))
|
|
444
|
-
body = mk_and(body_conjs)
|
|
445
482
|
|
|
446
483
|
# Average needs to wrap the reduce in Exists(Conjunction(Reduce, div))
|
|
447
484
|
if aggr_name == "avg":
|
|
@@ -454,7 +491,6 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
|
|
|
454
491
|
one_var, eq = constant_to_var(ctx, to_lqp_value(1, types.Int64), "one")
|
|
455
492
|
body_conjs.append(eq)
|
|
456
493
|
abstr_args.append((one_var, count_type))
|
|
457
|
-
body = mk_and(body_conjs)
|
|
458
494
|
|
|
459
495
|
# The average will produce two output variables: sum and count.
|
|
460
496
|
sum_result = gen_unique_var(ctx, "sum")
|
|
@@ -462,6 +498,7 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
|
|
|
462
498
|
|
|
463
499
|
# Second to last is the variable we're summing over.
|
|
464
500
|
(sum_var, sum_type) = abstr_args[-2]
|
|
501
|
+
body = mk_and(body_conjs)
|
|
465
502
|
|
|
466
503
|
result = lqp.Reduce(
|
|
467
504
|
op=lqp_avg_op(ctx, aggr.aggregation.name, sum_var.name, sum_type),
|
|
@@ -494,6 +531,7 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
|
|
|
494
531
|
# `input_args` hold the types of the input arguments, but they may have been modified
|
|
495
532
|
# if we're dealing with a count, so we use `abstr_args` to find the type.
|
|
496
533
|
(aggr_arg, aggr_arg_type) = abstr_args[-1]
|
|
534
|
+
body = mk_and(body_conjs)
|
|
497
535
|
|
|
498
536
|
# Group-bys do not need to be handled at all, since they are introduced outside already
|
|
499
537
|
reduce = lqp.Reduce(
|
|
@@ -668,11 +706,11 @@ def to_lqp_value(value: ir.PyValue, value_type: ir.Type) -> lqp.Value:
|
|
|
668
706
|
val = value
|
|
669
707
|
elif typ.type_name == lqp.TypeName.STRING and isinstance(value, str):
|
|
670
708
|
val = value
|
|
671
|
-
elif typ.type_name == lqp.TypeName.DECIMAL and isinstance(value, PyDecimal):
|
|
709
|
+
elif typ.type_name == lqp.TypeName.DECIMAL and isinstance(value, (int, float, PyDecimal)):
|
|
672
710
|
precision = typ.parameters[0].value
|
|
673
711
|
scale = typ.parameters[1].value
|
|
674
712
|
assert isinstance(precision, int) and isinstance(scale, int)
|
|
675
|
-
val = lqp.DecimalValue(precision=precision, scale=scale, value=value, meta=None)
|
|
713
|
+
val = lqp.DecimalValue(precision=precision, scale=scale, value=PyDecimal(value), meta=None)
|
|
676
714
|
elif typ.type_name == lqp.TypeName.DATE and isinstance(value, date):
|
|
677
715
|
val = lqp.DateValue(value=value, meta=None)
|
|
678
716
|
elif typ.type_name == lqp.TypeName.DATETIME and isinstance(value, datetime):
|
|
@@ -200,7 +200,7 @@ class CDC(Pass):
|
|
|
200
200
|
Get the relation that represents this property var in this wide_cdc_relation. If the
|
|
201
201
|
relation is not yet available in the context, this method will create and register it.
|
|
202
202
|
"""
|
|
203
|
-
relation_name = wide_cdc_relation.name
|
|
203
|
+
relation_name = helpers.sanitize(wide_cdc_relation.name).replace("-", "_")
|
|
204
204
|
key = (relation_name, property.name)
|
|
205
205
|
if key not in ctx.cdc_relations:
|
|
206
206
|
# the property relation is overloaded for all properties of the same wide cdc relation, so they have
|
|
@@ -335,9 +335,45 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
335
335
|
partitions:dict[ir.Var, OrderedSet[ir.Task]] = defaultdict(OrderedSet)
|
|
336
336
|
dependencies:dict[ir.Var, OrderedSet[ir.Var]] = defaultdict(OrderedSet)
|
|
337
337
|
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
338
|
+
def dfs_collect_deps(task, deps):
|
|
339
|
+
if isinstance(task, ir.Lookup):
|
|
340
|
+
args = helpers.vars(task.args)
|
|
341
|
+
for i, v in enumerate(args):
|
|
342
|
+
# v depends on all previous vars
|
|
343
|
+
for j in range(i):
|
|
344
|
+
deps[v].add(args[j])
|
|
345
|
+
# for ternary+ lookups, a var also depends on the next vars
|
|
346
|
+
if i > 0 and len(args) >= 3:
|
|
347
|
+
for j in range(i+1, len(args)):
|
|
348
|
+
deps[v].add(args[j])
|
|
349
|
+
elif isinstance(task, ir.Construct):
|
|
350
|
+
vars = helpers.vars(task.values)
|
|
351
|
+
for val_var in vars:
|
|
352
|
+
deps[task.id_var].add(val_var)
|
|
353
|
+
elif isinstance(task, ir.Logical):
|
|
354
|
+
for child in task.body:
|
|
355
|
+
dfs_collect_deps(child, deps)
|
|
356
|
+
elif isinstance(task, (ir.Match, ir.Union)):
|
|
357
|
+
for child in task.tasks:
|
|
358
|
+
dfs_collect_deps(child, deps)
|
|
359
|
+
|
|
360
|
+
for task in tasks:
|
|
361
|
+
dfs_collect_deps(task, dependencies)
|
|
362
|
+
|
|
363
|
+
def dfs_transitive_deps(var, visited):
|
|
364
|
+
for dep_var in dependencies[var]:
|
|
365
|
+
if dep_var not in visited:
|
|
366
|
+
visited.add(dep_var)
|
|
367
|
+
dfs_transitive_deps(dep_var, visited)
|
|
368
|
+
|
|
369
|
+
transitive_deps = defaultdict(OrderedSet)
|
|
370
|
+
for var in list(dependencies.keys()):
|
|
371
|
+
visited = OrderedSet()
|
|
372
|
+
dfs_transitive_deps(var, visited)
|
|
373
|
+
transitive_deps[var] = visited
|
|
374
|
+
dependencies = transitive_deps
|
|
375
|
+
|
|
376
|
+
for var in vars:
|
|
341
377
|
extended_vars = OrderedSet[ir.Var]()
|
|
342
378
|
extended_vars.add(var)
|
|
343
379
|
|
|
@@ -347,28 +383,33 @@ class ExtractKeysRewriter(Rewriter):
|
|
|
347
383
|
for task in tasks:
|
|
348
384
|
if task in partitions[var]:
|
|
349
385
|
continue
|
|
350
|
-
|
|
386
|
+
|
|
387
|
+
if isinstance(task, (ir.Logical, ir.Match, ir.Union)):
|
|
388
|
+
hoisted = helpers.hoisted_vars(task.hoisted)
|
|
389
|
+
if var in hoisted:
|
|
390
|
+
partitions[var].add(task)
|
|
391
|
+
there_is_progress = True
|
|
392
|
+
elif isinstance(task, ir.Construct):
|
|
393
|
+
if task.id_var == var:
|
|
394
|
+
partitions[var].add(task)
|
|
395
|
+
there_is_progress = True
|
|
351
396
|
elif isinstance(task, ir.Lookup):
|
|
352
397
|
args = helpers.vars(task.args)
|
|
353
398
|
if len(args) == 1 and args[0] in extended_vars:
|
|
354
399
|
partitions[var].add(task)
|
|
355
|
-
|
|
400
|
+
there_is_progress = True
|
|
401
|
+
# NOTE: heuristics to have dot_joins work
|
|
356
402
|
elif len(args) >= 3 and args[-2] in extended_vars:
|
|
357
403
|
partitions[var].add(task)
|
|
358
404
|
extended_vars.add(args[-1])
|
|
359
|
-
dependencies[var].add(args[-1])
|
|
360
405
|
there_is_progress = True
|
|
361
406
|
elif len(args) > 1 and args[-1] in extended_vars:
|
|
362
407
|
partitions[var].add(task)
|
|
363
408
|
for arg in args[:-1]:
|
|
364
409
|
extended_vars.add(arg)
|
|
365
|
-
dependencies[var].add(arg)
|
|
366
|
-
there_is_progress = True
|
|
367
|
-
elif isinstance(task, ir.Logical):
|
|
368
|
-
hoisted = helpers.hoisted_vars(task.hoisted)
|
|
369
|
-
if var in hoisted:
|
|
370
|
-
partitions[var].add(task)
|
|
371
410
|
there_is_progress = True
|
|
411
|
+
else:
|
|
412
|
+
assert False, f"invalid node kind {type(task)}"
|
|
372
413
|
|
|
373
414
|
return partitions, dependencies
|
|
374
415
|
|
|
@@ -443,12 +443,14 @@ datetime_second = f.relation("datetime_second", [f.input_field("a", types.DateTi
|
|
|
443
443
|
datetime_weekday = f.relation("datetime_weekday", [f.input_field("a", types.DateTime), f.input_field("b", types.String), f.field("c", types.Int64)])
|
|
444
444
|
|
|
445
445
|
# Other
|
|
446
|
-
range = f.relation(
|
|
447
|
-
|
|
448
|
-
f.input_field("stop", types.
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
])
|
|
446
|
+
range = f.relation(
|
|
447
|
+
"range",
|
|
448
|
+
[f.input_field("start", types.Number), f.input_field("stop", types.Number), f.input_field("step", types.Number), f.field("result", types.Number)],
|
|
449
|
+
overloads=[
|
|
450
|
+
f.relation("range", [f.input_field("start", types.Int64), f.input_field("stop", types.Int64), f.input_field("step", types.Int64), f.field("result", types.Int64)]),
|
|
451
|
+
f.relation("range", [f.input_field("start", types.Int128), f.input_field("stop", types.Int128), f.input_field("step", types.Int128), f.field("result", types.Int128)]),
|
|
452
|
+
],
|
|
453
|
+
)
|
|
452
454
|
|
|
453
455
|
hash = f.relation("hash", [f.input_field("args", types.AnyList), f.field("hash", types.Hash)])
|
|
454
456
|
|
|
@@ -124,10 +124,10 @@ class Flatten(Pass):
|
|
|
124
124
|
output
|
|
125
125
|
"""
|
|
126
126
|
|
|
127
|
-
def __init__(self,
|
|
127
|
+
def __init__(self, use_sql: bool=False):
|
|
128
128
|
super().__init__()
|
|
129
129
|
self.name_cache = NameCache(start_from_one=True)
|
|
130
|
-
self.
|
|
130
|
+
self._use_sql = use_sql
|
|
131
131
|
|
|
132
132
|
|
|
133
133
|
#--------------------------------------------------
|
|
@@ -181,7 +181,12 @@ class Flatten(Pass):
|
|
|
181
181
|
def handle(self, task: ir.Task, ctx: Context) -> Flatten.HandleResult:
|
|
182
182
|
if isinstance(task, ir.Logical):
|
|
183
183
|
return self.handle_logical(task, ctx)
|
|
184
|
-
elif isinstance(task, ir.Union):
|
|
184
|
+
elif isinstance(task, ir.Union) and (task.hoisted or self._use_sql):
|
|
185
|
+
# Only flatten Unions which hoist variables. If there are no hoisted variables,
|
|
186
|
+
# then the Union acts as a filter, and it can be inefficient to flatten it.
|
|
187
|
+
#
|
|
188
|
+
# However, for the SQL backend, we always need to flatten Unions for correct SQL
|
|
189
|
+
# generation.
|
|
185
190
|
return self.handle_union(task, ctx)
|
|
186
191
|
elif isinstance(task, ir.Match):
|
|
187
192
|
return self.handle_match(task, ctx)
|
|
@@ -238,7 +243,7 @@ class Flatten(Pass):
|
|
|
238
243
|
# If there are outputs, flatten each into its own top-level rule, along with its
|
|
239
244
|
# dependencies.
|
|
240
245
|
if groups["outputs"]:
|
|
241
|
-
if
|
|
246
|
+
if self._use_sql:
|
|
242
247
|
ctx.rewrite_ctx.top_level.append(ir.Logical(task.engine, task.hoisted, tuple(body), task.annotations))
|
|
243
248
|
return Flatten.HandleResult(None)
|
|
244
249
|
|
|
@@ -6222,9 +6222,9 @@ class Graph():
|
|
|
6222
6222
|
def _distance_reversed_non_weighted(self):
|
|
6223
6223
|
"""Lazily define and cache the self._distance_reversed_non_weighted relationship, a non-public helper."""
|
|
6224
6224
|
_distance_reversed_non_weighted_rel = self._model.Relationship(f"{{node_u:{self._NodeConceptStr}}} and {{node_v:{self._NodeConceptStr}}} have a reversed distance of {{d:Integer}}")
|
|
6225
|
-
node_u, node_v, node_n
|
|
6225
|
+
node_u, node_v, node_n = self.Node.ref(), self.Node.ref(), self.Node.ref()
|
|
6226
6226
|
node_u, node_v, d = union(
|
|
6227
|
-
where(node_u == node_v, d1
|
|
6227
|
+
where(node_u == node_v, d1 := 0).select(node_u, node_v, d1), # Base case.
|
|
6228
6228
|
where(self._edge(node_v, node_n),
|
|
6229
6229
|
d2 := _distance_reversed_non_weighted_rel(node_u, node_n, Integer) + 1).select(node_u, node_v, d2) # Recursive case.
|
|
6230
6230
|
)
|
|
@@ -6326,13 +6326,12 @@ class Graph():
|
|
|
6326
6326
|
_is_connected_rel.annotate(annotations.track("graphs", "is_connected"))
|
|
6327
6327
|
|
|
6328
6328
|
where(
|
|
6329
|
-
|
|
6330
|
-
|
|
6331
|
-
|
|
6332
|
-
|
|
6333
|
-
|
|
6334
|
-
|
|
6335
|
-
).define(_is_connected_rel(False))
|
|
6329
|
+
union(
|
|
6330
|
+
self._num_nodes(0),
|
|
6331
|
+
count(self._reachable_from_min_node(self.Node.ref())) == self._num_nodes(Integer.ref())
|
|
6332
|
+
)
|
|
6333
|
+
).define(_is_connected_rel(True)) \
|
|
6334
|
+
| define(_is_connected_rel(False))
|
|
6336
6335
|
|
|
6337
6336
|
return _is_connected_rel
|
|
6338
6337
|
|
|
@@ -33,7 +33,7 @@ class Compiler(c.Compiler):
|
|
|
33
33
|
ExtractNestedLogicals(), # before InferTypes to avoid extracting casts
|
|
34
34
|
InferTypes(),
|
|
35
35
|
DNFUnionSplitter(),
|
|
36
|
-
Flatten(
|
|
36
|
+
Flatten(use_sql=True),
|
|
37
37
|
rewrite.RecursiveUnion(),
|
|
38
38
|
rewrite.DoubleNegation(),
|
|
39
39
|
rewrite.SortOutputQuery()
|
|
@@ -1264,7 +1264,7 @@ class ModelToSQL:
|
|
|
1264
1264
|
assert isinstance(index, ir.Var) and isinstance(part, ir.Var), "Third and fourth arguments (index, part) must be variables"
|
|
1265
1265
|
builtin_vars[part] = part_expr
|
|
1266
1266
|
builtin_vars[index] = index_expr
|
|
1267
|
-
elif relation == builtins.range:
|
|
1267
|
+
elif relation == builtins.range or relation in builtins.range.overloads:
|
|
1268
1268
|
assert len(args) == 4, f"Expected 4 args for `range`, got {len(args)}: {args}"
|
|
1269
1269
|
start_raw, stop_raw, step_raw, result = args
|
|
1270
1270
|
start = self._var_to_expr(start_raw, reference, resolve_builtin_var, var_to_construct)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|