pepflow 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pepflow/parameter.py ADDED
@@ -0,0 +1,187 @@
1
+ # Copyright: 2025 The PEPFlow Developers
2
+ #
3
+ # Licensed to the Apache Software Foundation (ASF) under one
4
+ # or more contributor license agreements. See the NOTICE file
5
+ # distributed with this work for additional information
6
+ # regarding copyright ownership. The ASF licenses this file
7
+ # to you under the Apache License, Version 2.0 (the
8
+ # "License"); you may not use this file except in compliance
9
+ # with the License. You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing,
14
+ # software distributed under the License is distributed on an
15
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16
+ # KIND, either express or implied. See the License for the
17
+ # specific language governing permissions and limitations
18
+ # under the License.
19
+
20
+ from __future__ import annotations
21
+
22
+ import attrs
23
+
24
+ from pepflow import utils
25
+
26
+ # Sentile of no found of resolving parameters
27
+ NOT_FOUND = "__NOT_FOUND__"
28
+
29
+
30
+ @attrs.frozen
31
+ class ParameterRepresentation:
32
+ op: utils.Op
33
+ left_param: utils.NUMERICAL_TYPE | Parameter
34
+ right_param: utils.NUMERICAL_TYPE | Parameter
35
+
36
+
37
+ def eval_parameter(
38
+ param: Parameter | utils.NUMERICAL_TYPE,
39
+ resolve_parameters: dict[str, utils.NUMERICAL_TYPE],
40
+ ) -> utils.NUMERICAL_TYPE:
41
+ if isinstance(param, Parameter):
42
+ return param.get_value(resolve_parameters)
43
+ if utils.is_numerical(param):
44
+ return param
45
+ raise ValueError(f"Encounter the unknown parameter type: {param} ({type(param)})")
46
+
47
+
48
+ @attrs.frozen
49
+ class Parameter:
50
+ # If name is None, it is a composite parameter.
51
+ name: str | None
52
+
53
+ eval_expression: ParameterRepresentation | None = None
54
+
55
+ def __attrs_post_init__(self):
56
+ if self.name is None and self.eval_expression is None:
57
+ raise ValueError(
58
+ "For a parameter, must specify a name or an eval_expression"
59
+ )
60
+ if self.name is None or self.eval_expression is None:
61
+ return
62
+
63
+ raise ValueError(
64
+ "For a parameter, only one of name or eval_expression should be None."
65
+ )
66
+
67
+ def __repr__(self):
68
+ if self.eval_expression is None:
69
+ return self.name
70
+
71
+ op = self.eval_expression.op
72
+ left_param = self.eval_expression.left_param
73
+ right_param = self.eval_expression.right_param
74
+ # TODO having a better parentheses handling.
75
+ if op == utils.Op.ADD:
76
+ return f"({left_param}+{right_param})"
77
+ if op == utils.Op.SUB:
78
+ return f"({left_param}-{right_param})"
79
+ if op == utils.Op.MUL:
80
+ return f"({left_param}*{right_param})"
81
+ if op == utils.Op.DIV:
82
+ return f"({left_param}/{right_param})"
83
+
84
+ def get_value(
85
+ self, resolve_parameters: dict[str, utils.NUMERICAL_TYPE]
86
+ ) -> utils.NUMERICAL_TYPE:
87
+ if self.eval_expression is None:
88
+ val = resolve_parameters.get(self.name, NOT_FOUND)
89
+ if val is NOT_FOUND:
90
+ raise ValueError(f"Cannot resolve Parameter named: {self.name}")
91
+ return val
92
+ op = self.eval_expression.op
93
+ left_param = eval_parameter(self.eval_expression.left_param, resolve_parameters)
94
+ right_param = eval_parameter(
95
+ self.eval_expression.right_param, resolve_parameters
96
+ )
97
+
98
+ if op == utils.Op.ADD:
99
+ return left_param + right_param
100
+ if op == utils.Op.SUB:
101
+ return left_param - right_param
102
+ if op == utils.Op.MUL:
103
+ return left_param * right_param
104
+ if op == utils.Op.DIV:
105
+ return left_param / right_param
106
+
107
+ raise ValueError(f"Encountered unknown {op=} when evaluation the point.")
108
+
109
+ def __add__(self, other):
110
+ if not utils.is_numerical_or_parameter(other):
111
+ return NotImplemented
112
+ return Parameter(
113
+ name=None,
114
+ eval_expression=ParameterRepresentation(
115
+ op=utils.Op.ADD, left_param=self, right_param=other
116
+ ),
117
+ )
118
+
119
+ def __radd__(self, other):
120
+ if not utils.is_numerical_or_parameter(other):
121
+ return NotImplemented
122
+ return Parameter(
123
+ name=None,
124
+ eval_expression=ParameterRepresentation(
125
+ op=utils.Op.ADD, left_param=other, right_param=self
126
+ ),
127
+ )
128
+
129
+ def __sub__(self, other):
130
+ if not utils.is_numerical_or_parameter(other):
131
+ return NotImplemented
132
+ return Parameter(
133
+ name=None,
134
+ eval_expression=ParameterRepresentation(
135
+ op=utils.Op.SUB, left_param=self, right_param=other
136
+ ),
137
+ )
138
+
139
+ def __rsub__(self, other):
140
+ if not utils.is_numerical_or_parameter(other):
141
+ return NotImplemented
142
+ return Parameter(
143
+ name=None,
144
+ eval_expression=ParameterRepresentation(
145
+ op=utils.Op.SUB, left_param=other, right_param=self
146
+ ),
147
+ )
148
+
149
+ def __mul__(self, other):
150
+ if not utils.is_numerical_or_parameter(other):
151
+ return NotImplemented
152
+ return Parameter(
153
+ name=None,
154
+ eval_expression=ParameterRepresentation(
155
+ op=utils.Op.MUL, left_param=self, right_param=other
156
+ ),
157
+ )
158
+
159
+ def __rmul__(self, other):
160
+ if not utils.is_numerical_or_parameter(other):
161
+ return NotImplemented
162
+ return Parameter(
163
+ name=None,
164
+ eval_expression=ParameterRepresentation(
165
+ op=utils.Op.MUL, left_param=other, right_param=self
166
+ ),
167
+ )
168
+
169
+ def __truediv__(self, other):
170
+ if not utils.is_numerical_or_parameter(other):
171
+ return NotImplemented
172
+ return Parameter(
173
+ name=None,
174
+ eval_expression=ParameterRepresentation(
175
+ op=utils.Op.DIV, left_param=self, right_param=other
176
+ ),
177
+ )
178
+
179
+ def __rtruediv__(self, other):
180
+ if not utils.is_numerical_or_parameter(other):
181
+ return NotImplemented
182
+ return Parameter(
183
+ name=None,
184
+ eval_expression=ParameterRepresentation(
185
+ op=utils.Op.DIV, left_param=other, right_param=self
186
+ ),
187
+ )
@@ -0,0 +1,128 @@
1
+ # Copyright: 2025 The PEPFlow Developers
2
+ #
3
+ # Licensed to the Apache Software Foundation (ASF) under one
4
+ # or more contributor license agreements. See the NOTICE file
5
+ # distributed with this work for additional information
6
+ # regarding copyright ownership. The ASF licenses this file
7
+ # to you under the Apache License, Version 2.0 (the
8
+ # "License"); you may not use this file except in compliance
9
+ # with the License. You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing,
14
+ # software distributed under the License is distributed on an
15
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16
+ # KIND, either express or implied. See the License for the
17
+ # specific language governing permissions and limitations
18
+ # under the License.
19
+
20
+ from typing import Iterator
21
+
22
+ import numpy as np
23
+ import pytest
24
+ import sympy as sp
25
+
26
+ from pepflow import pep_context as pc
27
+ from pepflow.expression_manager import ExpressionManager
28
+ from pepflow.parameter import Parameter
29
+ from pepflow.point import Point
30
+ from pepflow.scalar import Scalar
31
+
32
+
33
+ @pytest.fixture
34
+ def pep_context() -> Iterator[pc.PEPContext]:
35
+ """Prepare the pep context and reset the context to None at the end."""
36
+ ctx = pc.PEPContext("test").set_as_current()
37
+ yield ctx
38
+ pc.set_current_context(None)
39
+
40
+
41
+ def test_parameter_interact_with_scalar(pep_context: pc.PEPContext):
42
+ pm1 = Parameter("pm1")
43
+ s1 = Scalar(is_basis=True, tags=["s1"])
44
+
45
+ _ = pm1 + s1
46
+ _ = s1 + pm1
47
+ _ = pm1 - s1
48
+ _ = s1 - pm1
49
+ _ = s1 * pm1
50
+ _ = pm1 * s1
51
+ _ = s1 / pm1
52
+
53
+
54
+ def test_parameter_interact_with_point(pep_context: pc.PEPContext):
55
+ pm1 = Parameter("pm1")
56
+ p1 = Point(is_basis=True, tags=["p1"])
57
+
58
+ _ = p1 * pm1
59
+ _ = pm1 * p1
60
+ _ = p1 / pm1
61
+
62
+
63
+ def test_parameter_composition_with_point_and_scalar(pep_context: pc.PEPContext):
64
+ pm1 = Parameter("pm1")
65
+ pm2 = Parameter("pm2")
66
+ p1 = Point(is_basis=True, tags=["p1"])
67
+ s1 = Scalar(is_basis=True, tags=["s1"])
68
+
69
+ s2 = s1 + pm1 + pm2 * p1**2
70
+ assert str(s2) == "s1+pm1+pm2*|p1|^2"
71
+
72
+
73
+ def test_parameter_composition(pep_context: pc.PEPContext):
74
+ pm1 = Parameter("pm1")
75
+ pm2 = Parameter("pm2")
76
+
77
+ pp = (pm1 + 2) * pm2
78
+ assert str(pp) == "((pm1+2)*pm2)"
79
+ assert pp.get_value({"pm1": 3, "pm2": 6}) == 30
80
+
81
+ pp2 = (pm1 + sp.Rational(1, 2)) * pm2
82
+ assert str(pp2) == "((pm1+1/2)*pm2)"
83
+ assert pp2.get_value({"pm1": sp.Rational(1, 3), "pm2": sp.Rational(6, 5)}) == 1
84
+
85
+
86
+ def test_expression_manager_eval_with_parameter(pep_context: pc.PEPContext):
87
+ pm1 = Parameter("pm1")
88
+ p1 = Point(is_basis=True, tags=["p1"])
89
+ p2 = Point(is_basis=True, tags=["p2"])
90
+ p3 = pm1 * p1 + p2 / 4
91
+
92
+ em = ExpressionManager(pep_context, {"pm1": 2.3})
93
+ np.testing.assert_allclose(em.eval_point(p3).vector, np.array([2.3, 0.25]))
94
+
95
+ em = ExpressionManager(pep_context, {"pm1": 3.4})
96
+ np.testing.assert_allclose(em.eval_point(p3).vector, np.array([3.4, 0.25]))
97
+
98
+
99
+ def test_expression_manager_eval_with_parameter_scalar(pep_context: pc.PEPContext):
100
+ pm1 = Parameter("pm1")
101
+ pm2 = Parameter("pm2")
102
+ p1 = Point(is_basis=True, tags=["p1"])
103
+ p2 = Point(is_basis=True, tags=["p2"])
104
+ s1 = Scalar(is_basis=True, tags=["s1"])
105
+ s2 = pm1 * p1 * p2 + pm2 + s1
106
+
107
+ em = ExpressionManager(pep_context, {"pm1": 2.4, "pm2": 4.3})
108
+ assert np.isclose(em.eval_scalar(s2).constant, 4.3)
109
+ np.testing.assert_allclose(em.eval_scalar(s2).vector, np.array([1]))
110
+ np.testing.assert_allclose(
111
+ em.eval_scalar(s2).matrix, np.array([[0, 1.2], [1.2, 0]])
112
+ )
113
+
114
+
115
+ def test_expression_manager_eval_composition(pep_context: pc.PEPContext):
116
+ pm1 = Parameter("pm1")
117
+ pm2 = Parameter("pm2")
118
+ p1 = Point(is_basis=True, tags=["p1"])
119
+ p2 = Point(is_basis=True, tags=["p2"])
120
+ s1 = Scalar(is_basis=True, tags=["s1"])
121
+
122
+ s2 = 1 / pm1 * p1 * p2 + (pm2 + 1) * s1
123
+ em = ExpressionManager(pep_context, {"pm1": 0.5, "pm2": 4.3})
124
+ assert np.isclose(em.eval_scalar(s2).constant, 0)
125
+ np.testing.assert_allclose(em.eval_scalar(s2).vector, np.array([5.3]))
126
+ np.testing.assert_allclose(
127
+ em.eval_scalar(s2).matrix, np.array([[0, 1.0], [1.0, 0]])
128
+ )
pepflow/pep.py CHANGED
@@ -20,6 +20,7 @@
20
20
  from __future__ import annotations
21
21
 
22
22
  import contextlib
23
+ from collections import defaultdict
23
24
  from typing import TYPE_CHECKING, Any, Iterator
24
25
 
25
26
  import attrs
@@ -30,22 +31,50 @@ from pepflow import pep_context as pc
30
31
  from pepflow import point as pt
31
32
  from pepflow import scalar as sc
32
33
  from pepflow import solver as ps
34
+ from pepflow import utils
33
35
  from pepflow.constants import PSD_CONSTRAINT
34
36
 
35
37
  if TYPE_CHECKING:
36
38
  from pepflow.constraint import Constraint
37
39
  from pepflow.function import Function
38
- from pepflow.solver import DualVariableManager
40
+ from pepflow.solver import DualVariableManager, PrimalVariableManager
39
41
 
40
42
 
41
43
  @attrs.frozen
42
44
  class PEPResult:
45
+ """
46
+ A data class object that contains the results of solving the Primal
47
+ PEP.
48
+
49
+ Attributes:
50
+ primal_opt_value (float): The objective value of the solved Primal PEP.
51
+ dual_var_manager (:class:`DualVariableManager`): A manager object which
52
+ provides access to the dual variables associated with the
53
+ constraints of the Primal PEP.
54
+ solver_status (Any): States whether the solver managed to solve the
55
+ Primal PEP successfully.
56
+ context (:class:`PEPContext`): The :class:`PEPContext` object used to
57
+ solve the Primal PEP.
58
+
59
+ """
60
+
43
61
  primal_opt_value: float
44
62
  dual_var_manager: DualVariableManager
45
63
  solver_status: Any
46
64
  context: pc.PEPContext
47
65
 
48
66
  def get_function_dual_variables(self) -> dict[Function, np.ndarray]:
67
+ """
68
+ Return a dictionary which contains the associated dual variables of the
69
+ interpolation constraints for Primal PEP.
70
+
71
+ Returns:
72
+ dict[:class:`Function`, np.ndarray]: A dictionary where the keys
73
+ are :class:`Function` objects, and the values are the dual
74
+ variables associated with the interpolation constraints of the
75
+ :class:`Function` key.
76
+ """
77
+
49
78
  def get_matrix_of_dual_value(df: pd.DataFrame) -> np.ndarray:
50
79
  # Check if we need to update the order.
51
80
  return (
@@ -68,31 +97,90 @@ class PEPResult:
68
97
 
69
98
  return df_dict_matrix
70
99
 
71
- def get_psd_dual_matrix(self):
100
+ def get_psd_dual_matrix(self) -> np.ndarray:
101
+ """
102
+ Return the PSD dual variable matrix associated with the constraint
103
+ that the Primal PEP decision variable :math:`G` is PSD.
104
+
105
+ Returns:
106
+ np.ndarray: The PSD dual variable matrix associated with the
107
+ constraint that the Primal PEP decision variable :math:`G` is PSD.
108
+ """
72
109
  return np.array(self.dual_var_manager.dual_value(PSD_CONSTRAINT))
73
110
 
74
111
 
112
+ @attrs.frozen
113
+ class DualPEPResult:
114
+ """
115
+ A data class object that contains the results of solving the Dual
116
+ PEP.
117
+
118
+ Attributes:
119
+ dual_opt_value (float): The objective value of the solved Dual PEP.
120
+ primal_var_manager (:class:`PrimalVariableManager`): A manager object
121
+ which provides access to the primal variables of Dual PEP.
122
+ solver_status (Any): States whether the solver managed to solve the
123
+ Dual PEP successfully.
124
+ context (:class:`PEPContext`): The :class:`PEPContext` object used to
125
+ solve the Dual PEP.
126
+
127
+ """
128
+
129
+ dual_opt_value: float
130
+ primal_var_manager: PrimalVariableManager
131
+ solver_status: Any
132
+ context: pc.PEPContext
133
+
134
+
75
135
  class PEPBuilder:
76
- """The main class for PEP primal formulation."""
136
+ """The main class for Primal and Dual PEP formulation.
137
+
138
+ Attributes:
139
+ init_conditions (list[:class:`Constraint`]): A list of all the initial
140
+ conditions associated with this PEP.
141
+ functions (list[:class:`Function`]): A list of all the functions
142
+ associated with this PEP.
143
+ performance_metric (:class:`Scalar`): The performance metric for this
144
+ PEP.
145
+ relaxed_constraints (list[str]): A list of names of the constraints
146
+ that will be ignored when the Primal or Dual PEP is constructed.
147
+ dual_val_constraint (dict[str, list[tuple[str, float]]]): A dictionary
148
+ of the form `{constraint_name: [op, val]}`. The `constraint_name`
149
+ is the name of the constraint the dual variable is associated with.
150
+ The `op` is a string for the type of relation, i.e., `lt`, `gt`,
151
+ or `eq`. The `val` is the value for the other side of the
152
+ constraint. For example, consider `{"f:x_1,x_0", [("eq", 0)]}`.
153
+ Denote the associated dual variable as :math:`\\lambda_{1,0}`.
154
+ Then, this means to add a constraint of the form
155
+ :math:`\\lambda_{1,0} = 0` to the Dual PEP. Because it is hard to
156
+ judge if the constraint associated with `constraint_name` is
157
+ active, we suggest to not add dual variable constraints manually
158
+ but instead use the interactive dashboard.
159
+ """
77
160
 
78
161
  def __init__(self):
79
162
  self.pep_context_dict: dict[str, pc.PEPContext] = {}
80
163
 
81
- self.init_conditions = [] #: list["constraint"] =[]
82
- self.functions = [] #: list["function"] = []
83
- self.interpolation_constraints = [] #: list["constraint"] = []
84
- self.performance_metric = None # scalar
164
+ self.init_conditions: list[Constraint] = []
165
+ self.functions: list[Function] = []
166
+ self.performance_metric: sc.Scalar | None = None
85
167
 
86
168
  # Contain the name for the constraints that should be removed.
87
169
  # We should think about a better choice like manager.
88
- self.relaxed_constraints = []
170
+ self.relaxed_constraints: list[str] = []
171
+
172
+ # `dual_val_constraint` has the data structure: {constraint_name: [op, val]}.
173
+ # Because it is hard to judge if the dual_val_constraint is applied or not,
174
+ # we recommend that do not use manually but through the interactive dashboard.
175
+ self.dual_val_constraint: dict[str, list[tuple[str, float]]] = defaultdict(list)
89
176
 
90
177
  def clear_setup(self):
178
+ """Resets the :class:`PEPBuilder` object."""
91
179
  self.init_conditions.clear()
92
180
  self.functions.clear()
93
- self.interpolation_constraints.clear()
94
181
  self.performance_metric = None
95
182
  self.relaxed_constraints.clear()
183
+ self.dual_val_constraint.clear()
96
184
 
97
185
  @contextlib.contextmanager
98
186
  def make_context(
@@ -130,20 +218,116 @@ class PEPBuilder:
130
218
  return point
131
219
 
132
220
  def set_initial_constraint(self, constraint):
221
+ """
222
+ Set an initial condition.
223
+
224
+ Args:
225
+ constraint (:class:`Constraint`): A :class:`Constraint` object that
226
+ represents the desired initial condition.
227
+ """
133
228
  self.init_conditions.append(constraint)
134
229
 
135
230
  def set_performance_metric(self, metric: sc.Scalar):
231
+ """
232
+ Set the performance metric.
233
+
234
+ Args:
235
+ metric (:class:`Scalar`): A :class:`Scalar` object that
236
+ represents the desired performance metric.
237
+ """
136
238
  self.performance_metric = metric
137
239
 
138
240
  def set_relaxed_constraints(self, relaxed_constraints: list[str]):
241
+ """
242
+ Set the constraints that will be ignored.
243
+
244
+ Args:
245
+ relaxed_constraints (list[str]): A list of names of constraints
246
+ that will be ignored.
247
+ """
139
248
  self.relaxed_constraints.extend(relaxed_constraints)
140
249
 
141
- def declare_func(self, function_class, **kwargs):
250
+ def add_dual_val_constraint(
251
+ self, constraint_name: str, op: str, val: float
252
+ ) -> None:
253
+ if op not in ["lt", "gt", "eq"]:
254
+ raise ValueError(f"op must be one of `lt`, `gt`, or `eq` but get {op}")
255
+ if not utils.is_numerical(val):
256
+ raise ValueError("Value must be some numerical value.")
257
+
258
+ self.dual_val_constraint[constraint_name].append((op, val))
259
+
260
+ def declare_func(self, function_class: type[Function], tag: str, **kwargs):
261
+ """
262
+ Declare a function.
263
+
264
+ Args:
265
+ function_class (type[:class:`Function`]): The type of function we want to
266
+ declare. Examples include :class:`ConvexFunction` or
267
+ :class:`SmoothConvexFunction`.
268
+ tag (str): A tag that will be added to the :class:`Function`'s
269
+ `tags` list. It can be used to identify the :class:`Function`
270
+ object.
271
+ **kwargs: The other parameters needed to declare the function. For
272
+ example, :class:`SmoothConvexFunction` will require a
273
+ smoothness parameter `L`.
274
+ """
142
275
  func = function_class(is_basis=True, composition=None, **kwargs)
276
+ func.add_tag(tag)
143
277
  self.functions.append(func)
144
278
  return func
145
279
 
146
- def solve(self, context: pc.PEPContext | None = None, **kwargs):
280
+ def get_func_by_tag(self, tag: str):
281
+ """
282
+ Return the :class:`Function` object associated with the provided `tag`.
283
+
284
+ Args:
285
+ tag (str): The `tag` of the :class:`Function` object we want to
286
+ retrieve.
287
+
288
+ Returns:
289
+ :class:`Function`: The :class:`Function` object associated with
290
+ the `tag`.
291
+
292
+ Note:
293
+ Currently, only basis :class:`Function` objects can be retrieved.
294
+ This will be updated eventually.
295
+ """
296
+ # TODO: Add support to return composite functions as well. Right now we can only return base functions
297
+ for f in self.functions:
298
+ if tag in f.tags:
299
+ return f
300
+ raise ValueError("Cannot find the function of given tag.")
301
+
302
+ def solve(
303
+ self,
304
+ context: pc.PEPContext | None = None,
305
+ resolve_parameters: dict[str, utils.NUMERICAL_TYPE] | None = None,
306
+ ):
307
+ return self.solve_primal(context, resolve_parameters=resolve_parameters)
308
+
309
+ def solve_primal(
310
+ self,
311
+ context: pc.PEPContext | None = None,
312
+ resolve_parameters: dict[str, utils.NUMERICAL_TYPE] | None = None,
313
+ ):
314
+ """
315
+ Solve the Primal PEP associated with this :class:`PEPBuilder` object
316
+ using the given :class:`PEPContext` object.
317
+
318
+ Args:
319
+ context (:class:`PEPContext`): The :class:`PEPContext` object used
320
+ to solve the Primal PEP associated with this
321
+ :class:`PEPBuilder` object. `None` if we consider the current
322
+ global :class:`PEPContext` object.
323
+ resolve_parameters (dict[str, :class:`NUMERICAL_TYPE`]): A dictionary that
324
+ maps the name of parameters to the numerical values.
325
+
326
+ Returns:
327
+ :class:`PEPResult`: A :class:`PEPResult` object that contains the
328
+ information obtained after solving the Primal PEP associated with
329
+ this :class:`PEPBuilder` object.
330
+ """
147
331
  if context is None:
148
332
  context = pc.get_current_context()
149
333
  if context is None:
@@ -151,23 +335,86 @@ class PEPBuilder:
151
335
 
152
336
  all_constraints: list[Constraint] = [*self.init_conditions]
153
337
  for f in self.functions:
154
- all_constraints.extend(f.get_interpolation_constraints())
155
- all_constraints.extend(context.opt_conditions[f])
338
+ all_constraints.extend(f.get_interpolation_constraints(context))
156
339
 
157
340
  # for now, we heavily rely on the CVX. We can make a wrapper class to avoid
158
341
  # direct dependency in the future.
159
- solver = ps.CVXSolver(
342
+ solver = ps.CVXPrimalSolver(
160
343
  perf_metric=self.performance_metric,
161
344
  constraints=[
162
345
  c for c in all_constraints if c.name not in self.relaxed_constraints
163
346
  ],
164
347
  context=context,
165
348
  )
166
- problem = solver.build_problem()
167
- result = problem.solve(**kwargs)
349
+ problem = solver.build_problem(resolve_parameters=resolve_parameters)
350
+ result = problem.solve()
168
351
  return PEPResult(
169
352
  primal_opt_value=result,
170
353
  dual_var_manager=solver.dual_var_manager,
171
354
  solver_status=problem.status,
172
355
  context=context,
173
356
  )
357
+
358
+ def solve_dual(
359
+ self,
360
+ context: pc.PEPContext | None = None,
361
+ resolve_parameters: dict[str, utils.NUMERICAL_TYPE] | None = None,
362
+ ):
363
+ """
364
+ Solve the Dual PEP associated with this :class:`PEPBuilder` object
365
+ using the given :class:`PEPContext` object.
366
+
367
+ Args:
368
+ context (:class:`PEPContext`): The :class:`PEPContext` object used
369
+ to solve the Dual PEP associated with this :class:`PEPBuilder`
370
+ object. `None` if we consider the current global
371
+ :class:`PEPContext` object.
372
+ resolve_parameters (dict[str, :class:`NUMERICAL_TYPE`]): A dictionary that
373
+ maps the name of parameters to the numerical values.
374
+
375
+ Returns:
376
+ :class:`DualPEPResult`: A :class:`DualPEPResult` object that
377
+ contains the information obtained after solving the Dual PEP
378
+ associated with this :class:`PEPBuilder` object.
379
+ """
380
+ if context is None:
381
+ context = pc.get_current_context()
382
+ if context is None:
383
+ raise RuntimeError("Did you forget to create a context?")
384
+
385
+ all_constraints: list[Constraint] = [*self.init_conditions]
386
+ for f in self.functions:
387
+ all_constraints.extend(f.get_interpolation_constraints(context))
388
+
389
+ # TODO: Consider a better API and interface to adding constraint for primal
390
+ # variable of dual problem. We can add `extra_primal_val_constraints` to add
391
+ # more constraints on primal var in dual PEP, i.e. dual var in primal PEP.
392
+ constraints = []
393
+ for c in all_constraints:
394
+ if c.name in self.relaxed_constraints:
395
+ continue
396
+ for op, val in self.dual_val_constraint[c.name]:
397
+ if op == "lt":
398
+ c.dual_lt(val)
399
+ elif op == "gt":
400
+ c.dual_gt(val)
401
+ elif op == "eq":
402
+ c.dual_eq(val)
403
+ else:
404
+ raise ValueError(f"Unknown op when construct the {c}")
405
+ constraints.append(c)
406
+
407
+ dual_solver = ps.CVXDualSolver(
408
+ perf_metric=self.performance_metric,
409
+ constraints=constraints,
410
+ context=context,
411
+ )
412
+ problem = dual_solver.build_problem(resolve_parameters=resolve_parameters)
413
+ result = problem.solve()
414
+
415
+ return DualPEPResult(
416
+ dual_opt_value=result,
417
+ primal_var_manager=dual_solver.primal_var_manager,
418
+ solver_status=problem.status,
419
+ context=context,
420
+ )