pepflow 0.1.4a1__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pepflow/function.py CHANGED
@@ -36,6 +36,19 @@ if TYPE_CHECKING:
36
36
 
37
37
  @attrs.frozen
38
38
  class Triplet:
39
+ """
40
+ A data class that represents, for some given function :math:`f`,
41
+ the tuple :math:`\\{x, f(x), \\nabla f(x)\\}`. We can also consider
42
+ a subgradient :math:`\\widetilde{\\nabla} f(x)` instead of the gradient.
43
+
44
+ Attributes:
45
+ point (:class:`Point`): The point :math:`x`.
46
+ function_value (:class:`Scalar`): The function value :math:`f(x)`.
47
+ gradient (:class:`Point`): The gradient :math:`\\nabla f(x)` or
48
+ a subgradient :math:`\\widetilde{\\nabla} f(x)`.
49
+ name (str): The unique name of the :class:`Triplet` object.
50
+ """
51
+
39
52
  point: pt.Point
40
53
  function_value: sc.Scalar
41
54
  gradient: pt.Point
@@ -53,7 +66,7 @@ class AddedFunc:
53
66
 
54
67
  @attrs.frozen
55
68
  class ScaledFunc:
56
- """Represents scale * base_func."""
69
+ """Represents scalar * base_func."""
57
70
 
58
71
  scale: float
59
72
  base_func: Function
@@ -61,6 +74,25 @@ class ScaledFunc:
61
74
 
62
75
  @attrs.mutable
63
76
  class Function:
77
+ """A :class:`Function` object represents a function.
78
+
79
+ :class:`Function` objects can be constructed as linear combinations
80
+ of other :class:`Function` objects. Let `a` and `b` be some numeric
81
+ data type. Let `f` and `g` be :class:`Function` objects. Then, we
82
+ can form a new :class:`Function` object: `a*f+b*g`.
83
+
84
+ A :class:`Function` object should never be explicitly constructed. Only
85
+ children of :class:`Function` such as :class:`ConvexFunction` or
86
+ :class:`SmoothConvexFunction` should be constructed. See their respective
87
+ documentation to see how.
88
+
89
+ Attributes:
90
+ is_basis (bool): `True` if this function is not formed through a linear
91
+ combination of other functions. `False` otherwise.
92
+ tags (list[str]): A list that contains tags that can be used to
93
+ identify the :class:`Function` object. Tags should be unique.
94
+ """
95
+
64
96
  is_basis: bool
65
97
 
66
98
  composition: AddedFunc | ScaledFunc | None = None
@@ -79,11 +111,21 @@ class Function:
79
111
 
80
112
  @property
81
113
  def tag(self):
114
+ """Returns the most recently added tag.
115
+
116
+ Returns:
117
+ str: The most recently added tag of this :class:`Function` object.
118
+ """
82
119
  if len(self.tags) == 0:
83
120
  raise ValueError("Function should have a name.")
84
121
  return self.tags[-1]
85
122
 
86
123
  def add_tag(self, tag: str) -> None:
124
+ """Add a new tag for this :class:`Function` object.
125
+
126
+ Args:
127
+ tag (str): The new tag to be added to the `tags` list.
128
+ """
87
129
  self.tags.append(tag)
88
130
 
89
131
  def __repr__(self):
@@ -93,7 +135,7 @@ class Function:
93
135
 
94
136
  def _repr_latex_(self):
95
137
  s = repr(self)
96
- return rf"$\\displaystyle {s}$"
138
+ return rf"$\displaystyle {s}$"
97
139
 
98
140
  def get_interpolation_constraints(self):
99
141
  raise NotImplementedError(
@@ -161,6 +203,18 @@ class Function:
161
203
  return triplet
162
204
 
163
205
  def add_stationary_point(self, name: str) -> pt.Point:
206
+ """
207
+ Return a stationary point for this :class:`Function` object.
208
+ A :class:`Function` object can only have one stationary point.
209
+
210
+ Args:
211
+ name (str): The tag for the :class:`Point` object which
212
+ will serve as the stationary point.
213
+
214
+ Returns:
215
+ :class:`Point`: The stationary point for this :class:`Function`
216
+ object.
217
+ """
164
218
  # assert we can only add one stationary point?
165
219
  pep_context = pc.get_current_context()
166
220
  if pep_context is None:
@@ -171,7 +225,7 @@ class Function:
171
225
  )
172
226
  point = pt.Point(is_basis=True)
173
227
  point.add_tag(name)
174
- desired_grad = 0 * point
228
+ desired_grad = pt.Point.zero() # zero point
175
229
  desired_grad.add_tag(f"gradient_{self.tag}({name})")
176
230
  triplet = self.add_point_with_grad_restriction(point, desired_grad)
177
231
  pep_context.add_stationary_triplet(self, triplet)
@@ -195,6 +249,9 @@ class Function:
195
249
  if pep_context is None:
196
250
  raise RuntimeError("Did you forget to create a context?")
197
251
 
252
+ if not isinstance(point, pt.Point):
253
+ raise ValueError("The Function can only take point as input.")
254
+
198
255
  if self.is_basis:
199
256
  for triplet in pep_context.triplets[self]:
200
257
  if triplet.point.uid == point.uid:
@@ -232,14 +289,50 @@ class Function:
232
289
  return Triplet(point, function_value, gradient, name=None)
233
290
 
234
291
  def gradient(self, point: pt.Point) -> pt.Point:
292
+ """
293
+ Returns a :class:`Point` object that is the gradient of the
294
+ :class:`Function` at the given :class:`Point`.
295
+
296
+ Args:
297
+ point (:class:`Point`): Any :class:`Point`.
298
+
299
+ Returns:
300
+ :class:`Point`: The gradient of the :class:`Function` at the
301
+ given :class:`Point`.
302
+ """
235
303
  triplet = self.generate_triplet(point)
236
304
  return triplet.gradient
237
305
 
238
306
  def subgradient(self, point: pt.Point) -> pt.Point:
307
+ """
308
+ Returns a :class:`Point` object that is the subgradient of the
309
+ :class:`Function` at the given :class:`Point`.
310
+
311
+ Args:
312
+ point (:class:`Point`): Any :class:`Point`.
313
+
314
+ Returns:
315
+ :class:`Point`: The subgradient of the :class:`Function` at the
316
+ given :class:`Point`.
317
+
318
+ Note:
319
+ The method `gradient` is exactly the same.
320
+ """
239
321
  triplet = self.generate_triplet(point)
240
322
  return triplet.gradient
241
323
 
242
324
  def function_value(self, point: pt.Point) -> sc.Scalar:
325
+ """
326
+ Returns a :class:`Scalar` object that is the function value of the
327
+ :class:`Function` at the given :class:`Point`.
328
+
329
+ Args:
330
+ point (:class:`Point`): Any :class:`Point`.
331
+
332
+ Returns:
333
+ :class:`Point`: The function value of the :class:`Function` at the
334
+ given :class:`Point`.
335
+ """
243
336
  triplet = self.generate_triplet(point)
244
337
  return triplet.function_value
245
338
 
@@ -247,7 +340,8 @@ class Function:
247
340
  return self.function_value(point)
248
341
 
249
342
  def __add__(self, other):
250
- assert isinstance(other, Function)
343
+ if not isinstance(other, Function):
344
+ return NotImplemented
251
345
  return Function(
252
346
  is_basis=False,
253
347
  composition=AddedFunc(self, other),
@@ -255,7 +349,8 @@ class Function:
255
349
  )
256
350
 
257
351
  def __sub__(self, other):
258
- assert isinstance(other, Function)
352
+ if not isinstance(other, Function):
353
+ return NotImplemented
259
354
  tag_other = other.tag
260
355
  if isinstance(other.composition, AddedFunc):
261
356
  tag_other = f"({other.tag})"
@@ -266,7 +361,8 @@ class Function:
266
361
  )
267
362
 
268
363
  def __mul__(self, other):
269
- assert utils.is_numerical(other)
364
+ if not utils.is_numerical(other):
365
+ return NotImplemented
270
366
  tag_self = self.tag
271
367
  if isinstance(self.composition, AddedFunc):
272
368
  tag_self = f"({self.tag})"
@@ -277,7 +373,8 @@ class Function:
277
373
  )
278
374
 
279
375
  def __rmul__(self, other):
280
- assert utils.is_numerical(other)
376
+ if not utils.is_numerical(other):
377
+ return NotImplemented
281
378
  tag_self = self.tag
282
379
  if isinstance(self.composition, AddedFunc):
283
380
  tag_self = f"({self.tag})"
@@ -298,7 +395,8 @@ class Function:
298
395
  )
299
396
 
300
397
  def __truediv__(self, other):
301
- assert utils.is_numerical(other)
398
+ if not utils.is_numerical(other):
399
+ return NotImplemented
302
400
  tag_self = self.tag
303
401
  if isinstance(self.composition, AddedFunc):
304
402
  tag_self = f"({self.tag})"
@@ -318,6 +416,22 @@ class Function:
318
416
 
319
417
 
320
418
  class ConvexFunction(Function):
419
+ """
420
+ The :class:`ConvexFunction` class is a child of :class:`Function.`
421
+ The :class:`ConvexFunction` class represents a closed, convex, and
422
+ proper (CCP) function, i.e., a convex function whose epigraph is a
423
+ non-empty closed set.
424
+
425
+ A CCP function typically has no parameters. We can instantiate a
426
+ :class:`ConvexFunction` object as follows:
427
+
428
+ Example:
429
+ >>> import pepflow as pf
430
+ >>> ctx = pf.PEPContext("example").set_as_current()
431
+ >>> pep_builder = pf.PEPBuilder()
432
+ >>> g = pep_builder.declare_func(pf.ConvexFunction, "g")
433
+ """
434
+
321
435
  def __init__(
322
436
  self,
323
437
  is_basis=True,
@@ -365,7 +479,16 @@ class ConvexFunction(Function):
365
479
  def interpolate_ineq(
366
480
  self, p1_tag: str, p2_tag: str, pep_context: pc.PEPContext | None = None
367
481
  ) -> sc.Scalar:
368
- """Generate the interpolation inequality scalar by tags."""
482
+ """Generate the interpolation inequality :class:`Scalar` by tags.
483
+ The interpolation inequality between two points :math:`p_1, p_2` for a
484
+ CCP function :math:`f` is
485
+
486
+ .. math:: f(p_2) - f(p_1) + \\langle \\nabla f(p_2), p_1 - p_2 \\rangle.
487
+
488
+ Args:
489
+ p1_tag (str): A tag of the :class:`Point` :math:`p_1`.
490
+ p2_tag (str): A tag of the :class:`Point` :math:`p_2`.
491
+ """
369
492
  if pep_context is None:
370
493
  pep_context = pc.get_current_context()
371
494
  if pep_context is None:
@@ -379,6 +502,29 @@ class ConvexFunction(Function):
379
502
  return f2 - f1 + g2 * (x1 - x2)
380
503
 
381
504
  def proximal_step(self, x_0: pt.Point, stepsize: numbers.Number) -> pt.Point:
505
+ """ Define the proximal operator as
506
+
507
+ .. math:: \\text{prox}_{\\gamma f}(x_0) := \\arg\\min_x \\left\\{ \\gamma f(x) + \\frac{1}{2} \\|x - x_0\\|^2 \\right\\}.
508
+
509
+ This function performs a proximal step with respect to some
510
+ :class:`Function` :math:`f` on the :class:`Point` :math:`x_0`
511
+ with stepsize :math:`\\gamma`:
512
+
513
+ .. math::
514
+ :nowrap:
515
+
516
+ \\begin{eqnarray}
517
+ x := \\text{prox}_{\\gamma f}(x_0) & := & \\arg\\min_x \\left\\{ \\gamma f(x) + \\frac{1}{2} \\|x - x_0\\|^2 \\right\\}, \\\\
518
+ & \\Updownarrow & \\\\
519
+ 0 & = & \\gamma \\partial f(x) + x - x_0,\\\\
520
+ & \\Updownarrow & \\\\
521
+ x & = & x_0 - \\gamma \\widetilde{\\nabla} f(x) \\text{ where } \\widetilde{\\nabla} f(x)\\in\\partial f(x).
522
+ \\end{eqnarray}
523
+
524
+ Args:
525
+ x_0 (:class:`Point`): The initial point.
526
+ stepsize (int | float): The stepsize.
527
+ """
382
528
  gradient = pt.Point(is_basis=True)
383
529
  gradient.add_tag(
384
530
  f"gradient_{self.tag}(prox_{{{stepsize}*{self.tag}}}({x_0.tag}))"
@@ -398,6 +544,21 @@ class ConvexFunction(Function):
398
544
 
399
545
 
400
546
  class SmoothConvexFunction(Function):
547
+ """
548
+ The :class:`SmoothConvexFunction` class is a child of :class:`Function.`
549
+ The :class:`SmoothConvexFunction` class represents a smooth,
550
+ convex function.
551
+
552
+ A smooth, convex function has a smoothness parameter :math:`L`.
553
+ We can instantiate a :class:`SmoothConvexFunction` object as follows:
554
+
555
+ Example:
556
+ >>> import pepflow as pf
557
+ >>> ctx = pf.PEPContext("example").set_as_current()
558
+ >>> pep_builder = pf.PEPBuilder()
559
+ >>> f = pep_builder.declare_func(pf.SmoothConvexFunction, "f", L=1)
560
+ """
561
+
401
562
  def __init__(
402
563
  self,
403
564
  L,
@@ -445,7 +606,16 @@ class SmoothConvexFunction(Function):
445
606
  def interpolate_ineq(
446
607
  self, p1_tag: str, p2_tag: str, pep_context: pc.PEPContext | None = None
447
608
  ) -> sc.Scalar:
448
- """Generate the interpolation inequality scalar by tags."""
609
+ """Generate the interpolation inequality :class:`Scalar` by tags.
610
+ The interpolation inequality between two points :math:`p_1, p_2` for a
611
+ smooth, convex function :math:`f` is
612
+
613
+ .. math:: f(p_2) - f(p_1) + \\langle \\nabla f(p_2), p_1 - p_2 \\rangle + \\tfrac{1}{2} \\lVert \\nabla f(p_1) - \\nabla f(p_2) \\rVert^2.
614
+
615
+ Args:
616
+ p1_tag (str): A tag of the :class:`Point` :math:`p_1`.
617
+ p2_tag (str): A tag of the :class:`Point` :math:`p_2`.
618
+ """
449
619
  if pep_context is None:
450
620
  pep_context = pc.get_current_context()
451
621
  if pep_context is None:
pepflow/parameter.py ADDED
@@ -0,0 +1,187 @@
1
+ # Copyright: 2025 The PEPFlow Developers
2
+ #
3
+ # Licensed to the Apache Software Foundation (ASF) under one
4
+ # or more contributor license agreements. See the NOTICE file
5
+ # distributed with this work for additional information
6
+ # regarding copyright ownership. The ASF licenses this file
7
+ # to you under the Apache License, Version 2.0 (the
8
+ # "License"); you may not use this file except in compliance
9
+ # with the License. You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing,
14
+ # software distributed under the License is distributed on an
15
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16
+ # KIND, either express or implied. See the License for the
17
+ # specific language governing permissions and limitations
18
+ # under the License.
19
+
20
+ from __future__ import annotations
21
+
22
+ import attrs
23
+
24
+ from pepflow import utils
25
+
26
+ # Sentile of no found of resolving parameters
27
+ NOT_FOUND = "__NOT_FOUND__"
28
+
29
+
30
+ @attrs.frozen
31
+ class ParameterRepresentation:
32
+ op: utils.Op
33
+ left_param: utils.NUMERICAL_TYPE | Parameter
34
+ right_param: utils.NUMERICAL_TYPE | Parameter
35
+
36
+
37
+ def eval_parameter(
38
+ param: Parameter | utils.NUMERICAL_TYPE,
39
+ resolve_parameters: dict[str, utils.NUMERICAL_TYPE],
40
+ ) -> utils.NUMERICAL_TYPE:
41
+ if isinstance(param, Parameter):
42
+ return param.get_value(resolve_parameters)
43
+ if utils.is_numerical(param):
44
+ return param
45
+ raise ValueError(f"Encounter the unknown parameter type: {param} ({type(param)})")
46
+
47
+
48
+ @attrs.frozen
49
+ class Parameter:
50
+ # If name is None, it is a composite parameter.
51
+ name: str | None
52
+
53
+ eval_expression: ParameterRepresentation | None = None
54
+
55
+ def __attrs_post_init__(self):
56
+ if self.name is None and self.eval_expression is None:
57
+ raise ValueError(
58
+ "For a parameter, must specify a name or an eval_expression"
59
+ )
60
+ if self.name is None or self.eval_expression is None:
61
+ return
62
+
63
+ raise ValueError(
64
+ "For a parameter, only one of name or eval_expression should be None."
65
+ )
66
+
67
+ def __repr__(self):
68
+ if self.eval_expression is None:
69
+ return self.name
70
+
71
+ op = self.eval_expression.op
72
+ left_param = self.eval_expression.left_param
73
+ right_param = self.eval_expression.right_param
74
+ # TODO having a better parentheses handling.
75
+ if op == utils.Op.ADD:
76
+ return f"({left_param}+{right_param})"
77
+ if op == utils.Op.SUB:
78
+ return f"({left_param}-{right_param})"
79
+ if op == utils.Op.MUL:
80
+ return f"({left_param}*{right_param})"
81
+ if op == utils.Op.DIV:
82
+ return f"({left_param}/{right_param})"
83
+
84
+ def get_value(
85
+ self, resolve_parameters: dict[str, utils.NUMERICAL_TYPE]
86
+ ) -> utils.NUMERICAL_TYPE:
87
+ if self.eval_expression is None:
88
+ val = resolve_parameters.get(self.name, NOT_FOUND)
89
+ if val is NOT_FOUND:
90
+ raise ValueError(f"Cannot resolve Parameter named: {self.name}")
91
+ return val
92
+ op = self.eval_expression.op
93
+ left_param = eval_parameter(self.eval_expression.left_param, resolve_parameters)
94
+ right_param = eval_parameter(
95
+ self.eval_expression.right_param, resolve_parameters
96
+ )
97
+
98
+ if op == utils.Op.ADD:
99
+ return left_param + right_param
100
+ if op == utils.Op.SUB:
101
+ return left_param - right_param
102
+ if op == utils.Op.MUL:
103
+ return left_param * right_param
104
+ if op == utils.Op.DIV:
105
+ return left_param / right_param
106
+
107
+ raise ValueError(f"Encountered unknown {op=} when evaluation the point.")
108
+
109
+ def __add__(self, other):
110
+ if not utils.is_numerical_or_parameter(other):
111
+ return NotImplemented
112
+ return Parameter(
113
+ name=None,
114
+ eval_expression=ParameterRepresentation(
115
+ op=utils.Op.ADD, left_param=self, right_param=other
116
+ ),
117
+ )
118
+
119
+ def __radd__(self, other):
120
+ if not utils.is_numerical_or_parameter(other):
121
+ return NotImplemented
122
+ return Parameter(
123
+ name=None,
124
+ eval_expression=ParameterRepresentation(
125
+ op=utils.Op.ADD, left_param=other, right_param=self
126
+ ),
127
+ )
128
+
129
+ def __sub__(self, other):
130
+ if not utils.is_numerical_or_parameter(other):
131
+ return NotImplemented
132
+ return Parameter(
133
+ name=None,
134
+ eval_expression=ParameterRepresentation(
135
+ op=utils.Op.SUB, left_param=self, right_param=other
136
+ ),
137
+ )
138
+
139
+ def __rsub__(self, other):
140
+ if not utils.is_numerical_or_parameter(other):
141
+ return NotImplemented
142
+ return Parameter(
143
+ name=None,
144
+ eval_expression=ParameterRepresentation(
145
+ op=utils.Op.SUB, left_param=other, right_param=self
146
+ ),
147
+ )
148
+
149
+ def __mul__(self, other):
150
+ if not utils.is_numerical_or_parameter(other):
151
+ return NotImplemented
152
+ return Parameter(
153
+ name=None,
154
+ eval_expression=ParameterRepresentation(
155
+ op=utils.Op.MUL, left_param=self, right_param=other
156
+ ),
157
+ )
158
+
159
+ def __rmul__(self, other):
160
+ if not utils.is_numerical_or_parameter(other):
161
+ return NotImplemented
162
+ return Parameter(
163
+ name=None,
164
+ eval_expression=ParameterRepresentation(
165
+ op=utils.Op.MUL, left_param=other, right_param=self
166
+ ),
167
+ )
168
+
169
+ def __truediv__(self, other):
170
+ if not utils.is_numerical_or_parameter(other):
171
+ return NotImplemented
172
+ return Parameter(
173
+ name=None,
174
+ eval_expression=ParameterRepresentation(
175
+ op=utils.Op.DIV, left_param=self, right_param=other
176
+ ),
177
+ )
178
+
179
+ def __rtruediv__(self, other):
180
+ if not utils.is_numerical_or_parameter(other):
181
+ return NotImplemented
182
+ return Parameter(
183
+ name=None,
184
+ eval_expression=ParameterRepresentation(
185
+ op=utils.Op.DIV, left_param=other, right_param=self
186
+ ),
187
+ )
@@ -0,0 +1,128 @@
1
+ # Copyright: 2025 The PEPFlow Developers
2
+ #
3
+ # Licensed to the Apache Software Foundation (ASF) under one
4
+ # or more contributor license agreements. See the NOTICE file
5
+ # distributed with this work for additional information
6
+ # regarding copyright ownership. The ASF licenses this file
7
+ # to you under the Apache License, Version 2.0 (the
8
+ # "License"); you may not use this file except in compliance
9
+ # with the License. You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing,
14
+ # software distributed under the License is distributed on an
15
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16
+ # KIND, either express or implied. See the License for the
17
+ # specific language governing permissions and limitations
18
+ # under the License.
19
+
20
+ from typing import Iterator
21
+
22
+ import numpy as np
23
+ import pytest
24
+ import sympy as sp
25
+
26
+ from pepflow import pep_context as pc
27
+ from pepflow.expression_manager import ExpressionManager
28
+ from pepflow.parameter import Parameter
29
+ from pepflow.point import Point
30
+ from pepflow.scalar import Scalar
31
+
32
+
33
+ @pytest.fixture
34
+ def pep_context() -> Iterator[pc.PEPContext]:
35
+ """Prepare the pep context and reset the context to None at the end."""
36
+ ctx = pc.PEPContext("test").set_as_current()
37
+ yield ctx
38
+ pc.set_current_context(None)
39
+
40
+
41
+ def test_parameter_interact_with_scalar(pep_context: pc.PEPContext):
42
+ pm1 = Parameter("pm1")
43
+ s1 = Scalar(is_basis=True, tags=["s1"])
44
+
45
+ _ = pm1 + s1
46
+ _ = s1 + pm1
47
+ _ = pm1 - s1
48
+ _ = s1 - pm1
49
+ _ = s1 * pm1
50
+ _ = pm1 * s1
51
+ _ = s1 / pm1
52
+
53
+
54
+ def test_parameter_interact_with_point(pep_context: pc.PEPContext):
55
+ pm1 = Parameter("pm1")
56
+ p1 = Point(is_basis=True, tags=["p1"])
57
+
58
+ _ = p1 * pm1
59
+ _ = pm1 * p1
60
+ _ = p1 / pm1
61
+
62
+
63
+ def test_parameter_composition_with_point_and_scalar(pep_context: pc.PEPContext):
64
+ pm1 = Parameter("pm1")
65
+ pm2 = Parameter("pm2")
66
+ p1 = Point(is_basis=True, tags=["p1"])
67
+ s1 = Scalar(is_basis=True, tags=["s1"])
68
+
69
+ s2 = s1 + pm1 + pm2 * p1**2
70
+ assert str(s2) == "s1+pm1+pm2*|p1|^2"
71
+
72
+
73
+ def test_parameter_composition(pep_context: pc.PEPContext):
74
+ pm1 = Parameter("pm1")
75
+ pm2 = Parameter("pm2")
76
+
77
+ pp = (pm1 + 2) * pm2
78
+ assert str(pp) == "((pm1+2)*pm2)"
79
+ assert pp.get_value({"pm1": 3, "pm2": 6}) == 30
80
+
81
+ pp2 = (pm1 + sp.Rational(1, 2)) * pm2
82
+ assert str(pp2) == "((pm1+1/2)*pm2)"
83
+ assert pp2.get_value({"pm1": sp.Rational(1, 3), "pm2": sp.Rational(6, 5)}) == 1
84
+
85
+
86
+ def test_expression_manager_eval_with_parameter(pep_context: pc.PEPContext):
87
+ pm1 = Parameter("pm1")
88
+ p1 = Point(is_basis=True, tags=["p1"])
89
+ p2 = Point(is_basis=True, tags=["p2"])
90
+ p3 = pm1 * p1 + p2 / 4
91
+
92
+ em = ExpressionManager(pep_context, {"pm1": 2.3})
93
+ np.testing.assert_allclose(em.eval_point(p3).vector, np.array([2.3, 0.25]))
94
+
95
+ em = ExpressionManager(pep_context, {"pm1": 3.4})
96
+ np.testing.assert_allclose(em.eval_point(p3).vector, np.array([3.4, 0.25]))
97
+
98
+
99
+ def test_expression_manager_eval_with_parameter_scalar(pep_context: pc.PEPContext):
100
+ pm1 = Parameter("pm1")
101
+ pm2 = Parameter("pm2")
102
+ p1 = Point(is_basis=True, tags=["p1"])
103
+ p2 = Point(is_basis=True, tags=["p2"])
104
+ s1 = Scalar(is_basis=True, tags=["s1"])
105
+ s2 = pm1 * p1 * p2 + pm2 + s1
106
+
107
+ em = ExpressionManager(pep_context, {"pm1": 2.4, "pm2": 4.3})
108
+ assert np.isclose(em.eval_scalar(s2).constant, 4.3)
109
+ np.testing.assert_allclose(em.eval_scalar(s2).vector, np.array([1]))
110
+ np.testing.assert_allclose(
111
+ em.eval_scalar(s2).matrix, np.array([[0, 1.2], [1.2, 0]])
112
+ )
113
+
114
+
115
+ def test_expression_manager_eval_composition(pep_context: pc.PEPContext):
116
+ pm1 = Parameter("pm1")
117
+ pm2 = Parameter("pm2")
118
+ p1 = Point(is_basis=True, tags=["p1"])
119
+ p2 = Point(is_basis=True, tags=["p2"])
120
+ s1 = Scalar(is_basis=True, tags=["s1"])
121
+
122
+ s2 = 1 / pm1 * p1 * p2 + (pm2 + 1) * s1
123
+ em = ExpressionManager(pep_context, {"pm1": 0.5, "pm2": 4.3})
124
+ assert np.isclose(em.eval_scalar(s2).constant, 0)
125
+ np.testing.assert_allclose(em.eval_scalar(s2).vector, np.array([5.3]))
126
+ np.testing.assert_allclose(
127
+ em.eval_scalar(s2).matrix, np.array([[0, 1.0], [1.0, 0]])
128
+ )