relationalai 1.0.0a1__py3-none-any.whl → 1.0.0a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. relationalai/semantics/frontend/base.py +3 -0
  2. relationalai/semantics/frontend/front_compiler.py +5 -2
  3. relationalai/semantics/metamodel/builtins.py +2 -1
  4. relationalai/semantics/metamodel/metamodel.py +32 -4
  5. relationalai/semantics/metamodel/pprint.py +5 -3
  6. relationalai/semantics/metamodel/typer.py +324 -297
  7. relationalai/semantics/std/aggregates.py +0 -1
  8. relationalai/semantics/std/datetime.py +4 -1
  9. relationalai/shims/executor.py +22 -4
  10. relationalai/shims/mm2v0.py +108 -38
  11. {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a2.dist-info}/METADATA +1 -1
  12. {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a2.dist-info}/RECORD +27 -27
  13. v0/relationalai/errors.py +23 -0
  14. v0/relationalai/semantics/internal/internal.py +4 -4
  15. v0/relationalai/semantics/internal/snowflake.py +2 -1
  16. v0/relationalai/semantics/lqp/executor.py +16 -11
  17. v0/relationalai/semantics/lqp/model2lqp.py +42 -4
  18. v0/relationalai/semantics/lqp/passes.py +1 -1
  19. v0/relationalai/semantics/lqp/rewrite/cdc.py +1 -1
  20. v0/relationalai/semantics/lqp/rewrite/extract_keys.py +53 -12
  21. v0/relationalai/semantics/metamodel/builtins.py +8 -6
  22. v0/relationalai/semantics/metamodel/rewrite/flatten.py +9 -4
  23. v0/relationalai/semantics/reasoners/graph/core.py +8 -9
  24. v0/relationalai/semantics/sql/compiler.py +2 -2
  25. {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a2.dist-info}/WHEEL +0 -0
  26. {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a2.dist-info}/entry_points.txt +0 -0
  27. {relationalai-1.0.0a1.dist-info → relationalai-1.0.0a2.dist-info}/top_level.txt +0 -0
@@ -396,6 +396,41 @@ def _translate_ascending_rank(ctx: TranslationCtx, limit: int, result_var: lqp.V
396
396
  terms=terms,
397
397
  )
398
398
 
399
+ def _rename_shadowed_abstraction_vars(
400
+ ctx: TranslationCtx,
401
+ aggr: ir.Aggregate,
402
+ abstr_args: list[Tuple[lqp.Var, lqp.Type]],
403
+ body_conjs: list[lqp.Formula]
404
+ ) -> list[Tuple[lqp.Var, lqp.Type]]:
405
+ """
406
+ Rename abstraction variables that shadow group-by variables.
407
+
408
+ This can happen when the same variable appears in both aggr.group and as an input
409
+ to the aggregation, e.g., min(Person.age).per(Person.age). The group-by variables
410
+ are in the outer scope, while the abstraction parameters are in the inner scope,
411
+ so we need different names to avoid shadowing.
412
+ """
413
+ # Get the LQP names of group-by variables
414
+ group_var_names = set()
415
+ for group_var in aggr.group:
416
+ lqp_var = _translate_var(ctx, group_var)
417
+ group_var_names.add(lqp_var.name)
418
+
419
+ # Rename any abstraction parameters that conflict with group-by variables
420
+ renamed_abstr_args = []
421
+ for var, typ in abstr_args:
422
+ if var.name in group_var_names:
423
+ # This variable shadows a group-by variable, so rename it
424
+ fresh_var = gen_unique_var(ctx, var.name)
425
+ # Add an equality constraint: fresh_var == var
426
+ # var is a free variable referring to the outer scope group-by variable
427
+ body_conjs.append(mk_primitive("rel_primitive_eq", [fresh_var, var]))
428
+ renamed_abstr_args.append((fresh_var, typ))
429
+ else:
430
+ renamed_abstr_args.append((var, typ))
431
+
432
+ return renamed_abstr_args
433
+
399
434
  def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Formula) -> Union[lqp.Reduce, lqp.Formula]:
400
435
  # TODO: handle this properly
401
436
  aggr_name = aggr.aggregation.name
@@ -432,6 +467,9 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
432
467
  body_conjs.extend(projected_eqs)
433
468
  abstr_args: list[Tuple[lqp.Var, lqp.Type]] = projected_args + input_args
434
469
 
470
+ # Rename abstraction variables that shadow group-by variables
471
+ abstr_args = _rename_shadowed_abstraction_vars(ctx, aggr, abstr_args, body_conjs)
472
+
435
473
  if aggr_name == "count":
436
474
  assert len(output_terms) == 1, "Count and avg expect a single output variable"
437
475
  assert isinstance(meta_output_terms[0], ir.Var)
@@ -441,7 +479,6 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
441
479
  one_var, eq = constant_to_var(ctx, to_lqp_value(1, meta_output_terms[0].type), "one")
442
480
  body_conjs.append(eq)
443
481
  abstr_args.append((one_var, typ))
444
- body = mk_and(body_conjs)
445
482
 
446
483
  # Average needs to wrap the reduce in Exists(Conjunction(Reduce, div))
447
484
  if aggr_name == "avg":
@@ -454,7 +491,6 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
454
491
  one_var, eq = constant_to_var(ctx, to_lqp_value(1, types.Int64), "one")
455
492
  body_conjs.append(eq)
456
493
  abstr_args.append((one_var, count_type))
457
- body = mk_and(body_conjs)
458
494
 
459
495
  # The average will produce two output variables: sum and count.
460
496
  sum_result = gen_unique_var(ctx, "sum")
@@ -462,6 +498,7 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
462
498
 
463
499
  # Second to last is the variable we're summing over.
464
500
  (sum_var, sum_type) = abstr_args[-2]
501
+ body = mk_and(body_conjs)
465
502
 
466
503
  result = lqp.Reduce(
467
504
  op=lqp_avg_op(ctx, aggr.aggregation.name, sum_var.name, sum_type),
@@ -494,6 +531,7 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
494
531
  # `input_args` hold the types of the input arguments, but they may have been modified
495
532
  # if we're dealing with a count, so we use `abstr_args` to find the type.
496
533
  (aggr_arg, aggr_arg_type) = abstr_args[-1]
534
+ body = mk_and(body_conjs)
497
535
 
498
536
  # Group-bys do not need to be handled at all, since they are introduced outside already
499
537
  reduce = lqp.Reduce(
@@ -668,11 +706,11 @@ def to_lqp_value(value: ir.PyValue, value_type: ir.Type) -> lqp.Value:
668
706
  val = value
669
707
  elif typ.type_name == lqp.TypeName.STRING and isinstance(value, str):
670
708
  val = value
671
- elif typ.type_name == lqp.TypeName.DECIMAL and isinstance(value, PyDecimal):
709
+ elif typ.type_name == lqp.TypeName.DECIMAL and isinstance(value, (int, float, PyDecimal)):
672
710
  precision = typ.parameters[0].value
673
711
  scale = typ.parameters[1].value
674
712
  assert isinstance(precision, int) and isinstance(scale, int)
675
- val = lqp.DecimalValue(precision=precision, scale=scale, value=value, meta=None)
713
+ val = lqp.DecimalValue(precision=precision, scale=scale, value=PyDecimal(value), meta=None)
676
714
  elif typ.type_name == lqp.TypeName.DATE and isinstance(value, date):
677
715
  val = lqp.DateValue(value=value, meta=None)
678
716
  elif typ.type_name == lqp.TypeName.DATETIME and isinstance(value, datetime):
@@ -390,7 +390,7 @@ class EliminateData(Pass):
390
390
  [
391
391
  f.logical(
392
392
  [
393
- f.lookup(rel_builtins.eq, [f.literal(val), var])
393
+ f.lookup(rel_builtins.eq, [f.literal(val, var.type), var])
394
394
  for (val, var) in zip(row, node.vars)
395
395
  ],
396
396
  )
@@ -200,7 +200,7 @@ class CDC(Pass):
200
200
  Get the relation that represents this property var in this wide_cdc_relation. If the
201
201
  relation is not yet available in the context, this method will create and register it.
202
202
  """
203
- relation_name = wide_cdc_relation.name.lower().replace(".", "_")
203
+ relation_name = helpers.sanitize(wide_cdc_relation.name).replace("-", "_")
204
204
  key = (relation_name, property.name)
205
205
  if key not in ctx.cdc_relations:
206
206
  # the property relation is overloaded for all properties of the same wide cdc relation, so they have
@@ -335,9 +335,45 @@ class ExtractKeysRewriter(Rewriter):
335
335
  partitions:dict[ir.Var, OrderedSet[ir.Task]] = defaultdict(OrderedSet)
336
336
  dependencies:dict[ir.Var, OrderedSet[ir.Var]] = defaultdict(OrderedSet)
337
337
 
338
- worklist = list(vars)
339
- while worklist:
340
- var = worklist.pop()
338
+ def dfs_collect_deps(task, deps):
339
+ if isinstance(task, ir.Lookup):
340
+ args = helpers.vars(task.args)
341
+ for i, v in enumerate(args):
342
+ # v depends on all previous vars
343
+ for j in range(i):
344
+ deps[v].add(args[j])
345
+ # for ternary+ lookups, a var also depends on the next vars
346
+ if i > 0 and len(args) >= 3:
347
+ for j in range(i+1, len(args)):
348
+ deps[v].add(args[j])
349
+ elif isinstance(task, ir.Construct):
350
+ vars = helpers.vars(task.values)
351
+ for val_var in vars:
352
+ deps[task.id_var].add(val_var)
353
+ elif isinstance(task, ir.Logical):
354
+ for child in task.body:
355
+ dfs_collect_deps(child, deps)
356
+ elif isinstance(task, (ir.Match, ir.Union)):
357
+ for child in task.tasks:
358
+ dfs_collect_deps(child, deps)
359
+
360
+ for task in tasks:
361
+ dfs_collect_deps(task, dependencies)
362
+
363
+ def dfs_transitive_deps(var, visited):
364
+ for dep_var in dependencies[var]:
365
+ if dep_var not in visited:
366
+ visited.add(dep_var)
367
+ dfs_transitive_deps(dep_var, visited)
368
+
369
+ transitive_deps = defaultdict(OrderedSet)
370
+ for var in list(dependencies.keys()):
371
+ visited = OrderedSet()
372
+ dfs_transitive_deps(var, visited)
373
+ transitive_deps[var] = visited
374
+ dependencies = transitive_deps
375
+
376
+ for var in vars:
341
377
  extended_vars = OrderedSet[ir.Var]()
342
378
  extended_vars.add(var)
343
379
 
@@ -347,28 +383,33 @@ class ExtractKeysRewriter(Rewriter):
347
383
  for task in tasks:
348
384
  if task in partitions[var]:
349
385
  continue
350
- # Already added this task to this partition
386
+
387
+ if isinstance(task, (ir.Logical, ir.Match, ir.Union)):
388
+ hoisted = helpers.hoisted_vars(task.hoisted)
389
+ if var in hoisted:
390
+ partitions[var].add(task)
391
+ there_is_progress = True
392
+ elif isinstance(task, ir.Construct):
393
+ if task.id_var == var:
394
+ partitions[var].add(task)
395
+ there_is_progress = True
351
396
  elif isinstance(task, ir.Lookup):
352
397
  args = helpers.vars(task.args)
353
398
  if len(args) == 1 and args[0] in extended_vars:
354
399
  partitions[var].add(task)
355
- # TODO: hack to have dot_joins work
400
+ there_is_progress = True
401
+ # NOTE: heuristics to have dot_joins work
356
402
  elif len(args) >= 3 and args[-2] in extended_vars:
357
403
  partitions[var].add(task)
358
404
  extended_vars.add(args[-1])
359
- dependencies[var].add(args[-1])
360
405
  there_is_progress = True
361
406
  elif len(args) > 1 and args[-1] in extended_vars:
362
407
  partitions[var].add(task)
363
408
  for arg in args[:-1]:
364
409
  extended_vars.add(arg)
365
- dependencies[var].add(arg)
366
- there_is_progress = True
367
- elif isinstance(task, ir.Logical):
368
- hoisted = helpers.hoisted_vars(task.hoisted)
369
- if var in hoisted:
370
- partitions[var].add(task)
371
410
  there_is_progress = True
411
+ else:
412
+ assert False, f"invalid node kind {type(task)}"
372
413
 
373
414
  return partitions, dependencies
374
415
 
@@ -443,12 +443,14 @@ datetime_second = f.relation("datetime_second", [f.input_field("a", types.DateTi
443
443
  datetime_weekday = f.relation("datetime_weekday", [f.input_field("a", types.DateTime), f.input_field("b", types.String), f.field("c", types.Int64)])
444
444
 
445
445
  # Other
446
- range = f.relation("range", [
447
- f.input_field("start", types.Int64),
448
- f.input_field("stop", types.Int64),
449
- f.input_field("step", types.Int64),
450
- f.field("result", types.Int64),
451
- ])
446
+ range = f.relation(
447
+ "range",
448
+ [f.input_field("start", types.Number), f.input_field("stop", types.Number), f.input_field("step", types.Number), f.field("result", types.Number)],
449
+ overloads=[
450
+ f.relation("range", [f.input_field("start", types.Int64), f.input_field("stop", types.Int64), f.input_field("step", types.Int64), f.field("result", types.Int64)]),
451
+ f.relation("range", [f.input_field("start", types.Int128), f.input_field("stop", types.Int128), f.input_field("step", types.Int128), f.field("result", types.Int128)]),
452
+ ],
453
+ )
452
454
 
453
455
  hash = f.relation("hash", [f.input_field("args", types.AnyList), f.field("hash", types.Hash)])
454
456
 
@@ -124,10 +124,10 @@ class Flatten(Pass):
124
124
  output
125
125
  """
126
126
 
127
- def __init__(self, handle_outputs: bool=True):
127
+ def __init__(self, use_sql: bool=False):
128
128
  super().__init__()
129
129
  self.name_cache = NameCache(start_from_one=True)
130
- self._handle_outputs = handle_outputs
130
+ self._use_sql = use_sql
131
131
 
132
132
 
133
133
  #--------------------------------------------------
@@ -181,7 +181,12 @@ class Flatten(Pass):
181
181
  def handle(self, task: ir.Task, ctx: Context) -> Flatten.HandleResult:
182
182
  if isinstance(task, ir.Logical):
183
183
  return self.handle_logical(task, ctx)
184
- elif isinstance(task, ir.Union):
184
+ elif isinstance(task, ir.Union) and (task.hoisted or self._use_sql):
185
+ # Only flatten Unions which hoist variables. If there are no hoisted variables,
186
+ # then the Union acts as a filter, and it can be inefficient to flatten it.
187
+ #
188
+ # However, for the SQL backend, we always need to flatten Unions for correct SQL
189
+ # generation.
185
190
  return self.handle_union(task, ctx)
186
191
  elif isinstance(task, ir.Match):
187
192
  return self.handle_match(task, ctx)
@@ -238,7 +243,7 @@ class Flatten(Pass):
238
243
  # If there are outputs, flatten each into its own top-level rule, along with its
239
244
  # dependencies.
240
245
  if groups["outputs"]:
241
- if not self._handle_outputs:
246
+ if self._use_sql:
242
247
  ctx.rewrite_ctx.top_level.append(ir.Logical(task.engine, task.hoisted, tuple(body), task.annotations))
243
248
  return Flatten.HandleResult(None)
244
249
 
@@ -6222,9 +6222,9 @@ class Graph():
6222
6222
  def _distance_reversed_non_weighted(self):
6223
6223
  """Lazily define and cache the self._distance_reversed_non_weighted relationship, a non-public helper."""
6224
6224
  _distance_reversed_non_weighted_rel = self._model.Relationship(f"{{node_u:{self._NodeConceptStr}}} and {{node_v:{self._NodeConceptStr}}} have a reversed distance of {{d:Integer}}")
6225
- node_u, node_v, node_n, d1 = self.Node.ref(), self.Node.ref(), self.Node.ref(), Integer.ref()
6225
+ node_u, node_v, node_n = self.Node.ref(), self.Node.ref(), self.Node.ref()
6226
6226
  node_u, node_v, d = union(
6227
- where(node_u == node_v, d1 == 0).select(node_u, node_v, d1), # Base case.
6227
+ where(node_u == node_v, d1 := 0).select(node_u, node_v, d1), # Base case.
6228
6228
  where(self._edge(node_v, node_n),
6229
6229
  d2 := _distance_reversed_non_weighted_rel(node_u, node_n, Integer) + 1).select(node_u, node_v, d2) # Recursive case.
6230
6230
  )
@@ -6326,13 +6326,12 @@ class Graph():
6326
6326
  _is_connected_rel.annotate(annotations.track("graphs", "is_connected"))
6327
6327
 
6328
6328
  where(
6329
- self._num_nodes(0) |
6330
- count(self._reachable_from_min_node(self.Node.ref())) == self._num_nodes(Integer.ref())
6331
- ).define(_is_connected_rel(True))
6332
-
6333
- where(
6334
- not_(_is_connected_rel(True))
6335
- ).define(_is_connected_rel(False))
6329
+ union(
6330
+ self._num_nodes(0),
6331
+ count(self._reachable_from_min_node(self.Node.ref())) == self._num_nodes(Integer.ref())
6332
+ )
6333
+ ).define(_is_connected_rel(True)) \
6334
+ | define(_is_connected_rel(False))
6336
6335
 
6337
6336
  return _is_connected_rel
6338
6337
 
@@ -33,7 +33,7 @@ class Compiler(c.Compiler):
33
33
  ExtractNestedLogicals(), # before InferTypes to avoid extracting casts
34
34
  InferTypes(),
35
35
  DNFUnionSplitter(),
36
- Flatten(handle_outputs=False),
36
+ Flatten(use_sql=True),
37
37
  rewrite.RecursiveUnion(),
38
38
  rewrite.DoubleNegation(),
39
39
  rewrite.SortOutputQuery()
@@ -1264,7 +1264,7 @@ class ModelToSQL:
1264
1264
  assert isinstance(index, ir.Var) and isinstance(part, ir.Var), "Third and fourth arguments (index, part) must be variables"
1265
1265
  builtin_vars[part] = part_expr
1266
1266
  builtin_vars[index] = index_expr
1267
- elif relation == builtins.range:
1267
+ elif relation == builtins.range or relation in builtins.range.overloads:
1268
1268
  assert len(args) == 4, f"Expected 4 args for `range`, got {len(args)}: {args}"
1269
1269
  start_raw, stop_raw, step_raw, result = args
1270
1270
  start = self._var_to_expr(start_raw, reference, resolve_builtin_var, var_to_construct)