pepflow 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pepflow/point_test.py ADDED
@@ -0,0 +1,329 @@
1
+ import time
2
+
3
+ import numpy as np
4
+
5
+ from pepflow import expression_manager as exm
6
+ from pepflow import function as fc
7
+ from pepflow import pep as pep
8
+ from pepflow import point, scalar, utils
9
+
10
+
11
+ def test_point_hash_different():
12
+ pep_builder = pep.PEPBuilder()
13
+ with pep_builder.make_context("test"):
14
+ p1 = point.Point(is_basis=True, eval_expression=None)
15
+ p2 = point.Point(is_basis=True, eval_expression=None)
16
+ assert p1.uid != p2.uid
17
+
18
+
19
+ def test_scalar_hash_different():
20
+ pep_builder = pep.PEPBuilder()
21
+ with pep_builder.make_context("test"):
22
+ s1 = scalar.Scalar(is_basis=True, eval_expression=None)
23
+ s2 = scalar.Scalar(is_basis=True, eval_expression=None)
24
+ assert s1.uid != s2.uid
25
+
26
+
27
+ def test_point_tag():
28
+ pep_builder = pep.PEPBuilder()
29
+ with pep_builder.make_context("test"):
30
+ p1 = point.Point(is_basis=True, eval_expression=None)
31
+ p1.add_tag(tag="my_tag")
32
+ assert p1.tags == ["my_tag"]
33
+
34
+
35
+ def test_scalar_tag():
36
+ pep_builder = pep.PEPBuilder()
37
+ with pep_builder.make_context("test"):
38
+ s1 = scalar.Scalar(is_basis=True, eval_expression=None)
39
+ s1.add_tag(tag="my_tag")
40
+ assert s1.tags == ["my_tag"]
41
+
42
+
43
+ def test_point_in_a_list():
44
+ pep_builder = pep.PEPBuilder()
45
+ with pep_builder.make_context("test"):
46
+ p1 = point.Point(is_basis=True, eval_expression=None)
47
+ p2 = point.Point(is_basis=True, eval_expression=None)
48
+ p3 = point.Point(is_basis=True, eval_expression=None)
49
+ assert p1 in [p1, p2]
50
+ assert p3 not in [p1, p2]
51
+
52
+
53
+ def test_scalar_in_a_list():
54
+ pep_builder = pep.PEPBuilder()
55
+ with pep_builder.make_context("test"):
56
+ s1 = scalar.Scalar(is_basis=True, eval_expression=None)
57
+ s2 = scalar.Scalar(is_basis=True, eval_expression=None)
58
+ s3 = scalar.Scalar(is_basis=True, eval_expression=None)
59
+ assert s1 in [s1, s2]
60
+ assert s3 not in [s1, s2]
61
+
62
+
63
+ def test_expression_manager_on_basis_point():
64
+ pep_builder = pep.PEPBuilder()
65
+ with pep_builder.make_context("test") as ctx:
66
+ p1 = point.Point(is_basis=True, eval_expression=None, tags=["p1"])
67
+ p2 = point.Point(is_basis=True, eval_expression=None, tags=["p2"])
68
+ pm = exm.ExpressionManager(ctx)
69
+
70
+ np.testing.assert_allclose(pm.eval_point(p1).vector, np.array([1, 0]))
71
+ np.testing.assert_allclose(pm.eval_point(p2).vector, np.array([0, 1]))
72
+
73
+ p3 = point.Point(is_basis=True, eval_expression=None, tags=["p3"]) # noqa: F841
74
+ pm = exm.ExpressionManager(ctx)
75
+
76
+ np.testing.assert_allclose(pm.eval_point(p1).vector, np.array([1, 0, 0]))
77
+ np.testing.assert_allclose(pm.eval_point(p2).vector, np.array([0, 1, 0]))
78
+
79
+
80
+ def test_expression_manager_on_basis_scalar():
81
+ pep_builder = pep.PEPBuilder()
82
+ with pep_builder.make_context("test") as ctx:
83
+ s1 = scalar.Scalar(is_basis=True, eval_expression=None, tags=["s1"])
84
+ s2 = scalar.Scalar(is_basis=True, eval_expression=None, tags=["s2"])
85
+ pm = exm.ExpressionManager(ctx)
86
+
87
+ np.testing.assert_allclose(pm.eval_scalar(s1).vector, np.array([1, 0]))
88
+ np.testing.assert_allclose(pm.eval_scalar(s2).vector, np.array([0, 1]))
89
+
90
+ s3 = scalar.Scalar(is_basis=True, eval_expression=None, tags=["s3"]) # noqa: F841
91
+ pm = exm.ExpressionManager(ctx)
92
+
93
+ np.testing.assert_allclose(pm.eval_scalar(s1).vector, np.array([1, 0, 0]))
94
+ np.testing.assert_allclose(pm.eval_scalar(s2).vector, np.array([0, 1, 0]))
95
+
96
+
97
+ def test_expression_manager_eval_point():
98
+ pep_builder = pep.PEPBuilder()
99
+ with pep_builder.make_context("test") as ctx:
100
+ p1 = point.Point(is_basis=True)
101
+ p2 = point.Point(is_basis=True)
102
+ p3 = 2 * p1 + p2 / 4
103
+ p4 = p3 + p1
104
+
105
+ pm = exm.ExpressionManager(ctx)
106
+ np.testing.assert_allclose(pm.eval_point(p3).vector, np.array([2, 0.25]))
107
+ np.testing.assert_allclose(pm.eval_point(p4).vector, np.array([3, 0.25]))
108
+
109
+
110
+ def test_expression_manager_eval_scalar():
111
+ pep_builder = pep.PEPBuilder()
112
+ with pep_builder.make_context("test") as ctx:
113
+ s1 = scalar.Scalar(is_basis=True)
114
+ s2 = scalar.Scalar(is_basis=True)
115
+ s3 = 2 * s1 + s2 / 4 + 5
116
+ s4 = s3 + s1
117
+ s5 = s4 + 5
118
+
119
+ p1 = point.Point(is_basis=True)
120
+ p2 = point.Point(is_basis=True)
121
+ s6 = p1 * p2
122
+
123
+ p3 = point.Point(is_basis=True)
124
+ p4 = point.Point(is_basis=True)
125
+ s7 = 5 * p3 * p4
126
+
127
+ s8 = s6 + s7
128
+
129
+ pm = exm.ExpressionManager(ctx)
130
+
131
+ np.testing.assert_allclose(pm.eval_scalar(s3).vector, np.array([2, 0.25]))
132
+ np.testing.assert_allclose(pm.eval_scalar(s3).constant, 5)
133
+ np.testing.assert_allclose(pm.eval_scalar(s4).vector, np.array([3, 0.25]))
134
+ np.testing.assert_allclose(pm.eval_scalar(s5).vector, np.array([3, 0.25]))
135
+ np.testing.assert_allclose(pm.eval_scalar(s5).constant, 10)
136
+
137
+ np.testing.assert_allclose(pm.eval_point(p1).vector, np.array([1, 0, 0, 0]))
138
+ np.testing.assert_allclose(pm.eval_point(p2).vector, np.array([0, 1, 0, 0]))
139
+ np.testing.assert_allclose(pm.eval_point(p3).vector, np.array([0, 0, 1, 0]))
140
+ np.testing.assert_allclose(pm.eval_point(p4).vector, np.array([0, 0, 0, 1]))
141
+
142
+ np.testing.assert_allclose(
143
+ pm.eval_scalar(s6).matrix,
144
+ np.array(
145
+ [
146
+ [0.0, 0.5, 0.0, 0.0],
147
+ [0.5, 0.0, 0.0, 0.0],
148
+ [0.0, 0.0, 0.0, 0.0],
149
+ [0.0, 0.0, 0.0, 0.0],
150
+ ]
151
+ ),
152
+ )
153
+ np.testing.assert_allclose(
154
+ pm.eval_scalar(s7).matrix,
155
+ np.array(
156
+ [
157
+ [0.0, 0.0, 0.0, 0.0],
158
+ [0.0, 0.0, 0.0, 0.0],
159
+ [0.0, 0.0, 0.0, 2.5],
160
+ [0.0, 0.0, 2.5, 0.0],
161
+ ]
162
+ ),
163
+ )
164
+
165
+ np.testing.assert_allclose(
166
+ pm.eval_scalar(s8).matrix,
167
+ np.array(
168
+ [
169
+ [0.0, 0.5, 0.0, 0.0],
170
+ [0.5, 0.0, 0.0, 0.0],
171
+ [0.0, 0.0, 0.0, 2.5],
172
+ [0.0, 0.0, 2.5, 0.0],
173
+ ]
174
+ ),
175
+ )
176
+
177
+
178
+ def test_constraint():
179
+ pep_builder = pep.PEPBuilder()
180
+ with pep_builder.make_context("test") as ctx:
181
+ s1 = scalar.Scalar(is_basis=True)
182
+ s2 = scalar.Scalar(is_basis=True)
183
+ s3 = 2 * s1 + s2 / 4 + 5
184
+
185
+ c1 = s3.le(5, name="c1")
186
+ c2 = s3.lt(5, name="c2")
187
+ c3 = s3.ge(5, name="c3")
188
+ c4 = s3.gt(5, name="c4")
189
+ c5 = s3.eq(5, name="c5")
190
+
191
+ pm = exm.ExpressionManager(ctx)
192
+
193
+ np.testing.assert_allclose(pm.eval_scalar(c1.scalar).vector, np.array([2, 0.25]))
194
+ np.testing.assert_allclose(pm.eval_scalar(c1.scalar).constant, 0)
195
+ assert c1.comparator == utils.Comparator.LT
196
+
197
+ np.testing.assert_allclose(pm.eval_scalar(c2.scalar).vector, np.array([2, 0.25]))
198
+ np.testing.assert_allclose(pm.eval_scalar(c2.scalar).constant, 0)
199
+ assert c2.comparator == utils.Comparator.LT
200
+
201
+ np.testing.assert_allclose(pm.eval_scalar(c3.scalar).vector, np.array([2, 0.25]))
202
+ np.testing.assert_allclose(pm.eval_scalar(c3.scalar).constant, 0)
203
+ assert c3.comparator == utils.Comparator.GT
204
+
205
+ np.testing.assert_allclose(pm.eval_scalar(c4.scalar).vector, np.array([2, 0.25]))
206
+ np.testing.assert_allclose(pm.eval_scalar(c4.scalar).constant, 0)
207
+ assert c4.comparator == utils.Comparator.GT
208
+
209
+ np.testing.assert_allclose(pm.eval_scalar(c5.scalar).vector, np.array([2, 0.25]))
210
+ np.testing.assert_allclose(pm.eval_scalar(c5.scalar).constant, 0)
211
+ assert c5.comparator == utils.Comparator.EQ
212
+
213
+
214
+ def test_expression_manager_eval_point_large_scale():
215
+ pep_builder = pep.PEPBuilder()
216
+ with pep_builder.make_context("test") as ctx:
217
+ all_basis = [point.Point(is_basis=True) for _ in range(100)]
218
+ p = all_basis[0]
219
+ for i in range(len(all_basis)):
220
+ for j in range(i + 1, len(all_basis)):
221
+ p += all_basis[i] * 2 + all_basis[j]
222
+ pm = exm.ExpressionManager(ctx)
223
+ t = time.time()
224
+ for pp in ctx.points:
225
+ pm.eval_point(pp)
226
+
227
+ assert (time.time() - t) < 0.1
228
+
229
+
230
+ def test_function_generate_triplet():
231
+ pep_builder = pep.PEPBuilder()
232
+ with pep_builder.make_context("test") as ctx:
233
+ f = fc.Function(is_basis=True, reuse_gradient=True)
234
+ g = fc.Function(is_basis=True, reuse_gradient=True)
235
+ h = 5 * f + 5 * g
236
+
237
+ f1 = fc.Function(is_basis=True, reuse_gradient=False)
238
+ g1 = fc.Function(is_basis=True, reuse_gradient=False)
239
+ h1 = 5 * f1 + 5 * g1
240
+
241
+ p1 = point.Point(is_basis=True)
242
+ _, func_value, grad = h.generate_triplet(p1)
243
+ _, func_value_1, grad_1 = h.generate_triplet(p1)
244
+
245
+ _, func_value_2, grad_2 = h1.generate_triplet(p1)
246
+ _, func_value_3, grad_3 = h1.generate_triplet(p1)
247
+
248
+ pm = exm.ExpressionManager(ctx)
249
+
250
+ np.testing.assert_allclose(
251
+ pm.eval_point(p1).vector, np.array([1, 0, 0, 0, 0, 0, 0])
252
+ )
253
+
254
+ np.testing.assert_allclose(
255
+ pm.eval_point(grad).vector, np.array([0, 5, 5, 0, 0, 0, 0])
256
+ )
257
+ np.testing.assert_allclose(
258
+ pm.eval_scalar(func_value).vector, np.array([5, 5, 0, 0])
259
+ )
260
+
261
+ np.testing.assert_allclose(
262
+ pm.eval_point(grad_1).vector, np.array([0, 5, 5, 0, 0, 0, 0])
263
+ )
264
+ np.testing.assert_allclose(
265
+ pm.eval_scalar(func_value_1).vector, np.array([5, 5, 0, 0])
266
+ )
267
+
268
+ np.testing.assert_allclose(
269
+ pm.eval_point(grad_2).vector, np.array([0, 0, 0, 5, 5, 0, 0])
270
+ )
271
+ np.testing.assert_allclose(
272
+ pm.eval_scalar(func_value_2).vector, np.array([0, 0, 5, 5])
273
+ )
274
+
275
+ np.testing.assert_allclose(
276
+ pm.eval_point(grad_3).vector, np.array([0, 0, 0, 0, 0, 5, 5])
277
+ )
278
+ np.testing.assert_allclose(
279
+ pm.eval_scalar(func_value_3).vector, np.array([0, 0, 5, 5])
280
+ )
281
+
282
+
283
+ def test_function_add_stationary_point():
284
+ pep_builder = pep.PEPBuilder()
285
+ with pep_builder.make_context("test") as ctx:
286
+ f = fc.Function(is_basis=True, reuse_gradient=True)
287
+ x_opt = f.add_stationary_point()
288
+
289
+ pm = exm.ExpressionManager(ctx)
290
+
291
+ np.testing.assert_allclose(pm.eval_point(x_opt).vector, np.array([1, 0]))
292
+
293
+ np.testing.assert_allclose(
294
+ pm.eval_scalar(f.constraints[0].scalar).matrix, [[0, 0], [0, 1]]
295
+ )
296
+ np.testing.assert_allclose(pm.eval_scalar(f.constraints[0].scalar).vector, [0])
297
+ np.testing.assert_allclose(pm.eval_scalar(f.constraints[0].scalar).constant, 0)
298
+ assert f.constraints[0].comparator == utils.Comparator.EQ
299
+
300
+
301
+ def test_smooth_interpolability_constraints():
302
+ pep_builder = pep.PEPBuilder()
303
+ with pep_builder.make_context("test") as ctx:
304
+ f = fc.SmoothConvexFunction(L=1)
305
+ _ = f.add_stationary_point()
306
+
307
+ x_0 = point.Point(is_basis=True)
308
+ _, _, _ = f.generate_triplet(x_0)
309
+
310
+ all_interpolation_constraints = f.get_interpolation_constraints()
311
+
312
+ pm = exm.ExpressionManager(ctx)
313
+
314
+ np.testing.assert_allclose(
315
+ pm.eval_scalar(all_interpolation_constraints[1].scalar).vector, [1, -1]
316
+ )
317
+ np.testing.assert_allclose(
318
+ pm.eval_scalar(all_interpolation_constraints[1].scalar).matrix,
319
+ [
320
+ [0.0, -0.5, 0.0, 0.0],
321
+ [-0.5, 0.5, 0.5, -0.5],
322
+ [0.0, 0.5, 0.0, 0.0],
323
+ [0.0, -0.5, 0.0, 0.5],
324
+ ],
325
+ )
326
+
327
+ np.testing.assert_allclose(
328
+ pm.eval_scalar(all_interpolation_constraints[1].scalar).constant, 0
329
+ )
pepflow/scalar.py ADDED
@@ -0,0 +1,219 @@
1
+ from __future__ import annotations
2
+
3
+ import uuid
4
+ from typing import TYPE_CHECKING, Any
5
+
6
+ import attrs
7
+ import numpy as np
8
+
9
+ from pepflow import constraint as ctr
10
+ from pepflow import pep_context as pc
11
+ from pepflow import utils
12
+
13
+ if TYPE_CHECKING:
14
+ from pepflow.point import Point
15
+
16
+
17
+ def is_numerical_or_scalar(val: Any) -> bool:
18
+ return utils.is_numerical(val) or isinstance(val, Scalar)
19
+
20
+
21
+ def is_numerical_or_evaluatedscalar(val: Any) -> bool:
22
+ return utils.is_numerical(val) or isinstance(val, EvaluatedScalar)
23
+
24
+
25
+ @attrs.frozen
26
+ class EvalExpressionScalar:
27
+ op: utils.Op
28
+ left_scalar: Point | Scalar | float
29
+ right_scalar: Point | Scalar | float
30
+
31
+
32
+ @attrs.frozen
33
+ class EvaluatedScalar:
34
+ vector: np.array
35
+ matrix: np.array
36
+ constant: float
37
+
38
+ def __add__(self, other):
39
+ assert is_numerical_or_evaluatedscalar(other)
40
+ if utils.is_numerical(other):
41
+ return EvaluatedScalar(
42
+ vector=self.vector, matrix=self.matrix, constant=self.constant + other
43
+ )
44
+ else:
45
+ return EvaluatedScalar(
46
+ vector=self.vector + other.vector,
47
+ matrix=self.matrix + other.matrix,
48
+ constant=self.constant + other.constant,
49
+ )
50
+
51
+ def __radd__(self, other):
52
+ assert is_numerical_or_evaluatedscalar(other)
53
+ if utils.is_numerical(other):
54
+ return EvaluatedScalar(
55
+ vector=self.vector, matrix=self.matrix, constant=other + self.constant
56
+ )
57
+ else:
58
+ return EvaluatedScalar(
59
+ vector=other.vector + self.vector,
60
+ matrix=other.matrix + self.matrix,
61
+ constant=other.constant + self.constant,
62
+ )
63
+
64
+ def __sub__(self, other):
65
+ assert is_numerical_or_evaluatedscalar(other)
66
+ if utils.is_numerical(other):
67
+ return EvaluatedScalar(
68
+ vector=self.vector, matrix=self.matrix, constant=self.constant - other
69
+ )
70
+ else:
71
+ return EvaluatedScalar(
72
+ vector=self.vector - other.vector,
73
+ matrix=self.matrix - other.matrix,
74
+ constant=self.constant - other.constant,
75
+ )
76
+
77
+ def __rsub__(self, other):
78
+ assert is_numerical_or_evaluatedscalar(other)
79
+ if utils.is_numerical(other):
80
+ return EvaluatedScalar(
81
+ vector=-self.vector, matrix=-self.matrix, constant=other - self.constant
82
+ )
83
+ else:
84
+ return EvaluatedScalar(
85
+ vector=other.vector - self.vector,
86
+ matrix=other.matrix - self.matrix,
87
+ constant=other.constant - self.constant,
88
+ )
89
+
90
+ def __mul__(self, other):
91
+ assert utils.is_numerical(other)
92
+ return EvaluatedScalar(
93
+ vector=self.vector * other,
94
+ matrix=self.matrix * other,
95
+ constant=self.constant * other,
96
+ )
97
+
98
+ def __rmul__(self, other):
99
+ assert utils.is_numerical(other)
100
+ return EvaluatedScalar(
101
+ vector=other * self.vector,
102
+ matrix=other * self.matrix,
103
+ constant=other * self.constant,
104
+ )
105
+
106
+ def __neg__(self):
107
+ return self.__rmul__(other=-1)
108
+
109
+ def __truediv__(self, other):
110
+ assert utils.is_numerical(other)
111
+ return EvaluatedScalar(
112
+ vector=self.vector / other,
113
+ matrix=self.matrix / other,
114
+ constant=self.constant / other,
115
+ )
116
+
117
+
118
+ @attrs.frozen
119
+ class Scalar:
120
+ # If true, the scalar is the basis for the evaluations of F
121
+ is_basis: bool
122
+
123
+ # Not sure on this yet
124
+ eval_expression: EvalExpressionScalar | None = None
125
+
126
+ # Human tagged value for the scalar
127
+ tags: list[str] = attrs.field(factory=list)
128
+
129
+ # Generate an automatic id
130
+ uid: uuid.UUID = attrs.field(factory=uuid.uuid4, init=False)
131
+
132
+ def __attrs_post_init__(self):
133
+ if self.is_basis:
134
+ assert self.eval_expression is None
135
+ else:
136
+ assert self.eval_expression is not None
137
+
138
+ pep_context = pc.get_current_context()
139
+ if pep_context is None:
140
+ raise RuntimeError("Did you forget to create a context?")
141
+ pep_context.add_scalar(self)
142
+
143
+ def add_tag(self, tag: str):
144
+ self.tags.append(tag)
145
+
146
+ def __add__(self, other):
147
+ assert is_numerical_or_scalar(other)
148
+ return Scalar(
149
+ is_basis=False,
150
+ eval_expression=EvalExpressionScalar(utils.Op.ADD, self, other),
151
+ )
152
+
153
+ def __radd__(self, other):
154
+ assert is_numerical_or_scalar(other)
155
+ return Scalar(
156
+ is_basis=False,
157
+ eval_expression=EvalExpressionScalar(utils.Op.ADD, other, self),
158
+ )
159
+
160
+ def __sub__(self, other):
161
+ assert is_numerical_or_scalar(other)
162
+ return Scalar(
163
+ is_basis=False,
164
+ eval_expression=EvalExpressionScalar(utils.Op.SUB, self, other),
165
+ )
166
+
167
+ def __rsub__(self, other):
168
+ assert is_numerical_or_scalar(other)
169
+ return Scalar(
170
+ is_basis=False,
171
+ eval_expression=EvalExpressionScalar(utils.Op.SUB, other, self),
172
+ )
173
+
174
+ def __mul__(self, other):
175
+ assert utils.is_numerical(other)
176
+ return Scalar(
177
+ is_basis=False,
178
+ eval_expression=EvalExpressionScalar(utils.Op.MUL, self, other),
179
+ )
180
+
181
+ def __rmul__(self, other):
182
+ assert utils.is_numerical(other)
183
+ return Scalar(
184
+ is_basis=False,
185
+ eval_expression=EvalExpressionScalar(utils.Op.MUL, other, self),
186
+ )
187
+
188
+ def __neg__(self):
189
+ return self.__rmul__(other=-1)
190
+
191
+ def __truediv__(self, other):
192
+ assert utils.is_numerical(other)
193
+ return Scalar(
194
+ is_basis=False,
195
+ eval_expression=EvalExpressionScalar(utils.Op.DIV, self, other),
196
+ )
197
+
198
+ def __hash__(self):
199
+ return hash(self.uid)
200
+
201
+ def __eq__(self, other):
202
+ if not isinstance(other, Scalar):
203
+ return NotImplemented
204
+ return self.uid == other.uid
205
+
206
+ def le(self, other, name: str) -> ctr.Constraint:
207
+ return ctr.Constraint(self - other, comparator=utils.Comparator.LT, name=name)
208
+
209
+ def lt(self, other, name: str) -> ctr.Constraint:
210
+ return ctr.Constraint(self - other, comparator=utils.Comparator.LT, name=name)
211
+
212
+ def ge(self, other, name: str) -> ctr.Constraint:
213
+ return ctr.Constraint(self - other, comparator=utils.Comparator.GT, name=name)
214
+
215
+ def gt(self, other, name: str) -> ctr.Constraint:
216
+ return ctr.Constraint(self - other, comparator=utils.Comparator.GT, name=name)
217
+
218
+ def eq(self, other, name: str) -> ctr.Constraint:
219
+ return ctr.Constraint(self - other, comparator=utils.Comparator.EQ, name=name)
pepflow/solver.py ADDED
@@ -0,0 +1,98 @@
1
+ import warnings
2
+
3
+ import cvxpy
4
+
5
+ from pepflow import constraint as ctr
6
+ from pepflow import expression_manager as exm
7
+ from pepflow import pep_context as pc
8
+ from pepflow import scalar as sc
9
+ from pepflow import utils
10
+
11
+
12
+ def evaled_scalar_to_cvx_express(
13
+ eval_scalar: sc.EvaluatedScalar, vec_var: cvxpy.Variable, matrix_var: cvxpy.Variable
14
+ ) -> cvxpy.Expression:
15
+ return (
16
+ vec_var @ eval_scalar.vector
17
+ + cvxpy.trace(matrix_var @ eval_scalar.matrix)
18
+ + eval_scalar.constant
19
+ )
20
+
21
+
22
+ class DualVariableManager:
23
+ def __init__(self, named_constraints: list[tuple[str, cvxpy.Constraint]]):
24
+ self.named_constraints = {}
25
+ for name, c in named_constraints:
26
+ self.add_constraint(name, c)
27
+
28
+ def cvx_constraints(self) -> list[cvxpy.Constraint]:
29
+ return list(self.named_constraints.values())
30
+
31
+ def clear(self) -> None:
32
+ self.named_constraints.clear()
33
+
34
+ def add_constraint(self, name: str, constraint: cvxpy.Constraint) -> None:
35
+ if name in self.named_constraints:
36
+ raise KeyError(f"There is already a constraint named {name}")
37
+ self.named_constraints[name] = constraint
38
+
39
+ def dual_value(self, name: str) -> float | None:
40
+ if name not in self.named_constraints:
41
+ raise KeyError(f"Cannot find the constraint named {name}")
42
+ dual_value = self.named_constraints[name].dual_value
43
+ if dual_value is None:
44
+ warnings.warn("Did you forget to solve the problem first?")
45
+ return None
46
+ return dual_value
47
+
48
+
49
+ class CVXSolver:
50
+ def __init__(
51
+ self,
52
+ perf_metric: sc.Scalar,
53
+ constraints: list[ctr.Constraint],
54
+ context: pc.PEPContext,
55
+ ):
56
+ self.perf_metric = perf_metric
57
+ self.constraints = constraints
58
+ self.dual_var_manager = DualVariableManager([])
59
+ self.context = context
60
+
61
+ def build_problem(self) -> cvxpy.Problem:
62
+ em = exm.ExpressionManager(self.context)
63
+ f_var = cvxpy.Variable(em._num_basis_scalars)
64
+ g_var = cvxpy.Variable(
65
+ (em._num_basis_points, em._num_basis_points), symmetric=True
66
+ )
67
+
68
+ # Evaluate all poiints and scalars in advance to store it in cache.
69
+ for point in self.context.points:
70
+ em.eval_point(point)
71
+ for scalar in self.context.scalars:
72
+ em.eval_scalar(scalar)
73
+
74
+ self.dual_var_manager.clear()
75
+ self.dual_var_manager.add_constraint("PSD of Grammian Matrix", g_var >> 0)
76
+ for c in self.constraints:
77
+ exp = evaled_scalar_to_cvx_express(em.eval_scalar(c.scalar), f_var, g_var)
78
+ if c.comparator == utils.Comparator.GT:
79
+ self.dual_var_manager.add_constraint(c.name, exp >= 0)
80
+ elif c.comparator == utils.Comparator.LT:
81
+ self.dual_var_manager.add_constraint(c.name, exp <= 0)
82
+ elif c.comparator == utils.Comparator.EQ:
83
+ self.dual_var_manager.add_constraint(c.name, exp == 0)
84
+ else:
85
+ raise ValueError(f"Unknown comparator {c.comparator}")
86
+
87
+ obj = evaled_scalar_to_cvx_express(
88
+ em.eval_scalar(self.perf_metric), f_var, g_var
89
+ )
90
+
91
+ return cvxpy.Problem(
92
+ cvxpy.Maximize(obj), self.dual_var_manager.cvx_constraints()
93
+ )
94
+
95
+ def solve(self, **kwargs):
96
+ problem = self.build_problem()
97
+ result = problem.solve(**kwargs)
98
+ return result
pepflow/solver_test.py ADDED
@@ -0,0 +1,51 @@
1
+ import numpy as np
2
+
3
+ from pepflow import pep as pep
4
+ from pepflow import point as pp
5
+ from pepflow import solver as ps
6
+
7
+
8
+ def test_cvx_solver_case1():
9
+ pep_builder = pep.PEPBuilder()
10
+ with pep_builder.make_context("test"):
11
+ p1 = pp.Point(is_basis=True)
12
+ s1 = pp.Scalar(is_basis=True)
13
+ s2 = -(1 + p1 * p1)
14
+ constraints = [(p1 * p1).gt(1, name="x^2 >= 1"), s1.gt(0, name="s1 > 0")]
15
+
16
+ solver = ps.CVXSolver(
17
+ perf_metric=s2,
18
+ constraints=constraints,
19
+ context=pep_builder.get_context("test"),
20
+ )
21
+
22
+ # It is a simple `min_x 1 + x^2; s.t. x^2 >= 1` problem.
23
+ problem = solver.build_problem()
24
+ result = problem.solve()
25
+ assert abs(-result - 2) < 1e-6
26
+
27
+ assert np.isclose(solver.dual_var_manager.dual_value("x^2 >= 1"), 1)
28
+ assert solver.dual_var_manager.dual_value("s1 > 0") == 0
29
+
30
+
31
+ def test_cvx_solver_case2():
32
+ pep_builder = pep.PEPBuilder()
33
+ with pep_builder.make_context("test"):
34
+ p1 = pp.Point(is_basis=True)
35
+ s1 = pp.Scalar(is_basis=True)
36
+ s2 = -(p1 - 1) * (p1 - 2)
37
+ constraints = [(p1 * p1).lt(1, name="x^2 <= 1"), s1.gt(0, name="s1 > 0")]
38
+
39
+ solver = ps.CVXSolver(
40
+ perf_metric=s2,
41
+ constraints=constraints,
42
+ context=pep_builder.get_context("test"),
43
+ )
44
+
45
+ # It is a simple `min_x (x-1)(x-2); s.t. x^2 <= 1` problem.
46
+ problem = solver.build_problem()
47
+ result = problem.solve()
48
+ assert abs(-result) < 1e-6
49
+
50
+ assert np.isclose(solver.dual_var_manager.dual_value("x^2 <= 1"), 0)
51
+ assert solver.dual_var_manager.dual_value("s1 > 0") == 0