openscvx 0.3.2.dev170__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/__init__.py +123 -0
- openscvx/_version.py +34 -0
- openscvx/algorithms/__init__.py +92 -0
- openscvx/algorithms/autotuning.py +24 -0
- openscvx/algorithms/base.py +351 -0
- openscvx/algorithms/optimization_results.py +215 -0
- openscvx/algorithms/penalized_trust_region.py +384 -0
- openscvx/config.py +437 -0
- openscvx/discretization/__init__.py +47 -0
- openscvx/discretization/discretization.py +236 -0
- openscvx/expert/__init__.py +23 -0
- openscvx/expert/byof.py +326 -0
- openscvx/expert/lowering.py +419 -0
- openscvx/expert/validation.py +357 -0
- openscvx/integrators/__init__.py +48 -0
- openscvx/integrators/runge_kutta.py +281 -0
- openscvx/lowered/__init__.py +30 -0
- openscvx/lowered/cvxpy_constraints.py +23 -0
- openscvx/lowered/cvxpy_variables.py +124 -0
- openscvx/lowered/dynamics.py +34 -0
- openscvx/lowered/jax_constraints.py +133 -0
- openscvx/lowered/parameters.py +54 -0
- openscvx/lowered/problem.py +70 -0
- openscvx/lowered/unified.py +718 -0
- openscvx/plotting/__init__.py +63 -0
- openscvx/plotting/plotting.py +756 -0
- openscvx/plotting/scp_iteration.py +299 -0
- openscvx/plotting/viser/__init__.py +126 -0
- openscvx/plotting/viser/animated.py +605 -0
- openscvx/plotting/viser/plotly_integration.py +333 -0
- openscvx/plotting/viser/primitives.py +355 -0
- openscvx/plotting/viser/scp.py +459 -0
- openscvx/plotting/viser/server.py +112 -0
- openscvx/problem.py +734 -0
- openscvx/propagation/__init__.py +60 -0
- openscvx/propagation/post_processing.py +104 -0
- openscvx/propagation/propagation.py +248 -0
- openscvx/solvers/__init__.py +51 -0
- openscvx/solvers/cvxpy.py +226 -0
- openscvx/symbolic/__init__.py +9 -0
- openscvx/symbolic/augmentation.py +630 -0
- openscvx/symbolic/builder.py +492 -0
- openscvx/symbolic/constraint_set.py +92 -0
- openscvx/symbolic/expr/__init__.py +222 -0
- openscvx/symbolic/expr/arithmetic.py +517 -0
- openscvx/symbolic/expr/array.py +632 -0
- openscvx/symbolic/expr/constraint.py +796 -0
- openscvx/symbolic/expr/control.py +135 -0
- openscvx/symbolic/expr/expr.py +720 -0
- openscvx/symbolic/expr/lie/__init__.py +87 -0
- openscvx/symbolic/expr/lie/adjoint.py +357 -0
- openscvx/symbolic/expr/lie/se3.py +172 -0
- openscvx/symbolic/expr/lie/so3.py +138 -0
- openscvx/symbolic/expr/linalg.py +279 -0
- openscvx/symbolic/expr/math.py +699 -0
- openscvx/symbolic/expr/spatial.py +209 -0
- openscvx/symbolic/expr/state.py +607 -0
- openscvx/symbolic/expr/stl.py +136 -0
- openscvx/symbolic/expr/variable.py +321 -0
- openscvx/symbolic/hashing.py +112 -0
- openscvx/symbolic/lower.py +760 -0
- openscvx/symbolic/lowerers/__init__.py +106 -0
- openscvx/symbolic/lowerers/cvxpy.py +1302 -0
- openscvx/symbolic/lowerers/jax.py +1382 -0
- openscvx/symbolic/preprocessing.py +757 -0
- openscvx/symbolic/problem.py +110 -0
- openscvx/symbolic/time.py +116 -0
- openscvx/symbolic/unified.py +420 -0
- openscvx/utils/__init__.py +20 -0
- openscvx/utils/cache.py +131 -0
- openscvx/utils/caching.py +210 -0
- openscvx/utils/printing.py +301 -0
- openscvx/utils/profiling.py +37 -0
- openscvx/utils/utils.py +100 -0
- openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
- openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
- openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
- openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
- openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""SO(3) Lie group operations for rotation matrices.
|
|
2
|
+
|
|
3
|
+
This module provides exponential and logarithm maps for the SO(3) rotation
|
|
4
|
+
group, enabling axis-angle to rotation matrix conversions and vice versa.
|
|
5
|
+
|
|
6
|
+
Requires jaxlie: pip install openscvx[lie]
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Tuple
|
|
10
|
+
|
|
11
|
+
import jaxlie # noqa: F401 - validates jaxlie is installed
|
|
12
|
+
|
|
13
|
+
from ..expr import Expr, to_expr
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SO3Exp(Expr):
|
|
17
|
+
"""Exponential map from so(3) to SO(3) rotation matrix.
|
|
18
|
+
|
|
19
|
+
Maps a 3D rotation vector (axis-angle representation) to a 3×3 rotation
|
|
20
|
+
matrix using the Rodrigues formula. Uses jaxlie for numerically robust
|
|
21
|
+
implementation with proper handling of small angles.
|
|
22
|
+
|
|
23
|
+
The rotation vector ω has direction equal to the rotation axis and
|
|
24
|
+
magnitude equal to the rotation angle in radians.
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
omega: 3D rotation vector with shape (3,)
|
|
28
|
+
|
|
29
|
+
Example:
|
|
30
|
+
Create a rotation about the z-axis::
|
|
31
|
+
|
|
32
|
+
import openscvx as ox
|
|
33
|
+
import numpy as np
|
|
34
|
+
|
|
35
|
+
# 90 degree rotation about z
|
|
36
|
+
omega = ox.Constant(np.array([0, 0, np.pi/2]))
|
|
37
|
+
R = ox.lie.SO3Exp(omega) # 3×3 rotation matrix
|
|
38
|
+
|
|
39
|
+
Parameterized rotation for optimization::
|
|
40
|
+
|
|
41
|
+
theta = ox.State("theta", shape=(1,))
|
|
42
|
+
axis = ox.Constant(np.array([0, 0, 1])) # z-axis
|
|
43
|
+
R = ox.lie.SO3Exp(axis * theta)
|
|
44
|
+
|
|
45
|
+
See Also:
|
|
46
|
+
- SO3Log: Inverse operation (rotation matrix to rotation vector)
|
|
47
|
+
- SE3Exp: Full rigid body transformation including translation
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(self, omega):
|
|
51
|
+
"""Initialize SO3 exponential map.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
omega: 3D rotation vector (axis × angle) with shape (3,)
|
|
55
|
+
"""
|
|
56
|
+
self.omega = to_expr(omega)
|
|
57
|
+
|
|
58
|
+
def children(self):
|
|
59
|
+
return [self.omega]
|
|
60
|
+
|
|
61
|
+
def canonicalize(self) -> "Expr":
|
|
62
|
+
omega = self.omega.canonicalize()
|
|
63
|
+
return SO3Exp(omega)
|
|
64
|
+
|
|
65
|
+
def check_shape(self) -> Tuple[int, ...]:
|
|
66
|
+
"""Check that input is a 3D vector and return output shape.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
tuple: Shape (3, 3) for the rotation matrix
|
|
70
|
+
|
|
71
|
+
Raises:
|
|
72
|
+
ValueError: If omega does not have shape (3,)
|
|
73
|
+
"""
|
|
74
|
+
omega_shape = self.omega.check_shape()
|
|
75
|
+
if omega_shape != (3,):
|
|
76
|
+
raise ValueError(f"SO3Exp expects omega with shape (3,), got {omega_shape}")
|
|
77
|
+
return (3, 3)
|
|
78
|
+
|
|
79
|
+
def __repr__(self):
|
|
80
|
+
return f"SO3Exp({self.omega!r})"
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class SO3Log(Expr):
|
|
84
|
+
"""Logarithm map from SO(3) rotation matrix to so(3) rotation vector.
|
|
85
|
+
|
|
86
|
+
Maps a 3×3 rotation matrix to a 3D rotation vector (axis-angle
|
|
87
|
+
representation). Uses jaxlie for numerically robust implementation.
|
|
88
|
+
|
|
89
|
+
The output rotation vector ω has direction equal to the rotation axis
|
|
90
|
+
and magnitude equal to the rotation angle in radians.
|
|
91
|
+
|
|
92
|
+
Attributes:
|
|
93
|
+
rotation: 3×3 rotation matrix with shape (3, 3)
|
|
94
|
+
|
|
95
|
+
Example:
|
|
96
|
+
Extract rotation vector from a rotation matrix::
|
|
97
|
+
|
|
98
|
+
import openscvx as ox
|
|
99
|
+
|
|
100
|
+
R = ox.State("R", shape=(3, 3)) # Rotation matrix state
|
|
101
|
+
omega = ox.lie.SO3Log(R) # 3D rotation vector
|
|
102
|
+
|
|
103
|
+
See Also:
|
|
104
|
+
- SO3Exp: Inverse operation (rotation vector to rotation matrix)
|
|
105
|
+
- SE3Log: Full rigid body transformation logarithm
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
def __init__(self, rotation):
|
|
109
|
+
"""Initialize SO3 logarithm map.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
rotation: 3×3 rotation matrix with shape (3, 3)
|
|
113
|
+
"""
|
|
114
|
+
self.rotation = to_expr(rotation)
|
|
115
|
+
|
|
116
|
+
def children(self):
|
|
117
|
+
return [self.rotation]
|
|
118
|
+
|
|
119
|
+
def canonicalize(self) -> "Expr":
|
|
120
|
+
rotation = self.rotation.canonicalize()
|
|
121
|
+
return SO3Log(rotation)
|
|
122
|
+
|
|
123
|
+
def check_shape(self) -> Tuple[int, ...]:
|
|
124
|
+
"""Check that input is a 3×3 matrix and return output shape.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
tuple: Shape (3,) for the rotation vector
|
|
128
|
+
|
|
129
|
+
Raises:
|
|
130
|
+
ValueError: If rotation does not have shape (3, 3)
|
|
131
|
+
"""
|
|
132
|
+
rotation_shape = self.rotation.check_shape()
|
|
133
|
+
if rotation_shape != (3, 3):
|
|
134
|
+
raise ValueError(f"SO3Log expects rotation with shape (3, 3), got {rotation_shape}")
|
|
135
|
+
return (3,)
|
|
136
|
+
|
|
137
|
+
def __repr__(self):
|
|
138
|
+
return f"SO3Log({self.rotation!r})"
|
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
"""Linear algebra operations for symbolic expressions.
|
|
2
|
+
|
|
3
|
+
This module provides essential linear algebra operations for matrix and vector
|
|
4
|
+
manipulation in optimization problems. Operations follow NumPy/JAX conventions
|
|
5
|
+
for shapes and broadcasting behavior.
|
|
6
|
+
|
|
7
|
+
Key Operations:
|
|
8
|
+
- **Matrix Operations:**
|
|
9
|
+
- `Transpose` - Matrix/tensor transposition (swaps last two dimensions)
|
|
10
|
+
- `Diag` - Construct diagonal matrix from vector
|
|
11
|
+
- **Reductions:**
|
|
12
|
+
- `Sum` - Sum all elements of an array (reduces to scalar)
|
|
13
|
+
- `Norm` - Euclidean (L2) norm and other norms of vectors/matrices
|
|
14
|
+
|
|
15
|
+
Note:
|
|
16
|
+
For array manipulation operations like stacking and concatenation, see the
|
|
17
|
+
`array` module.
|
|
18
|
+
|
|
19
|
+
Example:
|
|
20
|
+
Matrix transposition and diagonal matrices::
|
|
21
|
+
|
|
22
|
+
import openscvx as ox
|
|
23
|
+
import numpy as np
|
|
24
|
+
|
|
25
|
+
# Transpose a matrix
|
|
26
|
+
A = ox.State("A", shape=(3, 4))
|
|
27
|
+
A_T = A.T # Result shape (4, 3)
|
|
28
|
+
|
|
29
|
+
# Create a diagonal matrix
|
|
30
|
+
v = ox.State("v", shape=(5,))
|
|
31
|
+
D = ox.Diag(v) # Result shape (5, 5)
|
|
32
|
+
|
|
33
|
+
Reduction operations::
|
|
34
|
+
|
|
35
|
+
x = ox.State("x", shape=(3, 4))
|
|
36
|
+
|
|
37
|
+
# Sum all elements
|
|
38
|
+
total = ox.Sum(x) # Result is scalar
|
|
39
|
+
|
|
40
|
+
# Compute norm
|
|
41
|
+
magnitude = ox.Norm(x) # Result is scalar
|
|
42
|
+
|
|
43
|
+
Computing kinetic energy with norms::
|
|
44
|
+
|
|
45
|
+
v = ox.State("v", shape=(3,)) # Velocity vector
|
|
46
|
+
m = 10.0 # Mass
|
|
47
|
+
kinetic_energy = 0.5 * m * ox.Norm(v)**2
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
import hashlib
|
|
51
|
+
from typing import Tuple
|
|
52
|
+
|
|
53
|
+
from .expr import Expr, to_expr
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class Transpose(Expr):
|
|
57
|
+
"""Matrix transpose operation for symbolic expressions.
|
|
58
|
+
|
|
59
|
+
Transposes the last two dimensions of an expression. For matrices, this swaps
|
|
60
|
+
rows and columns. For higher-dimensional arrays, it swaps the last two axes.
|
|
61
|
+
Scalars and vectors are unchanged by transposition.
|
|
62
|
+
|
|
63
|
+
The canonicalization includes an optimization that eliminates double transposes:
|
|
64
|
+
(A.T).T simplifies to A.
|
|
65
|
+
|
|
66
|
+
Attributes:
|
|
67
|
+
operand: Expression to transpose
|
|
68
|
+
|
|
69
|
+
Example:
|
|
70
|
+
Define Tranpose expressions:
|
|
71
|
+
|
|
72
|
+
A = Variable("A", shape=(3, 4))
|
|
73
|
+
A_T = Transpose(A) # or A.T, result shape (4, 3)
|
|
74
|
+
v = Variable("v", shape=(5,))
|
|
75
|
+
v_T = Transpose(v) # result shape (5,) - vectors unchanged
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
def __init__(self, operand):
|
|
79
|
+
"""Initialize a transpose operation.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
operand: Expression to transpose
|
|
83
|
+
"""
|
|
84
|
+
self.operand = to_expr(operand)
|
|
85
|
+
|
|
86
|
+
def children(self):
|
|
87
|
+
return [self.operand]
|
|
88
|
+
|
|
89
|
+
def canonicalize(self) -> "Expr":
|
|
90
|
+
"""Canonicalize the operand with double transpose optimization."""
|
|
91
|
+
operand = self.operand.canonicalize()
|
|
92
|
+
|
|
93
|
+
# Double transpose optimization: (A.T).T = A
|
|
94
|
+
if isinstance(operand, Transpose):
|
|
95
|
+
return operand.operand
|
|
96
|
+
|
|
97
|
+
return Transpose(operand)
|
|
98
|
+
|
|
99
|
+
def check_shape(self) -> Tuple[int, ...]:
|
|
100
|
+
"""Matrix transpose operation swaps the last two dimensions."""
|
|
101
|
+
operand_shape = self.operand.check_shape()
|
|
102
|
+
|
|
103
|
+
if len(operand_shape) == 0:
|
|
104
|
+
# Scalar transpose is the scalar itself
|
|
105
|
+
return ()
|
|
106
|
+
elif len(operand_shape) == 1:
|
|
107
|
+
# Vector transpose is the vector itself (row vector remains row vector)
|
|
108
|
+
return operand_shape
|
|
109
|
+
elif len(operand_shape) == 2:
|
|
110
|
+
# Matrix transpose: (m,n) -> (n,m)
|
|
111
|
+
return (operand_shape[1], operand_shape[0])
|
|
112
|
+
else:
|
|
113
|
+
# Higher-dimensional array: transpose last two dimensions
|
|
114
|
+
# (..., m, n) -> (..., n, m)
|
|
115
|
+
return operand_shape[:-2] + (operand_shape[-1], operand_shape[-2])
|
|
116
|
+
|
|
117
|
+
def __repr__(self):
|
|
118
|
+
return f"({self.operand!r}).T"
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class Diag(Expr):
|
|
122
|
+
"""Diagonal matrix construction from a vector.
|
|
123
|
+
|
|
124
|
+
Creates a square diagonal matrix from a 1D vector. The vector elements become
|
|
125
|
+
the diagonal entries, with all off-diagonal entries set to zero. This is
|
|
126
|
+
analogous to numpy.diag() or jax.numpy.diag().
|
|
127
|
+
|
|
128
|
+
Note:
|
|
129
|
+
Currently only supports creating diagonal matrices from vectors.
|
|
130
|
+
Extracting diagonals from matrices is not yet implemented.
|
|
131
|
+
|
|
132
|
+
Attributes:
|
|
133
|
+
operand: 1D vector expression to place on the diagonal
|
|
134
|
+
|
|
135
|
+
Example:
|
|
136
|
+
Define a Diag:
|
|
137
|
+
|
|
138
|
+
v = Variable("v", shape=(3,))
|
|
139
|
+
D = Diag(v) # Creates a (3, 3) diagonal matrix
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
def __init__(self, operand):
|
|
143
|
+
"""Initialize a diagonal matrix operation.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
operand: 1D vector expression to place on the diagonal
|
|
147
|
+
"""
|
|
148
|
+
self.operand = to_expr(operand)
|
|
149
|
+
|
|
150
|
+
def children(self):
|
|
151
|
+
return [self.operand]
|
|
152
|
+
|
|
153
|
+
def canonicalize(self) -> "Expr":
|
|
154
|
+
operand = self.operand.canonicalize()
|
|
155
|
+
return Diag(operand)
|
|
156
|
+
|
|
157
|
+
def check_shape(self) -> Tuple[int, ...]:
|
|
158
|
+
"""Diag converts a vector (n,) to a diagonal matrix (n,n)."""
|
|
159
|
+
operand_shape = self.operand.check_shape()
|
|
160
|
+
if len(operand_shape) != 1:
|
|
161
|
+
raise ValueError(f"Diag expects a 1D vector, got shape {operand_shape}")
|
|
162
|
+
n = operand_shape[0]
|
|
163
|
+
return (n, n)
|
|
164
|
+
|
|
165
|
+
def __repr__(self):
|
|
166
|
+
return f"diag({self.operand!r})"
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class Sum(Expr):
|
|
170
|
+
"""Sum reduction operation for symbolic expressions.
|
|
171
|
+
|
|
172
|
+
Sums all elements of an expression, reducing it to a scalar. This is a
|
|
173
|
+
reduction operation that collapses all dimensions.
|
|
174
|
+
|
|
175
|
+
Attributes:
|
|
176
|
+
operand: Expression whose elements will be summed
|
|
177
|
+
|
|
178
|
+
Example:
|
|
179
|
+
Define a Sum expression::
|
|
180
|
+
|
|
181
|
+
x = ox.State("x", shape=(3, 4))
|
|
182
|
+
total = Sum(x) # Creates Sum(x), result shape ()
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
def __init__(self, operand):
|
|
186
|
+
"""Initialize a sum reduction operation.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
operand: Expression to sum over all elements
|
|
190
|
+
"""
|
|
191
|
+
self.operand = to_expr(operand)
|
|
192
|
+
|
|
193
|
+
def children(self):
|
|
194
|
+
return [self.operand]
|
|
195
|
+
|
|
196
|
+
def canonicalize(self) -> "Expr":
|
|
197
|
+
"""Canonicalize sum: canonicalize the operand.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
Expr: Canonical form of the sum expression
|
|
201
|
+
"""
|
|
202
|
+
operand = self.operand.canonicalize()
|
|
203
|
+
return Sum(operand)
|
|
204
|
+
|
|
205
|
+
def check_shape(self) -> Tuple[int, ...]:
|
|
206
|
+
"""Sum reduces any shape to a scalar."""
|
|
207
|
+
# Validate that the operand has a valid shape
|
|
208
|
+
self.operand.check_shape()
|
|
209
|
+
# Sum always produces a scalar regardless of input shape
|
|
210
|
+
return ()
|
|
211
|
+
|
|
212
|
+
def __repr__(self):
|
|
213
|
+
return f"sum({self.operand!r})"
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class Norm(Expr):
|
|
217
|
+
"""Norm operation for symbolic expressions (reduction to scalar).
|
|
218
|
+
|
|
219
|
+
Computes the norm of an expression according to the specified order parameter.
|
|
220
|
+
This is a reduction operation that always produces a scalar result regardless
|
|
221
|
+
of the input shape. Supports various norm types following NumPy/SciPy conventions.
|
|
222
|
+
|
|
223
|
+
Attributes:
|
|
224
|
+
operand: Expression to compute norm of
|
|
225
|
+
ord: Norm order specification (default: "fro" for Frobenius norm)
|
|
226
|
+
- "fro": Frobenius norm (default)
|
|
227
|
+
- "inf": Infinity norm
|
|
228
|
+
- 1: L1 norm (sum of absolute values)
|
|
229
|
+
- 2: L2 norm (Euclidean norm)
|
|
230
|
+
- Other values as supported by the backend
|
|
231
|
+
|
|
232
|
+
Example:
|
|
233
|
+
Define Norms:
|
|
234
|
+
|
|
235
|
+
x = Variable("x", shape=(3,))
|
|
236
|
+
euclidean_norm = Norm(x, ord=2) # L2 norm, result is scalar
|
|
237
|
+
A = Variable("A", shape=(3, 4))
|
|
238
|
+
frobenius_norm = Norm(A) # Frobenius norm, result is scalar
|
|
239
|
+
"""
|
|
240
|
+
|
|
241
|
+
def __init__(self, operand, ord="fro"):
|
|
242
|
+
"""Initialize a norm operation.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
operand: Expression to compute norm of
|
|
246
|
+
ord: Norm order specification (default: "fro")
|
|
247
|
+
"""
|
|
248
|
+
self.operand = to_expr(operand)
|
|
249
|
+
self.ord = ord # Can be "fro", "inf", 1, 2, etc.
|
|
250
|
+
|
|
251
|
+
def children(self):
|
|
252
|
+
return [self.operand]
|
|
253
|
+
|
|
254
|
+
def canonicalize(self) -> "Expr":
|
|
255
|
+
"""Canonicalize the operand but preserve the ord parameter."""
|
|
256
|
+
canon_operand = self.operand.canonicalize()
|
|
257
|
+
return Norm(canon_operand, ord=self.ord)
|
|
258
|
+
|
|
259
|
+
def check_shape(self) -> Tuple[int, ...]:
|
|
260
|
+
"""Norm reduces any shape to a scalar."""
|
|
261
|
+
# Validate that the operand has a valid shape
|
|
262
|
+
self.operand.check_shape()
|
|
263
|
+
# Norm always produces a scalar regardless of input shape
|
|
264
|
+
return ()
|
|
265
|
+
|
|
266
|
+
def _hash_into(self, hasher: "hashlib._Hash") -> None:
|
|
267
|
+
"""Hash Norm including its ord parameter.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
hasher: A hashlib hash object to update
|
|
271
|
+
"""
|
|
272
|
+
hasher.update(b"Norm")
|
|
273
|
+
# Hash the ord parameter
|
|
274
|
+
hasher.update(repr(self.ord).encode())
|
|
275
|
+
# Hash the operand
|
|
276
|
+
self.operand._hash_into(hasher)
|
|
277
|
+
|
|
278
|
+
def __repr__(self):
|
|
279
|
+
return f"norm({self.operand!r}, ord={self.ord!r})"
|