pepflow 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pepflow/__init__.py +6 -1
- pepflow/constraint.py +58 -1
- pepflow/constraint_test.py +71 -0
- pepflow/e2e_test.py +83 -4
- pepflow/expression_manager.py +329 -44
- pepflow/expression_manager_test.py +150 -0
- pepflow/function.py +294 -52
- pepflow/function_test.py +180 -114
- pepflow/interactive_constraint.py +165 -75
- pepflow/parameter.py +187 -0
- pepflow/parameter_test.py +128 -0
- pepflow/pep.py +263 -16
- pepflow/pep_context.py +122 -6
- pepflow/pep_context_test.py +25 -0
- pepflow/pep_test.py +8 -0
- pepflow/point.py +155 -49
- pepflow/point_test.py +40 -188
- pepflow/scalar.py +260 -47
- pepflow/scalar_test.py +102 -130
- pepflow/solver.py +170 -3
- pepflow/solver_test.py +50 -2
- pepflow/utils.py +39 -7
- {pepflow-0.1.4.dist-info → pepflow-0.1.5.dist-info}/METADATA +24 -5
- pepflow-0.1.5.dist-info/RECORD +28 -0
- pepflow-0.1.4.dist-info/RECORD +0 -24
- {pepflow-0.1.4.dist-info → pepflow-0.1.5.dist-info}/WHEEL +0 -0
- {pepflow-0.1.4.dist-info → pepflow-0.1.5.dist-info}/licenses/LICENSE +0 -0
- {pepflow-0.1.4.dist-info → pepflow-0.1.5.dist-info}/top_level.txt +0 -0
pepflow/expression_manager.py
CHANGED
@@ -18,22 +18,53 @@
|
|
18
18
|
# under the License.
|
19
19
|
|
20
20
|
import functools
|
21
|
+
import math
|
21
22
|
|
22
23
|
import numpy as np
|
23
24
|
|
25
|
+
from pepflow import parameter as pm
|
24
26
|
from pepflow import pep_context as pc
|
25
27
|
from pepflow import point as pt
|
26
28
|
from pepflow import scalar as sc
|
27
29
|
from pepflow import utils
|
28
30
|
|
29
31
|
|
32
|
+
def tag_and_coef_to_str(tag: str, v: float) -> str:
|
33
|
+
coef = utils.numerical_str(abs(v))
|
34
|
+
sign = "+" if v >= 0 else "-"
|
35
|
+
if math.isclose(abs(v), 1):
|
36
|
+
return f"{sign} {tag} "
|
37
|
+
elif math.isclose(v, 0, abs_tol=1e-5):
|
38
|
+
return ""
|
39
|
+
else:
|
40
|
+
return f"{sign} {coef}*{tag} "
|
41
|
+
|
42
|
+
|
30
43
|
class ExpressionManager:
|
31
|
-
|
44
|
+
"""
|
45
|
+
A class that handles the concrete representations of abstract
|
46
|
+
:class:`Point` and :class:`Scalar` objects managed by a particular
|
47
|
+
:class:`PEPContext` object.
|
48
|
+
|
49
|
+
Attributes:
|
50
|
+
context (:class:`PEPContext`): The :class:`PEPContext` object which
|
51
|
+
manages the abstract :class:`Point` and :class:`Scalar` objects
|
52
|
+
of interest.
|
53
|
+
resolve_parameters (dict[str, :class:`NUMERICAL_TYPE`]): A dictionary that
|
54
|
+
maps the name of parameters to the numerical values.
|
55
|
+
"""
|
56
|
+
|
57
|
+
def __init__(
|
58
|
+
self,
|
59
|
+
pep_context: pc.PEPContext,
|
60
|
+
resolve_parameters: dict[str, utils.NUMERICAL_TYPE] | None = None,
|
61
|
+
):
|
32
62
|
self.context = pep_context
|
33
63
|
self._basis_points = []
|
34
64
|
self._basis_point_uid_to_index = {}
|
35
65
|
self._basis_scalars = []
|
36
66
|
self._basis_scalar_uid_to_index = {}
|
67
|
+
self.resolve_parameters = resolve_parameters or {}
|
37
68
|
for point in self.context.points:
|
38
69
|
if point.is_basis:
|
39
70
|
self._basis_points.append(point)
|
@@ -48,85 +79,339 @@ class ExpressionManager:
|
|
48
79
|
self._num_basis_points = len(self._basis_points)
|
49
80
|
self._num_basis_scalars = len(self._basis_scalars)
|
50
81
|
|
51
|
-
def get_index_of_basis_point(self, point: pt.Point):
|
82
|
+
def get_index_of_basis_point(self, point: pt.Point) -> int:
|
52
83
|
return self._basis_point_uid_to_index[point.uid]
|
53
84
|
|
54
|
-
def get_index_of_basis_scalar(self, scalar: sc.Scalar):
|
85
|
+
def get_index_of_basis_scalar(self, scalar: sc.Scalar) -> int:
|
55
86
|
return self._basis_scalar_uid_to_index[scalar.uid]
|
56
87
|
|
88
|
+
def get_tag_of_basis_point_index(self, index: int) -> str:
|
89
|
+
return self._basis_points[index].tag
|
90
|
+
|
91
|
+
def get_tag_of_basis_scalar_index(self, index: int) -> str:
|
92
|
+
return self._basis_scalars[index].tag
|
93
|
+
|
57
94
|
@functools.cache
|
58
95
|
def eval_point(self, point: pt.Point | float | int):
|
96
|
+
"""
|
97
|
+
Return the concrete representation of the given :class:`Point`,
|
98
|
+
`float`, or `int`. Concrete representations of :class:`Point` objects
|
99
|
+
are :class:`EvaluatedPoint` objects. Concrete representations of
|
100
|
+
`float` or `int` arguments are themselves.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
point (:class:`Point`, float, int): The abstract :class:`Point`,
|
104
|
+
`float`, or `int` object whose concrete representation we want
|
105
|
+
to find.
|
106
|
+
|
107
|
+
Returns:
|
108
|
+
:class:`EvaluatedPoint` | float | int: The concrete representation
|
109
|
+
the `point` argument.
|
110
|
+
"""
|
59
111
|
if utils.is_numerical(point):
|
60
112
|
return point
|
61
113
|
|
62
|
-
|
114
|
+
if isinstance(point, pm.Parameter):
|
115
|
+
return point.get_value(self.resolve_parameters)
|
116
|
+
|
63
117
|
if point.is_basis:
|
64
118
|
index = self.get_index_of_basis_point(point)
|
119
|
+
array = np.zeros(self._num_basis_points)
|
65
120
|
array[index] = 1
|
66
121
|
return pt.EvaluatedPoint(vector=array)
|
67
122
|
|
123
|
+
if isinstance(point.eval_expression, pt.ZeroPoint):
|
124
|
+
return pt.EvaluatedPoint.zero(num_basis_points=self._num_basis_points)
|
125
|
+
|
68
126
|
op = point.eval_expression.op
|
127
|
+
left_evaled_point = self.eval_point(point.eval_expression.left_point)
|
128
|
+
right_evaled_point = self.eval_point(point.eval_expression.right_point)
|
69
129
|
if op == utils.Op.ADD:
|
70
|
-
return
|
71
|
-
point.eval_expression.right_point
|
72
|
-
)
|
130
|
+
return left_evaled_point + right_evaled_point
|
73
131
|
if op == utils.Op.SUB:
|
74
|
-
return
|
75
|
-
point.eval_expression.right_point
|
76
|
-
)
|
132
|
+
return left_evaled_point - right_evaled_point
|
77
133
|
if op == utils.Op.MUL:
|
78
|
-
return
|
79
|
-
point.eval_expression.right_point
|
80
|
-
)
|
134
|
+
return left_evaled_point * right_evaled_point
|
81
135
|
if op == utils.Op.DIV:
|
82
|
-
return
|
83
|
-
point.eval_expression.right_point
|
84
|
-
)
|
136
|
+
return left_evaled_point / right_evaled_point
|
85
137
|
|
86
|
-
raise ValueError("
|
138
|
+
raise ValueError(f"Encountered unknown {op=} when evaluation the point.")
|
87
139
|
|
88
140
|
@functools.cache
|
89
141
|
def eval_scalar(self, scalar: sc.Scalar | float | int):
|
142
|
+
"""
|
143
|
+
Return the concrete representation of the given :class:`Scalar`,
|
144
|
+
`float`, or `int`. Concrete representations of :class:`Scalar` objects
|
145
|
+
are :class:`EvaluatedScalar` objects. Concrete representations of
|
146
|
+
`float` or `int` arguments are themselves.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
scalar (:class:`Point`, float, int): The abstract :class:`Scalar`,
|
150
|
+
`float`, or `int` object whose concrete representation we want
|
151
|
+
to find.
|
152
|
+
|
153
|
+
Returns:
|
154
|
+
:class:`EvaluatedScalar` | float | int: The concrete representation
|
155
|
+
the `scalar` argument.
|
156
|
+
"""
|
90
157
|
if utils.is_numerical(scalar):
|
91
158
|
return scalar
|
159
|
+
if isinstance(scalar, pm.Parameter):
|
160
|
+
return scalar.get_value(self.resolve_parameters)
|
92
161
|
|
93
|
-
array = np.zeros(self._num_basis_scalars)
|
94
162
|
if scalar.is_basis:
|
95
163
|
index = self.get_index_of_basis_scalar(scalar)
|
164
|
+
array = np.zeros(self._num_basis_scalars)
|
96
165
|
array[index] = 1
|
166
|
+
matrix = np.zeros((self._num_basis_points, self._num_basis_points))
|
97
167
|
return sc.EvaluatedScalar(
|
98
168
|
vector=array,
|
99
|
-
matrix=
|
169
|
+
matrix=matrix,
|
100
170
|
constant=float(0.0),
|
101
171
|
)
|
172
|
+
|
173
|
+
if isinstance(scalar.eval_expression, sc.ZeroScalar):
|
174
|
+
return sc.EvaluatedScalar.zero(
|
175
|
+
num_basis_scalars=self._num_basis_scalars,
|
176
|
+
num_basis_points=self._num_basis_points,
|
177
|
+
)
|
178
|
+
|
102
179
|
op = scalar.eval_expression.op
|
180
|
+
# The special inner product usage.
|
181
|
+
if (
|
182
|
+
op == utils.Op.MUL
|
183
|
+
and isinstance(scalar.eval_expression.left_scalar, pt.Point)
|
184
|
+
and isinstance(scalar.eval_expression.right_scalar, pt.Point)
|
185
|
+
):
|
186
|
+
array = np.zeros(self._num_basis_scalars)
|
187
|
+
return sc.EvaluatedScalar(
|
188
|
+
vector=array,
|
189
|
+
matrix=utils.SOP(
|
190
|
+
self.eval_point(scalar.eval_expression.left_scalar).vector,
|
191
|
+
self.eval_point(scalar.eval_expression.right_scalar).vector,
|
192
|
+
),
|
193
|
+
constant=float(0.0),
|
194
|
+
)
|
195
|
+
|
196
|
+
left_evaled_scalar = self.eval_scalar(scalar.eval_expression.left_scalar)
|
197
|
+
right_evaled_scalar = self.eval_scalar(scalar.eval_expression.right_scalar)
|
103
198
|
if op == utils.Op.ADD:
|
104
|
-
return
|
105
|
-
scalar.eval_expression.left_scalar
|
106
|
-
) + self.eval_scalar(scalar.eval_expression.right_scalar)
|
199
|
+
return left_evaled_scalar + right_evaled_scalar
|
107
200
|
if op == utils.Op.SUB:
|
108
|
-
return
|
109
|
-
scalar.eval_expression.left_scalar
|
110
|
-
) - self.eval_scalar(scalar.eval_expression.right_scalar)
|
201
|
+
return left_evaled_scalar - right_evaled_scalar
|
111
202
|
if op == utils.Op.MUL:
|
112
|
-
|
113
|
-
scalar.eval_expression.right_scalar, pt.Point
|
114
|
-
):
|
115
|
-
return sc.EvaluatedScalar(
|
116
|
-
vector=np.zeros(self._num_basis_scalars),
|
117
|
-
matrix=utils.SOP(
|
118
|
-
self.eval_point(scalar.eval_expression.left_scalar).vector,
|
119
|
-
self.eval_point(scalar.eval_expression.right_scalar).vector,
|
120
|
-
),
|
121
|
-
constant=float(0.0),
|
122
|
-
)
|
123
|
-
else:
|
124
|
-
return self.eval_scalar(
|
125
|
-
scalar.eval_expression.left_scalar
|
126
|
-
) * self.eval_scalar(scalar.eval_expression.right_scalar)
|
203
|
+
return left_evaled_scalar * right_evaled_scalar
|
127
204
|
if op == utils.Op.DIV:
|
128
|
-
return
|
129
|
-
|
130
|
-
|
205
|
+
return left_evaled_scalar / right_evaled_scalar
|
206
|
+
|
207
|
+
raise ValueError(f"Encountered unknown {op=} when evaluation the scalar.")
|
208
|
+
|
209
|
+
@functools.cache
|
210
|
+
def repr_point_by_basis(self, point: pt.Point) -> str:
|
211
|
+
"""
|
212
|
+
Express the given :class:`Point` object as the linear combination of
|
213
|
+
the basis :class:`Point` objects of the :class:`PEPContext` associated
|
214
|
+
with this :class:`ExpressionManager`. This linear combination is
|
215
|
+
expressed as a `str` where, to refer to the basis :class:`Point`
|
216
|
+
objects, we use their tags.
|
217
|
+
|
218
|
+
Args:
|
219
|
+
point (:class:`Point`): The :class:`Point` object which we want
|
220
|
+
to express in terms of the basis :class:`Point` objects.
|
221
|
+
|
222
|
+
Returns:
|
223
|
+
str: The representation of `point` in terms of the basis
|
224
|
+
:class:`Point` objects of the :class:`PEPContext` associated
|
225
|
+
with this :class:`ExpressionManager`.
|
226
|
+
"""
|
227
|
+
assert isinstance(point, pt.Point)
|
228
|
+
evaluated_point = self.eval_point(point)
|
229
|
+
return self.repr_evaluated_point_by_basis(evaluated_point)
|
230
|
+
|
231
|
+
def repr_evaluated_point_by_basis(self, evaluated_point: pt.EvaluatedPoint) -> str:
|
232
|
+
"""
|
233
|
+
Express the given :class:`EvaluatedPoint` object as the linear
|
234
|
+
combination of the basis :class:`Point` objects of the
|
235
|
+
:class:`PEPContext` associated with this :class:`ExpressionManager`.
|
236
|
+
This linear combination is expressed as a `str` where, to refer to the
|
237
|
+
basis :class:`Point` objects, we use their tags.
|
238
|
+
|
239
|
+
Args:
|
240
|
+
evaluated_point (:class:`EvaluatedPoint`): The
|
241
|
+
:class:`EvaluatedPoint` object which we want to express in
|
242
|
+
terms of the basis :class:`Point` objects.
|
243
|
+
|
244
|
+
Returns:
|
245
|
+
str: The representation of `evaluated_point` in terms of
|
246
|
+
the basis :class:`Point` objects of the :class:`PEPContext`
|
247
|
+
associated with this :class:`ExpressionManager`.
|
248
|
+
"""
|
249
|
+
repr_str = ""
|
250
|
+
for i, v in enumerate(evaluated_point.vector):
|
251
|
+
ith_tag = self.get_tag_of_basis_point_index(i)
|
252
|
+
repr_str += tag_and_coef_to_str(ith_tag, v)
|
253
|
+
|
254
|
+
# Post processing
|
255
|
+
if repr_str == "":
|
256
|
+
return "0"
|
257
|
+
if repr_str.startswith("+ "):
|
258
|
+
repr_str = repr_str[2:]
|
259
|
+
if repr_str.startswith("- "):
|
260
|
+
repr_str = "-" + repr_str[2:]
|
261
|
+
return repr_str.strip()
|
262
|
+
|
263
|
+
@functools.cache
|
264
|
+
def repr_scalar_by_basis(
|
265
|
+
self, scalar: sc.Scalar, greedy_square: bool = True
|
266
|
+
) -> str:
|
267
|
+
"""Express the given :class:`Scalar` object in terms of the basis
|
268
|
+
:class:`Point` and :class:`Scalar` objects of the :class:`PEPContext`
|
269
|
+
associated with this :class:`ExpressionManager`.
|
270
|
+
|
271
|
+
A :class:`Scalar` can be formed by linear combinations of basis
|
272
|
+
:class:`Scalar` objects. A :class:`Scalar` can also be formed through
|
273
|
+
the inner product of two basis :class:`Point` objects. This function
|
274
|
+
returns the representation of this :class:`Scalar` object in terms of
|
275
|
+
the basis :class:`Point` and :class:`Scalar` objects as a `str` where,
|
276
|
+
to refer to the basis :class:`Point` and :class:`Scalar` objects,
|
277
|
+
we use their tags.
|
278
|
+
|
279
|
+
Args:
|
280
|
+
scalar (:class:`Scalar`): The :class:`Scalar` object which we want
|
281
|
+
to express in terms of the basis :class:`Point` and
|
282
|
+
:class:`Scalar` objects.
|
283
|
+
greedy_square (bool): If `greedy_square` is true, the function will
|
284
|
+
try to return :math:`\\|a-b\\|^2` whenever possible. If not,
|
285
|
+
the function will return
|
286
|
+
:math:`\\|a\\|^2 - 2 * \\langle a, b \\rangle + \\|b\\|^2` instead.
|
287
|
+
`True` by default.
|
288
|
+
|
289
|
+
Returns:
|
290
|
+
str: The representation of `scalar` in terms of the basis
|
291
|
+
:class:`Point` and :class:`Scalar` objects of the
|
292
|
+
:class:`PEPContext` associated with this
|
293
|
+
:class:`ExpressionManager`.
|
294
|
+
"""
|
295
|
+
assert isinstance(scalar, sc.Scalar)
|
296
|
+
evaluated_scalar = self.eval_scalar(scalar)
|
297
|
+
return self.repr_evaluated_scalar_by_basis(
|
298
|
+
evaluated_scalar, greedy_square=greedy_square
|
299
|
+
)
|
300
|
+
|
301
|
+
def repr_evaluated_scalar_by_basis(
|
302
|
+
self, evaluated_scalar: sc.EvaluatedScalar, greedy_square: bool = True
|
303
|
+
) -> str:
|
304
|
+
"""Express the given :class:`EvaluatedScalar` object in terms of the
|
305
|
+
basis :class:`Point` and :class:`Scalar` objects of the
|
306
|
+
:class:`PEPContext` associated with this :class:`ExpressionManager`.
|
307
|
+
|
308
|
+
A :class:`Scalar` can be formed by linear combinations of basis
|
309
|
+
:class:`Scalar` objects. A :class:`Scalar` can also be formed through
|
310
|
+
the inner product of two basis :class:`Point` objects. This function
|
311
|
+
returns the representation of this :class:`Scalar` object in terms of
|
312
|
+
the basis :class:`Point` and :class:`Scalar` objects as a `str` where,
|
313
|
+
to refer to the basis :class:`Point` and :class:`Scalar` objects,
|
314
|
+
we use their tags.
|
315
|
+
|
316
|
+
Args:
|
317
|
+
evaluated_scalar (:class:`EvaluatedScalar`): The
|
318
|
+
:class:`EvaluatedScalar` object which we want to express in
|
319
|
+
terms of the basis :class:`Point` and :class:`Scalar` objects.
|
320
|
+
greedy_square (bool): If `greedy_square` is true, the function will
|
321
|
+
try to return :math:`\\|a-b\\|^2` whenever possible. If not,
|
322
|
+
the function will return
|
323
|
+
:math:`\\|a\\|^2 - 2 * \\langle a, b \\rangle + \\|b\\|^2` instead.
|
324
|
+
`True` by default.
|
325
|
+
|
326
|
+
Returns:
|
327
|
+
str: The representation of `evaluated_scalar` in terms of
|
328
|
+
the basis :class:`Point` and :class:`Scalar` objects of the
|
329
|
+
:class:`PEPContext` associated with this :class:`ExpressionManager`.
|
330
|
+
"""
|
331
|
+
repr_str = ""
|
332
|
+
if not math.isclose(evaluated_scalar.constant, 0, abs_tol=1e-5):
|
333
|
+
repr_str += utils.numerical_str(evaluated_scalar.constant)
|
334
|
+
|
335
|
+
for i, v in enumerate(evaluated_scalar.vector):
|
336
|
+
# Note the tag is from scalar basis.
|
337
|
+
ith_tag = self.get_tag_of_basis_scalar_index(i)
|
338
|
+
repr_str += tag_and_coef_to_str(ith_tag, v)
|
339
|
+
|
340
|
+
if greedy_square:
|
341
|
+
diag_elem = np.diag(evaluated_scalar.matrix).copy()
|
342
|
+
for i in range(evaluated_scalar.matrix.shape[0]):
|
343
|
+
ith_tag = self.get_tag_of_basis_point_index(i)
|
344
|
+
# j starts from i+1 since we want to handle the diag elem at last.
|
345
|
+
for j in range(i + 1, evaluated_scalar.matrix.shape[0]):
|
346
|
+
jth_tag = self.get_tag_of_basis_point_index(j)
|
347
|
+
v = float(evaluated_scalar.matrix[i, j])
|
348
|
+
# We want to minimize the diagonal elements to zero greedily.
|
349
|
+
if diag_elem[i] * v > 0: # same sign with diagonal elem
|
350
|
+
diag_elem[i] -= v
|
351
|
+
diag_elem[j] -= v
|
352
|
+
repr_str += tag_and_coef_to_str(f"|{ith_tag}+{jth_tag}|^2", v)
|
353
|
+
else: # different sign
|
354
|
+
diag_elem[i] += v
|
355
|
+
diag_elem[j] += v
|
356
|
+
repr_str += tag_and_coef_to_str(f"|{ith_tag}-{jth_tag}|^2", -v)
|
357
|
+
# Handle the diagonal elements
|
358
|
+
repr_str += tag_and_coef_to_str(f"|{ith_tag}|^2", diag_elem[i])
|
359
|
+
else:
|
360
|
+
for i in range(evaluated_scalar.matrix.shape[0]):
|
361
|
+
ith_tag = self.get_tag_of_basis_point_index(i)
|
362
|
+
for j in range(i, evaluated_scalar.matrix.shape[0]):
|
363
|
+
jth_tag = self.get_tag_of_basis_point_index(j)
|
364
|
+
v = float(evaluated_scalar.matrix[i, j])
|
365
|
+
if i == j:
|
366
|
+
repr_str += tag_and_coef_to_str(f"|{ith_tag}|^2", v)
|
367
|
+
else:
|
368
|
+
repr_str += tag_and_coef_to_str(
|
369
|
+
f"<{ith_tag}, {jth_tag}>", 2 * v
|
370
|
+
)
|
371
|
+
|
372
|
+
# Post processing
|
373
|
+
if repr_str == "":
|
374
|
+
return "0"
|
375
|
+
if repr_str.startswith("+ "):
|
376
|
+
repr_str = repr_str[2:]
|
377
|
+
if repr_str.startswith("- "):
|
378
|
+
repr_str = "-" + repr_str[2:]
|
379
|
+
return repr_str.strip()
|
380
|
+
|
381
|
+
|
382
|
+
def represent_matrix_by_basis(matrix: np.ndarray, ctx: pc.PEPContext) -> str:
|
383
|
+
"""Express the given matrix in terms of the basis :class:`Point` objects
|
384
|
+
of the given :class:`PEPContext`.
|
385
|
+
|
386
|
+
The concrete representation of the inner product of two abstract
|
387
|
+
basis :class:`Point` objects is a matrix (the outer product of the
|
388
|
+
basis vectors corresponding to the concrete representations of the abstract
|
389
|
+
basis :class:`Point` objects). The given matrix can then be expressed
|
390
|
+
as a linear combination of the inner products of abstract basis
|
391
|
+
:class:`Point` objects. This is provided as a `str` where, to refer to
|
392
|
+
the basis :class:`Point` objects, we use their tags.
|
393
|
+
|
394
|
+
Args:
|
395
|
+
matrix (np.ndarray): The matrix which we want to express in terms of
|
396
|
+
the basis :class:`Point` objects of the given :class:`PEPContext`.
|
397
|
+
ctx (:class:`PEPContext`): The :class:`PEPContext` whose basis
|
398
|
+
:class:`Point` objects we consider.
|
399
|
+
|
400
|
+
Returns:
|
401
|
+
str: The representation of `matrix` in terms of the basis
|
402
|
+
:class:`Point` objects of `ctx`.
|
403
|
+
"""
|
404
|
+
em = ExpressionManager(ctx)
|
405
|
+
matrix_shape = (len(em._basis_points), len(em._basis_points))
|
406
|
+
if matrix.shape != matrix_shape:
|
407
|
+
raise ValueError(
|
408
|
+
"The valid matrix for given context should have shape {matrix_shape}"
|
409
|
+
)
|
410
|
+
if not np.allclose(matrix, matrix.T):
|
411
|
+
raise ValueError("Input matrix must be symmetric.")
|
131
412
|
|
132
|
-
|
413
|
+
return em.repr_evaluated_scalar_by_basis(
|
414
|
+
sc.EvaluatedScalar(
|
415
|
+
vector=np.zeros(len(em._basis_scalars)), matrix=matrix, constant=0.0
|
416
|
+
)
|
417
|
+
)
|
@@ -0,0 +1,150 @@
|
|
1
|
+
# Copyright: 2025 The PEPFlow Developers
|
2
|
+
#
|
3
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
4
|
+
# or more contributor license agreements. See the NOTICE file
|
5
|
+
# distributed with this work for additional information
|
6
|
+
# regarding copyright ownership. The ASF licenses this file
|
7
|
+
# to you under the Apache License, Version 2.0 (the
|
8
|
+
# "License"); you may not use this file except in compliance
|
9
|
+
# with the License. You may obtain a copy of the License at
|
10
|
+
#
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12
|
+
#
|
13
|
+
# Unless required by applicable law or agreed to in writing,
|
14
|
+
# software distributed under the License is distributed on an
|
15
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
16
|
+
# KIND, either express or implied. See the License for the
|
17
|
+
# specific language governing permissions and limitations
|
18
|
+
# under the License.
|
19
|
+
|
20
|
+
|
21
|
+
from typing import Iterator
|
22
|
+
|
23
|
+
import numpy as np
|
24
|
+
import pytest
|
25
|
+
|
26
|
+
from pepflow import expression_manager as exm
|
27
|
+
from pepflow import function as fc
|
28
|
+
from pepflow import pep as pep
|
29
|
+
from pepflow import pep_context as pc
|
30
|
+
from pepflow import point as pt
|
31
|
+
|
32
|
+
|
33
|
+
@pytest.fixture
|
34
|
+
def pep_context() -> Iterator[pc.PEPContext]:
|
35
|
+
"""Prepare the pep context and reset the context to None at the end."""
|
36
|
+
ctx = pc.PEPContext("test").set_as_current()
|
37
|
+
yield ctx
|
38
|
+
pc.set_current_context(None)
|
39
|
+
|
40
|
+
|
41
|
+
def test_repr_point_by_basis(pep_context: pc.PEPContext) -> None:
|
42
|
+
x = pt.Point(is_basis=True, tags=["x_0"])
|
43
|
+
f = fc.Function(is_basis=True, tags=["f"])
|
44
|
+
L = 0.5
|
45
|
+
for i in range(2):
|
46
|
+
x = x - L * f.gradient(x)
|
47
|
+
x.add_tag(f"x_{i + 1}")
|
48
|
+
|
49
|
+
em = exm.ExpressionManager(pep_context)
|
50
|
+
np.testing.assert_allclose(em.eval_point(x).vector, [1, -0.5, -0.5])
|
51
|
+
assert (
|
52
|
+
em.repr_point_by_basis(x) == "x_0 - 0.5*gradient_f(x_0) - 0.5*gradient_f(x_1)"
|
53
|
+
)
|
54
|
+
|
55
|
+
|
56
|
+
def test_repr_point_by_basis_with_zero(pep_context: pc.PEPContext) -> None:
|
57
|
+
x = pt.Point(is_basis=True, tags=["x_0"])
|
58
|
+
_ = pt.Point(is_basis=True, tags=["x_unused"]) # Add this extra point.
|
59
|
+
f = fc.Function(is_basis=True, tags=["f"])
|
60
|
+
L = 0.5
|
61
|
+
for i in range(2):
|
62
|
+
x = x - L * f.gradient(x)
|
63
|
+
x.add_tag(f"x_{i + 1}")
|
64
|
+
|
65
|
+
em = exm.ExpressionManager(pep_context)
|
66
|
+
# Note the vector representation of point is different from previous case
|
67
|
+
# But the string representation is still the same.
|
68
|
+
np.testing.assert_allclose(em.eval_point(x).vector, [1, 0, -0.5, -0.5])
|
69
|
+
assert (
|
70
|
+
em.repr_point_by_basis(x) == "x_0 - 0.5*gradient_f(x_0) - 0.5*gradient_f(x_1)"
|
71
|
+
)
|
72
|
+
|
73
|
+
|
74
|
+
def test_repr_point_by_basis_heavy_ball(pep_context: pc.PEPContext) -> None:
|
75
|
+
x_prev = pt.Point(is_basis=True, tags=["x_{-1}"])
|
76
|
+
x = pt.Point(is_basis=True, tags=["x_0"])
|
77
|
+
f = fc.Function(is_basis=True, tags=["f"])
|
78
|
+
|
79
|
+
beta = 0.5
|
80
|
+
for i in range(2):
|
81
|
+
x_next = x - f.gradient(x) + beta * (x - x_prev)
|
82
|
+
x_next.add_tag(f"x_{i + 1}")
|
83
|
+
x_prev = x
|
84
|
+
x = x_next
|
85
|
+
|
86
|
+
em = exm.ExpressionManager(pep_context)
|
87
|
+
np.testing.assert_allclose(em.eval_point(x).vector, [-0.75, 1.75, -1.5, -1])
|
88
|
+
assert (
|
89
|
+
em.repr_point_by_basis(x)
|
90
|
+
== "-0.75*x_{-1} + 1.75*x_0 - 1.5*gradient_f(x_0) - gradient_f(x_1)"
|
91
|
+
)
|
92
|
+
|
93
|
+
|
94
|
+
def test_repr_scalar_by_basis(pep_context: pc.PEPContext) -> None:
|
95
|
+
x = pt.Point(is_basis=True, tags=["x"])
|
96
|
+
f = fc.Function(is_basis=True, tags=["f"])
|
97
|
+
|
98
|
+
s = f(x) + x * f.gradient(x)
|
99
|
+
em = exm.ExpressionManager(pep_context)
|
100
|
+
assert (
|
101
|
+
em.repr_scalar_by_basis(s, greedy_square=False) == "f(x) + <x, gradient_f(x)>"
|
102
|
+
)
|
103
|
+
assert (
|
104
|
+
em.repr_scalar_by_basis(s)
|
105
|
+
== "f(x) - 0.5*|x-gradient_f(x)|^2 + 0.5*|x|^2 + 0.5*|gradient_f(x)|^2"
|
106
|
+
)
|
107
|
+
|
108
|
+
|
109
|
+
def test_repr_scalar_by_basis2(pep_context: pc.PEPContext) -> None:
|
110
|
+
x = pt.Point(is_basis=True, tags=["x"])
|
111
|
+
f = fc.Function(is_basis=True, tags=["f"])
|
112
|
+
|
113
|
+
s = f(x) - x * f.gradient(x)
|
114
|
+
em = exm.ExpressionManager(pep_context)
|
115
|
+
assert (
|
116
|
+
em.repr_scalar_by_basis(s, greedy_square=False) == "f(x) - <x, gradient_f(x)>"
|
117
|
+
)
|
118
|
+
assert (
|
119
|
+
em.repr_scalar_by_basis(s)
|
120
|
+
== "f(x) + 0.5*|x-gradient_f(x)|^2 - 0.5*|x|^2 - 0.5*|gradient_f(x)|^2"
|
121
|
+
)
|
122
|
+
|
123
|
+
|
124
|
+
def test_repr_scalar_by_basis_interpolation(pep_context: pc.PEPContext) -> None:
|
125
|
+
xi = pt.Point(is_basis=True, tags=["x_i"])
|
126
|
+
xj = pt.Point(is_basis=True, tags=["x_j"])
|
127
|
+
f = fc.SmoothConvexFunction(is_basis=True, L=1)
|
128
|
+
f.add_tag("f")
|
129
|
+
fi = f(xi) # noqa: F841
|
130
|
+
fj = f(xj) # noqa: F841
|
131
|
+
interp_scalar = f.interpolate_ineq("x_i", "x_j")
|
132
|
+
em = exm.ExpressionManager(pep_context)
|
133
|
+
expected_repr = "-f(x_i) + f(x_j) + <x_i, gradient_f(x_j)> - <x_j, gradient_f(x_j)> + 0.5*|gradient_f(x_i)|^2 - <gradient_f(x_i), gradient_f(x_j)> + 0.5*|gradient_f(x_j)|^2"
|
134
|
+
assert em.repr_scalar_by_basis(interp_scalar, greedy_square=False) == expected_repr
|
135
|
+
expected_square_repr = "-f(x_i) + f(x_j) - 0.5*|x_i-gradient_f(x_j)|^2 + 0.5*|x_i|^2 + 0.5*|x_j-gradient_f(x_j)|^2 - 0.5*|x_j|^2 + 0.5*|gradient_f(x_i)-gradient_f(x_j)|^2"
|
136
|
+
assert em.repr_scalar_by_basis(interp_scalar) == expected_square_repr
|
137
|
+
|
138
|
+
|
139
|
+
# TODO add more tests about repr_scalar_by_basis
|
140
|
+
|
141
|
+
|
142
|
+
def test_represent_matrix_by_basis(pep_context: pc.PEPContext) -> None:
|
143
|
+
_ = pt.Point(is_basis=True, tags=["x_1"])
|
144
|
+
_ = pt.Point(is_basis=True, tags=["x_2"])
|
145
|
+
_ = pt.Point(is_basis=True, tags=["x_3"])
|
146
|
+
matrix = np.array([[0.5, 0.5, 0], [0.5, 2, 0], [0, 0, 3]])
|
147
|
+
assert (
|
148
|
+
exm.represent_matrix_by_basis(matrix, pep_context)
|
149
|
+
== "0.5*|x_1+x_2|^2 + 1.5*|x_2|^2 + 3*|x_3|^2"
|
150
|
+
)
|