pepflow 0.1.3a1__py3-none-any.whl → 0.1.4a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pepflow/__init__.py +1 -0
- pepflow/constraint_test.py +71 -0
- pepflow/e2e_test.py +69 -0
- pepflow/expression_manager.py +72 -2
- pepflow/expression_manager_test.py +116 -0
- pepflow/function.py +142 -48
- pepflow/function_test.py +249 -108
- pepflow/interactive_constraint.py +165 -75
- pepflow/pep.py +18 -3
- pepflow/pep_context.py +12 -7
- pepflow/pep_context_test.py +23 -21
- pepflow/pep_test.py +8 -0
- pepflow/point.py +43 -8
- pepflow/point_test.py +106 -308
- pepflow/scalar.py +39 -1
- pepflow/scalar_test.py +207 -0
- pepflow/solver_test.py +7 -7
- pepflow/utils.py +14 -1
- {pepflow-0.1.3a1.dist-info → pepflow-0.1.4a1.dist-info}/METADATA +19 -1
- pepflow-0.1.4a1.dist-info/RECORD +26 -0
- pepflow-0.1.3a1.dist-info/RECORD +0 -22
- {pepflow-0.1.3a1.dist-info → pepflow-0.1.4a1.dist-info}/WHEEL +0 -0
- {pepflow-0.1.3a1.dist-info → pepflow-0.1.4a1.dist-info}/licenses/LICENSE +0 -0
- {pepflow-0.1.3a1.dist-info → pepflow-0.1.4a1.dist-info}/top_level.txt +0 -0
pepflow/__init__.py
CHANGED
@@ -35,6 +35,7 @@ from .pep_context import set_current_context as set_current_context
|
|
35
35
|
# Function, Point, Scalar
|
36
36
|
from .function import Function as Function
|
37
37
|
from .function import SmoothConvexFunction as SmoothConvexFunction
|
38
|
+
from .function import ConvexFunction as ConvexFunction
|
38
39
|
from .function import Triplet as Triplet
|
39
40
|
from .point import EvaluatedPoint as EvaluatedPoint
|
40
41
|
from .point import Point as Point
|
@@ -0,0 +1,71 @@
|
|
1
|
+
# Copyright: 2025 The PEPFlow Developers
|
2
|
+
#
|
3
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
4
|
+
# or more contributor license agreements. See the NOTICE file
|
5
|
+
# distributed with this work for additional information
|
6
|
+
# regarding copyright ownership. The ASF licenses this file
|
7
|
+
# to you under the Apache License, Version 2.0 (the
|
8
|
+
# "License"); you may not use this file except in compliance
|
9
|
+
# with the License. You may obtain a copy of the License at
|
10
|
+
#
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12
|
+
#
|
13
|
+
# Unless required by applicable law or agreed to in writing,
|
14
|
+
# software distributed under the License is distributed on an
|
15
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
16
|
+
# KIND, either express or implied. See the License for the
|
17
|
+
# specific language governing permissions and limitations
|
18
|
+
# under the License.
|
19
|
+
|
20
|
+
|
21
|
+
from typing import Iterator
|
22
|
+
|
23
|
+
import numpy as np
|
24
|
+
import pytest
|
25
|
+
|
26
|
+
from pepflow import expression_manager as exm
|
27
|
+
from pepflow import pep as pep
|
28
|
+
from pepflow import pep_context as pc
|
29
|
+
from pepflow import scalar, utils
|
30
|
+
|
31
|
+
|
32
|
+
@pytest.fixture
|
33
|
+
def pep_context() -> Iterator[pc.PEPContext]:
|
34
|
+
"""Prepare the pep context and reset the context to None at the end."""
|
35
|
+
ctx = pc.PEPContext("test").set_as_current()
|
36
|
+
yield ctx
|
37
|
+
pc.set_current_context(None)
|
38
|
+
|
39
|
+
|
40
|
+
def test_constraint(pep_context: pc.PEPContext):
|
41
|
+
s1 = scalar.Scalar(is_basis=True, tags=["s1"])
|
42
|
+
s2 = scalar.Scalar(is_basis=True, tags=["s2"])
|
43
|
+
s3 = 2 * s1 + s2 / 4 + 5
|
44
|
+
|
45
|
+
c1 = s3.le(5, name="c1")
|
46
|
+
c2 = s3.lt(5, name="c2")
|
47
|
+
c3 = s3.ge(5, name="c3")
|
48
|
+
c4 = s3.gt(5, name="c4")
|
49
|
+
c5 = s3.eq(5, name="c5")
|
50
|
+
|
51
|
+
pm = exm.ExpressionManager(pep_context)
|
52
|
+
|
53
|
+
np.testing.assert_allclose(pm.eval_scalar(c1.scalar).vector, np.array([2, 0.25]))
|
54
|
+
np.testing.assert_allclose(pm.eval_scalar(c1.scalar).constant, 0)
|
55
|
+
assert c1.comparator == utils.Comparator.LT
|
56
|
+
|
57
|
+
np.testing.assert_allclose(pm.eval_scalar(c2.scalar).vector, np.array([2, 0.25]))
|
58
|
+
np.testing.assert_allclose(pm.eval_scalar(c2.scalar).constant, 0)
|
59
|
+
assert c2.comparator == utils.Comparator.LT
|
60
|
+
|
61
|
+
np.testing.assert_allclose(pm.eval_scalar(c3.scalar).vector, np.array([2, 0.25]))
|
62
|
+
np.testing.assert_allclose(pm.eval_scalar(c3.scalar).constant, 0)
|
63
|
+
assert c3.comparator == utils.Comparator.GT
|
64
|
+
|
65
|
+
np.testing.assert_allclose(pm.eval_scalar(c4.scalar).vector, np.array([2, 0.25]))
|
66
|
+
np.testing.assert_allclose(pm.eval_scalar(c4.scalar).constant, 0)
|
67
|
+
assert c4.comparator == utils.Comparator.GT
|
68
|
+
|
69
|
+
np.testing.assert_allclose(pm.eval_scalar(c5.scalar).vector, np.array([2, 0.25]))
|
70
|
+
np.testing.assert_allclose(pm.eval_scalar(c5.scalar).constant, 0)
|
71
|
+
assert c5.comparator == utils.Comparator.EQ
|
pepflow/e2e_test.py
ADDED
@@ -0,0 +1,69 @@
|
|
1
|
+
import math
|
2
|
+
|
3
|
+
from pepflow import function, pep
|
4
|
+
from pepflow import pep_context as pc
|
5
|
+
|
6
|
+
|
7
|
+
def test_gd_e2e():
|
8
|
+
ctx = pc.PEPContext("gd").set_as_current()
|
9
|
+
pep_builder = pep.PEPBuilder()
|
10
|
+
eta = 1
|
11
|
+
N = 9
|
12
|
+
|
13
|
+
f = pep_builder.declare_func(function.SmoothConvexFunction, "f", L=1)
|
14
|
+
x = pep_builder.set_init_point("x_0")
|
15
|
+
x_star = f.add_stationary_point("x_star")
|
16
|
+
pep_builder.set_initial_constraint(
|
17
|
+
((x - x_star) ** 2).le(1, name="initial_condition")
|
18
|
+
)
|
19
|
+
|
20
|
+
# We first build the algorithm with the largest number of iterations.
|
21
|
+
for i in range(N):
|
22
|
+
x = x - eta * f.gradient(x)
|
23
|
+
x.add_tag(f"x_{i + 1}")
|
24
|
+
|
25
|
+
# To achieve the sweep, we can just update the performance_metric.
|
26
|
+
for i in range(1, N + 1):
|
27
|
+
p = ctx.get_by_tag(f"x_{i}")
|
28
|
+
pep_builder.set_performance_metric(
|
29
|
+
f.function_value(p) - f.function_value(x_star)
|
30
|
+
)
|
31
|
+
result = pep_builder.solve()
|
32
|
+
expected_opt_value = 1 / (4 * i + 2)
|
33
|
+
assert math.isclose(result.primal_opt_value, expected_opt_value, rel_tol=1e-3)
|
34
|
+
|
35
|
+
|
36
|
+
def test_pgm_e2e():
|
37
|
+
ctx = pc.PEPContext("pgm").set_as_current()
|
38
|
+
pep_builder = pep.PEPBuilder()
|
39
|
+
eta = 1
|
40
|
+
N = 1
|
41
|
+
|
42
|
+
f = pep_builder.declare_func(function.SmoothConvexFunction, "f", L=1)
|
43
|
+
g = pep_builder.declare_func(function.ConvexFunction, "g")
|
44
|
+
|
45
|
+
h = f + g
|
46
|
+
|
47
|
+
x = pep_builder.set_init_point("x_0")
|
48
|
+
x_star = h.add_stationary_point("x_star")
|
49
|
+
pep_builder.set_initial_constraint(
|
50
|
+
((x - x_star) ** 2).le(1, name="initial_condition")
|
51
|
+
)
|
52
|
+
|
53
|
+
# We first build the algorithm with the largest number of iterations.
|
54
|
+
for i in range(N):
|
55
|
+
y = x - eta * f.gradient(x)
|
56
|
+
y.add_tag(f"y_{i + 1}")
|
57
|
+
x = g.proximal_step(y, eta)
|
58
|
+
x.add_tag(f"x_{i + 1}")
|
59
|
+
|
60
|
+
# To achieve the sweep, we can just update the performance_metric.
|
61
|
+
for i in range(1, N + 1):
|
62
|
+
p = ctx.get_by_tag(f"x_{i}")
|
63
|
+
pep_builder.set_performance_metric(
|
64
|
+
h.function_value(p) - h.function_value(x_star)
|
65
|
+
)
|
66
|
+
|
67
|
+
result = pep_builder.solve()
|
68
|
+
expected_opt_value = 1 / (4 * i)
|
69
|
+
assert math.isclose(result.primal_opt_value, expected_opt_value, rel_tol=1e-3)
|
pepflow/expression_manager.py
CHANGED
@@ -18,6 +18,7 @@
|
|
18
18
|
# under the License.
|
19
19
|
|
20
20
|
import functools
|
21
|
+
import math
|
21
22
|
|
22
23
|
import numpy as np
|
23
24
|
|
@@ -27,6 +28,17 @@ from pepflow import scalar as sc
|
|
27
28
|
from pepflow import utils
|
28
29
|
|
29
30
|
|
31
|
+
def tag_and_coef_to_str(tag: str, v: float) -> str:
|
32
|
+
coef = f"{abs(v):.3g}"
|
33
|
+
sign = "+" if v >= 0 else "-"
|
34
|
+
if math.isclose(abs(v), 1):
|
35
|
+
return f"{sign} {tag} "
|
36
|
+
elif math.isclose(v, 0):
|
37
|
+
return ""
|
38
|
+
else:
|
39
|
+
return f"{sign} {coef}*{tag} "
|
40
|
+
|
41
|
+
|
30
42
|
class ExpressionManager:
|
31
43
|
def __init__(self, pep_context: pc.PEPContext):
|
32
44
|
self.context = pep_context
|
@@ -48,12 +60,18 @@ class ExpressionManager:
|
|
48
60
|
self._num_basis_points = len(self._basis_points)
|
49
61
|
self._num_basis_scalars = len(self._basis_scalars)
|
50
62
|
|
51
|
-
def get_index_of_basis_point(self, point: pt.Point):
|
63
|
+
def get_index_of_basis_point(self, point: pt.Point) -> int:
|
52
64
|
return self._basis_point_uid_to_index[point.uid]
|
53
65
|
|
54
|
-
def get_index_of_basis_scalar(self, scalar: sc.Scalar):
|
66
|
+
def get_index_of_basis_scalar(self, scalar: sc.Scalar) -> int:
|
55
67
|
return self._basis_scalar_uid_to_index[scalar.uid]
|
56
68
|
|
69
|
+
def get_tag_of_basis_point_index(self, index: int) -> str:
|
70
|
+
return self._basis_points[index].tag
|
71
|
+
|
72
|
+
def get_tag_of_basis_scalar_index(self, index: int) -> str:
|
73
|
+
return self._basis_scalars[index].tag
|
74
|
+
|
57
75
|
@functools.cache
|
58
76
|
def eval_point(self, point: pt.Point | float | int):
|
59
77
|
if utils.is_numerical(point):
|
@@ -130,3 +148,55 @@ class ExpressionManager:
|
|
130
148
|
) / self.eval_scalar(scalar.eval_expression.right_scalar)
|
131
149
|
|
132
150
|
raise ValueError("This should never happen!")
|
151
|
+
|
152
|
+
@functools.cache
|
153
|
+
def repr_point_by_basis(self, point: pt.Point) -> str:
|
154
|
+
assert isinstance(point, pt.Point)
|
155
|
+
repr_array = self.eval_point(point).vector
|
156
|
+
|
157
|
+
repr_str = ""
|
158
|
+
for i, v in enumerate(repr_array):
|
159
|
+
ith_tag = self.get_tag_of_basis_point_index(i)
|
160
|
+
repr_str += tag_and_coef_to_str(ith_tag, v)
|
161
|
+
|
162
|
+
# Post processing
|
163
|
+
if repr_str == "":
|
164
|
+
return "0"
|
165
|
+
if repr_str.startswith("+ "):
|
166
|
+
repr_str = repr_str[2:]
|
167
|
+
if repr_str.startswith("- "):
|
168
|
+
repr_str = "-" + repr_str[2:]
|
169
|
+
return repr_str.strip()
|
170
|
+
|
171
|
+
@functools.cache
|
172
|
+
def repr_scalar_by_basis(self, scalar: sc.Scalar) -> str:
|
173
|
+
assert isinstance(scalar, sc.Scalar)
|
174
|
+
evaluated_scalar = self.eval_scalar(scalar)
|
175
|
+
|
176
|
+
repr_str = ""
|
177
|
+
if not math.isclose(evaluated_scalar.constant, 0):
|
178
|
+
repr_str += f"{evaluated_scalar.constant:.3g}"
|
179
|
+
|
180
|
+
for i, v in enumerate(evaluated_scalar.vector):
|
181
|
+
# Note the tag is from scalar basis.
|
182
|
+
ith_tag = self.get_tag_of_basis_scalar_index(i)
|
183
|
+
repr_str += tag_and_coef_to_str(ith_tag, v)
|
184
|
+
|
185
|
+
for i in range(evaluated_scalar.matrix.shape[0]):
|
186
|
+
for j in range(i, evaluated_scalar.matrix.shape[0]):
|
187
|
+
ith_tag = self.get_tag_of_basis_point_index(i)
|
188
|
+
v = evaluated_scalar.matrix[i, j]
|
189
|
+
if i == j:
|
190
|
+
repr_str += tag_and_coef_to_str(f"|{ith_tag}|^2", v)
|
191
|
+
continue
|
192
|
+
jth_tag = self.get_tag_of_basis_point_index(j)
|
193
|
+
repr_str += tag_and_coef_to_str(f"<{ith_tag}, {jth_tag}>", 2 * v)
|
194
|
+
|
195
|
+
# Post processing
|
196
|
+
if repr_str == "":
|
197
|
+
return "0"
|
198
|
+
if repr_str.startswith("+ "):
|
199
|
+
repr_str = repr_str[2:]
|
200
|
+
if repr_str.startswith("- "):
|
201
|
+
repr_str = "-" + repr_str[2:]
|
202
|
+
return repr_str.strip()
|
@@ -0,0 +1,116 @@
|
|
1
|
+
# Copyright: 2025 The PEPFlow Developers
|
2
|
+
#
|
3
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
4
|
+
# or more contributor license agreements. See the NOTICE file
|
5
|
+
# distributed with this work for additional information
|
6
|
+
# regarding copyright ownership. The ASF licenses this file
|
7
|
+
# to you under the Apache License, Version 2.0 (the
|
8
|
+
# "License"); you may not use this file except in compliance
|
9
|
+
# with the License. You may obtain a copy of the License at
|
10
|
+
#
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12
|
+
#
|
13
|
+
# Unless required by applicable law or agreed to in writing,
|
14
|
+
# software distributed under the License is distributed on an
|
15
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
16
|
+
# KIND, either express or implied. See the License for the
|
17
|
+
# specific language governing permissions and limitations
|
18
|
+
# under the License.
|
19
|
+
|
20
|
+
|
21
|
+
from typing import Iterator
|
22
|
+
|
23
|
+
import numpy as np
|
24
|
+
import pytest
|
25
|
+
|
26
|
+
from pepflow import expression_manager as exm
|
27
|
+
from pepflow import function as fc
|
28
|
+
from pepflow import pep as pep
|
29
|
+
from pepflow import pep_context as pc
|
30
|
+
from pepflow import point as pt
|
31
|
+
|
32
|
+
|
33
|
+
@pytest.fixture
|
34
|
+
def pep_context() -> Iterator[pc.PEPContext]:
|
35
|
+
"""Prepare the pep context and reset the context to None at the end."""
|
36
|
+
ctx = pc.PEPContext("test").set_as_current()
|
37
|
+
yield ctx
|
38
|
+
pc.set_current_context(None)
|
39
|
+
|
40
|
+
|
41
|
+
def test_repr_point_by_basis(pep_context: pc.PEPContext) -> None:
|
42
|
+
x = pt.Point(is_basis=True, tags=["x_0"])
|
43
|
+
f = fc.Function(is_basis=True, tags=["f"])
|
44
|
+
L = 0.5
|
45
|
+
for i in range(2):
|
46
|
+
x = x - L * f.gradient(x)
|
47
|
+
x.add_tag(f"x_{i + 1}")
|
48
|
+
|
49
|
+
em = exm.ExpressionManager(pep_context)
|
50
|
+
np.testing.assert_allclose(em.eval_point(x).vector, [1, -0.5, -0.5])
|
51
|
+
assert (
|
52
|
+
em.repr_point_by_basis(x) == "x_0 - 0.5*gradient_f(x_0) - 0.5*gradient_f(x_1)"
|
53
|
+
)
|
54
|
+
|
55
|
+
|
56
|
+
def test_repr_point_by_basis_with_zero(pep_context: pc.PEPContext) -> None:
|
57
|
+
x = pt.Point(is_basis=True, tags=["x_0"])
|
58
|
+
_ = pt.Point(is_basis=True, tags=["x_unused"]) # Add this extra point.
|
59
|
+
f = fc.Function(is_basis=True, tags=["f"])
|
60
|
+
L = 0.5
|
61
|
+
for i in range(2):
|
62
|
+
x = x - L * f.gradient(x)
|
63
|
+
x.add_tag(f"x_{i + 1}")
|
64
|
+
|
65
|
+
em = exm.ExpressionManager(pep_context)
|
66
|
+
# Note the vector representation of point is different from previous case
|
67
|
+
# But the string representation is still the same.
|
68
|
+
np.testing.assert_allclose(em.eval_point(x).vector, [1, 0, -0.5, -0.5])
|
69
|
+
assert (
|
70
|
+
em.repr_point_by_basis(x) == "x_0 - 0.5*gradient_f(x_0) - 0.5*gradient_f(x_1)"
|
71
|
+
)
|
72
|
+
|
73
|
+
|
74
|
+
def test_repr_point_by_basis_heavy_ball(pep_context: pc.PEPContext) -> None:
|
75
|
+
x_prev = pt.Point(is_basis=True, tags=["x_{-1}"])
|
76
|
+
x = pt.Point(is_basis=True, tags=["x_0"])
|
77
|
+
f = fc.Function(is_basis=True, tags=["f"])
|
78
|
+
|
79
|
+
beta = 0.5
|
80
|
+
for i in range(2):
|
81
|
+
x_next = x - f.gradient(x) + beta * (x - x_prev)
|
82
|
+
x_next.add_tag(f"x_{i + 1}")
|
83
|
+
x_prev = x
|
84
|
+
x = x_next
|
85
|
+
|
86
|
+
em = exm.ExpressionManager(pep_context)
|
87
|
+
np.testing.assert_allclose(em.eval_point(x).vector, [-0.75, 1.75, -1.5, -1])
|
88
|
+
assert (
|
89
|
+
em.repr_point_by_basis(x)
|
90
|
+
== "-0.75*x_{-1} + 1.75*x_0 - 1.5*gradient_f(x_0) - gradient_f(x_1)"
|
91
|
+
)
|
92
|
+
|
93
|
+
|
94
|
+
def test_repr_scalar_by_basis(pep_context: pc.PEPContext) -> None:
|
95
|
+
x = pt.Point(is_basis=True, tags=["x"])
|
96
|
+
f = fc.Function(is_basis=True, tags=["f"])
|
97
|
+
|
98
|
+
s = f(x) + x * f.gradient(x)
|
99
|
+
em = exm.ExpressionManager(pep_context)
|
100
|
+
assert em.repr_scalar_by_basis(s) == "f(x) + <x, gradient_f(x)>"
|
101
|
+
|
102
|
+
|
103
|
+
def test_repr_scalar_by_basis_interpolation(pep_context: pc.PEPContext) -> None:
|
104
|
+
xi = pt.Point(is_basis=True, tags=["x_i"])
|
105
|
+
xj = pt.Point(is_basis=True, tags=["x_j"])
|
106
|
+
f = fc.SmoothConvexFunction(is_basis=True, L=1)
|
107
|
+
f.add_tag("f")
|
108
|
+
fi = f(xi) # noqa: F841
|
109
|
+
fj = f(xj) # noqa: F841
|
110
|
+
interp_scalar = f.interpolate_ineq("x_i", "x_j")
|
111
|
+
em = exm.ExpressionManager(pep_context)
|
112
|
+
expected_repr = "-f(x_i) + f(x_j) + <x_i, gradient_f(x_j)> - <x_j, gradient_f(x_j)> + 0.5*|gradient_f(x_i)|^2 - <gradient_f(x_i), gradient_f(x_j)> + 0.5*|gradient_f(x_j)|^2"
|
113
|
+
assert em.repr_scalar_by_basis(interp_scalar) == expected_repr
|
114
|
+
|
115
|
+
|
116
|
+
# TODO add more tests about repr_scalar_by_basis
|
pepflow/function.py
CHANGED
@@ -19,7 +19,9 @@
|
|
19
19
|
|
20
20
|
from __future__ import annotations
|
21
21
|
|
22
|
+
import numbers
|
22
23
|
import uuid
|
24
|
+
from typing import TYPE_CHECKING
|
23
25
|
|
24
26
|
import attrs
|
25
27
|
|
@@ -28,6 +30,9 @@ from pepflow import point as pt
|
|
28
30
|
from pepflow import scalar as sc
|
29
31
|
from pepflow import utils
|
30
32
|
|
33
|
+
if TYPE_CHECKING:
|
34
|
+
from pepflow.constraint import Constraint
|
35
|
+
|
31
36
|
|
32
37
|
@attrs.frozen
|
33
38
|
class Triplet:
|
@@ -57,7 +62,6 @@ class ScaledFunc:
|
|
57
62
|
@attrs.mutable
|
58
63
|
class Function:
|
59
64
|
is_basis: bool
|
60
|
-
reuse_gradient: bool
|
61
65
|
|
62
66
|
composition: AddedFunc | ScaledFunc | None = None
|
63
67
|
|
@@ -87,6 +91,10 @@ class Function:
|
|
87
91
|
return self.tag
|
88
92
|
return super().__repr__()
|
89
93
|
|
94
|
+
def _repr_latex_(self):
|
95
|
+
s = repr(self)
|
96
|
+
return rf"$\\displaystyle {s}$"
|
97
|
+
|
90
98
|
def get_interpolation_constraints(self):
|
91
99
|
raise NotImplementedError(
|
92
100
|
"This method should be implemented in the children class."
|
@@ -154,11 +162,19 @@ class Function:
|
|
154
162
|
|
155
163
|
def add_stationary_point(self, name: str) -> pt.Point:
|
156
164
|
# assert we can only add one stationary point?
|
165
|
+
pep_context = pc.get_current_context()
|
166
|
+
if pep_context is None:
|
167
|
+
raise RuntimeError("Did you forget to create a context?")
|
168
|
+
if len(pep_context.stationary_triplets[self]) > 0:
|
169
|
+
raise ValueError(
|
170
|
+
"You are trying to add a stationary point to a function that already has a stationary point."
|
171
|
+
)
|
157
172
|
point = pt.Point(is_basis=True)
|
158
173
|
point.add_tag(name)
|
159
174
|
desired_grad = 0 * point
|
160
175
|
desired_grad.add_tag(f"gradient_{self.tag}({name})")
|
161
|
-
self.add_point_with_grad_restriction(point, desired_grad)
|
176
|
+
triplet = self.add_point_with_grad_restriction(point, desired_grad)
|
177
|
+
pep_context.add_stationary_triplet(self, triplet)
|
162
178
|
return point
|
163
179
|
|
164
180
|
# The following the old gradient(opt) = 0 constraint style.
|
@@ -180,42 +196,22 @@ class Function:
|
|
180
196
|
raise RuntimeError("Did you forget to create a context?")
|
181
197
|
|
182
198
|
if self.is_basis:
|
183
|
-
generate_new_basis = True
|
184
|
-
instances_of_point = 0
|
185
199
|
for triplet in pep_context.triplets[self]:
|
186
200
|
if triplet.point.uid == point.uid:
|
187
|
-
|
188
|
-
generate_new_basis = False
|
189
|
-
previous_triplet = triplet
|
201
|
+
return triplet
|
190
202
|
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
gradient.add_tag(f"gradient_{self.tag}({point.tag})")
|
203
|
+
function_value = sc.Scalar(is_basis=True)
|
204
|
+
function_value.add_tag(f"{self.tag}({point.tag})")
|
205
|
+
gradient = pt.Point(is_basis=True)
|
206
|
+
gradient.add_tag(f"gradient_{self.tag}({point.tag})")
|
196
207
|
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
elif not generate_new_basis and self.reuse_gradient:
|
205
|
-
function_value = previous_triplet.function_value
|
206
|
-
gradient = previous_triplet.gradient
|
207
|
-
elif not generate_new_basis and not self.reuse_gradient:
|
208
|
-
function_value = previous_triplet.function_value
|
209
|
-
gradient = pt.Point(is_basis=True)
|
210
|
-
gradient.add_tag(f"gradient_{self.tag}({point.tag})")
|
211
|
-
|
212
|
-
new_triplet = Triplet(
|
213
|
-
point,
|
214
|
-
previous_triplet.function_value,
|
215
|
-
gradient,
|
216
|
-
name=f"{point.tag}_{function_value.tag}_{gradient.tag}_{instances_of_point}",
|
217
|
-
)
|
218
|
-
self.add_triplet_to_func(new_triplet)
|
208
|
+
new_triplet = Triplet(
|
209
|
+
point,
|
210
|
+
function_value,
|
211
|
+
gradient,
|
212
|
+
name=f"{point.tag}_{function_value.tag}_{gradient.tag}",
|
213
|
+
)
|
214
|
+
self.add_triplet_to_func(new_triplet)
|
219
215
|
else:
|
220
216
|
if isinstance(self.composition, AddedFunc):
|
221
217
|
left_triplet = self.composition.left_func.generate_triplet(point)
|
@@ -247,57 +243,69 @@ class Function:
|
|
247
243
|
triplet = self.generate_triplet(point)
|
248
244
|
return triplet.function_value
|
249
245
|
|
246
|
+
def __call__(self, point: pt.Point) -> sc.Scalar:
|
247
|
+
return self.function_value(point)
|
248
|
+
|
250
249
|
def __add__(self, other):
|
251
250
|
assert isinstance(other, Function)
|
252
251
|
return Function(
|
253
252
|
is_basis=False,
|
254
|
-
reuse_gradient=self.reuse_gradient and other.reuse_gradient,
|
255
253
|
composition=AddedFunc(self, other),
|
256
254
|
tags=[f"{self.tag}+{other.tag}"],
|
257
255
|
)
|
258
256
|
|
259
257
|
def __sub__(self, other):
|
260
258
|
assert isinstance(other, Function)
|
259
|
+
tag_other = other.tag
|
260
|
+
if isinstance(other.composition, AddedFunc):
|
261
|
+
tag_other = f"({other.tag})"
|
261
262
|
return Function(
|
262
263
|
is_basis=False,
|
263
|
-
reuse_gradient=self.reuse_gradient and other.reuse_gradient,
|
264
264
|
composition=AddedFunc(self, -other),
|
265
|
-
tags=[f"{self.tag}-{
|
265
|
+
tags=[f"{self.tag}-{tag_other}"],
|
266
266
|
)
|
267
267
|
|
268
268
|
def __mul__(self, other):
|
269
269
|
assert utils.is_numerical(other)
|
270
|
+
tag_self = self.tag
|
271
|
+
if isinstance(self.composition, AddedFunc):
|
272
|
+
tag_self = f"({self.tag})"
|
270
273
|
return Function(
|
271
274
|
is_basis=False,
|
272
|
-
reuse_gradient=self.reuse_gradient,
|
273
275
|
composition=ScaledFunc(scale=other, base_func=self),
|
274
|
-
tags=[f"{other:.4g}*{
|
276
|
+
tags=[f"{other:.4g}*{tag_self}"],
|
275
277
|
)
|
276
278
|
|
277
279
|
def __rmul__(self, other):
|
278
280
|
assert utils.is_numerical(other)
|
281
|
+
tag_self = self.tag
|
282
|
+
if isinstance(self.composition, AddedFunc):
|
283
|
+
tag_self = f"({self.tag})"
|
279
284
|
return Function(
|
280
285
|
is_basis=False,
|
281
|
-
reuse_gradient=self.reuse_gradient,
|
282
286
|
composition=ScaledFunc(scale=other, base_func=self),
|
283
|
-
tags=[f"{other:.4g}*{
|
287
|
+
tags=[f"{other:.4g}*{tag_self}"],
|
284
288
|
)
|
285
289
|
|
286
290
|
def __neg__(self):
|
291
|
+
tag_self = self.tag
|
292
|
+
if isinstance(self.composition, AddedFunc):
|
293
|
+
tag_self = f"({self.tag})"
|
287
294
|
return Function(
|
288
295
|
is_basis=False,
|
289
|
-
reuse_gradient=self.reuse_gradient,
|
290
296
|
composition=ScaledFunc(scale=-1, base_func=self),
|
291
|
-
tags=[f"-{
|
297
|
+
tags=[f"-{tag_self}"],
|
292
298
|
)
|
293
299
|
|
294
300
|
def __truediv__(self, other):
|
295
301
|
assert utils.is_numerical(other)
|
302
|
+
tag_self = self.tag
|
303
|
+
if isinstance(self.composition, AddedFunc):
|
304
|
+
tag_self = f"({self.tag})"
|
296
305
|
return Function(
|
297
306
|
is_basis=False,
|
298
|
-
reuse_gradient=self.reuse_gradient,
|
299
307
|
composition=ScaledFunc(scale=1 / other, base_func=self),
|
300
|
-
tags=[f"1/{other:.4g}*{
|
308
|
+
tags=[f"1/{other:.4g}*{tag_self}"],
|
301
309
|
)
|
302
310
|
|
303
311
|
def __hash__(self):
|
@@ -309,10 +317,96 @@ class Function:
|
|
309
317
|
return self.uid == other.uid
|
310
318
|
|
311
319
|
|
320
|
+
class ConvexFunction(Function):
|
321
|
+
def __init__(
|
322
|
+
self,
|
323
|
+
is_basis=True,
|
324
|
+
composition=None,
|
325
|
+
):
|
326
|
+
super().__init__(
|
327
|
+
is_basis=is_basis,
|
328
|
+
composition=composition,
|
329
|
+
)
|
330
|
+
|
331
|
+
def convex_interpolability_constraints(
|
332
|
+
self, triplet_i: Triplet, triplet_j: Triplet
|
333
|
+
) -> Constraint:
|
334
|
+
point_i = triplet_i.point
|
335
|
+
function_value_i = triplet_i.function_value
|
336
|
+
|
337
|
+
point_j = triplet_j.point
|
338
|
+
function_value_j = triplet_j.function_value
|
339
|
+
grad_j = triplet_j.gradient
|
340
|
+
|
341
|
+
func_diff = function_value_j - function_value_i
|
342
|
+
cross_term = grad_j * (point_i - point_j)
|
343
|
+
|
344
|
+
return (func_diff + cross_term).le(
|
345
|
+
0, name=f"{self.tag}:{point_i.tag},{point_j.tag}"
|
346
|
+
)
|
347
|
+
|
348
|
+
def get_interpolation_constraints(
|
349
|
+
self, pep_context: pc.PEPContext | None = None
|
350
|
+
) -> list[Constraint]:
|
351
|
+
interpolation_constraints = []
|
352
|
+
if pep_context is None:
|
353
|
+
pep_context = pc.get_current_context()
|
354
|
+
if pep_context is None:
|
355
|
+
raise RuntimeError("Did you forget to create a context?")
|
356
|
+
for i in pep_context.triplets[self]:
|
357
|
+
for j in pep_context.triplets[self]:
|
358
|
+
if i == j:
|
359
|
+
continue
|
360
|
+
interpolation_constraints.append(
|
361
|
+
self.convex_interpolability_constraints(i, j)
|
362
|
+
)
|
363
|
+
return interpolation_constraints
|
364
|
+
|
365
|
+
def interpolate_ineq(
|
366
|
+
self, p1_tag: str, p2_tag: str, pep_context: pc.PEPContext | None = None
|
367
|
+
) -> sc.Scalar:
|
368
|
+
"""Generate the interpolation inequality scalar by tags."""
|
369
|
+
if pep_context is None:
|
370
|
+
pep_context = pc.get_current_context()
|
371
|
+
if pep_context is None:
|
372
|
+
raise RuntimeError("Did you forget to specify a context?")
|
373
|
+
# TODO: we definitely need a more robust tag system
|
374
|
+
x1 = pep_context.get_by_tag(p1_tag)
|
375
|
+
x2 = pep_context.get_by_tag(p2_tag)
|
376
|
+
f1 = pep_context.get_by_tag(f"{self.tag}({p1_tag})")
|
377
|
+
f2 = pep_context.get_by_tag(f"{self.tag}({p2_tag})")
|
378
|
+
g2 = pep_context.get_by_tag(f"gradient_{self.tag}({p2_tag})")
|
379
|
+
return f2 - f1 + g2 * (x1 - x2)
|
380
|
+
|
381
|
+
def proximal_step(self, x_0: pt.Point, stepsize: numbers.Number) -> pt.Point:
|
382
|
+
gradient = pt.Point(is_basis=True)
|
383
|
+
gradient.add_tag(
|
384
|
+
f"gradient_{self.tag}(prox_{{{stepsize}*{self.tag}}}({x_0.tag}))"
|
385
|
+
)
|
386
|
+
function_value = sc.Scalar(is_basis=True)
|
387
|
+
function_value.add_tag(f"{self.tag}(prox_{{{stepsize}*{self.tag}}}({x_0.tag}))")
|
388
|
+
x = x_0 - stepsize * gradient
|
389
|
+
x.add_tag(f"prox_{{{stepsize}*{self.tag}}}({x_0.tag})")
|
390
|
+
new_triplet = Triplet(
|
391
|
+
x,
|
392
|
+
function_value,
|
393
|
+
gradient,
|
394
|
+
name=f"{x.tag}_{function_value.tag}_{gradient.tag}",
|
395
|
+
)
|
396
|
+
self.add_triplet_to_func(new_triplet)
|
397
|
+
return x
|
398
|
+
|
399
|
+
|
312
400
|
class SmoothConvexFunction(Function):
|
313
|
-
def __init__(
|
401
|
+
def __init__(
|
402
|
+
self,
|
403
|
+
L,
|
404
|
+
is_basis=True,
|
405
|
+
composition=None,
|
406
|
+
):
|
314
407
|
super().__init__(
|
315
|
-
is_basis=is_basis,
|
408
|
+
is_basis=is_basis,
|
409
|
+
composition=composition,
|
316
410
|
)
|
317
411
|
self.L = L
|
318
412
|
|
@@ -350,7 +444,7 @@ class SmoothConvexFunction(Function):
|
|
350
444
|
|
351
445
|
def interpolate_ineq(
|
352
446
|
self, p1_tag: str, p2_tag: str, pep_context: pc.PEPContext | None = None
|
353
|
-
) ->
|
447
|
+
) -> sc.Scalar:
|
354
448
|
"""Generate the interpolation inequality scalar by tags."""
|
355
449
|
if pep_context is None:
|
356
450
|
pep_context = pc.get_current_context()
|