relationalai 0.12.7__py3-none-any.whl → 0.12.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. relationalai/clients/snowflake.py +37 -5
  2. relationalai/clients/use_index_poller.py +11 -1
  3. relationalai/semantics/internal/internal.py +29 -7
  4. relationalai/semantics/lqp/compiler.py +1 -1
  5. relationalai/semantics/lqp/constructors.py +6 -0
  6. relationalai/semantics/lqp/executor.py +23 -38
  7. relationalai/semantics/lqp/intrinsics.py +4 -3
  8. relationalai/semantics/lqp/model2lqp.py +6 -12
  9. relationalai/semantics/lqp/passes.py +4 -2
  10. relationalai/semantics/lqp/rewrite/__init__.py +2 -1
  11. relationalai/semantics/lqp/rewrite/function_annotations.py +91 -56
  12. relationalai/semantics/lqp/rewrite/functional_dependencies.py +282 -0
  13. relationalai/semantics/metamodel/builtins.py +6 -0
  14. relationalai/semantics/metamodel/rewrite/__init__.py +2 -1
  15. relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +1 -1
  16. relationalai/semantics/metamodel/rewrite/extract_nested_logicals.py +9 -9
  17. relationalai/semantics/metamodel/rewrite/flatten.py +18 -149
  18. relationalai/semantics/metamodel/rewrite/format_outputs.py +165 -0
  19. relationalai/semantics/reasoners/graph/core.py +98 -70
  20. relationalai/semantics/reasoners/optimization/__init__.py +55 -10
  21. relationalai/semantics/reasoners/optimization/common.py +63 -8
  22. relationalai/semantics/reasoners/optimization/solvers_dev.py +39 -33
  23. relationalai/semantics/reasoners/optimization/solvers_pb.py +1033 -385
  24. relationalai/semantics/rel/compiler.py +21 -2
  25. relationalai/semantics/tests/test_snapshot_abstract.py +3 -0
  26. relationalai/tools/cli.py +10 -0
  27. relationalai/tools/cli_controls.py +15 -0
  28. relationalai/util/otel_handler.py +10 -4
  29. {relationalai-0.12.7.dist-info → relationalai-0.12.9.dist-info}/METADATA +1 -1
  30. {relationalai-0.12.7.dist-info → relationalai-0.12.9.dist-info}/RECORD +33 -31
  31. {relationalai-0.12.7.dist-info → relationalai-0.12.9.dist-info}/WHEEL +0 -0
  32. {relationalai-0.12.7.dist-info → relationalai-0.12.9.dist-info}/entry_points.txt +0 -0
  33. {relationalai-0.12.7.dist-info → relationalai-0.12.9.dist-info}/licenses/LICENSE +0 -0
@@ -6,7 +6,7 @@ import time
6
6
 
7
7
  from relationalai.semantics.snowflake import Table
8
8
  from relationalai.semantics import std
9
- from relationalai.semantics.internal import internal as b # TODO(coey) change b name or remove b.?
9
+ from relationalai.semantics.internal import internal as b
10
10
  from relationalai.semantics.rel.executor import RelExecutor
11
11
  from relationalai.semantics.lqp.executor import LQPExecutor
12
12
  from relationalai.tools.constants import DEFAULT_QUERY_TIMEOUT_MINS
@@ -57,18 +57,29 @@ class SolverModelDev:
57
57
  # self.objective_values = model.Relationship(f"point {{i:int}} has objective value {{val:{data_type}}}", short_name=_name("objective_values"))
58
58
  # self.primal_statuses = model.Relationship("point {i:int} has primal status {status:str}", short_name=_name("primal_statuses"))
59
59
 
60
+ self._model_info = {
61
+ "num_variables": self.variables,
62
+ "num_constraints": self.constraints,
63
+ "num_min_objectives": self.min_objectives,
64
+ "num_max_objectives": self.max_objectives,
65
+ }
66
+
60
67
  # TODO(coey) assert that it is a property? not just a relationship.
61
- def solve_for(self, expr: b.Relationship | b.Fragment, populate: bool = True, **kwargs):
62
- where = []
68
+ def solve_for(self, expr, where: list = [], populate: bool = True, **kwargs):
63
69
  if isinstance(expr, b.Fragment):
64
- assert expr._select and len(expr._select) == 1 and expr._where, "Fragment input for `solve_for` must have exactly one select and a where clause"
65
- rel = expr._select[0]
66
- where.extend(expr._where)
70
+ # TODO(coey) remove in future
71
+ raise ValueError("select fragment argument to `solve_for` is deprecated; instead use `where = [conditions...]` kwarg to specify optional grounding conditions")
72
+ elif isinstance(expr, b.Expression):
73
+ # must be of the form rel(a, ..., x) where the last element is the decision variable
74
+ rel = expr._op
75
+ assert isinstance(rel, b.Relationship)
76
+ params = expr._params
67
77
  elif isinstance(expr, b.Relationship):
68
78
  rel = expr
79
+ params = [b.field_to_type(self._model, f) for f in rel._fields]
69
80
  else:
70
- raise ValueError(f"Invalid expression type {type(expr)} for `solve_for`; must be a Relationship or Fragment")
71
- assert rel._parent and rel._short_name, "Relationship for `solve_for` must have a parent and a short name"
81
+ raise ValueError(f"Invalid expression type {type(expr)} for `solve_for`")
82
+ assert len(params) == len(rel._fields)
72
83
  assert rel not in self._variable_relationships
73
84
 
74
85
  self._variable_relationships.add(rel)
@@ -81,7 +92,7 @@ class SolverModelDev:
81
92
  new_kwargs["type"] = "cont" if self._data_type == "float" else "int"
82
93
  for (key, val) in new_kwargs.items():
83
94
  if key == "name":
84
- assert isinstance(val, (_Any, list)), f"Expected {key} to be a value or list, got {type(val)}"
95
+ assert isinstance(val, _Any) or isinstance(val, list), f"Expected {key} to be a value or list, got {type(val)}"
85
96
  defs.append(self.variable_name(node, make_name(val)))
86
97
  elif key == "type":
87
98
  assert val in ("cont", "int", "bin"), f"Unsupported variable type {val} for `solve_for`; must be cont, int, or bin"
@@ -96,9 +107,7 @@ class SolverModelDev:
96
107
  self.satisfy(b.require(b.Expression(b.Relationship.builtins[op], rel, val)).where(*where))
97
108
  else:
98
109
  raise ValueError(f"Invalid keyword argument {key} for `solve_for`")
99
-
100
- where.append(_make_hash((rel._short_name, rel._parent), node))
101
- b.define(*defs).where(*where)
110
+ b.define(*defs).where(*where, _make_hash((rel._short_name, rel._parent or 0), node))
102
111
 
103
112
  if populate:
104
113
  # get variable values from the result point (populated by the solver)
@@ -108,12 +117,10 @@ class SolverModelDev:
108
117
 
109
118
  return None
110
119
 
111
- def minimize(self, expr: _Number, name: _String | list | None = None):
112
- assert isinstance(expr, _Number)
120
+ def minimize(self, expr, name: _String | list | None = None):
113
121
  return self._handle_expr(self.min_objectives, expr, name)
114
122
 
115
- def maximize(self, expr: _Number, name: _String | list | None = None):
116
- assert isinstance(expr, _Number)
123
+ def maximize(self, expr, name: _String | list | None = None):
117
124
  return self._handle_expr(self.max_objectives, expr, name)
118
125
 
119
126
  def satisfy(self, expr: b.Fragment, check: bool = False, name: _String | list | None = None):
@@ -351,18 +358,10 @@ class SolverModelDev:
351
358
  # get scalar information
352
359
  def __getattr__(self, name: str):
353
360
  df = None
354
- # model info
355
- if name in {"num_variables", "num_constraints", "num_min_objectives", "num_max_objectives"}:
356
- map = {
357
- "num_variables": self.variables,
358
- "num_constraints": self.constraints,
359
- "num_min_objectives": self.min_objectives,
360
- "num_max_objectives": self.max_objectives,
361
- }
361
+ if name in self._model_info:
362
362
  node = b.Hash.ref()
363
- df = b.select(b.count(node).where(map[name](node)) | 0).to_df()
364
- # result info
365
- if name in {"error", "termination_status", "solver_version", "printed_model", "solve_time_sec", "objective_value", "result_count"}:
363
+ df = b.select(b.count(node).where(self._model_info[name](node)) | 0).to_df()
364
+ elif name in {"error", "termination_status", "solver_version", "printed_model", "solve_time_sec", "objective_value", "result_count"}:
366
365
  val = b.String.ref()
367
366
  df = b.select(val).where(self.result_info(name, val)).to_df()
368
367
  if df is not None:
@@ -404,7 +403,7 @@ def _rewrite(expr: b.Producer | b.Fragment, ctx: ExprContext):
404
403
  elif isinstance(expr, (b.Relationship, b.RelationshipRef, b.RelationshipFieldRef)):
405
404
  rel = expr if isinstance(expr, b.Relationship) else expr._relationship
406
405
  if rel in ctx.solver_model._variable_relationships:
407
- return std.hash(rel._short_name, expr._parent)
406
+ return std.hash(rel._short_name, expr._parent or 0)
408
407
  return None
409
408
 
410
409
  elif isinstance(expr, b.Union):
@@ -465,11 +464,18 @@ def _rewrite(expr: b.Producer | b.Fragment, ctx: ExprContext):
465
464
  subctx = ExprContext(sm)
466
465
  ctx.subcontext.append(subctx)
467
466
  subctx.where.extend(expr._where._where)
468
- arg_hash = b.Hash.ref()
469
- subctx.where.append(_make_hash((sm._expr_id, *pre_args), arg_hash)) # TODO also add sym_arg here?
470
467
  sm._expr_id += 1
471
468
  subctx.define.append(sm.operator(node, op))
472
- subctx.define.append(sm.unordered_args_hash(node, arg_hash, sym_arg)) # TODO what if some values are data not hashes?
469
+
470
+ # special_ordered_set_type_2 has two ordered arguments: rank and variables
471
+ if op == "special_ordered_set_type_2":
472
+ assert len(pre_args) == 1, "special_ordered_set_type_2 expects exactly 2 arguments (rank, variables)"
473
+ subctx.define.append(sm.ordered_args_hash(node, pre_args[0], sym_arg))
474
+ else:
475
+ # other aggregate operators use unordered args
476
+ arg_hash = b.Hash.ref()
477
+ subctx.where.append(_make_hash((sm._expr_id, *pre_args), arg_hash))
478
+ subctx.define.append(sm.unordered_args_hash(node, arg_hash, sym_arg))
473
479
  return node
474
480
 
475
481
  elif isinstance(expr, b.Fragment):
@@ -502,7 +508,7 @@ def _expr_strings_rec(x, names_dict, ops_dict, args_dict):
502
508
  s = f"({s})"
503
509
  arg_strs.append(s)
504
510
 
505
- if op in agg_ops:
511
+ if op in agg_ops and not op == "special_ordered_set_type_2":
506
512
  # sort unordered args to improve determinism
507
513
  arg_strs.sort()
508
514
 
@@ -522,7 +528,7 @@ infix_ops = set(["+", "-", "*", "/", "^"])
522
528
  infix_comps = set(["=", "!=", "<", "<=", ">", ">=", "implies"])
523
529
  infixs = infix_ops.union(infix_comps)
524
530
  prefix_ops = set(["abs", "log", "exp"])
525
- agg_ops = set(["sum", "count", "min", "max", "all_different"])
531
+ agg_ops = set(["sum", "count", "min", "max", "all_different", "special_ordered_set_type_2"])
526
532
 
527
533
  # _variable_types = {
528
534
  # "cont": 40,