pepflow 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pepflow/e2e_test.py +34 -0
- pepflow/function.py +44 -5
- pepflow/function_test.py +75 -0
- pepflow/pep.py +9 -1
- pepflow/pep_context.py +12 -3
- pepflow/pep_context_test.py +19 -21
- pepflow/point.py +54 -8
- pepflow/point_test.py +89 -131
- pepflow/scalar.py +50 -1
- pepflow/scalar_test.py +250 -0
- pepflow/solver_test.py +7 -7
- pepflow/utils.py +14 -1
- {pepflow-0.1.2.dist-info → pepflow-0.1.4.dist-info}/METADATA +1 -1
- pepflow-0.1.4.dist-info/RECORD +24 -0
- pepflow-0.1.2.dist-info/RECORD +0 -22
- {pepflow-0.1.2.dist-info → pepflow-0.1.4.dist-info}/WHEEL +0 -0
- {pepflow-0.1.2.dist-info → pepflow-0.1.4.dist-info}/licenses/LICENSE +0 -0
- {pepflow-0.1.2.dist-info → pepflow-0.1.4.dist-info}/top_level.txt +0 -0
pepflow/e2e_test.py
ADDED
@@ -0,0 +1,34 @@
|
|
1
|
+
import math
|
2
|
+
|
3
|
+
from pepflow import function, pep
|
4
|
+
from pepflow import pep_context as pc
|
5
|
+
|
6
|
+
|
7
|
+
def test_gd_e2e():
|
8
|
+
ctx = pc.PEPContext("gd").set_as_current()
|
9
|
+
pep_builder = pep.PEPBuilder()
|
10
|
+
eta = 1
|
11
|
+
N = 9
|
12
|
+
|
13
|
+
f = pep_builder.declare_func(function.SmoothConvexFunction, L=1)
|
14
|
+
f.add_tag("f")
|
15
|
+
x = pep_builder.set_init_point("x_0")
|
16
|
+
x_star = f.add_stationary_point("x_star")
|
17
|
+
pep_builder.set_initial_constraint(
|
18
|
+
((x - x_star) ** 2).le(1, name="initial_condition")
|
19
|
+
)
|
20
|
+
|
21
|
+
# We first build the algorithm with the largest number of iterations.
|
22
|
+
for i in range(N):
|
23
|
+
x = x - eta * f.gradient(x)
|
24
|
+
x.add_tag(f"x_{i + 1}")
|
25
|
+
|
26
|
+
# To achieve the sweep, we can just update the performance_metric.
|
27
|
+
for i in range(1, N + 1):
|
28
|
+
p = ctx.get_by_tag(f"x_{i}")
|
29
|
+
pep_builder.set_performance_metric(
|
30
|
+
f.function_value(p) - f.function_value(x_star)
|
31
|
+
)
|
32
|
+
result = pep_builder.solve()
|
33
|
+
expected_opt_value = 1 / (4 * i + 2)
|
34
|
+
assert math.isclose(result.primal_opt_value, expected_opt_value, rel_tol=1e-3)
|
pepflow/function.py
CHANGED
@@ -87,6 +87,10 @@ class Function:
|
|
87
87
|
return self.tag
|
88
88
|
return super().__repr__()
|
89
89
|
|
90
|
+
def _repr_latex_(self):
|
91
|
+
s = repr(self)
|
92
|
+
return rf"$\\displaystyle {s}$"
|
93
|
+
|
90
94
|
def get_interpolation_constraints(self):
|
91
95
|
raise NotImplementedError(
|
92
96
|
"This method should be implemented in the children class."
|
@@ -247,6 +251,9 @@ class Function:
|
|
247
251
|
triplet = self.generate_triplet(point)
|
248
252
|
return triplet.function_value
|
249
253
|
|
254
|
+
def __call__(self, point: pt.Point) -> sc.Scalar:
|
255
|
+
return self.function_value(point)
|
256
|
+
|
250
257
|
def __add__(self, other):
|
251
258
|
assert isinstance(other, Function)
|
252
259
|
return Function(
|
@@ -258,46 +265,61 @@ class Function:
|
|
258
265
|
|
259
266
|
def __sub__(self, other):
|
260
267
|
assert isinstance(other, Function)
|
268
|
+
tag_other = other.tag
|
269
|
+
if isinstance(other.composition, AddedFunc):
|
270
|
+
tag_other = f"({other.tag})"
|
261
271
|
return Function(
|
262
272
|
is_basis=False,
|
263
273
|
reuse_gradient=self.reuse_gradient and other.reuse_gradient,
|
264
274
|
composition=AddedFunc(self, -other),
|
265
|
-
tags=[f"{self.tag}-{
|
275
|
+
tags=[f"{self.tag}-{tag_other}"],
|
266
276
|
)
|
267
277
|
|
268
278
|
def __mul__(self, other):
|
269
279
|
assert utils.is_numerical(other)
|
280
|
+
tag_self = self.tag
|
281
|
+
if isinstance(self.composition, AddedFunc):
|
282
|
+
tag_self = f"({self.tag})"
|
270
283
|
return Function(
|
271
284
|
is_basis=False,
|
272
285
|
reuse_gradient=self.reuse_gradient,
|
273
286
|
composition=ScaledFunc(scale=other, base_func=self),
|
274
|
-
tags=[f"{other:.4g}*{
|
287
|
+
tags=[f"{other:.4g}*{tag_self}"],
|
275
288
|
)
|
276
289
|
|
277
290
|
def __rmul__(self, other):
|
278
291
|
assert utils.is_numerical(other)
|
292
|
+
tag_self = self.tag
|
293
|
+
if isinstance(self.composition, AddedFunc):
|
294
|
+
tag_self = f"({self.tag})"
|
279
295
|
return Function(
|
280
296
|
is_basis=False,
|
281
297
|
reuse_gradient=self.reuse_gradient,
|
282
298
|
composition=ScaledFunc(scale=other, base_func=self),
|
283
|
-
tags=[f"{other:.4g}*{
|
299
|
+
tags=[f"{other:.4g}*{tag_self}"],
|
284
300
|
)
|
285
301
|
|
286
302
|
def __neg__(self):
|
303
|
+
tag_self = self.tag
|
304
|
+
if isinstance(self.composition, AddedFunc):
|
305
|
+
tag_self = f"({self.tag})"
|
287
306
|
return Function(
|
288
307
|
is_basis=False,
|
289
308
|
reuse_gradient=self.reuse_gradient,
|
290
309
|
composition=ScaledFunc(scale=-1, base_func=self),
|
291
|
-
tags=[f"-{
|
310
|
+
tags=[f"-{tag_self}"],
|
292
311
|
)
|
293
312
|
|
294
313
|
def __truediv__(self, other):
|
295
314
|
assert utils.is_numerical(other)
|
315
|
+
tag_self = self.tag
|
316
|
+
if isinstance(self.composition, AddedFunc):
|
317
|
+
tag_self = f"({self.tag})"
|
296
318
|
return Function(
|
297
319
|
is_basis=False,
|
298
320
|
reuse_gradient=self.reuse_gradient,
|
299
321
|
composition=ScaledFunc(scale=1 / other, base_func=self),
|
300
|
-
tags=[f"1/{other:.4g}*{
|
322
|
+
tags=[f"1/{other:.4g}*{tag_self}"],
|
301
323
|
)
|
302
324
|
|
303
325
|
def __hash__(self):
|
@@ -347,3 +369,20 @@ class SmoothConvexFunction(Function):
|
|
347
369
|
self.smooth_convex_interpolability_constraints(i, j)
|
348
370
|
)
|
349
371
|
return interpolation_constraints
|
372
|
+
|
373
|
+
def interpolate_ineq(
|
374
|
+
self, p1_tag: str, p2_tag: str, pep_context: pc.PEPContext | None = None
|
375
|
+
) -> pt.Scalar:
|
376
|
+
"""Generate the interpolation inequality scalar by tags."""
|
377
|
+
if pep_context is None:
|
378
|
+
pep_context = pc.get_current_context()
|
379
|
+
if pep_context is None:
|
380
|
+
raise RuntimeError("Did you forget to specify a context?")
|
381
|
+
# TODO: we definitely need a more robust tag system
|
382
|
+
x1 = pep_context.get_by_tag(p1_tag)
|
383
|
+
x2 = pep_context.get_by_tag(p2_tag)
|
384
|
+
f1 = pep_context.get_by_tag(f"{self.tag}({p1_tag})")
|
385
|
+
f2 = pep_context.get_by_tag(f"{self.tag}({p2_tag})")
|
386
|
+
g1 = pep_context.get_by_tag(f"gradient_{self.tag}({p1_tag})")
|
387
|
+
g2 = pep_context.get_by_tag(f"gradient_{self.tag}({p2_tag})")
|
388
|
+
return f2 - f1 + g2 * (x1 - x2) + 1 / 2 * (g1 - g2) ** 2
|
pepflow/function_test.py
CHANGED
@@ -17,11 +17,86 @@
|
|
17
17
|
# specific language governing permissions and limitations
|
18
18
|
# under the License.
|
19
19
|
|
20
|
+
from typing import Iterator
|
21
|
+
|
20
22
|
import numpy as np
|
23
|
+
import pytest
|
21
24
|
|
22
25
|
from pepflow import expression_manager as exm
|
23
26
|
from pepflow import function as fc
|
24
27
|
from pepflow import pep as pep
|
28
|
+
from pepflow import pep_context as pc
|
29
|
+
from pepflow import point
|
30
|
+
|
31
|
+
|
32
|
+
@pytest.fixture
|
33
|
+
def pep_context() -> Iterator[pc.PEPContext]:
|
34
|
+
"""Prepare the pep context and reset the context to None at the end."""
|
35
|
+
ctx = pc.PEPContext("test").set_as_current()
|
36
|
+
yield ctx
|
37
|
+
pc.set_current_context(None)
|
38
|
+
|
39
|
+
|
40
|
+
def test_function_add_tag(pep_context: pc.PEPContext) -> None:
|
41
|
+
f1 = fc.Function(is_basis=True, reuse_gradient=False, tags=["f1"])
|
42
|
+
f2 = fc.Function(is_basis=True, reuse_gradient=False, tags=["f2"])
|
43
|
+
|
44
|
+
f_add = f1 + f2
|
45
|
+
assert f_add.tag == "f1+f2"
|
46
|
+
|
47
|
+
f_sub = f1 - f2
|
48
|
+
assert f_sub.tag == "f1-f2"
|
49
|
+
|
50
|
+
f_sub = f1 - (f2 + f1)
|
51
|
+
assert f_sub.tag == "f1-(f2+f1)"
|
52
|
+
|
53
|
+
f_sub = f1 - (f2 - f1)
|
54
|
+
assert f_sub.tag == "f1-(f2-f1)"
|
55
|
+
|
56
|
+
|
57
|
+
def test_function_mul_tag(pep_context: pc.PEPContext) -> None:
|
58
|
+
f = fc.Function(is_basis=True, reuse_gradient=False, tags=["f"])
|
59
|
+
|
60
|
+
f_mul = f * 0.1
|
61
|
+
assert f_mul.tag == "0.1*f"
|
62
|
+
|
63
|
+
f_rmul = 0.1 * f
|
64
|
+
assert f_rmul.tag == "0.1*f"
|
65
|
+
|
66
|
+
f_neg = -f
|
67
|
+
assert f_neg.tag == "-f"
|
68
|
+
|
69
|
+
f_truediv = f / 0.1
|
70
|
+
assert f_truediv.tag == "1/0.1*f"
|
71
|
+
|
72
|
+
|
73
|
+
def test_function_add_and_mul_tag(pep_context: pc.PEPContext) -> None:
|
74
|
+
f1 = fc.Function(is_basis=True, reuse_gradient=False, tags=["f1"])
|
75
|
+
f2 = fc.Function(is_basis=True, reuse_gradient=False, tags=["f2"])
|
76
|
+
|
77
|
+
f_add_mul = (f1 + f2) * 0.1
|
78
|
+
assert f_add_mul.tag == "0.1*(f1+f2)"
|
79
|
+
|
80
|
+
f_add_mul = f1 + f2 * 0.1
|
81
|
+
assert f_add_mul.tag == "f1+0.1*f2"
|
82
|
+
|
83
|
+
f_neg_add = -(f1 + f2)
|
84
|
+
assert f_neg_add.tag == "-(f1+f2)"
|
85
|
+
|
86
|
+
f_rmul_add = 0.1 * (f1 + f2)
|
87
|
+
assert f_rmul_add.tag == "0.1*(f1+f2)"
|
88
|
+
|
89
|
+
f_rmul_add = f1 + 5 * (f2 + 3 * f1)
|
90
|
+
assert f_rmul_add.tag == "f1+5*(f2+3*f1)"
|
91
|
+
|
92
|
+
f_multiple_add = f1 + f1 + f1 + f1 + f1 + f1
|
93
|
+
assert f_multiple_add.tag == "f1+f1+f1+f1+f1+f1"
|
94
|
+
|
95
|
+
|
96
|
+
def test_function_call(pep_context: pc.PEPContext) -> None:
|
97
|
+
f = fc.Function(is_basis=True, reuse_gradient=False, tags=["f"])
|
98
|
+
x = point.Point(is_basis=True, eval_expression=None, tags=["x"])
|
99
|
+
assert f.function_value(x) == f(x)
|
25
100
|
|
26
101
|
|
27
102
|
def test_function_repr():
|
pepflow/pep.py
CHANGED
@@ -87,6 +87,13 @@ class PEPBuilder:
|
|
87
87
|
# We should think about a better choice like manager.
|
88
88
|
self.relaxed_constraints = []
|
89
89
|
|
90
|
+
def clear_setup(self):
|
91
|
+
self.init_conditions.clear()
|
92
|
+
self.functions.clear()
|
93
|
+
self.interpolation_constraints.clear()
|
94
|
+
self.performance_metric = None
|
95
|
+
self.relaxed_constraints.clear()
|
96
|
+
|
90
97
|
@contextlib.contextmanager
|
91
98
|
def make_context(
|
92
99
|
self, name: str, override: bool = False
|
@@ -94,7 +101,8 @@ class PEPBuilder:
|
|
94
101
|
if not override and name in self.pep_context_dict:
|
95
102
|
raise KeyError(f"There is already a context {name} in the builder")
|
96
103
|
try:
|
97
|
-
|
104
|
+
self.clear_setup()
|
105
|
+
ctx = pc.PEPContext(name)
|
98
106
|
self.pep_context_dict[name] = ctx
|
99
107
|
pc.set_current_context(ctx)
|
100
108
|
yield ctx
|
pepflow/pep_context.py
CHANGED
@@ -33,6 +33,8 @@ if TYPE_CHECKING:
|
|
33
33
|
|
34
34
|
# A global variable for storing the current context that is used for points or scalars.
|
35
35
|
CURRENT_CONTEXT: PEPContext | None = None
|
36
|
+
# Keep the track of all previous created context
|
37
|
+
GLOBAL_CONTEXT_DICT: dict[str, PEPContext] = {}
|
36
38
|
|
37
39
|
|
38
40
|
def get_current_context() -> PEPContext | None:
|
@@ -46,12 +48,19 @@ def set_current_context(ctx: PEPContext | None):
|
|
46
48
|
|
47
49
|
|
48
50
|
class PEPContext:
|
49
|
-
def __init__(self):
|
51
|
+
def __init__(self, name: str):
|
52
|
+
self.name = name
|
50
53
|
self.points: list[Point] = []
|
51
54
|
self.scalars: list[Scalar] = []
|
52
55
|
self.triplets: dict[Function, list[Triplet]] = defaultdict(list)
|
53
56
|
self.opt_conditions: dict[Function, list[Constraint]] = defaultdict(list)
|
54
57
|
|
58
|
+
GLOBAL_CONTEXT_DICT[name] = self
|
59
|
+
|
60
|
+
def set_as_current(self) -> PEPContext:
|
61
|
+
set_current_context(self)
|
62
|
+
return self
|
63
|
+
|
55
64
|
def add_point(self, point: Point):
|
56
65
|
self.points.append(point)
|
57
66
|
|
@@ -66,10 +75,10 @@ class PEPContext:
|
|
66
75
|
|
67
76
|
def get_by_tag(self, tag: str) -> Point | Scalar:
|
68
77
|
for p in self.points:
|
69
|
-
if
|
78
|
+
if tag in p.tags:
|
70
79
|
return p
|
71
80
|
for s in self.scalars:
|
72
|
-
if
|
81
|
+
if tag in s.tags:
|
73
82
|
return s
|
74
83
|
raise ValueError("Cannot find the point or scalar of given tag")
|
75
84
|
|
pepflow/pep_context_test.py
CHANGED
@@ -17,17 +17,25 @@
|
|
17
17
|
# specific language governing permissions and limitations
|
18
18
|
# under the License.
|
19
19
|
|
20
|
+
from typing import Iterator
|
21
|
+
|
20
22
|
import pandas as pd
|
23
|
+
import pytest
|
21
24
|
|
22
25
|
from pepflow import pep_context as pc
|
23
26
|
from pepflow.function import SmoothConvexFunction
|
24
27
|
from pepflow.point import Point
|
25
28
|
|
26
29
|
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
+
@pytest.fixture
|
31
|
+
def pep_context() -> Iterator[pc.PEPContext]:
|
32
|
+
"""Prepare the pep context and reset the context to None at the end."""
|
33
|
+
ctx = pc.PEPContext("test").set_as_current()
|
34
|
+
yield ctx
|
35
|
+
pc.set_current_context(None)
|
36
|
+
|
30
37
|
|
38
|
+
def test_tracked_points(pep_context: pc.PEPContext):
|
31
39
|
f = SmoothConvexFunction(L=1, is_basis=True)
|
32
40
|
f.add_tag("f")
|
33
41
|
|
@@ -41,16 +49,11 @@ def test_tracked_points():
|
|
41
49
|
_ = f.generate_triplet(p3)
|
42
50
|
_ = f.generate_triplet(p_star)
|
43
51
|
|
44
|
-
assert
|
45
|
-
assert
|
46
|
-
|
47
|
-
pc.set_current_context(None)
|
52
|
+
assert pep_context.order_of_point(f) == ["x_1", "x_2", "x_3", "x_*"]
|
53
|
+
assert pep_context.tracked_point(f) == [p1, p3, p2, p_star]
|
48
54
|
|
49
55
|
|
50
|
-
def test_triplets_to_dataframe():
|
51
|
-
ctx = pc.PEPContext()
|
52
|
-
pc.set_current_context(ctx)
|
53
|
-
|
56
|
+
def test_triplets_to_dataframe(pep_context: pc.PEPContext):
|
54
57
|
f = SmoothConvexFunction(L=1, is_basis=True)
|
55
58
|
f.add_tag("f")
|
56
59
|
|
@@ -62,7 +65,7 @@ def test_triplets_to_dataframe():
|
|
62
65
|
_ = f.generate_triplet(p2)
|
63
66
|
_ = f.generate_triplet(p3)
|
64
67
|
|
65
|
-
func_to_df, func_to_order =
|
68
|
+
func_to_df, func_to_order = pep_context.triplets_to_df_and_order()
|
66
69
|
expected_df = pd.DataFrame(
|
67
70
|
{
|
68
71
|
"constraint_name": [
|
@@ -83,20 +86,15 @@ def test_triplets_to_dataframe():
|
|
83
86
|
pd.testing.assert_frame_equal(func_to_df[f], expected_df)
|
84
87
|
assert func_to_order[f] == ["x1", "x2", "x3"]
|
85
88
|
|
86
|
-
pc.set_current_context(None)
|
87
|
-
|
88
|
-
|
89
|
-
def test_get_by_tag():
|
90
|
-
ctx = pc.PEPContext()
|
91
|
-
pc.set_current_context(ctx)
|
92
89
|
|
90
|
+
def test_get_by_tag(pep_context: pc.PEPContext):
|
93
91
|
f = SmoothConvexFunction(L=1, is_basis=True)
|
94
92
|
f.add_tag("f")
|
95
93
|
p1 = Point(is_basis=True, tags=["x1"])
|
96
94
|
|
97
95
|
triplet = f.generate_triplet(p1)
|
98
96
|
|
99
|
-
assert
|
100
|
-
assert
|
101
|
-
assert
|
97
|
+
assert pep_context.get_by_tag("x1") == p1
|
98
|
+
assert pep_context.get_by_tag("f(x1)") == triplet.function_value
|
99
|
+
assert pep_context.get_by_tag("gradient_f(x1)") == triplet.gradient
|
102
100
|
pc.set_current_context(None)
|
pepflow/point.py
CHANGED
@@ -134,75 +134,110 @@ class Point:
|
|
134
134
|
return self.tag
|
135
135
|
return super().__repr__()
|
136
136
|
|
137
|
+
def _repr_latex_(self):
|
138
|
+
s = repr(self)
|
139
|
+
s = s.replace("star", r"\star")
|
140
|
+
s = s.replace("gradient_", r"\nabla ")
|
141
|
+
s = s.replace("|", r"\|")
|
142
|
+
return rf"$\\displaystyle {s}$"
|
143
|
+
|
137
144
|
# TODO: add a validator that `is_basis` and `eval_expression` are properly setup.
|
138
145
|
def __add__(self, other):
|
139
|
-
assert
|
146
|
+
assert isinstance(other, Point)
|
140
147
|
return Point(
|
141
148
|
is_basis=False,
|
142
149
|
eval_expression=EvalExpressionPoint(utils.Op.ADD, self, other),
|
150
|
+
tags=[f"{self.tag}+{other.tag}"],
|
143
151
|
)
|
144
152
|
|
145
153
|
def __radd__(self, other):
|
146
|
-
|
154
|
+
# TODO: come up with better way to handle this
|
155
|
+
if other == 0:
|
156
|
+
return self
|
157
|
+
assert isinstance(other, Point)
|
147
158
|
return Point(
|
148
159
|
is_basis=False,
|
149
160
|
eval_expression=EvalExpressionPoint(utils.Op.ADD, other, self),
|
161
|
+
tags=[f"{other.tag}+{self.tag}"],
|
150
162
|
)
|
151
163
|
|
152
164
|
def __sub__(self, other):
|
153
|
-
assert
|
165
|
+
assert isinstance(other, Point)
|
166
|
+
tag_other = utils.parenthesize_tag(other)
|
154
167
|
return Point(
|
155
168
|
is_basis=False,
|
156
169
|
eval_expression=EvalExpressionPoint(utils.Op.SUB, self, other),
|
170
|
+
tags=[f"{self.tag}-{tag_other}"],
|
157
171
|
)
|
158
172
|
|
159
173
|
def __rsub__(self, other):
|
160
|
-
assert
|
174
|
+
assert isinstance(other, Point)
|
175
|
+
tag_self = utils.parenthesize_tag(self)
|
161
176
|
return Point(
|
162
177
|
is_basis=False,
|
163
178
|
eval_expression=EvalExpressionPoint(utils.Op.SUB, other, self),
|
179
|
+
tags=[f"{other.tag}-{tag_self}"],
|
164
180
|
)
|
165
181
|
|
166
182
|
def __mul__(self, other):
|
167
183
|
# TODO allow the other to be point so that we return a scalar.
|
168
184
|
assert is_numerical_or_point(other)
|
185
|
+
tag_self = utils.parenthesize_tag(self)
|
169
186
|
if utils.is_numerical(other):
|
170
187
|
return Point(
|
171
188
|
is_basis=False,
|
172
189
|
eval_expression=EvalExpressionPoint(utils.Op.MUL, self, other),
|
190
|
+
tags=[f"{tag_self}*{other:.4g}"],
|
173
191
|
)
|
174
192
|
else:
|
193
|
+
tag_other = utils.parenthesize_tag(other)
|
175
194
|
return Scalar(
|
176
195
|
is_basis=False,
|
177
|
-
eval_expression=EvalExpressionScalar(utils.Op.MUL, self, other),
|
196
|
+
eval_expression=EvalExpressionScalar(utils.Op.MUL, self, other),
|
197
|
+
tags=[f"{tag_self}*{tag_other}"],
|
178
198
|
)
|
179
199
|
|
180
200
|
def __rmul__(self, other):
|
181
201
|
# TODO allow the other to be point so that we return a scalar.
|
182
202
|
assert is_numerical_or_point(other)
|
203
|
+
tag_self = utils.parenthesize_tag(self)
|
183
204
|
if utils.is_numerical(other):
|
184
205
|
return Point(
|
185
206
|
is_basis=False,
|
186
207
|
eval_expression=EvalExpressionPoint(utils.Op.MUL, other, self),
|
208
|
+
tags=[f"{other:.4g}*{tag_self}"],
|
187
209
|
)
|
188
210
|
else:
|
211
|
+
tag_other = utils.parenthesize_tag(other)
|
189
212
|
return Scalar(
|
190
213
|
is_basis=False,
|
191
|
-
eval_expression=EvalExpressionScalar(utils.Op.MUL, other, self),
|
214
|
+
eval_expression=EvalExpressionScalar(utils.Op.MUL, other, self),
|
215
|
+
tags=[f"{tag_other}*{tag_self}"],
|
192
216
|
)
|
193
217
|
|
194
218
|
def __pow__(self, power):
|
195
219
|
assert power == 2
|
196
|
-
return
|
220
|
+
return Scalar(
|
221
|
+
is_basis=False,
|
222
|
+
eval_expression=EvalExpressionScalar(utils.Op.MUL, self, self),
|
223
|
+
tags=[rf"|{self.tag}|^{power}"],
|
224
|
+
)
|
197
225
|
|
198
226
|
def __neg__(self):
|
199
|
-
|
227
|
+
tag_self = utils.parenthesize_tag(self)
|
228
|
+
return Point(
|
229
|
+
is_basis=False,
|
230
|
+
eval_expression=EvalExpressionPoint(utils.Op.MUL, -1, self),
|
231
|
+
tags=[f"-{tag_self}"],
|
232
|
+
)
|
200
233
|
|
201
234
|
def __truediv__(self, other):
|
202
235
|
assert utils.is_numerical(other)
|
236
|
+
tag_self = utils.parenthesize_tag(self)
|
203
237
|
return Point(
|
204
238
|
is_basis=False,
|
205
239
|
eval_expression=EvalExpressionPoint(utils.Op.DIV, self, other),
|
240
|
+
tags=[f"1/{other:.4g}*{tag_self}"],
|
206
241
|
)
|
207
242
|
|
208
243
|
def __hash__(self):
|
@@ -212,3 +247,14 @@ class Point:
|
|
212
247
|
if not isinstance(other, Point):
|
213
248
|
return NotImplemented
|
214
249
|
return self.uid == other.uid
|
250
|
+
|
251
|
+
def eval(self, ctx: pc.PEPContext | None = None) -> np.ndarray:
|
252
|
+
from pepflow.expression_manager import ExpressionManager
|
253
|
+
|
254
|
+
# Note this can be inefficient.
|
255
|
+
if ctx is None:
|
256
|
+
ctx = pc.get_current_context()
|
257
|
+
if ctx is None:
|
258
|
+
raise RuntimeError("Did you forget to create a context?")
|
259
|
+
em = ExpressionManager(ctx)
|
260
|
+
return em.eval_point(self).vector
|
pepflow/point_test.py
CHANGED
@@ -18,47 +18,109 @@
|
|
18
18
|
# under the License.
|
19
19
|
|
20
20
|
import time
|
21
|
+
from typing import Iterator
|
21
22
|
|
22
23
|
import numpy as np
|
24
|
+
import pytest
|
23
25
|
|
24
26
|
from pepflow import expression_manager as exm
|
25
27
|
from pepflow import function as fc
|
26
28
|
from pepflow import pep as pep
|
27
|
-
from pepflow import
|
29
|
+
from pepflow import pep_context as pc
|
30
|
+
from pepflow import point, scalar
|
28
31
|
|
29
32
|
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
33
|
+
@pytest.fixture
|
34
|
+
def pep_context() -> Iterator[pc.PEPContext]:
|
35
|
+
"""Prepare the pep context and reset the context to None at the end."""
|
36
|
+
ctx = pc.PEPContext("test").set_as_current()
|
37
|
+
yield ctx
|
38
|
+
pc.set_current_context(None)
|
39
|
+
|
40
|
+
|
41
|
+
def test_point_add_tag(pep_context: pc.PEPContext) -> None:
|
42
|
+
p1 = point.Point(is_basis=True, eval_expression=None, tags=["p1"])
|
43
|
+
p2 = point.Point(is_basis=True, eval_expression=None, tags=["p2"])
|
44
|
+
|
45
|
+
p_add = p1 + p2
|
46
|
+
assert p_add.tag == "p1+p2"
|
47
|
+
|
48
|
+
p_sub = p1 - p2
|
49
|
+
assert p_sub.tag == "p1-p2"
|
50
|
+
|
51
|
+
p_sub = p1 - (p2 + p1)
|
52
|
+
assert p_sub.tag == "p1-(p2+p1)"
|
53
|
+
|
54
|
+
p_sub = p1 - (p2 - p1)
|
55
|
+
assert p_sub.tag == "p1-(p2-p1)"
|
56
|
+
|
57
|
+
|
58
|
+
def test_point_mul_tag(pep_context: pc.PEPContext) -> None:
|
59
|
+
p = point.Point(is_basis=True, eval_expression=None, tags=["p"])
|
60
|
+
|
61
|
+
p_mul = p * 0.1
|
62
|
+
assert p_mul.tag == "p*0.1"
|
63
|
+
|
64
|
+
p_rmul = 0.1 * p
|
65
|
+
assert p_rmul.tag == "0.1*p"
|
66
|
+
|
67
|
+
p_pow = p**2
|
68
|
+
assert p_pow.tag == "|p|^2"
|
69
|
+
|
70
|
+
p_neg = -p
|
71
|
+
assert p_neg.tag == "-p"
|
72
|
+
|
73
|
+
p_truediv = p / 0.1
|
74
|
+
assert p_truediv.tag == "1/0.1*p"
|
75
|
+
|
76
|
+
|
77
|
+
def test_point_add_and_mul_tag(pep_context: pc.PEPContext) -> None:
|
78
|
+
p1 = point.Point(is_basis=True, eval_expression=None, tags=["p1"])
|
79
|
+
p2 = point.Point(is_basis=True, eval_expression=None, tags=["p2"])
|
80
|
+
|
81
|
+
p_add_mul = (p1 + p2) * 0.1
|
82
|
+
assert p_add_mul.tag == "(p1+p2)*0.1"
|
83
|
+
|
84
|
+
p_add_mul = (p1 + p2) * (p1 + p2)
|
85
|
+
assert p_add_mul.tag == "(p1+p2)*(p1+p2)"
|
86
|
+
|
87
|
+
p_add_pow = (p1 + p2) ** 2
|
88
|
+
assert p_add_pow.tag == "|p1+p2|^2"
|
89
|
+
|
90
|
+
p_add_mul = p1 + p2 * 0.1
|
91
|
+
assert p_add_mul.tag == "p1+p2*0.1"
|
92
|
+
|
93
|
+
p_neg_add = -(p1 + p2)
|
94
|
+
assert p_neg_add.tag == "-(p1+p2)"
|
95
|
+
|
96
|
+
p_rmul_add = 0.1 * (p1 + p2)
|
97
|
+
assert p_rmul_add.tag == "0.1*(p1+p2)"
|
98
|
+
|
99
|
+
|
100
|
+
def test_point_hash_different(pep_context: pc.PEPContext) -> None:
|
101
|
+
p1 = point.Point(is_basis=True, eval_expression=None)
|
102
|
+
p2 = point.Point(is_basis=True, eval_expression=None)
|
35
103
|
assert p1.uid != p2.uid
|
36
104
|
|
37
105
|
|
38
|
-
def test_scalar_hash_different():
|
39
|
-
|
40
|
-
|
41
|
-
s1 = scalar.Scalar(is_basis=True, eval_expression=None)
|
42
|
-
s2 = scalar.Scalar(is_basis=True, eval_expression=None)
|
106
|
+
def test_scalar_hash_different(pep_context: pc.PEPContext) -> None:
|
107
|
+
s1 = scalar.Scalar(is_basis=True, eval_expression=None)
|
108
|
+
s2 = scalar.Scalar(is_basis=True, eval_expression=None)
|
43
109
|
assert s1.uid != s2.uid
|
44
110
|
|
45
111
|
|
46
|
-
def test_point_tag():
|
47
|
-
|
48
|
-
|
49
|
-
p1 = point.Point(is_basis=True, eval_expression=None)
|
50
|
-
p1.add_tag(tag="my_tag")
|
112
|
+
def test_point_tag(pep_context: pc.PEPContext) -> None:
|
113
|
+
p1 = point.Point(is_basis=True, eval_expression=None)
|
114
|
+
p1.add_tag(tag="my_tag")
|
51
115
|
assert p1.tags == ["my_tag"]
|
52
116
|
assert p1.tag == "my_tag"
|
53
117
|
|
54
118
|
|
55
|
-
def test_point_repr():
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
p1.add_tag("my_tag")
|
61
|
-
assert str(p1) == "my_tag"
|
119
|
+
def test_point_repr(pep_context: pc.PEPContext) -> None:
|
120
|
+
p1 = point.Point(is_basis=True)
|
121
|
+
assert str(p1) is not None # it should be fine without tag
|
122
|
+
p1.add_tag("my_tag")
|
123
|
+
assert str(p1) == "my_tag"
|
62
124
|
|
63
125
|
|
64
126
|
def test_scalar_tag():
|
@@ -73,7 +135,7 @@ def test_scalar_tag():
|
|
73
135
|
def test_scalar_repr():
|
74
136
|
pep_builder = pep.PEPBuilder()
|
75
137
|
with pep_builder.make_context("test"):
|
76
|
-
s1 = scalar.Scalar(is_basis=True)
|
138
|
+
s1 = scalar.Scalar(is_basis=True, tags=["s1"])
|
77
139
|
print(s1) # it should be fine without tag
|
78
140
|
s1.add_tag("my_tag")
|
79
141
|
assert str(s1) == "my_tag"
|
@@ -136,8 +198,8 @@ def test_expression_manager_on_basis_scalar():
|
|
136
198
|
def test_expression_manager_eval_point():
|
137
199
|
pep_builder = pep.PEPBuilder()
|
138
200
|
with pep_builder.make_context("test") as ctx:
|
139
|
-
p1 = point.Point(is_basis=True)
|
140
|
-
p2 = point.Point(is_basis=True)
|
201
|
+
p1 = point.Point(is_basis=True, tags=["p1"])
|
202
|
+
p2 = point.Point(is_basis=True, tags=["p2"])
|
141
203
|
p3 = 2 * p1 + p2 / 4
|
142
204
|
p4 = p3 + p1
|
143
205
|
|
@@ -146,114 +208,10 @@ def test_expression_manager_eval_point():
|
|
146
208
|
np.testing.assert_allclose(pm.eval_point(p4).vector, np.array([3, 0.25]))
|
147
209
|
|
148
210
|
|
149
|
-
def test_expression_manager_eval_scalar():
|
150
|
-
pep_builder = pep.PEPBuilder()
|
151
|
-
with pep_builder.make_context("test") as ctx:
|
152
|
-
s1 = scalar.Scalar(is_basis=True)
|
153
|
-
s2 = scalar.Scalar(is_basis=True)
|
154
|
-
s3 = 2 * s1 + s2 / 4 + 5
|
155
|
-
s4 = s3 + s1
|
156
|
-
s5 = s4 + 5
|
157
|
-
|
158
|
-
p1 = point.Point(is_basis=True)
|
159
|
-
p2 = point.Point(is_basis=True)
|
160
|
-
s6 = p1 * p2
|
161
|
-
|
162
|
-
p3 = point.Point(is_basis=True)
|
163
|
-
p4 = point.Point(is_basis=True)
|
164
|
-
s7 = 5 * p3 * p4
|
165
|
-
|
166
|
-
s8 = s6 + s7
|
167
|
-
|
168
|
-
pm = exm.ExpressionManager(ctx)
|
169
|
-
|
170
|
-
np.testing.assert_allclose(pm.eval_scalar(s3).vector, np.array([2, 0.25]))
|
171
|
-
np.testing.assert_allclose(pm.eval_scalar(s3).constant, 5)
|
172
|
-
np.testing.assert_allclose(pm.eval_scalar(s4).vector, np.array([3, 0.25]))
|
173
|
-
np.testing.assert_allclose(pm.eval_scalar(s5).vector, np.array([3, 0.25]))
|
174
|
-
np.testing.assert_allclose(pm.eval_scalar(s5).constant, 10)
|
175
|
-
|
176
|
-
np.testing.assert_allclose(pm.eval_point(p1).vector, np.array([1, 0, 0, 0]))
|
177
|
-
np.testing.assert_allclose(pm.eval_point(p2).vector, np.array([0, 1, 0, 0]))
|
178
|
-
np.testing.assert_allclose(pm.eval_point(p3).vector, np.array([0, 0, 1, 0]))
|
179
|
-
np.testing.assert_allclose(pm.eval_point(p4).vector, np.array([0, 0, 0, 1]))
|
180
|
-
|
181
|
-
np.testing.assert_allclose(
|
182
|
-
pm.eval_scalar(s6).matrix,
|
183
|
-
np.array(
|
184
|
-
[
|
185
|
-
[0.0, 0.5, 0.0, 0.0],
|
186
|
-
[0.5, 0.0, 0.0, 0.0],
|
187
|
-
[0.0, 0.0, 0.0, 0.0],
|
188
|
-
[0.0, 0.0, 0.0, 0.0],
|
189
|
-
]
|
190
|
-
),
|
191
|
-
)
|
192
|
-
np.testing.assert_allclose(
|
193
|
-
pm.eval_scalar(s7).matrix,
|
194
|
-
np.array(
|
195
|
-
[
|
196
|
-
[0.0, 0.0, 0.0, 0.0],
|
197
|
-
[0.0, 0.0, 0.0, 0.0],
|
198
|
-
[0.0, 0.0, 0.0, 2.5],
|
199
|
-
[0.0, 0.0, 2.5, 0.0],
|
200
|
-
]
|
201
|
-
),
|
202
|
-
)
|
203
|
-
|
204
|
-
np.testing.assert_allclose(
|
205
|
-
pm.eval_scalar(s8).matrix,
|
206
|
-
np.array(
|
207
|
-
[
|
208
|
-
[0.0, 0.5, 0.0, 0.0],
|
209
|
-
[0.5, 0.0, 0.0, 0.0],
|
210
|
-
[0.0, 0.0, 0.0, 2.5],
|
211
|
-
[0.0, 0.0, 2.5, 0.0],
|
212
|
-
]
|
213
|
-
),
|
214
|
-
)
|
215
|
-
|
216
|
-
|
217
|
-
def test_constraint():
|
218
|
-
pep_builder = pep.PEPBuilder()
|
219
|
-
with pep_builder.make_context("test") as ctx:
|
220
|
-
s1 = scalar.Scalar(is_basis=True)
|
221
|
-
s2 = scalar.Scalar(is_basis=True)
|
222
|
-
s3 = 2 * s1 + s2 / 4 + 5
|
223
|
-
|
224
|
-
c1 = s3.le(5, name="c1")
|
225
|
-
c2 = s3.lt(5, name="c2")
|
226
|
-
c3 = s3.ge(5, name="c3")
|
227
|
-
c4 = s3.gt(5, name="c4")
|
228
|
-
c5 = s3.eq(5, name="c5")
|
229
|
-
|
230
|
-
pm = exm.ExpressionManager(ctx)
|
231
|
-
|
232
|
-
np.testing.assert_allclose(pm.eval_scalar(c1.scalar).vector, np.array([2, 0.25]))
|
233
|
-
np.testing.assert_allclose(pm.eval_scalar(c1.scalar).constant, 0)
|
234
|
-
assert c1.comparator == utils.Comparator.LT
|
235
|
-
|
236
|
-
np.testing.assert_allclose(pm.eval_scalar(c2.scalar).vector, np.array([2, 0.25]))
|
237
|
-
np.testing.assert_allclose(pm.eval_scalar(c2.scalar).constant, 0)
|
238
|
-
assert c2.comparator == utils.Comparator.LT
|
239
|
-
|
240
|
-
np.testing.assert_allclose(pm.eval_scalar(c3.scalar).vector, np.array([2, 0.25]))
|
241
|
-
np.testing.assert_allclose(pm.eval_scalar(c3.scalar).constant, 0)
|
242
|
-
assert c3.comparator == utils.Comparator.GT
|
243
|
-
|
244
|
-
np.testing.assert_allclose(pm.eval_scalar(c4.scalar).vector, np.array([2, 0.25]))
|
245
|
-
np.testing.assert_allclose(pm.eval_scalar(c4.scalar).constant, 0)
|
246
|
-
assert c4.comparator == utils.Comparator.GT
|
247
|
-
|
248
|
-
np.testing.assert_allclose(pm.eval_scalar(c5.scalar).vector, np.array([2, 0.25]))
|
249
|
-
np.testing.assert_allclose(pm.eval_scalar(c5.scalar).constant, 0)
|
250
|
-
assert c5.comparator == utils.Comparator.EQ
|
251
|
-
|
252
|
-
|
253
211
|
def test_expression_manager_eval_point_large_scale():
|
254
212
|
pep_builder = pep.PEPBuilder()
|
255
213
|
with pep_builder.make_context("test") as ctx:
|
256
|
-
all_basis = [point.Point(is_basis=True) for
|
214
|
+
all_basis = [point.Point(is_basis=True, tags=[f"p_{i}"]) for i in range(100)]
|
257
215
|
p = all_basis[0]
|
258
216
|
for i in range(len(all_basis)):
|
259
217
|
for j in range(i + 1, len(all_basis)):
|
pepflow/scalar.py
CHANGED
@@ -173,56 +173,94 @@ class Scalar:
|
|
173
173
|
return self.tag
|
174
174
|
return super().__repr__()
|
175
175
|
|
176
|
+
def _repr_latex_(self):
|
177
|
+
s = repr(self)
|
178
|
+
s = s.replace("star", r"\star")
|
179
|
+
s = s.replace("gradient_", r"\nabla ")
|
180
|
+
return rf"$\\displaystyle {s}$"
|
181
|
+
|
176
182
|
def __add__(self, other):
|
177
183
|
assert is_numerical_or_scalar(other)
|
184
|
+
if utils.is_numerical(other):
|
185
|
+
tag_other = f"{other:.4g}"
|
186
|
+
else:
|
187
|
+
tag_other = other.tag
|
178
188
|
return Scalar(
|
179
189
|
is_basis=False,
|
180
190
|
eval_expression=EvalExpressionScalar(utils.Op.ADD, self, other),
|
191
|
+
tags=[f"{self.tag}+{tag_other}"],
|
181
192
|
)
|
182
193
|
|
183
194
|
def __radd__(self, other):
|
184
195
|
assert is_numerical_or_scalar(other)
|
196
|
+
if utils.is_numerical(other):
|
197
|
+
tag_other = f"{other:.4g}"
|
198
|
+
else:
|
199
|
+
tag_other = other.tag
|
185
200
|
return Scalar(
|
186
201
|
is_basis=False,
|
187
202
|
eval_expression=EvalExpressionScalar(utils.Op.ADD, other, self),
|
203
|
+
tags=[f"{tag_other}+{self.tag}"],
|
188
204
|
)
|
189
205
|
|
190
206
|
def __sub__(self, other):
|
191
207
|
assert is_numerical_or_scalar(other)
|
208
|
+
if utils.is_numerical(other):
|
209
|
+
tag_other = f"{other:.4g}"
|
210
|
+
else:
|
211
|
+
tag_other = utils.parenthesize_tag(other)
|
192
212
|
return Scalar(
|
193
213
|
is_basis=False,
|
194
214
|
eval_expression=EvalExpressionScalar(utils.Op.SUB, self, other),
|
215
|
+
tags=[f"{self.tag}-{tag_other}"],
|
195
216
|
)
|
196
217
|
|
197
218
|
def __rsub__(self, other):
|
198
219
|
assert is_numerical_or_scalar(other)
|
220
|
+
tag_self = utils.parenthesize_tag(self)
|
221
|
+
if utils.is_numerical(other):
|
222
|
+
tag_other = f"{other:.4g}"
|
223
|
+
else:
|
224
|
+
tag_other = other.tag
|
199
225
|
return Scalar(
|
200
226
|
is_basis=False,
|
201
227
|
eval_expression=EvalExpressionScalar(utils.Op.SUB, other, self),
|
228
|
+
tags=[f"{tag_other}-{tag_self}"],
|
202
229
|
)
|
203
230
|
|
204
231
|
def __mul__(self, other):
|
205
232
|
assert utils.is_numerical(other)
|
233
|
+
tag_self = utils.parenthesize_tag(self)
|
206
234
|
return Scalar(
|
207
235
|
is_basis=False,
|
208
236
|
eval_expression=EvalExpressionScalar(utils.Op.MUL, self, other),
|
237
|
+
tags=[f"{tag_self}*{other:.4g}"],
|
209
238
|
)
|
210
239
|
|
211
240
|
def __rmul__(self, other):
|
212
241
|
assert utils.is_numerical(other)
|
242
|
+
tag_self = utils.parenthesize_tag(self)
|
213
243
|
return Scalar(
|
214
244
|
is_basis=False,
|
215
245
|
eval_expression=EvalExpressionScalar(utils.Op.MUL, other, self),
|
246
|
+
tags=[f"{other:.4g}*{tag_self}"],
|
216
247
|
)
|
217
248
|
|
218
249
|
def __neg__(self):
|
219
|
-
|
250
|
+
tag_self = utils.parenthesize_tag(self)
|
251
|
+
return Scalar(
|
252
|
+
is_basis=False,
|
253
|
+
eval_expression=EvalExpressionScalar(utils.Op.MUL, -1, self),
|
254
|
+
tags=[f"-{tag_self}"],
|
255
|
+
)
|
220
256
|
|
221
257
|
def __truediv__(self, other):
|
222
258
|
assert utils.is_numerical(other)
|
259
|
+
tag_self = utils.parenthesize_tag(self)
|
223
260
|
return Scalar(
|
224
261
|
is_basis=False,
|
225
262
|
eval_expression=EvalExpressionScalar(utils.Op.DIV, self, other),
|
263
|
+
tags=[f"1/{other:.4g}*{tag_self}"],
|
226
264
|
)
|
227
265
|
|
228
266
|
def __hash__(self):
|
@@ -247,3 +285,14 @@ class Scalar:
|
|
247
285
|
|
248
286
|
def eq(self, other, name: str) -> ctr.Constraint:
|
249
287
|
return ctr.Constraint(self - other, comparator=utils.Comparator.EQ, name=name)
|
288
|
+
|
289
|
+
def eval(self, ctx: pc.PEPContext | None = None) -> EvaluatedScalar:
|
290
|
+
from pepflow.expression_manager import ExpressionManager
|
291
|
+
|
292
|
+
# Note this can be inefficient.
|
293
|
+
if ctx is None:
|
294
|
+
ctx = pc.get_current_context()
|
295
|
+
if ctx is None:
|
296
|
+
raise RuntimeError("Did you forget to create a context?")
|
297
|
+
em = ExpressionManager(ctx)
|
298
|
+
return em.eval_scalar(self)
|
pepflow/scalar_test.py
ADDED
@@ -0,0 +1,250 @@
|
|
1
|
+
# Copyright: 2025 The PEPFlow Developers
|
2
|
+
#
|
3
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
4
|
+
# or more contributor license agreements. See the NOTICE file
|
5
|
+
# distributed with this work for additional information
|
6
|
+
# regarding copyright ownership. The ASF licenses this file
|
7
|
+
# to you under the Apache License, Version 2.0 (the
|
8
|
+
# "License"); you may not use this file except in compliance
|
9
|
+
# with the License. You may obtain a copy of the License at
|
10
|
+
#
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12
|
+
#
|
13
|
+
# Unless required by applicable law or agreed to in writing,
|
14
|
+
# software distributed under the License is distributed on an
|
15
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
16
|
+
# KIND, either express or implied. See the License for the
|
17
|
+
# specific language governing permissions and limitations
|
18
|
+
# under the License.
|
19
|
+
|
20
|
+
|
21
|
+
import numpy as np
|
22
|
+
|
23
|
+
from pepflow import expression_manager as exm
|
24
|
+
from pepflow import pep as pep
|
25
|
+
from pepflow import point, scalar, utils
|
26
|
+
|
27
|
+
|
28
|
+
def test_scalar_add_tag():
|
29
|
+
pep_builder = pep.PEPBuilder()
|
30
|
+
with pep_builder.make_context("test"):
|
31
|
+
s1 = scalar.Scalar(is_basis=True, eval_expression=None, tags=["s1"])
|
32
|
+
s2 = scalar.Scalar(is_basis=True, eval_expression=None, tags=["s2"])
|
33
|
+
|
34
|
+
s_add = s1 + s2
|
35
|
+
assert s_add.tag == "s1+s2"
|
36
|
+
|
37
|
+
s_add = s1 + 0.1
|
38
|
+
assert s_add.tag == "s1+0.1"
|
39
|
+
|
40
|
+
s_radd = 0.1 + s1
|
41
|
+
assert s_radd.tag == "0.1+s1"
|
42
|
+
|
43
|
+
s_sub = s1 - s2
|
44
|
+
assert s_sub.tag == "s1-s2"
|
45
|
+
|
46
|
+
s_sub = s1 - (s2 + s1)
|
47
|
+
assert s_sub.tag == "s1-(s2+s1)"
|
48
|
+
|
49
|
+
s_sub = s1 - (s2 - s1)
|
50
|
+
assert s_sub.tag == "s1-(s2-s1)"
|
51
|
+
|
52
|
+
s_sub = s1 - 0.1
|
53
|
+
assert s_sub.tag == "s1-0.1"
|
54
|
+
|
55
|
+
s_rsub = 0.1 - s1
|
56
|
+
assert s_rsub.tag == "0.1-s1"
|
57
|
+
|
58
|
+
|
59
|
+
def test_scalar_mul_tag():
|
60
|
+
pep_builder = pep.PEPBuilder()
|
61
|
+
with pep_builder.make_context("test"):
|
62
|
+
s = scalar.Scalar(is_basis=True, eval_expression=None, tags=["s"])
|
63
|
+
|
64
|
+
s_mul = s * 0.1
|
65
|
+
assert s_mul.tag == "s*0.1"
|
66
|
+
|
67
|
+
s_rmul = 0.1 * s
|
68
|
+
assert s_rmul.tag == "0.1*s"
|
69
|
+
|
70
|
+
s_neg = -s
|
71
|
+
assert s_neg.tag == "-s"
|
72
|
+
|
73
|
+
s_truediv = s / 0.1
|
74
|
+
assert s_truediv.tag == "1/0.1*s"
|
75
|
+
|
76
|
+
|
77
|
+
def test_scalar_add_and_mul_tag():
|
78
|
+
pep_builder = pep.PEPBuilder()
|
79
|
+
with pep_builder.make_context("test"):
|
80
|
+
s1 = scalar.Scalar(is_basis=True, eval_expression=None, tags=["s1"])
|
81
|
+
s2 = scalar.Scalar(is_basis=True, eval_expression=None, tags=["s2"])
|
82
|
+
|
83
|
+
s_add_mul = (s1 + s2) * 0.1
|
84
|
+
assert s_add_mul.tag == "(s1+s2)*0.1"
|
85
|
+
|
86
|
+
s_add_mul = s1 + s2 * 0.1
|
87
|
+
assert s_add_mul.tag == "s1+s2*0.1"
|
88
|
+
|
89
|
+
s_neg_add = -(s1 + s2)
|
90
|
+
assert s_neg_add.tag == "-(s1+s2)"
|
91
|
+
|
92
|
+
s_rmul_add = 0.1 * (s1 + s2)
|
93
|
+
assert s_rmul_add.tag == "0.1*(s1+s2)"
|
94
|
+
|
95
|
+
|
96
|
+
def test_scalar_hash_different():
|
97
|
+
pep_builder = pep.PEPBuilder()
|
98
|
+
with pep_builder.make_context("test"):
|
99
|
+
s1 = scalar.Scalar(is_basis=True, eval_expression=None)
|
100
|
+
s2 = scalar.Scalar(is_basis=True, eval_expression=None)
|
101
|
+
assert s1.uid != s2.uid
|
102
|
+
|
103
|
+
|
104
|
+
def test_scalar_tag():
|
105
|
+
pep_builder = pep.PEPBuilder()
|
106
|
+
with pep_builder.make_context("test"):
|
107
|
+
s1 = scalar.Scalar(is_basis=True, eval_expression=None)
|
108
|
+
s1.add_tag(tag="my_tag")
|
109
|
+
assert s1.tags == ["my_tag"]
|
110
|
+
assert s1.tag == "my_tag"
|
111
|
+
|
112
|
+
|
113
|
+
def test_scalar_repr():
|
114
|
+
pep_builder = pep.PEPBuilder()
|
115
|
+
with pep_builder.make_context("test"):
|
116
|
+
s1 = scalar.Scalar(is_basis=True, tags=["s1"])
|
117
|
+
print(s1) # it should be fine without tag
|
118
|
+
s1.add_tag("my_tag")
|
119
|
+
assert str(s1) == "my_tag"
|
120
|
+
|
121
|
+
|
122
|
+
def test_scalar_in_a_list():
|
123
|
+
pep_builder = pep.PEPBuilder()
|
124
|
+
with pep_builder.make_context("test"):
|
125
|
+
s1 = scalar.Scalar(is_basis=True, eval_expression=None)
|
126
|
+
s2 = scalar.Scalar(is_basis=True, eval_expression=None)
|
127
|
+
s3 = scalar.Scalar(is_basis=True, eval_expression=None)
|
128
|
+
assert s1 in [s1, s2]
|
129
|
+
assert s3 not in [s1, s2]
|
130
|
+
|
131
|
+
|
132
|
+
def test_expression_manager_on_basis_scalar():
|
133
|
+
pep_builder = pep.PEPBuilder()
|
134
|
+
with pep_builder.make_context("test") as ctx:
|
135
|
+
s1 = scalar.Scalar(is_basis=True, eval_expression=None, tags=["s1"])
|
136
|
+
s2 = scalar.Scalar(is_basis=True, eval_expression=None, tags=["s2"])
|
137
|
+
pm = exm.ExpressionManager(ctx)
|
138
|
+
|
139
|
+
np.testing.assert_allclose(pm.eval_scalar(s1).vector, np.array([1, 0]))
|
140
|
+
np.testing.assert_allclose(pm.eval_scalar(s2).vector, np.array([0, 1]))
|
141
|
+
|
142
|
+
s3 = scalar.Scalar(is_basis=True, eval_expression=None, tags=["s3"]) # noqa: F841
|
143
|
+
pm = exm.ExpressionManager(ctx)
|
144
|
+
|
145
|
+
np.testing.assert_allclose(pm.eval_scalar(s1).vector, np.array([1, 0, 0]))
|
146
|
+
np.testing.assert_allclose(pm.eval_scalar(s2).vector, np.array([0, 1, 0]))
|
147
|
+
|
148
|
+
|
149
|
+
def test_expression_manager_eval_scalar():
|
150
|
+
pep_builder = pep.PEPBuilder()
|
151
|
+
with pep_builder.make_context("test") as ctx:
|
152
|
+
s1 = scalar.Scalar(is_basis=True, tags=["s1"])
|
153
|
+
s2 = scalar.Scalar(is_basis=True, tags=["s2"])
|
154
|
+
s3 = 2 * s1 + s2 / 4 + 5
|
155
|
+
s4 = s3 + s1
|
156
|
+
s5 = s4 + 5
|
157
|
+
|
158
|
+
p1 = point.Point(is_basis=True, tags=["p1"])
|
159
|
+
p2 = point.Point(is_basis=True, tags=["p2"])
|
160
|
+
s6 = p1 * p2
|
161
|
+
|
162
|
+
p3 = point.Point(is_basis=True, tags=["p3"])
|
163
|
+
p4 = point.Point(is_basis=True, tags=["p4"])
|
164
|
+
s7 = 5 * p3 * p4
|
165
|
+
|
166
|
+
s8 = s6 + s7
|
167
|
+
|
168
|
+
pm = exm.ExpressionManager(ctx)
|
169
|
+
|
170
|
+
np.testing.assert_allclose(pm.eval_scalar(s3).vector, np.array([2, 0.25]))
|
171
|
+
np.testing.assert_allclose(pm.eval_scalar(s3).constant, 5)
|
172
|
+
np.testing.assert_allclose(pm.eval_scalar(s4).vector, np.array([3, 0.25]))
|
173
|
+
np.testing.assert_allclose(pm.eval_scalar(s5).vector, np.array([3, 0.25]))
|
174
|
+
np.testing.assert_allclose(pm.eval_scalar(s5).constant, 10)
|
175
|
+
|
176
|
+
np.testing.assert_allclose(pm.eval_point(p1).vector, np.array([1, 0, 0, 0]))
|
177
|
+
np.testing.assert_allclose(pm.eval_point(p2).vector, np.array([0, 1, 0, 0]))
|
178
|
+
np.testing.assert_allclose(pm.eval_point(p3).vector, np.array([0, 0, 1, 0]))
|
179
|
+
np.testing.assert_allclose(pm.eval_point(p4).vector, np.array([0, 0, 0, 1]))
|
180
|
+
|
181
|
+
np.testing.assert_allclose(
|
182
|
+
pm.eval_scalar(s6).matrix,
|
183
|
+
np.array(
|
184
|
+
[
|
185
|
+
[0.0, 0.5, 0.0, 0.0],
|
186
|
+
[0.5, 0.0, 0.0, 0.0],
|
187
|
+
[0.0, 0.0, 0.0, 0.0],
|
188
|
+
[0.0, 0.0, 0.0, 0.0],
|
189
|
+
]
|
190
|
+
),
|
191
|
+
)
|
192
|
+
np.testing.assert_allclose(
|
193
|
+
pm.eval_scalar(s7).matrix,
|
194
|
+
np.array(
|
195
|
+
[
|
196
|
+
[0.0, 0.0, 0.0, 0.0],
|
197
|
+
[0.0, 0.0, 0.0, 0.0],
|
198
|
+
[0.0, 0.0, 0.0, 2.5],
|
199
|
+
[0.0, 0.0, 2.5, 0.0],
|
200
|
+
]
|
201
|
+
),
|
202
|
+
)
|
203
|
+
|
204
|
+
np.testing.assert_allclose(
|
205
|
+
pm.eval_scalar(s8).matrix,
|
206
|
+
np.array(
|
207
|
+
[
|
208
|
+
[0.0, 0.5, 0.0, 0.0],
|
209
|
+
[0.5, 0.0, 0.0, 0.0],
|
210
|
+
[0.0, 0.0, 0.0, 2.5],
|
211
|
+
[0.0, 0.0, 2.5, 0.0],
|
212
|
+
]
|
213
|
+
),
|
214
|
+
)
|
215
|
+
|
216
|
+
|
217
|
+
def test_constraint():
|
218
|
+
pep_builder = pep.PEPBuilder()
|
219
|
+
with pep_builder.make_context("test") as ctx:
|
220
|
+
s1 = scalar.Scalar(is_basis=True, tags=["s1"])
|
221
|
+
s2 = scalar.Scalar(is_basis=True, tags=["s2"])
|
222
|
+
s3 = 2 * s1 + s2 / 4 + 5
|
223
|
+
|
224
|
+
c1 = s3.le(5, name="c1")
|
225
|
+
c2 = s3.lt(5, name="c2")
|
226
|
+
c3 = s3.ge(5, name="c3")
|
227
|
+
c4 = s3.gt(5, name="c4")
|
228
|
+
c5 = s3.eq(5, name="c5")
|
229
|
+
|
230
|
+
pm = exm.ExpressionManager(ctx)
|
231
|
+
|
232
|
+
np.testing.assert_allclose(pm.eval_scalar(c1.scalar).vector, np.array([2, 0.25]))
|
233
|
+
np.testing.assert_allclose(pm.eval_scalar(c1.scalar).constant, 0)
|
234
|
+
assert c1.comparator == utils.Comparator.LT
|
235
|
+
|
236
|
+
np.testing.assert_allclose(pm.eval_scalar(c2.scalar).vector, np.array([2, 0.25]))
|
237
|
+
np.testing.assert_allclose(pm.eval_scalar(c2.scalar).constant, 0)
|
238
|
+
assert c2.comparator == utils.Comparator.LT
|
239
|
+
|
240
|
+
np.testing.assert_allclose(pm.eval_scalar(c3.scalar).vector, np.array([2, 0.25]))
|
241
|
+
np.testing.assert_allclose(pm.eval_scalar(c3.scalar).constant, 0)
|
242
|
+
assert c3.comparator == utils.Comparator.GT
|
243
|
+
|
244
|
+
np.testing.assert_allclose(pm.eval_scalar(c4.scalar).vector, np.array([2, 0.25]))
|
245
|
+
np.testing.assert_allclose(pm.eval_scalar(c4.scalar).constant, 0)
|
246
|
+
assert c4.comparator == utils.Comparator.GT
|
247
|
+
|
248
|
+
np.testing.assert_allclose(pm.eval_scalar(c5.scalar).vector, np.array([2, 0.25]))
|
249
|
+
np.testing.assert_allclose(pm.eval_scalar(c5.scalar).constant, 0)
|
250
|
+
assert c5.comparator == utils.Comparator.EQ
|
pepflow/solver_test.py
CHANGED
@@ -27,8 +27,8 @@ from pepflow import solver as ps
|
|
27
27
|
def test_cvx_solver_case1():
|
28
28
|
pep_builder = pep.PEPBuilder()
|
29
29
|
with pep_builder.make_context("test"):
|
30
|
-
p1 = pp.Point(is_basis=True)
|
31
|
-
s1 = pp.Scalar(is_basis=True)
|
30
|
+
p1 = pp.Point(is_basis=True, tags=["p1"])
|
31
|
+
s1 = pp.Scalar(is_basis=True, tags=["s1"])
|
32
32
|
s2 = -(1 + p1 * p1)
|
33
33
|
constraints = [(p1 * p1).gt(1, name="x^2 >= 1"), s1.gt(0, name="s1 > 0")]
|
34
34
|
|
@@ -50,9 +50,9 @@ def test_cvx_solver_case1():
|
|
50
50
|
def test_cvx_solver_case2():
|
51
51
|
pep_builder = pep.PEPBuilder()
|
52
52
|
with pep_builder.make_context("test"):
|
53
|
-
p1 = pp.Point(is_basis=True)
|
54
|
-
s1 = pp.Scalar(is_basis=True)
|
55
|
-
s2 = -
|
53
|
+
p1 = pp.Point(is_basis=True, tags=["p1"])
|
54
|
+
s1 = pp.Scalar(is_basis=True, tags=["s1"])
|
55
|
+
s2 = -p1 * p1 + 2
|
56
56
|
constraints = [(p1 * p1).lt(1, name="x^2 <= 1"), s1.gt(0, name="s1 > 0")]
|
57
57
|
|
58
58
|
solver = ps.CVXSolver(
|
@@ -61,10 +61,10 @@ def test_cvx_solver_case2():
|
|
61
61
|
context=pep_builder.get_context("test"),
|
62
62
|
)
|
63
63
|
|
64
|
-
# It is a simple `min_x
|
64
|
+
# It is a simple `min_x x^2-2; s.t. x^2 <= 1` problem.
|
65
65
|
problem = solver.build_problem()
|
66
66
|
result = problem.solve()
|
67
|
-
assert abs(-result) < 1e-6
|
67
|
+
assert abs(-result + 2) < 1e-6
|
68
68
|
|
69
69
|
assert np.isclose(solver.dual_var_manager.dual_value("x^2 <= 1"), 0)
|
70
70
|
assert solver.dual_var_manager.dual_value("s1 > 0") == 0
|
pepflow/utils.py
CHANGED
@@ -21,10 +21,15 @@ from __future__ import annotations
|
|
21
21
|
|
22
22
|
import enum
|
23
23
|
import numbers
|
24
|
-
from typing import Any
|
24
|
+
from typing import TYPE_CHECKING, Any
|
25
25
|
|
26
26
|
import numpy as np
|
27
27
|
|
28
|
+
if TYPE_CHECKING:
|
29
|
+
from pepflow.function import Function
|
30
|
+
from pepflow.point import Point
|
31
|
+
from pepflow.scalar import Scalar
|
32
|
+
|
28
33
|
|
29
34
|
def SOP(v, w):
|
30
35
|
"""Symmetric Outer Product."""
|
@@ -50,3 +55,11 @@ class Comparator(enum.Enum):
|
|
50
55
|
|
51
56
|
def is_numerical(val: Any) -> bool:
|
52
57
|
return isinstance(val, numbers.Number)
|
58
|
+
|
59
|
+
|
60
|
+
def parenthesize_tag(val: Point | Scalar | Function) -> str:
|
61
|
+
tmp_tag = val.tag
|
62
|
+
if not val.is_basis:
|
63
|
+
if val.eval_expression.op == Op.ADD or val.eval_expression.op == Op.SUB:
|
64
|
+
tmp_tag = f"({val.tag})"
|
65
|
+
return tmp_tag
|
@@ -0,0 +1,24 @@
|
|
1
|
+
pepflow/__init__.py,sha256=tLnOlZ1y_mIodRl5Fr5HMLPP5M7h_ad76s4in5OgiHE,1930
|
2
|
+
pepflow/constants.py,sha256=t29CDRE8kw773zgKS0ZZCYGegwagaDDdLfpSaeDpK14,871
|
3
|
+
pepflow/constraint.py,sha256=n-01dcQplvsXB7V4fceJBImbwSr-Wa9k9tE7ZcVmi3o,1153
|
4
|
+
pepflow/e2e_test.py,sha256=9UfWefn5kq9Jn7L7tMRPtnkUMaac3uJIbU8-qcUYvd0,1080
|
5
|
+
pepflow/expression_manager.py,sha256=64FbdMjcUrIqtbFZyJoxoDZ42j1bl1CeGxhJnKcGTD8,5222
|
6
|
+
pepflow/function.py,sha256=Ghgm2yyKsq7WSQfTufub5h8kD7j7ouwYfLd2hdRBw-w,14212
|
7
|
+
pepflow/function_test.py,sha256=dm4qcEiB8sshqRQ4zphQlJP-lE6IcVrwf1CDTwUmz9M,6903
|
8
|
+
pepflow/interactive_constraint.py,sha256=HMwAo80tJyCKwWKryxNGk1LAh-qlnQk5OeULLwKZUOQ,9826
|
9
|
+
pepflow/pep.py,sha256=fLjEc3dtpDKiVXV2T5vQGJKRDYim6piuNQZHsmxiKc4,6082
|
10
|
+
pepflow/pep_context.py,sha256=9an5fwdjqft7P_7MkRQ1TWYI-5_qw3VSD9qLLQGdX6Y,4719
|
11
|
+
pepflow/pep_context_test.py,sha256=FqMsIznBcM4NmBU9mb1tGh4vUwmuJ-MWa3GMA_pg6Ik,3252
|
12
|
+
pepflow/pep_test.py,sha256=aWwP3CqXNJzQUhJqC5sNcsM83i_asJaZewpGNqC1G7M,2504
|
13
|
+
pepflow/point.py,sha256=vSKCuivJL3zyUCnV48rfM5XpAGsVbdQ8YWdO3fKrViI,8514
|
14
|
+
pepflow/point_test.py,sha256=ojWJ-mOS2XJsTKuNhlMGBmOzPYSrknaTGIJi6mXLoEk,10708
|
15
|
+
pepflow/scalar.py,sha256=8E-D5pgZCDpyEX13VwrFARo87v-HWAUs4gAD_-Ty_g8,9653
|
16
|
+
pepflow/scalar_test.py,sha256=T2-KGjWUsZGWZII64DEmUiVnu8o4Kr9Mi0YCgSUvYVc,8386
|
17
|
+
pepflow/solver.py,sha256=WzeN_IWNBs9IpE212jenhYMWFuuwH890h0vaFmJRM6I,4312
|
18
|
+
pepflow/solver_test.py,sha256=-aCEe-oQ26xJUWR64b-CIIfFOK_pNJnMlOly2bagk68,2457
|
19
|
+
pepflow/utils.py,sha256=Xd-DwtUoUSMw_i6xjYsOxsHmn8l3ZbYCyptXEBfoyZk,1718
|
20
|
+
pepflow-0.1.4.dist-info/licenses/LICENSE,sha256=na5oVXAps-5f1hLGG4SYnwFdavQeXgYUeN-E3MxOA_s,11361
|
21
|
+
pepflow-0.1.4.dist-info/METADATA,sha256=eMSRF3dk0eWgwicdFSsnJVxjsTcS9rG5Jp-OF9orDO4,1434
|
22
|
+
pepflow-0.1.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
23
|
+
pepflow-0.1.4.dist-info/top_level.txt,sha256=0YEPCZQQa6yIAIwMumzDg4pj7AME8aXu2sXkuq8xM6M,8
|
24
|
+
pepflow-0.1.4.dist-info/RECORD,,
|
pepflow-0.1.2.dist-info/RECORD
DELETED
@@ -1,22 +0,0 @@
|
|
1
|
-
pepflow/__init__.py,sha256=tLnOlZ1y_mIodRl5Fr5HMLPP5M7h_ad76s4in5OgiHE,1930
|
2
|
-
pepflow/constants.py,sha256=t29CDRE8kw773zgKS0ZZCYGegwagaDDdLfpSaeDpK14,871
|
3
|
-
pepflow/constraint.py,sha256=n-01dcQplvsXB7V4fceJBImbwSr-Wa9k9tE7ZcVmi3o,1153
|
4
|
-
pepflow/expression_manager.py,sha256=64FbdMjcUrIqtbFZyJoxoDZ42j1bl1CeGxhJnKcGTD8,5222
|
5
|
-
pepflow/function.py,sha256=yQcx78Ml6O34pVTcqYXxmO75l7IGdVvfwaY01lwsOHc,12566
|
6
|
-
pepflow/function_test.py,sha256=947Jv0i7HtQ-LOi9-Hd7YsIGbX6BjnHsW3N4Q9HZE-A,4833
|
7
|
-
pepflow/interactive_constraint.py,sha256=HMwAo80tJyCKwWKryxNGk1LAh-qlnQk5OeULLwKZUOQ,9826
|
8
|
-
pepflow/pep.py,sha256=IKrYW6DBIekKOjGfL7rc6t1xU06VJEyL_gxj1sH_m_Q,5824
|
9
|
-
pepflow/pep_context.py,sha256=IgrikL5Eqi0RfrFaVKn9lsodqtjj9A_RWFb82Qk-SYQ,4443
|
10
|
-
pepflow/pep_context_test.py,sha256=1kFkmG56JVkmjaT9ic9byV-F9WOs9ZNQz5P1P5v5lLw,3092
|
11
|
-
pepflow/pep_test.py,sha256=aWwP3CqXNJzQUhJqC5sNcsM83i_asJaZewpGNqC1G7M,2504
|
12
|
-
pepflow/point.py,sha256=2iKmU6K_e3DmskCW9faQBxTaqPm3MUQPHTZvKVh3_yM,6711
|
13
|
-
pepflow/point_test.py,sha256=jwbvjfwHP_9oY7KLp4UE_uSEbAVlDTXAQjr8d0z-eME,12602
|
14
|
-
pepflow/scalar.py,sha256=vVA9qZjKIvM6RlZ4wN5MEFyG9eTFx7ItQaXtPY7ws7c,7846
|
15
|
-
pepflow/solver.py,sha256=WzeN_IWNBs9IpE212jenhYMWFuuwH890h0vaFmJRM6I,4312
|
16
|
-
pepflow/solver_test.py,sha256=BC37ggL_S8NDtnq8nPwRR2xItMCYMLs-n931u7f34YE,2414
|
17
|
-
pepflow/utils.py,sha256=tYcpOOWAieNay5Ly7v7dRll69RPYCUGMEuW-Bwp2z68,1321
|
18
|
-
pepflow-0.1.2.dist-info/licenses/LICENSE,sha256=na5oVXAps-5f1hLGG4SYnwFdavQeXgYUeN-E3MxOA_s,11361
|
19
|
-
pepflow-0.1.2.dist-info/METADATA,sha256=tZyuaGvCHycaEKDj2RUQhhAguXM1F9JQYMpWbA-jRSc,1434
|
20
|
-
pepflow-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
21
|
-
pepflow-0.1.2.dist-info/top_level.txt,sha256=0YEPCZQQa6yIAIwMumzDg4pj7AME8aXu2sXkuq8xM6M,8
|
22
|
-
pepflow-0.1.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|