pepflow 0.1.3a1__py3-none-any.whl → 0.1.4a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pepflow/__init__.py CHANGED
@@ -35,6 +35,7 @@ from .pep_context import set_current_context as set_current_context
35
35
  # Function, Point, Scalar
36
36
  from .function import Function as Function
37
37
  from .function import SmoothConvexFunction as SmoothConvexFunction
38
+ from .function import ConvexFunction as ConvexFunction
38
39
  from .function import Triplet as Triplet
39
40
  from .point import EvaluatedPoint as EvaluatedPoint
40
41
  from .point import Point as Point
@@ -0,0 +1,71 @@
1
+ # Copyright: 2025 The PEPFlow Developers
2
+ #
3
+ # Licensed to the Apache Software Foundation (ASF) under one
4
+ # or more contributor license agreements. See the NOTICE file
5
+ # distributed with this work for additional information
6
+ # regarding copyright ownership. The ASF licenses this file
7
+ # to you under the Apache License, Version 2.0 (the
8
+ # "License"); you may not use this file except in compliance
9
+ # with the License. You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing,
14
+ # software distributed under the License is distributed on an
15
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16
+ # KIND, either express or implied. See the License for the
17
+ # specific language governing permissions and limitations
18
+ # under the License.
19
+
20
+
21
+ from typing import Iterator
22
+
23
+ import numpy as np
24
+ import pytest
25
+
26
+ from pepflow import expression_manager as exm
27
+ from pepflow import pep as pep
28
+ from pepflow import pep_context as pc
29
+ from pepflow import scalar, utils
30
+
31
+
32
+ @pytest.fixture
33
+ def pep_context() -> Iterator[pc.PEPContext]:
34
+ """Prepare the pep context and reset the context to None at the end."""
35
+ ctx = pc.PEPContext("test").set_as_current()
36
+ yield ctx
37
+ pc.set_current_context(None)
38
+
39
+
40
+ def test_constraint(pep_context: pc.PEPContext):
41
+ s1 = scalar.Scalar(is_basis=True, tags=["s1"])
42
+ s2 = scalar.Scalar(is_basis=True, tags=["s2"])
43
+ s3 = 2 * s1 + s2 / 4 + 5
44
+
45
+ c1 = s3.le(5, name="c1")
46
+ c2 = s3.lt(5, name="c2")
47
+ c3 = s3.ge(5, name="c3")
48
+ c4 = s3.gt(5, name="c4")
49
+ c5 = s3.eq(5, name="c5")
50
+
51
+ pm = exm.ExpressionManager(pep_context)
52
+
53
+ np.testing.assert_allclose(pm.eval_scalar(c1.scalar).vector, np.array([2, 0.25]))
54
+ np.testing.assert_allclose(pm.eval_scalar(c1.scalar).constant, 0)
55
+ assert c1.comparator == utils.Comparator.LT
56
+
57
+ np.testing.assert_allclose(pm.eval_scalar(c2.scalar).vector, np.array([2, 0.25]))
58
+ np.testing.assert_allclose(pm.eval_scalar(c2.scalar).constant, 0)
59
+ assert c2.comparator == utils.Comparator.LT
60
+
61
+ np.testing.assert_allclose(pm.eval_scalar(c3.scalar).vector, np.array([2, 0.25]))
62
+ np.testing.assert_allclose(pm.eval_scalar(c3.scalar).constant, 0)
63
+ assert c3.comparator == utils.Comparator.GT
64
+
65
+ np.testing.assert_allclose(pm.eval_scalar(c4.scalar).vector, np.array([2, 0.25]))
66
+ np.testing.assert_allclose(pm.eval_scalar(c4.scalar).constant, 0)
67
+ assert c4.comparator == utils.Comparator.GT
68
+
69
+ np.testing.assert_allclose(pm.eval_scalar(c5.scalar).vector, np.array([2, 0.25]))
70
+ np.testing.assert_allclose(pm.eval_scalar(c5.scalar).constant, 0)
71
+ assert c5.comparator == utils.Comparator.EQ
pepflow/e2e_test.py ADDED
@@ -0,0 +1,69 @@
1
+ import math
2
+
3
+ from pepflow import function, pep
4
+ from pepflow import pep_context as pc
5
+
6
+
7
+ def test_gd_e2e():
8
+ ctx = pc.PEPContext("gd").set_as_current()
9
+ pep_builder = pep.PEPBuilder()
10
+ eta = 1
11
+ N = 9
12
+
13
+ f = pep_builder.declare_func(function.SmoothConvexFunction, "f", L=1)
14
+ x = pep_builder.set_init_point("x_0")
15
+ x_star = f.add_stationary_point("x_star")
16
+ pep_builder.set_initial_constraint(
17
+ ((x - x_star) ** 2).le(1, name="initial_condition")
18
+ )
19
+
20
+ # We first build the algorithm with the largest number of iterations.
21
+ for i in range(N):
22
+ x = x - eta * f.gradient(x)
23
+ x.add_tag(f"x_{i + 1}")
24
+
25
+ # To achieve the sweep, we can just update the performance_metric.
26
+ for i in range(1, N + 1):
27
+ p = ctx.get_by_tag(f"x_{i}")
28
+ pep_builder.set_performance_metric(
29
+ f.function_value(p) - f.function_value(x_star)
30
+ )
31
+ result = pep_builder.solve()
32
+ expected_opt_value = 1 / (4 * i + 2)
33
+ assert math.isclose(result.primal_opt_value, expected_opt_value, rel_tol=1e-3)
34
+
35
+
36
+ def test_pgm_e2e():
37
+ ctx = pc.PEPContext("pgm").set_as_current()
38
+ pep_builder = pep.PEPBuilder()
39
+ eta = 1
40
+ N = 1
41
+
42
+ f = pep_builder.declare_func(function.SmoothConvexFunction, "f", L=1)
43
+ g = pep_builder.declare_func(function.ConvexFunction, "g")
44
+
45
+ h = f + g
46
+
47
+ x = pep_builder.set_init_point("x_0")
48
+ x_star = h.add_stationary_point("x_star")
49
+ pep_builder.set_initial_constraint(
50
+ ((x - x_star) ** 2).le(1, name="initial_condition")
51
+ )
52
+
53
+ # We first build the algorithm with the largest number of iterations.
54
+ for i in range(N):
55
+ y = x - eta * f.gradient(x)
56
+ y.add_tag(f"y_{i + 1}")
57
+ x = g.proximal_step(y, eta)
58
+ x.add_tag(f"x_{i + 1}")
59
+
60
+ # To achieve the sweep, we can just update the performance_metric.
61
+ for i in range(1, N + 1):
62
+ p = ctx.get_by_tag(f"x_{i}")
63
+ pep_builder.set_performance_metric(
64
+ h.function_value(p) - h.function_value(x_star)
65
+ )
66
+
67
+ result = pep_builder.solve()
68
+ expected_opt_value = 1 / (4 * i)
69
+ assert math.isclose(result.primal_opt_value, expected_opt_value, rel_tol=1e-3)
@@ -18,6 +18,7 @@
18
18
  # under the License.
19
19
 
20
20
  import functools
21
+ import math
21
22
 
22
23
  import numpy as np
23
24
 
@@ -27,6 +28,17 @@ from pepflow import scalar as sc
27
28
  from pepflow import utils
28
29
 
29
30
 
31
+ def tag_and_coef_to_str(tag: str, v: float) -> str:
32
+ coef = f"{abs(v):.3g}"
33
+ sign = "+" if v >= 0 else "-"
34
+ if math.isclose(abs(v), 1):
35
+ return f"{sign} {tag} "
36
+ elif math.isclose(v, 0):
37
+ return ""
38
+ else:
39
+ return f"{sign} {coef}*{tag} "
40
+
41
+
30
42
  class ExpressionManager:
31
43
  def __init__(self, pep_context: pc.PEPContext):
32
44
  self.context = pep_context
@@ -48,12 +60,18 @@ class ExpressionManager:
48
60
  self._num_basis_points = len(self._basis_points)
49
61
  self._num_basis_scalars = len(self._basis_scalars)
50
62
 
51
- def get_index_of_basis_point(self, point: pt.Point):
63
+ def get_index_of_basis_point(self, point: pt.Point) -> int:
52
64
  return self._basis_point_uid_to_index[point.uid]
53
65
 
54
- def get_index_of_basis_scalar(self, scalar: sc.Scalar):
66
+ def get_index_of_basis_scalar(self, scalar: sc.Scalar) -> int:
55
67
  return self._basis_scalar_uid_to_index[scalar.uid]
56
68
 
69
+ def get_tag_of_basis_point_index(self, index: int) -> str:
70
+ return self._basis_points[index].tag
71
+
72
+ def get_tag_of_basis_scalar_index(self, index: int) -> str:
73
+ return self._basis_scalars[index].tag
74
+
57
75
  @functools.cache
58
76
  def eval_point(self, point: pt.Point | float | int):
59
77
  if utils.is_numerical(point):
@@ -130,3 +148,55 @@ class ExpressionManager:
130
148
  ) / self.eval_scalar(scalar.eval_expression.right_scalar)
131
149
 
132
150
  raise ValueError("This should never happen!")
151
+
152
+ @functools.cache
153
+ def repr_point_by_basis(self, point: pt.Point) -> str:
154
+ assert isinstance(point, pt.Point)
155
+ repr_array = self.eval_point(point).vector
156
+
157
+ repr_str = ""
158
+ for i, v in enumerate(repr_array):
159
+ ith_tag = self.get_tag_of_basis_point_index(i)
160
+ repr_str += tag_and_coef_to_str(ith_tag, v)
161
+
162
+ # Post processing
163
+ if repr_str == "":
164
+ return "0"
165
+ if repr_str.startswith("+ "):
166
+ repr_str = repr_str[2:]
167
+ if repr_str.startswith("- "):
168
+ repr_str = "-" + repr_str[2:]
169
+ return repr_str.strip()
170
+
171
+ @functools.cache
172
+ def repr_scalar_by_basis(self, scalar: sc.Scalar) -> str:
173
+ assert isinstance(scalar, sc.Scalar)
174
+ evaluated_scalar = self.eval_scalar(scalar)
175
+
176
+ repr_str = ""
177
+ if not math.isclose(evaluated_scalar.constant, 0):
178
+ repr_str += f"{evaluated_scalar.constant:.3g}"
179
+
180
+ for i, v in enumerate(evaluated_scalar.vector):
181
+ # Note the tag is from scalar basis.
182
+ ith_tag = self.get_tag_of_basis_scalar_index(i)
183
+ repr_str += tag_and_coef_to_str(ith_tag, v)
184
+
185
+ for i in range(evaluated_scalar.matrix.shape[0]):
186
+ for j in range(i, evaluated_scalar.matrix.shape[0]):
187
+ ith_tag = self.get_tag_of_basis_point_index(i)
188
+ v = evaluated_scalar.matrix[i, j]
189
+ if i == j:
190
+ repr_str += tag_and_coef_to_str(f"|{ith_tag}|^2", v)
191
+ continue
192
+ jth_tag = self.get_tag_of_basis_point_index(j)
193
+ repr_str += tag_and_coef_to_str(f"<{ith_tag}, {jth_tag}>", 2 * v)
194
+
195
+ # Post processing
196
+ if repr_str == "":
197
+ return "0"
198
+ if repr_str.startswith("+ "):
199
+ repr_str = repr_str[2:]
200
+ if repr_str.startswith("- "):
201
+ repr_str = "-" + repr_str[2:]
202
+ return repr_str.strip()
@@ -0,0 +1,116 @@
1
+ # Copyright: 2025 The PEPFlow Developers
2
+ #
3
+ # Licensed to the Apache Software Foundation (ASF) under one
4
+ # or more contributor license agreements. See the NOTICE file
5
+ # distributed with this work for additional information
6
+ # regarding copyright ownership. The ASF licenses this file
7
+ # to you under the Apache License, Version 2.0 (the
8
+ # "License"); you may not use this file except in compliance
9
+ # with the License. You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing,
14
+ # software distributed under the License is distributed on an
15
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16
+ # KIND, either express or implied. See the License for the
17
+ # specific language governing permissions and limitations
18
+ # under the License.
19
+
20
+
21
+ from typing import Iterator
22
+
23
+ import numpy as np
24
+ import pytest
25
+
26
+ from pepflow import expression_manager as exm
27
+ from pepflow import function as fc
28
+ from pepflow import pep as pep
29
+ from pepflow import pep_context as pc
30
+ from pepflow import point as pt
31
+
32
+
33
+ @pytest.fixture
34
+ def pep_context() -> Iterator[pc.PEPContext]:
35
+ """Prepare the pep context and reset the context to None at the end."""
36
+ ctx = pc.PEPContext("test").set_as_current()
37
+ yield ctx
38
+ pc.set_current_context(None)
39
+
40
+
41
+ def test_repr_point_by_basis(pep_context: pc.PEPContext) -> None:
42
+ x = pt.Point(is_basis=True, tags=["x_0"])
43
+ f = fc.Function(is_basis=True, tags=["f"])
44
+ L = 0.5
45
+ for i in range(2):
46
+ x = x - L * f.gradient(x)
47
+ x.add_tag(f"x_{i + 1}")
48
+
49
+ em = exm.ExpressionManager(pep_context)
50
+ np.testing.assert_allclose(em.eval_point(x).vector, [1, -0.5, -0.5])
51
+ assert (
52
+ em.repr_point_by_basis(x) == "x_0 - 0.5*gradient_f(x_0) - 0.5*gradient_f(x_1)"
53
+ )
54
+
55
+
56
+ def test_repr_point_by_basis_with_zero(pep_context: pc.PEPContext) -> None:
57
+ x = pt.Point(is_basis=True, tags=["x_0"])
58
+ _ = pt.Point(is_basis=True, tags=["x_unused"]) # Add this extra point.
59
+ f = fc.Function(is_basis=True, tags=["f"])
60
+ L = 0.5
61
+ for i in range(2):
62
+ x = x - L * f.gradient(x)
63
+ x.add_tag(f"x_{i + 1}")
64
+
65
+ em = exm.ExpressionManager(pep_context)
66
+ # Note the vector representation of point is different from previous case
67
+ # But the string representation is still the same.
68
+ np.testing.assert_allclose(em.eval_point(x).vector, [1, 0, -0.5, -0.5])
69
+ assert (
70
+ em.repr_point_by_basis(x) == "x_0 - 0.5*gradient_f(x_0) - 0.5*gradient_f(x_1)"
71
+ )
72
+
73
+
74
+ def test_repr_point_by_basis_heavy_ball(pep_context: pc.PEPContext) -> None:
75
+ x_prev = pt.Point(is_basis=True, tags=["x_{-1}"])
76
+ x = pt.Point(is_basis=True, tags=["x_0"])
77
+ f = fc.Function(is_basis=True, tags=["f"])
78
+
79
+ beta = 0.5
80
+ for i in range(2):
81
+ x_next = x - f.gradient(x) + beta * (x - x_prev)
82
+ x_next.add_tag(f"x_{i + 1}")
83
+ x_prev = x
84
+ x = x_next
85
+
86
+ em = exm.ExpressionManager(pep_context)
87
+ np.testing.assert_allclose(em.eval_point(x).vector, [-0.75, 1.75, -1.5, -1])
88
+ assert (
89
+ em.repr_point_by_basis(x)
90
+ == "-0.75*x_{-1} + 1.75*x_0 - 1.5*gradient_f(x_0) - gradient_f(x_1)"
91
+ )
92
+
93
+
94
+ def test_repr_scalar_by_basis(pep_context: pc.PEPContext) -> None:
95
+ x = pt.Point(is_basis=True, tags=["x"])
96
+ f = fc.Function(is_basis=True, tags=["f"])
97
+
98
+ s = f(x) + x * f.gradient(x)
99
+ em = exm.ExpressionManager(pep_context)
100
+ assert em.repr_scalar_by_basis(s) == "f(x) + <x, gradient_f(x)>"
101
+
102
+
103
+ def test_repr_scalar_by_basis_interpolation(pep_context: pc.PEPContext) -> None:
104
+ xi = pt.Point(is_basis=True, tags=["x_i"])
105
+ xj = pt.Point(is_basis=True, tags=["x_j"])
106
+ f = fc.SmoothConvexFunction(is_basis=True, L=1)
107
+ f.add_tag("f")
108
+ fi = f(xi) # noqa: F841
109
+ fj = f(xj) # noqa: F841
110
+ interp_scalar = f.interpolate_ineq("x_i", "x_j")
111
+ em = exm.ExpressionManager(pep_context)
112
+ expected_repr = "-f(x_i) + f(x_j) + <x_i, gradient_f(x_j)> - <x_j, gradient_f(x_j)> + 0.5*|gradient_f(x_i)|^2 - <gradient_f(x_i), gradient_f(x_j)> + 0.5*|gradient_f(x_j)|^2"
113
+ assert em.repr_scalar_by_basis(interp_scalar) == expected_repr
114
+
115
+
116
+ # TODO add more tests about repr_scalar_by_basis
pepflow/function.py CHANGED
@@ -19,7 +19,9 @@
19
19
 
20
20
  from __future__ import annotations
21
21
 
22
+ import numbers
22
23
  import uuid
24
+ from typing import TYPE_CHECKING
23
25
 
24
26
  import attrs
25
27
 
@@ -28,6 +30,9 @@ from pepflow import point as pt
28
30
  from pepflow import scalar as sc
29
31
  from pepflow import utils
30
32
 
33
+ if TYPE_CHECKING:
34
+ from pepflow.constraint import Constraint
35
+
31
36
 
32
37
  @attrs.frozen
33
38
  class Triplet:
@@ -57,7 +62,6 @@ class ScaledFunc:
57
62
  @attrs.mutable
58
63
  class Function:
59
64
  is_basis: bool
60
- reuse_gradient: bool
61
65
 
62
66
  composition: AddedFunc | ScaledFunc | None = None
63
67
 
@@ -87,6 +91,10 @@ class Function:
87
91
  return self.tag
88
92
  return super().__repr__()
89
93
 
94
+ def _repr_latex_(self):
95
+ s = repr(self)
96
+ return rf"$\\displaystyle {s}$"
97
+
90
98
  def get_interpolation_constraints(self):
91
99
  raise NotImplementedError(
92
100
  "This method should be implemented in the children class."
@@ -154,11 +162,19 @@ class Function:
154
162
 
155
163
  def add_stationary_point(self, name: str) -> pt.Point:
156
164
  # assert we can only add one stationary point?
165
+ pep_context = pc.get_current_context()
166
+ if pep_context is None:
167
+ raise RuntimeError("Did you forget to create a context?")
168
+ if len(pep_context.stationary_triplets[self]) > 0:
169
+ raise ValueError(
170
+ "You are trying to add a stationary point to a function that already has a stationary point."
171
+ )
157
172
  point = pt.Point(is_basis=True)
158
173
  point.add_tag(name)
159
174
  desired_grad = 0 * point
160
175
  desired_grad.add_tag(f"gradient_{self.tag}({name})")
161
- self.add_point_with_grad_restriction(point, desired_grad)
176
+ triplet = self.add_point_with_grad_restriction(point, desired_grad)
177
+ pep_context.add_stationary_triplet(self, triplet)
162
178
  return point
163
179
 
164
180
  # The following the old gradient(opt) = 0 constraint style.
@@ -180,42 +196,22 @@ class Function:
180
196
  raise RuntimeError("Did you forget to create a context?")
181
197
 
182
198
  if self.is_basis:
183
- generate_new_basis = True
184
- instances_of_point = 0
185
199
  for triplet in pep_context.triplets[self]:
186
200
  if triplet.point.uid == point.uid:
187
- instances_of_point += 1
188
- generate_new_basis = False
189
- previous_triplet = triplet
201
+ return triplet
190
202
 
191
- if generate_new_basis:
192
- function_value = sc.Scalar(is_basis=True)
193
- function_value.add_tag(f"{self.tag}({point.tag})")
194
- gradient = pt.Point(is_basis=True)
195
- gradient.add_tag(f"gradient_{self.tag}({point.tag})")
203
+ function_value = sc.Scalar(is_basis=True)
204
+ function_value.add_tag(f"{self.tag}({point.tag})")
205
+ gradient = pt.Point(is_basis=True)
206
+ gradient.add_tag(f"gradient_{self.tag}({point.tag})")
196
207
 
197
- new_triplet = Triplet(
198
- point,
199
- function_value,
200
- gradient,
201
- name=f"{point.tag}_{function_value.tag}_{gradient.tag}",
202
- )
203
- self.add_triplet_to_func(new_triplet)
204
- elif not generate_new_basis and self.reuse_gradient:
205
- function_value = previous_triplet.function_value
206
- gradient = previous_triplet.gradient
207
- elif not generate_new_basis and not self.reuse_gradient:
208
- function_value = previous_triplet.function_value
209
- gradient = pt.Point(is_basis=True)
210
- gradient.add_tag(f"gradient_{self.tag}({point.tag})")
211
-
212
- new_triplet = Triplet(
213
- point,
214
- previous_triplet.function_value,
215
- gradient,
216
- name=f"{point.tag}_{function_value.tag}_{gradient.tag}_{instances_of_point}",
217
- )
218
- self.add_triplet_to_func(new_triplet)
208
+ new_triplet = Triplet(
209
+ point,
210
+ function_value,
211
+ gradient,
212
+ name=f"{point.tag}_{function_value.tag}_{gradient.tag}",
213
+ )
214
+ self.add_triplet_to_func(new_triplet)
219
215
  else:
220
216
  if isinstance(self.composition, AddedFunc):
221
217
  left_triplet = self.composition.left_func.generate_triplet(point)
@@ -247,57 +243,69 @@ class Function:
247
243
  triplet = self.generate_triplet(point)
248
244
  return triplet.function_value
249
245
 
246
+ def __call__(self, point: pt.Point) -> sc.Scalar:
247
+ return self.function_value(point)
248
+
250
249
  def __add__(self, other):
251
250
  assert isinstance(other, Function)
252
251
  return Function(
253
252
  is_basis=False,
254
- reuse_gradient=self.reuse_gradient and other.reuse_gradient,
255
253
  composition=AddedFunc(self, other),
256
254
  tags=[f"{self.tag}+{other.tag}"],
257
255
  )
258
256
 
259
257
  def __sub__(self, other):
260
258
  assert isinstance(other, Function)
259
+ tag_other = other.tag
260
+ if isinstance(other.composition, AddedFunc):
261
+ tag_other = f"({other.tag})"
261
262
  return Function(
262
263
  is_basis=False,
263
- reuse_gradient=self.reuse_gradient and other.reuse_gradient,
264
264
  composition=AddedFunc(self, -other),
265
- tags=[f"{self.tag}-{other.tag}"],
265
+ tags=[f"{self.tag}-{tag_other}"],
266
266
  )
267
267
 
268
268
  def __mul__(self, other):
269
269
  assert utils.is_numerical(other)
270
+ tag_self = self.tag
271
+ if isinstance(self.composition, AddedFunc):
272
+ tag_self = f"({self.tag})"
270
273
  return Function(
271
274
  is_basis=False,
272
- reuse_gradient=self.reuse_gradient,
273
275
  composition=ScaledFunc(scale=other, base_func=self),
274
- tags=[f"{other:.4g}*{self.tag}"],
276
+ tags=[f"{other:.4g}*{tag_self}"],
275
277
  )
276
278
 
277
279
  def __rmul__(self, other):
278
280
  assert utils.is_numerical(other)
281
+ tag_self = self.tag
282
+ if isinstance(self.composition, AddedFunc):
283
+ tag_self = f"({self.tag})"
279
284
  return Function(
280
285
  is_basis=False,
281
- reuse_gradient=self.reuse_gradient,
282
286
  composition=ScaledFunc(scale=other, base_func=self),
283
- tags=[f"{other:.4g}*{self.tag}"],
287
+ tags=[f"{other:.4g}*{tag_self}"],
284
288
  )
285
289
 
286
290
  def __neg__(self):
291
+ tag_self = self.tag
292
+ if isinstance(self.composition, AddedFunc):
293
+ tag_self = f"({self.tag})"
287
294
  return Function(
288
295
  is_basis=False,
289
- reuse_gradient=self.reuse_gradient,
290
296
  composition=ScaledFunc(scale=-1, base_func=self),
291
- tags=[f"-{self.tag}"],
297
+ tags=[f"-{tag_self}"],
292
298
  )
293
299
 
294
300
  def __truediv__(self, other):
295
301
  assert utils.is_numerical(other)
302
+ tag_self = self.tag
303
+ if isinstance(self.composition, AddedFunc):
304
+ tag_self = f"({self.tag})"
296
305
  return Function(
297
306
  is_basis=False,
298
- reuse_gradient=self.reuse_gradient,
299
307
  composition=ScaledFunc(scale=1 / other, base_func=self),
300
- tags=[f"1/{other:.4g}*{self.tag}"],
308
+ tags=[f"1/{other:.4g}*{tag_self}"],
301
309
  )
302
310
 
303
311
  def __hash__(self):
@@ -309,10 +317,96 @@ class Function:
309
317
  return self.uid == other.uid
310
318
 
311
319
 
320
+ class ConvexFunction(Function):
321
+ def __init__(
322
+ self,
323
+ is_basis=True,
324
+ composition=None,
325
+ ):
326
+ super().__init__(
327
+ is_basis=is_basis,
328
+ composition=composition,
329
+ )
330
+
331
+ def convex_interpolability_constraints(
332
+ self, triplet_i: Triplet, triplet_j: Triplet
333
+ ) -> Constraint:
334
+ point_i = triplet_i.point
335
+ function_value_i = triplet_i.function_value
336
+
337
+ point_j = triplet_j.point
338
+ function_value_j = triplet_j.function_value
339
+ grad_j = triplet_j.gradient
340
+
341
+ func_diff = function_value_j - function_value_i
342
+ cross_term = grad_j * (point_i - point_j)
343
+
344
+ return (func_diff + cross_term).le(
345
+ 0, name=f"{self.tag}:{point_i.tag},{point_j.tag}"
346
+ )
347
+
348
+ def get_interpolation_constraints(
349
+ self, pep_context: pc.PEPContext | None = None
350
+ ) -> list[Constraint]:
351
+ interpolation_constraints = []
352
+ if pep_context is None:
353
+ pep_context = pc.get_current_context()
354
+ if pep_context is None:
355
+ raise RuntimeError("Did you forget to create a context?")
356
+ for i in pep_context.triplets[self]:
357
+ for j in pep_context.triplets[self]:
358
+ if i == j:
359
+ continue
360
+ interpolation_constraints.append(
361
+ self.convex_interpolability_constraints(i, j)
362
+ )
363
+ return interpolation_constraints
364
+
365
+ def interpolate_ineq(
366
+ self, p1_tag: str, p2_tag: str, pep_context: pc.PEPContext | None = None
367
+ ) -> sc.Scalar:
368
+ """Generate the interpolation inequality scalar by tags."""
369
+ if pep_context is None:
370
+ pep_context = pc.get_current_context()
371
+ if pep_context is None:
372
+ raise RuntimeError("Did you forget to specify a context?")
373
+ # TODO: we definitely need a more robust tag system
374
+ x1 = pep_context.get_by_tag(p1_tag)
375
+ x2 = pep_context.get_by_tag(p2_tag)
376
+ f1 = pep_context.get_by_tag(f"{self.tag}({p1_tag})")
377
+ f2 = pep_context.get_by_tag(f"{self.tag}({p2_tag})")
378
+ g2 = pep_context.get_by_tag(f"gradient_{self.tag}({p2_tag})")
379
+ return f2 - f1 + g2 * (x1 - x2)
380
+
381
+ def proximal_step(self, x_0: pt.Point, stepsize: numbers.Number) -> pt.Point:
382
+ gradient = pt.Point(is_basis=True)
383
+ gradient.add_tag(
384
+ f"gradient_{self.tag}(prox_{{{stepsize}*{self.tag}}}({x_0.tag}))"
385
+ )
386
+ function_value = sc.Scalar(is_basis=True)
387
+ function_value.add_tag(f"{self.tag}(prox_{{{stepsize}*{self.tag}}}({x_0.tag}))")
388
+ x = x_0 - stepsize * gradient
389
+ x.add_tag(f"prox_{{{stepsize}*{self.tag}}}({x_0.tag})")
390
+ new_triplet = Triplet(
391
+ x,
392
+ function_value,
393
+ gradient,
394
+ name=f"{x.tag}_{function_value.tag}_{gradient.tag}",
395
+ )
396
+ self.add_triplet_to_func(new_triplet)
397
+ return x
398
+
399
+
312
400
  class SmoothConvexFunction(Function):
313
- def __init__(self, L, is_basis=True, composition=None, reuse_gradient=True):
401
+ def __init__(
402
+ self,
403
+ L,
404
+ is_basis=True,
405
+ composition=None,
406
+ ):
314
407
  super().__init__(
315
- is_basis=is_basis, composition=composition, reuse_gradient=reuse_gradient
408
+ is_basis=is_basis,
409
+ composition=composition,
316
410
  )
317
411
  self.L = L
318
412
 
@@ -350,7 +444,7 @@ class SmoothConvexFunction(Function):
350
444
 
351
445
  def interpolate_ineq(
352
446
  self, p1_tag: str, p2_tag: str, pep_context: pc.PEPContext | None = None
353
- ) -> pt.Scalar:
447
+ ) -> sc.Scalar:
354
448
  """Generate the interpolation inequality scalar by tags."""
355
449
  if pep_context is None:
356
450
  pep_context = pc.get_current_context()