openscvx 0.3.2.dev170__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

Files changed (79) hide show
  1. openscvx/__init__.py +123 -0
  2. openscvx/_version.py +34 -0
  3. openscvx/algorithms/__init__.py +92 -0
  4. openscvx/algorithms/autotuning.py +24 -0
  5. openscvx/algorithms/base.py +351 -0
  6. openscvx/algorithms/optimization_results.py +215 -0
  7. openscvx/algorithms/penalized_trust_region.py +384 -0
  8. openscvx/config.py +437 -0
  9. openscvx/discretization/__init__.py +47 -0
  10. openscvx/discretization/discretization.py +236 -0
  11. openscvx/expert/__init__.py +23 -0
  12. openscvx/expert/byof.py +326 -0
  13. openscvx/expert/lowering.py +419 -0
  14. openscvx/expert/validation.py +357 -0
  15. openscvx/integrators/__init__.py +48 -0
  16. openscvx/integrators/runge_kutta.py +281 -0
  17. openscvx/lowered/__init__.py +30 -0
  18. openscvx/lowered/cvxpy_constraints.py +23 -0
  19. openscvx/lowered/cvxpy_variables.py +124 -0
  20. openscvx/lowered/dynamics.py +34 -0
  21. openscvx/lowered/jax_constraints.py +133 -0
  22. openscvx/lowered/parameters.py +54 -0
  23. openscvx/lowered/problem.py +70 -0
  24. openscvx/lowered/unified.py +718 -0
  25. openscvx/plotting/__init__.py +63 -0
  26. openscvx/plotting/plotting.py +756 -0
  27. openscvx/plotting/scp_iteration.py +299 -0
  28. openscvx/plotting/viser/__init__.py +126 -0
  29. openscvx/plotting/viser/animated.py +605 -0
  30. openscvx/plotting/viser/plotly_integration.py +333 -0
  31. openscvx/plotting/viser/primitives.py +355 -0
  32. openscvx/plotting/viser/scp.py +459 -0
  33. openscvx/plotting/viser/server.py +112 -0
  34. openscvx/problem.py +734 -0
  35. openscvx/propagation/__init__.py +60 -0
  36. openscvx/propagation/post_processing.py +104 -0
  37. openscvx/propagation/propagation.py +248 -0
  38. openscvx/solvers/__init__.py +51 -0
  39. openscvx/solvers/cvxpy.py +226 -0
  40. openscvx/symbolic/__init__.py +9 -0
  41. openscvx/symbolic/augmentation.py +630 -0
  42. openscvx/symbolic/builder.py +492 -0
  43. openscvx/symbolic/constraint_set.py +92 -0
  44. openscvx/symbolic/expr/__init__.py +222 -0
  45. openscvx/symbolic/expr/arithmetic.py +517 -0
  46. openscvx/symbolic/expr/array.py +632 -0
  47. openscvx/symbolic/expr/constraint.py +796 -0
  48. openscvx/symbolic/expr/control.py +135 -0
  49. openscvx/symbolic/expr/expr.py +720 -0
  50. openscvx/symbolic/expr/lie/__init__.py +87 -0
  51. openscvx/symbolic/expr/lie/adjoint.py +357 -0
  52. openscvx/symbolic/expr/lie/se3.py +172 -0
  53. openscvx/symbolic/expr/lie/so3.py +138 -0
  54. openscvx/symbolic/expr/linalg.py +279 -0
  55. openscvx/symbolic/expr/math.py +699 -0
  56. openscvx/symbolic/expr/spatial.py +209 -0
  57. openscvx/symbolic/expr/state.py +607 -0
  58. openscvx/symbolic/expr/stl.py +136 -0
  59. openscvx/symbolic/expr/variable.py +321 -0
  60. openscvx/symbolic/hashing.py +112 -0
  61. openscvx/symbolic/lower.py +760 -0
  62. openscvx/symbolic/lowerers/__init__.py +106 -0
  63. openscvx/symbolic/lowerers/cvxpy.py +1302 -0
  64. openscvx/symbolic/lowerers/jax.py +1382 -0
  65. openscvx/symbolic/preprocessing.py +757 -0
  66. openscvx/symbolic/problem.py +110 -0
  67. openscvx/symbolic/time.py +116 -0
  68. openscvx/symbolic/unified.py +420 -0
  69. openscvx/utils/__init__.py +20 -0
  70. openscvx/utils/cache.py +131 -0
  71. openscvx/utils/caching.py +210 -0
  72. openscvx/utils/printing.py +301 -0
  73. openscvx/utils/profiling.py +37 -0
  74. openscvx/utils/utils.py +100 -0
  75. openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
  76. openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
  77. openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
  78. openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
  79. openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
@@ -0,0 +1,632 @@
1
+ """Array manipulation operations for symbolic expressions.
2
+
3
+ This module provides operations for indexing, slicing, concatenating, and stacking
4
+ symbolic expressions. These are structural operations that manipulate array shapes
5
+ and combine or extract array elements, as opposed to mathematical transformations.
6
+
7
+ Key Operations:
8
+
9
+ - **Indexing and Slicing:**
10
+ - `Index` - NumPy-style indexing and slicing to extract subarrays
11
+
12
+ - **Concatenation:**
13
+ - `Concat` - Concatenate expressions along the first dimension (axis 0)
14
+
15
+ - **Stacking:**
16
+ - `Stack` - Stack expressions along a new first dimension
17
+ - `Hstack` - Horizontal stacking (along columns for 2D arrays)
18
+ - `Vstack` - Vertical stacking (along rows for 2D arrays)
19
+
20
+ - **Block Matrix Construction:**
21
+ - `Block` - Assemble block matrices from nested arrays (like numpy.block)
22
+
23
+ All operations follow NumPy conventions for shapes and indexing behavior, enabling
24
+ familiar array manipulation patterns in symbolic optimization problems.
25
+
26
+ Example:
27
+ Indexing and slicing arrays::
28
+
29
+ import openscvx as ox
30
+
31
+ x = ox.State("x", shape=(10,))
32
+ first_half = x[0:5] # Slice: Index(x, slice(0, 5))
33
+ element = x[3] # Single element: Index(x, 3)
34
+
35
+ A = ox.State("A", shape=(5, 4))
36
+ row = A[2, :] # Extract row
37
+ col = A[:, 1] # Extract column
38
+
39
+ Concatenating expressions::
40
+
41
+ from openscvx.symbolic.expr.array import Concat
42
+
43
+ x = ox.State("x", shape=(3,))
44
+ y = ox.State("y", shape=(4,))
45
+ combined = Concat(x, y) # Result shape (7,)
46
+
47
+ Stacking to build matrices::
48
+
49
+ from openscvx.symbolic.expr.array import Stack, Hstack, Vstack
50
+
51
+ # Stack vectors into a matrix
52
+ v1 = ox.State("v1", shape=(3,))
53
+ v2 = ox.State("v2", shape=(3,))
54
+ v3 = ox.State("v3", shape=(3,))
55
+ matrix = Stack([v1, v2, v3]) # Result shape (3, 3)
56
+
57
+ # Horizontal stacking (concatenate along columns)
58
+ A = ox.State("A", shape=(3, 4))
59
+ B = ox.State("B", shape=(3, 2))
60
+ wide = Hstack([A, B]) # Result shape (3, 6)
61
+
62
+ # Vertical stacking (concatenate along rows)
63
+ C = ox.State("C", shape=(2, 4))
64
+ tall = Vstack([A, C]) # Result shape (5, 4)
65
+
66
+ Building rotation matrices with Block (recommended)::
67
+
68
+ import openscvx as ox
69
+ from openscvx.symbolic.expr.array import Block
70
+
71
+ theta = ox.Variable("theta", shape=(1,))
72
+ R = Block([
73
+ [ox.Cos(theta), -ox.Sin(theta)],
74
+ [ox.Sin(theta), ox.Cos(theta)]
75
+ ]) # 2D rotation matrix, shape (2, 2)
76
+
77
+ Building rotation matrices with stacking (alternative)::
78
+
79
+ import openscvx as ox
80
+ from openscvx.symbolic.expr.array import Stack, Hstack
81
+
82
+ theta = ox.Variable("theta", shape=(1,))
83
+ R = Stack([
84
+ Hstack([ox.Cos(theta), -ox.Sin(theta)]),
85
+ Hstack([ox.Sin(theta), ox.Cos(theta)])
86
+ ]) # 2D rotation matrix, shape (2, 2)
87
+ """
88
+
89
+ import hashlib
90
+ from typing import Tuple, Union
91
+
92
+ import numpy as np
93
+
94
+ from .expr import Expr, to_expr
95
+
96
+
97
+ class Index(Expr):
98
+ """Indexing and slicing operation for symbolic expressions.
99
+
100
+ Represents indexing or slicing of an expression using NumPy-style indexing.
101
+ Can be created using square bracket notation on Expr objects.
102
+
103
+ Attributes:
104
+ base: Expression to index into
105
+ index: Index specification (int, slice, or tuple of indices/slices)
106
+
107
+ Example:
108
+ Define an Index expression:
109
+
110
+ x = ox.State("x", shape=(10,))
111
+ y = x[0:5] # Creates Index(x, slice(0, 5))
112
+ z = x[3] # Creates Index(x, 3)
113
+ """
114
+
115
+ def __init__(self, base: Expr, index: Union[int, slice, tuple]):
116
+ """Initialize an indexing operation.
117
+
118
+ Args:
119
+ base: Expression to index into
120
+ index: NumPy-style index (int, slice, or tuple of indices/slices)
121
+ """
122
+ self.base = base
123
+ self.index = index
124
+
125
+ def children(self):
126
+ return [self.base]
127
+
128
+ def canonicalize(self) -> "Expr":
129
+ """Canonicalize index by canonicalizing the base expression.
130
+
131
+ Returns:
132
+ Expr: Canonical form of the indexing expression
133
+ """
134
+ base = self.base.canonicalize()
135
+ return Index(base, self.index)
136
+
137
+ def check_shape(self) -> Tuple[int, ...]:
138
+ """Compute the shape after indexing."""
139
+ base_shape = self.base.check_shape()
140
+ dummy = np.zeros(base_shape)
141
+ try:
142
+ result = dummy[self.index]
143
+ except Exception as e:
144
+ raise ValueError(f"Bad index {self.index} for shape {base_shape}") from e
145
+ return result.shape
146
+
147
+ def _hash_into(self, hasher: "hashlib._Hash") -> None:
148
+ """Hash Index including its index specification.
149
+
150
+ Args:
151
+ hasher: A hashlib hash object to update
152
+ """
153
+ hasher.update(b"Index")
154
+ # Hash the index specification (convert to string for generality)
155
+ hasher.update(repr(self.index).encode())
156
+ # Hash the base expression
157
+ self.base._hash_into(hasher)
158
+
159
+ def __repr__(self):
160
+ return f"{self.base!r}[{self.index!r}]"
161
+
162
+
163
+ class Concat(Expr):
164
+ """Concatenation operation for symbolic expressions.
165
+
166
+ Concatenates a sequence of expressions along their first dimension. All inputs
167
+ must have the same rank and matching dimensions except for the first dimension.
168
+
169
+ Attributes:
170
+ exprs: Tuple of expressions to concatenate
171
+
172
+ Example:
173
+ Define a Concat expression:
174
+
175
+ x = ox.State("x", shape=(3,))
176
+ y = ox.State("y", shape=(4,))
177
+ z = Concat(x, y) # Creates Concat(x, y), result shape (7,)
178
+ """
179
+
180
+ def __init__(self, *exprs: Expr):
181
+ """Initialize a concatenation operation.
182
+
183
+ Args:
184
+ *exprs: Expressions to concatenate along the first dimension
185
+ """
186
+ # wrap raw values as Constant if needed
187
+ self.exprs = [to_expr(e) for e in exprs]
188
+
189
+ def children(self):
190
+ return list(self.exprs)
191
+
192
+ def canonicalize(self) -> "Expr":
193
+ """Canonicalize concatenation by canonicalizing all operands.
194
+
195
+ Returns:
196
+ Expr: Canonical form of the concatenation expression
197
+ """
198
+ exprs = [e.canonicalize() for e in self.exprs]
199
+ return Concat(*exprs)
200
+
201
+ def check_shape(self) -> Tuple[int, ...]:
202
+ """Check concatenation shape compatibility and return result shape."""
203
+ shapes = [e.check_shape() for e in self.exprs]
204
+ shapes = [(1,) if len(s) == 0 else s for s in shapes]
205
+ rank = len(shapes[0])
206
+ if any(len(s) != rank for s in shapes):
207
+ raise ValueError(f"Concat rank mismatch: {shapes}")
208
+ if any(s[1:] != shapes[0][1:] for s in shapes[1:]):
209
+ raise ValueError(f"Concat non-0 dims differ: {shapes}")
210
+ return (sum(s[0] for s in shapes),) + shapes[0][1:]
211
+
212
+ def __repr__(self):
213
+ inner = ", ".join(repr(e) for e in self.exprs)
214
+ return f"Concat({inner})"
215
+
216
+
217
+ class Stack(Expr):
218
+ """Stack expressions vertically to create a higher-dimensional array.
219
+
220
+ Stacks a list of expressions along a new first dimension. All input expressions
221
+ must have the same shape. The result has shape (num_rows, *row_shape).
222
+
223
+ This is similar to numpy.array([row1, row2, ...]) or jax.numpy.stack(rows, axis=0).
224
+
225
+ Attributes:
226
+ rows: List of expressions to stack, each representing a "row"
227
+
228
+ Example:
229
+ Leverage stack to combine expressions:
230
+
231
+ x = Variable("x", shape=(3,))
232
+ y = Variable("y", shape=(3,))
233
+ z = Variable("z", shape=(3,))
234
+ stacked = Stack([x, y, z]) # Creates shape (3, 3)
235
+ # Equivalent to: [[x[0], x[1], x[2]],
236
+ # [y[0], y[1], y[2]],
237
+ # [z[0], z[1], z[2]]]
238
+ """
239
+
240
+ def __init__(self, rows):
241
+ """Initialize a stack operation.
242
+
243
+ Args:
244
+ rows: List of expressions to stack along a new first dimension.
245
+ All expressions must have the same shape.
246
+ """
247
+ # rows should be a list of expressions representing each row
248
+ self.rows = [to_expr(row) for row in rows]
249
+
250
+ def children(self):
251
+ return self.rows
252
+
253
+ def canonicalize(self) -> "Expr":
254
+ rows = [row.canonicalize() for row in self.rows]
255
+ return Stack(rows)
256
+
257
+ def check_shape(self) -> Tuple[int, ...]:
258
+ """Stack creates a 2D matrix from 1D rows."""
259
+ if not self.rows:
260
+ raise ValueError("Stack requires at least one row")
261
+
262
+ # All rows should have the same shape
263
+ row_shapes = [row.check_shape() for row in self.rows]
264
+
265
+ # Verify all rows have the same shape
266
+ first_shape = row_shapes[0]
267
+ for i, shape in enumerate(row_shapes[1:], 1):
268
+ if shape != first_shape:
269
+ raise ValueError(
270
+ f"Stack row {i} has shape {shape}, but row 0 has shape {first_shape}"
271
+ )
272
+
273
+ # Result shape is (num_rows, *row_shape)
274
+ return (len(self.rows),) + first_shape
275
+
276
+ def __repr__(self):
277
+ rows_repr = ", ".join(repr(row) for row in self.rows)
278
+ return f"Stack([{rows_repr}])"
279
+
280
+
281
+ class Hstack(Expr):
282
+ """Horizontal stacking operation for symbolic expressions.
283
+
284
+ Concatenates expressions horizontally (along columns for 2D arrays).
285
+ This is analogous to numpy.hstack() or jax.numpy.hstack().
286
+
287
+ Behavior depends on input dimensionality:
288
+ - 1D arrays: Concatenates along axis 0 (making a longer vector)
289
+ - 2D arrays: Concatenates along axis 1 (columns), rows must match
290
+ - Higher-D: Concatenates along axis 1, all other dimensions must match
291
+
292
+ Attributes:
293
+ arrays: List of expressions to stack horizontally
294
+
295
+ Example:
296
+ 1D case: concatenate vectors:
297
+
298
+ x = Variable("x", shape=(3,))
299
+ y = Variable("y", shape=(2,))
300
+ h = Hstack([x, y]) # Result shape (5,)
301
+
302
+ 2D case: concatenate matrices horizontally:
303
+
304
+ A = Variable("A", shape=(3, 4))
305
+ B = Variable("B", shape=(3, 2))
306
+ C = Hstack([A, B]) # Result shape (3, 6)
307
+ """
308
+
309
+ def __init__(self, arrays):
310
+ """Initialize a horizontal stack operation.
311
+
312
+ Args:
313
+ arrays: List of expressions to concatenate horizontally
314
+ """
315
+ self.arrays = [to_expr(arr) for arr in arrays]
316
+
317
+ def children(self):
318
+ return self.arrays
319
+
320
+ def canonicalize(self) -> "Expr":
321
+ arrays = [arr.canonicalize() for arr in self.arrays]
322
+ return Hstack(arrays)
323
+
324
+ def check_shape(self) -> Tuple[int, ...]:
325
+ """Horizontal stack concatenates arrays along the second axis (columns)."""
326
+ if not self.arrays:
327
+ raise ValueError("Hstack requires at least one array")
328
+
329
+ array_shapes = [arr.check_shape() for arr in self.arrays]
330
+
331
+ # All arrays must have the same number of dimensions
332
+ first_ndim = len(array_shapes[0])
333
+ for i, shape in enumerate(array_shapes[1:], 1):
334
+ if len(shape) != first_ndim:
335
+ raise ValueError(
336
+ f"Hstack array {i} has {len(shape)} dimensions, but array 0 has {first_ndim}"
337
+ )
338
+
339
+ # For 1D arrays, hstack concatenates along axis 0
340
+ if first_ndim == 1:
341
+ total_length = sum(shape[0] for shape in array_shapes)
342
+ return (total_length,)
343
+
344
+ # For 2D+ arrays, all dimensions except the second must match
345
+ first_shape = array_shapes[0]
346
+ for i, shape in enumerate(array_shapes[1:], 1):
347
+ if shape[0] != first_shape[0]:
348
+ raise ValueError(
349
+ f"Hstack array {i} has {shape[0]} rows, but array 0 has {first_shape[0]} rows"
350
+ )
351
+ if shape[2:] != first_shape[2:]:
352
+ raise ValueError(
353
+ f"Hstack array {i} has trailing dimensions {shape[2:]}, "
354
+ f"but array 0 has {first_shape[2:]}"
355
+ )
356
+
357
+ # Result shape: concatenate along axis 1 (columns)
358
+ total_cols = sum(shape[1] for shape in array_shapes)
359
+ return (first_shape[0], total_cols) + first_shape[2:]
360
+
361
+ def __repr__(self):
362
+ arrays_repr = ", ".join(repr(arr) for arr in self.arrays)
363
+ return f"Hstack([{arrays_repr}])"
364
+
365
+
366
+ class Vstack(Expr):
367
+ """Vertical stacking operation for symbolic expressions.
368
+
369
+ Concatenates expressions vertically (along rows for 2D arrays).
370
+ This is analogous to numpy.vstack() or jax.numpy.vstack().
371
+
372
+ All input expressions must have the same number of dimensions, and all
373
+ dimensions except the first must match. The result concatenates along
374
+ axis 0 (rows).
375
+
376
+ Attributes:
377
+ arrays: List of expressions to stack vertically
378
+
379
+ Example:
380
+ Stack vectors to create a matrix:
381
+
382
+ x = Variable("x", shape=(3,))
383
+ y = Variable("y", shape=(3,))
384
+ v = Vstack([x, y]) # Result shape (2, 3)
385
+
386
+ Stack matrices vertically:
387
+
388
+ A = Variable("A", shape=(3, 4))
389
+ B = Variable("B", shape=(2, 4))
390
+ C = Vstack([A, B]) # Result shape (5, 4)
391
+ """
392
+
393
+ def __init__(self, arrays):
394
+ """Initialize a vertical stack operation.
395
+
396
+ Args:
397
+ arrays: List of expressions to concatenate vertically.
398
+ All must have matching dimensions except the first.
399
+ """
400
+ self.arrays = [to_expr(arr) for arr in arrays]
401
+
402
+ def children(self):
403
+ return self.arrays
404
+
405
+ def canonicalize(self) -> "Expr":
406
+ arrays = [arr.canonicalize() for arr in self.arrays]
407
+ return Vstack(arrays)
408
+
409
+ def check_shape(self) -> Tuple[int, ...]:
410
+ """Vertical stack concatenates arrays along the first axis (rows)."""
411
+ if not self.arrays:
412
+ raise ValueError("Vstack requires at least one array")
413
+
414
+ array_shapes = [arr.check_shape() for arr in self.arrays]
415
+
416
+ # All arrays must have the same number of dimensions
417
+ first_ndim = len(array_shapes[0])
418
+ for i, shape in enumerate(array_shapes[1:], 1):
419
+ if len(shape) != first_ndim:
420
+ raise ValueError(
421
+ f"Vstack array {i} has {len(shape)} dimensions, but array 0 has {first_ndim}"
422
+ )
423
+
424
+ # All dimensions except the first must match
425
+ first_shape = array_shapes[0]
426
+ for i, shape in enumerate(array_shapes[1:], 1):
427
+ if shape[1:] != first_shape[1:]:
428
+ raise ValueError(
429
+ f"Vstack array {i} has trailing dimensions {shape[1:]}, "
430
+ f"but array 0 has {first_shape[1:]}"
431
+ )
432
+
433
+ # Result shape: concatenate along axis 0 (rows)
434
+ total_rows = sum(shape[0] for shape in array_shapes)
435
+ return (total_rows,) + first_shape[1:]
436
+
437
+ def __repr__(self):
438
+ arrays_repr = ", ".join(repr(arr) for arr in self.arrays)
439
+ return f"Vstack([{arrays_repr}])"
440
+
441
+
442
+ class Block(Expr):
443
+ """Block matrix/tensor construction from nested arrays of expressions.
444
+
445
+ Assembles a block matrix (or N-D tensor) from a nested list of expressions,
446
+ analogous to numpy.block(). Each inner list represents a row of blocks, and
447
+ blocks within the same row are concatenated horizontally, while rows are
448
+ stacked vertically.
449
+
450
+ This provides a convenient way to construct matrices from sub-expressions
451
+ without manually nesting Stack/Hstack/Vstack operations.
452
+
453
+ Attributes:
454
+ blocks: Nested list of expressions forming the block structure (each
455
+ expression can be a scalar, 1D, 2D, or N-D tensor)
456
+
457
+ Example:
458
+ Build a 2D rotation matrix::
459
+
460
+ import openscvx as ox
461
+ from openscvx.symbolic.expr.array import Block
462
+
463
+ theta = ox.Variable("theta", shape=(1,))
464
+ R = Block([
465
+ [ox.Cos(theta), -ox.Sin(theta)],
466
+ [ox.Sin(theta), ox.Cos(theta)]
467
+ ]) # Result shape (2, 2)
468
+
469
+ Build a block diagonal matrix::
470
+
471
+ A = ox.State("A", shape=(2, 2))
472
+ B = ox.State("B", shape=(3, 3))
473
+ zeros_23 = ox.Constant(np.zeros((2, 3)))
474
+ zeros_32 = ox.Constant(np.zeros((3, 2)))
475
+ block_diag = Block([
476
+ [A, zeros_23],
477
+ [zeros_32, B]
478
+ ]) # Result shape (5, 5)
479
+
480
+ Build from scalars and expressions::
481
+
482
+ x = ox.State("x", shape=(1,))
483
+ y = ox.State("y", shape=(1,))
484
+ # Scalars are automatically promoted to 1D arrays
485
+ M = Block([
486
+ [x, 0],
487
+ [0, y]
488
+ ]) # Result shape (2, 2)
489
+
490
+ Note:
491
+ - All blocks in the same row must have the same height (first dimension)
492
+ - All blocks in the same column must have the same width (second dimension)
493
+ - For N-D tensors (3D+), all trailing dimensions must match across all blocks
494
+ - Scalar values and raw Python lists are automatically wrapped via to_expr()
495
+ - 1D arrays are treated as row vectors when determining block dimensions
496
+ - N-D tensors are supported for JAX lowering; CVXPy only supports 2D blocks
497
+ """
498
+
499
+ def __init__(self, blocks):
500
+ """Initialize a block matrix construction.
501
+
502
+ Args:
503
+ blocks: A nested list of expressions. Can be either:
504
+ - 2D: [[row1_blocks], [row2_blocks], ...] for multiple rows
505
+ - 1D: [block1, block2, ...] for a single row (auto-promoted to [[...]])
506
+ Raw values (numbers, lists, numpy arrays) are automatically
507
+ converted to Constant expressions.
508
+
509
+ Raises:
510
+ ValueError: If blocks is empty
511
+ """
512
+ if not blocks:
513
+ raise ValueError("Block requires at least one row")
514
+
515
+ # Auto-promote 1D list to 2D (matching numpy.block behavior)
516
+ # e.g., Block([a, b]) -> Block([[a, b]])
517
+ if not isinstance(blocks[0], (list, tuple)):
518
+ blocks = [blocks]
519
+
520
+ # Convert all blocks to expressions
521
+ self.blocks = [[to_expr(block) for block in row] for row in blocks]
522
+
523
+ # Validate consistent row lengths
524
+ row_lengths = [len(row) for row in self.blocks]
525
+ if len(set(row_lengths)) > 1:
526
+ raise ValueError(
527
+ f"All rows must have the same number of blocks. Got row lengths: {row_lengths}"
528
+ )
529
+
530
+ def children(self):
531
+ """Return all block expressions in row-major order."""
532
+ return [block for row in self.blocks for block in row]
533
+
534
+ def canonicalize(self) -> "Expr":
535
+ """Canonicalize by recursively canonicalizing all blocks.
536
+
537
+ If the block contains only a single element ([[a]]), returns the
538
+ canonicalized element directly to simplify the expression tree.
539
+ """
540
+ canonical_blocks = [[block.canonicalize() for block in row] for row in self.blocks]
541
+
542
+ # Unwrap single-element blocks
543
+ if len(canonical_blocks) == 1 and len(canonical_blocks[0]) == 1:
544
+ return canonical_blocks[0][0]
545
+
546
+ return Block(canonical_blocks)
547
+
548
+ def check_shape(self) -> Tuple[int, ...]:
549
+ """Validate block dimensions and compute output shape.
550
+
551
+ For 2D blocks, returns (total_rows, total_cols). For N-D blocks,
552
+ returns the shape after assembling blocks along the first two axes,
553
+ with trailing dimensions preserved.
554
+
555
+ Returns:
556
+ Tuple representing the assembled block array shape
557
+
558
+ Raises:
559
+ ValueError: If block dimensions are incompatible
560
+ """
561
+ n_block_rows = len(self.blocks)
562
+ n_block_cols = len(self.blocks[0])
563
+
564
+ # Get shapes of all blocks
565
+ block_shapes = [[block.check_shape() for block in row] for row in self.blocks]
566
+
567
+ # Determine the maximum dimensionality across all blocks
568
+ max_ndim = max(len(shape) for row in block_shapes for shape in row)
569
+ max_ndim = max(max_ndim, 2) # At least 2D for block assembly
570
+
571
+ # Normalize shapes: pad to max_ndim by prepending 1s
572
+ # Scalars () -> (1, 1, ...), 1D (n,) -> (1, n, ...), etc.
573
+ def normalize_shape(shape):
574
+ if len(shape) == 0:
575
+ return (1,) * max_ndim
576
+ elif len(shape) < max_ndim:
577
+ # Prepend 1s to match max_ndim
578
+ return (1,) * (max_ndim - len(shape)) + shape
579
+ else:
580
+ return shape
581
+
582
+ normalized_shapes = [[normalize_shape(shape) for shape in row] for row in block_shapes]
583
+
584
+ # Validate trailing dimensions (dims 2+) match across ALL blocks
585
+ if max_ndim > 2:
586
+ trailing_shape = normalized_shapes[0][0][2:]
587
+ for i, row_shapes in enumerate(normalized_shapes):
588
+ for j, shape in enumerate(row_shapes):
589
+ if shape[2:] != trailing_shape:
590
+ raise ValueError(
591
+ f"Block[{i}][{j}] has trailing dimensions {shape[2:]}, "
592
+ f"but Block[0][0] has {trailing_shape}. "
593
+ f"All blocks must have matching dimensions beyond the first two."
594
+ )
595
+
596
+ # Compute row heights (first dimension of each row must match)
597
+ row_heights = []
598
+ for i, row_shapes in enumerate(normalized_shapes):
599
+ heights = [s[0] for s in row_shapes]
600
+ if len(set(heights)) > 1:
601
+ raise ValueError(
602
+ f"Block row {i} has inconsistent heights: {heights}. "
603
+ f"All blocks in a row must have the same height."
604
+ )
605
+ row_heights.append(heights[0])
606
+
607
+ # Compute column widths (second dimension of each column must match)
608
+ col_widths = []
609
+ for j in range(n_block_cols):
610
+ widths = [normalized_shapes[i][j][1] for i in range(n_block_rows)]
611
+ if len(set(widths)) > 1:
612
+ raise ValueError(
613
+ f"Block column {j} has inconsistent widths: {widths}. "
614
+ f"All blocks in a column must have the same width."
615
+ )
616
+ col_widths.append(widths[0])
617
+
618
+ total_rows = sum(row_heights)
619
+ total_cols = sum(col_widths)
620
+
621
+ # Return shape with trailing dimensions if present
622
+ if max_ndim > 2:
623
+ return (total_rows, total_cols) + normalized_shapes[0][0][2:]
624
+ return (total_rows, total_cols)
625
+
626
+ def __repr__(self):
627
+ rows_repr = []
628
+ for row in self.blocks:
629
+ blocks_repr = ", ".join(repr(block) for block in row)
630
+ rows_repr.append(f"[{blocks_repr}]")
631
+ inner = ", ".join(rows_repr)
632
+ return f"Block([{inner}])"