pepflow 0.1.4__py3-none-any.whl → 0.1.4a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pepflow/__init__.py +1 -0
- pepflow/constraint_test.py +71 -0
- pepflow/e2e_test.py +37 -2
- pepflow/expression_manager.py +72 -2
- pepflow/expression_manager_test.py +116 -0
- pepflow/function.py +115 -43
- pepflow/function_test.py +180 -114
- pepflow/interactive_constraint.py +165 -75
- pepflow/pep.py +9 -2
- pepflow/pep_context.py +6 -6
- pepflow/pep_context_test.py +4 -0
- pepflow/pep_test.py +8 -0
- pepflow/point_test.py +31 -191
- pepflow/scalar_test.py +91 -134
- {pepflow-0.1.4.dist-info → pepflow-0.1.4a1.dist-info}/METADATA +19 -1
- pepflow-0.1.4a1.dist-info/RECORD +26 -0
- pepflow-0.1.4.dist-info/RECORD +0 -24
- {pepflow-0.1.4.dist-info → pepflow-0.1.4a1.dist-info}/WHEEL +0 -0
- {pepflow-0.1.4.dist-info → pepflow-0.1.4a1.dist-info}/licenses/LICENSE +0 -0
- {pepflow-0.1.4.dist-info → pepflow-0.1.4a1.dist-info}/top_level.txt +0 -0
pepflow/__init__.py
CHANGED
@@ -35,6 +35,7 @@ from .pep_context import set_current_context as set_current_context
|
|
35
35
|
# Function, Point, Scalar
|
36
36
|
from .function import Function as Function
|
37
37
|
from .function import SmoothConvexFunction as SmoothConvexFunction
|
38
|
+
from .function import ConvexFunction as ConvexFunction
|
38
39
|
from .function import Triplet as Triplet
|
39
40
|
from .point import EvaluatedPoint as EvaluatedPoint
|
40
41
|
from .point import Point as Point
|
@@ -0,0 +1,71 @@
|
|
1
|
+
# Copyright: 2025 The PEPFlow Developers
|
2
|
+
#
|
3
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
4
|
+
# or more contributor license agreements. See the NOTICE file
|
5
|
+
# distributed with this work for additional information
|
6
|
+
# regarding copyright ownership. The ASF licenses this file
|
7
|
+
# to you under the Apache License, Version 2.0 (the
|
8
|
+
# "License"); you may not use this file except in compliance
|
9
|
+
# with the License. You may obtain a copy of the License at
|
10
|
+
#
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12
|
+
#
|
13
|
+
# Unless required by applicable law or agreed to in writing,
|
14
|
+
# software distributed under the License is distributed on an
|
15
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
16
|
+
# KIND, either express or implied. See the License for the
|
17
|
+
# specific language governing permissions and limitations
|
18
|
+
# under the License.
|
19
|
+
|
20
|
+
|
21
|
+
from typing import Iterator
|
22
|
+
|
23
|
+
import numpy as np
|
24
|
+
import pytest
|
25
|
+
|
26
|
+
from pepflow import expression_manager as exm
|
27
|
+
from pepflow import pep as pep
|
28
|
+
from pepflow import pep_context as pc
|
29
|
+
from pepflow import scalar, utils
|
30
|
+
|
31
|
+
|
32
|
+
@pytest.fixture
|
33
|
+
def pep_context() -> Iterator[pc.PEPContext]:
|
34
|
+
"""Prepare the pep context and reset the context to None at the end."""
|
35
|
+
ctx = pc.PEPContext("test").set_as_current()
|
36
|
+
yield ctx
|
37
|
+
pc.set_current_context(None)
|
38
|
+
|
39
|
+
|
40
|
+
def test_constraint(pep_context: pc.PEPContext):
|
41
|
+
s1 = scalar.Scalar(is_basis=True, tags=["s1"])
|
42
|
+
s2 = scalar.Scalar(is_basis=True, tags=["s2"])
|
43
|
+
s3 = 2 * s1 + s2 / 4 + 5
|
44
|
+
|
45
|
+
c1 = s3.le(5, name="c1")
|
46
|
+
c2 = s3.lt(5, name="c2")
|
47
|
+
c3 = s3.ge(5, name="c3")
|
48
|
+
c4 = s3.gt(5, name="c4")
|
49
|
+
c5 = s3.eq(5, name="c5")
|
50
|
+
|
51
|
+
pm = exm.ExpressionManager(pep_context)
|
52
|
+
|
53
|
+
np.testing.assert_allclose(pm.eval_scalar(c1.scalar).vector, np.array([2, 0.25]))
|
54
|
+
np.testing.assert_allclose(pm.eval_scalar(c1.scalar).constant, 0)
|
55
|
+
assert c1.comparator == utils.Comparator.LT
|
56
|
+
|
57
|
+
np.testing.assert_allclose(pm.eval_scalar(c2.scalar).vector, np.array([2, 0.25]))
|
58
|
+
np.testing.assert_allclose(pm.eval_scalar(c2.scalar).constant, 0)
|
59
|
+
assert c2.comparator == utils.Comparator.LT
|
60
|
+
|
61
|
+
np.testing.assert_allclose(pm.eval_scalar(c3.scalar).vector, np.array([2, 0.25]))
|
62
|
+
np.testing.assert_allclose(pm.eval_scalar(c3.scalar).constant, 0)
|
63
|
+
assert c3.comparator == utils.Comparator.GT
|
64
|
+
|
65
|
+
np.testing.assert_allclose(pm.eval_scalar(c4.scalar).vector, np.array([2, 0.25]))
|
66
|
+
np.testing.assert_allclose(pm.eval_scalar(c4.scalar).constant, 0)
|
67
|
+
assert c4.comparator == utils.Comparator.GT
|
68
|
+
|
69
|
+
np.testing.assert_allclose(pm.eval_scalar(c5.scalar).vector, np.array([2, 0.25]))
|
70
|
+
np.testing.assert_allclose(pm.eval_scalar(c5.scalar).constant, 0)
|
71
|
+
assert c5.comparator == utils.Comparator.EQ
|
pepflow/e2e_test.py
CHANGED
@@ -10,8 +10,7 @@ def test_gd_e2e():
|
|
10
10
|
eta = 1
|
11
11
|
N = 9
|
12
12
|
|
13
|
-
f = pep_builder.declare_func(function.SmoothConvexFunction, L=1)
|
14
|
-
f.add_tag("f")
|
13
|
+
f = pep_builder.declare_func(function.SmoothConvexFunction, "f", L=1)
|
15
14
|
x = pep_builder.set_init_point("x_0")
|
16
15
|
x_star = f.add_stationary_point("x_star")
|
17
16
|
pep_builder.set_initial_constraint(
|
@@ -32,3 +31,39 @@ def test_gd_e2e():
|
|
32
31
|
result = pep_builder.solve()
|
33
32
|
expected_opt_value = 1 / (4 * i + 2)
|
34
33
|
assert math.isclose(result.primal_opt_value, expected_opt_value, rel_tol=1e-3)
|
34
|
+
|
35
|
+
|
36
|
+
def test_pgm_e2e():
|
37
|
+
ctx = pc.PEPContext("pgm").set_as_current()
|
38
|
+
pep_builder = pep.PEPBuilder()
|
39
|
+
eta = 1
|
40
|
+
N = 1
|
41
|
+
|
42
|
+
f = pep_builder.declare_func(function.SmoothConvexFunction, "f", L=1)
|
43
|
+
g = pep_builder.declare_func(function.ConvexFunction, "g")
|
44
|
+
|
45
|
+
h = f + g
|
46
|
+
|
47
|
+
x = pep_builder.set_init_point("x_0")
|
48
|
+
x_star = h.add_stationary_point("x_star")
|
49
|
+
pep_builder.set_initial_constraint(
|
50
|
+
((x - x_star) ** 2).le(1, name="initial_condition")
|
51
|
+
)
|
52
|
+
|
53
|
+
# We first build the algorithm with the largest number of iterations.
|
54
|
+
for i in range(N):
|
55
|
+
y = x - eta * f.gradient(x)
|
56
|
+
y.add_tag(f"y_{i + 1}")
|
57
|
+
x = g.proximal_step(y, eta)
|
58
|
+
x.add_tag(f"x_{i + 1}")
|
59
|
+
|
60
|
+
# To achieve the sweep, we can just update the performance_metric.
|
61
|
+
for i in range(1, N + 1):
|
62
|
+
p = ctx.get_by_tag(f"x_{i}")
|
63
|
+
pep_builder.set_performance_metric(
|
64
|
+
h.function_value(p) - h.function_value(x_star)
|
65
|
+
)
|
66
|
+
|
67
|
+
result = pep_builder.solve()
|
68
|
+
expected_opt_value = 1 / (4 * i)
|
69
|
+
assert math.isclose(result.primal_opt_value, expected_opt_value, rel_tol=1e-3)
|
pepflow/expression_manager.py
CHANGED
@@ -18,6 +18,7 @@
|
|
18
18
|
# under the License.
|
19
19
|
|
20
20
|
import functools
|
21
|
+
import math
|
21
22
|
|
22
23
|
import numpy as np
|
23
24
|
|
@@ -27,6 +28,17 @@ from pepflow import scalar as sc
|
|
27
28
|
from pepflow import utils
|
28
29
|
|
29
30
|
|
31
|
+
def tag_and_coef_to_str(tag: str, v: float) -> str:
|
32
|
+
coef = f"{abs(v):.3g}"
|
33
|
+
sign = "+" if v >= 0 else "-"
|
34
|
+
if math.isclose(abs(v), 1):
|
35
|
+
return f"{sign} {tag} "
|
36
|
+
elif math.isclose(v, 0):
|
37
|
+
return ""
|
38
|
+
else:
|
39
|
+
return f"{sign} {coef}*{tag} "
|
40
|
+
|
41
|
+
|
30
42
|
class ExpressionManager:
|
31
43
|
def __init__(self, pep_context: pc.PEPContext):
|
32
44
|
self.context = pep_context
|
@@ -48,12 +60,18 @@ class ExpressionManager:
|
|
48
60
|
self._num_basis_points = len(self._basis_points)
|
49
61
|
self._num_basis_scalars = len(self._basis_scalars)
|
50
62
|
|
51
|
-
def get_index_of_basis_point(self, point: pt.Point):
|
63
|
+
def get_index_of_basis_point(self, point: pt.Point) -> int:
|
52
64
|
return self._basis_point_uid_to_index[point.uid]
|
53
65
|
|
54
|
-
def get_index_of_basis_scalar(self, scalar: sc.Scalar):
|
66
|
+
def get_index_of_basis_scalar(self, scalar: sc.Scalar) -> int:
|
55
67
|
return self._basis_scalar_uid_to_index[scalar.uid]
|
56
68
|
|
69
|
+
def get_tag_of_basis_point_index(self, index: int) -> str:
|
70
|
+
return self._basis_points[index].tag
|
71
|
+
|
72
|
+
def get_tag_of_basis_scalar_index(self, index: int) -> str:
|
73
|
+
return self._basis_scalars[index].tag
|
74
|
+
|
57
75
|
@functools.cache
|
58
76
|
def eval_point(self, point: pt.Point | float | int):
|
59
77
|
if utils.is_numerical(point):
|
@@ -130,3 +148,55 @@ class ExpressionManager:
|
|
130
148
|
) / self.eval_scalar(scalar.eval_expression.right_scalar)
|
131
149
|
|
132
150
|
raise ValueError("This should never happen!")
|
151
|
+
|
152
|
+
@functools.cache
|
153
|
+
def repr_point_by_basis(self, point: pt.Point) -> str:
|
154
|
+
assert isinstance(point, pt.Point)
|
155
|
+
repr_array = self.eval_point(point).vector
|
156
|
+
|
157
|
+
repr_str = ""
|
158
|
+
for i, v in enumerate(repr_array):
|
159
|
+
ith_tag = self.get_tag_of_basis_point_index(i)
|
160
|
+
repr_str += tag_and_coef_to_str(ith_tag, v)
|
161
|
+
|
162
|
+
# Post processing
|
163
|
+
if repr_str == "":
|
164
|
+
return "0"
|
165
|
+
if repr_str.startswith("+ "):
|
166
|
+
repr_str = repr_str[2:]
|
167
|
+
if repr_str.startswith("- "):
|
168
|
+
repr_str = "-" + repr_str[2:]
|
169
|
+
return repr_str.strip()
|
170
|
+
|
171
|
+
@functools.cache
|
172
|
+
def repr_scalar_by_basis(self, scalar: sc.Scalar) -> str:
|
173
|
+
assert isinstance(scalar, sc.Scalar)
|
174
|
+
evaluated_scalar = self.eval_scalar(scalar)
|
175
|
+
|
176
|
+
repr_str = ""
|
177
|
+
if not math.isclose(evaluated_scalar.constant, 0):
|
178
|
+
repr_str += f"{evaluated_scalar.constant:.3g}"
|
179
|
+
|
180
|
+
for i, v in enumerate(evaluated_scalar.vector):
|
181
|
+
# Note the tag is from scalar basis.
|
182
|
+
ith_tag = self.get_tag_of_basis_scalar_index(i)
|
183
|
+
repr_str += tag_and_coef_to_str(ith_tag, v)
|
184
|
+
|
185
|
+
for i in range(evaluated_scalar.matrix.shape[0]):
|
186
|
+
for j in range(i, evaluated_scalar.matrix.shape[0]):
|
187
|
+
ith_tag = self.get_tag_of_basis_point_index(i)
|
188
|
+
v = evaluated_scalar.matrix[i, j]
|
189
|
+
if i == j:
|
190
|
+
repr_str += tag_and_coef_to_str(f"|{ith_tag}|^2", v)
|
191
|
+
continue
|
192
|
+
jth_tag = self.get_tag_of_basis_point_index(j)
|
193
|
+
repr_str += tag_and_coef_to_str(f"<{ith_tag}, {jth_tag}>", 2 * v)
|
194
|
+
|
195
|
+
# Post processing
|
196
|
+
if repr_str == "":
|
197
|
+
return "0"
|
198
|
+
if repr_str.startswith("+ "):
|
199
|
+
repr_str = repr_str[2:]
|
200
|
+
if repr_str.startswith("- "):
|
201
|
+
repr_str = "-" + repr_str[2:]
|
202
|
+
return repr_str.strip()
|
@@ -0,0 +1,116 @@
|
|
1
|
+
# Copyright: 2025 The PEPFlow Developers
|
2
|
+
#
|
3
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
4
|
+
# or more contributor license agreements. See the NOTICE file
|
5
|
+
# distributed with this work for additional information
|
6
|
+
# regarding copyright ownership. The ASF licenses this file
|
7
|
+
# to you under the Apache License, Version 2.0 (the
|
8
|
+
# "License"); you may not use this file except in compliance
|
9
|
+
# with the License. You may obtain a copy of the License at
|
10
|
+
#
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12
|
+
#
|
13
|
+
# Unless required by applicable law or agreed to in writing,
|
14
|
+
# software distributed under the License is distributed on an
|
15
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
16
|
+
# KIND, either express or implied. See the License for the
|
17
|
+
# specific language governing permissions and limitations
|
18
|
+
# under the License.
|
19
|
+
|
20
|
+
|
21
|
+
from typing import Iterator
|
22
|
+
|
23
|
+
import numpy as np
|
24
|
+
import pytest
|
25
|
+
|
26
|
+
from pepflow import expression_manager as exm
|
27
|
+
from pepflow import function as fc
|
28
|
+
from pepflow import pep as pep
|
29
|
+
from pepflow import pep_context as pc
|
30
|
+
from pepflow import point as pt
|
31
|
+
|
32
|
+
|
33
|
+
@pytest.fixture
|
34
|
+
def pep_context() -> Iterator[pc.PEPContext]:
|
35
|
+
"""Prepare the pep context and reset the context to None at the end."""
|
36
|
+
ctx = pc.PEPContext("test").set_as_current()
|
37
|
+
yield ctx
|
38
|
+
pc.set_current_context(None)
|
39
|
+
|
40
|
+
|
41
|
+
def test_repr_point_by_basis(pep_context: pc.PEPContext) -> None:
|
42
|
+
x = pt.Point(is_basis=True, tags=["x_0"])
|
43
|
+
f = fc.Function(is_basis=True, tags=["f"])
|
44
|
+
L = 0.5
|
45
|
+
for i in range(2):
|
46
|
+
x = x - L * f.gradient(x)
|
47
|
+
x.add_tag(f"x_{i + 1}")
|
48
|
+
|
49
|
+
em = exm.ExpressionManager(pep_context)
|
50
|
+
np.testing.assert_allclose(em.eval_point(x).vector, [1, -0.5, -0.5])
|
51
|
+
assert (
|
52
|
+
em.repr_point_by_basis(x) == "x_0 - 0.5*gradient_f(x_0) - 0.5*gradient_f(x_1)"
|
53
|
+
)
|
54
|
+
|
55
|
+
|
56
|
+
def test_repr_point_by_basis_with_zero(pep_context: pc.PEPContext) -> None:
|
57
|
+
x = pt.Point(is_basis=True, tags=["x_0"])
|
58
|
+
_ = pt.Point(is_basis=True, tags=["x_unused"]) # Add this extra point.
|
59
|
+
f = fc.Function(is_basis=True, tags=["f"])
|
60
|
+
L = 0.5
|
61
|
+
for i in range(2):
|
62
|
+
x = x - L * f.gradient(x)
|
63
|
+
x.add_tag(f"x_{i + 1}")
|
64
|
+
|
65
|
+
em = exm.ExpressionManager(pep_context)
|
66
|
+
# Note the vector representation of point is different from previous case
|
67
|
+
# But the string representation is still the same.
|
68
|
+
np.testing.assert_allclose(em.eval_point(x).vector, [1, 0, -0.5, -0.5])
|
69
|
+
assert (
|
70
|
+
em.repr_point_by_basis(x) == "x_0 - 0.5*gradient_f(x_0) - 0.5*gradient_f(x_1)"
|
71
|
+
)
|
72
|
+
|
73
|
+
|
74
|
+
def test_repr_point_by_basis_heavy_ball(pep_context: pc.PEPContext) -> None:
|
75
|
+
x_prev = pt.Point(is_basis=True, tags=["x_{-1}"])
|
76
|
+
x = pt.Point(is_basis=True, tags=["x_0"])
|
77
|
+
f = fc.Function(is_basis=True, tags=["f"])
|
78
|
+
|
79
|
+
beta = 0.5
|
80
|
+
for i in range(2):
|
81
|
+
x_next = x - f.gradient(x) + beta * (x - x_prev)
|
82
|
+
x_next.add_tag(f"x_{i + 1}")
|
83
|
+
x_prev = x
|
84
|
+
x = x_next
|
85
|
+
|
86
|
+
em = exm.ExpressionManager(pep_context)
|
87
|
+
np.testing.assert_allclose(em.eval_point(x).vector, [-0.75, 1.75, -1.5, -1])
|
88
|
+
assert (
|
89
|
+
em.repr_point_by_basis(x)
|
90
|
+
== "-0.75*x_{-1} + 1.75*x_0 - 1.5*gradient_f(x_0) - gradient_f(x_1)"
|
91
|
+
)
|
92
|
+
|
93
|
+
|
94
|
+
def test_repr_scalar_by_basis(pep_context: pc.PEPContext) -> None:
|
95
|
+
x = pt.Point(is_basis=True, tags=["x"])
|
96
|
+
f = fc.Function(is_basis=True, tags=["f"])
|
97
|
+
|
98
|
+
s = f(x) + x * f.gradient(x)
|
99
|
+
em = exm.ExpressionManager(pep_context)
|
100
|
+
assert em.repr_scalar_by_basis(s) == "f(x) + <x, gradient_f(x)>"
|
101
|
+
|
102
|
+
|
103
|
+
def test_repr_scalar_by_basis_interpolation(pep_context: pc.PEPContext) -> None:
|
104
|
+
xi = pt.Point(is_basis=True, tags=["x_i"])
|
105
|
+
xj = pt.Point(is_basis=True, tags=["x_j"])
|
106
|
+
f = fc.SmoothConvexFunction(is_basis=True, L=1)
|
107
|
+
f.add_tag("f")
|
108
|
+
fi = f(xi) # noqa: F841
|
109
|
+
fj = f(xj) # noqa: F841
|
110
|
+
interp_scalar = f.interpolate_ineq("x_i", "x_j")
|
111
|
+
em = exm.ExpressionManager(pep_context)
|
112
|
+
expected_repr = "-f(x_i) + f(x_j) + <x_i, gradient_f(x_j)> - <x_j, gradient_f(x_j)> + 0.5*|gradient_f(x_i)|^2 - <gradient_f(x_i), gradient_f(x_j)> + 0.5*|gradient_f(x_j)|^2"
|
113
|
+
assert em.repr_scalar_by_basis(interp_scalar) == expected_repr
|
114
|
+
|
115
|
+
|
116
|
+
# TODO add more tests about repr_scalar_by_basis
|
pepflow/function.py
CHANGED
@@ -19,7 +19,9 @@
|
|
19
19
|
|
20
20
|
from __future__ import annotations
|
21
21
|
|
22
|
+
import numbers
|
22
23
|
import uuid
|
24
|
+
from typing import TYPE_CHECKING
|
23
25
|
|
24
26
|
import attrs
|
25
27
|
|
@@ -28,6 +30,9 @@ from pepflow import point as pt
|
|
28
30
|
from pepflow import scalar as sc
|
29
31
|
from pepflow import utils
|
30
32
|
|
33
|
+
if TYPE_CHECKING:
|
34
|
+
from pepflow.constraint import Constraint
|
35
|
+
|
31
36
|
|
32
37
|
@attrs.frozen
|
33
38
|
class Triplet:
|
@@ -57,7 +62,6 @@ class ScaledFunc:
|
|
57
62
|
@attrs.mutable
|
58
63
|
class Function:
|
59
64
|
is_basis: bool
|
60
|
-
reuse_gradient: bool
|
61
65
|
|
62
66
|
composition: AddedFunc | ScaledFunc | None = None
|
63
67
|
|
@@ -158,11 +162,19 @@ class Function:
|
|
158
162
|
|
159
163
|
def add_stationary_point(self, name: str) -> pt.Point:
|
160
164
|
# assert we can only add one stationary point?
|
165
|
+
pep_context = pc.get_current_context()
|
166
|
+
if pep_context is None:
|
167
|
+
raise RuntimeError("Did you forget to create a context?")
|
168
|
+
if len(pep_context.stationary_triplets[self]) > 0:
|
169
|
+
raise ValueError(
|
170
|
+
"You are trying to add a stationary point to a function that already has a stationary point."
|
171
|
+
)
|
161
172
|
point = pt.Point(is_basis=True)
|
162
173
|
point.add_tag(name)
|
163
174
|
desired_grad = 0 * point
|
164
175
|
desired_grad.add_tag(f"gradient_{self.tag}({name})")
|
165
|
-
self.add_point_with_grad_restriction(point, desired_grad)
|
176
|
+
triplet = self.add_point_with_grad_restriction(point, desired_grad)
|
177
|
+
pep_context.add_stationary_triplet(self, triplet)
|
166
178
|
return point
|
167
179
|
|
168
180
|
# The following the old gradient(opt) = 0 constraint style.
|
@@ -184,42 +196,22 @@ class Function:
|
|
184
196
|
raise RuntimeError("Did you forget to create a context?")
|
185
197
|
|
186
198
|
if self.is_basis:
|
187
|
-
generate_new_basis = True
|
188
|
-
instances_of_point = 0
|
189
199
|
for triplet in pep_context.triplets[self]:
|
190
200
|
if triplet.point.uid == point.uid:
|
191
|
-
|
192
|
-
generate_new_basis = False
|
193
|
-
previous_triplet = triplet
|
201
|
+
return triplet
|
194
202
|
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
gradient.add_tag(f"gradient_{self.tag}({point.tag})")
|
203
|
+
function_value = sc.Scalar(is_basis=True)
|
204
|
+
function_value.add_tag(f"{self.tag}({point.tag})")
|
205
|
+
gradient = pt.Point(is_basis=True)
|
206
|
+
gradient.add_tag(f"gradient_{self.tag}({point.tag})")
|
200
207
|
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
elif not generate_new_basis and self.reuse_gradient:
|
209
|
-
function_value = previous_triplet.function_value
|
210
|
-
gradient = previous_triplet.gradient
|
211
|
-
elif not generate_new_basis and not self.reuse_gradient:
|
212
|
-
function_value = previous_triplet.function_value
|
213
|
-
gradient = pt.Point(is_basis=True)
|
214
|
-
gradient.add_tag(f"gradient_{self.tag}({point.tag})")
|
215
|
-
|
216
|
-
new_triplet = Triplet(
|
217
|
-
point,
|
218
|
-
previous_triplet.function_value,
|
219
|
-
gradient,
|
220
|
-
name=f"{point.tag}_{function_value.tag}_{gradient.tag}_{instances_of_point}",
|
221
|
-
)
|
222
|
-
self.add_triplet_to_func(new_triplet)
|
208
|
+
new_triplet = Triplet(
|
209
|
+
point,
|
210
|
+
function_value,
|
211
|
+
gradient,
|
212
|
+
name=f"{point.tag}_{function_value.tag}_{gradient.tag}",
|
213
|
+
)
|
214
|
+
self.add_triplet_to_func(new_triplet)
|
223
215
|
else:
|
224
216
|
if isinstance(self.composition, AddedFunc):
|
225
217
|
left_triplet = self.composition.left_func.generate_triplet(point)
|
@@ -258,7 +250,6 @@ class Function:
|
|
258
250
|
assert isinstance(other, Function)
|
259
251
|
return Function(
|
260
252
|
is_basis=False,
|
261
|
-
reuse_gradient=self.reuse_gradient and other.reuse_gradient,
|
262
253
|
composition=AddedFunc(self, other),
|
263
254
|
tags=[f"{self.tag}+{other.tag}"],
|
264
255
|
)
|
@@ -270,7 +261,6 @@ class Function:
|
|
270
261
|
tag_other = f"({other.tag})"
|
271
262
|
return Function(
|
272
263
|
is_basis=False,
|
273
|
-
reuse_gradient=self.reuse_gradient and other.reuse_gradient,
|
274
264
|
composition=AddedFunc(self, -other),
|
275
265
|
tags=[f"{self.tag}-{tag_other}"],
|
276
266
|
)
|
@@ -282,7 +272,6 @@ class Function:
|
|
282
272
|
tag_self = f"({self.tag})"
|
283
273
|
return Function(
|
284
274
|
is_basis=False,
|
285
|
-
reuse_gradient=self.reuse_gradient,
|
286
275
|
composition=ScaledFunc(scale=other, base_func=self),
|
287
276
|
tags=[f"{other:.4g}*{tag_self}"],
|
288
277
|
)
|
@@ -294,7 +283,6 @@ class Function:
|
|
294
283
|
tag_self = f"({self.tag})"
|
295
284
|
return Function(
|
296
285
|
is_basis=False,
|
297
|
-
reuse_gradient=self.reuse_gradient,
|
298
286
|
composition=ScaledFunc(scale=other, base_func=self),
|
299
287
|
tags=[f"{other:.4g}*{tag_self}"],
|
300
288
|
)
|
@@ -305,7 +293,6 @@ class Function:
|
|
305
293
|
tag_self = f"({self.tag})"
|
306
294
|
return Function(
|
307
295
|
is_basis=False,
|
308
|
-
reuse_gradient=self.reuse_gradient,
|
309
296
|
composition=ScaledFunc(scale=-1, base_func=self),
|
310
297
|
tags=[f"-{tag_self}"],
|
311
298
|
)
|
@@ -317,7 +304,6 @@ class Function:
|
|
317
304
|
tag_self = f"({self.tag})"
|
318
305
|
return Function(
|
319
306
|
is_basis=False,
|
320
|
-
reuse_gradient=self.reuse_gradient,
|
321
307
|
composition=ScaledFunc(scale=1 / other, base_func=self),
|
322
308
|
tags=[f"1/{other:.4g}*{tag_self}"],
|
323
309
|
)
|
@@ -331,10 +317,96 @@ class Function:
|
|
331
317
|
return self.uid == other.uid
|
332
318
|
|
333
319
|
|
320
|
+
class ConvexFunction(Function):
|
321
|
+
def __init__(
|
322
|
+
self,
|
323
|
+
is_basis=True,
|
324
|
+
composition=None,
|
325
|
+
):
|
326
|
+
super().__init__(
|
327
|
+
is_basis=is_basis,
|
328
|
+
composition=composition,
|
329
|
+
)
|
330
|
+
|
331
|
+
def convex_interpolability_constraints(
|
332
|
+
self, triplet_i: Triplet, triplet_j: Triplet
|
333
|
+
) -> Constraint:
|
334
|
+
point_i = triplet_i.point
|
335
|
+
function_value_i = triplet_i.function_value
|
336
|
+
|
337
|
+
point_j = triplet_j.point
|
338
|
+
function_value_j = triplet_j.function_value
|
339
|
+
grad_j = triplet_j.gradient
|
340
|
+
|
341
|
+
func_diff = function_value_j - function_value_i
|
342
|
+
cross_term = grad_j * (point_i - point_j)
|
343
|
+
|
344
|
+
return (func_diff + cross_term).le(
|
345
|
+
0, name=f"{self.tag}:{point_i.tag},{point_j.tag}"
|
346
|
+
)
|
347
|
+
|
348
|
+
def get_interpolation_constraints(
|
349
|
+
self, pep_context: pc.PEPContext | None = None
|
350
|
+
) -> list[Constraint]:
|
351
|
+
interpolation_constraints = []
|
352
|
+
if pep_context is None:
|
353
|
+
pep_context = pc.get_current_context()
|
354
|
+
if pep_context is None:
|
355
|
+
raise RuntimeError("Did you forget to create a context?")
|
356
|
+
for i in pep_context.triplets[self]:
|
357
|
+
for j in pep_context.triplets[self]:
|
358
|
+
if i == j:
|
359
|
+
continue
|
360
|
+
interpolation_constraints.append(
|
361
|
+
self.convex_interpolability_constraints(i, j)
|
362
|
+
)
|
363
|
+
return interpolation_constraints
|
364
|
+
|
365
|
+
def interpolate_ineq(
|
366
|
+
self, p1_tag: str, p2_tag: str, pep_context: pc.PEPContext | None = None
|
367
|
+
) -> sc.Scalar:
|
368
|
+
"""Generate the interpolation inequality scalar by tags."""
|
369
|
+
if pep_context is None:
|
370
|
+
pep_context = pc.get_current_context()
|
371
|
+
if pep_context is None:
|
372
|
+
raise RuntimeError("Did you forget to specify a context?")
|
373
|
+
# TODO: we definitely need a more robust tag system
|
374
|
+
x1 = pep_context.get_by_tag(p1_tag)
|
375
|
+
x2 = pep_context.get_by_tag(p2_tag)
|
376
|
+
f1 = pep_context.get_by_tag(f"{self.tag}({p1_tag})")
|
377
|
+
f2 = pep_context.get_by_tag(f"{self.tag}({p2_tag})")
|
378
|
+
g2 = pep_context.get_by_tag(f"gradient_{self.tag}({p2_tag})")
|
379
|
+
return f2 - f1 + g2 * (x1 - x2)
|
380
|
+
|
381
|
+
def proximal_step(self, x_0: pt.Point, stepsize: numbers.Number) -> pt.Point:
|
382
|
+
gradient = pt.Point(is_basis=True)
|
383
|
+
gradient.add_tag(
|
384
|
+
f"gradient_{self.tag}(prox_{{{stepsize}*{self.tag}}}({x_0.tag}))"
|
385
|
+
)
|
386
|
+
function_value = sc.Scalar(is_basis=True)
|
387
|
+
function_value.add_tag(f"{self.tag}(prox_{{{stepsize}*{self.tag}}}({x_0.tag}))")
|
388
|
+
x = x_0 - stepsize * gradient
|
389
|
+
x.add_tag(f"prox_{{{stepsize}*{self.tag}}}({x_0.tag})")
|
390
|
+
new_triplet = Triplet(
|
391
|
+
x,
|
392
|
+
function_value,
|
393
|
+
gradient,
|
394
|
+
name=f"{x.tag}_{function_value.tag}_{gradient.tag}",
|
395
|
+
)
|
396
|
+
self.add_triplet_to_func(new_triplet)
|
397
|
+
return x
|
398
|
+
|
399
|
+
|
334
400
|
class SmoothConvexFunction(Function):
|
335
|
-
def __init__(
|
401
|
+
def __init__(
|
402
|
+
self,
|
403
|
+
L,
|
404
|
+
is_basis=True,
|
405
|
+
composition=None,
|
406
|
+
):
|
336
407
|
super().__init__(
|
337
|
-
is_basis=is_basis,
|
408
|
+
is_basis=is_basis,
|
409
|
+
composition=composition,
|
338
410
|
)
|
339
411
|
self.L = L
|
340
412
|
|
@@ -372,7 +444,7 @@ class SmoothConvexFunction(Function):
|
|
372
444
|
|
373
445
|
def interpolate_ineq(
|
374
446
|
self, p1_tag: str, p2_tag: str, pep_context: pc.PEPContext | None = None
|
375
|
-
) ->
|
447
|
+
) -> sc.Scalar:
|
376
448
|
"""Generate the interpolation inequality scalar by tags."""
|
377
449
|
if pep_context is None:
|
378
450
|
pep_context = pc.get_current_context()
|