pepflow 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pepflow/__init__.py +6 -1
- pepflow/constraint.py +58 -1
- pepflow/constraint_test.py +71 -0
- pepflow/e2e_test.py +83 -4
- pepflow/expression_manager.py +329 -44
- pepflow/expression_manager_test.py +150 -0
- pepflow/function.py +294 -52
- pepflow/function_test.py +180 -114
- pepflow/interactive_constraint.py +165 -75
- pepflow/parameter.py +187 -0
- pepflow/parameter_test.py +128 -0
- pepflow/pep.py +263 -16
- pepflow/pep_context.py +122 -6
- pepflow/pep_context_test.py +25 -0
- pepflow/pep_test.py +8 -0
- pepflow/point.py +155 -49
- pepflow/point_test.py +40 -188
- pepflow/scalar.py +260 -47
- pepflow/scalar_test.py +102 -130
- pepflow/solver.py +170 -3
- pepflow/solver_test.py +50 -2
- pepflow/utils.py +39 -7
- {pepflow-0.1.4.dist-info → pepflow-0.1.5.dist-info}/METADATA +24 -5
- pepflow-0.1.5.dist-info/RECORD +28 -0
- pepflow-0.1.4.dist-info/RECORD +0 -24
- {pepflow-0.1.4.dist-info → pepflow-0.1.5.dist-info}/WHEEL +0 -0
- {pepflow-0.1.4.dist-info → pepflow-0.1.5.dist-info}/licenses/LICENSE +0 -0
- {pepflow-0.1.4.dist-info → pepflow-0.1.5.dist-info}/top_level.txt +0 -0
pepflow/parameter.py
ADDED
@@ -0,0 +1,187 @@
|
|
1
|
+
# Copyright: 2025 The PEPFlow Developers
|
2
|
+
#
|
3
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
4
|
+
# or more contributor license agreements. See the NOTICE file
|
5
|
+
# distributed with this work for additional information
|
6
|
+
# regarding copyright ownership. The ASF licenses this file
|
7
|
+
# to you under the Apache License, Version 2.0 (the
|
8
|
+
# "License"); you may not use this file except in compliance
|
9
|
+
# with the License. You may obtain a copy of the License at
|
10
|
+
#
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12
|
+
#
|
13
|
+
# Unless required by applicable law or agreed to in writing,
|
14
|
+
# software distributed under the License is distributed on an
|
15
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
16
|
+
# KIND, either express or implied. See the License for the
|
17
|
+
# specific language governing permissions and limitations
|
18
|
+
# under the License.
|
19
|
+
|
20
|
+
from __future__ import annotations
|
21
|
+
|
22
|
+
import attrs
|
23
|
+
|
24
|
+
from pepflow import utils
|
25
|
+
|
26
|
+
# Sentile of no found of resolving parameters
|
27
|
+
NOT_FOUND = "__NOT_FOUND__"
|
28
|
+
|
29
|
+
|
30
|
+
@attrs.frozen
|
31
|
+
class ParameterRepresentation:
|
32
|
+
op: utils.Op
|
33
|
+
left_param: utils.NUMERICAL_TYPE | Parameter
|
34
|
+
right_param: utils.NUMERICAL_TYPE | Parameter
|
35
|
+
|
36
|
+
|
37
|
+
def eval_parameter(
|
38
|
+
param: Parameter | utils.NUMERICAL_TYPE,
|
39
|
+
resolve_parameters: dict[str, utils.NUMERICAL_TYPE],
|
40
|
+
) -> utils.NUMERICAL_TYPE:
|
41
|
+
if isinstance(param, Parameter):
|
42
|
+
return param.get_value(resolve_parameters)
|
43
|
+
if utils.is_numerical(param):
|
44
|
+
return param
|
45
|
+
raise ValueError(f"Encounter the unknown parameter type: {param} ({type(param)})")
|
46
|
+
|
47
|
+
|
48
|
+
@attrs.frozen
|
49
|
+
class Parameter:
|
50
|
+
# If name is None, it is a composite parameter.
|
51
|
+
name: str | None
|
52
|
+
|
53
|
+
eval_expression: ParameterRepresentation | None = None
|
54
|
+
|
55
|
+
def __attrs_post_init__(self):
|
56
|
+
if self.name is None and self.eval_expression is None:
|
57
|
+
raise ValueError(
|
58
|
+
"For a parameter, must specify a name or an eval_expression"
|
59
|
+
)
|
60
|
+
if self.name is None or self.eval_expression is None:
|
61
|
+
return
|
62
|
+
|
63
|
+
raise ValueError(
|
64
|
+
"For a parameter, only one of name or eval_expression should be None."
|
65
|
+
)
|
66
|
+
|
67
|
+
def __repr__(self):
|
68
|
+
if self.eval_expression is None:
|
69
|
+
return self.name
|
70
|
+
|
71
|
+
op = self.eval_expression.op
|
72
|
+
left_param = self.eval_expression.left_param
|
73
|
+
right_param = self.eval_expression.right_param
|
74
|
+
# TODO having a better parentheses handling.
|
75
|
+
if op == utils.Op.ADD:
|
76
|
+
return f"({left_param}+{right_param})"
|
77
|
+
if op == utils.Op.SUB:
|
78
|
+
return f"({left_param}-{right_param})"
|
79
|
+
if op == utils.Op.MUL:
|
80
|
+
return f"({left_param}*{right_param})"
|
81
|
+
if op == utils.Op.DIV:
|
82
|
+
return f"({left_param}/{right_param})"
|
83
|
+
|
84
|
+
def get_value(
|
85
|
+
self, resolve_parameters: dict[str, utils.NUMERICAL_TYPE]
|
86
|
+
) -> utils.NUMERICAL_TYPE:
|
87
|
+
if self.eval_expression is None:
|
88
|
+
val = resolve_parameters.get(self.name, NOT_FOUND)
|
89
|
+
if val is NOT_FOUND:
|
90
|
+
raise ValueError(f"Cannot resolve Parameter named: {self.name}")
|
91
|
+
return val
|
92
|
+
op = self.eval_expression.op
|
93
|
+
left_param = eval_parameter(self.eval_expression.left_param, resolve_parameters)
|
94
|
+
right_param = eval_parameter(
|
95
|
+
self.eval_expression.right_param, resolve_parameters
|
96
|
+
)
|
97
|
+
|
98
|
+
if op == utils.Op.ADD:
|
99
|
+
return left_param + right_param
|
100
|
+
if op == utils.Op.SUB:
|
101
|
+
return left_param - right_param
|
102
|
+
if op == utils.Op.MUL:
|
103
|
+
return left_param * right_param
|
104
|
+
if op == utils.Op.DIV:
|
105
|
+
return left_param / right_param
|
106
|
+
|
107
|
+
raise ValueError(f"Encountered unknown {op=} when evaluation the point.")
|
108
|
+
|
109
|
+
def __add__(self, other):
|
110
|
+
if not utils.is_numerical_or_parameter(other):
|
111
|
+
return NotImplemented
|
112
|
+
return Parameter(
|
113
|
+
name=None,
|
114
|
+
eval_expression=ParameterRepresentation(
|
115
|
+
op=utils.Op.ADD, left_param=self, right_param=other
|
116
|
+
),
|
117
|
+
)
|
118
|
+
|
119
|
+
def __radd__(self, other):
|
120
|
+
if not utils.is_numerical_or_parameter(other):
|
121
|
+
return NotImplemented
|
122
|
+
return Parameter(
|
123
|
+
name=None,
|
124
|
+
eval_expression=ParameterRepresentation(
|
125
|
+
op=utils.Op.ADD, left_param=other, right_param=self
|
126
|
+
),
|
127
|
+
)
|
128
|
+
|
129
|
+
def __sub__(self, other):
|
130
|
+
if not utils.is_numerical_or_parameter(other):
|
131
|
+
return NotImplemented
|
132
|
+
return Parameter(
|
133
|
+
name=None,
|
134
|
+
eval_expression=ParameterRepresentation(
|
135
|
+
op=utils.Op.SUB, left_param=self, right_param=other
|
136
|
+
),
|
137
|
+
)
|
138
|
+
|
139
|
+
def __rsub__(self, other):
|
140
|
+
if not utils.is_numerical_or_parameter(other):
|
141
|
+
return NotImplemented
|
142
|
+
return Parameter(
|
143
|
+
name=None,
|
144
|
+
eval_expression=ParameterRepresentation(
|
145
|
+
op=utils.Op.SUB, left_param=other, right_param=self
|
146
|
+
),
|
147
|
+
)
|
148
|
+
|
149
|
+
def __mul__(self, other):
|
150
|
+
if not utils.is_numerical_or_parameter(other):
|
151
|
+
return NotImplemented
|
152
|
+
return Parameter(
|
153
|
+
name=None,
|
154
|
+
eval_expression=ParameterRepresentation(
|
155
|
+
op=utils.Op.MUL, left_param=self, right_param=other
|
156
|
+
),
|
157
|
+
)
|
158
|
+
|
159
|
+
def __rmul__(self, other):
|
160
|
+
if not utils.is_numerical_or_parameter(other):
|
161
|
+
return NotImplemented
|
162
|
+
return Parameter(
|
163
|
+
name=None,
|
164
|
+
eval_expression=ParameterRepresentation(
|
165
|
+
op=utils.Op.MUL, left_param=other, right_param=self
|
166
|
+
),
|
167
|
+
)
|
168
|
+
|
169
|
+
def __truediv__(self, other):
|
170
|
+
if not utils.is_numerical_or_parameter(other):
|
171
|
+
return NotImplemented
|
172
|
+
return Parameter(
|
173
|
+
name=None,
|
174
|
+
eval_expression=ParameterRepresentation(
|
175
|
+
op=utils.Op.DIV, left_param=self, right_param=other
|
176
|
+
),
|
177
|
+
)
|
178
|
+
|
179
|
+
def __rtruediv__(self, other):
|
180
|
+
if not utils.is_numerical_or_parameter(other):
|
181
|
+
return NotImplemented
|
182
|
+
return Parameter(
|
183
|
+
name=None,
|
184
|
+
eval_expression=ParameterRepresentation(
|
185
|
+
op=utils.Op.DIV, left_param=other, right_param=self
|
186
|
+
),
|
187
|
+
)
|
@@ -0,0 +1,128 @@
|
|
1
|
+
# Copyright: 2025 The PEPFlow Developers
|
2
|
+
#
|
3
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
4
|
+
# or more contributor license agreements. See the NOTICE file
|
5
|
+
# distributed with this work for additional information
|
6
|
+
# regarding copyright ownership. The ASF licenses this file
|
7
|
+
# to you under the Apache License, Version 2.0 (the
|
8
|
+
# "License"); you may not use this file except in compliance
|
9
|
+
# with the License. You may obtain a copy of the License at
|
10
|
+
#
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12
|
+
#
|
13
|
+
# Unless required by applicable law or agreed to in writing,
|
14
|
+
# software distributed under the License is distributed on an
|
15
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
16
|
+
# KIND, either express or implied. See the License for the
|
17
|
+
# specific language governing permissions and limitations
|
18
|
+
# under the License.
|
19
|
+
|
20
|
+
from typing import Iterator
|
21
|
+
|
22
|
+
import numpy as np
|
23
|
+
import pytest
|
24
|
+
import sympy as sp
|
25
|
+
|
26
|
+
from pepflow import pep_context as pc
|
27
|
+
from pepflow.expression_manager import ExpressionManager
|
28
|
+
from pepflow.parameter import Parameter
|
29
|
+
from pepflow.point import Point
|
30
|
+
from pepflow.scalar import Scalar
|
31
|
+
|
32
|
+
|
33
|
+
@pytest.fixture
|
34
|
+
def pep_context() -> Iterator[pc.PEPContext]:
|
35
|
+
"""Prepare the pep context and reset the context to None at the end."""
|
36
|
+
ctx = pc.PEPContext("test").set_as_current()
|
37
|
+
yield ctx
|
38
|
+
pc.set_current_context(None)
|
39
|
+
|
40
|
+
|
41
|
+
def test_parameter_interact_with_scalar(pep_context: pc.PEPContext):
|
42
|
+
pm1 = Parameter("pm1")
|
43
|
+
s1 = Scalar(is_basis=True, tags=["s1"])
|
44
|
+
|
45
|
+
_ = pm1 + s1
|
46
|
+
_ = s1 + pm1
|
47
|
+
_ = pm1 - s1
|
48
|
+
_ = s1 - pm1
|
49
|
+
_ = s1 * pm1
|
50
|
+
_ = pm1 * s1
|
51
|
+
_ = s1 / pm1
|
52
|
+
|
53
|
+
|
54
|
+
def test_parameter_interact_with_point(pep_context: pc.PEPContext):
|
55
|
+
pm1 = Parameter("pm1")
|
56
|
+
p1 = Point(is_basis=True, tags=["p1"])
|
57
|
+
|
58
|
+
_ = p1 * pm1
|
59
|
+
_ = pm1 * p1
|
60
|
+
_ = p1 / pm1
|
61
|
+
|
62
|
+
|
63
|
+
def test_parameter_composition_with_point_and_scalar(pep_context: pc.PEPContext):
|
64
|
+
pm1 = Parameter("pm1")
|
65
|
+
pm2 = Parameter("pm2")
|
66
|
+
p1 = Point(is_basis=True, tags=["p1"])
|
67
|
+
s1 = Scalar(is_basis=True, tags=["s1"])
|
68
|
+
|
69
|
+
s2 = s1 + pm1 + pm2 * p1**2
|
70
|
+
assert str(s2) == "s1+pm1+pm2*|p1|^2"
|
71
|
+
|
72
|
+
|
73
|
+
def test_parameter_composition(pep_context: pc.PEPContext):
|
74
|
+
pm1 = Parameter("pm1")
|
75
|
+
pm2 = Parameter("pm2")
|
76
|
+
|
77
|
+
pp = (pm1 + 2) * pm2
|
78
|
+
assert str(pp) == "((pm1+2)*pm2)"
|
79
|
+
assert pp.get_value({"pm1": 3, "pm2": 6}) == 30
|
80
|
+
|
81
|
+
pp2 = (pm1 + sp.Rational(1, 2)) * pm2
|
82
|
+
assert str(pp2) == "((pm1+1/2)*pm2)"
|
83
|
+
assert pp2.get_value({"pm1": sp.Rational(1, 3), "pm2": sp.Rational(6, 5)}) == 1
|
84
|
+
|
85
|
+
|
86
|
+
def test_expression_manager_eval_with_parameter(pep_context: pc.PEPContext):
|
87
|
+
pm1 = Parameter("pm1")
|
88
|
+
p1 = Point(is_basis=True, tags=["p1"])
|
89
|
+
p2 = Point(is_basis=True, tags=["p2"])
|
90
|
+
p3 = pm1 * p1 + p2 / 4
|
91
|
+
|
92
|
+
em = ExpressionManager(pep_context, {"pm1": 2.3})
|
93
|
+
np.testing.assert_allclose(em.eval_point(p3).vector, np.array([2.3, 0.25]))
|
94
|
+
|
95
|
+
em = ExpressionManager(pep_context, {"pm1": 3.4})
|
96
|
+
np.testing.assert_allclose(em.eval_point(p3).vector, np.array([3.4, 0.25]))
|
97
|
+
|
98
|
+
|
99
|
+
def test_expression_manager_eval_with_parameter_scalar(pep_context: pc.PEPContext):
|
100
|
+
pm1 = Parameter("pm1")
|
101
|
+
pm2 = Parameter("pm2")
|
102
|
+
p1 = Point(is_basis=True, tags=["p1"])
|
103
|
+
p2 = Point(is_basis=True, tags=["p2"])
|
104
|
+
s1 = Scalar(is_basis=True, tags=["s1"])
|
105
|
+
s2 = pm1 * p1 * p2 + pm2 + s1
|
106
|
+
|
107
|
+
em = ExpressionManager(pep_context, {"pm1": 2.4, "pm2": 4.3})
|
108
|
+
assert np.isclose(em.eval_scalar(s2).constant, 4.3)
|
109
|
+
np.testing.assert_allclose(em.eval_scalar(s2).vector, np.array([1]))
|
110
|
+
np.testing.assert_allclose(
|
111
|
+
em.eval_scalar(s2).matrix, np.array([[0, 1.2], [1.2, 0]])
|
112
|
+
)
|
113
|
+
|
114
|
+
|
115
|
+
def test_expression_manager_eval_composition(pep_context: pc.PEPContext):
|
116
|
+
pm1 = Parameter("pm1")
|
117
|
+
pm2 = Parameter("pm2")
|
118
|
+
p1 = Point(is_basis=True, tags=["p1"])
|
119
|
+
p2 = Point(is_basis=True, tags=["p2"])
|
120
|
+
s1 = Scalar(is_basis=True, tags=["s1"])
|
121
|
+
|
122
|
+
s2 = 1 / pm1 * p1 * p2 + (pm2 + 1) * s1
|
123
|
+
em = ExpressionManager(pep_context, {"pm1": 0.5, "pm2": 4.3})
|
124
|
+
assert np.isclose(em.eval_scalar(s2).constant, 0)
|
125
|
+
np.testing.assert_allclose(em.eval_scalar(s2).vector, np.array([5.3]))
|
126
|
+
np.testing.assert_allclose(
|
127
|
+
em.eval_scalar(s2).matrix, np.array([[0, 1.0], [1.0, 0]])
|
128
|
+
)
|
pepflow/pep.py
CHANGED
@@ -20,6 +20,7 @@
|
|
20
20
|
from __future__ import annotations
|
21
21
|
|
22
22
|
import contextlib
|
23
|
+
from collections import defaultdict
|
23
24
|
from typing import TYPE_CHECKING, Any, Iterator
|
24
25
|
|
25
26
|
import attrs
|
@@ -30,22 +31,50 @@ from pepflow import pep_context as pc
|
|
30
31
|
from pepflow import point as pt
|
31
32
|
from pepflow import scalar as sc
|
32
33
|
from pepflow import solver as ps
|
34
|
+
from pepflow import utils
|
33
35
|
from pepflow.constants import PSD_CONSTRAINT
|
34
36
|
|
35
37
|
if TYPE_CHECKING:
|
36
38
|
from pepflow.constraint import Constraint
|
37
39
|
from pepflow.function import Function
|
38
|
-
from pepflow.solver import DualVariableManager
|
40
|
+
from pepflow.solver import DualVariableManager, PrimalVariableManager
|
39
41
|
|
40
42
|
|
41
43
|
@attrs.frozen
|
42
44
|
class PEPResult:
|
45
|
+
"""
|
46
|
+
A data class object that contains the results of solving the Primal
|
47
|
+
PEP.
|
48
|
+
|
49
|
+
Attributes:
|
50
|
+
primal_opt_value (float): The objective value of the solved Primal PEP.
|
51
|
+
dual_var_manager (:class:`DualVariableManager`): A manager object which
|
52
|
+
provides access to the dual variables associated with the
|
53
|
+
constraints of the Primal PEP.
|
54
|
+
solver_status (Any): States whether the solver managed to solve the
|
55
|
+
Primal PEP successfully.
|
56
|
+
context (:class:`PEPContext`): The :class:`PEPContext` object used to
|
57
|
+
solve the Primal PEP.
|
58
|
+
|
59
|
+
"""
|
60
|
+
|
43
61
|
primal_opt_value: float
|
44
62
|
dual_var_manager: DualVariableManager
|
45
63
|
solver_status: Any
|
46
64
|
context: pc.PEPContext
|
47
65
|
|
48
66
|
def get_function_dual_variables(self) -> dict[Function, np.ndarray]:
|
67
|
+
"""
|
68
|
+
Return a dictionary which contains the associated dual variables of the
|
69
|
+
interpolation constraints for Primal PEP.
|
70
|
+
|
71
|
+
Returns:
|
72
|
+
dict[:class:`Function`, np.ndarray]: A dictionary where the keys
|
73
|
+
are :class:`Function` objects, and the values are the dual
|
74
|
+
variables associated with the interpolation constraints of the
|
75
|
+
:class:`Function` key.
|
76
|
+
"""
|
77
|
+
|
49
78
|
def get_matrix_of_dual_value(df: pd.DataFrame) -> np.ndarray:
|
50
79
|
# Check if we need to update the order.
|
51
80
|
return (
|
@@ -68,31 +97,90 @@ class PEPResult:
|
|
68
97
|
|
69
98
|
return df_dict_matrix
|
70
99
|
|
71
|
-
def get_psd_dual_matrix(self):
|
100
|
+
def get_psd_dual_matrix(self) -> np.ndarray:
|
101
|
+
"""
|
102
|
+
Return the PSD dual variable matrix associated with the constraint
|
103
|
+
that the Primal PEP decision variable :math:`G` is PSD.
|
104
|
+
|
105
|
+
Returns:
|
106
|
+
np.ndarray: The PSD dual variable matrix associated with the
|
107
|
+
constraint that the Primal PEP decision variable :math:`G` is PSD.
|
108
|
+
"""
|
72
109
|
return np.array(self.dual_var_manager.dual_value(PSD_CONSTRAINT))
|
73
110
|
|
74
111
|
|
112
|
+
@attrs.frozen
|
113
|
+
class DualPEPResult:
|
114
|
+
"""
|
115
|
+
A data class object that contains the results of solving the Dual
|
116
|
+
PEP.
|
117
|
+
|
118
|
+
Attributes:
|
119
|
+
dual_opt_value (float): The objective value of the solved Dual PEP.
|
120
|
+
primal_var_manager (:class:`PrimalVariableManager`): A manager object
|
121
|
+
which provides access to the primal variables of Dual PEP.
|
122
|
+
solver_status (Any): States whether the solver managed to solve the
|
123
|
+
Dual PEP successfully.
|
124
|
+
context (:class:`PEPContext`): The :class:`PEPContext` object used to
|
125
|
+
solve the Dual PEP.
|
126
|
+
|
127
|
+
"""
|
128
|
+
|
129
|
+
dual_opt_value: float
|
130
|
+
primal_var_manager: PrimalVariableManager
|
131
|
+
solver_status: Any
|
132
|
+
context: pc.PEPContext
|
133
|
+
|
134
|
+
|
75
135
|
class PEPBuilder:
|
76
|
-
"""The main class for PEP
|
136
|
+
"""The main class for Primal and Dual PEP formulation.
|
137
|
+
|
138
|
+
Attributes:
|
139
|
+
init_conditions (list[:class:`Constraint`]): A list of all the initial
|
140
|
+
conditions associated with this PEP.
|
141
|
+
functions (list[:class:`Function`]): A list of all the functions
|
142
|
+
associated with this PEP.
|
143
|
+
performance_metric (:class:`Scalar`): The performance metric for this
|
144
|
+
PEP.
|
145
|
+
relaxed_constraints (list[str]): A list of names of the constraints
|
146
|
+
that will be ignored when the Primal or Dual PEP is constructed.
|
147
|
+
dual_val_constraint (dict[str, list[tuple[str, float]]]): A dictionary
|
148
|
+
of the form `{constraint_name: [op, val]}`. The `constraint_name`
|
149
|
+
is the name of the constraint the dual variable is associated with.
|
150
|
+
The `op` is a string for the type of relation, i.e., `lt`, `gt`,
|
151
|
+
or `eq`. The `val` is the value for the other side of the
|
152
|
+
constraint. For example, consider `{"f:x_1,x_0", [("eq", 0)]}`.
|
153
|
+
Denote the associated dual variable as :math:`\\lambda_{1,0}`.
|
154
|
+
Then, this means to add a constraint of the form
|
155
|
+
:math:`\\lambda_{1,0} = 0` to the Dual PEP. Because it is hard to
|
156
|
+
judge if the constraint associated with `constraint_name` is
|
157
|
+
active, we suggest to not add dual variable constraints manually
|
158
|
+
but instead use the interactive dashboard.
|
159
|
+
"""
|
77
160
|
|
78
161
|
def __init__(self):
|
79
162
|
self.pep_context_dict: dict[str, pc.PEPContext] = {}
|
80
163
|
|
81
|
-
self.init_conditions
|
82
|
-
self.functions
|
83
|
-
self.
|
84
|
-
self.performance_metric = None # scalar
|
164
|
+
self.init_conditions: list[Constraint] = []
|
165
|
+
self.functions: list[Function] = []
|
166
|
+
self.performance_metric: sc.Scalar | None = None
|
85
167
|
|
86
168
|
# Contain the name for the constraints that should be removed.
|
87
169
|
# We should think about a better choice like manager.
|
88
|
-
self.relaxed_constraints = []
|
170
|
+
self.relaxed_constraints: list[str] = []
|
171
|
+
|
172
|
+
# `dual_val_constraint` has the data structure: {constraint_name: [op, val]}.
|
173
|
+
# Because it is hard to judge if the dual_val_constraint is applied or not,
|
174
|
+
# we recommend that do not use manually but through the interactive dashboard.
|
175
|
+
self.dual_val_constraint: dict[str, list[tuple[str, float]]] = defaultdict(list)
|
89
176
|
|
90
177
|
def clear_setup(self):
|
178
|
+
"""Resets the :class:`PEPBuilder` object."""
|
91
179
|
self.init_conditions.clear()
|
92
180
|
self.functions.clear()
|
93
|
-
self.interpolation_constraints.clear()
|
94
181
|
self.performance_metric = None
|
95
182
|
self.relaxed_constraints.clear()
|
183
|
+
self.dual_val_constraint.clear()
|
96
184
|
|
97
185
|
@contextlib.contextmanager
|
98
186
|
def make_context(
|
@@ -130,20 +218,116 @@ class PEPBuilder:
|
|
130
218
|
return point
|
131
219
|
|
132
220
|
def set_initial_constraint(self, constraint):
|
221
|
+
"""
|
222
|
+
Set an initial condition.
|
223
|
+
|
224
|
+
Args:
|
225
|
+
constraint (:class:`Constraint`): A :class:`Constraint` object that
|
226
|
+
represents the desired initial condition.
|
227
|
+
"""
|
133
228
|
self.init_conditions.append(constraint)
|
134
229
|
|
135
230
|
def set_performance_metric(self, metric: sc.Scalar):
|
231
|
+
"""
|
232
|
+
Set the performance metric.
|
233
|
+
|
234
|
+
Args:
|
235
|
+
metric (:class:`Scalar`): A :class:`Scalar` object that
|
236
|
+
represents the desired performance metric.
|
237
|
+
"""
|
136
238
|
self.performance_metric = metric
|
137
239
|
|
138
240
|
def set_relaxed_constraints(self, relaxed_constraints: list[str]):
|
241
|
+
"""
|
242
|
+
Set the constraints that will be ignored.
|
243
|
+
|
244
|
+
Args:
|
245
|
+
relaxed_constraints (list[str]): A list of names of constraints
|
246
|
+
that will be ignored.
|
247
|
+
"""
|
139
248
|
self.relaxed_constraints.extend(relaxed_constraints)
|
140
249
|
|
141
|
-
def
|
250
|
+
def add_dual_val_constraint(
|
251
|
+
self, constraint_name: str, op: str, val: float
|
252
|
+
) -> None:
|
253
|
+
if op not in ["lt", "gt", "eq"]:
|
254
|
+
raise ValueError(f"op must be one of `lt`, `gt`, or `eq` but get {op}")
|
255
|
+
if not utils.is_numerical(val):
|
256
|
+
raise ValueError("Value must be some numerical value.")
|
257
|
+
|
258
|
+
self.dual_val_constraint[constraint_name].append((op, val))
|
259
|
+
|
260
|
+
def declare_func(self, function_class: type[Function], tag: str, **kwargs):
|
261
|
+
"""
|
262
|
+
Declare a function.
|
263
|
+
|
264
|
+
Args:
|
265
|
+
function_class (type[:class:`Function`]): The type of function we want to
|
266
|
+
declare. Examples include :class:`ConvexFunction` or
|
267
|
+
:class:`SmoothConvexFunction`.
|
268
|
+
tag (str): A tag that will be added to the :class:`Function`'s
|
269
|
+
`tags` list. It can be used to identify the :class:`Function`
|
270
|
+
object.
|
271
|
+
**kwargs: The other parameters needed to declare the function. For
|
272
|
+
example, :class:`SmoothConvexFunction` will require a
|
273
|
+
smoothness parameter `L`.
|
274
|
+
"""
|
142
275
|
func = function_class(is_basis=True, composition=None, **kwargs)
|
276
|
+
func.add_tag(tag)
|
143
277
|
self.functions.append(func)
|
144
278
|
return func
|
145
279
|
|
146
|
-
def
|
280
|
+
def get_func_by_tag(self, tag: str):
|
281
|
+
"""
|
282
|
+
Return the :class:`Function` object associated with the provided `tag`.
|
283
|
+
|
284
|
+
Args:
|
285
|
+
tag (str): The `tag` of the :class:`Function` object we want to
|
286
|
+
retrieve.
|
287
|
+
|
288
|
+
Returns:
|
289
|
+
:class:`Function`: The :class:`Function` object associated with
|
290
|
+
the `tag`.
|
291
|
+
|
292
|
+
Note:
|
293
|
+
Currently, only basis :class:`Function` objects can be retrieved.
|
294
|
+
This will be updated eventually.
|
295
|
+
"""
|
296
|
+
# TODO: Add support to return composite functions as well. Right now we can only return base functions
|
297
|
+
for f in self.functions:
|
298
|
+
if tag in f.tags:
|
299
|
+
return f
|
300
|
+
raise ValueError("Cannot find the function of given tag.")
|
301
|
+
|
302
|
+
def solve(
|
303
|
+
self,
|
304
|
+
context: pc.PEPContext | None = None,
|
305
|
+
resolve_parameters: dict[str, utils.NUMERICAL_TYPE] | None = None,
|
306
|
+
):
|
307
|
+
return self.solve_primal(context, resolve_parameters=resolve_parameters)
|
308
|
+
|
309
|
+
def solve_primal(
|
310
|
+
self,
|
311
|
+
context: pc.PEPContext | None = None,
|
312
|
+
resolve_parameters: dict[str, utils.NUMERICAL_TYPE] | None = None,
|
313
|
+
):
|
314
|
+
"""
|
315
|
+
Solve the Primal PEP associated with this :class:`PEPBuilder` object
|
316
|
+
using the given :class:`PEPContext` object.
|
317
|
+
|
318
|
+
Args:
|
319
|
+
context (:class:`PEPContext`): The :class:`PEPContext` object used
|
320
|
+
to solve the Primal PEP associated with this
|
321
|
+
:class:`PEPBuilder` object. `None` if we consider the current
|
322
|
+
global :class:`PEPContext` object.
|
323
|
+
resolve_parameters (dict[str, :class:`NUMERICAL_TYPE`]): A dictionary that
|
324
|
+
maps the name of parameters to the numerical values.
|
325
|
+
|
326
|
+
Returns:
|
327
|
+
:class:`PEPResult`: A :class:`PEPResult` object that contains the
|
328
|
+
information obtained after solving the Primal PEP associated with
|
329
|
+
this :class:`PEPBuilder` object.
|
330
|
+
"""
|
147
331
|
if context is None:
|
148
332
|
context = pc.get_current_context()
|
149
333
|
if context is None:
|
@@ -151,23 +335,86 @@ class PEPBuilder:
|
|
151
335
|
|
152
336
|
all_constraints: list[Constraint] = [*self.init_conditions]
|
153
337
|
for f in self.functions:
|
154
|
-
all_constraints.extend(f.get_interpolation_constraints())
|
155
|
-
all_constraints.extend(context.opt_conditions[f])
|
338
|
+
all_constraints.extend(f.get_interpolation_constraints(context))
|
156
339
|
|
157
340
|
# for now, we heavily rely on the CVX. We can make a wrapper class to avoid
|
158
341
|
# direct dependency in the future.
|
159
|
-
solver = ps.
|
342
|
+
solver = ps.CVXPrimalSolver(
|
160
343
|
perf_metric=self.performance_metric,
|
161
344
|
constraints=[
|
162
345
|
c for c in all_constraints if c.name not in self.relaxed_constraints
|
163
346
|
],
|
164
347
|
context=context,
|
165
348
|
)
|
166
|
-
problem = solver.build_problem()
|
167
|
-
result = problem.solve(
|
349
|
+
problem = solver.build_problem(resolve_parameters=resolve_parameters)
|
350
|
+
result = problem.solve()
|
168
351
|
return PEPResult(
|
169
352
|
primal_opt_value=result,
|
170
353
|
dual_var_manager=solver.dual_var_manager,
|
171
354
|
solver_status=problem.status,
|
172
355
|
context=context,
|
173
356
|
)
|
357
|
+
|
358
|
+
def solve_dual(
|
359
|
+
self,
|
360
|
+
context: pc.PEPContext | None = None,
|
361
|
+
resolve_parameters: dict[str, utils.NUMERICAL_TYPE] | None = None,
|
362
|
+
):
|
363
|
+
"""
|
364
|
+
Solve the Dual PEP associated with this :class:`PEPBuilder` object
|
365
|
+
using the given :class:`PEPContext` object.
|
366
|
+
|
367
|
+
Args:
|
368
|
+
context (:class:`PEPContext`): The :class:`PEPContext` object used
|
369
|
+
to solve the Dual PEP associated with this :class:`PEPBuilder`
|
370
|
+
object. `None` if we consider the current global
|
371
|
+
:class:`PEPContext` object.
|
372
|
+
resolve_parameters (dict[str, :class:`NUMERICAL_TYPE`]): A dictionary that
|
373
|
+
maps the name of parameters to the numerical values.
|
374
|
+
|
375
|
+
Returns:
|
376
|
+
:class:`DualPEPResult`: A :class:`DualPEPResult` object that
|
377
|
+
contains the information obtained after solving the Dual PEP
|
378
|
+
associated with this :class:`PEPBuilder` object.
|
379
|
+
"""
|
380
|
+
if context is None:
|
381
|
+
context = pc.get_current_context()
|
382
|
+
if context is None:
|
383
|
+
raise RuntimeError("Did you forget to create a context?")
|
384
|
+
|
385
|
+
all_constraints: list[Constraint] = [*self.init_conditions]
|
386
|
+
for f in self.functions:
|
387
|
+
all_constraints.extend(f.get_interpolation_constraints(context))
|
388
|
+
|
389
|
+
# TODO: Consider a better API and interface to adding constraint for primal
|
390
|
+
# variable of dual problem. We can add `extra_primal_val_constraints` to add
|
391
|
+
# more constraints on primal var in dual PEP, i.e. dual var in primal PEP.
|
392
|
+
constraints = []
|
393
|
+
for c in all_constraints:
|
394
|
+
if c.name in self.relaxed_constraints:
|
395
|
+
continue
|
396
|
+
for op, val in self.dual_val_constraint[c.name]:
|
397
|
+
if op == "lt":
|
398
|
+
c.dual_lt(val)
|
399
|
+
elif op == "gt":
|
400
|
+
c.dual_gt(val)
|
401
|
+
elif op == "eq":
|
402
|
+
c.dual_eq(val)
|
403
|
+
else:
|
404
|
+
raise ValueError(f"Unknown op when construct the {c}")
|
405
|
+
constraints.append(c)
|
406
|
+
|
407
|
+
dual_solver = ps.CVXDualSolver(
|
408
|
+
perf_metric=self.performance_metric,
|
409
|
+
constraints=constraints,
|
410
|
+
context=context,
|
411
|
+
)
|
412
|
+
problem = dual_solver.build_problem(resolve_parameters=resolve_parameters)
|
413
|
+
result = problem.solve()
|
414
|
+
|
415
|
+
return DualPEPResult(
|
416
|
+
dual_opt_value=result,
|
417
|
+
primal_var_manager=dual_solver.primal_var_manager,
|
418
|
+
solver_status=problem.status,
|
419
|
+
context=context,
|
420
|
+
)
|