openscvx 0.3.2.dev170__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/__init__.py +123 -0
- openscvx/_version.py +34 -0
- openscvx/algorithms/__init__.py +92 -0
- openscvx/algorithms/autotuning.py +24 -0
- openscvx/algorithms/base.py +351 -0
- openscvx/algorithms/optimization_results.py +215 -0
- openscvx/algorithms/penalized_trust_region.py +384 -0
- openscvx/config.py +437 -0
- openscvx/discretization/__init__.py +47 -0
- openscvx/discretization/discretization.py +236 -0
- openscvx/expert/__init__.py +23 -0
- openscvx/expert/byof.py +326 -0
- openscvx/expert/lowering.py +419 -0
- openscvx/expert/validation.py +357 -0
- openscvx/integrators/__init__.py +48 -0
- openscvx/integrators/runge_kutta.py +281 -0
- openscvx/lowered/__init__.py +30 -0
- openscvx/lowered/cvxpy_constraints.py +23 -0
- openscvx/lowered/cvxpy_variables.py +124 -0
- openscvx/lowered/dynamics.py +34 -0
- openscvx/lowered/jax_constraints.py +133 -0
- openscvx/lowered/parameters.py +54 -0
- openscvx/lowered/problem.py +70 -0
- openscvx/lowered/unified.py +718 -0
- openscvx/plotting/__init__.py +63 -0
- openscvx/plotting/plotting.py +756 -0
- openscvx/plotting/scp_iteration.py +299 -0
- openscvx/plotting/viser/__init__.py +126 -0
- openscvx/plotting/viser/animated.py +605 -0
- openscvx/plotting/viser/plotly_integration.py +333 -0
- openscvx/plotting/viser/primitives.py +355 -0
- openscvx/plotting/viser/scp.py +459 -0
- openscvx/plotting/viser/server.py +112 -0
- openscvx/problem.py +734 -0
- openscvx/propagation/__init__.py +60 -0
- openscvx/propagation/post_processing.py +104 -0
- openscvx/propagation/propagation.py +248 -0
- openscvx/solvers/__init__.py +51 -0
- openscvx/solvers/cvxpy.py +226 -0
- openscvx/symbolic/__init__.py +9 -0
- openscvx/symbolic/augmentation.py +630 -0
- openscvx/symbolic/builder.py +492 -0
- openscvx/symbolic/constraint_set.py +92 -0
- openscvx/symbolic/expr/__init__.py +222 -0
- openscvx/symbolic/expr/arithmetic.py +517 -0
- openscvx/symbolic/expr/array.py +632 -0
- openscvx/symbolic/expr/constraint.py +796 -0
- openscvx/symbolic/expr/control.py +135 -0
- openscvx/symbolic/expr/expr.py +720 -0
- openscvx/symbolic/expr/lie/__init__.py +87 -0
- openscvx/symbolic/expr/lie/adjoint.py +357 -0
- openscvx/symbolic/expr/lie/se3.py +172 -0
- openscvx/symbolic/expr/lie/so3.py +138 -0
- openscvx/symbolic/expr/linalg.py +279 -0
- openscvx/symbolic/expr/math.py +699 -0
- openscvx/symbolic/expr/spatial.py +209 -0
- openscvx/symbolic/expr/state.py +607 -0
- openscvx/symbolic/expr/stl.py +136 -0
- openscvx/symbolic/expr/variable.py +321 -0
- openscvx/symbolic/hashing.py +112 -0
- openscvx/symbolic/lower.py +760 -0
- openscvx/symbolic/lowerers/__init__.py +106 -0
- openscvx/symbolic/lowerers/cvxpy.py +1302 -0
- openscvx/symbolic/lowerers/jax.py +1382 -0
- openscvx/symbolic/preprocessing.py +757 -0
- openscvx/symbolic/problem.py +110 -0
- openscvx/symbolic/time.py +116 -0
- openscvx/symbolic/unified.py +420 -0
- openscvx/utils/__init__.py +20 -0
- openscvx/utils/cache.py +131 -0
- openscvx/utils/caching.py +210 -0
- openscvx/utils/printing.py +301 -0
- openscvx/utils/profiling.py +37 -0
- openscvx/utils/utils.py +100 -0
- openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
- openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
- openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
- openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
- openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,632 @@
|
|
|
1
|
+
"""Array manipulation operations for symbolic expressions.
|
|
2
|
+
|
|
3
|
+
This module provides operations for indexing, slicing, concatenating, and stacking
|
|
4
|
+
symbolic expressions. These are structural operations that manipulate array shapes
|
|
5
|
+
and combine or extract array elements, as opposed to mathematical transformations.
|
|
6
|
+
|
|
7
|
+
Key Operations:
|
|
8
|
+
|
|
9
|
+
- **Indexing and Slicing:**
|
|
10
|
+
- `Index` - NumPy-style indexing and slicing to extract subarrays
|
|
11
|
+
|
|
12
|
+
- **Concatenation:**
|
|
13
|
+
- `Concat` - Concatenate expressions along the first dimension (axis 0)
|
|
14
|
+
|
|
15
|
+
- **Stacking:**
|
|
16
|
+
- `Stack` - Stack expressions along a new first dimension
|
|
17
|
+
- `Hstack` - Horizontal stacking (along columns for 2D arrays)
|
|
18
|
+
- `Vstack` - Vertical stacking (along rows for 2D arrays)
|
|
19
|
+
|
|
20
|
+
- **Block Matrix Construction:**
|
|
21
|
+
- `Block` - Assemble block matrices from nested arrays (like numpy.block)
|
|
22
|
+
|
|
23
|
+
All operations follow NumPy conventions for shapes and indexing behavior, enabling
|
|
24
|
+
familiar array manipulation patterns in symbolic optimization problems.
|
|
25
|
+
|
|
26
|
+
Example:
|
|
27
|
+
Indexing and slicing arrays::
|
|
28
|
+
|
|
29
|
+
import openscvx as ox
|
|
30
|
+
|
|
31
|
+
x = ox.State("x", shape=(10,))
|
|
32
|
+
first_half = x[0:5] # Slice: Index(x, slice(0, 5))
|
|
33
|
+
element = x[3] # Single element: Index(x, 3)
|
|
34
|
+
|
|
35
|
+
A = ox.State("A", shape=(5, 4))
|
|
36
|
+
row = A[2, :] # Extract row
|
|
37
|
+
col = A[:, 1] # Extract column
|
|
38
|
+
|
|
39
|
+
Concatenating expressions::
|
|
40
|
+
|
|
41
|
+
from openscvx.symbolic.expr.array import Concat
|
|
42
|
+
|
|
43
|
+
x = ox.State("x", shape=(3,))
|
|
44
|
+
y = ox.State("y", shape=(4,))
|
|
45
|
+
combined = Concat(x, y) # Result shape (7,)
|
|
46
|
+
|
|
47
|
+
Stacking to build matrices::
|
|
48
|
+
|
|
49
|
+
from openscvx.symbolic.expr.array import Stack, Hstack, Vstack
|
|
50
|
+
|
|
51
|
+
# Stack vectors into a matrix
|
|
52
|
+
v1 = ox.State("v1", shape=(3,))
|
|
53
|
+
v2 = ox.State("v2", shape=(3,))
|
|
54
|
+
v3 = ox.State("v3", shape=(3,))
|
|
55
|
+
matrix = Stack([v1, v2, v3]) # Result shape (3, 3)
|
|
56
|
+
|
|
57
|
+
# Horizontal stacking (concatenate along columns)
|
|
58
|
+
A = ox.State("A", shape=(3, 4))
|
|
59
|
+
B = ox.State("B", shape=(3, 2))
|
|
60
|
+
wide = Hstack([A, B]) # Result shape (3, 6)
|
|
61
|
+
|
|
62
|
+
# Vertical stacking (concatenate along rows)
|
|
63
|
+
C = ox.State("C", shape=(2, 4))
|
|
64
|
+
tall = Vstack([A, C]) # Result shape (5, 4)
|
|
65
|
+
|
|
66
|
+
Building rotation matrices with Block (recommended)::
|
|
67
|
+
|
|
68
|
+
import openscvx as ox
|
|
69
|
+
from openscvx.symbolic.expr.array import Block
|
|
70
|
+
|
|
71
|
+
theta = ox.Variable("theta", shape=(1,))
|
|
72
|
+
R = Block([
|
|
73
|
+
[ox.Cos(theta), -ox.Sin(theta)],
|
|
74
|
+
[ox.Sin(theta), ox.Cos(theta)]
|
|
75
|
+
]) # 2D rotation matrix, shape (2, 2)
|
|
76
|
+
|
|
77
|
+
Building rotation matrices with stacking (alternative)::
|
|
78
|
+
|
|
79
|
+
import openscvx as ox
|
|
80
|
+
from openscvx.symbolic.expr.array import Stack, Hstack
|
|
81
|
+
|
|
82
|
+
theta = ox.Variable("theta", shape=(1,))
|
|
83
|
+
R = Stack([
|
|
84
|
+
Hstack([ox.Cos(theta), -ox.Sin(theta)]),
|
|
85
|
+
Hstack([ox.Sin(theta), ox.Cos(theta)])
|
|
86
|
+
]) # 2D rotation matrix, shape (2, 2)
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
import hashlib
|
|
90
|
+
from typing import Tuple, Union
|
|
91
|
+
|
|
92
|
+
import numpy as np
|
|
93
|
+
|
|
94
|
+
from .expr import Expr, to_expr
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class Index(Expr):
|
|
98
|
+
"""Indexing and slicing operation for symbolic expressions.
|
|
99
|
+
|
|
100
|
+
Represents indexing or slicing of an expression using NumPy-style indexing.
|
|
101
|
+
Can be created using square bracket notation on Expr objects.
|
|
102
|
+
|
|
103
|
+
Attributes:
|
|
104
|
+
base: Expression to index into
|
|
105
|
+
index: Index specification (int, slice, or tuple of indices/slices)
|
|
106
|
+
|
|
107
|
+
Example:
|
|
108
|
+
Define an Index expression:
|
|
109
|
+
|
|
110
|
+
x = ox.State("x", shape=(10,))
|
|
111
|
+
y = x[0:5] # Creates Index(x, slice(0, 5))
|
|
112
|
+
z = x[3] # Creates Index(x, 3)
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
def __init__(self, base: Expr, index: Union[int, slice, tuple]):
|
|
116
|
+
"""Initialize an indexing operation.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
base: Expression to index into
|
|
120
|
+
index: NumPy-style index (int, slice, or tuple of indices/slices)
|
|
121
|
+
"""
|
|
122
|
+
self.base = base
|
|
123
|
+
self.index = index
|
|
124
|
+
|
|
125
|
+
def children(self):
|
|
126
|
+
return [self.base]
|
|
127
|
+
|
|
128
|
+
def canonicalize(self) -> "Expr":
|
|
129
|
+
"""Canonicalize index by canonicalizing the base expression.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
Expr: Canonical form of the indexing expression
|
|
133
|
+
"""
|
|
134
|
+
base = self.base.canonicalize()
|
|
135
|
+
return Index(base, self.index)
|
|
136
|
+
|
|
137
|
+
def check_shape(self) -> Tuple[int, ...]:
|
|
138
|
+
"""Compute the shape after indexing."""
|
|
139
|
+
base_shape = self.base.check_shape()
|
|
140
|
+
dummy = np.zeros(base_shape)
|
|
141
|
+
try:
|
|
142
|
+
result = dummy[self.index]
|
|
143
|
+
except Exception as e:
|
|
144
|
+
raise ValueError(f"Bad index {self.index} for shape {base_shape}") from e
|
|
145
|
+
return result.shape
|
|
146
|
+
|
|
147
|
+
def _hash_into(self, hasher: "hashlib._Hash") -> None:
|
|
148
|
+
"""Hash Index including its index specification.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
hasher: A hashlib hash object to update
|
|
152
|
+
"""
|
|
153
|
+
hasher.update(b"Index")
|
|
154
|
+
# Hash the index specification (convert to string for generality)
|
|
155
|
+
hasher.update(repr(self.index).encode())
|
|
156
|
+
# Hash the base expression
|
|
157
|
+
self.base._hash_into(hasher)
|
|
158
|
+
|
|
159
|
+
def __repr__(self):
|
|
160
|
+
return f"{self.base!r}[{self.index!r}]"
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class Concat(Expr):
|
|
164
|
+
"""Concatenation operation for symbolic expressions.
|
|
165
|
+
|
|
166
|
+
Concatenates a sequence of expressions along their first dimension. All inputs
|
|
167
|
+
must have the same rank and matching dimensions except for the first dimension.
|
|
168
|
+
|
|
169
|
+
Attributes:
|
|
170
|
+
exprs: Tuple of expressions to concatenate
|
|
171
|
+
|
|
172
|
+
Example:
|
|
173
|
+
Define a Concat expression:
|
|
174
|
+
|
|
175
|
+
x = ox.State("x", shape=(3,))
|
|
176
|
+
y = ox.State("y", shape=(4,))
|
|
177
|
+
z = Concat(x, y) # Creates Concat(x, y), result shape (7,)
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
def __init__(self, *exprs: Expr):
|
|
181
|
+
"""Initialize a concatenation operation.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
*exprs: Expressions to concatenate along the first dimension
|
|
185
|
+
"""
|
|
186
|
+
# wrap raw values as Constant if needed
|
|
187
|
+
self.exprs = [to_expr(e) for e in exprs]
|
|
188
|
+
|
|
189
|
+
def children(self):
|
|
190
|
+
return list(self.exprs)
|
|
191
|
+
|
|
192
|
+
def canonicalize(self) -> "Expr":
|
|
193
|
+
"""Canonicalize concatenation by canonicalizing all operands.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
Expr: Canonical form of the concatenation expression
|
|
197
|
+
"""
|
|
198
|
+
exprs = [e.canonicalize() for e in self.exprs]
|
|
199
|
+
return Concat(*exprs)
|
|
200
|
+
|
|
201
|
+
def check_shape(self) -> Tuple[int, ...]:
|
|
202
|
+
"""Check concatenation shape compatibility and return result shape."""
|
|
203
|
+
shapes = [e.check_shape() for e in self.exprs]
|
|
204
|
+
shapes = [(1,) if len(s) == 0 else s for s in shapes]
|
|
205
|
+
rank = len(shapes[0])
|
|
206
|
+
if any(len(s) != rank for s in shapes):
|
|
207
|
+
raise ValueError(f"Concat rank mismatch: {shapes}")
|
|
208
|
+
if any(s[1:] != shapes[0][1:] for s in shapes[1:]):
|
|
209
|
+
raise ValueError(f"Concat non-0 dims differ: {shapes}")
|
|
210
|
+
return (sum(s[0] for s in shapes),) + shapes[0][1:]
|
|
211
|
+
|
|
212
|
+
def __repr__(self):
|
|
213
|
+
inner = ", ".join(repr(e) for e in self.exprs)
|
|
214
|
+
return f"Concat({inner})"
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class Stack(Expr):
|
|
218
|
+
"""Stack expressions vertically to create a higher-dimensional array.
|
|
219
|
+
|
|
220
|
+
Stacks a list of expressions along a new first dimension. All input expressions
|
|
221
|
+
must have the same shape. The result has shape (num_rows, *row_shape).
|
|
222
|
+
|
|
223
|
+
This is similar to numpy.array([row1, row2, ...]) or jax.numpy.stack(rows, axis=0).
|
|
224
|
+
|
|
225
|
+
Attributes:
|
|
226
|
+
rows: List of expressions to stack, each representing a "row"
|
|
227
|
+
|
|
228
|
+
Example:
|
|
229
|
+
Leverage stack to combine expressions:
|
|
230
|
+
|
|
231
|
+
x = Variable("x", shape=(3,))
|
|
232
|
+
y = Variable("y", shape=(3,))
|
|
233
|
+
z = Variable("z", shape=(3,))
|
|
234
|
+
stacked = Stack([x, y, z]) # Creates shape (3, 3)
|
|
235
|
+
# Equivalent to: [[x[0], x[1], x[2]],
|
|
236
|
+
# [y[0], y[1], y[2]],
|
|
237
|
+
# [z[0], z[1], z[2]]]
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
def __init__(self, rows):
|
|
241
|
+
"""Initialize a stack operation.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
rows: List of expressions to stack along a new first dimension.
|
|
245
|
+
All expressions must have the same shape.
|
|
246
|
+
"""
|
|
247
|
+
# rows should be a list of expressions representing each row
|
|
248
|
+
self.rows = [to_expr(row) for row in rows]
|
|
249
|
+
|
|
250
|
+
def children(self):
|
|
251
|
+
return self.rows
|
|
252
|
+
|
|
253
|
+
def canonicalize(self) -> "Expr":
|
|
254
|
+
rows = [row.canonicalize() for row in self.rows]
|
|
255
|
+
return Stack(rows)
|
|
256
|
+
|
|
257
|
+
def check_shape(self) -> Tuple[int, ...]:
|
|
258
|
+
"""Stack creates a 2D matrix from 1D rows."""
|
|
259
|
+
if not self.rows:
|
|
260
|
+
raise ValueError("Stack requires at least one row")
|
|
261
|
+
|
|
262
|
+
# All rows should have the same shape
|
|
263
|
+
row_shapes = [row.check_shape() for row in self.rows]
|
|
264
|
+
|
|
265
|
+
# Verify all rows have the same shape
|
|
266
|
+
first_shape = row_shapes[0]
|
|
267
|
+
for i, shape in enumerate(row_shapes[1:], 1):
|
|
268
|
+
if shape != first_shape:
|
|
269
|
+
raise ValueError(
|
|
270
|
+
f"Stack row {i} has shape {shape}, but row 0 has shape {first_shape}"
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
# Result shape is (num_rows, *row_shape)
|
|
274
|
+
return (len(self.rows),) + first_shape
|
|
275
|
+
|
|
276
|
+
def __repr__(self):
|
|
277
|
+
rows_repr = ", ".join(repr(row) for row in self.rows)
|
|
278
|
+
return f"Stack([{rows_repr}])"
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
class Hstack(Expr):
|
|
282
|
+
"""Horizontal stacking operation for symbolic expressions.
|
|
283
|
+
|
|
284
|
+
Concatenates expressions horizontally (along columns for 2D arrays).
|
|
285
|
+
This is analogous to numpy.hstack() or jax.numpy.hstack().
|
|
286
|
+
|
|
287
|
+
Behavior depends on input dimensionality:
|
|
288
|
+
- 1D arrays: Concatenates along axis 0 (making a longer vector)
|
|
289
|
+
- 2D arrays: Concatenates along axis 1 (columns), rows must match
|
|
290
|
+
- Higher-D: Concatenates along axis 1, all other dimensions must match
|
|
291
|
+
|
|
292
|
+
Attributes:
|
|
293
|
+
arrays: List of expressions to stack horizontally
|
|
294
|
+
|
|
295
|
+
Example:
|
|
296
|
+
1D case: concatenate vectors:
|
|
297
|
+
|
|
298
|
+
x = Variable("x", shape=(3,))
|
|
299
|
+
y = Variable("y", shape=(2,))
|
|
300
|
+
h = Hstack([x, y]) # Result shape (5,)
|
|
301
|
+
|
|
302
|
+
2D case: concatenate matrices horizontally:
|
|
303
|
+
|
|
304
|
+
A = Variable("A", shape=(3, 4))
|
|
305
|
+
B = Variable("B", shape=(3, 2))
|
|
306
|
+
C = Hstack([A, B]) # Result shape (3, 6)
|
|
307
|
+
"""
|
|
308
|
+
|
|
309
|
+
def __init__(self, arrays):
|
|
310
|
+
"""Initialize a horizontal stack operation.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
arrays: List of expressions to concatenate horizontally
|
|
314
|
+
"""
|
|
315
|
+
self.arrays = [to_expr(arr) for arr in arrays]
|
|
316
|
+
|
|
317
|
+
def children(self):
|
|
318
|
+
return self.arrays
|
|
319
|
+
|
|
320
|
+
def canonicalize(self) -> "Expr":
|
|
321
|
+
arrays = [arr.canonicalize() for arr in self.arrays]
|
|
322
|
+
return Hstack(arrays)
|
|
323
|
+
|
|
324
|
+
def check_shape(self) -> Tuple[int, ...]:
|
|
325
|
+
"""Horizontal stack concatenates arrays along the second axis (columns)."""
|
|
326
|
+
if not self.arrays:
|
|
327
|
+
raise ValueError("Hstack requires at least one array")
|
|
328
|
+
|
|
329
|
+
array_shapes = [arr.check_shape() for arr in self.arrays]
|
|
330
|
+
|
|
331
|
+
# All arrays must have the same number of dimensions
|
|
332
|
+
first_ndim = len(array_shapes[0])
|
|
333
|
+
for i, shape in enumerate(array_shapes[1:], 1):
|
|
334
|
+
if len(shape) != first_ndim:
|
|
335
|
+
raise ValueError(
|
|
336
|
+
f"Hstack array {i} has {len(shape)} dimensions, but array 0 has {first_ndim}"
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
# For 1D arrays, hstack concatenates along axis 0
|
|
340
|
+
if first_ndim == 1:
|
|
341
|
+
total_length = sum(shape[0] for shape in array_shapes)
|
|
342
|
+
return (total_length,)
|
|
343
|
+
|
|
344
|
+
# For 2D+ arrays, all dimensions except the second must match
|
|
345
|
+
first_shape = array_shapes[0]
|
|
346
|
+
for i, shape in enumerate(array_shapes[1:], 1):
|
|
347
|
+
if shape[0] != first_shape[0]:
|
|
348
|
+
raise ValueError(
|
|
349
|
+
f"Hstack array {i} has {shape[0]} rows, but array 0 has {first_shape[0]} rows"
|
|
350
|
+
)
|
|
351
|
+
if shape[2:] != first_shape[2:]:
|
|
352
|
+
raise ValueError(
|
|
353
|
+
f"Hstack array {i} has trailing dimensions {shape[2:]}, "
|
|
354
|
+
f"but array 0 has {first_shape[2:]}"
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
# Result shape: concatenate along axis 1 (columns)
|
|
358
|
+
total_cols = sum(shape[1] for shape in array_shapes)
|
|
359
|
+
return (first_shape[0], total_cols) + first_shape[2:]
|
|
360
|
+
|
|
361
|
+
def __repr__(self):
|
|
362
|
+
arrays_repr = ", ".join(repr(arr) for arr in self.arrays)
|
|
363
|
+
return f"Hstack([{arrays_repr}])"
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
class Vstack(Expr):
|
|
367
|
+
"""Vertical stacking operation for symbolic expressions.
|
|
368
|
+
|
|
369
|
+
Concatenates expressions vertically (along rows for 2D arrays).
|
|
370
|
+
This is analogous to numpy.vstack() or jax.numpy.vstack().
|
|
371
|
+
|
|
372
|
+
All input expressions must have the same number of dimensions, and all
|
|
373
|
+
dimensions except the first must match. The result concatenates along
|
|
374
|
+
axis 0 (rows).
|
|
375
|
+
|
|
376
|
+
Attributes:
|
|
377
|
+
arrays: List of expressions to stack vertically
|
|
378
|
+
|
|
379
|
+
Example:
|
|
380
|
+
Stack vectors to create a matrix:
|
|
381
|
+
|
|
382
|
+
x = Variable("x", shape=(3,))
|
|
383
|
+
y = Variable("y", shape=(3,))
|
|
384
|
+
v = Vstack([x, y]) # Result shape (2, 3)
|
|
385
|
+
|
|
386
|
+
Stack matrices vertically:
|
|
387
|
+
|
|
388
|
+
A = Variable("A", shape=(3, 4))
|
|
389
|
+
B = Variable("B", shape=(2, 4))
|
|
390
|
+
C = Vstack([A, B]) # Result shape (5, 4)
|
|
391
|
+
"""
|
|
392
|
+
|
|
393
|
+
def __init__(self, arrays):
|
|
394
|
+
"""Initialize a vertical stack operation.
|
|
395
|
+
|
|
396
|
+
Args:
|
|
397
|
+
arrays: List of expressions to concatenate vertically.
|
|
398
|
+
All must have matching dimensions except the first.
|
|
399
|
+
"""
|
|
400
|
+
self.arrays = [to_expr(arr) for arr in arrays]
|
|
401
|
+
|
|
402
|
+
def children(self):
|
|
403
|
+
return self.arrays
|
|
404
|
+
|
|
405
|
+
def canonicalize(self) -> "Expr":
|
|
406
|
+
arrays = [arr.canonicalize() for arr in self.arrays]
|
|
407
|
+
return Vstack(arrays)
|
|
408
|
+
|
|
409
|
+
def check_shape(self) -> Tuple[int, ...]:
|
|
410
|
+
"""Vertical stack concatenates arrays along the first axis (rows)."""
|
|
411
|
+
if not self.arrays:
|
|
412
|
+
raise ValueError("Vstack requires at least one array")
|
|
413
|
+
|
|
414
|
+
array_shapes = [arr.check_shape() for arr in self.arrays]
|
|
415
|
+
|
|
416
|
+
# All arrays must have the same number of dimensions
|
|
417
|
+
first_ndim = len(array_shapes[0])
|
|
418
|
+
for i, shape in enumerate(array_shapes[1:], 1):
|
|
419
|
+
if len(shape) != first_ndim:
|
|
420
|
+
raise ValueError(
|
|
421
|
+
f"Vstack array {i} has {len(shape)} dimensions, but array 0 has {first_ndim}"
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
# All dimensions except the first must match
|
|
425
|
+
first_shape = array_shapes[0]
|
|
426
|
+
for i, shape in enumerate(array_shapes[1:], 1):
|
|
427
|
+
if shape[1:] != first_shape[1:]:
|
|
428
|
+
raise ValueError(
|
|
429
|
+
f"Vstack array {i} has trailing dimensions {shape[1:]}, "
|
|
430
|
+
f"but array 0 has {first_shape[1:]}"
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
# Result shape: concatenate along axis 0 (rows)
|
|
434
|
+
total_rows = sum(shape[0] for shape in array_shapes)
|
|
435
|
+
return (total_rows,) + first_shape[1:]
|
|
436
|
+
|
|
437
|
+
def __repr__(self):
|
|
438
|
+
arrays_repr = ", ".join(repr(arr) for arr in self.arrays)
|
|
439
|
+
return f"Vstack([{arrays_repr}])"
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
class Block(Expr):
|
|
443
|
+
"""Block matrix/tensor construction from nested arrays of expressions.
|
|
444
|
+
|
|
445
|
+
Assembles a block matrix (or N-D tensor) from a nested list of expressions,
|
|
446
|
+
analogous to numpy.block(). Each inner list represents a row of blocks, and
|
|
447
|
+
blocks within the same row are concatenated horizontally, while rows are
|
|
448
|
+
stacked vertically.
|
|
449
|
+
|
|
450
|
+
This provides a convenient way to construct matrices from sub-expressions
|
|
451
|
+
without manually nesting Stack/Hstack/Vstack operations.
|
|
452
|
+
|
|
453
|
+
Attributes:
|
|
454
|
+
blocks: Nested list of expressions forming the block structure (each
|
|
455
|
+
expression can be a scalar, 1D, 2D, or N-D tensor)
|
|
456
|
+
|
|
457
|
+
Example:
|
|
458
|
+
Build a 2D rotation matrix::
|
|
459
|
+
|
|
460
|
+
import openscvx as ox
|
|
461
|
+
from openscvx.symbolic.expr.array import Block
|
|
462
|
+
|
|
463
|
+
theta = ox.Variable("theta", shape=(1,))
|
|
464
|
+
R = Block([
|
|
465
|
+
[ox.Cos(theta), -ox.Sin(theta)],
|
|
466
|
+
[ox.Sin(theta), ox.Cos(theta)]
|
|
467
|
+
]) # Result shape (2, 2)
|
|
468
|
+
|
|
469
|
+
Build a block diagonal matrix::
|
|
470
|
+
|
|
471
|
+
A = ox.State("A", shape=(2, 2))
|
|
472
|
+
B = ox.State("B", shape=(3, 3))
|
|
473
|
+
zeros_23 = ox.Constant(np.zeros((2, 3)))
|
|
474
|
+
zeros_32 = ox.Constant(np.zeros((3, 2)))
|
|
475
|
+
block_diag = Block([
|
|
476
|
+
[A, zeros_23],
|
|
477
|
+
[zeros_32, B]
|
|
478
|
+
]) # Result shape (5, 5)
|
|
479
|
+
|
|
480
|
+
Build from scalars and expressions::
|
|
481
|
+
|
|
482
|
+
x = ox.State("x", shape=(1,))
|
|
483
|
+
y = ox.State("y", shape=(1,))
|
|
484
|
+
# Scalars are automatically promoted to 1D arrays
|
|
485
|
+
M = Block([
|
|
486
|
+
[x, 0],
|
|
487
|
+
[0, y]
|
|
488
|
+
]) # Result shape (2, 2)
|
|
489
|
+
|
|
490
|
+
Note:
|
|
491
|
+
- All blocks in the same row must have the same height (first dimension)
|
|
492
|
+
- All blocks in the same column must have the same width (second dimension)
|
|
493
|
+
- For N-D tensors (3D+), all trailing dimensions must match across all blocks
|
|
494
|
+
- Scalar values and raw Python lists are automatically wrapped via to_expr()
|
|
495
|
+
- 1D arrays are treated as row vectors when determining block dimensions
|
|
496
|
+
- N-D tensors are supported for JAX lowering; CVXPy only supports 2D blocks
|
|
497
|
+
"""
|
|
498
|
+
|
|
499
|
+
def __init__(self, blocks):
|
|
500
|
+
"""Initialize a block matrix construction.
|
|
501
|
+
|
|
502
|
+
Args:
|
|
503
|
+
blocks: A nested list of expressions. Can be either:
|
|
504
|
+
- 2D: [[row1_blocks], [row2_blocks], ...] for multiple rows
|
|
505
|
+
- 1D: [block1, block2, ...] for a single row (auto-promoted to [[...]])
|
|
506
|
+
Raw values (numbers, lists, numpy arrays) are automatically
|
|
507
|
+
converted to Constant expressions.
|
|
508
|
+
|
|
509
|
+
Raises:
|
|
510
|
+
ValueError: If blocks is empty
|
|
511
|
+
"""
|
|
512
|
+
if not blocks:
|
|
513
|
+
raise ValueError("Block requires at least one row")
|
|
514
|
+
|
|
515
|
+
# Auto-promote 1D list to 2D (matching numpy.block behavior)
|
|
516
|
+
# e.g., Block([a, b]) -> Block([[a, b]])
|
|
517
|
+
if not isinstance(blocks[0], (list, tuple)):
|
|
518
|
+
blocks = [blocks]
|
|
519
|
+
|
|
520
|
+
# Convert all blocks to expressions
|
|
521
|
+
self.blocks = [[to_expr(block) for block in row] for row in blocks]
|
|
522
|
+
|
|
523
|
+
# Validate consistent row lengths
|
|
524
|
+
row_lengths = [len(row) for row in self.blocks]
|
|
525
|
+
if len(set(row_lengths)) > 1:
|
|
526
|
+
raise ValueError(
|
|
527
|
+
f"All rows must have the same number of blocks. Got row lengths: {row_lengths}"
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
def children(self):
|
|
531
|
+
"""Return all block expressions in row-major order."""
|
|
532
|
+
return [block for row in self.blocks for block in row]
|
|
533
|
+
|
|
534
|
+
def canonicalize(self) -> "Expr":
|
|
535
|
+
"""Canonicalize by recursively canonicalizing all blocks.
|
|
536
|
+
|
|
537
|
+
If the block contains only a single element ([[a]]), returns the
|
|
538
|
+
canonicalized element directly to simplify the expression tree.
|
|
539
|
+
"""
|
|
540
|
+
canonical_blocks = [[block.canonicalize() for block in row] for row in self.blocks]
|
|
541
|
+
|
|
542
|
+
# Unwrap single-element blocks
|
|
543
|
+
if len(canonical_blocks) == 1 and len(canonical_blocks[0]) == 1:
|
|
544
|
+
return canonical_blocks[0][0]
|
|
545
|
+
|
|
546
|
+
return Block(canonical_blocks)
|
|
547
|
+
|
|
548
|
+
def check_shape(self) -> Tuple[int, ...]:
|
|
549
|
+
"""Validate block dimensions and compute output shape.
|
|
550
|
+
|
|
551
|
+
For 2D blocks, returns (total_rows, total_cols). For N-D blocks,
|
|
552
|
+
returns the shape after assembling blocks along the first two axes,
|
|
553
|
+
with trailing dimensions preserved.
|
|
554
|
+
|
|
555
|
+
Returns:
|
|
556
|
+
Tuple representing the assembled block array shape
|
|
557
|
+
|
|
558
|
+
Raises:
|
|
559
|
+
ValueError: If block dimensions are incompatible
|
|
560
|
+
"""
|
|
561
|
+
n_block_rows = len(self.blocks)
|
|
562
|
+
n_block_cols = len(self.blocks[0])
|
|
563
|
+
|
|
564
|
+
# Get shapes of all blocks
|
|
565
|
+
block_shapes = [[block.check_shape() for block in row] for row in self.blocks]
|
|
566
|
+
|
|
567
|
+
# Determine the maximum dimensionality across all blocks
|
|
568
|
+
max_ndim = max(len(shape) for row in block_shapes for shape in row)
|
|
569
|
+
max_ndim = max(max_ndim, 2) # At least 2D for block assembly
|
|
570
|
+
|
|
571
|
+
# Normalize shapes: pad to max_ndim by prepending 1s
|
|
572
|
+
# Scalars () -> (1, 1, ...), 1D (n,) -> (1, n, ...), etc.
|
|
573
|
+
def normalize_shape(shape):
|
|
574
|
+
if len(shape) == 0:
|
|
575
|
+
return (1,) * max_ndim
|
|
576
|
+
elif len(shape) < max_ndim:
|
|
577
|
+
# Prepend 1s to match max_ndim
|
|
578
|
+
return (1,) * (max_ndim - len(shape)) + shape
|
|
579
|
+
else:
|
|
580
|
+
return shape
|
|
581
|
+
|
|
582
|
+
normalized_shapes = [[normalize_shape(shape) for shape in row] for row in block_shapes]
|
|
583
|
+
|
|
584
|
+
# Validate trailing dimensions (dims 2+) match across ALL blocks
|
|
585
|
+
if max_ndim > 2:
|
|
586
|
+
trailing_shape = normalized_shapes[0][0][2:]
|
|
587
|
+
for i, row_shapes in enumerate(normalized_shapes):
|
|
588
|
+
for j, shape in enumerate(row_shapes):
|
|
589
|
+
if shape[2:] != trailing_shape:
|
|
590
|
+
raise ValueError(
|
|
591
|
+
f"Block[{i}][{j}] has trailing dimensions {shape[2:]}, "
|
|
592
|
+
f"but Block[0][0] has {trailing_shape}. "
|
|
593
|
+
f"All blocks must have matching dimensions beyond the first two."
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
# Compute row heights (first dimension of each row must match)
|
|
597
|
+
row_heights = []
|
|
598
|
+
for i, row_shapes in enumerate(normalized_shapes):
|
|
599
|
+
heights = [s[0] for s in row_shapes]
|
|
600
|
+
if len(set(heights)) > 1:
|
|
601
|
+
raise ValueError(
|
|
602
|
+
f"Block row {i} has inconsistent heights: {heights}. "
|
|
603
|
+
f"All blocks in a row must have the same height."
|
|
604
|
+
)
|
|
605
|
+
row_heights.append(heights[0])
|
|
606
|
+
|
|
607
|
+
# Compute column widths (second dimension of each column must match)
|
|
608
|
+
col_widths = []
|
|
609
|
+
for j in range(n_block_cols):
|
|
610
|
+
widths = [normalized_shapes[i][j][1] for i in range(n_block_rows)]
|
|
611
|
+
if len(set(widths)) > 1:
|
|
612
|
+
raise ValueError(
|
|
613
|
+
f"Block column {j} has inconsistent widths: {widths}. "
|
|
614
|
+
f"All blocks in a column must have the same width."
|
|
615
|
+
)
|
|
616
|
+
col_widths.append(widths[0])
|
|
617
|
+
|
|
618
|
+
total_rows = sum(row_heights)
|
|
619
|
+
total_cols = sum(col_widths)
|
|
620
|
+
|
|
621
|
+
# Return shape with trailing dimensions if present
|
|
622
|
+
if max_ndim > 2:
|
|
623
|
+
return (total_rows, total_cols) + normalized_shapes[0][0][2:]
|
|
624
|
+
return (total_rows, total_cols)
|
|
625
|
+
|
|
626
|
+
def __repr__(self):
|
|
627
|
+
rows_repr = []
|
|
628
|
+
for row in self.blocks:
|
|
629
|
+
blocks_repr = ", ".join(repr(block) for block in row)
|
|
630
|
+
rows_repr.append(f"[{blocks_repr}]")
|
|
631
|
+
inner = ", ".join(rows_repr)
|
|
632
|
+
return f"Block([{inner}])"
|