pepflow 0.1.4a1__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pepflow/__init__.py +5 -1
- pepflow/constraint.py +58 -1
- pepflow/e2e_test.py +47 -3
- pepflow/expression_manager.py +272 -57
- pepflow/expression_manager_test.py +36 -2
- pepflow/function.py +180 -10
- pepflow/parameter.py +187 -0
- pepflow/parameter_test.py +128 -0
- pepflow/pep.py +254 -14
- pepflow/pep_context.py +116 -0
- pepflow/pep_context_test.py +21 -0
- pepflow/point.py +155 -49
- pepflow/point_test.py +12 -0
- pepflow/scalar.py +260 -47
- pepflow/scalar_test.py +15 -0
- pepflow/solver.py +170 -3
- pepflow/solver_test.py +50 -2
- pepflow/utils.py +39 -7
- {pepflow-0.1.4a1.dist-info → pepflow-0.1.5.dist-info}/METADATA +12 -11
- pepflow-0.1.5.dist-info/RECORD +28 -0
- pepflow-0.1.4a1.dist-info/RECORD +0 -26
- {pepflow-0.1.4a1.dist-info → pepflow-0.1.5.dist-info}/WHEEL +0 -0
- {pepflow-0.1.4a1.dist-info → pepflow-0.1.5.dist-info}/licenses/LICENSE +0 -0
- {pepflow-0.1.4a1.dist-info → pepflow-0.1.5.dist-info}/top_level.txt +0 -0
pepflow/function.py
CHANGED
@@ -36,6 +36,19 @@ if TYPE_CHECKING:
|
|
36
36
|
|
37
37
|
@attrs.frozen
|
38
38
|
class Triplet:
|
39
|
+
"""
|
40
|
+
A data class that represents, for some given function :math:`f`,
|
41
|
+
the tuple :math:`\\{x, f(x), \\nabla f(x)\\}`. We can also consider
|
42
|
+
a subgradient :math:`\\widetilde{\\nabla} f(x)` instead of the gradient.
|
43
|
+
|
44
|
+
Attributes:
|
45
|
+
point (:class:`Point`): The point :math:`x`.
|
46
|
+
function_value (:class:`Scalar`): The function value :math:`f(x)`.
|
47
|
+
gradient (:class:`Point`): The gradient :math:`\\nabla f(x)` or
|
48
|
+
a subgradient :math:`\\widetilde{\\nabla} f(x)`.
|
49
|
+
name (str): The unique name of the :class:`Triplet` object.
|
50
|
+
"""
|
51
|
+
|
39
52
|
point: pt.Point
|
40
53
|
function_value: sc.Scalar
|
41
54
|
gradient: pt.Point
|
@@ -53,7 +66,7 @@ class AddedFunc:
|
|
53
66
|
|
54
67
|
@attrs.frozen
|
55
68
|
class ScaledFunc:
|
56
|
-
"""Represents
|
69
|
+
"""Represents scalar * base_func."""
|
57
70
|
|
58
71
|
scale: float
|
59
72
|
base_func: Function
|
@@ -61,6 +74,25 @@ class ScaledFunc:
|
|
61
74
|
|
62
75
|
@attrs.mutable
|
63
76
|
class Function:
|
77
|
+
"""A :class:`Function` object represents a function.
|
78
|
+
|
79
|
+
:class:`Function` objects can be constructed as linear combinations
|
80
|
+
of other :class:`Function` objects. Let `a` and `b` be some numeric
|
81
|
+
data type. Let `f` and `g` be :class:`Function` objects. Then, we
|
82
|
+
can form a new :class:`Function` object: `a*f+b*g`.
|
83
|
+
|
84
|
+
A :class:`Function` object should never be explicitly constructed. Only
|
85
|
+
children of :class:`Function` such as :class:`ConvexFunction` or
|
86
|
+
:class:`SmoothConvexFunction` should be constructed. See their respective
|
87
|
+
documentation to see how.
|
88
|
+
|
89
|
+
Attributes:
|
90
|
+
is_basis (bool): `True` if this function is not formed through a linear
|
91
|
+
combination of other functions. `False` otherwise.
|
92
|
+
tags (list[str]): A list that contains tags that can be used to
|
93
|
+
identify the :class:`Function` object. Tags should be unique.
|
94
|
+
"""
|
95
|
+
|
64
96
|
is_basis: bool
|
65
97
|
|
66
98
|
composition: AddedFunc | ScaledFunc | None = None
|
@@ -79,11 +111,21 @@ class Function:
|
|
79
111
|
|
80
112
|
@property
|
81
113
|
def tag(self):
|
114
|
+
"""Returns the most recently added tag.
|
115
|
+
|
116
|
+
Returns:
|
117
|
+
str: The most recently added tag of this :class:`Function` object.
|
118
|
+
"""
|
82
119
|
if len(self.tags) == 0:
|
83
120
|
raise ValueError("Function should have a name.")
|
84
121
|
return self.tags[-1]
|
85
122
|
|
86
123
|
def add_tag(self, tag: str) -> None:
|
124
|
+
"""Add a new tag for this :class:`Function` object.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
tag (str): The new tag to be added to the `tags` list.
|
128
|
+
"""
|
87
129
|
self.tags.append(tag)
|
88
130
|
|
89
131
|
def __repr__(self):
|
@@ -93,7 +135,7 @@ class Function:
|
|
93
135
|
|
94
136
|
def _repr_latex_(self):
|
95
137
|
s = repr(self)
|
96
|
-
return rf"
|
138
|
+
return rf"$\displaystyle {s}$"
|
97
139
|
|
98
140
|
def get_interpolation_constraints(self):
|
99
141
|
raise NotImplementedError(
|
@@ -161,6 +203,18 @@ class Function:
|
|
161
203
|
return triplet
|
162
204
|
|
163
205
|
def add_stationary_point(self, name: str) -> pt.Point:
|
206
|
+
"""
|
207
|
+
Return a stationary point for this :class:`Function` object.
|
208
|
+
A :class:`Function` object can only have one stationary point.
|
209
|
+
|
210
|
+
Args:
|
211
|
+
name (str): The tag for the :class:`Point` object which
|
212
|
+
will serve as the stationary point.
|
213
|
+
|
214
|
+
Returns:
|
215
|
+
:class:`Point`: The stationary point for this :class:`Function`
|
216
|
+
object.
|
217
|
+
"""
|
164
218
|
# assert we can only add one stationary point?
|
165
219
|
pep_context = pc.get_current_context()
|
166
220
|
if pep_context is None:
|
@@ -171,7 +225,7 @@ class Function:
|
|
171
225
|
)
|
172
226
|
point = pt.Point(is_basis=True)
|
173
227
|
point.add_tag(name)
|
174
|
-
desired_grad =
|
228
|
+
desired_grad = pt.Point.zero() # zero point
|
175
229
|
desired_grad.add_tag(f"gradient_{self.tag}({name})")
|
176
230
|
triplet = self.add_point_with_grad_restriction(point, desired_grad)
|
177
231
|
pep_context.add_stationary_triplet(self, triplet)
|
@@ -195,6 +249,9 @@ class Function:
|
|
195
249
|
if pep_context is None:
|
196
250
|
raise RuntimeError("Did you forget to create a context?")
|
197
251
|
|
252
|
+
if not isinstance(point, pt.Point):
|
253
|
+
raise ValueError("The Function can only take point as input.")
|
254
|
+
|
198
255
|
if self.is_basis:
|
199
256
|
for triplet in pep_context.triplets[self]:
|
200
257
|
if triplet.point.uid == point.uid:
|
@@ -232,14 +289,50 @@ class Function:
|
|
232
289
|
return Triplet(point, function_value, gradient, name=None)
|
233
290
|
|
234
291
|
def gradient(self, point: pt.Point) -> pt.Point:
|
292
|
+
"""
|
293
|
+
Returns a :class:`Point` object that is the gradient of the
|
294
|
+
:class:`Function` at the given :class:`Point`.
|
295
|
+
|
296
|
+
Args:
|
297
|
+
point (:class:`Point`): Any :class:`Point`.
|
298
|
+
|
299
|
+
Returns:
|
300
|
+
:class:`Point`: The gradient of the :class:`Function` at the
|
301
|
+
given :class:`Point`.
|
302
|
+
"""
|
235
303
|
triplet = self.generate_triplet(point)
|
236
304
|
return triplet.gradient
|
237
305
|
|
238
306
|
def subgradient(self, point: pt.Point) -> pt.Point:
|
307
|
+
"""
|
308
|
+
Returns a :class:`Point` object that is the subgradient of the
|
309
|
+
:class:`Function` at the given :class:`Point`.
|
310
|
+
|
311
|
+
Args:
|
312
|
+
point (:class:`Point`): Any :class:`Point`.
|
313
|
+
|
314
|
+
Returns:
|
315
|
+
:class:`Point`: The subgradient of the :class:`Function` at the
|
316
|
+
given :class:`Point`.
|
317
|
+
|
318
|
+
Note:
|
319
|
+
The method `gradient` is exactly the same.
|
320
|
+
"""
|
239
321
|
triplet = self.generate_triplet(point)
|
240
322
|
return triplet.gradient
|
241
323
|
|
242
324
|
def function_value(self, point: pt.Point) -> sc.Scalar:
|
325
|
+
"""
|
326
|
+
Returns a :class:`Scalar` object that is the function value of the
|
327
|
+
:class:`Function` at the given :class:`Point`.
|
328
|
+
|
329
|
+
Args:
|
330
|
+
point (:class:`Point`): Any :class:`Point`.
|
331
|
+
|
332
|
+
Returns:
|
333
|
+
:class:`Point`: The function value of the :class:`Function` at the
|
334
|
+
given :class:`Point`.
|
335
|
+
"""
|
243
336
|
triplet = self.generate_triplet(point)
|
244
337
|
return triplet.function_value
|
245
338
|
|
@@ -247,7 +340,8 @@ class Function:
|
|
247
340
|
return self.function_value(point)
|
248
341
|
|
249
342
|
def __add__(self, other):
|
250
|
-
|
343
|
+
if not isinstance(other, Function):
|
344
|
+
return NotImplemented
|
251
345
|
return Function(
|
252
346
|
is_basis=False,
|
253
347
|
composition=AddedFunc(self, other),
|
@@ -255,7 +349,8 @@ class Function:
|
|
255
349
|
)
|
256
350
|
|
257
351
|
def __sub__(self, other):
|
258
|
-
|
352
|
+
if not isinstance(other, Function):
|
353
|
+
return NotImplemented
|
259
354
|
tag_other = other.tag
|
260
355
|
if isinstance(other.composition, AddedFunc):
|
261
356
|
tag_other = f"({other.tag})"
|
@@ -266,7 +361,8 @@ class Function:
|
|
266
361
|
)
|
267
362
|
|
268
363
|
def __mul__(self, other):
|
269
|
-
|
364
|
+
if not utils.is_numerical(other):
|
365
|
+
return NotImplemented
|
270
366
|
tag_self = self.tag
|
271
367
|
if isinstance(self.composition, AddedFunc):
|
272
368
|
tag_self = f"({self.tag})"
|
@@ -277,7 +373,8 @@ class Function:
|
|
277
373
|
)
|
278
374
|
|
279
375
|
def __rmul__(self, other):
|
280
|
-
|
376
|
+
if not utils.is_numerical(other):
|
377
|
+
return NotImplemented
|
281
378
|
tag_self = self.tag
|
282
379
|
if isinstance(self.composition, AddedFunc):
|
283
380
|
tag_self = f"({self.tag})"
|
@@ -298,7 +395,8 @@ class Function:
|
|
298
395
|
)
|
299
396
|
|
300
397
|
def __truediv__(self, other):
|
301
|
-
|
398
|
+
if not utils.is_numerical(other):
|
399
|
+
return NotImplemented
|
302
400
|
tag_self = self.tag
|
303
401
|
if isinstance(self.composition, AddedFunc):
|
304
402
|
tag_self = f"({self.tag})"
|
@@ -318,6 +416,22 @@ class Function:
|
|
318
416
|
|
319
417
|
|
320
418
|
class ConvexFunction(Function):
|
419
|
+
"""
|
420
|
+
The :class:`ConvexFunction` class is a child of :class:`Function.`
|
421
|
+
The :class:`ConvexFunction` class represents a closed, convex, and
|
422
|
+
proper (CCP) function, i.e., a convex function whose epigraph is a
|
423
|
+
non-empty closed set.
|
424
|
+
|
425
|
+
A CCP function typically has no parameters. We can instantiate a
|
426
|
+
:class:`ConvexFunction` object as follows:
|
427
|
+
|
428
|
+
Example:
|
429
|
+
>>> import pepflow as pf
|
430
|
+
>>> ctx = pf.PEPContext("example").set_as_current()
|
431
|
+
>>> pep_builder = pf.PEPBuilder()
|
432
|
+
>>> g = pep_builder.declare_func(pf.ConvexFunction, "g")
|
433
|
+
"""
|
434
|
+
|
321
435
|
def __init__(
|
322
436
|
self,
|
323
437
|
is_basis=True,
|
@@ -365,7 +479,16 @@ class ConvexFunction(Function):
|
|
365
479
|
def interpolate_ineq(
|
366
480
|
self, p1_tag: str, p2_tag: str, pep_context: pc.PEPContext | None = None
|
367
481
|
) -> sc.Scalar:
|
368
|
-
"""Generate the interpolation inequality
|
482
|
+
"""Generate the interpolation inequality :class:`Scalar` by tags.
|
483
|
+
The interpolation inequality between two points :math:`p_1, p_2` for a
|
484
|
+
CCP function :math:`f` is
|
485
|
+
|
486
|
+
.. math:: f(p_2) - f(p_1) + \\langle \\nabla f(p_2), p_1 - p_2 \\rangle.
|
487
|
+
|
488
|
+
Args:
|
489
|
+
p1_tag (str): A tag of the :class:`Point` :math:`p_1`.
|
490
|
+
p2_tag (str): A tag of the :class:`Point` :math:`p_2`.
|
491
|
+
"""
|
369
492
|
if pep_context is None:
|
370
493
|
pep_context = pc.get_current_context()
|
371
494
|
if pep_context is None:
|
@@ -379,6 +502,29 @@ class ConvexFunction(Function):
|
|
379
502
|
return f2 - f1 + g2 * (x1 - x2)
|
380
503
|
|
381
504
|
def proximal_step(self, x_0: pt.Point, stepsize: numbers.Number) -> pt.Point:
|
505
|
+
""" Define the proximal operator as
|
506
|
+
|
507
|
+
.. math:: \\text{prox}_{\\gamma f}(x_0) := \\arg\\min_x \\left\\{ \\gamma f(x) + \\frac{1}{2} \\|x - x_0\\|^2 \\right\\}.
|
508
|
+
|
509
|
+
This function performs a proximal step with respect to some
|
510
|
+
:class:`Function` :math:`f` on the :class:`Point` :math:`x_0`
|
511
|
+
with stepsize :math:`\\gamma`:
|
512
|
+
|
513
|
+
.. math::
|
514
|
+
:nowrap:
|
515
|
+
|
516
|
+
\\begin{eqnarray}
|
517
|
+
x := \\text{prox}_{\\gamma f}(x_0) & := & \\arg\\min_x \\left\\{ \\gamma f(x) + \\frac{1}{2} \\|x - x_0\\|^2 \\right\\}, \\\\
|
518
|
+
& \\Updownarrow & \\\\
|
519
|
+
0 & = & \\gamma \\partial f(x) + x - x_0,\\\\
|
520
|
+
& \\Updownarrow & \\\\
|
521
|
+
x & = & x_0 - \\gamma \\widetilde{\\nabla} f(x) \\text{ where } \\widetilde{\\nabla} f(x)\\in\\partial f(x).
|
522
|
+
\\end{eqnarray}
|
523
|
+
|
524
|
+
Args:
|
525
|
+
x_0 (:class:`Point`): The initial point.
|
526
|
+
stepsize (int | float): The stepsize.
|
527
|
+
"""
|
382
528
|
gradient = pt.Point(is_basis=True)
|
383
529
|
gradient.add_tag(
|
384
530
|
f"gradient_{self.tag}(prox_{{{stepsize}*{self.tag}}}({x_0.tag}))"
|
@@ -398,6 +544,21 @@ class ConvexFunction(Function):
|
|
398
544
|
|
399
545
|
|
400
546
|
class SmoothConvexFunction(Function):
|
547
|
+
"""
|
548
|
+
The :class:`SmoothConvexFunction` class is a child of :class:`Function.`
|
549
|
+
The :class:`SmoothConvexFunction` class represents a smooth,
|
550
|
+
convex function.
|
551
|
+
|
552
|
+
A smooth, convex function has a smoothness parameter :math:`L`.
|
553
|
+
We can instantiate a :class:`SmoothConvexFunction` object as follows:
|
554
|
+
|
555
|
+
Example:
|
556
|
+
>>> import pepflow as pf
|
557
|
+
>>> ctx = pf.PEPContext("example").set_as_current()
|
558
|
+
>>> pep_builder = pf.PEPBuilder()
|
559
|
+
>>> f = pep_builder.declare_func(pf.SmoothConvexFunction, "f", L=1)
|
560
|
+
"""
|
561
|
+
|
401
562
|
def __init__(
|
402
563
|
self,
|
403
564
|
L,
|
@@ -445,7 +606,16 @@ class SmoothConvexFunction(Function):
|
|
445
606
|
def interpolate_ineq(
|
446
607
|
self, p1_tag: str, p2_tag: str, pep_context: pc.PEPContext | None = None
|
447
608
|
) -> sc.Scalar:
|
448
|
-
"""Generate the interpolation inequality
|
609
|
+
"""Generate the interpolation inequality :class:`Scalar` by tags.
|
610
|
+
The interpolation inequality between two points :math:`p_1, p_2` for a
|
611
|
+
smooth, convex function :math:`f` is
|
612
|
+
|
613
|
+
.. math:: f(p_2) - f(p_1) + \\langle \\nabla f(p_2), p_1 - p_2 \\rangle + \\tfrac{1}{2} \\lVert \\nabla f(p_1) - \\nabla f(p_2) \\rVert^2.
|
614
|
+
|
615
|
+
Args:
|
616
|
+
p1_tag (str): A tag of the :class:`Point` :math:`p_1`.
|
617
|
+
p2_tag (str): A tag of the :class:`Point` :math:`p_2`.
|
618
|
+
"""
|
449
619
|
if pep_context is None:
|
450
620
|
pep_context = pc.get_current_context()
|
451
621
|
if pep_context is None:
|
pepflow/parameter.py
ADDED
@@ -0,0 +1,187 @@
|
|
1
|
+
# Copyright: 2025 The PEPFlow Developers
|
2
|
+
#
|
3
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
4
|
+
# or more contributor license agreements. See the NOTICE file
|
5
|
+
# distributed with this work for additional information
|
6
|
+
# regarding copyright ownership. The ASF licenses this file
|
7
|
+
# to you under the Apache License, Version 2.0 (the
|
8
|
+
# "License"); you may not use this file except in compliance
|
9
|
+
# with the License. You may obtain a copy of the License at
|
10
|
+
#
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12
|
+
#
|
13
|
+
# Unless required by applicable law or agreed to in writing,
|
14
|
+
# software distributed under the License is distributed on an
|
15
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
16
|
+
# KIND, either express or implied. See the License for the
|
17
|
+
# specific language governing permissions and limitations
|
18
|
+
# under the License.
|
19
|
+
|
20
|
+
from __future__ import annotations
|
21
|
+
|
22
|
+
import attrs
|
23
|
+
|
24
|
+
from pepflow import utils
|
25
|
+
|
26
|
+
# Sentile of no found of resolving parameters
|
27
|
+
NOT_FOUND = "__NOT_FOUND__"
|
28
|
+
|
29
|
+
|
30
|
+
@attrs.frozen
|
31
|
+
class ParameterRepresentation:
|
32
|
+
op: utils.Op
|
33
|
+
left_param: utils.NUMERICAL_TYPE | Parameter
|
34
|
+
right_param: utils.NUMERICAL_TYPE | Parameter
|
35
|
+
|
36
|
+
|
37
|
+
def eval_parameter(
|
38
|
+
param: Parameter | utils.NUMERICAL_TYPE,
|
39
|
+
resolve_parameters: dict[str, utils.NUMERICAL_TYPE],
|
40
|
+
) -> utils.NUMERICAL_TYPE:
|
41
|
+
if isinstance(param, Parameter):
|
42
|
+
return param.get_value(resolve_parameters)
|
43
|
+
if utils.is_numerical(param):
|
44
|
+
return param
|
45
|
+
raise ValueError(f"Encounter the unknown parameter type: {param} ({type(param)})")
|
46
|
+
|
47
|
+
|
48
|
+
@attrs.frozen
|
49
|
+
class Parameter:
|
50
|
+
# If name is None, it is a composite parameter.
|
51
|
+
name: str | None
|
52
|
+
|
53
|
+
eval_expression: ParameterRepresentation | None = None
|
54
|
+
|
55
|
+
def __attrs_post_init__(self):
|
56
|
+
if self.name is None and self.eval_expression is None:
|
57
|
+
raise ValueError(
|
58
|
+
"For a parameter, must specify a name or an eval_expression"
|
59
|
+
)
|
60
|
+
if self.name is None or self.eval_expression is None:
|
61
|
+
return
|
62
|
+
|
63
|
+
raise ValueError(
|
64
|
+
"For a parameter, only one of name or eval_expression should be None."
|
65
|
+
)
|
66
|
+
|
67
|
+
def __repr__(self):
|
68
|
+
if self.eval_expression is None:
|
69
|
+
return self.name
|
70
|
+
|
71
|
+
op = self.eval_expression.op
|
72
|
+
left_param = self.eval_expression.left_param
|
73
|
+
right_param = self.eval_expression.right_param
|
74
|
+
# TODO having a better parentheses handling.
|
75
|
+
if op == utils.Op.ADD:
|
76
|
+
return f"({left_param}+{right_param})"
|
77
|
+
if op == utils.Op.SUB:
|
78
|
+
return f"({left_param}-{right_param})"
|
79
|
+
if op == utils.Op.MUL:
|
80
|
+
return f"({left_param}*{right_param})"
|
81
|
+
if op == utils.Op.DIV:
|
82
|
+
return f"({left_param}/{right_param})"
|
83
|
+
|
84
|
+
def get_value(
|
85
|
+
self, resolve_parameters: dict[str, utils.NUMERICAL_TYPE]
|
86
|
+
) -> utils.NUMERICAL_TYPE:
|
87
|
+
if self.eval_expression is None:
|
88
|
+
val = resolve_parameters.get(self.name, NOT_FOUND)
|
89
|
+
if val is NOT_FOUND:
|
90
|
+
raise ValueError(f"Cannot resolve Parameter named: {self.name}")
|
91
|
+
return val
|
92
|
+
op = self.eval_expression.op
|
93
|
+
left_param = eval_parameter(self.eval_expression.left_param, resolve_parameters)
|
94
|
+
right_param = eval_parameter(
|
95
|
+
self.eval_expression.right_param, resolve_parameters
|
96
|
+
)
|
97
|
+
|
98
|
+
if op == utils.Op.ADD:
|
99
|
+
return left_param + right_param
|
100
|
+
if op == utils.Op.SUB:
|
101
|
+
return left_param - right_param
|
102
|
+
if op == utils.Op.MUL:
|
103
|
+
return left_param * right_param
|
104
|
+
if op == utils.Op.DIV:
|
105
|
+
return left_param / right_param
|
106
|
+
|
107
|
+
raise ValueError(f"Encountered unknown {op=} when evaluation the point.")
|
108
|
+
|
109
|
+
def __add__(self, other):
|
110
|
+
if not utils.is_numerical_or_parameter(other):
|
111
|
+
return NotImplemented
|
112
|
+
return Parameter(
|
113
|
+
name=None,
|
114
|
+
eval_expression=ParameterRepresentation(
|
115
|
+
op=utils.Op.ADD, left_param=self, right_param=other
|
116
|
+
),
|
117
|
+
)
|
118
|
+
|
119
|
+
def __radd__(self, other):
|
120
|
+
if not utils.is_numerical_or_parameter(other):
|
121
|
+
return NotImplemented
|
122
|
+
return Parameter(
|
123
|
+
name=None,
|
124
|
+
eval_expression=ParameterRepresentation(
|
125
|
+
op=utils.Op.ADD, left_param=other, right_param=self
|
126
|
+
),
|
127
|
+
)
|
128
|
+
|
129
|
+
def __sub__(self, other):
|
130
|
+
if not utils.is_numerical_or_parameter(other):
|
131
|
+
return NotImplemented
|
132
|
+
return Parameter(
|
133
|
+
name=None,
|
134
|
+
eval_expression=ParameterRepresentation(
|
135
|
+
op=utils.Op.SUB, left_param=self, right_param=other
|
136
|
+
),
|
137
|
+
)
|
138
|
+
|
139
|
+
def __rsub__(self, other):
|
140
|
+
if not utils.is_numerical_or_parameter(other):
|
141
|
+
return NotImplemented
|
142
|
+
return Parameter(
|
143
|
+
name=None,
|
144
|
+
eval_expression=ParameterRepresentation(
|
145
|
+
op=utils.Op.SUB, left_param=other, right_param=self
|
146
|
+
),
|
147
|
+
)
|
148
|
+
|
149
|
+
def __mul__(self, other):
|
150
|
+
if not utils.is_numerical_or_parameter(other):
|
151
|
+
return NotImplemented
|
152
|
+
return Parameter(
|
153
|
+
name=None,
|
154
|
+
eval_expression=ParameterRepresentation(
|
155
|
+
op=utils.Op.MUL, left_param=self, right_param=other
|
156
|
+
),
|
157
|
+
)
|
158
|
+
|
159
|
+
def __rmul__(self, other):
|
160
|
+
if not utils.is_numerical_or_parameter(other):
|
161
|
+
return NotImplemented
|
162
|
+
return Parameter(
|
163
|
+
name=None,
|
164
|
+
eval_expression=ParameterRepresentation(
|
165
|
+
op=utils.Op.MUL, left_param=other, right_param=self
|
166
|
+
),
|
167
|
+
)
|
168
|
+
|
169
|
+
def __truediv__(self, other):
|
170
|
+
if not utils.is_numerical_or_parameter(other):
|
171
|
+
return NotImplemented
|
172
|
+
return Parameter(
|
173
|
+
name=None,
|
174
|
+
eval_expression=ParameterRepresentation(
|
175
|
+
op=utils.Op.DIV, left_param=self, right_param=other
|
176
|
+
),
|
177
|
+
)
|
178
|
+
|
179
|
+
def __rtruediv__(self, other):
|
180
|
+
if not utils.is_numerical_or_parameter(other):
|
181
|
+
return NotImplemented
|
182
|
+
return Parameter(
|
183
|
+
name=None,
|
184
|
+
eval_expression=ParameterRepresentation(
|
185
|
+
op=utils.Op.DIV, left_param=other, right_param=self
|
186
|
+
),
|
187
|
+
)
|
@@ -0,0 +1,128 @@
|
|
1
|
+
# Copyright: 2025 The PEPFlow Developers
|
2
|
+
#
|
3
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
4
|
+
# or more contributor license agreements. See the NOTICE file
|
5
|
+
# distributed with this work for additional information
|
6
|
+
# regarding copyright ownership. The ASF licenses this file
|
7
|
+
# to you under the Apache License, Version 2.0 (the
|
8
|
+
# "License"); you may not use this file except in compliance
|
9
|
+
# with the License. You may obtain a copy of the License at
|
10
|
+
#
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12
|
+
#
|
13
|
+
# Unless required by applicable law or agreed to in writing,
|
14
|
+
# software distributed under the License is distributed on an
|
15
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
16
|
+
# KIND, either express or implied. See the License for the
|
17
|
+
# specific language governing permissions and limitations
|
18
|
+
# under the License.
|
19
|
+
|
20
|
+
from typing import Iterator
|
21
|
+
|
22
|
+
import numpy as np
|
23
|
+
import pytest
|
24
|
+
import sympy as sp
|
25
|
+
|
26
|
+
from pepflow import pep_context as pc
|
27
|
+
from pepflow.expression_manager import ExpressionManager
|
28
|
+
from pepflow.parameter import Parameter
|
29
|
+
from pepflow.point import Point
|
30
|
+
from pepflow.scalar import Scalar
|
31
|
+
|
32
|
+
|
33
|
+
@pytest.fixture
|
34
|
+
def pep_context() -> Iterator[pc.PEPContext]:
|
35
|
+
"""Prepare the pep context and reset the context to None at the end."""
|
36
|
+
ctx = pc.PEPContext("test").set_as_current()
|
37
|
+
yield ctx
|
38
|
+
pc.set_current_context(None)
|
39
|
+
|
40
|
+
|
41
|
+
def test_parameter_interact_with_scalar(pep_context: pc.PEPContext):
|
42
|
+
pm1 = Parameter("pm1")
|
43
|
+
s1 = Scalar(is_basis=True, tags=["s1"])
|
44
|
+
|
45
|
+
_ = pm1 + s1
|
46
|
+
_ = s1 + pm1
|
47
|
+
_ = pm1 - s1
|
48
|
+
_ = s1 - pm1
|
49
|
+
_ = s1 * pm1
|
50
|
+
_ = pm1 * s1
|
51
|
+
_ = s1 / pm1
|
52
|
+
|
53
|
+
|
54
|
+
def test_parameter_interact_with_point(pep_context: pc.PEPContext):
|
55
|
+
pm1 = Parameter("pm1")
|
56
|
+
p1 = Point(is_basis=True, tags=["p1"])
|
57
|
+
|
58
|
+
_ = p1 * pm1
|
59
|
+
_ = pm1 * p1
|
60
|
+
_ = p1 / pm1
|
61
|
+
|
62
|
+
|
63
|
+
def test_parameter_composition_with_point_and_scalar(pep_context: pc.PEPContext):
|
64
|
+
pm1 = Parameter("pm1")
|
65
|
+
pm2 = Parameter("pm2")
|
66
|
+
p1 = Point(is_basis=True, tags=["p1"])
|
67
|
+
s1 = Scalar(is_basis=True, tags=["s1"])
|
68
|
+
|
69
|
+
s2 = s1 + pm1 + pm2 * p1**2
|
70
|
+
assert str(s2) == "s1+pm1+pm2*|p1|^2"
|
71
|
+
|
72
|
+
|
73
|
+
def test_parameter_composition(pep_context: pc.PEPContext):
|
74
|
+
pm1 = Parameter("pm1")
|
75
|
+
pm2 = Parameter("pm2")
|
76
|
+
|
77
|
+
pp = (pm1 + 2) * pm2
|
78
|
+
assert str(pp) == "((pm1+2)*pm2)"
|
79
|
+
assert pp.get_value({"pm1": 3, "pm2": 6}) == 30
|
80
|
+
|
81
|
+
pp2 = (pm1 + sp.Rational(1, 2)) * pm2
|
82
|
+
assert str(pp2) == "((pm1+1/2)*pm2)"
|
83
|
+
assert pp2.get_value({"pm1": sp.Rational(1, 3), "pm2": sp.Rational(6, 5)}) == 1
|
84
|
+
|
85
|
+
|
86
|
+
def test_expression_manager_eval_with_parameter(pep_context: pc.PEPContext):
|
87
|
+
pm1 = Parameter("pm1")
|
88
|
+
p1 = Point(is_basis=True, tags=["p1"])
|
89
|
+
p2 = Point(is_basis=True, tags=["p2"])
|
90
|
+
p3 = pm1 * p1 + p2 / 4
|
91
|
+
|
92
|
+
em = ExpressionManager(pep_context, {"pm1": 2.3})
|
93
|
+
np.testing.assert_allclose(em.eval_point(p3).vector, np.array([2.3, 0.25]))
|
94
|
+
|
95
|
+
em = ExpressionManager(pep_context, {"pm1": 3.4})
|
96
|
+
np.testing.assert_allclose(em.eval_point(p3).vector, np.array([3.4, 0.25]))
|
97
|
+
|
98
|
+
|
99
|
+
def test_expression_manager_eval_with_parameter_scalar(pep_context: pc.PEPContext):
|
100
|
+
pm1 = Parameter("pm1")
|
101
|
+
pm2 = Parameter("pm2")
|
102
|
+
p1 = Point(is_basis=True, tags=["p1"])
|
103
|
+
p2 = Point(is_basis=True, tags=["p2"])
|
104
|
+
s1 = Scalar(is_basis=True, tags=["s1"])
|
105
|
+
s2 = pm1 * p1 * p2 + pm2 + s1
|
106
|
+
|
107
|
+
em = ExpressionManager(pep_context, {"pm1": 2.4, "pm2": 4.3})
|
108
|
+
assert np.isclose(em.eval_scalar(s2).constant, 4.3)
|
109
|
+
np.testing.assert_allclose(em.eval_scalar(s2).vector, np.array([1]))
|
110
|
+
np.testing.assert_allclose(
|
111
|
+
em.eval_scalar(s2).matrix, np.array([[0, 1.2], [1.2, 0]])
|
112
|
+
)
|
113
|
+
|
114
|
+
|
115
|
+
def test_expression_manager_eval_composition(pep_context: pc.PEPContext):
|
116
|
+
pm1 = Parameter("pm1")
|
117
|
+
pm2 = Parameter("pm2")
|
118
|
+
p1 = Point(is_basis=True, tags=["p1"])
|
119
|
+
p2 = Point(is_basis=True, tags=["p2"])
|
120
|
+
s1 = Scalar(is_basis=True, tags=["s1"])
|
121
|
+
|
122
|
+
s2 = 1 / pm1 * p1 * p2 + (pm2 + 1) * s1
|
123
|
+
em = ExpressionManager(pep_context, {"pm1": 0.5, "pm2": 4.3})
|
124
|
+
assert np.isclose(em.eval_scalar(s2).constant, 0)
|
125
|
+
np.testing.assert_allclose(em.eval_scalar(s2).vector, np.array([5.3]))
|
126
|
+
np.testing.assert_allclose(
|
127
|
+
em.eval_scalar(s2).matrix, np.array([[0, 1.0], [1.0, 0]])
|
128
|
+
)
|