pepflow 0.1.4a1__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pepflow/point.py CHANGED
@@ -27,37 +27,65 @@ import numpy as np
27
27
 
28
28
  from pepflow import pep_context as pc
29
29
  from pepflow import utils
30
- from pepflow.scalar import EvalExpressionScalar, Scalar
30
+ from pepflow.scalar import Scalar, ScalarRepresentation
31
31
 
32
32
 
33
33
  def is_numerical_or_point(val: Any) -> bool:
34
- return utils.is_numerical(val) or isinstance(val, Point)
34
+ return utils.is_numerical_or_parameter(val) or isinstance(val, Point)
35
35
 
36
36
 
37
37
  def is_numerical_or_evaluatedpoint(val: Any) -> bool:
38
- return utils.is_numerical(val) or isinstance(val, EvaluatedPoint)
38
+ return utils.is_numerical_or_parameter(val) or isinstance(val, EvaluatedPoint)
39
39
 
40
40
 
41
41
  @attrs.frozen
42
- class EvalExpressionPoint:
42
+ class PointRepresentation:
43
43
  op: utils.Op
44
44
  left_point: Point | float
45
45
  right_point: Point | float
46
46
 
47
47
 
48
+ @attrs.frozen
49
+ class ZeroPoint:
50
+ """A special class to represent 0 in Point."""
51
+
52
+ pass
53
+
54
+
48
55
  @attrs.frozen
49
56
  class EvaluatedPoint:
57
+ """
58
+ The concrete representation of the abstract :class:`Point`.
59
+
60
+ Each abstract basis :class:`Point` object has a unique concrete
61
+ representation as a unit vector. The concrete representations of
62
+ linear combinations of abstract basis :class:`Point` objects are
63
+ linear combinations of the unit vectors. This information is stored
64
+ in the `vector` attribute.
65
+
66
+ :class:`EvaluatedPoint` objects can be constructed as linear combinations
67
+ of other :class:`EvaluatedPoint` objects. Let `a` and `b` be some numeric
68
+ data type. Let `x` and `y` be :class:`EvaluatedPoint` objects. Then, we
69
+ can form a new :class:`EvaluatedPoint` object: `a*x+b*y`.
70
+
71
+ Attributes:
72
+ vector (np.ndarray): The concrete representation of an
73
+ abstract :class:`Point`.
74
+ """
75
+
50
76
  vector: np.ndarray
51
77
 
78
+ @classmethod
79
+ def zero(cls, num_basis_points: int):
80
+ return EvaluatedPoint(vector=np.zeros(num_basis_points))
81
+
52
82
  def __add__(self, other):
53
83
  if isinstance(other, EvaluatedPoint):
54
84
  return EvaluatedPoint(vector=self.vector + other.vector)
55
85
  elif utils.is_numerical(other):
56
86
  return EvaluatedPoint(vector=self.vector + other)
57
87
  else:
58
- raise ValueError(
59
- f"Unsupported add operation between EvaluatedPoint and {type(other)}"
60
- )
88
+ return NotImplemented
61
89
 
62
90
  def __radd__(self, other):
63
91
  return self.__add__(other)
@@ -68,9 +96,7 @@ class EvaluatedPoint:
68
96
  elif utils.is_numerical(other):
69
97
  return EvaluatedPoint(vector=self.vector - other)
70
98
  else:
71
- raise ValueError(
72
- f"Unsupported sub operation between EvaluatedPoint and {type(other)}"
73
- )
99
+ return NotImplemented
74
100
 
75
101
  def __rsub__(self, other):
76
102
  if isinstance(other, EvaluatedPoint):
@@ -78,30 +104,51 @@ class EvaluatedPoint:
78
104
  elif utils.is_numerical(other):
79
105
  return EvaluatedPoint(vector=other - self.vector)
80
106
  else:
81
- raise ValueError(
82
- f"Unsupported sub operation between EvaluatedPoint and {type(other)}"
83
- )
107
+ return NotImplemented
84
108
 
85
109
  def __mul__(self, other):
86
- assert utils.is_numerical(other)
110
+ if not utils.is_numerical(other):
111
+ return NotImplemented
87
112
  return EvaluatedPoint(vector=self.vector * other)
88
113
 
89
114
  def __rmul__(self, other):
90
- assert utils.is_numerical(other)
115
+ if not utils.is_numerical(other):
116
+ return NotImplemented
91
117
  return EvaluatedPoint(vector=other * self.vector)
92
118
 
93
119
  def __truediv__(self, other):
94
- assert utils.is_numerical(other)
120
+ if not utils.is_numerical(other):
121
+ return NotImplemented
95
122
  return EvaluatedPoint(vector=self.vector / other)
96
123
 
97
124
 
98
125
  @attrs.frozen
99
126
  class Point:
127
+ """
128
+ A :class:`Point` object represents an element of a pre-Hilbert space.
129
+ Examples include a point or a gradient.
130
+
131
+ :class:`Point` objects can be constructed as linear combinations of
132
+ other :class:`Point` objects. Let `a` and `b` be some numeric data type.
133
+ Let `x` and `y` be :class:`Point` objects. Then, we can form a new
134
+ :class:`Point` object: `a*x+b*y`.
135
+
136
+ The inner product of two :class:`Point` objects can also be taken.
137
+ Let `x` and `y` be :class:`Point` objects. Then, their inner product is
138
+ `x*y` and returns a :class:`Scalar` object.
139
+
140
+ Attributes:
141
+ is_basis (bool): `True` if this point is not formed through a linear
142
+ combination of other points. `False` otherwise.
143
+ tags (list[str]): A list that contains tags that can be used to
144
+ identify the :class:`Point` object. Tags should be unique.
145
+ """
146
+
100
147
  # If true, the point is the basis for the evaluations of G
101
148
  is_basis: bool
102
149
 
103
- # How to evaluate the point.
104
- eval_expression: EvalExpressionPoint | None = None
150
+ # The representation of point used for evaluation.
151
+ eval_expression: PointRepresentation | ZeroPoint | None = None
105
152
 
106
153
  # Human tagged value for the Point
107
154
  tags: list[str] = attrs.field(factory=list)
@@ -120,13 +167,27 @@ class Point:
120
167
  raise RuntimeError("Did you forget to create a context?")
121
168
  pep_context.add_point(self)
122
169
 
170
+ @staticmethod
171
+ def zero() -> Point:
172
+ return Point(is_basis=False, eval_expression=ZeroPoint(), tags=["0"])
173
+
123
174
  @property
124
175
  def tag(self):
176
+ """Returns the most recently added tag.
177
+
178
+ Returns:
179
+ str: The most recently added tag of this :class:`Point` object.
180
+ """
125
181
  if len(self.tags) == 0:
126
182
  raise ValueError("Point should have a name.")
127
183
  return self.tags[-1]
128
184
 
129
185
  def add_tag(self, tag: str) -> None:
186
+ """Add a new tag for this :class:`Point` object.
187
+
188
+ Args:
189
+ tag (str): The new tag to be added to the `tags` list.
190
+ """
130
191
  self.tags.append(tag)
131
192
 
132
193
  def __repr__(self):
@@ -135,18 +196,15 @@ class Point:
135
196
  return super().__repr__()
136
197
 
137
198
  def _repr_latex_(self):
138
- s = repr(self)
139
- s = s.replace("star", r"\star")
140
- s = s.replace("gradient_", r"\nabla ")
141
- s = s.replace("|", r"\|")
142
- return rf"$\\displaystyle {s}$"
199
+ return utils.str_to_latex(repr(self))
143
200
 
144
201
  # TODO: add a validator that `is_basis` and `eval_expression` are properly setup.
145
202
  def __add__(self, other):
146
- assert isinstance(other, Point)
203
+ if not isinstance(other, Point):
204
+ return NotImplemented
147
205
  return Point(
148
206
  is_basis=False,
149
- eval_expression=EvalExpressionPoint(utils.Op.ADD, self, other),
207
+ eval_expression=PointRepresentation(utils.Op.ADD, self, other),
150
208
  tags=[f"{self.tag}+{other.tag}"],
151
209
  )
152
210
 
@@ -154,72 +212,78 @@ class Point:
154
212
  # TODO: come up with better way to handle this
155
213
  if other == 0:
156
214
  return self
157
- assert isinstance(other, Point)
215
+ if not isinstance(other, Point):
216
+ return NotImplemented
158
217
  return Point(
159
218
  is_basis=False,
160
- eval_expression=EvalExpressionPoint(utils.Op.ADD, other, self),
219
+ eval_expression=PointRepresentation(utils.Op.ADD, other, self),
161
220
  tags=[f"{other.tag}+{self.tag}"],
162
221
  )
163
222
 
164
223
  def __sub__(self, other):
165
- assert isinstance(other, Point)
224
+ if not isinstance(other, Point):
225
+ return NotImplemented
166
226
  tag_other = utils.parenthesize_tag(other)
167
227
  return Point(
168
228
  is_basis=False,
169
- eval_expression=EvalExpressionPoint(utils.Op.SUB, self, other),
229
+ eval_expression=PointRepresentation(utils.Op.SUB, self, other),
170
230
  tags=[f"{self.tag}-{tag_other}"],
171
231
  )
172
232
 
173
233
  def __rsub__(self, other):
174
- assert isinstance(other, Point)
234
+ if not isinstance(other, Point):
235
+ return NotImplemented
175
236
  tag_self = utils.parenthesize_tag(self)
176
237
  return Point(
177
238
  is_basis=False,
178
- eval_expression=EvalExpressionPoint(utils.Op.SUB, other, self),
239
+ eval_expression=PointRepresentation(utils.Op.SUB, other, self),
179
240
  tags=[f"{other.tag}-{tag_self}"],
180
241
  )
181
242
 
182
243
  def __mul__(self, other):
183
- # TODO allow the other to be point so that we return a scalar.
184
- assert is_numerical_or_point(other)
244
+ if not is_numerical_or_point(other):
245
+ return NotImplemented
185
246
  tag_self = utils.parenthesize_tag(self)
186
- if utils.is_numerical(other):
247
+ if utils.is_numerical_or_parameter(other):
248
+ tag_other = utils.numerical_str(other)
187
249
  return Point(
188
250
  is_basis=False,
189
- eval_expression=EvalExpressionPoint(utils.Op.MUL, self, other),
190
- tags=[f"{tag_self}*{other:.4g}"],
251
+ eval_expression=PointRepresentation(utils.Op.MUL, self, other),
252
+ tags=[f"{tag_self}*{tag_other}"],
191
253
  )
192
254
  else:
193
255
  tag_other = utils.parenthesize_tag(other)
194
256
  return Scalar(
195
257
  is_basis=False,
196
- eval_expression=EvalExpressionScalar(utils.Op.MUL, self, other),
258
+ eval_expression=ScalarRepresentation(utils.Op.MUL, self, other),
197
259
  tags=[f"{tag_self}*{tag_other}"],
198
260
  )
199
261
 
200
262
  def __rmul__(self, other):
201
- # TODO allow the other to be point so that we return a scalar.
202
- assert is_numerical_or_point(other)
263
+ if not is_numerical_or_point(other):
264
+ return NotImplemented
203
265
  tag_self = utils.parenthesize_tag(self)
204
- if utils.is_numerical(other):
266
+ if utils.is_numerical_or_parameter(other):
267
+ tag_other = utils.numerical_str(other)
205
268
  return Point(
206
269
  is_basis=False,
207
- eval_expression=EvalExpressionPoint(utils.Op.MUL, other, self),
208
- tags=[f"{other:.4g}*{tag_self}"],
270
+ eval_expression=PointRepresentation(utils.Op.MUL, other, self),
271
+ tags=[f"{tag_other}*{tag_self}"],
209
272
  )
210
273
  else:
211
274
  tag_other = utils.parenthesize_tag(other)
212
275
  return Scalar(
213
276
  is_basis=False,
214
- eval_expression=EvalExpressionScalar(utils.Op.MUL, other, self),
277
+ eval_expression=ScalarRepresentation(utils.Op.MUL, other, self),
215
278
  tags=[f"{tag_other}*{tag_self}"],
216
279
  )
217
280
 
218
281
  def __pow__(self, power):
219
- assert power == 2
282
+ if power != 2:
283
+ return NotImplemented
220
284
  return Scalar(
221
285
  is_basis=False,
222
- eval_expression=EvalExpressionScalar(utils.Op.MUL, self, self),
286
+ eval_expression=ScalarRepresentation(utils.Op.MUL, self, self),
223
287
  tags=[rf"|{self.tag}|^{power}"],
224
288
  )
225
289
 
@@ -227,17 +291,19 @@ class Point:
227
291
  tag_self = utils.parenthesize_tag(self)
228
292
  return Point(
229
293
  is_basis=False,
230
- eval_expression=EvalExpressionPoint(utils.Op.MUL, -1, self),
294
+ eval_expression=PointRepresentation(utils.Op.MUL, -1, self),
231
295
  tags=[f"-{tag_self}"],
232
296
  )
233
297
 
234
298
  def __truediv__(self, other):
235
- assert utils.is_numerical(other)
299
+ if not utils.is_numerical_or_parameter(other):
300
+ return NotImplemented
236
301
  tag_self = utils.parenthesize_tag(self)
302
+ tag_other = f"1/{utils.numerical_str(other)}"
237
303
  return Point(
238
304
  is_basis=False,
239
- eval_expression=EvalExpressionPoint(utils.Op.DIV, self, other),
240
- tags=[f"1/{other:.4g}*{tag_self}"],
305
+ eval_expression=PointRepresentation(utils.Op.DIV, self, other),
306
+ tags=[f"{tag_other}*{tag_self}"],
241
307
  )
242
308
 
243
309
  def __hash__(self):
@@ -249,6 +315,20 @@ class Point:
249
315
  return self.uid == other.uid
250
316
 
251
317
  def eval(self, ctx: pc.PEPContext | None = None) -> np.ndarray:
318
+ """
319
+ Return the concrete representation of this :class:`Point`.
320
+ Concrete representations of :class:`Point` objects are
321
+ :class:`EvaluatedPoint` objects.
322
+
323
+ Args:
324
+ ctx (:class:`PEPContext` | None): The :class:`PEPContext` object
325
+ we consider. `None` if we consider the current global
326
+ :class:`PEPContext` object.
327
+
328
+ Returns:
329
+ :class:`EvaluatedPoint`: The concrete representation of
330
+ this :class:`Point`.
331
+ """
252
332
  from pepflow.expression_manager import ExpressionManager
253
333
 
254
334
  # Note this can be inefficient.
@@ -258,3 +338,29 @@ class Point:
258
338
  raise RuntimeError("Did you forget to create a context?")
259
339
  em = ExpressionManager(ctx)
260
340
  return em.eval_point(self).vector
341
+
342
+ def repr_by_basis(self, ctx: pc.PEPContext | None = None) -> str:
343
+ """
344
+ Express this :class:`Point` object as the linear combination of
345
+ the basis :class:`Point` objects of the given :class:`PEPContext`.
346
+ This linear combination is expressed as a `str` where, to refer to
347
+ the basis :class:`Point` objects, we use their tags.
348
+
349
+ Args:
350
+ ctx (:class:`PEPContext`): The :class:`PEPContext` object
351
+ whose basis :class:`Point` objects we consider. `None` if
352
+ we consider the current global :class:`PEPContext` object.
353
+
354
+ Returns:
355
+ str: The representation of this :class:`Point` object in terms of
356
+ the basis :class:`Point` objects of the given :class:`PEPContext`.
357
+ """
358
+ from pepflow.expression_manager import ExpressionManager
359
+
360
+ # Note this can be inefficient.
361
+ if ctx is None:
362
+ ctx = pc.get_current_context()
363
+ if ctx is None:
364
+ raise RuntimeError("Did you forget to create a context?")
365
+ em = ExpressionManager(ctx)
366
+ return em.repr_point_by_basis(self)
pepflow/point_test.py CHANGED
@@ -162,3 +162,15 @@ def test_expression_manager_eval_point_large_scale(pep_context):
162
162
  pm.eval_point(pp)
163
163
 
164
164
  assert (time.time() - t) < 0.5
165
+
166
+
167
+ def test_zero_point(pep_context):
168
+ _ = point.Point(is_basis=True, tags=["p1"])
169
+ p0 = point.Point.zero()
170
+
171
+ pm = exm.ExpressionManager(pep_context)
172
+ np.testing.assert_allclose(pm.eval_point(p0).vector, np.array([0]))
173
+
174
+ _ = point.Point(is_basis=True, tags=["p2"])
175
+ pm = exm.ExpressionManager(pep_context)
176
+ np.testing.assert_allclose(pm.eval_point(p0).vector, np.array([0, 0]))