pepflow 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,22 +18,53 @@
18
18
  # under the License.
19
19
 
20
20
  import functools
21
+ import math
21
22
 
22
23
  import numpy as np
23
24
 
25
+ from pepflow import parameter as pm
24
26
  from pepflow import pep_context as pc
25
27
  from pepflow import point as pt
26
28
  from pepflow import scalar as sc
27
29
  from pepflow import utils
28
30
 
29
31
 
32
+ def tag_and_coef_to_str(tag: str, v: float) -> str:
33
+ coef = utils.numerical_str(abs(v))
34
+ sign = "+" if v >= 0 else "-"
35
+ if math.isclose(abs(v), 1):
36
+ return f"{sign} {tag} "
37
+ elif math.isclose(v, 0, abs_tol=1e-5):
38
+ return ""
39
+ else:
40
+ return f"{sign} {coef}*{tag} "
41
+
42
+
30
43
  class ExpressionManager:
31
- def __init__(self, pep_context: pc.PEPContext):
44
+ """
45
+ A class that handles the concrete representations of abstract
46
+ :class:`Point` and :class:`Scalar` objects managed by a particular
47
+ :class:`PEPContext` object.
48
+
49
+ Attributes:
50
+ context (:class:`PEPContext`): The :class:`PEPContext` object which
51
+ manages the abstract :class:`Point` and :class:`Scalar` objects
52
+ of interest.
53
+ resolve_parameters (dict[str, :class:`NUMERICAL_TYPE`]): A dictionary that
54
+ maps the name of parameters to the numerical values.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ pep_context: pc.PEPContext,
60
+ resolve_parameters: dict[str, utils.NUMERICAL_TYPE] | None = None,
61
+ ):
32
62
  self.context = pep_context
33
63
  self._basis_points = []
34
64
  self._basis_point_uid_to_index = {}
35
65
  self._basis_scalars = []
36
66
  self._basis_scalar_uid_to_index = {}
67
+ self.resolve_parameters = resolve_parameters or {}
37
68
  for point in self.context.points:
38
69
  if point.is_basis:
39
70
  self._basis_points.append(point)
@@ -48,85 +79,339 @@ class ExpressionManager:
48
79
  self._num_basis_points = len(self._basis_points)
49
80
  self._num_basis_scalars = len(self._basis_scalars)
50
81
 
51
- def get_index_of_basis_point(self, point: pt.Point):
82
+ def get_index_of_basis_point(self, point: pt.Point) -> int:
52
83
  return self._basis_point_uid_to_index[point.uid]
53
84
 
54
- def get_index_of_basis_scalar(self, scalar: sc.Scalar):
85
+ def get_index_of_basis_scalar(self, scalar: sc.Scalar) -> int:
55
86
  return self._basis_scalar_uid_to_index[scalar.uid]
56
87
 
88
+ def get_tag_of_basis_point_index(self, index: int) -> str:
89
+ return self._basis_points[index].tag
90
+
91
+ def get_tag_of_basis_scalar_index(self, index: int) -> str:
92
+ return self._basis_scalars[index].tag
93
+
57
94
  @functools.cache
58
95
  def eval_point(self, point: pt.Point | float | int):
96
+ """
97
+ Return the concrete representation of the given :class:`Point`,
98
+ `float`, or `int`. Concrete representations of :class:`Point` objects
99
+ are :class:`EvaluatedPoint` objects. Concrete representations of
100
+ `float` or `int` arguments are themselves.
101
+
102
+ Args:
103
+ point (:class:`Point`, float, int): The abstract :class:`Point`,
104
+ `float`, or `int` object whose concrete representation we want
105
+ to find.
106
+
107
+ Returns:
108
+ :class:`EvaluatedPoint` | float | int: The concrete representation
109
+ the `point` argument.
110
+ """
59
111
  if utils.is_numerical(point):
60
112
  return point
61
113
 
62
- array = np.zeros(self._num_basis_points)
114
+ if isinstance(point, pm.Parameter):
115
+ return point.get_value(self.resolve_parameters)
116
+
63
117
  if point.is_basis:
64
118
  index = self.get_index_of_basis_point(point)
119
+ array = np.zeros(self._num_basis_points)
65
120
  array[index] = 1
66
121
  return pt.EvaluatedPoint(vector=array)
67
122
 
123
+ if isinstance(point.eval_expression, pt.ZeroPoint):
124
+ return pt.EvaluatedPoint.zero(num_basis_points=self._num_basis_points)
125
+
68
126
  op = point.eval_expression.op
127
+ left_evaled_point = self.eval_point(point.eval_expression.left_point)
128
+ right_evaled_point = self.eval_point(point.eval_expression.right_point)
69
129
  if op == utils.Op.ADD:
70
- return self.eval_point(point.eval_expression.left_point) + self.eval_point(
71
- point.eval_expression.right_point
72
- )
130
+ return left_evaled_point + right_evaled_point
73
131
  if op == utils.Op.SUB:
74
- return self.eval_point(point.eval_expression.left_point) - self.eval_point(
75
- point.eval_expression.right_point
76
- )
132
+ return left_evaled_point - right_evaled_point
77
133
  if op == utils.Op.MUL:
78
- return self.eval_point(point.eval_expression.left_point) * self.eval_point(
79
- point.eval_expression.right_point
80
- )
134
+ return left_evaled_point * right_evaled_point
81
135
  if op == utils.Op.DIV:
82
- return self.eval_point(point.eval_expression.left_point) / self.eval_point(
83
- point.eval_expression.right_point
84
- )
136
+ return left_evaled_point / right_evaled_point
85
137
 
86
- raise ValueError("This should never happen!")
138
+ raise ValueError(f"Encountered unknown {op=} when evaluation the point.")
87
139
 
88
140
  @functools.cache
89
141
  def eval_scalar(self, scalar: sc.Scalar | float | int):
142
+ """
143
+ Return the concrete representation of the given :class:`Scalar`,
144
+ `float`, or `int`. Concrete representations of :class:`Scalar` objects
145
+ are :class:`EvaluatedScalar` objects. Concrete representations of
146
+ `float` or `int` arguments are themselves.
147
+
148
+ Args:
149
+ scalar (:class:`Point`, float, int): The abstract :class:`Scalar`,
150
+ `float`, or `int` object whose concrete representation we want
151
+ to find.
152
+
153
+ Returns:
154
+ :class:`EvaluatedScalar` | float | int: The concrete representation
155
+ the `scalar` argument.
156
+ """
90
157
  if utils.is_numerical(scalar):
91
158
  return scalar
159
+ if isinstance(scalar, pm.Parameter):
160
+ return scalar.get_value(self.resolve_parameters)
92
161
 
93
- array = np.zeros(self._num_basis_scalars)
94
162
  if scalar.is_basis:
95
163
  index = self.get_index_of_basis_scalar(scalar)
164
+ array = np.zeros(self._num_basis_scalars)
96
165
  array[index] = 1
166
+ matrix = np.zeros((self._num_basis_points, self._num_basis_points))
97
167
  return sc.EvaluatedScalar(
98
168
  vector=array,
99
- matrix=np.zeros((self._num_basis_points, self._num_basis_points)),
169
+ matrix=matrix,
100
170
  constant=float(0.0),
101
171
  )
172
+
173
+ if isinstance(scalar.eval_expression, sc.ZeroScalar):
174
+ return sc.EvaluatedScalar.zero(
175
+ num_basis_scalars=self._num_basis_scalars,
176
+ num_basis_points=self._num_basis_points,
177
+ )
178
+
102
179
  op = scalar.eval_expression.op
180
+ # The special inner product usage.
181
+ if (
182
+ op == utils.Op.MUL
183
+ and isinstance(scalar.eval_expression.left_scalar, pt.Point)
184
+ and isinstance(scalar.eval_expression.right_scalar, pt.Point)
185
+ ):
186
+ array = np.zeros(self._num_basis_scalars)
187
+ return sc.EvaluatedScalar(
188
+ vector=array,
189
+ matrix=utils.SOP(
190
+ self.eval_point(scalar.eval_expression.left_scalar).vector,
191
+ self.eval_point(scalar.eval_expression.right_scalar).vector,
192
+ ),
193
+ constant=float(0.0),
194
+ )
195
+
196
+ left_evaled_scalar = self.eval_scalar(scalar.eval_expression.left_scalar)
197
+ right_evaled_scalar = self.eval_scalar(scalar.eval_expression.right_scalar)
103
198
  if op == utils.Op.ADD:
104
- return self.eval_scalar(
105
- scalar.eval_expression.left_scalar
106
- ) + self.eval_scalar(scalar.eval_expression.right_scalar)
199
+ return left_evaled_scalar + right_evaled_scalar
107
200
  if op == utils.Op.SUB:
108
- return self.eval_scalar(
109
- scalar.eval_expression.left_scalar
110
- ) - self.eval_scalar(scalar.eval_expression.right_scalar)
201
+ return left_evaled_scalar - right_evaled_scalar
111
202
  if op == utils.Op.MUL:
112
- if isinstance(scalar.eval_expression.left_scalar, pt.Point) and isinstance(
113
- scalar.eval_expression.right_scalar, pt.Point
114
- ):
115
- return sc.EvaluatedScalar(
116
- vector=np.zeros(self._num_basis_scalars),
117
- matrix=utils.SOP(
118
- self.eval_point(scalar.eval_expression.left_scalar).vector,
119
- self.eval_point(scalar.eval_expression.right_scalar).vector,
120
- ),
121
- constant=float(0.0),
122
- )
123
- else:
124
- return self.eval_scalar(
125
- scalar.eval_expression.left_scalar
126
- ) * self.eval_scalar(scalar.eval_expression.right_scalar)
203
+ return left_evaled_scalar * right_evaled_scalar
127
204
  if op == utils.Op.DIV:
128
- return self.eval_scalar(
129
- scalar.eval_expression.left_scalar
130
- ) / self.eval_scalar(scalar.eval_expression.right_scalar)
205
+ return left_evaled_scalar / right_evaled_scalar
206
+
207
+ raise ValueError(f"Encountered unknown {op=} when evaluation the scalar.")
208
+
209
+ @functools.cache
210
+ def repr_point_by_basis(self, point: pt.Point) -> str:
211
+ """
212
+ Express the given :class:`Point` object as the linear combination of
213
+ the basis :class:`Point` objects of the :class:`PEPContext` associated
214
+ with this :class:`ExpressionManager`. This linear combination is
215
+ expressed as a `str` where, to refer to the basis :class:`Point`
216
+ objects, we use their tags.
217
+
218
+ Args:
219
+ point (:class:`Point`): The :class:`Point` object which we want
220
+ to express in terms of the basis :class:`Point` objects.
221
+
222
+ Returns:
223
+ str: The representation of `point` in terms of the basis
224
+ :class:`Point` objects of the :class:`PEPContext` associated
225
+ with this :class:`ExpressionManager`.
226
+ """
227
+ assert isinstance(point, pt.Point)
228
+ evaluated_point = self.eval_point(point)
229
+ return self.repr_evaluated_point_by_basis(evaluated_point)
230
+
231
+ def repr_evaluated_point_by_basis(self, evaluated_point: pt.EvaluatedPoint) -> str:
232
+ """
233
+ Express the given :class:`EvaluatedPoint` object as the linear
234
+ combination of the basis :class:`Point` objects of the
235
+ :class:`PEPContext` associated with this :class:`ExpressionManager`.
236
+ This linear combination is expressed as a `str` where, to refer to the
237
+ basis :class:`Point` objects, we use their tags.
238
+
239
+ Args:
240
+ evaluated_point (:class:`EvaluatedPoint`): The
241
+ :class:`EvaluatedPoint` object which we want to express in
242
+ terms of the basis :class:`Point` objects.
243
+
244
+ Returns:
245
+ str: The representation of `evaluated_point` in terms of
246
+ the basis :class:`Point` objects of the :class:`PEPContext`
247
+ associated with this :class:`ExpressionManager`.
248
+ """
249
+ repr_str = ""
250
+ for i, v in enumerate(evaluated_point.vector):
251
+ ith_tag = self.get_tag_of_basis_point_index(i)
252
+ repr_str += tag_and_coef_to_str(ith_tag, v)
253
+
254
+ # Post processing
255
+ if repr_str == "":
256
+ return "0"
257
+ if repr_str.startswith("+ "):
258
+ repr_str = repr_str[2:]
259
+ if repr_str.startswith("- "):
260
+ repr_str = "-" + repr_str[2:]
261
+ return repr_str.strip()
262
+
263
+ @functools.cache
264
+ def repr_scalar_by_basis(
265
+ self, scalar: sc.Scalar, greedy_square: bool = True
266
+ ) -> str:
267
+ """Express the given :class:`Scalar` object in terms of the basis
268
+ :class:`Point` and :class:`Scalar` objects of the :class:`PEPContext`
269
+ associated with this :class:`ExpressionManager`.
270
+
271
+ A :class:`Scalar` can be formed by linear combinations of basis
272
+ :class:`Scalar` objects. A :class:`Scalar` can also be formed through
273
+ the inner product of two basis :class:`Point` objects. This function
274
+ returns the representation of this :class:`Scalar` object in terms of
275
+ the basis :class:`Point` and :class:`Scalar` objects as a `str` where,
276
+ to refer to the basis :class:`Point` and :class:`Scalar` objects,
277
+ we use their tags.
278
+
279
+ Args:
280
+ scalar (:class:`Scalar`): The :class:`Scalar` object which we want
281
+ to express in terms of the basis :class:`Point` and
282
+ :class:`Scalar` objects.
283
+ greedy_square (bool): If `greedy_square` is true, the function will
284
+ try to return :math:`\\|a-b\\|^2` whenever possible. If not,
285
+ the function will return
286
+ :math:`\\|a\\|^2 - 2 * \\langle a, b \\rangle + \\|b\\|^2` instead.
287
+ `True` by default.
288
+
289
+ Returns:
290
+ str: The representation of `scalar` in terms of the basis
291
+ :class:`Point` and :class:`Scalar` objects of the
292
+ :class:`PEPContext` associated with this
293
+ :class:`ExpressionManager`.
294
+ """
295
+ assert isinstance(scalar, sc.Scalar)
296
+ evaluated_scalar = self.eval_scalar(scalar)
297
+ return self.repr_evaluated_scalar_by_basis(
298
+ evaluated_scalar, greedy_square=greedy_square
299
+ )
300
+
301
+ def repr_evaluated_scalar_by_basis(
302
+ self, evaluated_scalar: sc.EvaluatedScalar, greedy_square: bool = True
303
+ ) -> str:
304
+ """Express the given :class:`EvaluatedScalar` object in terms of the
305
+ basis :class:`Point` and :class:`Scalar` objects of the
306
+ :class:`PEPContext` associated with this :class:`ExpressionManager`.
307
+
308
+ A :class:`Scalar` can be formed by linear combinations of basis
309
+ :class:`Scalar` objects. A :class:`Scalar` can also be formed through
310
+ the inner product of two basis :class:`Point` objects. This function
311
+ returns the representation of this :class:`Scalar` object in terms of
312
+ the basis :class:`Point` and :class:`Scalar` objects as a `str` where,
313
+ to refer to the basis :class:`Point` and :class:`Scalar` objects,
314
+ we use their tags.
315
+
316
+ Args:
317
+ evaluated_scalar (:class:`EvaluatedScalar`): The
318
+ :class:`EvaluatedScalar` object which we want to express in
319
+ terms of the basis :class:`Point` and :class:`Scalar` objects.
320
+ greedy_square (bool): If `greedy_square` is true, the function will
321
+ try to return :math:`\\|a-b\\|^2` whenever possible. If not,
322
+ the function will return
323
+ :math:`\\|a\\|^2 - 2 * \\langle a, b \\rangle + \\|b\\|^2` instead.
324
+ `True` by default.
325
+
326
+ Returns:
327
+ str: The representation of `evaluated_scalar` in terms of
328
+ the basis :class:`Point` and :class:`Scalar` objects of the
329
+ :class:`PEPContext` associated with this :class:`ExpressionManager`.
330
+ """
331
+ repr_str = ""
332
+ if not math.isclose(evaluated_scalar.constant, 0, abs_tol=1e-5):
333
+ repr_str += utils.numerical_str(evaluated_scalar.constant)
334
+
335
+ for i, v in enumerate(evaluated_scalar.vector):
336
+ # Note the tag is from scalar basis.
337
+ ith_tag = self.get_tag_of_basis_scalar_index(i)
338
+ repr_str += tag_and_coef_to_str(ith_tag, v)
339
+
340
+ if greedy_square:
341
+ diag_elem = np.diag(evaluated_scalar.matrix).copy()
342
+ for i in range(evaluated_scalar.matrix.shape[0]):
343
+ ith_tag = self.get_tag_of_basis_point_index(i)
344
+ # j starts from i+1 since we want to handle the diag elem at last.
345
+ for j in range(i + 1, evaluated_scalar.matrix.shape[0]):
346
+ jth_tag = self.get_tag_of_basis_point_index(j)
347
+ v = float(evaluated_scalar.matrix[i, j])
348
+ # We want to minimize the diagonal elements to zero greedily.
349
+ if diag_elem[i] * v > 0: # same sign with diagonal elem
350
+ diag_elem[i] -= v
351
+ diag_elem[j] -= v
352
+ repr_str += tag_and_coef_to_str(f"|{ith_tag}+{jth_tag}|^2", v)
353
+ else: # different sign
354
+ diag_elem[i] += v
355
+ diag_elem[j] += v
356
+ repr_str += tag_and_coef_to_str(f"|{ith_tag}-{jth_tag}|^2", -v)
357
+ # Handle the diagonal elements
358
+ repr_str += tag_and_coef_to_str(f"|{ith_tag}|^2", diag_elem[i])
359
+ else:
360
+ for i in range(evaluated_scalar.matrix.shape[0]):
361
+ ith_tag = self.get_tag_of_basis_point_index(i)
362
+ for j in range(i, evaluated_scalar.matrix.shape[0]):
363
+ jth_tag = self.get_tag_of_basis_point_index(j)
364
+ v = float(evaluated_scalar.matrix[i, j])
365
+ if i == j:
366
+ repr_str += tag_and_coef_to_str(f"|{ith_tag}|^2", v)
367
+ else:
368
+ repr_str += tag_and_coef_to_str(
369
+ f"<{ith_tag}, {jth_tag}>", 2 * v
370
+ )
371
+
372
+ # Post processing
373
+ if repr_str == "":
374
+ return "0"
375
+ if repr_str.startswith("+ "):
376
+ repr_str = repr_str[2:]
377
+ if repr_str.startswith("- "):
378
+ repr_str = "-" + repr_str[2:]
379
+ return repr_str.strip()
380
+
381
+
382
+ def represent_matrix_by_basis(matrix: np.ndarray, ctx: pc.PEPContext) -> str:
383
+ """Express the given matrix in terms of the basis :class:`Point` objects
384
+ of the given :class:`PEPContext`.
385
+
386
+ The concrete representation of the inner product of two abstract
387
+ basis :class:`Point` objects is a matrix (the outer product of the
388
+ basis vectors corresponding to the concrete representations of the abstract
389
+ basis :class:`Point` objects). The given matrix can then be expressed
390
+ as a linear combination of the inner products of abstract basis
391
+ :class:`Point` objects. This is provided as a `str` where, to refer to
392
+ the basis :class:`Point` objects, we use their tags.
393
+
394
+ Args:
395
+ matrix (np.ndarray): The matrix which we want to express in terms of
396
+ the basis :class:`Point` objects of the given :class:`PEPContext`.
397
+ ctx (:class:`PEPContext`): The :class:`PEPContext` whose basis
398
+ :class:`Point` objects we consider.
399
+
400
+ Returns:
401
+ str: The representation of `matrix` in terms of the basis
402
+ :class:`Point` objects of `ctx`.
403
+ """
404
+ em = ExpressionManager(ctx)
405
+ matrix_shape = (len(em._basis_points), len(em._basis_points))
406
+ if matrix.shape != matrix_shape:
407
+ raise ValueError(
408
+ "The valid matrix for given context should have shape {matrix_shape}"
409
+ )
410
+ if not np.allclose(matrix, matrix.T):
411
+ raise ValueError("Input matrix must be symmetric.")
131
412
 
132
- raise ValueError("This should never happen!")
413
+ return em.repr_evaluated_scalar_by_basis(
414
+ sc.EvaluatedScalar(
415
+ vector=np.zeros(len(em._basis_scalars)), matrix=matrix, constant=0.0
416
+ )
417
+ )
@@ -0,0 +1,150 @@
1
+ # Copyright: 2025 The PEPFlow Developers
2
+ #
3
+ # Licensed to the Apache Software Foundation (ASF) under one
4
+ # or more contributor license agreements. See the NOTICE file
5
+ # distributed with this work for additional information
6
+ # regarding copyright ownership. The ASF licenses this file
7
+ # to you under the Apache License, Version 2.0 (the
8
+ # "License"); you may not use this file except in compliance
9
+ # with the License. You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing,
14
+ # software distributed under the License is distributed on an
15
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16
+ # KIND, either express or implied. See the License for the
17
+ # specific language governing permissions and limitations
18
+ # under the License.
19
+
20
+
21
+ from typing import Iterator
22
+
23
+ import numpy as np
24
+ import pytest
25
+
26
+ from pepflow import expression_manager as exm
27
+ from pepflow import function as fc
28
+ from pepflow import pep as pep
29
+ from pepflow import pep_context as pc
30
+ from pepflow import point as pt
31
+
32
+
33
+ @pytest.fixture
34
+ def pep_context() -> Iterator[pc.PEPContext]:
35
+ """Prepare the pep context and reset the context to None at the end."""
36
+ ctx = pc.PEPContext("test").set_as_current()
37
+ yield ctx
38
+ pc.set_current_context(None)
39
+
40
+
41
+ def test_repr_point_by_basis(pep_context: pc.PEPContext) -> None:
42
+ x = pt.Point(is_basis=True, tags=["x_0"])
43
+ f = fc.Function(is_basis=True, tags=["f"])
44
+ L = 0.5
45
+ for i in range(2):
46
+ x = x - L * f.gradient(x)
47
+ x.add_tag(f"x_{i + 1}")
48
+
49
+ em = exm.ExpressionManager(pep_context)
50
+ np.testing.assert_allclose(em.eval_point(x).vector, [1, -0.5, -0.5])
51
+ assert (
52
+ em.repr_point_by_basis(x) == "x_0 - 0.5*gradient_f(x_0) - 0.5*gradient_f(x_1)"
53
+ )
54
+
55
+
56
+ def test_repr_point_by_basis_with_zero(pep_context: pc.PEPContext) -> None:
57
+ x = pt.Point(is_basis=True, tags=["x_0"])
58
+ _ = pt.Point(is_basis=True, tags=["x_unused"]) # Add this extra point.
59
+ f = fc.Function(is_basis=True, tags=["f"])
60
+ L = 0.5
61
+ for i in range(2):
62
+ x = x - L * f.gradient(x)
63
+ x.add_tag(f"x_{i + 1}")
64
+
65
+ em = exm.ExpressionManager(pep_context)
66
+ # Note the vector representation of point is different from previous case
67
+ # But the string representation is still the same.
68
+ np.testing.assert_allclose(em.eval_point(x).vector, [1, 0, -0.5, -0.5])
69
+ assert (
70
+ em.repr_point_by_basis(x) == "x_0 - 0.5*gradient_f(x_0) - 0.5*gradient_f(x_1)"
71
+ )
72
+
73
+
74
+ def test_repr_point_by_basis_heavy_ball(pep_context: pc.PEPContext) -> None:
75
+ x_prev = pt.Point(is_basis=True, tags=["x_{-1}"])
76
+ x = pt.Point(is_basis=True, tags=["x_0"])
77
+ f = fc.Function(is_basis=True, tags=["f"])
78
+
79
+ beta = 0.5
80
+ for i in range(2):
81
+ x_next = x - f.gradient(x) + beta * (x - x_prev)
82
+ x_next.add_tag(f"x_{i + 1}")
83
+ x_prev = x
84
+ x = x_next
85
+
86
+ em = exm.ExpressionManager(pep_context)
87
+ np.testing.assert_allclose(em.eval_point(x).vector, [-0.75, 1.75, -1.5, -1])
88
+ assert (
89
+ em.repr_point_by_basis(x)
90
+ == "-0.75*x_{-1} + 1.75*x_0 - 1.5*gradient_f(x_0) - gradient_f(x_1)"
91
+ )
92
+
93
+
94
+ def test_repr_scalar_by_basis(pep_context: pc.PEPContext) -> None:
95
+ x = pt.Point(is_basis=True, tags=["x"])
96
+ f = fc.Function(is_basis=True, tags=["f"])
97
+
98
+ s = f(x) + x * f.gradient(x)
99
+ em = exm.ExpressionManager(pep_context)
100
+ assert (
101
+ em.repr_scalar_by_basis(s, greedy_square=False) == "f(x) + <x, gradient_f(x)>"
102
+ )
103
+ assert (
104
+ em.repr_scalar_by_basis(s)
105
+ == "f(x) - 0.5*|x-gradient_f(x)|^2 + 0.5*|x|^2 + 0.5*|gradient_f(x)|^2"
106
+ )
107
+
108
+
109
+ def test_repr_scalar_by_basis2(pep_context: pc.PEPContext) -> None:
110
+ x = pt.Point(is_basis=True, tags=["x"])
111
+ f = fc.Function(is_basis=True, tags=["f"])
112
+
113
+ s = f(x) - x * f.gradient(x)
114
+ em = exm.ExpressionManager(pep_context)
115
+ assert (
116
+ em.repr_scalar_by_basis(s, greedy_square=False) == "f(x) - <x, gradient_f(x)>"
117
+ )
118
+ assert (
119
+ em.repr_scalar_by_basis(s)
120
+ == "f(x) + 0.5*|x-gradient_f(x)|^2 - 0.5*|x|^2 - 0.5*|gradient_f(x)|^2"
121
+ )
122
+
123
+
124
+ def test_repr_scalar_by_basis_interpolation(pep_context: pc.PEPContext) -> None:
125
+ xi = pt.Point(is_basis=True, tags=["x_i"])
126
+ xj = pt.Point(is_basis=True, tags=["x_j"])
127
+ f = fc.SmoothConvexFunction(is_basis=True, L=1)
128
+ f.add_tag("f")
129
+ fi = f(xi) # noqa: F841
130
+ fj = f(xj) # noqa: F841
131
+ interp_scalar = f.interpolate_ineq("x_i", "x_j")
132
+ em = exm.ExpressionManager(pep_context)
133
+ expected_repr = "-f(x_i) + f(x_j) + <x_i, gradient_f(x_j)> - <x_j, gradient_f(x_j)> + 0.5*|gradient_f(x_i)|^2 - <gradient_f(x_i), gradient_f(x_j)> + 0.5*|gradient_f(x_j)|^2"
134
+ assert em.repr_scalar_by_basis(interp_scalar, greedy_square=False) == expected_repr
135
+ expected_square_repr = "-f(x_i) + f(x_j) - 0.5*|x_i-gradient_f(x_j)|^2 + 0.5*|x_i|^2 + 0.5*|x_j-gradient_f(x_j)|^2 - 0.5*|x_j|^2 + 0.5*|gradient_f(x_i)-gradient_f(x_j)|^2"
136
+ assert em.repr_scalar_by_basis(interp_scalar) == expected_square_repr
137
+
138
+
139
+ # TODO add more tests about repr_scalar_by_basis
140
+
141
+
142
+ def test_represent_matrix_by_basis(pep_context: pc.PEPContext) -> None:
143
+ _ = pt.Point(is_basis=True, tags=["x_1"])
144
+ _ = pt.Point(is_basis=True, tags=["x_2"])
145
+ _ = pt.Point(is_basis=True, tags=["x_3"])
146
+ matrix = np.array([[0.5, 0.5, 0], [0.5, 2, 0], [0, 0, 3]])
147
+ assert (
148
+ exm.represent_matrix_by_basis(matrix, pep_context)
149
+ == "0.5*|x_1+x_2|^2 + 1.5*|x_2|^2 + 3*|x_3|^2"
150
+ )