pepflow 0.1.4a1__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pepflow/__init__.py CHANGED
@@ -21,6 +21,7 @@
21
21
  from .constants import PSD_CONSTRAINT as PSD_CONSTRAINT
22
22
  from .constraint import Constraint as Constraint
23
23
  from .expression_manager import ExpressionManager as ExpressionManager
24
+ from .expression_manager import represent_matrix_by_basis as represent_matrix_by_basis
24
25
 
25
26
  # interactive_constraint
26
27
  from .interactive_constraint import launch as launch
@@ -28,6 +29,7 @@ from .interactive_constraint import launch as launch
28
29
  # pep
29
30
  from .pep import PEPBuilder as PEPBuilder
30
31
  from .pep import PEPResult as PEPResult
32
+ from .pep import DualPEPResult as DualPEPResult
31
33
  from .pep_context import PEPContext as PEPContext
32
34
  from .pep_context import get_current_context as get_current_context
33
35
  from .pep_context import set_current_context as set_current_context
@@ -43,8 +45,10 @@ from .scalar import EvaluatedScalar as EvaluatedScalar
43
45
  from .scalar import Scalar as Scalar
44
46
 
45
47
  # Solver
46
- from .solver import CVXSolver as CVXSolver
48
+ from .solver import CVXPrimalSolver as CVXPrimalSolver
49
+ from .solver import CVXDualSolver as CVXDualSolver
47
50
  from .solver import DualVariableManager as DualVariableManager
51
+ from .solver import PrimalVariableManager as PrimalVariableManager
48
52
 
49
53
  # Others
50
54
  from .utils import SOP as SOP
pepflow/constraint.py CHANGED
@@ -31,8 +31,65 @@ from pepflow import utils
31
31
 
32
32
  @attrs.frozen
33
33
  class Constraint:
34
- """It represents `expression relation 0`."""
34
+ """A :class:`Constraint` object that represents inequalities and
35
+ equalities of :class:`Scalar` objects.
36
+
37
+ Denote an arbitrary :class:`Scalar` object as `x`. Constraints represent:
38
+ `x <= 0`, `x >= 0`, and `x = 0`.
39
+
40
+ Attributes:
41
+ scalar (:class:`Scalar`): The :class:`Scalar` object involved in
42
+ the inequality or equality.
43
+ comparator (:class:`Comparator`): :class:`Comparator` is an enumeration
44
+ that can be either `GT`, `LT`, or `EQ`. They represent `>=`, `<=`,
45
+ and `=` respectively.
46
+ name (str): The unique name of the :class:`Comparator` object.
47
+ associated_dual_var_constraints (list[tuple[utils.Comparator, float]]):
48
+ A list of all the constraints imposed on the associated dual
49
+ variable of this :class:`Constraint` object.
50
+ """
35
51
 
36
52
  scalar: Scalar | float
37
53
  comparator: utils.Comparator
38
54
  name: str
55
+
56
+ # Used to represent the constraint on primal variable in dual PEP.
57
+ associated_dual_var_constraints: list[tuple[utils.Comparator, float]] = attrs.field(
58
+ factory=list
59
+ )
60
+
61
+ def dual_lt(self, val: float) -> None:
62
+ """
63
+ Denote the associated dual variable of this constraint as `lambd`.
64
+ This generates a relation of the form `lambd <= val`.
65
+
66
+ Args:
67
+ val (float): The other object in the relation.
68
+ """
69
+ if not utils.is_numerical(val):
70
+ raise ValueError(f"The input {val=} must be a numerical value")
71
+ self.associated_dual_var_constraints.append((utils.Comparator.LT, val))
72
+
73
+ def dual_gt(self, val: float) -> None:
74
+ """
75
+ Denote the associated dual variable of this constraint as `lambd`.
76
+ This generates a relation of the form `lambd >= val`.
77
+
78
+ Args:
79
+ val (float): The other object in the relation.
80
+ """
81
+ if not utils.is_numerical(val):
82
+ raise ValueError(f"The input {val=} must be a numerical value")
83
+ self.associated_dual_var_constraints.append((utils.Comparator.GT, val))
84
+
85
+ def dual_eq(self, val: float) -> None:
86
+ """
87
+ Denote the associated dual variable of this constraint as `lambd`.
88
+ This generates a relation of the form `lambd = val`.
89
+
90
+ Args:
91
+ val (float): The other object in the relation.
92
+ """
93
+ if not utils.is_numerical(val):
94
+ raise ValueError(f"The input {val=} must be a numerical value")
95
+ self.associated_dual_var_constraint.append((utils.Comparator.EQ, val))
pepflow/e2e_test.py CHANGED
@@ -1,6 +1,8 @@
1
1
  import math
2
2
 
3
- from pepflow import function, pep
3
+ from pepflow import function
4
+ from pepflow import parameter as pm
5
+ from pepflow import pep
4
6
  from pepflow import pep_context as pc
5
7
 
6
8
 
@@ -28,10 +30,47 @@ def test_gd_e2e():
28
30
  pep_builder.set_performance_metric(
29
31
  f.function_value(p) - f.function_value(x_star)
30
32
  )
31
- result = pep_builder.solve()
33
+ result = pep_builder.solve_primal()
32
34
  expected_opt_value = 1 / (4 * i + 2)
33
35
  assert math.isclose(result.primal_opt_value, expected_opt_value, rel_tol=1e-3)
34
36
 
37
+ dual_result = pep_builder.solve_dual()
38
+ assert math.isclose(
39
+ dual_result.dual_opt_value, expected_opt_value, rel_tol=1e-3
40
+ )
41
+
42
+
43
+ def test_gd_diff_stepsize_e2e():
44
+ pc.PEPContext("gd").set_as_current()
45
+ pep_builder = pep.PEPBuilder()
46
+ eta = 1 / pm.Parameter(name="L")
47
+ N = 4
48
+
49
+ f = pep_builder.declare_func(
50
+ function.SmoothConvexFunction, "f", L=pm.Parameter(name="L")
51
+ )
52
+ x = pep_builder.set_init_point("x_0")
53
+ x_star = f.add_stationary_point("x_star")
54
+ pep_builder.set_initial_constraint(
55
+ ((x - x_star) ** 2).le(1, name="initial_condition")
56
+ )
57
+
58
+ # We first build the algorithm with the largest number of iterations.
59
+ for i in range(N):
60
+ x = x - eta * f.gradient(x)
61
+ x.add_tag(f"x_{i + 1}")
62
+ pep_builder.set_performance_metric(f(x) - f(x_star))
63
+
64
+ for l_val in [1, 4, 0.25]:
65
+ result = pep_builder.solve_primal(resolve_parameters={"L": l_val})
66
+ expected_opt_value = l_val / (4 * N + 2)
67
+ assert math.isclose(result.primal_opt_value, expected_opt_value, rel_tol=1e-3)
68
+
69
+ dual_result = pep_builder.solve_dual(resolve_parameters={"L": l_val})
70
+ assert math.isclose(
71
+ dual_result.dual_opt_value, expected_opt_value, rel_tol=1e-3
72
+ )
73
+
35
74
 
36
75
  def test_pgm_e2e():
37
76
  ctx = pc.PEPContext("pgm").set_as_current()
@@ -64,6 +103,11 @@ def test_pgm_e2e():
64
103
  h.function_value(p) - h.function_value(x_star)
65
104
  )
66
105
 
67
- result = pep_builder.solve()
106
+ result = pep_builder.solve_primal()
68
107
  expected_opt_value = 1 / (4 * i)
69
108
  assert math.isclose(result.primal_opt_value, expected_opt_value, rel_tol=1e-3)
109
+
110
+ dual_result = pep_builder.solve_dual()
111
+ assert math.isclose(
112
+ dual_result.dual_opt_value, expected_opt_value, rel_tol=1e-3
113
+ )
@@ -22,6 +22,7 @@ import math
22
22
 
23
23
  import numpy as np
24
24
 
25
+ from pepflow import parameter as pm
25
26
  from pepflow import pep_context as pc
26
27
  from pepflow import point as pt
27
28
  from pepflow import scalar as sc
@@ -29,23 +30,41 @@ from pepflow import utils
29
30
 
30
31
 
31
32
  def tag_and_coef_to_str(tag: str, v: float) -> str:
32
- coef = f"{abs(v):.3g}"
33
+ coef = utils.numerical_str(abs(v))
33
34
  sign = "+" if v >= 0 else "-"
34
35
  if math.isclose(abs(v), 1):
35
36
  return f"{sign} {tag} "
36
- elif math.isclose(v, 0):
37
+ elif math.isclose(v, 0, abs_tol=1e-5):
37
38
  return ""
38
39
  else:
39
40
  return f"{sign} {coef}*{tag} "
40
41
 
41
42
 
42
43
  class ExpressionManager:
43
- def __init__(self, pep_context: pc.PEPContext):
44
+ """
45
+ A class that handles the concrete representations of abstract
46
+ :class:`Point` and :class:`Scalar` objects managed by a particular
47
+ :class:`PEPContext` object.
48
+
49
+ Attributes:
50
+ context (:class:`PEPContext`): The :class:`PEPContext` object which
51
+ manages the abstract :class:`Point` and :class:`Scalar` objects
52
+ of interest.
53
+ resolve_parameters (dict[str, :class:`NUMERICAL_TYPE`]): A dictionary that
54
+ maps the name of parameters to the numerical values.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ pep_context: pc.PEPContext,
60
+ resolve_parameters: dict[str, utils.NUMERICAL_TYPE] | None = None,
61
+ ):
44
62
  self.context = pep_context
45
63
  self._basis_points = []
46
64
  self._basis_point_uid_to_index = {}
47
65
  self._basis_scalars = []
48
66
  self._basis_scalar_uid_to_index = {}
67
+ self.resolve_parameters = resolve_parameters or {}
49
68
  for point in self.context.points:
50
69
  if point.is_basis:
51
70
  self._basis_points.append(point)
@@ -74,88 +93,161 @@ class ExpressionManager:
74
93
 
75
94
  @functools.cache
76
95
  def eval_point(self, point: pt.Point | float | int):
96
+ """
97
+ Return the concrete representation of the given :class:`Point`,
98
+ `float`, or `int`. Concrete representations of :class:`Point` objects
99
+ are :class:`EvaluatedPoint` objects. Concrete representations of
100
+ `float` or `int` arguments are themselves.
101
+
102
+ Args:
103
+ point (:class:`Point`, float, int): The abstract :class:`Point`,
104
+ `float`, or `int` object whose concrete representation we want
105
+ to find.
106
+
107
+ Returns:
108
+ :class:`EvaluatedPoint` | float | int: The concrete representation
109
+ the `point` argument.
110
+ """
77
111
  if utils.is_numerical(point):
78
112
  return point
79
113
 
80
- array = np.zeros(self._num_basis_points)
114
+ if isinstance(point, pm.Parameter):
115
+ return point.get_value(self.resolve_parameters)
116
+
81
117
  if point.is_basis:
82
118
  index = self.get_index_of_basis_point(point)
119
+ array = np.zeros(self._num_basis_points)
83
120
  array[index] = 1
84
121
  return pt.EvaluatedPoint(vector=array)
85
122
 
123
+ if isinstance(point.eval_expression, pt.ZeroPoint):
124
+ return pt.EvaluatedPoint.zero(num_basis_points=self._num_basis_points)
125
+
86
126
  op = point.eval_expression.op
127
+ left_evaled_point = self.eval_point(point.eval_expression.left_point)
128
+ right_evaled_point = self.eval_point(point.eval_expression.right_point)
87
129
  if op == utils.Op.ADD:
88
- return self.eval_point(point.eval_expression.left_point) + self.eval_point(
89
- point.eval_expression.right_point
90
- )
130
+ return left_evaled_point + right_evaled_point
91
131
  if op == utils.Op.SUB:
92
- return self.eval_point(point.eval_expression.left_point) - self.eval_point(
93
- point.eval_expression.right_point
94
- )
132
+ return left_evaled_point - right_evaled_point
95
133
  if op == utils.Op.MUL:
96
- return self.eval_point(point.eval_expression.left_point) * self.eval_point(
97
- point.eval_expression.right_point
98
- )
134
+ return left_evaled_point * right_evaled_point
99
135
  if op == utils.Op.DIV:
100
- return self.eval_point(point.eval_expression.left_point) / self.eval_point(
101
- point.eval_expression.right_point
102
- )
136
+ return left_evaled_point / right_evaled_point
103
137
 
104
- raise ValueError("This should never happen!")
138
+ raise ValueError(f"Encountered unknown {op=} when evaluation the point.")
105
139
 
106
140
  @functools.cache
107
141
  def eval_scalar(self, scalar: sc.Scalar | float | int):
142
+ """
143
+ Return the concrete representation of the given :class:`Scalar`,
144
+ `float`, or `int`. Concrete representations of :class:`Scalar` objects
145
+ are :class:`EvaluatedScalar` objects. Concrete representations of
146
+ `float` or `int` arguments are themselves.
147
+
148
+ Args:
149
+ scalar (:class:`Point`, float, int): The abstract :class:`Scalar`,
150
+ `float`, or `int` object whose concrete representation we want
151
+ to find.
152
+
153
+ Returns:
154
+ :class:`EvaluatedScalar` | float | int: The concrete representation
155
+ the `scalar` argument.
156
+ """
108
157
  if utils.is_numerical(scalar):
109
158
  return scalar
159
+ if isinstance(scalar, pm.Parameter):
160
+ return scalar.get_value(self.resolve_parameters)
110
161
 
111
- array = np.zeros(self._num_basis_scalars)
112
162
  if scalar.is_basis:
113
163
  index = self.get_index_of_basis_scalar(scalar)
164
+ array = np.zeros(self._num_basis_scalars)
114
165
  array[index] = 1
166
+ matrix = np.zeros((self._num_basis_points, self._num_basis_points))
115
167
  return sc.EvaluatedScalar(
116
168
  vector=array,
117
- matrix=np.zeros((self._num_basis_points, self._num_basis_points)),
169
+ matrix=matrix,
118
170
  constant=float(0.0),
119
171
  )
172
+
173
+ if isinstance(scalar.eval_expression, sc.ZeroScalar):
174
+ return sc.EvaluatedScalar.zero(
175
+ num_basis_scalars=self._num_basis_scalars,
176
+ num_basis_points=self._num_basis_points,
177
+ )
178
+
120
179
  op = scalar.eval_expression.op
180
+ # The special inner product usage.
181
+ if (
182
+ op == utils.Op.MUL
183
+ and isinstance(scalar.eval_expression.left_scalar, pt.Point)
184
+ and isinstance(scalar.eval_expression.right_scalar, pt.Point)
185
+ ):
186
+ array = np.zeros(self._num_basis_scalars)
187
+ return sc.EvaluatedScalar(
188
+ vector=array,
189
+ matrix=utils.SOP(
190
+ self.eval_point(scalar.eval_expression.left_scalar).vector,
191
+ self.eval_point(scalar.eval_expression.right_scalar).vector,
192
+ ),
193
+ constant=float(0.0),
194
+ )
195
+
196
+ left_evaled_scalar = self.eval_scalar(scalar.eval_expression.left_scalar)
197
+ right_evaled_scalar = self.eval_scalar(scalar.eval_expression.right_scalar)
121
198
  if op == utils.Op.ADD:
122
- return self.eval_scalar(
123
- scalar.eval_expression.left_scalar
124
- ) + self.eval_scalar(scalar.eval_expression.right_scalar)
199
+ return left_evaled_scalar + right_evaled_scalar
125
200
  if op == utils.Op.SUB:
126
- return self.eval_scalar(
127
- scalar.eval_expression.left_scalar
128
- ) - self.eval_scalar(scalar.eval_expression.right_scalar)
201
+ return left_evaled_scalar - right_evaled_scalar
129
202
  if op == utils.Op.MUL:
130
- if isinstance(scalar.eval_expression.left_scalar, pt.Point) and isinstance(
131
- scalar.eval_expression.right_scalar, pt.Point
132
- ):
133
- return sc.EvaluatedScalar(
134
- vector=np.zeros(self._num_basis_scalars),
135
- matrix=utils.SOP(
136
- self.eval_point(scalar.eval_expression.left_scalar).vector,
137
- self.eval_point(scalar.eval_expression.right_scalar).vector,
138
- ),
139
- constant=float(0.0),
140
- )
141
- else:
142
- return self.eval_scalar(
143
- scalar.eval_expression.left_scalar
144
- ) * self.eval_scalar(scalar.eval_expression.right_scalar)
203
+ return left_evaled_scalar * right_evaled_scalar
145
204
  if op == utils.Op.DIV:
146
- return self.eval_scalar(
147
- scalar.eval_expression.left_scalar
148
- ) / self.eval_scalar(scalar.eval_expression.right_scalar)
205
+ return left_evaled_scalar / right_evaled_scalar
149
206
 
150
- raise ValueError("This should never happen!")
207
+ raise ValueError(f"Encountered unknown {op=} when evaluation the scalar.")
151
208
 
152
209
  @functools.cache
153
210
  def repr_point_by_basis(self, point: pt.Point) -> str:
211
+ """
212
+ Express the given :class:`Point` object as the linear combination of
213
+ the basis :class:`Point` objects of the :class:`PEPContext` associated
214
+ with this :class:`ExpressionManager`. This linear combination is
215
+ expressed as a `str` where, to refer to the basis :class:`Point`
216
+ objects, we use their tags.
217
+
218
+ Args:
219
+ point (:class:`Point`): The :class:`Point` object which we want
220
+ to express in terms of the basis :class:`Point` objects.
221
+
222
+ Returns:
223
+ str: The representation of `point` in terms of the basis
224
+ :class:`Point` objects of the :class:`PEPContext` associated
225
+ with this :class:`ExpressionManager`.
226
+ """
154
227
  assert isinstance(point, pt.Point)
155
- repr_array = self.eval_point(point).vector
228
+ evaluated_point = self.eval_point(point)
229
+ return self.repr_evaluated_point_by_basis(evaluated_point)
230
+
231
+ def repr_evaluated_point_by_basis(self, evaluated_point: pt.EvaluatedPoint) -> str:
232
+ """
233
+ Express the given :class:`EvaluatedPoint` object as the linear
234
+ combination of the basis :class:`Point` objects of the
235
+ :class:`PEPContext` associated with this :class:`ExpressionManager`.
236
+ This linear combination is expressed as a `str` where, to refer to the
237
+ basis :class:`Point` objects, we use their tags.
156
238
 
239
+ Args:
240
+ evaluated_point (:class:`EvaluatedPoint`): The
241
+ :class:`EvaluatedPoint` object which we want to express in
242
+ terms of the basis :class:`Point` objects.
243
+
244
+ Returns:
245
+ str: The representation of `evaluated_point` in terms of
246
+ the basis :class:`Point` objects of the :class:`PEPContext`
247
+ associated with this :class:`ExpressionManager`.
248
+ """
157
249
  repr_str = ""
158
- for i, v in enumerate(repr_array):
250
+ for i, v in enumerate(evaluated_point.vector):
159
251
  ith_tag = self.get_tag_of_basis_point_index(i)
160
252
  repr_str += tag_and_coef_to_str(ith_tag, v)
161
253
 
@@ -169,28 +261,113 @@ class ExpressionManager:
169
261
  return repr_str.strip()
170
262
 
171
263
  @functools.cache
172
- def repr_scalar_by_basis(self, scalar: sc.Scalar) -> str:
264
+ def repr_scalar_by_basis(
265
+ self, scalar: sc.Scalar, greedy_square: bool = True
266
+ ) -> str:
267
+ """Express the given :class:`Scalar` object in terms of the basis
268
+ :class:`Point` and :class:`Scalar` objects of the :class:`PEPContext`
269
+ associated with this :class:`ExpressionManager`.
270
+
271
+ A :class:`Scalar` can be formed by linear combinations of basis
272
+ :class:`Scalar` objects. A :class:`Scalar` can also be formed through
273
+ the inner product of two basis :class:`Point` objects. This function
274
+ returns the representation of this :class:`Scalar` object in terms of
275
+ the basis :class:`Point` and :class:`Scalar` objects as a `str` where,
276
+ to refer to the basis :class:`Point` and :class:`Scalar` objects,
277
+ we use their tags.
278
+
279
+ Args:
280
+ scalar (:class:`Scalar`): The :class:`Scalar` object which we want
281
+ to express in terms of the basis :class:`Point` and
282
+ :class:`Scalar` objects.
283
+ greedy_square (bool): If `greedy_square` is true, the function will
284
+ try to return :math:`\\|a-b\\|^2` whenever possible. If not,
285
+ the function will return
286
+ :math:`\\|a\\|^2 - 2 * \\langle a, b \\rangle + \\|b\\|^2` instead.
287
+ `True` by default.
288
+
289
+ Returns:
290
+ str: The representation of `scalar` in terms of the basis
291
+ :class:`Point` and :class:`Scalar` objects of the
292
+ :class:`PEPContext` associated with this
293
+ :class:`ExpressionManager`.
294
+ """
173
295
  assert isinstance(scalar, sc.Scalar)
174
296
  evaluated_scalar = self.eval_scalar(scalar)
297
+ return self.repr_evaluated_scalar_by_basis(
298
+ evaluated_scalar, greedy_square=greedy_square
299
+ )
300
+
301
+ def repr_evaluated_scalar_by_basis(
302
+ self, evaluated_scalar: sc.EvaluatedScalar, greedy_square: bool = True
303
+ ) -> str:
304
+ """Express the given :class:`EvaluatedScalar` object in terms of the
305
+ basis :class:`Point` and :class:`Scalar` objects of the
306
+ :class:`PEPContext` associated with this :class:`ExpressionManager`.
175
307
 
308
+ A :class:`Scalar` can be formed by linear combinations of basis
309
+ :class:`Scalar` objects. A :class:`Scalar` can also be formed through
310
+ the inner product of two basis :class:`Point` objects. This function
311
+ returns the representation of this :class:`Scalar` object in terms of
312
+ the basis :class:`Point` and :class:`Scalar` objects as a `str` where,
313
+ to refer to the basis :class:`Point` and :class:`Scalar` objects,
314
+ we use their tags.
315
+
316
+ Args:
317
+ evaluated_scalar (:class:`EvaluatedScalar`): The
318
+ :class:`EvaluatedScalar` object which we want to express in
319
+ terms of the basis :class:`Point` and :class:`Scalar` objects.
320
+ greedy_square (bool): If `greedy_square` is true, the function will
321
+ try to return :math:`\\|a-b\\|^2` whenever possible. If not,
322
+ the function will return
323
+ :math:`\\|a\\|^2 - 2 * \\langle a, b \\rangle + \\|b\\|^2` instead.
324
+ `True` by default.
325
+
326
+ Returns:
327
+ str: The representation of `evaluated_scalar` in terms of
328
+ the basis :class:`Point` and :class:`Scalar` objects of the
329
+ :class:`PEPContext` associated with this :class:`ExpressionManager`.
330
+ """
176
331
  repr_str = ""
177
- if not math.isclose(evaluated_scalar.constant, 0):
178
- repr_str += f"{evaluated_scalar.constant:.3g}"
332
+ if not math.isclose(evaluated_scalar.constant, 0, abs_tol=1e-5):
333
+ repr_str += utils.numerical_str(evaluated_scalar.constant)
179
334
 
180
335
  for i, v in enumerate(evaluated_scalar.vector):
181
336
  # Note the tag is from scalar basis.
182
337
  ith_tag = self.get_tag_of_basis_scalar_index(i)
183
338
  repr_str += tag_and_coef_to_str(ith_tag, v)
184
339
 
185
- for i in range(evaluated_scalar.matrix.shape[0]):
186
- for j in range(i, evaluated_scalar.matrix.shape[0]):
340
+ if greedy_square:
341
+ diag_elem = np.diag(evaluated_scalar.matrix).copy()
342
+ for i in range(evaluated_scalar.matrix.shape[0]):
343
+ ith_tag = self.get_tag_of_basis_point_index(i)
344
+ # j starts from i+1 since we want to handle the diag elem at last.
345
+ for j in range(i + 1, evaluated_scalar.matrix.shape[0]):
346
+ jth_tag = self.get_tag_of_basis_point_index(j)
347
+ v = float(evaluated_scalar.matrix[i, j])
348
+ # We want to minimize the diagonal elements to zero greedily.
349
+ if diag_elem[i] * v > 0: # same sign with diagonal elem
350
+ diag_elem[i] -= v
351
+ diag_elem[j] -= v
352
+ repr_str += tag_and_coef_to_str(f"|{ith_tag}+{jth_tag}|^2", v)
353
+ else: # different sign
354
+ diag_elem[i] += v
355
+ diag_elem[j] += v
356
+ repr_str += tag_and_coef_to_str(f"|{ith_tag}-{jth_tag}|^2", -v)
357
+ # Handle the diagonal elements
358
+ repr_str += tag_and_coef_to_str(f"|{ith_tag}|^2", diag_elem[i])
359
+ else:
360
+ for i in range(evaluated_scalar.matrix.shape[0]):
187
361
  ith_tag = self.get_tag_of_basis_point_index(i)
188
- v = evaluated_scalar.matrix[i, j]
189
- if i == j:
190
- repr_str += tag_and_coef_to_str(f"|{ith_tag}|^2", v)
191
- continue
192
- jth_tag = self.get_tag_of_basis_point_index(j)
193
- repr_str += tag_and_coef_to_str(f"<{ith_tag}, {jth_tag}>", 2 * v)
362
+ for j in range(i, evaluated_scalar.matrix.shape[0]):
363
+ jth_tag = self.get_tag_of_basis_point_index(j)
364
+ v = float(evaluated_scalar.matrix[i, j])
365
+ if i == j:
366
+ repr_str += tag_and_coef_to_str(f"|{ith_tag}|^2", v)
367
+ else:
368
+ repr_str += tag_and_coef_to_str(
369
+ f"<{ith_tag}, {jth_tag}>", 2 * v
370
+ )
194
371
 
195
372
  # Post processing
196
373
  if repr_str == "":
@@ -200,3 +377,41 @@ class ExpressionManager:
200
377
  if repr_str.startswith("- "):
201
378
  repr_str = "-" + repr_str[2:]
202
379
  return repr_str.strip()
380
+
381
+
382
+ def represent_matrix_by_basis(matrix: np.ndarray, ctx: pc.PEPContext) -> str:
383
+ """Express the given matrix in terms of the basis :class:`Point` objects
384
+ of the given :class:`PEPContext`.
385
+
386
+ The concrete representation of the inner product of two abstract
387
+ basis :class:`Point` objects is a matrix (the outer product of the
388
+ basis vectors corresponding to the concrete representations of the abstract
389
+ basis :class:`Point` objects). The given matrix can then be expressed
390
+ as a linear combination of the inner products of abstract basis
391
+ :class:`Point` objects. This is provided as a `str` where, to refer to
392
+ the basis :class:`Point` objects, we use their tags.
393
+
394
+ Args:
395
+ matrix (np.ndarray): The matrix which we want to express in terms of
396
+ the basis :class:`Point` objects of the given :class:`PEPContext`.
397
+ ctx (:class:`PEPContext`): The :class:`PEPContext` whose basis
398
+ :class:`Point` objects we consider.
399
+
400
+ Returns:
401
+ str: The representation of `matrix` in terms of the basis
402
+ :class:`Point` objects of `ctx`.
403
+ """
404
+ em = ExpressionManager(ctx)
405
+ matrix_shape = (len(em._basis_points), len(em._basis_points))
406
+ if matrix.shape != matrix_shape:
407
+ raise ValueError(
408
+ "The valid matrix for given context should have shape {matrix_shape}"
409
+ )
410
+ if not np.allclose(matrix, matrix.T):
411
+ raise ValueError("Input matrix must be symmetric.")
412
+
413
+ return em.repr_evaluated_scalar_by_basis(
414
+ sc.EvaluatedScalar(
415
+ vector=np.zeros(len(em._basis_scalars)), matrix=matrix, constant=0.0
416
+ )
417
+ )
@@ -97,7 +97,28 @@ def test_repr_scalar_by_basis(pep_context: pc.PEPContext) -> None:
97
97
 
98
98
  s = f(x) + x * f.gradient(x)
99
99
  em = exm.ExpressionManager(pep_context)
100
- assert em.repr_scalar_by_basis(s) == "f(x) + <x, gradient_f(x)>"
100
+ assert (
101
+ em.repr_scalar_by_basis(s, greedy_square=False) == "f(x) + <x, gradient_f(x)>"
102
+ )
103
+ assert (
104
+ em.repr_scalar_by_basis(s)
105
+ == "f(x) - 0.5*|x-gradient_f(x)|^2 + 0.5*|x|^2 + 0.5*|gradient_f(x)|^2"
106
+ )
107
+
108
+
109
+ def test_repr_scalar_by_basis2(pep_context: pc.PEPContext) -> None:
110
+ x = pt.Point(is_basis=True, tags=["x"])
111
+ f = fc.Function(is_basis=True, tags=["f"])
112
+
113
+ s = f(x) - x * f.gradient(x)
114
+ em = exm.ExpressionManager(pep_context)
115
+ assert (
116
+ em.repr_scalar_by_basis(s, greedy_square=False) == "f(x) - <x, gradient_f(x)>"
117
+ )
118
+ assert (
119
+ em.repr_scalar_by_basis(s)
120
+ == "f(x) + 0.5*|x-gradient_f(x)|^2 - 0.5*|x|^2 - 0.5*|gradient_f(x)|^2"
121
+ )
101
122
 
102
123
 
103
124
  def test_repr_scalar_by_basis_interpolation(pep_context: pc.PEPContext) -> None:
@@ -110,7 +131,20 @@ def test_repr_scalar_by_basis_interpolation(pep_context: pc.PEPContext) -> None:
110
131
  interp_scalar = f.interpolate_ineq("x_i", "x_j")
111
132
  em = exm.ExpressionManager(pep_context)
112
133
  expected_repr = "-f(x_i) + f(x_j) + <x_i, gradient_f(x_j)> - <x_j, gradient_f(x_j)> + 0.5*|gradient_f(x_i)|^2 - <gradient_f(x_i), gradient_f(x_j)> + 0.5*|gradient_f(x_j)|^2"
113
- assert em.repr_scalar_by_basis(interp_scalar) == expected_repr
134
+ assert em.repr_scalar_by_basis(interp_scalar, greedy_square=False) == expected_repr
135
+ expected_square_repr = "-f(x_i) + f(x_j) - 0.5*|x_i-gradient_f(x_j)|^2 + 0.5*|x_i|^2 + 0.5*|x_j-gradient_f(x_j)|^2 - 0.5*|x_j|^2 + 0.5*|gradient_f(x_i)-gradient_f(x_j)|^2"
136
+ assert em.repr_scalar_by_basis(interp_scalar) == expected_square_repr
114
137
 
115
138
 
116
139
  # TODO add more tests about repr_scalar_by_basis
140
+
141
+
142
+ def test_represent_matrix_by_basis(pep_context: pc.PEPContext) -> None:
143
+ _ = pt.Point(is_basis=True, tags=["x_1"])
144
+ _ = pt.Point(is_basis=True, tags=["x_2"])
145
+ _ = pt.Point(is_basis=True, tags=["x_3"])
146
+ matrix = np.array([[0.5, 0.5, 0], [0.5, 2, 0], [0, 0, 3]])
147
+ assert (
148
+ exm.represent_matrix_by_basis(matrix, pep_context)
149
+ == "0.5*|x_1+x_2|^2 + 1.5*|x_2|^2 + 3*|x_3|^2"
150
+ )