pepflow 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pepflow/__init__.py +6 -1
- pepflow/constraint.py +58 -1
- pepflow/constraint_test.py +71 -0
- pepflow/e2e_test.py +83 -4
- pepflow/expression_manager.py +329 -44
- pepflow/expression_manager_test.py +150 -0
- pepflow/function.py +294 -52
- pepflow/function_test.py +180 -114
- pepflow/interactive_constraint.py +165 -75
- pepflow/parameter.py +187 -0
- pepflow/parameter_test.py +128 -0
- pepflow/pep.py +263 -16
- pepflow/pep_context.py +122 -6
- pepflow/pep_context_test.py +25 -0
- pepflow/pep_test.py +8 -0
- pepflow/point.py +155 -49
- pepflow/point_test.py +40 -188
- pepflow/scalar.py +260 -47
- pepflow/scalar_test.py +102 -130
- pepflow/solver.py +170 -3
- pepflow/solver_test.py +50 -2
- pepflow/utils.py +39 -7
- {pepflow-0.1.4.dist-info → pepflow-0.1.5.dist-info}/METADATA +24 -5
- pepflow-0.1.5.dist-info/RECORD +28 -0
- pepflow-0.1.4.dist-info/RECORD +0 -24
- {pepflow-0.1.4.dist-info → pepflow-0.1.5.dist-info}/WHEEL +0 -0
- {pepflow-0.1.4.dist-info → pepflow-0.1.5.dist-info}/licenses/LICENSE +0 -0
- {pepflow-0.1.4.dist-info → pepflow-0.1.5.dist-info}/top_level.txt +0 -0
pepflow/pep_context.py
CHANGED
@@ -26,7 +26,6 @@ import natsort
|
|
26
26
|
import pandas as pd
|
27
27
|
|
28
28
|
if TYPE_CHECKING:
|
29
|
-
from pepflow.constraint import Constraint
|
30
29
|
from pepflow.function import Function, Triplet
|
31
30
|
from pepflow.point import Point
|
32
31
|
from pepflow.scalar import Scalar
|
@@ -38,26 +37,54 @@ GLOBAL_CONTEXT_DICT: dict[str, PEPContext] = {}
|
|
38
37
|
|
39
38
|
|
40
39
|
def get_current_context() -> PEPContext | None:
|
40
|
+
"""
|
41
|
+
Return the current global :class:`PEPContext`.
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
:class:`PEPContext`: The current global :class:`PEPContext`.
|
45
|
+
"""
|
41
46
|
return CURRENT_CONTEXT
|
42
47
|
|
43
48
|
|
44
49
|
def set_current_context(ctx: PEPContext | None):
|
50
|
+
"""
|
51
|
+
Change the current global :class:`PEPContext`.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
ctx (:class:`PEPContext`): The :class:`PEPContext` to set as the new
|
55
|
+
global :class:`PEPContext`.
|
56
|
+
"""
|
45
57
|
global CURRENT_CONTEXT
|
46
58
|
assert ctx is None or isinstance(ctx, PEPContext)
|
47
59
|
CURRENT_CONTEXT = ctx
|
48
60
|
|
49
61
|
|
50
62
|
class PEPContext:
|
63
|
+
"""
|
64
|
+
A :class:`PEPContext` object is a context manager which maintains
|
65
|
+
the abstract mathematical objects of the Primal and Dual PEP.
|
66
|
+
|
67
|
+
Attributes:
|
68
|
+
name (str): The unique name of the :class:`PEPContext` object.
|
69
|
+
"""
|
70
|
+
|
51
71
|
def __init__(self, name: str):
|
52
72
|
self.name = name
|
53
73
|
self.points: list[Point] = []
|
54
74
|
self.scalars: list[Scalar] = []
|
55
75
|
self.triplets: dict[Function, list[Triplet]] = defaultdict(list)
|
56
|
-
self.
|
76
|
+
# self.triplets will contain all stationary_triplets. They are not mutually exclusive.
|
77
|
+
self.stationary_triplets: dict[Function, list[Triplet]] = defaultdict(list)
|
57
78
|
|
58
79
|
GLOBAL_CONTEXT_DICT[name] = self
|
59
80
|
|
60
81
|
def set_as_current(self) -> PEPContext:
|
82
|
+
"""
|
83
|
+
Set this :class:`PEPContext` object as the global context.
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
:class:`PEPContext`: This :class:`PEPContext` object.
|
87
|
+
"""
|
61
88
|
set_current_context(self)
|
62
89
|
return self
|
63
90
|
|
@@ -70,35 +97,98 @@ class PEPContext:
|
|
70
97
|
def add_triplet(self, function: Function, triplet: Triplet):
|
71
98
|
self.triplets[function].append(triplet)
|
72
99
|
|
73
|
-
def
|
74
|
-
self.
|
100
|
+
def add_stationary_triplet(self, function: Function, stationary_triplet: Triplet):
|
101
|
+
self.stationary_triplets[function].append(stationary_triplet)
|
75
102
|
|
76
103
|
def get_by_tag(self, tag: str) -> Point | Scalar:
|
104
|
+
"""
|
105
|
+
Under this :class:`PEPContext`, get the :class:`Point` or
|
106
|
+
:class:`Scalar` object associated with the provided `tag`.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
tag (str): The tag of the :class:`Point` or :class:`Scalar` object
|
110
|
+
we want to retrieve.
|
111
|
+
|
112
|
+
Returns:
|
113
|
+
:class:`Point` | :class:`Scalar`: The :class:`Point` or
|
114
|
+
:class:`Scalar` object associated with the provided `tag`.
|
115
|
+
"""
|
77
116
|
for p in self.points:
|
78
117
|
if tag in p.tags:
|
79
118
|
return p
|
80
119
|
for s in self.scalars:
|
81
120
|
if tag in s.tags:
|
82
121
|
return s
|
83
|
-
raise ValueError("Cannot find the point or
|
122
|
+
raise ValueError("Cannot find the point, scalar, or function of given tag.")
|
84
123
|
|
85
124
|
def clear(self):
|
125
|
+
"""Reset this :class:`PEPContext` object."""
|
86
126
|
self.points.clear()
|
87
127
|
self.scalars.clear()
|
88
128
|
self.triplets.clear()
|
89
|
-
self.
|
129
|
+
self.stationary_triplets.clear()
|
90
130
|
|
91
131
|
def tracked_point(self, func: Function) -> list[Point]:
|
132
|
+
"""
|
133
|
+
Each function :math:`f` used in Primal and Dual PEP is associated with
|
134
|
+
a set of triplets :math:`\\{x_i, f(x_i), \\nabla f(x_i)\\}` visited by
|
135
|
+
the considered algorithm. We can also consider a subgradient
|
136
|
+
:math:`\\widetilde{\\nabla} f(x)` instead of the gradient. This
|
137
|
+
function returns a list of the visited points :math:`\\{x_i\\}` under
|
138
|
+
this :class:`PEPContext`.
|
139
|
+
|
140
|
+
Args:
|
141
|
+
func (:class:`Function`): The function associated with the set
|
142
|
+
of triplets :math:`\\{x_i, f(x_i), \\nabla f(x_i)\\}`.
|
143
|
+
|
144
|
+
Returns:
|
145
|
+
list[:class:`Point`]: The list of the visited points
|
146
|
+
:math:`\\{x_i\\}`.
|
147
|
+
"""
|
92
148
|
return natsort.natsorted(
|
93
149
|
[t.point for t in self.triplets[func]], key=lambda x: x.tag
|
94
150
|
)
|
95
151
|
|
96
152
|
def tracked_grad(self, func: Function) -> list[Point]:
|
153
|
+
"""
|
154
|
+
Each function :math:`f` used in Primal and Dual PEP is associated with
|
155
|
+
a set of triplets :math:`\\{x_i, f(x_i), \\nabla f(x_i)\\}` visited by
|
156
|
+
the considered algorithm. We can also consider a subgradient
|
157
|
+
:math:`\\widetilde{\\nabla} f(x)` instead of the gradient
|
158
|
+
:math:`\\nabla f(x_i)`. This function returns a list of the visited
|
159
|
+
gradients :math:`\\{\\nabla f(x_i)\\}` under this :class:`PEPContext`.
|
160
|
+
|
161
|
+
Args:
|
162
|
+
func (:class:`Function`): The function associated with the set
|
163
|
+
of triplets :math:`\\{x_i, f(x_i), \\nabla f(x_i)\\}`.
|
164
|
+
|
165
|
+
Returns:
|
166
|
+
list[:class:`Point`]: The list of the visited gradients
|
167
|
+
:math:`\\{\\nabla f(x_i)\\}`.
|
168
|
+
|
169
|
+
"""
|
97
170
|
return natsort.natsorted(
|
98
171
|
[t.gradient for t in self.triplets[func]], key=lambda x: x.tag
|
99
172
|
)
|
100
173
|
|
101
174
|
def tracked_func_value(self, func: Function) -> list[Scalar]:
|
175
|
+
"""
|
176
|
+
Each function :math:`f` used in Primal and Dual PEP is associated with
|
177
|
+
a set of triplets :math:`\\{x_i, f(x_i), \\nabla f(x_i)\\}` visited by
|
178
|
+
the considered algorithm. We can also consider a subgradient
|
179
|
+
:math:`\\widetilde{\\nabla} f(x)` instead of the gradient
|
180
|
+
:math:`\\nabla f(x_i)`. This function returns a list of the visited
|
181
|
+
function values :math:`\\{f(x_i)\\}` under this :class:`PEPContext`.
|
182
|
+
|
183
|
+
Args:
|
184
|
+
func (:class:`Function`): The function associated with the set of
|
185
|
+
triplets :math:`\\{x_i, f(x_i), \\nabla f(x_i)\\}`.
|
186
|
+
|
187
|
+
Returns:
|
188
|
+
list[:class:`Scalar`]: The list of the visited function values
|
189
|
+
:math:`\\{f(x_i)\\}`.
|
190
|
+
|
191
|
+
"""
|
102
192
|
return natsort.natsorted(
|
103
193
|
[t.function_value for t in self.triplets[func]], key=lambda x: x.tag
|
104
194
|
)
|
@@ -134,3 +224,29 @@ class PEPContext:
|
|
134
224
|
func_to_order[func] = order
|
135
225
|
|
136
226
|
return func_to_df, func_to_order
|
227
|
+
|
228
|
+
def basis_points(self) -> list[Point]:
|
229
|
+
"""
|
230
|
+
Return a list of the basis :class:`Point` objects managed by this
|
231
|
+
:class:`PEPContext`.
|
232
|
+
|
233
|
+
Returns:
|
234
|
+
list[:class:`Point`]: A list of the basis :class:`Point` objects
|
235
|
+
managed by this :class:`PEPContext`.
|
236
|
+
"""
|
237
|
+
return [
|
238
|
+
p for p in self.points if p.is_basis
|
239
|
+
] # Note the order is always the same as added time
|
240
|
+
|
241
|
+
def basis_scalars(self) -> list[Scalar]:
|
242
|
+
"""
|
243
|
+
Return a list of the basis :class:`Scalar` objects managed by this
|
244
|
+
:class:`PEPContext`.
|
245
|
+
|
246
|
+
Returns:
|
247
|
+
list[:class:`Scalar`]: A list of the basis :class:`Scalar` objects
|
248
|
+
managed by this :class:`PEPContext`.
|
249
|
+
"""
|
250
|
+
return [
|
251
|
+
s for s in self.scalars if s.is_basis
|
252
|
+
] # Note the order is always the same as added time
|
pepflow/pep_context_test.py
CHANGED
@@ -25,6 +25,7 @@ import pytest
|
|
25
25
|
from pepflow import pep_context as pc
|
26
26
|
from pepflow.function import SmoothConvexFunction
|
27
27
|
from pepflow.point import Point
|
28
|
+
from pepflow.scalar import Scalar
|
28
29
|
|
29
30
|
|
30
31
|
@pytest.fixture
|
@@ -91,10 +92,34 @@ def test_get_by_tag(pep_context: pc.PEPContext):
|
|
91
92
|
f = SmoothConvexFunction(L=1, is_basis=True)
|
92
93
|
f.add_tag("f")
|
93
94
|
p1 = Point(is_basis=True, tags=["x1"])
|
95
|
+
p2 = Point(is_basis=True, tags=["x2"])
|
96
|
+
p3 = p1 + p2
|
94
97
|
|
95
98
|
triplet = f.generate_triplet(p1)
|
99
|
+
_ = f.generate_triplet(p2)
|
96
100
|
|
97
101
|
assert pep_context.get_by_tag("x1") == p1
|
98
102
|
assert pep_context.get_by_tag("f(x1)") == triplet.function_value
|
99
103
|
assert pep_context.get_by_tag("gradient_f(x1)") == triplet.gradient
|
104
|
+
assert pep_context.get_by_tag("x1+x2") == p3
|
100
105
|
pc.set_current_context(None)
|
106
|
+
|
107
|
+
|
108
|
+
def test_basis_points(pep_context: pc.PEPContext):
|
109
|
+
p1 = Point(is_basis=True, tags=["x1"])
|
110
|
+
p2 = Point(is_basis=True, tags=["x2"])
|
111
|
+
_ = p1 + p2 # not basis
|
112
|
+
ps = Point(is_basis=True, tags=["x_star"])
|
113
|
+
p0 = Point(is_basis=True, tags=["x0"])
|
114
|
+
|
115
|
+
assert pep_context.basis_points() == [p1, p2, ps, p0]
|
116
|
+
|
117
|
+
|
118
|
+
def test_basis_scalars(pep_context: pc.PEPContext):
|
119
|
+
p1 = Point(is_basis=True, tags=["x1"])
|
120
|
+
p2 = Point(is_basis=True, tags=["x2"])
|
121
|
+
_ = p1 * p2 # not basis
|
122
|
+
s1 = Scalar(is_basis=True, tags=["s2"])
|
123
|
+
s2 = Scalar(is_basis=True, tags=["s1"])
|
124
|
+
|
125
|
+
assert pep_context.basis_scalars() == [s1, s2]
|
pepflow/pep_test.py
CHANGED
@@ -19,6 +19,7 @@
|
|
19
19
|
|
20
20
|
import pytest
|
21
21
|
|
22
|
+
from pepflow import function as fc
|
22
23
|
from pepflow import pep
|
23
24
|
from pepflow import pep_context as pc
|
24
25
|
|
@@ -75,3 +76,10 @@ class TestPEPBuilder:
|
|
75
76
|
|
76
77
|
with builder.make_context("test", override=True):
|
77
78
|
pass
|
79
|
+
|
80
|
+
def test_get_func_by_tag(self) -> None:
|
81
|
+
builder = pep.PEPBuilder()
|
82
|
+
with builder.make_context("test"):
|
83
|
+
f = builder.declare_func(fc.SmoothConvexFunction, "f", L=1)
|
84
|
+
|
85
|
+
assert builder.get_func_by_tag("f") == f
|
pepflow/point.py
CHANGED
@@ -27,37 +27,65 @@ import numpy as np
|
|
27
27
|
|
28
28
|
from pepflow import pep_context as pc
|
29
29
|
from pepflow import utils
|
30
|
-
from pepflow.scalar import
|
30
|
+
from pepflow.scalar import Scalar, ScalarRepresentation
|
31
31
|
|
32
32
|
|
33
33
|
def is_numerical_or_point(val: Any) -> bool:
|
34
|
-
return utils.
|
34
|
+
return utils.is_numerical_or_parameter(val) or isinstance(val, Point)
|
35
35
|
|
36
36
|
|
37
37
|
def is_numerical_or_evaluatedpoint(val: Any) -> bool:
|
38
|
-
return utils.
|
38
|
+
return utils.is_numerical_or_parameter(val) or isinstance(val, EvaluatedPoint)
|
39
39
|
|
40
40
|
|
41
41
|
@attrs.frozen
|
42
|
-
class
|
42
|
+
class PointRepresentation:
|
43
43
|
op: utils.Op
|
44
44
|
left_point: Point | float
|
45
45
|
right_point: Point | float
|
46
46
|
|
47
47
|
|
48
|
+
@attrs.frozen
|
49
|
+
class ZeroPoint:
|
50
|
+
"""A special class to represent 0 in Point."""
|
51
|
+
|
52
|
+
pass
|
53
|
+
|
54
|
+
|
48
55
|
@attrs.frozen
|
49
56
|
class EvaluatedPoint:
|
57
|
+
"""
|
58
|
+
The concrete representation of the abstract :class:`Point`.
|
59
|
+
|
60
|
+
Each abstract basis :class:`Point` object has a unique concrete
|
61
|
+
representation as a unit vector. The concrete representations of
|
62
|
+
linear combinations of abstract basis :class:`Point` objects are
|
63
|
+
linear combinations of the unit vectors. This information is stored
|
64
|
+
in the `vector` attribute.
|
65
|
+
|
66
|
+
:class:`EvaluatedPoint` objects can be constructed as linear combinations
|
67
|
+
of other :class:`EvaluatedPoint` objects. Let `a` and `b` be some numeric
|
68
|
+
data type. Let `x` and `y` be :class:`EvaluatedPoint` objects. Then, we
|
69
|
+
can form a new :class:`EvaluatedPoint` object: `a*x+b*y`.
|
70
|
+
|
71
|
+
Attributes:
|
72
|
+
vector (np.ndarray): The concrete representation of an
|
73
|
+
abstract :class:`Point`.
|
74
|
+
"""
|
75
|
+
|
50
76
|
vector: np.ndarray
|
51
77
|
|
78
|
+
@classmethod
|
79
|
+
def zero(cls, num_basis_points: int):
|
80
|
+
return EvaluatedPoint(vector=np.zeros(num_basis_points))
|
81
|
+
|
52
82
|
def __add__(self, other):
|
53
83
|
if isinstance(other, EvaluatedPoint):
|
54
84
|
return EvaluatedPoint(vector=self.vector + other.vector)
|
55
85
|
elif utils.is_numerical(other):
|
56
86
|
return EvaluatedPoint(vector=self.vector + other)
|
57
87
|
else:
|
58
|
-
|
59
|
-
f"Unsupported add operation between EvaluatedPoint and {type(other)}"
|
60
|
-
)
|
88
|
+
return NotImplemented
|
61
89
|
|
62
90
|
def __radd__(self, other):
|
63
91
|
return self.__add__(other)
|
@@ -68,9 +96,7 @@ class EvaluatedPoint:
|
|
68
96
|
elif utils.is_numerical(other):
|
69
97
|
return EvaluatedPoint(vector=self.vector - other)
|
70
98
|
else:
|
71
|
-
|
72
|
-
f"Unsupported sub operation between EvaluatedPoint and {type(other)}"
|
73
|
-
)
|
99
|
+
return NotImplemented
|
74
100
|
|
75
101
|
def __rsub__(self, other):
|
76
102
|
if isinstance(other, EvaluatedPoint):
|
@@ -78,30 +104,51 @@ class EvaluatedPoint:
|
|
78
104
|
elif utils.is_numerical(other):
|
79
105
|
return EvaluatedPoint(vector=other - self.vector)
|
80
106
|
else:
|
81
|
-
|
82
|
-
f"Unsupported sub operation between EvaluatedPoint and {type(other)}"
|
83
|
-
)
|
107
|
+
return NotImplemented
|
84
108
|
|
85
109
|
def __mul__(self, other):
|
86
|
-
|
110
|
+
if not utils.is_numerical(other):
|
111
|
+
return NotImplemented
|
87
112
|
return EvaluatedPoint(vector=self.vector * other)
|
88
113
|
|
89
114
|
def __rmul__(self, other):
|
90
|
-
|
115
|
+
if not utils.is_numerical(other):
|
116
|
+
return NotImplemented
|
91
117
|
return EvaluatedPoint(vector=other * self.vector)
|
92
118
|
|
93
119
|
def __truediv__(self, other):
|
94
|
-
|
120
|
+
if not utils.is_numerical(other):
|
121
|
+
return NotImplemented
|
95
122
|
return EvaluatedPoint(vector=self.vector / other)
|
96
123
|
|
97
124
|
|
98
125
|
@attrs.frozen
|
99
126
|
class Point:
|
127
|
+
"""
|
128
|
+
A :class:`Point` object represents an element of a pre-Hilbert space.
|
129
|
+
Examples include a point or a gradient.
|
130
|
+
|
131
|
+
:class:`Point` objects can be constructed as linear combinations of
|
132
|
+
other :class:`Point` objects. Let `a` and `b` be some numeric data type.
|
133
|
+
Let `x` and `y` be :class:`Point` objects. Then, we can form a new
|
134
|
+
:class:`Point` object: `a*x+b*y`.
|
135
|
+
|
136
|
+
The inner product of two :class:`Point` objects can also be taken.
|
137
|
+
Let `x` and `y` be :class:`Point` objects. Then, their inner product is
|
138
|
+
`x*y` and returns a :class:`Scalar` object.
|
139
|
+
|
140
|
+
Attributes:
|
141
|
+
is_basis (bool): `True` if this point is not formed through a linear
|
142
|
+
combination of other points. `False` otherwise.
|
143
|
+
tags (list[str]): A list that contains tags that can be used to
|
144
|
+
identify the :class:`Point` object. Tags should be unique.
|
145
|
+
"""
|
146
|
+
|
100
147
|
# If true, the point is the basis for the evaluations of G
|
101
148
|
is_basis: bool
|
102
149
|
|
103
|
-
#
|
104
|
-
eval_expression:
|
150
|
+
# The representation of point used for evaluation.
|
151
|
+
eval_expression: PointRepresentation | ZeroPoint | None = None
|
105
152
|
|
106
153
|
# Human tagged value for the Point
|
107
154
|
tags: list[str] = attrs.field(factory=list)
|
@@ -120,13 +167,27 @@ class Point:
|
|
120
167
|
raise RuntimeError("Did you forget to create a context?")
|
121
168
|
pep_context.add_point(self)
|
122
169
|
|
170
|
+
@staticmethod
|
171
|
+
def zero() -> Point:
|
172
|
+
return Point(is_basis=False, eval_expression=ZeroPoint(), tags=["0"])
|
173
|
+
|
123
174
|
@property
|
124
175
|
def tag(self):
|
176
|
+
"""Returns the most recently added tag.
|
177
|
+
|
178
|
+
Returns:
|
179
|
+
str: The most recently added tag of this :class:`Point` object.
|
180
|
+
"""
|
125
181
|
if len(self.tags) == 0:
|
126
182
|
raise ValueError("Point should have a name.")
|
127
183
|
return self.tags[-1]
|
128
184
|
|
129
185
|
def add_tag(self, tag: str) -> None:
|
186
|
+
"""Add a new tag for this :class:`Point` object.
|
187
|
+
|
188
|
+
Args:
|
189
|
+
tag (str): The new tag to be added to the `tags` list.
|
190
|
+
"""
|
130
191
|
self.tags.append(tag)
|
131
192
|
|
132
193
|
def __repr__(self):
|
@@ -135,18 +196,15 @@ class Point:
|
|
135
196
|
return super().__repr__()
|
136
197
|
|
137
198
|
def _repr_latex_(self):
|
138
|
-
|
139
|
-
s = s.replace("star", r"\star")
|
140
|
-
s = s.replace("gradient_", r"\nabla ")
|
141
|
-
s = s.replace("|", r"\|")
|
142
|
-
return rf"$\\displaystyle {s}$"
|
199
|
+
return utils.str_to_latex(repr(self))
|
143
200
|
|
144
201
|
# TODO: add a validator that `is_basis` and `eval_expression` are properly setup.
|
145
202
|
def __add__(self, other):
|
146
|
-
|
203
|
+
if not isinstance(other, Point):
|
204
|
+
return NotImplemented
|
147
205
|
return Point(
|
148
206
|
is_basis=False,
|
149
|
-
eval_expression=
|
207
|
+
eval_expression=PointRepresentation(utils.Op.ADD, self, other),
|
150
208
|
tags=[f"{self.tag}+{other.tag}"],
|
151
209
|
)
|
152
210
|
|
@@ -154,72 +212,78 @@ class Point:
|
|
154
212
|
# TODO: come up with better way to handle this
|
155
213
|
if other == 0:
|
156
214
|
return self
|
157
|
-
|
215
|
+
if not isinstance(other, Point):
|
216
|
+
return NotImplemented
|
158
217
|
return Point(
|
159
218
|
is_basis=False,
|
160
|
-
eval_expression=
|
219
|
+
eval_expression=PointRepresentation(utils.Op.ADD, other, self),
|
161
220
|
tags=[f"{other.tag}+{self.tag}"],
|
162
221
|
)
|
163
222
|
|
164
223
|
def __sub__(self, other):
|
165
|
-
|
224
|
+
if not isinstance(other, Point):
|
225
|
+
return NotImplemented
|
166
226
|
tag_other = utils.parenthesize_tag(other)
|
167
227
|
return Point(
|
168
228
|
is_basis=False,
|
169
|
-
eval_expression=
|
229
|
+
eval_expression=PointRepresentation(utils.Op.SUB, self, other),
|
170
230
|
tags=[f"{self.tag}-{tag_other}"],
|
171
231
|
)
|
172
232
|
|
173
233
|
def __rsub__(self, other):
|
174
|
-
|
234
|
+
if not isinstance(other, Point):
|
235
|
+
return NotImplemented
|
175
236
|
tag_self = utils.parenthesize_tag(self)
|
176
237
|
return Point(
|
177
238
|
is_basis=False,
|
178
|
-
eval_expression=
|
239
|
+
eval_expression=PointRepresentation(utils.Op.SUB, other, self),
|
179
240
|
tags=[f"{other.tag}-{tag_self}"],
|
180
241
|
)
|
181
242
|
|
182
243
|
def __mul__(self, other):
|
183
|
-
|
184
|
-
|
244
|
+
if not is_numerical_or_point(other):
|
245
|
+
return NotImplemented
|
185
246
|
tag_self = utils.parenthesize_tag(self)
|
186
|
-
if utils.
|
247
|
+
if utils.is_numerical_or_parameter(other):
|
248
|
+
tag_other = utils.numerical_str(other)
|
187
249
|
return Point(
|
188
250
|
is_basis=False,
|
189
|
-
eval_expression=
|
190
|
-
tags=[f"{tag_self}*{
|
251
|
+
eval_expression=PointRepresentation(utils.Op.MUL, self, other),
|
252
|
+
tags=[f"{tag_self}*{tag_other}"],
|
191
253
|
)
|
192
254
|
else:
|
193
255
|
tag_other = utils.parenthesize_tag(other)
|
194
256
|
return Scalar(
|
195
257
|
is_basis=False,
|
196
|
-
eval_expression=
|
258
|
+
eval_expression=ScalarRepresentation(utils.Op.MUL, self, other),
|
197
259
|
tags=[f"{tag_self}*{tag_other}"],
|
198
260
|
)
|
199
261
|
|
200
262
|
def __rmul__(self, other):
|
201
|
-
|
202
|
-
|
263
|
+
if not is_numerical_or_point(other):
|
264
|
+
return NotImplemented
|
203
265
|
tag_self = utils.parenthesize_tag(self)
|
204
|
-
if utils.
|
266
|
+
if utils.is_numerical_or_parameter(other):
|
267
|
+
tag_other = utils.numerical_str(other)
|
205
268
|
return Point(
|
206
269
|
is_basis=False,
|
207
|
-
eval_expression=
|
208
|
-
tags=[f"{
|
270
|
+
eval_expression=PointRepresentation(utils.Op.MUL, other, self),
|
271
|
+
tags=[f"{tag_other}*{tag_self}"],
|
209
272
|
)
|
210
273
|
else:
|
211
274
|
tag_other = utils.parenthesize_tag(other)
|
212
275
|
return Scalar(
|
213
276
|
is_basis=False,
|
214
|
-
eval_expression=
|
277
|
+
eval_expression=ScalarRepresentation(utils.Op.MUL, other, self),
|
215
278
|
tags=[f"{tag_other}*{tag_self}"],
|
216
279
|
)
|
217
280
|
|
218
281
|
def __pow__(self, power):
|
219
|
-
|
282
|
+
if power != 2:
|
283
|
+
return NotImplemented
|
220
284
|
return Scalar(
|
221
285
|
is_basis=False,
|
222
|
-
eval_expression=
|
286
|
+
eval_expression=ScalarRepresentation(utils.Op.MUL, self, self),
|
223
287
|
tags=[rf"|{self.tag}|^{power}"],
|
224
288
|
)
|
225
289
|
|
@@ -227,17 +291,19 @@ class Point:
|
|
227
291
|
tag_self = utils.parenthesize_tag(self)
|
228
292
|
return Point(
|
229
293
|
is_basis=False,
|
230
|
-
eval_expression=
|
294
|
+
eval_expression=PointRepresentation(utils.Op.MUL, -1, self),
|
231
295
|
tags=[f"-{tag_self}"],
|
232
296
|
)
|
233
297
|
|
234
298
|
def __truediv__(self, other):
|
235
|
-
|
299
|
+
if not utils.is_numerical_or_parameter(other):
|
300
|
+
return NotImplemented
|
236
301
|
tag_self = utils.parenthesize_tag(self)
|
302
|
+
tag_other = f"1/{utils.numerical_str(other)}"
|
237
303
|
return Point(
|
238
304
|
is_basis=False,
|
239
|
-
eval_expression=
|
240
|
-
tags=[f"
|
305
|
+
eval_expression=PointRepresentation(utils.Op.DIV, self, other),
|
306
|
+
tags=[f"{tag_other}*{tag_self}"],
|
241
307
|
)
|
242
308
|
|
243
309
|
def __hash__(self):
|
@@ -249,6 +315,20 @@ class Point:
|
|
249
315
|
return self.uid == other.uid
|
250
316
|
|
251
317
|
def eval(self, ctx: pc.PEPContext | None = None) -> np.ndarray:
|
318
|
+
"""
|
319
|
+
Return the concrete representation of this :class:`Point`.
|
320
|
+
Concrete representations of :class:`Point` objects are
|
321
|
+
:class:`EvaluatedPoint` objects.
|
322
|
+
|
323
|
+
Args:
|
324
|
+
ctx (:class:`PEPContext` | None): The :class:`PEPContext` object
|
325
|
+
we consider. `None` if we consider the current global
|
326
|
+
:class:`PEPContext` object.
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
:class:`EvaluatedPoint`: The concrete representation of
|
330
|
+
this :class:`Point`.
|
331
|
+
"""
|
252
332
|
from pepflow.expression_manager import ExpressionManager
|
253
333
|
|
254
334
|
# Note this can be inefficient.
|
@@ -258,3 +338,29 @@ class Point:
|
|
258
338
|
raise RuntimeError("Did you forget to create a context?")
|
259
339
|
em = ExpressionManager(ctx)
|
260
340
|
return em.eval_point(self).vector
|
341
|
+
|
342
|
+
def repr_by_basis(self, ctx: pc.PEPContext | None = None) -> str:
|
343
|
+
"""
|
344
|
+
Express this :class:`Point` object as the linear combination of
|
345
|
+
the basis :class:`Point` objects of the given :class:`PEPContext`.
|
346
|
+
This linear combination is expressed as a `str` where, to refer to
|
347
|
+
the basis :class:`Point` objects, we use their tags.
|
348
|
+
|
349
|
+
Args:
|
350
|
+
ctx (:class:`PEPContext`): The :class:`PEPContext` object
|
351
|
+
whose basis :class:`Point` objects we consider. `None` if
|
352
|
+
we consider the current global :class:`PEPContext` object.
|
353
|
+
|
354
|
+
Returns:
|
355
|
+
str: The representation of this :class:`Point` object in terms of
|
356
|
+
the basis :class:`Point` objects of the given :class:`PEPContext`.
|
357
|
+
"""
|
358
|
+
from pepflow.expression_manager import ExpressionManager
|
359
|
+
|
360
|
+
# Note this can be inefficient.
|
361
|
+
if ctx is None:
|
362
|
+
ctx = pc.get_current_context()
|
363
|
+
if ctx is None:
|
364
|
+
raise RuntimeError("Did you forget to create a context?")
|
365
|
+
em = ExpressionManager(ctx)
|
366
|
+
return em.repr_point_by_basis(self)
|