pepflow 0.1.4a1__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pepflow/__init__.py +5 -1
- pepflow/constraint.py +58 -1
- pepflow/e2e_test.py +47 -3
- pepflow/expression_manager.py +272 -57
- pepflow/expression_manager_test.py +36 -2
- pepflow/function.py +180 -10
- pepflow/parameter.py +187 -0
- pepflow/parameter_test.py +128 -0
- pepflow/pep.py +254 -14
- pepflow/pep_context.py +116 -0
- pepflow/pep_context_test.py +21 -0
- pepflow/point.py +155 -49
- pepflow/point_test.py +12 -0
- pepflow/scalar.py +260 -47
- pepflow/scalar_test.py +15 -0
- pepflow/solver.py +170 -3
- pepflow/solver_test.py +50 -2
- pepflow/utils.py +39 -7
- {pepflow-0.1.4a1.dist-info → pepflow-0.1.5.dist-info}/METADATA +12 -11
- pepflow-0.1.5.dist-info/RECORD +28 -0
- pepflow-0.1.4a1.dist-info/RECORD +0 -26
- {pepflow-0.1.4a1.dist-info → pepflow-0.1.5.dist-info}/WHEEL +0 -0
- {pepflow-0.1.4a1.dist-info → pepflow-0.1.5.dist-info}/licenses/LICENSE +0 -0
- {pepflow-0.1.4a1.dist-info → pepflow-0.1.5.dist-info}/top_level.txt +0 -0
pepflow/__init__.py
CHANGED
@@ -21,6 +21,7 @@
|
|
21
21
|
from .constants import PSD_CONSTRAINT as PSD_CONSTRAINT
|
22
22
|
from .constraint import Constraint as Constraint
|
23
23
|
from .expression_manager import ExpressionManager as ExpressionManager
|
24
|
+
from .expression_manager import represent_matrix_by_basis as represent_matrix_by_basis
|
24
25
|
|
25
26
|
# interactive_constraint
|
26
27
|
from .interactive_constraint import launch as launch
|
@@ -28,6 +29,7 @@ from .interactive_constraint import launch as launch
|
|
28
29
|
# pep
|
29
30
|
from .pep import PEPBuilder as PEPBuilder
|
30
31
|
from .pep import PEPResult as PEPResult
|
32
|
+
from .pep import DualPEPResult as DualPEPResult
|
31
33
|
from .pep_context import PEPContext as PEPContext
|
32
34
|
from .pep_context import get_current_context as get_current_context
|
33
35
|
from .pep_context import set_current_context as set_current_context
|
@@ -43,8 +45,10 @@ from .scalar import EvaluatedScalar as EvaluatedScalar
|
|
43
45
|
from .scalar import Scalar as Scalar
|
44
46
|
|
45
47
|
# Solver
|
46
|
-
from .solver import
|
48
|
+
from .solver import CVXPrimalSolver as CVXPrimalSolver
|
49
|
+
from .solver import CVXDualSolver as CVXDualSolver
|
47
50
|
from .solver import DualVariableManager as DualVariableManager
|
51
|
+
from .solver import PrimalVariableManager as PrimalVariableManager
|
48
52
|
|
49
53
|
# Others
|
50
54
|
from .utils import SOP as SOP
|
pepflow/constraint.py
CHANGED
@@ -31,8 +31,65 @@ from pepflow import utils
|
|
31
31
|
|
32
32
|
@attrs.frozen
|
33
33
|
class Constraint:
|
34
|
-
"""
|
34
|
+
"""A :class:`Constraint` object that represents inequalities and
|
35
|
+
equalities of :class:`Scalar` objects.
|
36
|
+
|
37
|
+
Denote an arbitrary :class:`Scalar` object as `x`. Constraints represent:
|
38
|
+
`x <= 0`, `x >= 0`, and `x = 0`.
|
39
|
+
|
40
|
+
Attributes:
|
41
|
+
scalar (:class:`Scalar`): The :class:`Scalar` object involved in
|
42
|
+
the inequality or equality.
|
43
|
+
comparator (:class:`Comparator`): :class:`Comparator` is an enumeration
|
44
|
+
that can be either `GT`, `LT`, or `EQ`. They represent `>=`, `<=`,
|
45
|
+
and `=` respectively.
|
46
|
+
name (str): The unique name of the :class:`Comparator` object.
|
47
|
+
associated_dual_var_constraints (list[tuple[utils.Comparator, float]]):
|
48
|
+
A list of all the constraints imposed on the associated dual
|
49
|
+
variable of this :class:`Constraint` object.
|
50
|
+
"""
|
35
51
|
|
36
52
|
scalar: Scalar | float
|
37
53
|
comparator: utils.Comparator
|
38
54
|
name: str
|
55
|
+
|
56
|
+
# Used to represent the constraint on primal variable in dual PEP.
|
57
|
+
associated_dual_var_constraints: list[tuple[utils.Comparator, float]] = attrs.field(
|
58
|
+
factory=list
|
59
|
+
)
|
60
|
+
|
61
|
+
def dual_lt(self, val: float) -> None:
|
62
|
+
"""
|
63
|
+
Denote the associated dual variable of this constraint as `lambd`.
|
64
|
+
This generates a relation of the form `lambd <= val`.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
val (float): The other object in the relation.
|
68
|
+
"""
|
69
|
+
if not utils.is_numerical(val):
|
70
|
+
raise ValueError(f"The input {val=} must be a numerical value")
|
71
|
+
self.associated_dual_var_constraints.append((utils.Comparator.LT, val))
|
72
|
+
|
73
|
+
def dual_gt(self, val: float) -> None:
|
74
|
+
"""
|
75
|
+
Denote the associated dual variable of this constraint as `lambd`.
|
76
|
+
This generates a relation of the form `lambd >= val`.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
val (float): The other object in the relation.
|
80
|
+
"""
|
81
|
+
if not utils.is_numerical(val):
|
82
|
+
raise ValueError(f"The input {val=} must be a numerical value")
|
83
|
+
self.associated_dual_var_constraints.append((utils.Comparator.GT, val))
|
84
|
+
|
85
|
+
def dual_eq(self, val: float) -> None:
|
86
|
+
"""
|
87
|
+
Denote the associated dual variable of this constraint as `lambd`.
|
88
|
+
This generates a relation of the form `lambd = val`.
|
89
|
+
|
90
|
+
Args:
|
91
|
+
val (float): The other object in the relation.
|
92
|
+
"""
|
93
|
+
if not utils.is_numerical(val):
|
94
|
+
raise ValueError(f"The input {val=} must be a numerical value")
|
95
|
+
self.associated_dual_var_constraint.append((utils.Comparator.EQ, val))
|
pepflow/e2e_test.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
import math
|
2
2
|
|
3
|
-
from pepflow import function
|
3
|
+
from pepflow import function
|
4
|
+
from pepflow import parameter as pm
|
5
|
+
from pepflow import pep
|
4
6
|
from pepflow import pep_context as pc
|
5
7
|
|
6
8
|
|
@@ -28,10 +30,47 @@ def test_gd_e2e():
|
|
28
30
|
pep_builder.set_performance_metric(
|
29
31
|
f.function_value(p) - f.function_value(x_star)
|
30
32
|
)
|
31
|
-
result = pep_builder.
|
33
|
+
result = pep_builder.solve_primal()
|
32
34
|
expected_opt_value = 1 / (4 * i + 2)
|
33
35
|
assert math.isclose(result.primal_opt_value, expected_opt_value, rel_tol=1e-3)
|
34
36
|
|
37
|
+
dual_result = pep_builder.solve_dual()
|
38
|
+
assert math.isclose(
|
39
|
+
dual_result.dual_opt_value, expected_opt_value, rel_tol=1e-3
|
40
|
+
)
|
41
|
+
|
42
|
+
|
43
|
+
def test_gd_diff_stepsize_e2e():
|
44
|
+
pc.PEPContext("gd").set_as_current()
|
45
|
+
pep_builder = pep.PEPBuilder()
|
46
|
+
eta = 1 / pm.Parameter(name="L")
|
47
|
+
N = 4
|
48
|
+
|
49
|
+
f = pep_builder.declare_func(
|
50
|
+
function.SmoothConvexFunction, "f", L=pm.Parameter(name="L")
|
51
|
+
)
|
52
|
+
x = pep_builder.set_init_point("x_0")
|
53
|
+
x_star = f.add_stationary_point("x_star")
|
54
|
+
pep_builder.set_initial_constraint(
|
55
|
+
((x - x_star) ** 2).le(1, name="initial_condition")
|
56
|
+
)
|
57
|
+
|
58
|
+
# We first build the algorithm with the largest number of iterations.
|
59
|
+
for i in range(N):
|
60
|
+
x = x - eta * f.gradient(x)
|
61
|
+
x.add_tag(f"x_{i + 1}")
|
62
|
+
pep_builder.set_performance_metric(f(x) - f(x_star))
|
63
|
+
|
64
|
+
for l_val in [1, 4, 0.25]:
|
65
|
+
result = pep_builder.solve_primal(resolve_parameters={"L": l_val})
|
66
|
+
expected_opt_value = l_val / (4 * N + 2)
|
67
|
+
assert math.isclose(result.primal_opt_value, expected_opt_value, rel_tol=1e-3)
|
68
|
+
|
69
|
+
dual_result = pep_builder.solve_dual(resolve_parameters={"L": l_val})
|
70
|
+
assert math.isclose(
|
71
|
+
dual_result.dual_opt_value, expected_opt_value, rel_tol=1e-3
|
72
|
+
)
|
73
|
+
|
35
74
|
|
36
75
|
def test_pgm_e2e():
|
37
76
|
ctx = pc.PEPContext("pgm").set_as_current()
|
@@ -64,6 +103,11 @@ def test_pgm_e2e():
|
|
64
103
|
h.function_value(p) - h.function_value(x_star)
|
65
104
|
)
|
66
105
|
|
67
|
-
result = pep_builder.
|
106
|
+
result = pep_builder.solve_primal()
|
68
107
|
expected_opt_value = 1 / (4 * i)
|
69
108
|
assert math.isclose(result.primal_opt_value, expected_opt_value, rel_tol=1e-3)
|
109
|
+
|
110
|
+
dual_result = pep_builder.solve_dual()
|
111
|
+
assert math.isclose(
|
112
|
+
dual_result.dual_opt_value, expected_opt_value, rel_tol=1e-3
|
113
|
+
)
|
pepflow/expression_manager.py
CHANGED
@@ -22,6 +22,7 @@ import math
|
|
22
22
|
|
23
23
|
import numpy as np
|
24
24
|
|
25
|
+
from pepflow import parameter as pm
|
25
26
|
from pepflow import pep_context as pc
|
26
27
|
from pepflow import point as pt
|
27
28
|
from pepflow import scalar as sc
|
@@ -29,23 +30,41 @@ from pepflow import utils
|
|
29
30
|
|
30
31
|
|
31
32
|
def tag_and_coef_to_str(tag: str, v: float) -> str:
|
32
|
-
coef =
|
33
|
+
coef = utils.numerical_str(abs(v))
|
33
34
|
sign = "+" if v >= 0 else "-"
|
34
35
|
if math.isclose(abs(v), 1):
|
35
36
|
return f"{sign} {tag} "
|
36
|
-
elif math.isclose(v, 0):
|
37
|
+
elif math.isclose(v, 0, abs_tol=1e-5):
|
37
38
|
return ""
|
38
39
|
else:
|
39
40
|
return f"{sign} {coef}*{tag} "
|
40
41
|
|
41
42
|
|
42
43
|
class ExpressionManager:
|
43
|
-
|
44
|
+
"""
|
45
|
+
A class that handles the concrete representations of abstract
|
46
|
+
:class:`Point` and :class:`Scalar` objects managed by a particular
|
47
|
+
:class:`PEPContext` object.
|
48
|
+
|
49
|
+
Attributes:
|
50
|
+
context (:class:`PEPContext`): The :class:`PEPContext` object which
|
51
|
+
manages the abstract :class:`Point` and :class:`Scalar` objects
|
52
|
+
of interest.
|
53
|
+
resolve_parameters (dict[str, :class:`NUMERICAL_TYPE`]): A dictionary that
|
54
|
+
maps the name of parameters to the numerical values.
|
55
|
+
"""
|
56
|
+
|
57
|
+
def __init__(
|
58
|
+
self,
|
59
|
+
pep_context: pc.PEPContext,
|
60
|
+
resolve_parameters: dict[str, utils.NUMERICAL_TYPE] | None = None,
|
61
|
+
):
|
44
62
|
self.context = pep_context
|
45
63
|
self._basis_points = []
|
46
64
|
self._basis_point_uid_to_index = {}
|
47
65
|
self._basis_scalars = []
|
48
66
|
self._basis_scalar_uid_to_index = {}
|
67
|
+
self.resolve_parameters = resolve_parameters or {}
|
49
68
|
for point in self.context.points:
|
50
69
|
if point.is_basis:
|
51
70
|
self._basis_points.append(point)
|
@@ -74,88 +93,161 @@ class ExpressionManager:
|
|
74
93
|
|
75
94
|
@functools.cache
|
76
95
|
def eval_point(self, point: pt.Point | float | int):
|
96
|
+
"""
|
97
|
+
Return the concrete representation of the given :class:`Point`,
|
98
|
+
`float`, or `int`. Concrete representations of :class:`Point` objects
|
99
|
+
are :class:`EvaluatedPoint` objects. Concrete representations of
|
100
|
+
`float` or `int` arguments are themselves.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
point (:class:`Point`, float, int): The abstract :class:`Point`,
|
104
|
+
`float`, or `int` object whose concrete representation we want
|
105
|
+
to find.
|
106
|
+
|
107
|
+
Returns:
|
108
|
+
:class:`EvaluatedPoint` | float | int: The concrete representation
|
109
|
+
the `point` argument.
|
110
|
+
"""
|
77
111
|
if utils.is_numerical(point):
|
78
112
|
return point
|
79
113
|
|
80
|
-
|
114
|
+
if isinstance(point, pm.Parameter):
|
115
|
+
return point.get_value(self.resolve_parameters)
|
116
|
+
|
81
117
|
if point.is_basis:
|
82
118
|
index = self.get_index_of_basis_point(point)
|
119
|
+
array = np.zeros(self._num_basis_points)
|
83
120
|
array[index] = 1
|
84
121
|
return pt.EvaluatedPoint(vector=array)
|
85
122
|
|
123
|
+
if isinstance(point.eval_expression, pt.ZeroPoint):
|
124
|
+
return pt.EvaluatedPoint.zero(num_basis_points=self._num_basis_points)
|
125
|
+
|
86
126
|
op = point.eval_expression.op
|
127
|
+
left_evaled_point = self.eval_point(point.eval_expression.left_point)
|
128
|
+
right_evaled_point = self.eval_point(point.eval_expression.right_point)
|
87
129
|
if op == utils.Op.ADD:
|
88
|
-
return
|
89
|
-
point.eval_expression.right_point
|
90
|
-
)
|
130
|
+
return left_evaled_point + right_evaled_point
|
91
131
|
if op == utils.Op.SUB:
|
92
|
-
return
|
93
|
-
point.eval_expression.right_point
|
94
|
-
)
|
132
|
+
return left_evaled_point - right_evaled_point
|
95
133
|
if op == utils.Op.MUL:
|
96
|
-
return
|
97
|
-
point.eval_expression.right_point
|
98
|
-
)
|
134
|
+
return left_evaled_point * right_evaled_point
|
99
135
|
if op == utils.Op.DIV:
|
100
|
-
return
|
101
|
-
point.eval_expression.right_point
|
102
|
-
)
|
136
|
+
return left_evaled_point / right_evaled_point
|
103
137
|
|
104
|
-
raise ValueError("
|
138
|
+
raise ValueError(f"Encountered unknown {op=} when evaluation the point.")
|
105
139
|
|
106
140
|
@functools.cache
|
107
141
|
def eval_scalar(self, scalar: sc.Scalar | float | int):
|
142
|
+
"""
|
143
|
+
Return the concrete representation of the given :class:`Scalar`,
|
144
|
+
`float`, or `int`. Concrete representations of :class:`Scalar` objects
|
145
|
+
are :class:`EvaluatedScalar` objects. Concrete representations of
|
146
|
+
`float` or `int` arguments are themselves.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
scalar (:class:`Point`, float, int): The abstract :class:`Scalar`,
|
150
|
+
`float`, or `int` object whose concrete representation we want
|
151
|
+
to find.
|
152
|
+
|
153
|
+
Returns:
|
154
|
+
:class:`EvaluatedScalar` | float | int: The concrete representation
|
155
|
+
the `scalar` argument.
|
156
|
+
"""
|
108
157
|
if utils.is_numerical(scalar):
|
109
158
|
return scalar
|
159
|
+
if isinstance(scalar, pm.Parameter):
|
160
|
+
return scalar.get_value(self.resolve_parameters)
|
110
161
|
|
111
|
-
array = np.zeros(self._num_basis_scalars)
|
112
162
|
if scalar.is_basis:
|
113
163
|
index = self.get_index_of_basis_scalar(scalar)
|
164
|
+
array = np.zeros(self._num_basis_scalars)
|
114
165
|
array[index] = 1
|
166
|
+
matrix = np.zeros((self._num_basis_points, self._num_basis_points))
|
115
167
|
return sc.EvaluatedScalar(
|
116
168
|
vector=array,
|
117
|
-
matrix=
|
169
|
+
matrix=matrix,
|
118
170
|
constant=float(0.0),
|
119
171
|
)
|
172
|
+
|
173
|
+
if isinstance(scalar.eval_expression, sc.ZeroScalar):
|
174
|
+
return sc.EvaluatedScalar.zero(
|
175
|
+
num_basis_scalars=self._num_basis_scalars,
|
176
|
+
num_basis_points=self._num_basis_points,
|
177
|
+
)
|
178
|
+
|
120
179
|
op = scalar.eval_expression.op
|
180
|
+
# The special inner product usage.
|
181
|
+
if (
|
182
|
+
op == utils.Op.MUL
|
183
|
+
and isinstance(scalar.eval_expression.left_scalar, pt.Point)
|
184
|
+
and isinstance(scalar.eval_expression.right_scalar, pt.Point)
|
185
|
+
):
|
186
|
+
array = np.zeros(self._num_basis_scalars)
|
187
|
+
return sc.EvaluatedScalar(
|
188
|
+
vector=array,
|
189
|
+
matrix=utils.SOP(
|
190
|
+
self.eval_point(scalar.eval_expression.left_scalar).vector,
|
191
|
+
self.eval_point(scalar.eval_expression.right_scalar).vector,
|
192
|
+
),
|
193
|
+
constant=float(0.0),
|
194
|
+
)
|
195
|
+
|
196
|
+
left_evaled_scalar = self.eval_scalar(scalar.eval_expression.left_scalar)
|
197
|
+
right_evaled_scalar = self.eval_scalar(scalar.eval_expression.right_scalar)
|
121
198
|
if op == utils.Op.ADD:
|
122
|
-
return
|
123
|
-
scalar.eval_expression.left_scalar
|
124
|
-
) + self.eval_scalar(scalar.eval_expression.right_scalar)
|
199
|
+
return left_evaled_scalar + right_evaled_scalar
|
125
200
|
if op == utils.Op.SUB:
|
126
|
-
return
|
127
|
-
scalar.eval_expression.left_scalar
|
128
|
-
) - self.eval_scalar(scalar.eval_expression.right_scalar)
|
201
|
+
return left_evaled_scalar - right_evaled_scalar
|
129
202
|
if op == utils.Op.MUL:
|
130
|
-
|
131
|
-
scalar.eval_expression.right_scalar, pt.Point
|
132
|
-
):
|
133
|
-
return sc.EvaluatedScalar(
|
134
|
-
vector=np.zeros(self._num_basis_scalars),
|
135
|
-
matrix=utils.SOP(
|
136
|
-
self.eval_point(scalar.eval_expression.left_scalar).vector,
|
137
|
-
self.eval_point(scalar.eval_expression.right_scalar).vector,
|
138
|
-
),
|
139
|
-
constant=float(0.0),
|
140
|
-
)
|
141
|
-
else:
|
142
|
-
return self.eval_scalar(
|
143
|
-
scalar.eval_expression.left_scalar
|
144
|
-
) * self.eval_scalar(scalar.eval_expression.right_scalar)
|
203
|
+
return left_evaled_scalar * right_evaled_scalar
|
145
204
|
if op == utils.Op.DIV:
|
146
|
-
return
|
147
|
-
scalar.eval_expression.left_scalar
|
148
|
-
) / self.eval_scalar(scalar.eval_expression.right_scalar)
|
205
|
+
return left_evaled_scalar / right_evaled_scalar
|
149
206
|
|
150
|
-
raise ValueError("
|
207
|
+
raise ValueError(f"Encountered unknown {op=} when evaluation the scalar.")
|
151
208
|
|
152
209
|
@functools.cache
|
153
210
|
def repr_point_by_basis(self, point: pt.Point) -> str:
|
211
|
+
"""
|
212
|
+
Express the given :class:`Point` object as the linear combination of
|
213
|
+
the basis :class:`Point` objects of the :class:`PEPContext` associated
|
214
|
+
with this :class:`ExpressionManager`. This linear combination is
|
215
|
+
expressed as a `str` where, to refer to the basis :class:`Point`
|
216
|
+
objects, we use their tags.
|
217
|
+
|
218
|
+
Args:
|
219
|
+
point (:class:`Point`): The :class:`Point` object which we want
|
220
|
+
to express in terms of the basis :class:`Point` objects.
|
221
|
+
|
222
|
+
Returns:
|
223
|
+
str: The representation of `point` in terms of the basis
|
224
|
+
:class:`Point` objects of the :class:`PEPContext` associated
|
225
|
+
with this :class:`ExpressionManager`.
|
226
|
+
"""
|
154
227
|
assert isinstance(point, pt.Point)
|
155
|
-
|
228
|
+
evaluated_point = self.eval_point(point)
|
229
|
+
return self.repr_evaluated_point_by_basis(evaluated_point)
|
230
|
+
|
231
|
+
def repr_evaluated_point_by_basis(self, evaluated_point: pt.EvaluatedPoint) -> str:
|
232
|
+
"""
|
233
|
+
Express the given :class:`EvaluatedPoint` object as the linear
|
234
|
+
combination of the basis :class:`Point` objects of the
|
235
|
+
:class:`PEPContext` associated with this :class:`ExpressionManager`.
|
236
|
+
This linear combination is expressed as a `str` where, to refer to the
|
237
|
+
basis :class:`Point` objects, we use their tags.
|
156
238
|
|
239
|
+
Args:
|
240
|
+
evaluated_point (:class:`EvaluatedPoint`): The
|
241
|
+
:class:`EvaluatedPoint` object which we want to express in
|
242
|
+
terms of the basis :class:`Point` objects.
|
243
|
+
|
244
|
+
Returns:
|
245
|
+
str: The representation of `evaluated_point` in terms of
|
246
|
+
the basis :class:`Point` objects of the :class:`PEPContext`
|
247
|
+
associated with this :class:`ExpressionManager`.
|
248
|
+
"""
|
157
249
|
repr_str = ""
|
158
|
-
for i, v in enumerate(
|
250
|
+
for i, v in enumerate(evaluated_point.vector):
|
159
251
|
ith_tag = self.get_tag_of_basis_point_index(i)
|
160
252
|
repr_str += tag_and_coef_to_str(ith_tag, v)
|
161
253
|
|
@@ -169,28 +261,113 @@ class ExpressionManager:
|
|
169
261
|
return repr_str.strip()
|
170
262
|
|
171
263
|
@functools.cache
|
172
|
-
def repr_scalar_by_basis(
|
264
|
+
def repr_scalar_by_basis(
|
265
|
+
self, scalar: sc.Scalar, greedy_square: bool = True
|
266
|
+
) -> str:
|
267
|
+
"""Express the given :class:`Scalar` object in terms of the basis
|
268
|
+
:class:`Point` and :class:`Scalar` objects of the :class:`PEPContext`
|
269
|
+
associated with this :class:`ExpressionManager`.
|
270
|
+
|
271
|
+
A :class:`Scalar` can be formed by linear combinations of basis
|
272
|
+
:class:`Scalar` objects. A :class:`Scalar` can also be formed through
|
273
|
+
the inner product of two basis :class:`Point` objects. This function
|
274
|
+
returns the representation of this :class:`Scalar` object in terms of
|
275
|
+
the basis :class:`Point` and :class:`Scalar` objects as a `str` where,
|
276
|
+
to refer to the basis :class:`Point` and :class:`Scalar` objects,
|
277
|
+
we use their tags.
|
278
|
+
|
279
|
+
Args:
|
280
|
+
scalar (:class:`Scalar`): The :class:`Scalar` object which we want
|
281
|
+
to express in terms of the basis :class:`Point` and
|
282
|
+
:class:`Scalar` objects.
|
283
|
+
greedy_square (bool): If `greedy_square` is true, the function will
|
284
|
+
try to return :math:`\\|a-b\\|^2` whenever possible. If not,
|
285
|
+
the function will return
|
286
|
+
:math:`\\|a\\|^2 - 2 * \\langle a, b \\rangle + \\|b\\|^2` instead.
|
287
|
+
`True` by default.
|
288
|
+
|
289
|
+
Returns:
|
290
|
+
str: The representation of `scalar` in terms of the basis
|
291
|
+
:class:`Point` and :class:`Scalar` objects of the
|
292
|
+
:class:`PEPContext` associated with this
|
293
|
+
:class:`ExpressionManager`.
|
294
|
+
"""
|
173
295
|
assert isinstance(scalar, sc.Scalar)
|
174
296
|
evaluated_scalar = self.eval_scalar(scalar)
|
297
|
+
return self.repr_evaluated_scalar_by_basis(
|
298
|
+
evaluated_scalar, greedy_square=greedy_square
|
299
|
+
)
|
300
|
+
|
301
|
+
def repr_evaluated_scalar_by_basis(
|
302
|
+
self, evaluated_scalar: sc.EvaluatedScalar, greedy_square: bool = True
|
303
|
+
) -> str:
|
304
|
+
"""Express the given :class:`EvaluatedScalar` object in terms of the
|
305
|
+
basis :class:`Point` and :class:`Scalar` objects of the
|
306
|
+
:class:`PEPContext` associated with this :class:`ExpressionManager`.
|
175
307
|
|
308
|
+
A :class:`Scalar` can be formed by linear combinations of basis
|
309
|
+
:class:`Scalar` objects. A :class:`Scalar` can also be formed through
|
310
|
+
the inner product of two basis :class:`Point` objects. This function
|
311
|
+
returns the representation of this :class:`Scalar` object in terms of
|
312
|
+
the basis :class:`Point` and :class:`Scalar` objects as a `str` where,
|
313
|
+
to refer to the basis :class:`Point` and :class:`Scalar` objects,
|
314
|
+
we use their tags.
|
315
|
+
|
316
|
+
Args:
|
317
|
+
evaluated_scalar (:class:`EvaluatedScalar`): The
|
318
|
+
:class:`EvaluatedScalar` object which we want to express in
|
319
|
+
terms of the basis :class:`Point` and :class:`Scalar` objects.
|
320
|
+
greedy_square (bool): If `greedy_square` is true, the function will
|
321
|
+
try to return :math:`\\|a-b\\|^2` whenever possible. If not,
|
322
|
+
the function will return
|
323
|
+
:math:`\\|a\\|^2 - 2 * \\langle a, b \\rangle + \\|b\\|^2` instead.
|
324
|
+
`True` by default.
|
325
|
+
|
326
|
+
Returns:
|
327
|
+
str: The representation of `evaluated_scalar` in terms of
|
328
|
+
the basis :class:`Point` and :class:`Scalar` objects of the
|
329
|
+
:class:`PEPContext` associated with this :class:`ExpressionManager`.
|
330
|
+
"""
|
176
331
|
repr_str = ""
|
177
|
-
if not math.isclose(evaluated_scalar.constant, 0):
|
178
|
-
repr_str +=
|
332
|
+
if not math.isclose(evaluated_scalar.constant, 0, abs_tol=1e-5):
|
333
|
+
repr_str += utils.numerical_str(evaluated_scalar.constant)
|
179
334
|
|
180
335
|
for i, v in enumerate(evaluated_scalar.vector):
|
181
336
|
# Note the tag is from scalar basis.
|
182
337
|
ith_tag = self.get_tag_of_basis_scalar_index(i)
|
183
338
|
repr_str += tag_and_coef_to_str(ith_tag, v)
|
184
339
|
|
185
|
-
|
186
|
-
|
340
|
+
if greedy_square:
|
341
|
+
diag_elem = np.diag(evaluated_scalar.matrix).copy()
|
342
|
+
for i in range(evaluated_scalar.matrix.shape[0]):
|
343
|
+
ith_tag = self.get_tag_of_basis_point_index(i)
|
344
|
+
# j starts from i+1 since we want to handle the diag elem at last.
|
345
|
+
for j in range(i + 1, evaluated_scalar.matrix.shape[0]):
|
346
|
+
jth_tag = self.get_tag_of_basis_point_index(j)
|
347
|
+
v = float(evaluated_scalar.matrix[i, j])
|
348
|
+
# We want to minimize the diagonal elements to zero greedily.
|
349
|
+
if diag_elem[i] * v > 0: # same sign with diagonal elem
|
350
|
+
diag_elem[i] -= v
|
351
|
+
diag_elem[j] -= v
|
352
|
+
repr_str += tag_and_coef_to_str(f"|{ith_tag}+{jth_tag}|^2", v)
|
353
|
+
else: # different sign
|
354
|
+
diag_elem[i] += v
|
355
|
+
diag_elem[j] += v
|
356
|
+
repr_str += tag_and_coef_to_str(f"|{ith_tag}-{jth_tag}|^2", -v)
|
357
|
+
# Handle the diagonal elements
|
358
|
+
repr_str += tag_and_coef_to_str(f"|{ith_tag}|^2", diag_elem[i])
|
359
|
+
else:
|
360
|
+
for i in range(evaluated_scalar.matrix.shape[0]):
|
187
361
|
ith_tag = self.get_tag_of_basis_point_index(i)
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
362
|
+
for j in range(i, evaluated_scalar.matrix.shape[0]):
|
363
|
+
jth_tag = self.get_tag_of_basis_point_index(j)
|
364
|
+
v = float(evaluated_scalar.matrix[i, j])
|
365
|
+
if i == j:
|
366
|
+
repr_str += tag_and_coef_to_str(f"|{ith_tag}|^2", v)
|
367
|
+
else:
|
368
|
+
repr_str += tag_and_coef_to_str(
|
369
|
+
f"<{ith_tag}, {jth_tag}>", 2 * v
|
370
|
+
)
|
194
371
|
|
195
372
|
# Post processing
|
196
373
|
if repr_str == "":
|
@@ -200,3 +377,41 @@ class ExpressionManager:
|
|
200
377
|
if repr_str.startswith("- "):
|
201
378
|
repr_str = "-" + repr_str[2:]
|
202
379
|
return repr_str.strip()
|
380
|
+
|
381
|
+
|
382
|
+
def represent_matrix_by_basis(matrix: np.ndarray, ctx: pc.PEPContext) -> str:
|
383
|
+
"""Express the given matrix in terms of the basis :class:`Point` objects
|
384
|
+
of the given :class:`PEPContext`.
|
385
|
+
|
386
|
+
The concrete representation of the inner product of two abstract
|
387
|
+
basis :class:`Point` objects is a matrix (the outer product of the
|
388
|
+
basis vectors corresponding to the concrete representations of the abstract
|
389
|
+
basis :class:`Point` objects). The given matrix can then be expressed
|
390
|
+
as a linear combination of the inner products of abstract basis
|
391
|
+
:class:`Point` objects. This is provided as a `str` where, to refer to
|
392
|
+
the basis :class:`Point` objects, we use their tags.
|
393
|
+
|
394
|
+
Args:
|
395
|
+
matrix (np.ndarray): The matrix which we want to express in terms of
|
396
|
+
the basis :class:`Point` objects of the given :class:`PEPContext`.
|
397
|
+
ctx (:class:`PEPContext`): The :class:`PEPContext` whose basis
|
398
|
+
:class:`Point` objects we consider.
|
399
|
+
|
400
|
+
Returns:
|
401
|
+
str: The representation of `matrix` in terms of the basis
|
402
|
+
:class:`Point` objects of `ctx`.
|
403
|
+
"""
|
404
|
+
em = ExpressionManager(ctx)
|
405
|
+
matrix_shape = (len(em._basis_points), len(em._basis_points))
|
406
|
+
if matrix.shape != matrix_shape:
|
407
|
+
raise ValueError(
|
408
|
+
"The valid matrix for given context should have shape {matrix_shape}"
|
409
|
+
)
|
410
|
+
if not np.allclose(matrix, matrix.T):
|
411
|
+
raise ValueError("Input matrix must be symmetric.")
|
412
|
+
|
413
|
+
return em.repr_evaluated_scalar_by_basis(
|
414
|
+
sc.EvaluatedScalar(
|
415
|
+
vector=np.zeros(len(em._basis_scalars)), matrix=matrix, constant=0.0
|
416
|
+
)
|
417
|
+
)
|
@@ -97,7 +97,28 @@ def test_repr_scalar_by_basis(pep_context: pc.PEPContext) -> None:
|
|
97
97
|
|
98
98
|
s = f(x) + x * f.gradient(x)
|
99
99
|
em = exm.ExpressionManager(pep_context)
|
100
|
-
assert
|
100
|
+
assert (
|
101
|
+
em.repr_scalar_by_basis(s, greedy_square=False) == "f(x) + <x, gradient_f(x)>"
|
102
|
+
)
|
103
|
+
assert (
|
104
|
+
em.repr_scalar_by_basis(s)
|
105
|
+
== "f(x) - 0.5*|x-gradient_f(x)|^2 + 0.5*|x|^2 + 0.5*|gradient_f(x)|^2"
|
106
|
+
)
|
107
|
+
|
108
|
+
|
109
|
+
def test_repr_scalar_by_basis2(pep_context: pc.PEPContext) -> None:
|
110
|
+
x = pt.Point(is_basis=True, tags=["x"])
|
111
|
+
f = fc.Function(is_basis=True, tags=["f"])
|
112
|
+
|
113
|
+
s = f(x) - x * f.gradient(x)
|
114
|
+
em = exm.ExpressionManager(pep_context)
|
115
|
+
assert (
|
116
|
+
em.repr_scalar_by_basis(s, greedy_square=False) == "f(x) - <x, gradient_f(x)>"
|
117
|
+
)
|
118
|
+
assert (
|
119
|
+
em.repr_scalar_by_basis(s)
|
120
|
+
== "f(x) + 0.5*|x-gradient_f(x)|^2 - 0.5*|x|^2 - 0.5*|gradient_f(x)|^2"
|
121
|
+
)
|
101
122
|
|
102
123
|
|
103
124
|
def test_repr_scalar_by_basis_interpolation(pep_context: pc.PEPContext) -> None:
|
@@ -110,7 +131,20 @@ def test_repr_scalar_by_basis_interpolation(pep_context: pc.PEPContext) -> None:
|
|
110
131
|
interp_scalar = f.interpolate_ineq("x_i", "x_j")
|
111
132
|
em = exm.ExpressionManager(pep_context)
|
112
133
|
expected_repr = "-f(x_i) + f(x_j) + <x_i, gradient_f(x_j)> - <x_j, gradient_f(x_j)> + 0.5*|gradient_f(x_i)|^2 - <gradient_f(x_i), gradient_f(x_j)> + 0.5*|gradient_f(x_j)|^2"
|
113
|
-
assert em.repr_scalar_by_basis(interp_scalar) == expected_repr
|
134
|
+
assert em.repr_scalar_by_basis(interp_scalar, greedy_square=False) == expected_repr
|
135
|
+
expected_square_repr = "-f(x_i) + f(x_j) - 0.5*|x_i-gradient_f(x_j)|^2 + 0.5*|x_i|^2 + 0.5*|x_j-gradient_f(x_j)|^2 - 0.5*|x_j|^2 + 0.5*|gradient_f(x_i)-gradient_f(x_j)|^2"
|
136
|
+
assert em.repr_scalar_by_basis(interp_scalar) == expected_square_repr
|
114
137
|
|
115
138
|
|
116
139
|
# TODO add more tests about repr_scalar_by_basis
|
140
|
+
|
141
|
+
|
142
|
+
def test_represent_matrix_by_basis(pep_context: pc.PEPContext) -> None:
|
143
|
+
_ = pt.Point(is_basis=True, tags=["x_1"])
|
144
|
+
_ = pt.Point(is_basis=True, tags=["x_2"])
|
145
|
+
_ = pt.Point(is_basis=True, tags=["x_3"])
|
146
|
+
matrix = np.array([[0.5, 0.5, 0], [0.5, 2, 0], [0, 0, 3]])
|
147
|
+
assert (
|
148
|
+
exm.represent_matrix_by_basis(matrix, pep_context)
|
149
|
+
== "0.5*|x_1+x_2|^2 + 1.5*|x_2|^2 + 3*|x_3|^2"
|
150
|
+
)
|