pepflow 0.1.4a1__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pepflow/scalar.py CHANGED
@@ -34,28 +34,78 @@ if TYPE_CHECKING:
34
34
 
35
35
 
36
36
  def is_numerical_or_scalar(val: Any) -> bool:
37
- return utils.is_numerical(val) or isinstance(val, Scalar)
37
+ return utils.is_numerical_or_parameter(val) or isinstance(val, Scalar)
38
38
 
39
39
 
40
40
  def is_numerical_or_evaluatedscalar(val: Any) -> bool:
41
- return utils.is_numerical(val) or isinstance(val, EvaluatedScalar)
41
+ return utils.is_numerical_or_parameter(val) or isinstance(val, EvaluatedScalar)
42
42
 
43
43
 
44
44
  @attrs.frozen
45
- class EvalExpressionScalar:
45
+ class ScalarRepresentation:
46
46
  op: utils.Op
47
47
  left_scalar: Point | Scalar | float
48
48
  right_scalar: Point | Scalar | float
49
49
 
50
50
 
51
+ @attrs.frozen
52
+ class ZeroScalar:
53
+ """A special class to represent 0 in scalar."""
54
+
55
+ pass
56
+
57
+
51
58
  @attrs.frozen
52
59
  class EvaluatedScalar:
60
+ """
61
+ The concrete representation of the abstract :class:`Scalar`.
62
+
63
+ Each abstract basis :class:`Scalar` object has a unique concrete
64
+ representation as a unit vector. The concrete representations of
65
+ linear combinations of abstract basis :class:`Scalar` objects
66
+ are linear combinations of the unit vectors. This information is
67
+ stored in the `vector` attribute.
68
+
69
+ Abstract :class:`Scalar` objects can be formed through taking the
70
+ inner product of two abstract :class:`Point` objects. The
71
+ concrete representation of an abstract :class:`Scalar` object formed
72
+ this way is the outer product of the concrete representations of the
73
+ two abstract :class:`Point` objects, i.e., a matrix. This information
74
+ is stored in the `matrix` attribute.
75
+
76
+ Abstract :class:`Scalar` objects can be added or subtracted with
77
+ numeric data types. This information is stored in the `constant`
78
+ attribute.
79
+
80
+ :class:`EvaluatedScalar` objects can be constructed as linear combinations
81
+ of other :class:`EvaluatedScalar` objects. Let `a` and `b` be some numeric
82
+ data type. Let `u` and `v` be :class:`EvaluatedScalar` objects. Then, we
83
+ can form a new :class:`EvaluatedScalar` object: `a*u+b*v`.
84
+
85
+ Attributes:
86
+ vector (np.ndarray): The vector component of the concrete
87
+ representation of the abstract :class:`Scalar`.
88
+ matrix (np.ndarray): The matrix component of the concrete
89
+ representation of the abstract :class:`Scalar`.
90
+ constant (float): The constant component of the concrete
91
+ representation of the abstract :class:`Scalar`.
92
+ """
93
+
53
94
  vector: np.ndarray
54
95
  matrix: np.ndarray
55
96
  constant: float
56
97
 
98
+ @classmethod
99
+ def zero(cls, num_basis_scalars: int, num_basis_points: int):
100
+ return EvaluatedScalar(
101
+ vector=np.zeros(num_basis_scalars),
102
+ matrix=np.zeros((num_basis_points, num_basis_points)),
103
+ constant=0.0,
104
+ )
105
+
57
106
  def __add__(self, other):
58
- assert is_numerical_or_evaluatedscalar(other)
107
+ if not is_numerical_or_evaluatedscalar(other):
108
+ return NotImplemented
59
109
  if utils.is_numerical(other):
60
110
  return EvaluatedScalar(
61
111
  vector=self.vector, matrix=self.matrix, constant=self.constant + other
@@ -68,7 +118,8 @@ class EvaluatedScalar:
68
118
  )
69
119
 
70
120
  def __radd__(self, other):
71
- assert is_numerical_or_evaluatedscalar(other)
121
+ if not is_numerical_or_evaluatedscalar(other):
122
+ return NotImplemented
72
123
  if utils.is_numerical(other):
73
124
  return EvaluatedScalar(
74
125
  vector=self.vector, matrix=self.matrix, constant=other + self.constant
@@ -81,7 +132,8 @@ class EvaluatedScalar:
81
132
  )
82
133
 
83
134
  def __sub__(self, other):
84
- assert is_numerical_or_evaluatedscalar(other)
135
+ if not is_numerical_or_evaluatedscalar(other):
136
+ return NotImplemented
85
137
  if utils.is_numerical(other):
86
138
  return EvaluatedScalar(
87
139
  vector=self.vector, matrix=self.matrix, constant=self.constant - other
@@ -94,7 +146,8 @@ class EvaluatedScalar:
94
146
  )
95
147
 
96
148
  def __rsub__(self, other):
97
- assert is_numerical_or_evaluatedscalar(other)
149
+ if not is_numerical_or_evaluatedscalar(other):
150
+ return NotImplemented
98
151
  if utils.is_numerical(other):
99
152
  return EvaluatedScalar(
100
153
  vector=-self.vector, matrix=-self.matrix, constant=other - self.constant
@@ -107,7 +160,8 @@ class EvaluatedScalar:
107
160
  )
108
161
 
109
162
  def __mul__(self, other):
110
- assert utils.is_numerical(other)
163
+ if not utils.is_numerical(other):
164
+ return NotImplemented
111
165
  return EvaluatedScalar(
112
166
  vector=self.vector * other,
113
167
  matrix=self.matrix * other,
@@ -115,7 +169,8 @@ class EvaluatedScalar:
115
169
  )
116
170
 
117
171
  def __rmul__(self, other):
118
- assert utils.is_numerical(other)
172
+ if not utils.is_numerical(other):
173
+ return NotImplemented
119
174
  return EvaluatedScalar(
120
175
  vector=other * self.vector,
121
176
  matrix=other * self.matrix,
@@ -126,7 +181,8 @@ class EvaluatedScalar:
126
181
  return self.__rmul__(other=-1)
127
182
 
128
183
  def __truediv__(self, other):
129
- assert utils.is_numerical(other)
184
+ if not utils.is_numerical(other):
185
+ return NotImplemented
130
186
  return EvaluatedScalar(
131
187
  vector=self.vector / other,
132
188
  matrix=self.matrix / other,
@@ -136,11 +192,27 @@ class EvaluatedScalar:
136
192
 
137
193
  @attrs.frozen
138
194
  class Scalar:
195
+ """
196
+ A :class:`Scalar` object represents linear combination of functions values,
197
+ inner products of points, and constant scalar values.
198
+
199
+ :class:`Scalar` objects can be constructed as linear combinations of
200
+ other :class:`Scalar` objects. Let `a` and `b` be some numeric data type.
201
+ Let `x` and `y` be :class:`Scalar` objects. Then, we can form a new
202
+ :class:`Scalar` object: `a*x+b*y`.
203
+
204
+ Attributes:
205
+ is_basis (bool): True if this point is not formed through a linear
206
+ combination of other points. False otherwise.
207
+ tags (list[str]): A list that contains tags that can be used to
208
+ identify the :class:`Point` object. Tags should be unique.
209
+ """
210
+
139
211
  # If true, the scalar is the basis for the evaluations of F
140
212
  is_basis: bool
141
213
 
142
- # Not sure on this yet
143
- eval_expression: EvalExpressionScalar | None = None
214
+ # The representation of scalar used for evaluation.
215
+ eval_expression: ScalarRepresentation | ZeroScalar | None = None
144
216
 
145
217
  # Human tagged value for the scalar
146
218
  tags: list[str] = attrs.field(factory=list)
@@ -159,13 +231,27 @@ class Scalar:
159
231
  raise RuntimeError("Did you forget to create a context?")
160
232
  pep_context.add_scalar(self)
161
233
 
234
+ @staticmethod
235
+ def zero() -> Scalar:
236
+ return Scalar(is_basis=False, eval_expression=ZeroScalar(), tags=["0"])
237
+
162
238
  @property
163
239
  def tag(self):
240
+ """Returns the most recently added tag.
241
+
242
+ Returns:
243
+ str: The most recently added tag of this :class:`Scalar` object.
244
+ """
164
245
  if len(self.tags) == 0:
165
246
  raise ValueError("Scalar should have a name.")
166
247
  return self.tags[-1]
167
248
 
168
249
  def add_tag(self, tag: str):
250
+ """Add a new tag for this :class:`Scalar` object.
251
+
252
+ Args:
253
+ tag (str): The new tag to be added to the `tags` list.
254
+ """
169
255
  self.tags.append(tag)
170
256
 
171
257
  def __repr__(self):
@@ -174,93 +260,100 @@ class Scalar:
174
260
  return super().__repr__()
175
261
 
176
262
  def _repr_latex_(self):
177
- s = repr(self)
178
- s = s.replace("star", r"\star")
179
- s = s.replace("gradient_", r"\nabla ")
180
- return rf"$\\displaystyle {s}$"
263
+ return utils.str_to_latex(repr(self))
181
264
 
182
265
  def __add__(self, other):
183
- assert is_numerical_or_scalar(other)
184
- if utils.is_numerical(other):
185
- tag_other = f"{other:.4g}"
266
+ if not is_numerical_or_scalar(other):
267
+ return NotImplemented
268
+ if utils.is_numerical_or_parameter(other):
269
+ tag_other = utils.numerical_str(other)
186
270
  else:
187
271
  tag_other = other.tag
188
272
  return Scalar(
189
273
  is_basis=False,
190
- eval_expression=EvalExpressionScalar(utils.Op.ADD, self, other),
274
+ eval_expression=ScalarRepresentation(utils.Op.ADD, self, other),
191
275
  tags=[f"{self.tag}+{tag_other}"],
192
276
  )
193
277
 
194
278
  def __radd__(self, other):
195
- assert is_numerical_or_scalar(other)
196
- if utils.is_numerical(other):
197
- tag_other = f"{other:.4g}"
279
+ if not is_numerical_or_scalar(other):
280
+ return NotImplemented
281
+ if utils.is_numerical_or_parameter(other):
282
+ tag_other = utils.numerical_str(other)
198
283
  else:
199
284
  tag_other = other.tag
200
285
  return Scalar(
201
286
  is_basis=False,
202
- eval_expression=EvalExpressionScalar(utils.Op.ADD, other, self),
287
+ eval_expression=ScalarRepresentation(utils.Op.ADD, other, self),
203
288
  tags=[f"{tag_other}+{self.tag}"],
204
289
  )
205
290
 
206
291
  def __sub__(self, other):
207
- assert is_numerical_or_scalar(other)
208
- if utils.is_numerical(other):
209
- tag_other = f"{other:.4g}"
292
+ if not is_numerical_or_scalar(other):
293
+ return NotImplemented
294
+ if utils.is_numerical_or_parameter(other):
295
+ tag_other = utils.numerical_str(other)
210
296
  else:
211
297
  tag_other = utils.parenthesize_tag(other)
212
298
  return Scalar(
213
299
  is_basis=False,
214
- eval_expression=EvalExpressionScalar(utils.Op.SUB, self, other),
300
+ eval_expression=ScalarRepresentation(utils.Op.SUB, self, other),
215
301
  tags=[f"{self.tag}-{tag_other}"],
216
302
  )
217
303
 
218
304
  def __rsub__(self, other):
219
- assert is_numerical_or_scalar(other)
305
+ if not is_numerical_or_scalar(other):
306
+ return NotImplemented
220
307
  tag_self = utils.parenthesize_tag(self)
221
- if utils.is_numerical(other):
222
- tag_other = f"{other:.4g}"
308
+ if utils.is_numerical_or_parameter(other):
309
+ tag_other = utils.numerical_str(other)
223
310
  else:
224
311
  tag_other = other.tag
225
312
  return Scalar(
226
313
  is_basis=False,
227
- eval_expression=EvalExpressionScalar(utils.Op.SUB, other, self),
314
+ eval_expression=ScalarRepresentation(utils.Op.SUB, other, self),
228
315
  tags=[f"{tag_other}-{tag_self}"],
229
316
  )
230
317
 
231
318
  def __mul__(self, other):
232
- assert utils.is_numerical(other)
319
+ if not utils.is_numerical_or_parameter(other):
320
+ return NotImplemented
233
321
  tag_self = utils.parenthesize_tag(self)
322
+ tag_other = utils.numerical_str(other)
234
323
  return Scalar(
235
324
  is_basis=False,
236
- eval_expression=EvalExpressionScalar(utils.Op.MUL, self, other),
237
- tags=[f"{tag_self}*{other:.4g}"],
325
+ eval_expression=ScalarRepresentation(utils.Op.MUL, self, other),
326
+ tags=[f"{tag_self}*{tag_other}"],
238
327
  )
239
328
 
240
329
  def __rmul__(self, other):
241
- assert utils.is_numerical(other)
330
+ if not utils.is_numerical_or_parameter(other):
331
+ return NotImplemented
242
332
  tag_self = utils.parenthesize_tag(self)
333
+ tag_other = utils.numerical_str(other)
243
334
  return Scalar(
244
335
  is_basis=False,
245
- eval_expression=EvalExpressionScalar(utils.Op.MUL, other, self),
246
- tags=[f"{other:.4g}*{tag_self}"],
336
+ eval_expression=ScalarRepresentation(utils.Op.MUL, other, self),
337
+ tags=[f"{tag_other}*{tag_self}"],
247
338
  )
248
339
 
249
340
  def __neg__(self):
250
341
  tag_self = utils.parenthesize_tag(self)
251
342
  return Scalar(
252
343
  is_basis=False,
253
- eval_expression=EvalExpressionScalar(utils.Op.MUL, -1, self),
344
+ eval_expression=ScalarRepresentation(utils.Op.MUL, -1, self),
254
345
  tags=[f"-{tag_self}"],
255
346
  )
256
347
 
257
348
  def __truediv__(self, other):
258
- assert utils.is_numerical(other)
349
+ if not utils.is_numerical_or_parameter(other):
350
+ return NotImplemented
259
351
  tag_self = utils.parenthesize_tag(self)
352
+ tag_other = f"1/{utils.numerical_str(other)}"
260
353
  return Scalar(
261
354
  is_basis=False,
262
- eval_expression=EvalExpressionScalar(utils.Op.DIV, self, other),
263
- tags=[f"1/{other:.4g}*{tag_self}"],
355
+ eval_expression=ScalarRepresentation(utils.Op.DIV, self, other),
356
+ tags=[f"{tag_other}*{tag_self}"],
264
357
  )
265
358
 
266
359
  def __hash__(self):
@@ -271,22 +364,101 @@ class Scalar:
271
364
  return NotImplemented
272
365
  return self.uid == other.uid
273
366
 
274
- def le(self, other, name: str) -> ctr.Constraint:
367
+ def le(self, other: Scalar | float | int, name: str) -> ctr.Constraint:
368
+ """
369
+ Generate a :class:`Constraint` object that represents the inequality
370
+ `self` <= `other`.
371
+
372
+ Args:
373
+ other (:class:`Scalar` | float | int): The other side of the
374
+ relation.
375
+ name (str): The name of the generated :class:`Constraint` object.
376
+
377
+ Returns:
378
+ :class:`Constraint`: An object that represents the inequality
379
+ `self` <= `other`.
380
+ """
275
381
  return ctr.Constraint(self - other, comparator=utils.Comparator.LT, name=name)
276
382
 
277
- def lt(self, other, name: str) -> ctr.Constraint:
383
+ def lt(self, other: Scalar | float | int, name: str) -> ctr.Constraint:
384
+ """
385
+ Generate a :class:`Constraint` object that represents the inequality
386
+ `self` < `other`.
387
+
388
+ Args:
389
+ other (:class:`Scalar` | float | int): The other side of the
390
+ relation.
391
+ name (str): The name of the generated :class:`Constraint` object.
392
+
393
+ Returns:
394
+ :class:`Constraint`: An object that represents the inequality
395
+ `self` < `other`.
396
+ """
278
397
  return ctr.Constraint(self - other, comparator=utils.Comparator.LT, name=name)
279
398
 
280
- def ge(self, other, name: str) -> ctr.Constraint:
399
+ def ge(self, other: Scalar | float | int, name: str) -> ctr.Constraint:
400
+ """
401
+ Generate a :class:`Constraint` object that represents the inequality
402
+ `self` >= `other`.
403
+
404
+ Args:
405
+ other (:class:`Scalar` | float | int): The other side of the
406
+ relation.
407
+ name (str): The name of the generated :class:`Constraint` object.
408
+
409
+ Returns:
410
+ :class:`Constraint`: An object that represents the inequality
411
+ `self` >= `other`.
412
+ """
281
413
  return ctr.Constraint(self - other, comparator=utils.Comparator.GT, name=name)
282
414
 
283
- def gt(self, other, name: str) -> ctr.Constraint:
415
+ def gt(self, other: Scalar | float | int, name: str) -> ctr.Constraint:
416
+ """
417
+ Generate a :class:`Constraint` object that represents the inequality
418
+ `self` > `other`.
419
+
420
+ Args:
421
+ other (:class:`Scalar` | float | int): The other side of the
422
+ relation.
423
+ name (str): The name of the generated :class:`Constraint` object.
424
+
425
+ Returns:
426
+ :class:`Constraint`: An object that represents the inequality
427
+ `self` > `other`.
428
+ """
284
429
  return ctr.Constraint(self - other, comparator=utils.Comparator.GT, name=name)
285
430
 
286
- def eq(self, other, name: str) -> ctr.Constraint:
431
+ def eq(self, other: Scalar | float | int, name: str) -> ctr.Constraint:
432
+ """
433
+ Generate a :class:`Constraint` object that represents the inequality
434
+ `self` = `other`.
435
+
436
+ Args:
437
+ other (:class:`Scalar` | float | int): The other side of the
438
+ relation.
439
+ name (str): The name of the generated :class:`Constraint` object.
440
+
441
+ Returns:
442
+ :class:`Constraint`: An object that represents the inequality
443
+ `self` = `other`.
444
+ """
287
445
  return ctr.Constraint(self - other, comparator=utils.Comparator.EQ, name=name)
288
446
 
289
447
  def eval(self, ctx: pc.PEPContext | None = None) -> EvaluatedScalar:
448
+ """
449
+ Return the concrete representation of this :class:`Scalar`.
450
+ Concrete representations of :class:`Scalar` objects are
451
+ :class:`EvaluatedScalar` objects.
452
+
453
+ Args:
454
+ ctx (:class:`PEPContext` | None): The :class:`PEPContext` object
455
+ we consider. `None` if we consider the current global
456
+ :class:`PEPContext` object.
457
+
458
+ Returns:
459
+ :class:`EvaluatedScalar`: The concrete representation of
460
+ this :class:`Scalar`.
461
+ """
290
462
  from pepflow.expression_manager import ExpressionManager
291
463
 
292
464
  # Note this can be inefficient.
@@ -296,3 +468,44 @@ class Scalar:
296
468
  raise RuntimeError("Did you forget to create a context?")
297
469
  em = ExpressionManager(ctx)
298
470
  return em.eval_scalar(self)
471
+
472
+ def repr_by_basis(
473
+ self, ctx: pc.PEPContext | None = None, greedy_square: bool = True
474
+ ) -> str:
475
+ """Express this :class:`Scalar` object in terms of the basis
476
+ :class:`Point` and :class:`Scalar` objects of the given
477
+ :class:`PEPContext`.
478
+
479
+ A :class:`Scalar` can be formed by linear combinations of basis
480
+ :class:`Scalar` objects. A :class:`Scalar` can also be formed through
481
+ the inner product of two basis :class:`Point` objects. This function
482
+ returns the representation of this :class:`Scalar` object in terms of
483
+ the basis :class:`Point` and :class:`Scalar` objects as a `str` where,
484
+ to refer to the basis :class:`Point` and :class:`Scalar` objects, we
485
+ use their tags.
486
+
487
+ Args:
488
+ ctx (:class:`PEPContext`): The :class:`PEPContext` object
489
+ whose basis :class:`Point` and :class:`Scalar` objects we
490
+ consider. `None` if we consider the current global
491
+ `PEPContext` object.
492
+ greedy_square (bool): If `greedy_square` is true, the function will
493
+ try to return :math:`\\|a-b\\|^2` whenever possible. If not,
494
+ the function will return
495
+ :math:`\\|a\\|^2 - 2 * \\langle a, b \\rangle + \\|b\\|^2` instead.
496
+ `True` by default.
497
+
498
+ Returns:
499
+ str: The representation of this :class:`Scalar` object in terms of
500
+ the basis :class:`Point` and :class:`Scalar` objects of the given
501
+ :class:`PEPContext`.
502
+ """
503
+ from pepflow.expression_manager import ExpressionManager
504
+
505
+ # Note this can be inefficient.
506
+ if ctx is None:
507
+ ctx = pc.get_current_context()
508
+ if ctx is None:
509
+ raise RuntimeError("Did you forget to create a context?")
510
+ em = ExpressionManager(ctx)
511
+ return em.repr_scalar_by_basis(self, greedy_square=greedy_square)
pepflow/scalar_test.py CHANGED
@@ -205,3 +205,18 @@ def test_expression_manager_eval_scalar(pep_context: pc.PEPContext):
205
205
  ]
206
206
  ),
207
207
  )
208
+
209
+
210
+ def test_zero_scalar(pep_context):
211
+ _ = scalar.Scalar(is_basis=True, tags=["s1"])
212
+ _ = point.Point(is_basis=True, tags=["p1"])
213
+ s0 = scalar.Scalar.zero()
214
+
215
+ pm = exm.ExpressionManager(pep_context)
216
+ np.testing.assert_allclose(pm.eval_scalar(s0).vector, np.array([0]))
217
+ np.testing.assert_allclose(pm.eval_scalar(s0).matrix, np.array([[0]]))
218
+
219
+ _ = point.Point(is_basis=True, tags=["p2"])
220
+ pm = exm.ExpressionManager(pep_context)
221
+ np.testing.assert_allclose(pm.eval_scalar(s0).vector, np.array([0]))
222
+ np.testing.assert_allclose(pm.eval_scalar(s0).matrix, np.array([[0, 0], [0, 0]]))