pepflow 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pepflow/__init__.py +0 -0
- pepflow/constraint.py +19 -0
- pepflow/expression_manager.py +113 -0
- pepflow/function.py +183 -0
- pepflow/pep.py +109 -0
- pepflow/pep_context.py +30 -0
- pepflow/pep_test.py +58 -0
- pepflow/point.py +184 -0
- pepflow/point_test.py +329 -0
- pepflow/scalar.py +219 -0
- pepflow/solver.py +98 -0
- pepflow/solver_test.py +51 -0
- pepflow/utils.py +51 -0
- pepflow-0.1.0.dist-info/METADATA +51 -0
- pepflow-0.1.0.dist-info/RECORD +18 -0
- pepflow-0.1.0.dist-info/WHEEL +5 -0
- pepflow-0.1.0.dist-info/licenses/LICENSE +202 -0
- pepflow-0.1.0.dist-info/top_level.txt +1 -0
pepflow/__init__.py
ADDED
File without changes
|
pepflow/constraint.py
ADDED
@@ -0,0 +1,19 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
5
|
+
import attrs
|
6
|
+
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
from pepflow.point import Scalar
|
9
|
+
|
10
|
+
from pepflow import utils
|
11
|
+
|
12
|
+
|
13
|
+
@attrs.frozen
|
14
|
+
class Constraint:
|
15
|
+
"""It represents `expression relation 0`."""
|
16
|
+
|
17
|
+
scalar: Scalar | float
|
18
|
+
comparator: utils.Comparator
|
19
|
+
name: str
|
@@ -0,0 +1,113 @@
|
|
1
|
+
import functools
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
|
5
|
+
from pepflow import pep_context as pc
|
6
|
+
from pepflow import point as pt
|
7
|
+
from pepflow import scalar as sc
|
8
|
+
from pepflow import utils
|
9
|
+
|
10
|
+
|
11
|
+
class ExpressionManager:
|
12
|
+
def __init__(self, pep_context: pc.PEPContext):
|
13
|
+
self.context = pep_context
|
14
|
+
self._basis_points = []
|
15
|
+
self._basis_point_uid_to_index = {}
|
16
|
+
self._basis_scalars = []
|
17
|
+
self._basis_scalar_uid_to_index = {}
|
18
|
+
for point in self.context.points:
|
19
|
+
if point.is_basis:
|
20
|
+
self._basis_points.append(point)
|
21
|
+
self._basis_point_uid_to_index[point.uid] = len(self._basis_points) - 1
|
22
|
+
for scalar in self.context.scalars:
|
23
|
+
if scalar.is_basis:
|
24
|
+
self._basis_scalars.append(scalar)
|
25
|
+
self._basis_scalar_uid_to_index[scalar.uid] = (
|
26
|
+
len(self._basis_scalars) - 1
|
27
|
+
)
|
28
|
+
|
29
|
+
self._num_basis_points = len(self._basis_points)
|
30
|
+
self._num_basis_scalars = len(self._basis_scalars)
|
31
|
+
|
32
|
+
def get_index_of_basis_point(self, point: pt.Point):
|
33
|
+
return self._basis_point_uid_to_index[point.uid]
|
34
|
+
|
35
|
+
def get_index_of_basis_scalar(self, scalar: sc.Scalar):
|
36
|
+
return self._basis_scalar_uid_to_index[scalar.uid]
|
37
|
+
|
38
|
+
@functools.cache
|
39
|
+
def eval_point(self, point: pt.Point | float | int):
|
40
|
+
if utils.is_numerical(point):
|
41
|
+
return point
|
42
|
+
|
43
|
+
array = np.zeros(self._num_basis_points)
|
44
|
+
if point.is_basis:
|
45
|
+
index = self.get_index_of_basis_point(point)
|
46
|
+
array[index] = 1
|
47
|
+
return pt.EvaluatedPoint(vector=array)
|
48
|
+
|
49
|
+
op = point.eval_expression.op
|
50
|
+
if op == utils.Op.ADD:
|
51
|
+
return self.eval_point(point.eval_expression.left_point) + self.eval_point(
|
52
|
+
point.eval_expression.right_point
|
53
|
+
)
|
54
|
+
if op == utils.Op.SUB:
|
55
|
+
return self.eval_point(point.eval_expression.left_point) - self.eval_point(
|
56
|
+
point.eval_expression.right_point
|
57
|
+
)
|
58
|
+
if op == utils.Op.MUL:
|
59
|
+
return self.eval_point(point.eval_expression.left_point) * self.eval_point(
|
60
|
+
point.eval_expression.right_point
|
61
|
+
)
|
62
|
+
if op == utils.Op.DIV:
|
63
|
+
return self.eval_point(point.eval_expression.left_point) / self.eval_point(
|
64
|
+
point.eval_expression.right_point
|
65
|
+
)
|
66
|
+
|
67
|
+
raise ValueError("This should never happen!")
|
68
|
+
|
69
|
+
@functools.cache
|
70
|
+
def eval_scalar(self, scalar: sc.Scalar | float | int):
|
71
|
+
if utils.is_numerical(scalar):
|
72
|
+
return scalar
|
73
|
+
|
74
|
+
array = np.zeros(self._num_basis_scalars)
|
75
|
+
if scalar.is_basis:
|
76
|
+
index = self.get_index_of_basis_scalar(scalar)
|
77
|
+
array[index] = 1
|
78
|
+
return sc.EvaluatedScalar(
|
79
|
+
vector=array,
|
80
|
+
matrix=np.zeros((self._num_basis_points, self._num_basis_points)),
|
81
|
+
constant=float(0.0),
|
82
|
+
)
|
83
|
+
op = scalar.eval_expression.op
|
84
|
+
if op == utils.Op.ADD:
|
85
|
+
return self.eval_scalar(
|
86
|
+
scalar.eval_expression.left_scalar
|
87
|
+
) + self.eval_scalar(scalar.eval_expression.right_scalar)
|
88
|
+
if op == utils.Op.SUB:
|
89
|
+
return self.eval_scalar(
|
90
|
+
scalar.eval_expression.left_scalar
|
91
|
+
) - self.eval_scalar(scalar.eval_expression.right_scalar)
|
92
|
+
if op == utils.Op.MUL:
|
93
|
+
if isinstance(scalar.eval_expression.left_scalar, pt.Point) and isinstance(
|
94
|
+
scalar.eval_expression.right_scalar, pt.Point
|
95
|
+
):
|
96
|
+
return sc.EvaluatedScalar(
|
97
|
+
vector=np.zeros(self._num_basis_scalars),
|
98
|
+
matrix=utils.SOP(
|
99
|
+
self.eval_point(scalar.eval_expression.left_scalar).vector,
|
100
|
+
self.eval_point(scalar.eval_expression.right_scalar).vector,
|
101
|
+
),
|
102
|
+
constant=float(0.0),
|
103
|
+
)
|
104
|
+
else:
|
105
|
+
return self.eval_scalar(
|
106
|
+
scalar.eval_expression.left_scalar
|
107
|
+
) * self.eval_scalar(scalar.eval_expression.right_scalar)
|
108
|
+
if op == utils.Op.DIV:
|
109
|
+
return self.eval_scalar(
|
110
|
+
scalar.eval_expression.left_scalar
|
111
|
+
) / self.eval_scalar(scalar.eval_expression.right_scalar)
|
112
|
+
|
113
|
+
raise ValueError("This should never happen!")
|
pepflow/function.py
ADDED
@@ -0,0 +1,183 @@
|
|
1
|
+
import uuid
|
2
|
+
|
3
|
+
import attrs
|
4
|
+
|
5
|
+
from pepflow import point as pt
|
6
|
+
from pepflow import scalar as sc
|
7
|
+
from pepflow import utils
|
8
|
+
|
9
|
+
|
10
|
+
class Function:
|
11
|
+
def __init__(
|
12
|
+
self,
|
13
|
+
is_basis: bool,
|
14
|
+
reuse_gradient: bool,
|
15
|
+
composition: dict | None = None,
|
16
|
+
tag: str | None = None,
|
17
|
+
):
|
18
|
+
self.is_basis = is_basis
|
19
|
+
self.reuse_gradient = reuse_gradient
|
20
|
+
self.tag = tag
|
21
|
+
self.uid = attrs.field(factory=uuid.uuid4, init=False)
|
22
|
+
self.triplets = [] #: list[("point", "scalar", "point")] = []
|
23
|
+
self.constraints = [] #: list["constraint"] = []
|
24
|
+
|
25
|
+
if is_basis:
|
26
|
+
assert composition is None
|
27
|
+
self.composition = {self: 1}
|
28
|
+
else:
|
29
|
+
assert isinstance(composition, dict)
|
30
|
+
self.composition = composition #: dict[{"function": float})] = []
|
31
|
+
|
32
|
+
def add_tag(self, tag: str) -> None:
|
33
|
+
self.tag = tag
|
34
|
+
return None
|
35
|
+
|
36
|
+
def get_interpolation_constraints(self):
|
37
|
+
raise NotImplementedError(
|
38
|
+
"This method should be implemented in the children class."
|
39
|
+
)
|
40
|
+
|
41
|
+
def add_triplet(self, triplet: tuple) -> None:
|
42
|
+
return NotImplemented
|
43
|
+
|
44
|
+
def add_stationary_point(self) -> pt.Point:
|
45
|
+
point = pt.Point(is_basis=True)
|
46
|
+
_, _, grad = self.generate_triplet(point)
|
47
|
+
self.constraints.append(
|
48
|
+
(grad**2).eq(0, name=str(self.__hash__) + " stationary point")
|
49
|
+
)
|
50
|
+
return point
|
51
|
+
|
52
|
+
def generate_triplet(self, point: pt.Point) -> tuple:
|
53
|
+
func_value = 0
|
54
|
+
grad = 0
|
55
|
+
|
56
|
+
if self.is_basis:
|
57
|
+
generate_new_basis = True
|
58
|
+
for triplet in self.triplets:
|
59
|
+
if triplet[0].uid == point.uid and self.reuse_gradient:
|
60
|
+
func_value = triplet[1]
|
61
|
+
grad = triplet[2]
|
62
|
+
generate_new_basis = False
|
63
|
+
break
|
64
|
+
elif triplet[0].uid == point.uid and not self.reuse_gradient:
|
65
|
+
func_value = triplet[1]
|
66
|
+
grad = pt.Point(is_basis=True)
|
67
|
+
generate_new_basis = False
|
68
|
+
self.triplets.append((point, func_value, grad))
|
69
|
+
break
|
70
|
+
if generate_new_basis:
|
71
|
+
func_value = sc.Scalar(is_basis=True)
|
72
|
+
grad = pt.Point(is_basis=True)
|
73
|
+
self.triplets.append((point, func_value, grad))
|
74
|
+
else:
|
75
|
+
for function, weights in self.composition.items():
|
76
|
+
_, func_value_slice, grad_slice = function.generate_triplet(point)
|
77
|
+
func_value += weights * func_value_slice
|
78
|
+
grad += weights * grad_slice
|
79
|
+
|
80
|
+
return (point, func_value, grad)
|
81
|
+
|
82
|
+
def gradient(self, point: pt.Point) -> pt.Point:
|
83
|
+
_, _, grad = self.generate_triplet(point)
|
84
|
+
return grad
|
85
|
+
|
86
|
+
def subgradient(self, point: pt.Point) -> pt.Point:
|
87
|
+
_, _, subgrad = self.generate_triplet(point)
|
88
|
+
return subgrad
|
89
|
+
|
90
|
+
def function_value(self, point: pt.Point) -> sc.Scalar:
|
91
|
+
_, func_value, _ = self.generate_triplet(point)
|
92
|
+
return func_value
|
93
|
+
|
94
|
+
def __add__(self, other):
|
95
|
+
assert isinstance(other, Function)
|
96
|
+
merged_composition = utils.merge_dict(self.composition, other.composition)
|
97
|
+
pruned_composition = utils.prune_dict(merged_composition)
|
98
|
+
return Function(
|
99
|
+
is_basis=False,
|
100
|
+
reuse_gradient=self.reuse_gradient and other.reuse_gradient,
|
101
|
+
composition=pruned_composition,
|
102
|
+
tag=None,
|
103
|
+
)
|
104
|
+
|
105
|
+
def __sub__(self, other):
|
106
|
+
return self.__add__(-other)
|
107
|
+
|
108
|
+
def __mul__(self, other):
|
109
|
+
return self.__rmul__(other=other)
|
110
|
+
|
111
|
+
def __rmul__(self, other):
|
112
|
+
assert utils.is_numerical(other)
|
113
|
+
scaled_composition = dict()
|
114
|
+
for key, value in self.composition.items():
|
115
|
+
scaled_composition[key] = value * other
|
116
|
+
pruned_composition = utils.prune_dict(scaled_composition)
|
117
|
+
return Function(
|
118
|
+
is_basis=False,
|
119
|
+
reuse_gradient=self.reuse_gradient,
|
120
|
+
composition=pruned_composition,
|
121
|
+
tag=None,
|
122
|
+
)
|
123
|
+
|
124
|
+
def __neg__(self):
|
125
|
+
return self.__mul__(other=-1)
|
126
|
+
|
127
|
+
def __truediv__(self, other):
|
128
|
+
assert utils.is_numerical(other)
|
129
|
+
scaled_composition = dict()
|
130
|
+
for key, value in self.composition.items():
|
131
|
+
scaled_composition[key] = value / other
|
132
|
+
pruned_composition = utils.prune_dict(scaled_composition)
|
133
|
+
return Function(
|
134
|
+
is_basis=False,
|
135
|
+
reuse_gradient=self.reuse_gradient,
|
136
|
+
composition=pruned_composition,
|
137
|
+
tag=None,
|
138
|
+
)
|
139
|
+
|
140
|
+
def __hash__(self):
|
141
|
+
return hash(self.uid)
|
142
|
+
|
143
|
+
def __eq__(self, other):
|
144
|
+
if not isinstance(other, Function):
|
145
|
+
return NotImplemented
|
146
|
+
return self.uid == other.uid
|
147
|
+
|
148
|
+
|
149
|
+
class SmoothConvexFunction(Function):
|
150
|
+
def __init__(self, L, is_basis=True, composition=None, reuse_gradient=True):
|
151
|
+
super().__init__(
|
152
|
+
is_basis=is_basis, composition=composition, reuse_gradient=reuse_gradient
|
153
|
+
)
|
154
|
+
self.L = L
|
155
|
+
|
156
|
+
def smooth_convex_interpolability_constraints(self, triplet_i, triplet_j):
|
157
|
+
point_i, func_value_i, grad_i = triplet_i
|
158
|
+
point_j, func_value_j, grad_j = triplet_j
|
159
|
+
func_diff = func_value_j - func_value_i
|
160
|
+
cross_term = grad_j * (point_i - point_j)
|
161
|
+
quad_term = 1 / (2 * self.L) * (grad_i - grad_j) ** 2
|
162
|
+
|
163
|
+
return (func_diff + cross_term + quad_term).le(
|
164
|
+
0,
|
165
|
+
name=str(self.__hash__())
|
166
|
+
+ ":"
|
167
|
+
+ str(point_i.__hash__())
|
168
|
+
+ ","
|
169
|
+
+ str(point_j.__hash__()),
|
170
|
+
)
|
171
|
+
|
172
|
+
def get_interpolation_constraints(self):
|
173
|
+
interpolation_constraints = []
|
174
|
+
for i in range(len(self.triplets)):
|
175
|
+
for j in range(len(self.triplets)):
|
176
|
+
if i == j:
|
177
|
+
continue
|
178
|
+
interpolation_constraints.append(
|
179
|
+
self.smooth_convex_interpolability_constraints(
|
180
|
+
self.triplets[i], self.triplets[j]
|
181
|
+
)
|
182
|
+
)
|
183
|
+
return interpolation_constraints
|
pepflow/pep.py
ADDED
@@ -0,0 +1,109 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import contextlib
|
4
|
+
from typing import TYPE_CHECKING, Any
|
5
|
+
|
6
|
+
import attrs
|
7
|
+
|
8
|
+
from pepflow import pep_context as pc
|
9
|
+
from pepflow import point as pt
|
10
|
+
from pepflow import scalar as sc
|
11
|
+
from pepflow import solver as ps
|
12
|
+
|
13
|
+
if TYPE_CHECKING:
|
14
|
+
from pepflow.solver import DualVariableManager
|
15
|
+
|
16
|
+
|
17
|
+
@attrs.frozen
|
18
|
+
class PEPResult:
|
19
|
+
primal_opt_value: float
|
20
|
+
dual_var_manager: DualVariableManager
|
21
|
+
solver_status: Any
|
22
|
+
|
23
|
+
|
24
|
+
class PEPBuilder:
|
25
|
+
"""The main class for PEP primal formulation."""
|
26
|
+
|
27
|
+
def __init__(self):
|
28
|
+
self.pep_context_dict: dict[str, pc.PEPContext] = {}
|
29
|
+
|
30
|
+
self.init_conditions = [] #: list["constraint"] =[]
|
31
|
+
self.functions = [] #: list["function"] = []
|
32
|
+
self.interpolation_constraints = [] #: list["constraint"] = []
|
33
|
+
self.performance_metric = None # scalar
|
34
|
+
|
35
|
+
# Contain the name for the constraints that should be removed.
|
36
|
+
# We should think about a better choice like manager.
|
37
|
+
self.relaxed_constraints = []
|
38
|
+
|
39
|
+
@contextlib.contextmanager
|
40
|
+
def make_context(self, name: str, override: bool = False) -> pc.PEPContext:
|
41
|
+
if not override and name in self.pep_context_dict:
|
42
|
+
raise KeyError(f"There is already a context {name} in the builder")
|
43
|
+
try:
|
44
|
+
ctx = pc.PEPContext()
|
45
|
+
self.pep_context_dict[name] = ctx
|
46
|
+
pc.set_current_context(ctx)
|
47
|
+
yield ctx
|
48
|
+
finally:
|
49
|
+
pc.set_current_context(None)
|
50
|
+
|
51
|
+
def get_context(self, name: str) -> pc.PEPContext:
|
52
|
+
if name not in self.pep_context_dict:
|
53
|
+
raise KeyError(f"Cannot find a context named {name} in the builder.")
|
54
|
+
ctx = self.pep_context_dict[name]
|
55
|
+
pc.set_current_context(ctx)
|
56
|
+
return ctx
|
57
|
+
|
58
|
+
def clear_context(self, name: str) -> None:
|
59
|
+
if name not in self.pep_context_dict:
|
60
|
+
raise KeyError(f"Cannot find a context named {name} in the builder.")
|
61
|
+
del self.pep_context_dict[name]
|
62
|
+
|
63
|
+
def clear_all_context(self) -> None:
|
64
|
+
self.pep_context_dict.clear()
|
65
|
+
|
66
|
+
def set_init_point(self, tag: str | None = None) -> pt.Point:
|
67
|
+
point = pt.Point(is_basis=True)
|
68
|
+
point.add_tag(tag)
|
69
|
+
return point
|
70
|
+
|
71
|
+
def set_initial_constraint(self, constraint):
|
72
|
+
self.init_conditions.append(constraint)
|
73
|
+
|
74
|
+
def set_performance_metric(self, metric: sc.Scalar):
|
75
|
+
self.performance_metric = metric
|
76
|
+
|
77
|
+
def declare_func(self, function_class, **kwargs):
|
78
|
+
func = function_class(is_basis=True, composition=None, **kwargs)
|
79
|
+
self.functions.append(func)
|
80
|
+
return func
|
81
|
+
|
82
|
+
def solve(self, context: pc.PEPContext | None = None, **kwargs):
|
83
|
+
if context is None:
|
84
|
+
context = pc.get_current_context()
|
85
|
+
if context is None:
|
86
|
+
raise RuntimeError("Did you forget to create a context?")
|
87
|
+
|
88
|
+
all_constraints = [*self.init_conditions]
|
89
|
+
for f in self.functions:
|
90
|
+
if f.is_basis:
|
91
|
+
all_constraints.extend(f.get_interpolation_constraints())
|
92
|
+
all_constraints.extend(f.constraints)
|
93
|
+
|
94
|
+
# for now, we heavily rely on the CVX. We can make a wrapper class to avoid
|
95
|
+
# direct dependency in the future.
|
96
|
+
solver = ps.CVXSolver(
|
97
|
+
perf_metric=self.performance_metric,
|
98
|
+
constraints=[
|
99
|
+
c for c in all_constraints if c.name not in self.relaxed_constraints
|
100
|
+
],
|
101
|
+
context=context,
|
102
|
+
)
|
103
|
+
problem = solver.build_problem()
|
104
|
+
result = problem.solve(**kwargs)
|
105
|
+
return PEPResult(
|
106
|
+
primal_opt_value=result,
|
107
|
+
dual_var_manager=solver.dual_var_manager,
|
108
|
+
solver_status=problem.status,
|
109
|
+
)
|
pepflow/pep_context.py
ADDED
@@ -0,0 +1,30 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
# A global variable for storing the current context that is used for points or scalars.
|
4
|
+
CURRENT_CONTEXT: PEPContext | None = None
|
5
|
+
|
6
|
+
|
7
|
+
def get_current_context() -> PEPContext | None:
|
8
|
+
return CURRENT_CONTEXT
|
9
|
+
|
10
|
+
|
11
|
+
def set_current_context(ctx: PEPContext | None):
|
12
|
+
global CURRENT_CONTEXT
|
13
|
+
assert ctx is None or isinstance(ctx, PEPContext)
|
14
|
+
CURRENT_CONTEXT = ctx
|
15
|
+
|
16
|
+
|
17
|
+
class PEPContext:
|
18
|
+
def __init__(self):
|
19
|
+
self.points = []
|
20
|
+
self.scalars = []
|
21
|
+
|
22
|
+
def add_point(self, point):
|
23
|
+
self.points.append(point)
|
24
|
+
|
25
|
+
def add_scalar(self, scalar):
|
26
|
+
self.scalars.append(scalar)
|
27
|
+
|
28
|
+
def clear(self):
|
29
|
+
self.points = []
|
30
|
+
self.scalars = []
|
pepflow/pep_test.py
ADDED
@@ -0,0 +1,58 @@
|
|
1
|
+
import pytest
|
2
|
+
|
3
|
+
from pepflow import pep
|
4
|
+
from pepflow import pep_context as pc
|
5
|
+
|
6
|
+
|
7
|
+
class TestPEPBuilder:
|
8
|
+
def test_make_context(self) -> None:
|
9
|
+
builder = pep.PEPBuilder()
|
10
|
+
assert pc.get_current_context() is None
|
11
|
+
|
12
|
+
with builder.make_context("test") as ctx:
|
13
|
+
assert ctx is pc.get_current_context()
|
14
|
+
|
15
|
+
assert pc.get_current_context() is None
|
16
|
+
|
17
|
+
def test_get_context(self) -> None:
|
18
|
+
builder = pep.PEPBuilder()
|
19
|
+
with builder.make_context("test") as ctx:
|
20
|
+
prev_ctx = ctx
|
21
|
+
|
22
|
+
builder.get_context("test") is prev_ctx
|
23
|
+
|
24
|
+
def test_clear_context(self) -> None:
|
25
|
+
builder = pep.PEPBuilder()
|
26
|
+
with builder.make_context("test"):
|
27
|
+
pass
|
28
|
+
|
29
|
+
assert "test" in builder.pep_context_dict.keys()
|
30
|
+
builder.clear_context("test")
|
31
|
+
assert "test" not in builder.pep_context_dict.keys()
|
32
|
+
|
33
|
+
def test_clear_all_context(self) -> None:
|
34
|
+
builder = pep.PEPBuilder()
|
35
|
+
with builder.make_context("test"):
|
36
|
+
pass
|
37
|
+
with builder.make_context("test2"):
|
38
|
+
pass
|
39
|
+
|
40
|
+
assert len(builder.pep_context_dict) == 2
|
41
|
+
builder.clear_all_context()
|
42
|
+
assert len(builder.pep_context_dict) == 0
|
43
|
+
|
44
|
+
def test_make_context_twice(self) -> None:
|
45
|
+
builder = pep.PEPBuilder()
|
46
|
+
with builder.make_context("test"):
|
47
|
+
pass
|
48
|
+
|
49
|
+
assert "test" in builder.pep_context_dict.keys()
|
50
|
+
|
51
|
+
with pytest.raises(
|
52
|
+
KeyError, match="There is already a context test in the builder"
|
53
|
+
):
|
54
|
+
with builder.make_context("test"):
|
55
|
+
pass
|
56
|
+
|
57
|
+
with builder.make_context("test", override=True):
|
58
|
+
pass
|
pepflow/point.py
ADDED
@@ -0,0 +1,184 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import uuid
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
import attrs
|
7
|
+
import numpy as np
|
8
|
+
|
9
|
+
from pepflow import pep_context as pc
|
10
|
+
from pepflow import utils
|
11
|
+
from pepflow.scalar import EvalExpressionScalar, Scalar
|
12
|
+
|
13
|
+
|
14
|
+
def is_numerical_or_point(val: Any) -> bool:
|
15
|
+
return utils.is_numerical(val) or isinstance(val, Point)
|
16
|
+
|
17
|
+
|
18
|
+
def is_numerical_or_evaluatedpoint(val: Any) -> bool:
|
19
|
+
return utils.is_numerical(val) or isinstance(val, EvaluatedPoint)
|
20
|
+
|
21
|
+
|
22
|
+
@attrs.frozen
|
23
|
+
class EvalExpressionPoint:
|
24
|
+
op: utils.Op
|
25
|
+
left_point: Point | float
|
26
|
+
right_point: Point | float
|
27
|
+
|
28
|
+
|
29
|
+
@attrs.frozen
|
30
|
+
class EvaluatedPoint:
|
31
|
+
vector: np.array
|
32
|
+
|
33
|
+
def __add__(self, other):
|
34
|
+
if isinstance(other, EvaluatedPoint):
|
35
|
+
return EvaluatedPoint(vector=self.vector + other.vector)
|
36
|
+
elif utils.is_numerical(other):
|
37
|
+
return EvaluatedPoint(vector=self.vector + other)
|
38
|
+
else:
|
39
|
+
raise ValueError(
|
40
|
+
f"Unsupported add operation between EvaluatedPoint and {type(other)}"
|
41
|
+
)
|
42
|
+
|
43
|
+
def __radd__(self, other):
|
44
|
+
return self.__add__(other)
|
45
|
+
|
46
|
+
def __sub__(self, other):
|
47
|
+
if isinstance(other, EvaluatedPoint):
|
48
|
+
return EvaluatedPoint(vector=self.vector - other.vector)
|
49
|
+
elif utils.is_numerical(other):
|
50
|
+
return EvaluatedPoint(vector=self.vector - other)
|
51
|
+
else:
|
52
|
+
raise ValueError(
|
53
|
+
f"Unsupported sub operation between EvaluatedPoint and {type(other)}"
|
54
|
+
)
|
55
|
+
|
56
|
+
def __rsub__(self, other):
|
57
|
+
if isinstance(other, EvaluatedPoint):
|
58
|
+
return EvaluatedPoint(vector=other.vector - self.vector)
|
59
|
+
elif utils.is_numerical(other):
|
60
|
+
return EvaluatedPoint(vector=other - self.vector)
|
61
|
+
else:
|
62
|
+
raise ValueError(
|
63
|
+
f"Unsupported sub operation between EvaluatedPoint and {type(other)}"
|
64
|
+
)
|
65
|
+
|
66
|
+
def __mul__(self, other):
|
67
|
+
assert utils.is_numerical(other)
|
68
|
+
return EvaluatedPoint(vector=self.vector * other)
|
69
|
+
|
70
|
+
def __rmul__(self, other):
|
71
|
+
assert utils.is_numerical(other)
|
72
|
+
return EvaluatedPoint(vector=other * self.vector)
|
73
|
+
|
74
|
+
def __truediv__(self, other):
|
75
|
+
assert utils.is_numerical(other)
|
76
|
+
return EvaluatedPoint(vector=self.vector / other)
|
77
|
+
|
78
|
+
|
79
|
+
@attrs.frozen
|
80
|
+
class Point:
|
81
|
+
# If true, the point is the basis for the evaluations of G
|
82
|
+
is_basis: bool
|
83
|
+
|
84
|
+
# How to evaluate the point.
|
85
|
+
eval_expression: EvalExpressionPoint | None = None
|
86
|
+
|
87
|
+
# Human tagged value for the Point
|
88
|
+
tags: list[str] = attrs.field(factory=list)
|
89
|
+
|
90
|
+
# Generate an automatic id
|
91
|
+
uid: uuid.UUID = attrs.field(factory=uuid.uuid4, init=False)
|
92
|
+
|
93
|
+
def __attrs_post_init__(self):
|
94
|
+
if self.is_basis:
|
95
|
+
assert self.eval_expression is None
|
96
|
+
else:
|
97
|
+
assert self.eval_expression is not None
|
98
|
+
|
99
|
+
pep_context = pc.get_current_context()
|
100
|
+
if pep_context is None:
|
101
|
+
raise RuntimeError("Did you forget to create a context?")
|
102
|
+
pep_context.add_point(self)
|
103
|
+
|
104
|
+
def add_tag(self, tag: str) -> None:
|
105
|
+
self.tags.append(tag)
|
106
|
+
|
107
|
+
# TODO: add a validator that `is_basis` and `eval_expression` are properly setup.
|
108
|
+
def __add__(self, other):
|
109
|
+
assert is_numerical_or_point(other)
|
110
|
+
return Point(
|
111
|
+
is_basis=False,
|
112
|
+
eval_expression=EvalExpressionPoint(utils.Op.ADD, self, other),
|
113
|
+
)
|
114
|
+
|
115
|
+
def __radd__(self, other):
|
116
|
+
assert is_numerical_or_point(other)
|
117
|
+
return Point(
|
118
|
+
is_basis=False,
|
119
|
+
eval_expression=EvalExpressionPoint(utils.Op.ADD, other, self),
|
120
|
+
)
|
121
|
+
|
122
|
+
def __sub__(self, other):
|
123
|
+
assert is_numerical_or_point(other)
|
124
|
+
return Point(
|
125
|
+
is_basis=False,
|
126
|
+
eval_expression=EvalExpressionPoint(utils.Op.SUB, self, other),
|
127
|
+
)
|
128
|
+
|
129
|
+
def __rsub__(self, other):
|
130
|
+
assert is_numerical_or_point(other)
|
131
|
+
return Point(
|
132
|
+
is_basis=False,
|
133
|
+
eval_expression=EvalExpressionPoint(utils.Op.SUB, other, self),
|
134
|
+
)
|
135
|
+
|
136
|
+
def __mul__(self, other):
|
137
|
+
# TODO allow the other to be point so that we return a scalar.
|
138
|
+
assert is_numerical_or_point(other)
|
139
|
+
if utils.is_numerical(other):
|
140
|
+
return Point(
|
141
|
+
is_basis=False,
|
142
|
+
eval_expression=EvalExpressionPoint(utils.Op.MUL, self, other),
|
143
|
+
)
|
144
|
+
else:
|
145
|
+
return Scalar(
|
146
|
+
is_basis=False,
|
147
|
+
eval_expression=EvalExpressionScalar(utils.Op.MUL, self, other), # TODO
|
148
|
+
)
|
149
|
+
|
150
|
+
def __rmul__(self, other):
|
151
|
+
# TODO allow the other to be point so that we return a scalar.
|
152
|
+
assert is_numerical_or_point(other)
|
153
|
+
if utils.is_numerical(other):
|
154
|
+
return Point(
|
155
|
+
is_basis=False,
|
156
|
+
eval_expression=EvalExpressionPoint(utils.Op.MUL, other, self),
|
157
|
+
)
|
158
|
+
else:
|
159
|
+
return Scalar(
|
160
|
+
is_basis=False,
|
161
|
+
eval_expression=EvalExpressionScalar(utils.Op.MUL, other, self), # TODO
|
162
|
+
)
|
163
|
+
|
164
|
+
def __pow__(self, power):
|
165
|
+
assert power == 2
|
166
|
+
return self.__rmul__(self)
|
167
|
+
|
168
|
+
def __neg__(self):
|
169
|
+
return self.__rmul__(other=-1)
|
170
|
+
|
171
|
+
def __truediv__(self, other):
|
172
|
+
assert utils.is_numerical(other)
|
173
|
+
return Point(
|
174
|
+
is_basis=False,
|
175
|
+
eval_expression=EvalExpressionPoint(utils.Op.DIV, self, other),
|
176
|
+
)
|
177
|
+
|
178
|
+
def __hash__(self):
|
179
|
+
return hash(self.uid)
|
180
|
+
|
181
|
+
def __eq__(self, other):
|
182
|
+
if not isinstance(other, Point):
|
183
|
+
return NotImplemented
|
184
|
+
return self.uid == other.uid
|