openscvx 0.3.2.dev170__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/__init__.py +123 -0
- openscvx/_version.py +34 -0
- openscvx/algorithms/__init__.py +92 -0
- openscvx/algorithms/autotuning.py +24 -0
- openscvx/algorithms/base.py +351 -0
- openscvx/algorithms/optimization_results.py +215 -0
- openscvx/algorithms/penalized_trust_region.py +384 -0
- openscvx/config.py +437 -0
- openscvx/discretization/__init__.py +47 -0
- openscvx/discretization/discretization.py +236 -0
- openscvx/expert/__init__.py +23 -0
- openscvx/expert/byof.py +326 -0
- openscvx/expert/lowering.py +419 -0
- openscvx/expert/validation.py +357 -0
- openscvx/integrators/__init__.py +48 -0
- openscvx/integrators/runge_kutta.py +281 -0
- openscvx/lowered/__init__.py +30 -0
- openscvx/lowered/cvxpy_constraints.py +23 -0
- openscvx/lowered/cvxpy_variables.py +124 -0
- openscvx/lowered/dynamics.py +34 -0
- openscvx/lowered/jax_constraints.py +133 -0
- openscvx/lowered/parameters.py +54 -0
- openscvx/lowered/problem.py +70 -0
- openscvx/lowered/unified.py +718 -0
- openscvx/plotting/__init__.py +63 -0
- openscvx/plotting/plotting.py +756 -0
- openscvx/plotting/scp_iteration.py +299 -0
- openscvx/plotting/viser/__init__.py +126 -0
- openscvx/plotting/viser/animated.py +605 -0
- openscvx/plotting/viser/plotly_integration.py +333 -0
- openscvx/plotting/viser/primitives.py +355 -0
- openscvx/plotting/viser/scp.py +459 -0
- openscvx/plotting/viser/server.py +112 -0
- openscvx/problem.py +734 -0
- openscvx/propagation/__init__.py +60 -0
- openscvx/propagation/post_processing.py +104 -0
- openscvx/propagation/propagation.py +248 -0
- openscvx/solvers/__init__.py +51 -0
- openscvx/solvers/cvxpy.py +226 -0
- openscvx/symbolic/__init__.py +9 -0
- openscvx/symbolic/augmentation.py +630 -0
- openscvx/symbolic/builder.py +492 -0
- openscvx/symbolic/constraint_set.py +92 -0
- openscvx/symbolic/expr/__init__.py +222 -0
- openscvx/symbolic/expr/arithmetic.py +517 -0
- openscvx/symbolic/expr/array.py +632 -0
- openscvx/symbolic/expr/constraint.py +796 -0
- openscvx/symbolic/expr/control.py +135 -0
- openscvx/symbolic/expr/expr.py +720 -0
- openscvx/symbolic/expr/lie/__init__.py +87 -0
- openscvx/symbolic/expr/lie/adjoint.py +357 -0
- openscvx/symbolic/expr/lie/se3.py +172 -0
- openscvx/symbolic/expr/lie/so3.py +138 -0
- openscvx/symbolic/expr/linalg.py +279 -0
- openscvx/symbolic/expr/math.py +699 -0
- openscvx/symbolic/expr/spatial.py +209 -0
- openscvx/symbolic/expr/state.py +607 -0
- openscvx/symbolic/expr/stl.py +136 -0
- openscvx/symbolic/expr/variable.py +321 -0
- openscvx/symbolic/hashing.py +112 -0
- openscvx/symbolic/lower.py +760 -0
- openscvx/symbolic/lowerers/__init__.py +106 -0
- openscvx/symbolic/lowerers/cvxpy.py +1302 -0
- openscvx/symbolic/lowerers/jax.py +1382 -0
- openscvx/symbolic/preprocessing.py +757 -0
- openscvx/symbolic/problem.py +110 -0
- openscvx/symbolic/time.py +116 -0
- openscvx/symbolic/unified.py +420 -0
- openscvx/utils/__init__.py +20 -0
- openscvx/utils/cache.py +131 -0
- openscvx/utils/caching.py +210 -0
- openscvx/utils/printing.py +301 -0
- openscvx/utils/profiling.py +37 -0
- openscvx/utils/utils.py +100 -0
- openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
- openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
- openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
- openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
- openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,517 @@
|
|
|
1
|
+
"""Arithmetic operations for symbolic expressions.
|
|
2
|
+
|
|
3
|
+
This module provides fundamental arithmetic operations that form the building blocks
|
|
4
|
+
of symbolic expressions in openscvx. These operations are created automatically through
|
|
5
|
+
operator overloading on Expr objects.
|
|
6
|
+
|
|
7
|
+
Arithmetic Operations:
|
|
8
|
+
|
|
9
|
+
- **Binary operations:** `Add`, `Sub`, `Mul`, `Div`, `MatMul`, `Power` - Standard arithmetic
|
|
10
|
+
- **Unary operations:** `Neg` - Negation (unary minus)
|
|
11
|
+
|
|
12
|
+
All arithmetic operations support:
|
|
13
|
+
- Automatic canonicalization (constant folding, identity elimination, flattening)
|
|
14
|
+
- Broadcasting following NumPy rules (except MatMul which follows linear algebra rules)
|
|
15
|
+
- Shape checking and validation
|
|
16
|
+
|
|
17
|
+
Example:
|
|
18
|
+
Arithmetic operations are created via operator overloading::
|
|
19
|
+
|
|
20
|
+
import openscvx as ox
|
|
21
|
+
|
|
22
|
+
x = ox.State("x", shape=(3,))
|
|
23
|
+
y = ox.State("y", shape=(3,))
|
|
24
|
+
|
|
25
|
+
# Element-wise operations
|
|
26
|
+
z = x + y # Creates Add(x, y)
|
|
27
|
+
w = x * 2 # Creates Mul(x, Constant(2))
|
|
28
|
+
neg_x = -x # Creates Neg(x)
|
|
29
|
+
|
|
30
|
+
# Matrix multiplication
|
|
31
|
+
A = ox.State("A", shape=(3, 3))
|
|
32
|
+
b = A @ x # Creates MatMul(A, x)
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
from typing import Tuple
|
|
36
|
+
|
|
37
|
+
import numpy as np
|
|
38
|
+
|
|
39
|
+
from .expr import Constant, Expr, to_expr
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class Add(Expr):
|
|
43
|
+
"""Addition operation for symbolic expressions.
|
|
44
|
+
|
|
45
|
+
Represents element-wise addition of two or more expressions. Supports broadcasting
|
|
46
|
+
following NumPy rules. Can be created using the + operator on Expr objects.
|
|
47
|
+
|
|
48
|
+
Attributes:
|
|
49
|
+
terms: List of expression operands to add together
|
|
50
|
+
|
|
51
|
+
Example:
|
|
52
|
+
Define an Add expression:
|
|
53
|
+
|
|
54
|
+
x = ox.State("x", shape=(3,))
|
|
55
|
+
y = ox.State("y", shape=(3,))
|
|
56
|
+
z = x + y + 5 # Creates Add(x, y, Constant(5))
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, *args):
|
|
60
|
+
"""Initialize an addition operation.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
*args: Two or more expressions to add together
|
|
64
|
+
|
|
65
|
+
Raises:
|
|
66
|
+
ValueError: If fewer than two operands are provided
|
|
67
|
+
"""
|
|
68
|
+
if len(args) < 2:
|
|
69
|
+
raise ValueError("Add requires two or more operands")
|
|
70
|
+
self.terms = [to_expr(a) for a in args]
|
|
71
|
+
|
|
72
|
+
def children(self):
|
|
73
|
+
return list(self.terms)
|
|
74
|
+
|
|
75
|
+
def canonicalize(self) -> "Expr":
|
|
76
|
+
"""Canonicalize addition: flatten, fold constants, and eliminate zeros.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Expr: Canonical form of the addition expression
|
|
80
|
+
"""
|
|
81
|
+
terms = []
|
|
82
|
+
const_vals = []
|
|
83
|
+
|
|
84
|
+
for t in self.terms:
|
|
85
|
+
c = t.canonicalize()
|
|
86
|
+
if isinstance(c, Add):
|
|
87
|
+
terms.extend(c.terms)
|
|
88
|
+
elif isinstance(c, Constant):
|
|
89
|
+
const_vals.append(c.value)
|
|
90
|
+
else:
|
|
91
|
+
terms.append(c)
|
|
92
|
+
|
|
93
|
+
if const_vals:
|
|
94
|
+
total = sum(const_vals)
|
|
95
|
+
# If not all-zero, keep it
|
|
96
|
+
if not (isinstance(total, np.ndarray) and np.all(total == 0)):
|
|
97
|
+
terms.append(Constant(total))
|
|
98
|
+
|
|
99
|
+
if not terms:
|
|
100
|
+
return Constant(np.array(0))
|
|
101
|
+
if len(terms) == 1:
|
|
102
|
+
return terms[0]
|
|
103
|
+
return Add(*terms)
|
|
104
|
+
|
|
105
|
+
def check_shape(self) -> Tuple[int, ...]:
|
|
106
|
+
"""Check shape compatibility and compute broadcasted result shape like NumPy.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
tuple: The broadcasted shape of all operands
|
|
110
|
+
|
|
111
|
+
Raises:
|
|
112
|
+
ValueError: If operand shapes are not broadcastable
|
|
113
|
+
"""
|
|
114
|
+
shapes = [child.check_shape() for child in self.children()]
|
|
115
|
+
try:
|
|
116
|
+
return np.broadcast_shapes(*shapes)
|
|
117
|
+
except ValueError as e:
|
|
118
|
+
raise ValueError(f"Add shapes not broadcastable: {shapes}") from e
|
|
119
|
+
|
|
120
|
+
def __repr__(self):
|
|
121
|
+
inner = " + ".join(repr(e) for e in self.terms)
|
|
122
|
+
return f"({inner})"
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class Sub(Expr):
|
|
126
|
+
"""Subtraction operation for symbolic expressions.
|
|
127
|
+
|
|
128
|
+
Represents element-wise subtraction (left - right). Supports broadcasting
|
|
129
|
+
following NumPy rules. Can be created using the - operator on Expr objects.
|
|
130
|
+
|
|
131
|
+
Attributes:
|
|
132
|
+
left: Left-hand side expression (minuend)
|
|
133
|
+
right: Right-hand side expression (subtrahend)
|
|
134
|
+
|
|
135
|
+
Example:
|
|
136
|
+
Define a Sub expression:
|
|
137
|
+
|
|
138
|
+
x = ox.State("x", shape=(3,))
|
|
139
|
+
y = ox.State("y", shape=(3,))
|
|
140
|
+
z = x - y # Creates Sub(x, y)
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
def __init__(self, left, right):
|
|
144
|
+
"""Initialize a subtraction operation.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
left: Expression to subtract from (minuend)
|
|
148
|
+
right: Expression to subtract (subtrahend)
|
|
149
|
+
"""
|
|
150
|
+
self.left = left
|
|
151
|
+
self.right = right
|
|
152
|
+
|
|
153
|
+
def children(self):
|
|
154
|
+
return [self.left, self.right]
|
|
155
|
+
|
|
156
|
+
def canonicalize(self) -> "Expr":
|
|
157
|
+
"""Canonicalize subtraction: fold constants if both sides are constants.
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
Expr: Canonical form of the subtraction expression
|
|
161
|
+
"""
|
|
162
|
+
left = self.left.canonicalize()
|
|
163
|
+
right = self.right.canonicalize()
|
|
164
|
+
if isinstance(left, Constant) and isinstance(right, Constant):
|
|
165
|
+
return Constant(left.value - right.value)
|
|
166
|
+
return Sub(left, right)
|
|
167
|
+
|
|
168
|
+
def check_shape(self) -> Tuple[int, ...]:
|
|
169
|
+
"""Check shape compatibility and compute broadcasted result shape like NumPy.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
tuple: The broadcasted shape of all operands
|
|
173
|
+
|
|
174
|
+
Raises:
|
|
175
|
+
ValueError: If operand shapes are not broadcastable
|
|
176
|
+
"""
|
|
177
|
+
shapes = [child.check_shape() for child in self.children()]
|
|
178
|
+
try:
|
|
179
|
+
return np.broadcast_shapes(*shapes)
|
|
180
|
+
except ValueError as e:
|
|
181
|
+
raise ValueError(f"Sub shapes not broadcastable: {shapes}") from e
|
|
182
|
+
|
|
183
|
+
def __repr__(self):
|
|
184
|
+
return f"({self.left!r} - {self.right!r})"
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class Mul(Expr):
|
|
188
|
+
"""Element-wise multiplication operation for symbolic expressions.
|
|
189
|
+
|
|
190
|
+
Represents element-wise (Hadamard) multiplication of two or more expressions.
|
|
191
|
+
Supports broadcasting following NumPy rules. Can be created using the * operator
|
|
192
|
+
on Expr objects. For matrix multiplication, use MatMul or the @ operator.
|
|
193
|
+
|
|
194
|
+
Attributes:
|
|
195
|
+
factors: List of expression operands to multiply together
|
|
196
|
+
|
|
197
|
+
Example:
|
|
198
|
+
Define a Mul expression:
|
|
199
|
+
|
|
200
|
+
x = ox.State("x", shape=(3,))
|
|
201
|
+
y = ox.State("y", shape=(3,))
|
|
202
|
+
z = x * y * 2 # Creates Mul(x, y, Constant(2))
|
|
203
|
+
"""
|
|
204
|
+
|
|
205
|
+
def __init__(self, *args):
|
|
206
|
+
"""Initialize an element-wise multiplication operation.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
*args: Two or more expressions to multiply together
|
|
210
|
+
|
|
211
|
+
Raises:
|
|
212
|
+
ValueError: If fewer than two operands are provided
|
|
213
|
+
"""
|
|
214
|
+
if len(args) < 2:
|
|
215
|
+
raise ValueError("Mul requires two or more operands")
|
|
216
|
+
self.factors = [to_expr(a) for a in args]
|
|
217
|
+
|
|
218
|
+
def children(self):
|
|
219
|
+
return list(self.factors)
|
|
220
|
+
|
|
221
|
+
def canonicalize(self) -> "Expr":
|
|
222
|
+
"""Canonicalize multiplication: flatten, fold constants, and eliminating ones.
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
Expr: Canonical form of the multiplication expression
|
|
226
|
+
"""
|
|
227
|
+
factors = []
|
|
228
|
+
const_vals = []
|
|
229
|
+
|
|
230
|
+
for f in self.factors:
|
|
231
|
+
c = f.canonicalize()
|
|
232
|
+
if isinstance(c, Mul):
|
|
233
|
+
factors.extend(c.factors)
|
|
234
|
+
elif isinstance(c, Constant):
|
|
235
|
+
const_vals.append(c.value)
|
|
236
|
+
else:
|
|
237
|
+
factors.append(c)
|
|
238
|
+
|
|
239
|
+
if const_vals:
|
|
240
|
+
# Multiply constants element-wise (broadcasting), not reducing with prod
|
|
241
|
+
prod = const_vals[0]
|
|
242
|
+
for val in const_vals[1:]:
|
|
243
|
+
prod = prod * val
|
|
244
|
+
|
|
245
|
+
# If prod != 1, keep it
|
|
246
|
+
# Check both scalar and array cases
|
|
247
|
+
is_identity = False
|
|
248
|
+
if isinstance(prod, np.ndarray):
|
|
249
|
+
is_identity = np.all(prod == 1)
|
|
250
|
+
else:
|
|
251
|
+
is_identity = prod == 1
|
|
252
|
+
|
|
253
|
+
if not is_identity:
|
|
254
|
+
factors.append(Constant(prod))
|
|
255
|
+
|
|
256
|
+
if not factors:
|
|
257
|
+
return Constant(np.array(1))
|
|
258
|
+
if len(factors) == 1:
|
|
259
|
+
return factors[0]
|
|
260
|
+
return Mul(*factors)
|
|
261
|
+
|
|
262
|
+
def check_shape(self) -> Tuple[int, ...]:
|
|
263
|
+
"""Check shape compatibility and compute broadcasted result shape like NumPy.
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
tuple: The broadcasted shape of all operands
|
|
268
|
+
|
|
269
|
+
Raises:
|
|
270
|
+
ValueError: If operand shapes are not broadcastable
|
|
271
|
+
"""
|
|
272
|
+
shapes = [child.check_shape() for child in self.children()]
|
|
273
|
+
try:
|
|
274
|
+
return np.broadcast_shapes(*shapes)
|
|
275
|
+
except ValueError as e:
|
|
276
|
+
raise ValueError(f"Mul shapes not broadcastable: {shapes}") from e
|
|
277
|
+
|
|
278
|
+
def __repr__(self):
|
|
279
|
+
inner = " * ".join(repr(e) for e in self.factors)
|
|
280
|
+
return f"({inner})"
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
class Div(Expr):
|
|
284
|
+
"""Element-wise division operation for symbolic expressions.
|
|
285
|
+
|
|
286
|
+
Represents element-wise division (left / right). Supports broadcasting
|
|
287
|
+
following NumPy rules. Can be created using the / operator on Expr objects.
|
|
288
|
+
|
|
289
|
+
Attributes:
|
|
290
|
+
left: Numerator expression
|
|
291
|
+
right: Denominator expression
|
|
292
|
+
|
|
293
|
+
Example:
|
|
294
|
+
Define a Div expression
|
|
295
|
+
|
|
296
|
+
x = ox.State("x", shape=(3,))
|
|
297
|
+
y = ox.State("y", shape=(3,))
|
|
298
|
+
z = x / y # Creates Div(x, y)
|
|
299
|
+
"""
|
|
300
|
+
|
|
301
|
+
def __init__(self, left, right):
|
|
302
|
+
"""Initialize a division operation.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
left: Expression for the numerator
|
|
306
|
+
right: Expression for the denominator
|
|
307
|
+
"""
|
|
308
|
+
self.left = left
|
|
309
|
+
self.right = right
|
|
310
|
+
|
|
311
|
+
def children(self):
|
|
312
|
+
return [self.left, self.right]
|
|
313
|
+
|
|
314
|
+
def canonicalize(self) -> "Expr":
|
|
315
|
+
"""Canonicalize division: fold constants if both sides are constants.
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
Expr: Canonical form of the division expression
|
|
319
|
+
"""
|
|
320
|
+
lhs = self.left.canonicalize()
|
|
321
|
+
rhs = self.right.canonicalize()
|
|
322
|
+
if isinstance(lhs, Constant) and isinstance(rhs, Constant):
|
|
323
|
+
return Constant(lhs.value / rhs.value)
|
|
324
|
+
return Div(lhs, rhs)
|
|
325
|
+
|
|
326
|
+
def check_shape(self) -> Tuple[int, ...]:
|
|
327
|
+
"""Check shape compatibility and compute broadcasted result shape like NumPy.
|
|
328
|
+
|
|
329
|
+
Returns:
|
|
330
|
+
tuple: The broadcasted shape of both operands
|
|
331
|
+
|
|
332
|
+
Raises:
|
|
333
|
+
ValueError: If operand shapes are not broadcastable
|
|
334
|
+
"""
|
|
335
|
+
shapes = [child.check_shape() for child in self.children()]
|
|
336
|
+
try:
|
|
337
|
+
return np.broadcast_shapes(*shapes)
|
|
338
|
+
except ValueError as e:
|
|
339
|
+
raise ValueError(f"Div shapes not broadcastable: {shapes}") from e
|
|
340
|
+
|
|
341
|
+
def __repr__(self):
|
|
342
|
+
return f"({self.left!r} / {self.right!r})"
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
class MatMul(Expr):
|
|
346
|
+
"""Matrix multiplication operation for symbolic expressions.
|
|
347
|
+
|
|
348
|
+
Represents matrix multiplication following standard linear algebra rules.
|
|
349
|
+
Can be created using the @ operator on Expr objects. Handles:
|
|
350
|
+
- Matrix @ Matrix: (m,n) @ (n,k) -> (m,k)
|
|
351
|
+
- Matrix @ Vector: (m,n) @ (n,) -> (m,)
|
|
352
|
+
- Vector @ Matrix: (m,) @ (m,n) -> (n,)
|
|
353
|
+
- Vector @ Vector: (m,) @ (m,) -> scalar
|
|
354
|
+
|
|
355
|
+
Attributes:
|
|
356
|
+
left: Left-hand side expression
|
|
357
|
+
right: Right-hand side expression
|
|
358
|
+
|
|
359
|
+
Example:
|
|
360
|
+
Define a MatMul expression:
|
|
361
|
+
|
|
362
|
+
A = ox.State("A", shape=(3, 4))
|
|
363
|
+
x = ox.State("x", shape=(4,))
|
|
364
|
+
y = A @ x # Creates MatMul(A, x), result shape (3,)
|
|
365
|
+
"""
|
|
366
|
+
|
|
367
|
+
def __init__(self, left, right):
|
|
368
|
+
"""Initialize a matrix multiplication operation.
|
|
369
|
+
|
|
370
|
+
Args:
|
|
371
|
+
left: Left-hand side expression for matrix multiplication
|
|
372
|
+
right: Right-hand side expression for matrix multiplication
|
|
373
|
+
"""
|
|
374
|
+
self.left = left
|
|
375
|
+
self.right = right
|
|
376
|
+
|
|
377
|
+
def children(self):
|
|
378
|
+
return [self.left, self.right]
|
|
379
|
+
|
|
380
|
+
def canonicalize(self) -> "Expr":
|
|
381
|
+
left = self.left.canonicalize()
|
|
382
|
+
right = self.right.canonicalize()
|
|
383
|
+
return MatMul(left, right)
|
|
384
|
+
|
|
385
|
+
def check_shape(self) -> Tuple[int, ...]:
|
|
386
|
+
"""Check matrix multiplication shape compatibility and return result shape."""
|
|
387
|
+
L, R = self.left.check_shape(), self.right.check_shape()
|
|
388
|
+
|
|
389
|
+
# Handle different matmul cases:
|
|
390
|
+
# Matrix @ Matrix: (m,n) @ (n,k) -> (m,k)
|
|
391
|
+
# Matrix @ Vector: (m,n) @ (n,) -> (m,)
|
|
392
|
+
# Vector @ Matrix: (m,) @ (m,n) -> (n,)
|
|
393
|
+
# Vector @ Vector: (m,) @ (m,) -> ()
|
|
394
|
+
|
|
395
|
+
if len(L) == 0 or len(R) == 0:
|
|
396
|
+
raise ValueError(f"MatMul requires at least 1D operands: {L} @ {R}")
|
|
397
|
+
|
|
398
|
+
if len(L) == 1 and len(R) == 1:
|
|
399
|
+
# Vector @ Vector -> scalar
|
|
400
|
+
if L[0] != R[0]:
|
|
401
|
+
raise ValueError(f"MatMul incompatible: {L} @ {R}")
|
|
402
|
+
return ()
|
|
403
|
+
elif len(L) == 1:
|
|
404
|
+
# Vector @ Matrix: (m,) @ (m,n) -> (n,)
|
|
405
|
+
if len(R) < 2 or L[0] != R[-2]:
|
|
406
|
+
raise ValueError(f"MatMul incompatible: {L} @ {R}")
|
|
407
|
+
return R[-1:]
|
|
408
|
+
elif len(R) == 1:
|
|
409
|
+
# Matrix @ Vector: (m,n) @ (n,) -> (m,)
|
|
410
|
+
if len(L) < 2 or L[-1] != R[0]:
|
|
411
|
+
raise ValueError(f"MatMul incompatible: {L} @ {R}")
|
|
412
|
+
return L[:-1]
|
|
413
|
+
else:
|
|
414
|
+
# Matrix @ Matrix: (...,m,n) @ (...,n,k) -> (...,m,k)
|
|
415
|
+
if len(L) < 2 or len(R) < 2 or L[-1] != R[-2]:
|
|
416
|
+
raise ValueError(f"MatMul incompatible: {L} @ {R}")
|
|
417
|
+
return L[:-1] + (R[-1],)
|
|
418
|
+
|
|
419
|
+
def __repr__(self):
|
|
420
|
+
return f"({self.left!r} * {self.right!r})"
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
class Neg(Expr):
|
|
424
|
+
"""Negation operation for symbolic expressions.
|
|
425
|
+
|
|
426
|
+
Represents element-wise negation (unary minus). Can be created using the
|
|
427
|
+
unary - operator on Expr objects.
|
|
428
|
+
|
|
429
|
+
Attributes:
|
|
430
|
+
operand: Expression to negate
|
|
431
|
+
|
|
432
|
+
Example:
|
|
433
|
+
Define a Neg expression:
|
|
434
|
+
|
|
435
|
+
x = ox.State("x", shape=(3,))
|
|
436
|
+
y = -x # Creates Neg(x)
|
|
437
|
+
"""
|
|
438
|
+
|
|
439
|
+
def __init__(self, operand):
|
|
440
|
+
"""Initialize a negation operation.
|
|
441
|
+
|
|
442
|
+
Args:
|
|
443
|
+
operand: Expression to negate
|
|
444
|
+
"""
|
|
445
|
+
self.operand = operand
|
|
446
|
+
|
|
447
|
+
def children(self):
|
|
448
|
+
return [self.operand]
|
|
449
|
+
|
|
450
|
+
def canonicalize(self) -> "Expr":
|
|
451
|
+
"""Canonicalize negation: fold constant negations.
|
|
452
|
+
|
|
453
|
+
Returns:
|
|
454
|
+
Expr: Canonical form of the negation expression
|
|
455
|
+
"""
|
|
456
|
+
o = self.operand.canonicalize()
|
|
457
|
+
if isinstance(o, Constant):
|
|
458
|
+
return Constant(-o.value)
|
|
459
|
+
return Neg(o)
|
|
460
|
+
|
|
461
|
+
def check_shape(self) -> Tuple[int, ...]:
|
|
462
|
+
"""Negation preserves the shape of its operand."""
|
|
463
|
+
return self.operand.check_shape()
|
|
464
|
+
|
|
465
|
+
def __repr__(self):
|
|
466
|
+
return f"(-{self.operand!r})"
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
class Power(Expr):
|
|
470
|
+
"""Element-wise power operation for symbolic expressions.
|
|
471
|
+
|
|
472
|
+
Represents element-wise exponentiation (base ** exponent). Supports broadcasting
|
|
473
|
+
following NumPy rules. Can be created using the ** operator on Expr objects.
|
|
474
|
+
|
|
475
|
+
Attributes:
|
|
476
|
+
base: Base expression
|
|
477
|
+
exponent: Exponent expression
|
|
478
|
+
|
|
479
|
+
Example:
|
|
480
|
+
Define a Power expression:
|
|
481
|
+
|
|
482
|
+
x = ox.State("x", shape=(3,))
|
|
483
|
+
y = x ** 2 # Creates Power(x, Constant(2))
|
|
484
|
+
"""
|
|
485
|
+
|
|
486
|
+
def __init__(self, base, exponent):
|
|
487
|
+
"""Initialize a power operation.
|
|
488
|
+
|
|
489
|
+
Args:
|
|
490
|
+
base: Base expression
|
|
491
|
+
exponent: Exponent expression
|
|
492
|
+
"""
|
|
493
|
+
self.base = to_expr(base)
|
|
494
|
+
self.exponent = to_expr(exponent)
|
|
495
|
+
|
|
496
|
+
def children(self):
|
|
497
|
+
return [self.base, self.exponent]
|
|
498
|
+
|
|
499
|
+
def canonicalize(self) -> "Expr":
|
|
500
|
+
"""Canonicalize power by canonicalizing base and exponent.
|
|
501
|
+
|
|
502
|
+
Returns:
|
|
503
|
+
Expr: Canonical form of the power expression
|
|
504
|
+
"""
|
|
505
|
+
base = self.base.canonicalize()
|
|
506
|
+
exponent = self.exponent.canonicalize()
|
|
507
|
+
return Power(base, exponent)
|
|
508
|
+
|
|
509
|
+
def check_shape(self) -> Tuple[int, ...]:
|
|
510
|
+
shapes = [child.check_shape() for child in self.children()]
|
|
511
|
+
try:
|
|
512
|
+
return np.broadcast_shapes(*shapes)
|
|
513
|
+
except ValueError as e:
|
|
514
|
+
raise ValueError(f"Power shapes not broadcastable: {shapes}") from e
|
|
515
|
+
|
|
516
|
+
def __repr__(self):
|
|
517
|
+
return f"({self.base!r})**({self.exponent!r})"
|