relationalai 0.12.7__py3-none-any.whl → 0.12.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/clients/snowflake.py +37 -5
- relationalai/clients/use_index_poller.py +11 -1
- relationalai/semantics/internal/internal.py +29 -7
- relationalai/semantics/lqp/compiler.py +1 -1
- relationalai/semantics/lqp/constructors.py +6 -0
- relationalai/semantics/lqp/executor.py +23 -38
- relationalai/semantics/lqp/intrinsics.py +4 -3
- relationalai/semantics/lqp/model2lqp.py +6 -12
- relationalai/semantics/lqp/passes.py +4 -2
- relationalai/semantics/lqp/rewrite/__init__.py +2 -1
- relationalai/semantics/lqp/rewrite/function_annotations.py +91 -56
- relationalai/semantics/lqp/rewrite/functional_dependencies.py +282 -0
- relationalai/semantics/metamodel/builtins.py +6 -0
- relationalai/semantics/metamodel/rewrite/__init__.py +2 -1
- relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +1 -1
- relationalai/semantics/metamodel/rewrite/extract_nested_logicals.py +9 -9
- relationalai/semantics/metamodel/rewrite/flatten.py +18 -149
- relationalai/semantics/metamodel/rewrite/format_outputs.py +165 -0
- relationalai/semantics/reasoners/graph/core.py +98 -70
- relationalai/semantics/reasoners/optimization/__init__.py +55 -10
- relationalai/semantics/reasoners/optimization/common.py +63 -8
- relationalai/semantics/reasoners/optimization/solvers_dev.py +39 -33
- relationalai/semantics/reasoners/optimization/solvers_pb.py +1033 -385
- relationalai/semantics/rel/compiler.py +21 -2
- relationalai/semantics/tests/test_snapshot_abstract.py +3 -0
- relationalai/tools/cli.py +10 -0
- relationalai/tools/cli_controls.py +15 -0
- relationalai/util/otel_handler.py +10 -4
- {relationalai-0.12.7.dist-info → relationalai-0.12.9.dist-info}/METADATA +1 -1
- {relationalai-0.12.7.dist-info → relationalai-0.12.9.dist-info}/RECORD +33 -31
- {relationalai-0.12.7.dist-info → relationalai-0.12.9.dist-info}/WHEEL +0 -0
- {relationalai-0.12.7.dist-info → relationalai-0.12.9.dist-info}/entry_points.txt +0 -0
- {relationalai-0.12.7.dist-info → relationalai-0.12.9.dist-info}/licenses/LICENSE +0 -0
|
@@ -6,7 +6,7 @@ import time
|
|
|
6
6
|
|
|
7
7
|
from relationalai.semantics.snowflake import Table
|
|
8
8
|
from relationalai.semantics import std
|
|
9
|
-
from relationalai.semantics.internal import internal as b
|
|
9
|
+
from relationalai.semantics.internal import internal as b
|
|
10
10
|
from relationalai.semantics.rel.executor import RelExecutor
|
|
11
11
|
from relationalai.semantics.lqp.executor import LQPExecutor
|
|
12
12
|
from relationalai.tools.constants import DEFAULT_QUERY_TIMEOUT_MINS
|
|
@@ -57,18 +57,29 @@ class SolverModelDev:
|
|
|
57
57
|
# self.objective_values = model.Relationship(f"point {{i:int}} has objective value {{val:{data_type}}}", short_name=_name("objective_values"))
|
|
58
58
|
# self.primal_statuses = model.Relationship("point {i:int} has primal status {status:str}", short_name=_name("primal_statuses"))
|
|
59
59
|
|
|
60
|
+
self._model_info = {
|
|
61
|
+
"num_variables": self.variables,
|
|
62
|
+
"num_constraints": self.constraints,
|
|
63
|
+
"num_min_objectives": self.min_objectives,
|
|
64
|
+
"num_max_objectives": self.max_objectives,
|
|
65
|
+
}
|
|
66
|
+
|
|
60
67
|
# TODO(coey) assert that it is a property? not just a relationship.
|
|
61
|
-
def solve_for(self, expr:
|
|
62
|
-
where = []
|
|
68
|
+
def solve_for(self, expr, where: list = [], populate: bool = True, **kwargs):
|
|
63
69
|
if isinstance(expr, b.Fragment):
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
70
|
+
# TODO(coey) remove in future
|
|
71
|
+
raise ValueError("select fragment argument to `solve_for` is deprecated; instead use `where = [conditions...]` kwarg to specify optional grounding conditions")
|
|
72
|
+
elif isinstance(expr, b.Expression):
|
|
73
|
+
# must be of the form rel(a, ..., x) where the last element is the decision variable
|
|
74
|
+
rel = expr._op
|
|
75
|
+
assert isinstance(rel, b.Relationship)
|
|
76
|
+
params = expr._params
|
|
67
77
|
elif isinstance(expr, b.Relationship):
|
|
68
78
|
rel = expr
|
|
79
|
+
params = [b.field_to_type(self._model, f) for f in rel._fields]
|
|
69
80
|
else:
|
|
70
|
-
raise ValueError(f"Invalid expression type {type(expr)} for `solve_for
|
|
71
|
-
assert
|
|
81
|
+
raise ValueError(f"Invalid expression type {type(expr)} for `solve_for`")
|
|
82
|
+
assert len(params) == len(rel._fields)
|
|
72
83
|
assert rel not in self._variable_relationships
|
|
73
84
|
|
|
74
85
|
self._variable_relationships.add(rel)
|
|
@@ -81,7 +92,7 @@ class SolverModelDev:
|
|
|
81
92
|
new_kwargs["type"] = "cont" if self._data_type == "float" else "int"
|
|
82
93
|
for (key, val) in new_kwargs.items():
|
|
83
94
|
if key == "name":
|
|
84
|
-
assert isinstance(val, (
|
|
95
|
+
assert isinstance(val, _Any) or isinstance(val, list), f"Expected {key} to be a value or list, got {type(val)}"
|
|
85
96
|
defs.append(self.variable_name(node, make_name(val)))
|
|
86
97
|
elif key == "type":
|
|
87
98
|
assert val in ("cont", "int", "bin"), f"Unsupported variable type {val} for `solve_for`; must be cont, int, or bin"
|
|
@@ -96,9 +107,7 @@ class SolverModelDev:
|
|
|
96
107
|
self.satisfy(b.require(b.Expression(b.Relationship.builtins[op], rel, val)).where(*where))
|
|
97
108
|
else:
|
|
98
109
|
raise ValueError(f"Invalid keyword argument {key} for `solve_for`")
|
|
99
|
-
|
|
100
|
-
where.append(_make_hash((rel._short_name, rel._parent), node))
|
|
101
|
-
b.define(*defs).where(*where)
|
|
110
|
+
b.define(*defs).where(*where, _make_hash((rel._short_name, rel._parent or 0), node))
|
|
102
111
|
|
|
103
112
|
if populate:
|
|
104
113
|
# get variable values from the result point (populated by the solver)
|
|
@@ -108,12 +117,10 @@ class SolverModelDev:
|
|
|
108
117
|
|
|
109
118
|
return None
|
|
110
119
|
|
|
111
|
-
def minimize(self, expr
|
|
112
|
-
assert isinstance(expr, _Number)
|
|
120
|
+
def minimize(self, expr, name: _String | list | None = None):
|
|
113
121
|
return self._handle_expr(self.min_objectives, expr, name)
|
|
114
122
|
|
|
115
|
-
def maximize(self, expr
|
|
116
|
-
assert isinstance(expr, _Number)
|
|
123
|
+
def maximize(self, expr, name: _String | list | None = None):
|
|
117
124
|
return self._handle_expr(self.max_objectives, expr, name)
|
|
118
125
|
|
|
119
126
|
def satisfy(self, expr: b.Fragment, check: bool = False, name: _String | list | None = None):
|
|
@@ -351,18 +358,10 @@ class SolverModelDev:
|
|
|
351
358
|
# get scalar information
|
|
352
359
|
def __getattr__(self, name: str):
|
|
353
360
|
df = None
|
|
354
|
-
|
|
355
|
-
if name in {"num_variables", "num_constraints", "num_min_objectives", "num_max_objectives"}:
|
|
356
|
-
map = {
|
|
357
|
-
"num_variables": self.variables,
|
|
358
|
-
"num_constraints": self.constraints,
|
|
359
|
-
"num_min_objectives": self.min_objectives,
|
|
360
|
-
"num_max_objectives": self.max_objectives,
|
|
361
|
-
}
|
|
361
|
+
if name in self._model_info:
|
|
362
362
|
node = b.Hash.ref()
|
|
363
|
-
df = b.select(b.count(node).where(
|
|
364
|
-
|
|
365
|
-
if name in {"error", "termination_status", "solver_version", "printed_model", "solve_time_sec", "objective_value", "result_count"}:
|
|
363
|
+
df = b.select(b.count(node).where(self._model_info[name](node)) | 0).to_df()
|
|
364
|
+
elif name in {"error", "termination_status", "solver_version", "printed_model", "solve_time_sec", "objective_value", "result_count"}:
|
|
366
365
|
val = b.String.ref()
|
|
367
366
|
df = b.select(val).where(self.result_info(name, val)).to_df()
|
|
368
367
|
if df is not None:
|
|
@@ -404,7 +403,7 @@ def _rewrite(expr: b.Producer | b.Fragment, ctx: ExprContext):
|
|
|
404
403
|
elif isinstance(expr, (b.Relationship, b.RelationshipRef, b.RelationshipFieldRef)):
|
|
405
404
|
rel = expr if isinstance(expr, b.Relationship) else expr._relationship
|
|
406
405
|
if rel in ctx.solver_model._variable_relationships:
|
|
407
|
-
return std.hash(rel._short_name, expr._parent)
|
|
406
|
+
return std.hash(rel._short_name, expr._parent or 0)
|
|
408
407
|
return None
|
|
409
408
|
|
|
410
409
|
elif isinstance(expr, b.Union):
|
|
@@ -465,11 +464,18 @@ def _rewrite(expr: b.Producer | b.Fragment, ctx: ExprContext):
|
|
|
465
464
|
subctx = ExprContext(sm)
|
|
466
465
|
ctx.subcontext.append(subctx)
|
|
467
466
|
subctx.where.extend(expr._where._where)
|
|
468
|
-
arg_hash = b.Hash.ref()
|
|
469
|
-
subctx.where.append(_make_hash((sm._expr_id, *pre_args), arg_hash)) # TODO also add sym_arg here?
|
|
470
467
|
sm._expr_id += 1
|
|
471
468
|
subctx.define.append(sm.operator(node, op))
|
|
472
|
-
|
|
469
|
+
|
|
470
|
+
# special_ordered_set_type_2 has two ordered arguments: rank and variables
|
|
471
|
+
if op == "special_ordered_set_type_2":
|
|
472
|
+
assert len(pre_args) == 1, "special_ordered_set_type_2 expects exactly 2 arguments (rank, variables)"
|
|
473
|
+
subctx.define.append(sm.ordered_args_hash(node, pre_args[0], sym_arg))
|
|
474
|
+
else:
|
|
475
|
+
# other aggregate operators use unordered args
|
|
476
|
+
arg_hash = b.Hash.ref()
|
|
477
|
+
subctx.where.append(_make_hash((sm._expr_id, *pre_args), arg_hash))
|
|
478
|
+
subctx.define.append(sm.unordered_args_hash(node, arg_hash, sym_arg))
|
|
473
479
|
return node
|
|
474
480
|
|
|
475
481
|
elif isinstance(expr, b.Fragment):
|
|
@@ -502,7 +508,7 @@ def _expr_strings_rec(x, names_dict, ops_dict, args_dict):
|
|
|
502
508
|
s = f"({s})"
|
|
503
509
|
arg_strs.append(s)
|
|
504
510
|
|
|
505
|
-
if op in agg_ops:
|
|
511
|
+
if op in agg_ops and not op == "special_ordered_set_type_2":
|
|
506
512
|
# sort unordered args to improve determinism
|
|
507
513
|
arg_strs.sort()
|
|
508
514
|
|
|
@@ -522,7 +528,7 @@ infix_ops = set(["+", "-", "*", "/", "^"])
|
|
|
522
528
|
infix_comps = set(["=", "!=", "<", "<=", ">", ">=", "implies"])
|
|
523
529
|
infixs = infix_ops.union(infix_comps)
|
|
524
530
|
prefix_ops = set(["abs", "log", "exp"])
|
|
525
|
-
agg_ops = set(["sum", "count", "min", "max", "all_different"])
|
|
531
|
+
agg_ops = set(["sum", "count", "min", "max", "all_different", "special_ordered_set_type_2"])
|
|
526
532
|
|
|
527
533
|
# _variable_types = {
|
|
528
534
|
# "cont": 40,
|