pepflow 0.1.4a1__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pepflow/__init__.py +5 -1
- pepflow/constraint.py +58 -1
- pepflow/e2e_test.py +47 -3
- pepflow/expression_manager.py +272 -57
- pepflow/expression_manager_test.py +36 -2
- pepflow/function.py +180 -10
- pepflow/parameter.py +187 -0
- pepflow/parameter_test.py +128 -0
- pepflow/pep.py +254 -14
- pepflow/pep_context.py +116 -0
- pepflow/pep_context_test.py +21 -0
- pepflow/point.py +155 -49
- pepflow/point_test.py +12 -0
- pepflow/scalar.py +260 -47
- pepflow/scalar_test.py +15 -0
- pepflow/solver.py +170 -3
- pepflow/solver_test.py +50 -2
- pepflow/utils.py +39 -7
- {pepflow-0.1.4a1.dist-info → pepflow-0.1.5.dist-info}/METADATA +12 -11
- pepflow-0.1.5.dist-info/RECORD +28 -0
- pepflow-0.1.4a1.dist-info/RECORD +0 -26
- {pepflow-0.1.4a1.dist-info → pepflow-0.1.5.dist-info}/WHEEL +0 -0
- {pepflow-0.1.4a1.dist-info → pepflow-0.1.5.dist-info}/licenses/LICENSE +0 -0
- {pepflow-0.1.4a1.dist-info → pepflow-0.1.5.dist-info}/top_level.txt +0 -0
pepflow/pep.py
CHANGED
@@ -20,6 +20,7 @@
|
|
20
20
|
from __future__ import annotations
|
21
21
|
|
22
22
|
import contextlib
|
23
|
+
from collections import defaultdict
|
23
24
|
from typing import TYPE_CHECKING, Any, Iterator
|
24
25
|
|
25
26
|
import attrs
|
@@ -30,22 +31,50 @@ from pepflow import pep_context as pc
|
|
30
31
|
from pepflow import point as pt
|
31
32
|
from pepflow import scalar as sc
|
32
33
|
from pepflow import solver as ps
|
34
|
+
from pepflow import utils
|
33
35
|
from pepflow.constants import PSD_CONSTRAINT
|
34
36
|
|
35
37
|
if TYPE_CHECKING:
|
36
38
|
from pepflow.constraint import Constraint
|
37
39
|
from pepflow.function import Function
|
38
|
-
from pepflow.solver import DualVariableManager
|
40
|
+
from pepflow.solver import DualVariableManager, PrimalVariableManager
|
39
41
|
|
40
42
|
|
41
43
|
@attrs.frozen
|
42
44
|
class PEPResult:
|
45
|
+
"""
|
46
|
+
A data class object that contains the results of solving the Primal
|
47
|
+
PEP.
|
48
|
+
|
49
|
+
Attributes:
|
50
|
+
primal_opt_value (float): The objective value of the solved Primal PEP.
|
51
|
+
dual_var_manager (:class:`DualVariableManager`): A manager object which
|
52
|
+
provides access to the dual variables associated with the
|
53
|
+
constraints of the Primal PEP.
|
54
|
+
solver_status (Any): States whether the solver managed to solve the
|
55
|
+
Primal PEP successfully.
|
56
|
+
context (:class:`PEPContext`): The :class:`PEPContext` object used to
|
57
|
+
solve the Primal PEP.
|
58
|
+
|
59
|
+
"""
|
60
|
+
|
43
61
|
primal_opt_value: float
|
44
62
|
dual_var_manager: DualVariableManager
|
45
63
|
solver_status: Any
|
46
64
|
context: pc.PEPContext
|
47
65
|
|
48
66
|
def get_function_dual_variables(self) -> dict[Function, np.ndarray]:
|
67
|
+
"""
|
68
|
+
Return a dictionary which contains the associated dual variables of the
|
69
|
+
interpolation constraints for Primal PEP.
|
70
|
+
|
71
|
+
Returns:
|
72
|
+
dict[:class:`Function`, np.ndarray]: A dictionary where the keys
|
73
|
+
are :class:`Function` objects, and the values are the dual
|
74
|
+
variables associated with the interpolation constraints of the
|
75
|
+
:class:`Function` key.
|
76
|
+
"""
|
77
|
+
|
49
78
|
def get_matrix_of_dual_value(df: pd.DataFrame) -> np.ndarray:
|
50
79
|
# Check if we need to update the order.
|
51
80
|
return (
|
@@ -68,31 +97,90 @@ class PEPResult:
|
|
68
97
|
|
69
98
|
return df_dict_matrix
|
70
99
|
|
71
|
-
def get_psd_dual_matrix(self):
|
100
|
+
def get_psd_dual_matrix(self) -> np.ndarray:
|
101
|
+
"""
|
102
|
+
Return the PSD dual variable matrix associated with the constraint
|
103
|
+
that the Primal PEP decision variable :math:`G` is PSD.
|
104
|
+
|
105
|
+
Returns:
|
106
|
+
np.ndarray: The PSD dual variable matrix associated with the
|
107
|
+
constraint that the Primal PEP decision variable :math:`G` is PSD.
|
108
|
+
"""
|
72
109
|
return np.array(self.dual_var_manager.dual_value(PSD_CONSTRAINT))
|
73
110
|
|
74
111
|
|
112
|
+
@attrs.frozen
|
113
|
+
class DualPEPResult:
|
114
|
+
"""
|
115
|
+
A data class object that contains the results of solving the Dual
|
116
|
+
PEP.
|
117
|
+
|
118
|
+
Attributes:
|
119
|
+
dual_opt_value (float): The objective value of the solved Dual PEP.
|
120
|
+
primal_var_manager (:class:`PrimalVariableManager`): A manager object
|
121
|
+
which provides access to the primal variables of Dual PEP.
|
122
|
+
solver_status (Any): States whether the solver managed to solve the
|
123
|
+
Dual PEP successfully.
|
124
|
+
context (:class:`PEPContext`): The :class:`PEPContext` object used to
|
125
|
+
solve the Dual PEP.
|
126
|
+
|
127
|
+
"""
|
128
|
+
|
129
|
+
dual_opt_value: float
|
130
|
+
primal_var_manager: PrimalVariableManager
|
131
|
+
solver_status: Any
|
132
|
+
context: pc.PEPContext
|
133
|
+
|
134
|
+
|
75
135
|
class PEPBuilder:
|
76
|
-
"""The main class for PEP
|
136
|
+
"""The main class for Primal and Dual PEP formulation.
|
137
|
+
|
138
|
+
Attributes:
|
139
|
+
init_conditions (list[:class:`Constraint`]): A list of all the initial
|
140
|
+
conditions associated with this PEP.
|
141
|
+
functions (list[:class:`Function`]): A list of all the functions
|
142
|
+
associated with this PEP.
|
143
|
+
performance_metric (:class:`Scalar`): The performance metric for this
|
144
|
+
PEP.
|
145
|
+
relaxed_constraints (list[str]): A list of names of the constraints
|
146
|
+
that will be ignored when the Primal or Dual PEP is constructed.
|
147
|
+
dual_val_constraint (dict[str, list[tuple[str, float]]]): A dictionary
|
148
|
+
of the form `{constraint_name: [op, val]}`. The `constraint_name`
|
149
|
+
is the name of the constraint the dual variable is associated with.
|
150
|
+
The `op` is a string for the type of relation, i.e., `lt`, `gt`,
|
151
|
+
or `eq`. The `val` is the value for the other side of the
|
152
|
+
constraint. For example, consider `{"f:x_1,x_0", [("eq", 0)]}`.
|
153
|
+
Denote the associated dual variable as :math:`\\lambda_{1,0}`.
|
154
|
+
Then, this means to add a constraint of the form
|
155
|
+
:math:`\\lambda_{1,0} = 0` to the Dual PEP. Because it is hard to
|
156
|
+
judge if the constraint associated with `constraint_name` is
|
157
|
+
active, we suggest to not add dual variable constraints manually
|
158
|
+
but instead use the interactive dashboard.
|
159
|
+
"""
|
77
160
|
|
78
161
|
def __init__(self):
|
79
162
|
self.pep_context_dict: dict[str, pc.PEPContext] = {}
|
80
163
|
|
81
|
-
self.init_conditions
|
82
|
-
self.functions
|
83
|
-
self.
|
84
|
-
self.performance_metric = None # scalar
|
164
|
+
self.init_conditions: list[Constraint] = []
|
165
|
+
self.functions: list[Function] = []
|
166
|
+
self.performance_metric: sc.Scalar | None = None
|
85
167
|
|
86
168
|
# Contain the name for the constraints that should be removed.
|
87
169
|
# We should think about a better choice like manager.
|
88
|
-
self.relaxed_constraints = []
|
170
|
+
self.relaxed_constraints: list[str] = []
|
171
|
+
|
172
|
+
# `dual_val_constraint` has the data structure: {constraint_name: [op, val]}.
|
173
|
+
# Because it is hard to judge if the dual_val_constraint is applied or not,
|
174
|
+
# we recommend that do not use manually but through the interactive dashboard.
|
175
|
+
self.dual_val_constraint: dict[str, list[tuple[str, float]]] = defaultdict(list)
|
89
176
|
|
90
177
|
def clear_setup(self):
|
178
|
+
"""Resets the :class:`PEPBuilder` object."""
|
91
179
|
self.init_conditions.clear()
|
92
180
|
self.functions.clear()
|
93
|
-
self.interpolation_constraints.clear()
|
94
181
|
self.performance_metric = None
|
95
182
|
self.relaxed_constraints.clear()
|
183
|
+
self.dual_val_constraint.clear()
|
96
184
|
|
97
185
|
@contextlib.contextmanager
|
98
186
|
def make_context(
|
@@ -130,28 +218,116 @@ class PEPBuilder:
|
|
130
218
|
return point
|
131
219
|
|
132
220
|
def set_initial_constraint(self, constraint):
|
221
|
+
"""
|
222
|
+
Set an initial condition.
|
223
|
+
|
224
|
+
Args:
|
225
|
+
constraint (:class:`Constraint`): A :class:`Constraint` object that
|
226
|
+
represents the desired initial condition.
|
227
|
+
"""
|
133
228
|
self.init_conditions.append(constraint)
|
134
229
|
|
135
230
|
def set_performance_metric(self, metric: sc.Scalar):
|
231
|
+
"""
|
232
|
+
Set the performance metric.
|
233
|
+
|
234
|
+
Args:
|
235
|
+
metric (:class:`Scalar`): A :class:`Scalar` object that
|
236
|
+
represents the desired performance metric.
|
237
|
+
"""
|
136
238
|
self.performance_metric = metric
|
137
239
|
|
138
240
|
def set_relaxed_constraints(self, relaxed_constraints: list[str]):
|
241
|
+
"""
|
242
|
+
Set the constraints that will be ignored.
|
243
|
+
|
244
|
+
Args:
|
245
|
+
relaxed_constraints (list[str]): A list of names of constraints
|
246
|
+
that will be ignored.
|
247
|
+
"""
|
139
248
|
self.relaxed_constraints.extend(relaxed_constraints)
|
140
249
|
|
250
|
+
def add_dual_val_constraint(
|
251
|
+
self, constraint_name: str, op: str, val: float
|
252
|
+
) -> None:
|
253
|
+
if op not in ["lt", "gt", "eq"]:
|
254
|
+
raise ValueError(f"op must be one of `lt`, `gt`, or `eq` but get {op}")
|
255
|
+
if not utils.is_numerical(val):
|
256
|
+
raise ValueError("Value must be some numerical value.")
|
257
|
+
|
258
|
+
self.dual_val_constraint[constraint_name].append((op, val))
|
259
|
+
|
141
260
|
def declare_func(self, function_class: type[Function], tag: str, **kwargs):
|
261
|
+
"""
|
262
|
+
Declare a function.
|
263
|
+
|
264
|
+
Args:
|
265
|
+
function_class (type[:class:`Function`]): The type of function we want to
|
266
|
+
declare. Examples include :class:`ConvexFunction` or
|
267
|
+
:class:`SmoothConvexFunction`.
|
268
|
+
tag (str): A tag that will be added to the :class:`Function`'s
|
269
|
+
`tags` list. It can be used to identify the :class:`Function`
|
270
|
+
object.
|
271
|
+
**kwargs: The other parameters needed to declare the function. For
|
272
|
+
example, :class:`SmoothConvexFunction` will require a
|
273
|
+
smoothness parameter `L`.
|
274
|
+
"""
|
142
275
|
func = function_class(is_basis=True, composition=None, **kwargs)
|
143
276
|
func.add_tag(tag)
|
144
277
|
self.functions.append(func)
|
145
278
|
return func
|
146
279
|
|
147
280
|
def get_func_by_tag(self, tag: str):
|
281
|
+
"""
|
282
|
+
Return the :class:`Function` object associated with the provided `tag`.
|
283
|
+
|
284
|
+
Args:
|
285
|
+
tag (str): The `tag` of the :class:`Function` object we want to
|
286
|
+
retrieve.
|
287
|
+
|
288
|
+
Returns:
|
289
|
+
:class:`Function`: The :class:`Function` object associated with
|
290
|
+
the `tag`.
|
291
|
+
|
292
|
+
Note:
|
293
|
+
Currently, only basis :class:`Function` objects can be retrieved.
|
294
|
+
This will be updated eventually.
|
295
|
+
"""
|
148
296
|
# TODO: Add support to return composite functions as well. Right now we can only return base functions
|
149
297
|
for f in self.functions:
|
150
298
|
if tag in f.tags:
|
151
299
|
return f
|
152
300
|
raise ValueError("Cannot find the function of given tag.")
|
153
301
|
|
154
|
-
def solve(
|
302
|
+
def solve(
|
303
|
+
self,
|
304
|
+
context: pc.PEPContext | None = None,
|
305
|
+
resolve_parameters: dict[str, utils.NUMERICAL_TYPE] | None = None,
|
306
|
+
):
|
307
|
+
return self.solve_primal(context, resolve_parameters=resolve_parameters)
|
308
|
+
|
309
|
+
def solve_primal(
|
310
|
+
self,
|
311
|
+
context: pc.PEPContext | None = None,
|
312
|
+
resolve_parameters: dict[str, utils.NUMERICAL_TYPE] | None = None,
|
313
|
+
):
|
314
|
+
"""
|
315
|
+
Solve the Primal PEP associated with this :class:`PEPBuilder` object
|
316
|
+
using the given :class:`PEPContext` object.
|
317
|
+
|
318
|
+
Args:
|
319
|
+
context (:class:`PEPContext`): The :class:`PEPContext` object used
|
320
|
+
to solve the Primal PEP associated with this
|
321
|
+
:class:`PEPBuilder` object. `None` if we consider the current
|
322
|
+
global :class:`PEPContext` object.
|
323
|
+
resolve_parameters (dict[str, :class:`NUMERICAL_TYPE`]): A dictionary that
|
324
|
+
maps the name of parameters to the numerical values.
|
325
|
+
|
326
|
+
Returns:
|
327
|
+
:class:`PEPResult`: A :class:`PEPResult` object that contains the
|
328
|
+
information obtained after solving the Primal PEP associated with
|
329
|
+
this :class:`PEPBuilder` object.
|
330
|
+
"""
|
155
331
|
if context is None:
|
156
332
|
context = pc.get_current_context()
|
157
333
|
if context is None:
|
@@ -159,22 +335,86 @@ class PEPBuilder:
|
|
159
335
|
|
160
336
|
all_constraints: list[Constraint] = [*self.init_conditions]
|
161
337
|
for f in self.functions:
|
162
|
-
all_constraints.extend(f.get_interpolation_constraints())
|
338
|
+
all_constraints.extend(f.get_interpolation_constraints(context))
|
163
339
|
|
164
340
|
# for now, we heavily rely on the CVX. We can make a wrapper class to avoid
|
165
341
|
# direct dependency in the future.
|
166
|
-
solver = ps.
|
342
|
+
solver = ps.CVXPrimalSolver(
|
167
343
|
perf_metric=self.performance_metric,
|
168
344
|
constraints=[
|
169
345
|
c for c in all_constraints if c.name not in self.relaxed_constraints
|
170
346
|
],
|
171
347
|
context=context,
|
172
348
|
)
|
173
|
-
problem = solver.build_problem()
|
174
|
-
result = problem.solve(
|
349
|
+
problem = solver.build_problem(resolve_parameters=resolve_parameters)
|
350
|
+
result = problem.solve()
|
175
351
|
return PEPResult(
|
176
352
|
primal_opt_value=result,
|
177
353
|
dual_var_manager=solver.dual_var_manager,
|
178
354
|
solver_status=problem.status,
|
179
355
|
context=context,
|
180
356
|
)
|
357
|
+
|
358
|
+
def solve_dual(
|
359
|
+
self,
|
360
|
+
context: pc.PEPContext | None = None,
|
361
|
+
resolve_parameters: dict[str, utils.NUMERICAL_TYPE] | None = None,
|
362
|
+
):
|
363
|
+
"""
|
364
|
+
Solve the Dual PEP associated with this :class:`PEPBuilder` object
|
365
|
+
using the given :class:`PEPContext` object.
|
366
|
+
|
367
|
+
Args:
|
368
|
+
context (:class:`PEPContext`): The :class:`PEPContext` object used
|
369
|
+
to solve the Dual PEP associated with this :class:`PEPBuilder`
|
370
|
+
object. `None` if we consider the current global
|
371
|
+
:class:`PEPContext` object.
|
372
|
+
resolve_parameters (dict[str, :class:`NUMERICAL_TYPE`]): A dictionary that
|
373
|
+
maps the name of parameters to the numerical values.
|
374
|
+
|
375
|
+
Returns:
|
376
|
+
:class:`DualPEPResult`: A :class:`DualPEPResult` object that
|
377
|
+
contains the information obtained after solving the Dual PEP
|
378
|
+
associated with this :class:`PEPBuilder` object.
|
379
|
+
"""
|
380
|
+
if context is None:
|
381
|
+
context = pc.get_current_context()
|
382
|
+
if context is None:
|
383
|
+
raise RuntimeError("Did you forget to create a context?")
|
384
|
+
|
385
|
+
all_constraints: list[Constraint] = [*self.init_conditions]
|
386
|
+
for f in self.functions:
|
387
|
+
all_constraints.extend(f.get_interpolation_constraints(context))
|
388
|
+
|
389
|
+
# TODO: Consider a better API and interface to adding constraint for primal
|
390
|
+
# variable of dual problem. We can add `extra_primal_val_constraints` to add
|
391
|
+
# more constraints on primal var in dual PEP, i.e. dual var in primal PEP.
|
392
|
+
constraints = []
|
393
|
+
for c in all_constraints:
|
394
|
+
if c.name in self.relaxed_constraints:
|
395
|
+
continue
|
396
|
+
for op, val in self.dual_val_constraint[c.name]:
|
397
|
+
if op == "lt":
|
398
|
+
c.dual_lt(val)
|
399
|
+
elif op == "gt":
|
400
|
+
c.dual_gt(val)
|
401
|
+
elif op == "eq":
|
402
|
+
c.dual_eq(val)
|
403
|
+
else:
|
404
|
+
raise ValueError(f"Unknown op when construct the {c}")
|
405
|
+
constraints.append(c)
|
406
|
+
|
407
|
+
dual_solver = ps.CVXDualSolver(
|
408
|
+
perf_metric=self.performance_metric,
|
409
|
+
constraints=constraints,
|
410
|
+
context=context,
|
411
|
+
)
|
412
|
+
problem = dual_solver.build_problem(resolve_parameters=resolve_parameters)
|
413
|
+
result = problem.solve()
|
414
|
+
|
415
|
+
return DualPEPResult(
|
416
|
+
dual_opt_value=result,
|
417
|
+
primal_var_manager=dual_solver.primal_var_manager,
|
418
|
+
solver_status=problem.status,
|
419
|
+
context=context,
|
420
|
+
)
|
pepflow/pep_context.py
CHANGED
@@ -37,16 +37,37 @@ GLOBAL_CONTEXT_DICT: dict[str, PEPContext] = {}
|
|
37
37
|
|
38
38
|
|
39
39
|
def get_current_context() -> PEPContext | None:
|
40
|
+
"""
|
41
|
+
Return the current global :class:`PEPContext`.
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
:class:`PEPContext`: The current global :class:`PEPContext`.
|
45
|
+
"""
|
40
46
|
return CURRENT_CONTEXT
|
41
47
|
|
42
48
|
|
43
49
|
def set_current_context(ctx: PEPContext | None):
|
50
|
+
"""
|
51
|
+
Change the current global :class:`PEPContext`.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
ctx (:class:`PEPContext`): The :class:`PEPContext` to set as the new
|
55
|
+
global :class:`PEPContext`.
|
56
|
+
"""
|
44
57
|
global CURRENT_CONTEXT
|
45
58
|
assert ctx is None or isinstance(ctx, PEPContext)
|
46
59
|
CURRENT_CONTEXT = ctx
|
47
60
|
|
48
61
|
|
49
62
|
class PEPContext:
|
63
|
+
"""
|
64
|
+
A :class:`PEPContext` object is a context manager which maintains
|
65
|
+
the abstract mathematical objects of the Primal and Dual PEP.
|
66
|
+
|
67
|
+
Attributes:
|
68
|
+
name (str): The unique name of the :class:`PEPContext` object.
|
69
|
+
"""
|
70
|
+
|
50
71
|
def __init__(self, name: str):
|
51
72
|
self.name = name
|
52
73
|
self.points: list[Point] = []
|
@@ -58,6 +79,12 @@ class PEPContext:
|
|
58
79
|
GLOBAL_CONTEXT_DICT[name] = self
|
59
80
|
|
60
81
|
def set_as_current(self) -> PEPContext:
|
82
|
+
"""
|
83
|
+
Set this :class:`PEPContext` object as the global context.
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
:class:`PEPContext`: This :class:`PEPContext` object.
|
87
|
+
"""
|
61
88
|
set_current_context(self)
|
62
89
|
return self
|
63
90
|
|
@@ -74,6 +101,18 @@ class PEPContext:
|
|
74
101
|
self.stationary_triplets[function].append(stationary_triplet)
|
75
102
|
|
76
103
|
def get_by_tag(self, tag: str) -> Point | Scalar:
|
104
|
+
"""
|
105
|
+
Under this :class:`PEPContext`, get the :class:`Point` or
|
106
|
+
:class:`Scalar` object associated with the provided `tag`.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
tag (str): The tag of the :class:`Point` or :class:`Scalar` object
|
110
|
+
we want to retrieve.
|
111
|
+
|
112
|
+
Returns:
|
113
|
+
:class:`Point` | :class:`Scalar`: The :class:`Point` or
|
114
|
+
:class:`Scalar` object associated with the provided `tag`.
|
115
|
+
"""
|
77
116
|
for p in self.points:
|
78
117
|
if tag in p.tags:
|
79
118
|
return p
|
@@ -83,22 +122,73 @@ class PEPContext:
|
|
83
122
|
raise ValueError("Cannot find the point, scalar, or function of given tag.")
|
84
123
|
|
85
124
|
def clear(self):
|
125
|
+
"""Reset this :class:`PEPContext` object."""
|
86
126
|
self.points.clear()
|
87
127
|
self.scalars.clear()
|
88
128
|
self.triplets.clear()
|
89
129
|
self.stationary_triplets.clear()
|
90
130
|
|
91
131
|
def tracked_point(self, func: Function) -> list[Point]:
|
132
|
+
"""
|
133
|
+
Each function :math:`f` used in Primal and Dual PEP is associated with
|
134
|
+
a set of triplets :math:`\\{x_i, f(x_i), \\nabla f(x_i)\\}` visited by
|
135
|
+
the considered algorithm. We can also consider a subgradient
|
136
|
+
:math:`\\widetilde{\\nabla} f(x)` instead of the gradient. This
|
137
|
+
function returns a list of the visited points :math:`\\{x_i\\}` under
|
138
|
+
this :class:`PEPContext`.
|
139
|
+
|
140
|
+
Args:
|
141
|
+
func (:class:`Function`): The function associated with the set
|
142
|
+
of triplets :math:`\\{x_i, f(x_i), \\nabla f(x_i)\\}`.
|
143
|
+
|
144
|
+
Returns:
|
145
|
+
list[:class:`Point`]: The list of the visited points
|
146
|
+
:math:`\\{x_i\\}`.
|
147
|
+
"""
|
92
148
|
return natsort.natsorted(
|
93
149
|
[t.point for t in self.triplets[func]], key=lambda x: x.tag
|
94
150
|
)
|
95
151
|
|
96
152
|
def tracked_grad(self, func: Function) -> list[Point]:
|
153
|
+
"""
|
154
|
+
Each function :math:`f` used in Primal and Dual PEP is associated with
|
155
|
+
a set of triplets :math:`\\{x_i, f(x_i), \\nabla f(x_i)\\}` visited by
|
156
|
+
the considered algorithm. We can also consider a subgradient
|
157
|
+
:math:`\\widetilde{\\nabla} f(x)` instead of the gradient
|
158
|
+
:math:`\\nabla f(x_i)`. This function returns a list of the visited
|
159
|
+
gradients :math:`\\{\\nabla f(x_i)\\}` under this :class:`PEPContext`.
|
160
|
+
|
161
|
+
Args:
|
162
|
+
func (:class:`Function`): The function associated with the set
|
163
|
+
of triplets :math:`\\{x_i, f(x_i), \\nabla f(x_i)\\}`.
|
164
|
+
|
165
|
+
Returns:
|
166
|
+
list[:class:`Point`]: The list of the visited gradients
|
167
|
+
:math:`\\{\\nabla f(x_i)\\}`.
|
168
|
+
|
169
|
+
"""
|
97
170
|
return natsort.natsorted(
|
98
171
|
[t.gradient for t in self.triplets[func]], key=lambda x: x.tag
|
99
172
|
)
|
100
173
|
|
101
174
|
def tracked_func_value(self, func: Function) -> list[Scalar]:
|
175
|
+
"""
|
176
|
+
Each function :math:`f` used in Primal and Dual PEP is associated with
|
177
|
+
a set of triplets :math:`\\{x_i, f(x_i), \\nabla f(x_i)\\}` visited by
|
178
|
+
the considered algorithm. We can also consider a subgradient
|
179
|
+
:math:`\\widetilde{\\nabla} f(x)` instead of the gradient
|
180
|
+
:math:`\\nabla f(x_i)`. This function returns a list of the visited
|
181
|
+
function values :math:`\\{f(x_i)\\}` under this :class:`PEPContext`.
|
182
|
+
|
183
|
+
Args:
|
184
|
+
func (:class:`Function`): The function associated with the set of
|
185
|
+
triplets :math:`\\{x_i, f(x_i), \\nabla f(x_i)\\}`.
|
186
|
+
|
187
|
+
Returns:
|
188
|
+
list[:class:`Scalar`]: The list of the visited function values
|
189
|
+
:math:`\\{f(x_i)\\}`.
|
190
|
+
|
191
|
+
"""
|
102
192
|
return natsort.natsorted(
|
103
193
|
[t.function_value for t in self.triplets[func]], key=lambda x: x.tag
|
104
194
|
)
|
@@ -134,3 +224,29 @@ class PEPContext:
|
|
134
224
|
func_to_order[func] = order
|
135
225
|
|
136
226
|
return func_to_df, func_to_order
|
227
|
+
|
228
|
+
def basis_points(self) -> list[Point]:
|
229
|
+
"""
|
230
|
+
Return a list of the basis :class:`Point` objects managed by this
|
231
|
+
:class:`PEPContext`.
|
232
|
+
|
233
|
+
Returns:
|
234
|
+
list[:class:`Point`]: A list of the basis :class:`Point` objects
|
235
|
+
managed by this :class:`PEPContext`.
|
236
|
+
"""
|
237
|
+
return [
|
238
|
+
p for p in self.points if p.is_basis
|
239
|
+
] # Note the order is always the same as added time
|
240
|
+
|
241
|
+
def basis_scalars(self) -> list[Scalar]:
|
242
|
+
"""
|
243
|
+
Return a list of the basis :class:`Scalar` objects managed by this
|
244
|
+
:class:`PEPContext`.
|
245
|
+
|
246
|
+
Returns:
|
247
|
+
list[:class:`Scalar`]: A list of the basis :class:`Scalar` objects
|
248
|
+
managed by this :class:`PEPContext`.
|
249
|
+
"""
|
250
|
+
return [
|
251
|
+
s for s in self.scalars if s.is_basis
|
252
|
+
] # Note the order is always the same as added time
|
pepflow/pep_context_test.py
CHANGED
@@ -25,6 +25,7 @@ import pytest
|
|
25
25
|
from pepflow import pep_context as pc
|
26
26
|
from pepflow.function import SmoothConvexFunction
|
27
27
|
from pepflow.point import Point
|
28
|
+
from pepflow.scalar import Scalar
|
28
29
|
|
29
30
|
|
30
31
|
@pytest.fixture
|
@@ -102,3 +103,23 @@ def test_get_by_tag(pep_context: pc.PEPContext):
|
|
102
103
|
assert pep_context.get_by_tag("gradient_f(x1)") == triplet.gradient
|
103
104
|
assert pep_context.get_by_tag("x1+x2") == p3
|
104
105
|
pc.set_current_context(None)
|
106
|
+
|
107
|
+
|
108
|
+
def test_basis_points(pep_context: pc.PEPContext):
|
109
|
+
p1 = Point(is_basis=True, tags=["x1"])
|
110
|
+
p2 = Point(is_basis=True, tags=["x2"])
|
111
|
+
_ = p1 + p2 # not basis
|
112
|
+
ps = Point(is_basis=True, tags=["x_star"])
|
113
|
+
p0 = Point(is_basis=True, tags=["x0"])
|
114
|
+
|
115
|
+
assert pep_context.basis_points() == [p1, p2, ps, p0]
|
116
|
+
|
117
|
+
|
118
|
+
def test_basis_scalars(pep_context: pc.PEPContext):
|
119
|
+
p1 = Point(is_basis=True, tags=["x1"])
|
120
|
+
p2 = Point(is_basis=True, tags=["x2"])
|
121
|
+
_ = p1 * p2 # not basis
|
122
|
+
s1 = Scalar(is_basis=True, tags=["s2"])
|
123
|
+
s2 = Scalar(is_basis=True, tags=["s1"])
|
124
|
+
|
125
|
+
assert pep_context.basis_scalars() == [s1, s2]
|