openscvx 0.3.2.dev170__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

Files changed (79) hide show
  1. openscvx/__init__.py +123 -0
  2. openscvx/_version.py +34 -0
  3. openscvx/algorithms/__init__.py +92 -0
  4. openscvx/algorithms/autotuning.py +24 -0
  5. openscvx/algorithms/base.py +351 -0
  6. openscvx/algorithms/optimization_results.py +215 -0
  7. openscvx/algorithms/penalized_trust_region.py +384 -0
  8. openscvx/config.py +437 -0
  9. openscvx/discretization/__init__.py +47 -0
  10. openscvx/discretization/discretization.py +236 -0
  11. openscvx/expert/__init__.py +23 -0
  12. openscvx/expert/byof.py +326 -0
  13. openscvx/expert/lowering.py +419 -0
  14. openscvx/expert/validation.py +357 -0
  15. openscvx/integrators/__init__.py +48 -0
  16. openscvx/integrators/runge_kutta.py +281 -0
  17. openscvx/lowered/__init__.py +30 -0
  18. openscvx/lowered/cvxpy_constraints.py +23 -0
  19. openscvx/lowered/cvxpy_variables.py +124 -0
  20. openscvx/lowered/dynamics.py +34 -0
  21. openscvx/lowered/jax_constraints.py +133 -0
  22. openscvx/lowered/parameters.py +54 -0
  23. openscvx/lowered/problem.py +70 -0
  24. openscvx/lowered/unified.py +718 -0
  25. openscvx/plotting/__init__.py +63 -0
  26. openscvx/plotting/plotting.py +756 -0
  27. openscvx/plotting/scp_iteration.py +299 -0
  28. openscvx/plotting/viser/__init__.py +126 -0
  29. openscvx/plotting/viser/animated.py +605 -0
  30. openscvx/plotting/viser/plotly_integration.py +333 -0
  31. openscvx/plotting/viser/primitives.py +355 -0
  32. openscvx/plotting/viser/scp.py +459 -0
  33. openscvx/plotting/viser/server.py +112 -0
  34. openscvx/problem.py +734 -0
  35. openscvx/propagation/__init__.py +60 -0
  36. openscvx/propagation/post_processing.py +104 -0
  37. openscvx/propagation/propagation.py +248 -0
  38. openscvx/solvers/__init__.py +51 -0
  39. openscvx/solvers/cvxpy.py +226 -0
  40. openscvx/symbolic/__init__.py +9 -0
  41. openscvx/symbolic/augmentation.py +630 -0
  42. openscvx/symbolic/builder.py +492 -0
  43. openscvx/symbolic/constraint_set.py +92 -0
  44. openscvx/symbolic/expr/__init__.py +222 -0
  45. openscvx/symbolic/expr/arithmetic.py +517 -0
  46. openscvx/symbolic/expr/array.py +632 -0
  47. openscvx/symbolic/expr/constraint.py +796 -0
  48. openscvx/symbolic/expr/control.py +135 -0
  49. openscvx/symbolic/expr/expr.py +720 -0
  50. openscvx/symbolic/expr/lie/__init__.py +87 -0
  51. openscvx/symbolic/expr/lie/adjoint.py +357 -0
  52. openscvx/symbolic/expr/lie/se3.py +172 -0
  53. openscvx/symbolic/expr/lie/so3.py +138 -0
  54. openscvx/symbolic/expr/linalg.py +279 -0
  55. openscvx/symbolic/expr/math.py +699 -0
  56. openscvx/symbolic/expr/spatial.py +209 -0
  57. openscvx/symbolic/expr/state.py +607 -0
  58. openscvx/symbolic/expr/stl.py +136 -0
  59. openscvx/symbolic/expr/variable.py +321 -0
  60. openscvx/symbolic/hashing.py +112 -0
  61. openscvx/symbolic/lower.py +760 -0
  62. openscvx/symbolic/lowerers/__init__.py +106 -0
  63. openscvx/symbolic/lowerers/cvxpy.py +1302 -0
  64. openscvx/symbolic/lowerers/jax.py +1382 -0
  65. openscvx/symbolic/preprocessing.py +757 -0
  66. openscvx/symbolic/problem.py +110 -0
  67. openscvx/symbolic/time.py +116 -0
  68. openscvx/symbolic/unified.py +420 -0
  69. openscvx/utils/__init__.py +20 -0
  70. openscvx/utils/cache.py +131 -0
  71. openscvx/utils/caching.py +210 -0
  72. openscvx/utils/printing.py +301 -0
  73. openscvx/utils/profiling.py +37 -0
  74. openscvx/utils/utils.py +100 -0
  75. openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
  76. openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
  77. openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
  78. openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
  79. openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
@@ -0,0 +1,757 @@
1
+ """Validation and preprocessing utilities for symbolic expressions.
2
+
3
+ This module provides preprocessing and validation functions for symbolic expressions
4
+ in trajectory optimization problems. These utilities ensure that expressions are
5
+ well-formed and constraints are properly specified before compilation to solvers.
6
+
7
+ The preprocessing pipeline includes:
8
+ - Shape validation: Ensure all expressions have compatible shapes
9
+ - Variable name validation: Check for unique, non-reserved variable names
10
+ - Constraint validation: Verify constraints appear only at root level
11
+ - Dynamics validation: Check that dynamics match state dimensions
12
+ - Time parameter validation: Validate time configuration
13
+ - Slice assignment: Assign contiguous memory slices to variables
14
+
15
+ These functions are typically called automatically during problem construction,
16
+ but can also be used manually for debugging or custom problem setups.
17
+
18
+ Example:
19
+ Validating expressions before problem construction::
20
+
21
+ import openscvx as ox
22
+
23
+ x = ox.State("x", shape=(3,))
24
+ u = ox.Control("u", shape=(2,))
25
+
26
+ # Build dynamics and constraints
27
+ dynamics = {
28
+ "x": u # Will fail validation - dimension mismatch!
29
+ }
30
+
31
+ # Validate dimensions before creating problem
32
+ from openscvx.symbolic.preprocessing import validate_dynamics_dict_dimensions
33
+
34
+ try:
35
+ validate_dynamics_dict_dimensions(dynamics, [x])
36
+ except ValueError as e:
37
+ print(f"Validation error: {e}")
38
+ """
39
+
40
+ from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
41
+
42
+ if TYPE_CHECKING:
43
+ from openscvx.symbolic.time import Time
44
+
45
+ import numpy as np
46
+
47
+ from openscvx.symbolic.expr import (
48
+ CTCS,
49
+ Concat,
50
+ Constant,
51
+ Constraint,
52
+ Control,
53
+ CrossNodeConstraint,
54
+ Expr,
55
+ NodalConstraint,
56
+ State,
57
+ traverse,
58
+ )
59
+
60
+
61
+ def validate_shapes(exprs: Union[Expr, list[Expr]]) -> None:
62
+ """Validate shapes for a single expression or list of expressions.
63
+
64
+ Args:
65
+ exprs: Single expression or list of expressions to validate
66
+
67
+ Raises:
68
+ ValueError: If any expression has invalid shapes
69
+ """
70
+ exprs = exprs if isinstance(exprs, (list, tuple)) else [exprs]
71
+ for e in exprs:
72
+ e.check_shape() # will raise ValueError if anything's wrong
73
+
74
+
75
+ # TODO: (norrisg) allow `traverse` to take a list of visitors, that way we can combine steps
76
+ def validate_variable_names(
77
+ exprs: Iterable[Expr],
78
+ *,
79
+ reserved_prefix: str = "_",
80
+ reserved_names: Set[str] = None,
81
+ ) -> None:
82
+ """Validate variable names for uniqueness and reserved name conflicts.
83
+
84
+ This function ensures that all State and Control variable names are:
85
+ 1. Unique across distinct variable instances
86
+ 2. Not starting with the reserved prefix (default: "_")
87
+ 3. Not colliding with explicitly reserved names
88
+
89
+ Args:
90
+ exprs: Iterable of expression trees to scan for variables
91
+ reserved_prefix: Prefix that user variables cannot start with (default: "_")
92
+ reserved_names: Set of explicitly reserved names that cannot be used (default: None)
93
+
94
+ Raises:
95
+ ValueError: If any variable name violates uniqueness or reserved name rules
96
+
97
+ Example:
98
+ x1 = ox.State("x", shape=(3,))
99
+ x2 = ox.State("x", shape=(2,)) # Same name, different object
100
+ validate_variable_names([x1 + x2]) # Raises ValueError: Duplicate name 'x'
101
+
102
+ bad = ox.State("_internal", shape=(2,))
103
+ validate_variable_names([bad]) # Raises ValueError: Reserved prefix '_'
104
+ """
105
+ seen_names = set()
106
+ seen_ids = set()
107
+ reserved = set(reserved_names or ())
108
+
109
+ def visitor(node):
110
+ if not isinstance(node, (State, Control)):
111
+ return
112
+
113
+ node_id = id(node)
114
+ if node_id in seen_ids:
115
+ # we already checked this exact object
116
+ return
117
+
118
+ name = node.name
119
+
120
+ # 1) uniqueness across *different* variables
121
+ if name in seen_names:
122
+ raise ValueError(f"Duplicate variable name: {name!r}")
123
+
124
+ # 2) no leading underscore
125
+ if name.startswith(reserved_prefix):
126
+ raise ValueError(
127
+ f"Variable name {name!r} is reserved (cannot start with {reserved_prefix!r})"
128
+ )
129
+
130
+ # 3) no collision with explicit reserved set
131
+ if name in reserved:
132
+ raise ValueError(f"Variable name {name!r} collides with reserved name")
133
+
134
+ seen_names.add(name)
135
+ seen_ids.add(node_id)
136
+
137
+ for e in exprs:
138
+ traverse(e, visitor)
139
+
140
+
141
+ def collect_and_assign_slices(
142
+ states: List[State], controls: List[Control], *, start_index: int = 0
143
+ ) -> Tuple[list[State], list[Control]]:
144
+ """Assign contiguous memory slices to states and controls.
145
+
146
+ This function assigns slice objects to states and controls that determine their
147
+ positions in the flat decision variable vector. Variables can have either:
148
+ - Auto-assigned slices: Automatically assigned contiguously based on order
149
+ - Manual slices: User-specified slices that must be contiguous and non-overlapping
150
+
151
+ If any variables have manual slices, they must:
152
+ - Start at index 0 (or start_index if specified)
153
+ - Be contiguous and non-overlapping
154
+ - Match the variable's flattened dimension
155
+
156
+ Args:
157
+ states: List of State objects in canonical order
158
+ controls: List of Control objects in canonical order
159
+ start_index: Starting index for slice assignment (default: 0)
160
+
161
+ Returns:
162
+ Tuple of (states, controls) with slice attributes assigned
163
+
164
+ Raises:
165
+ ValueError: If manual slices are invalid (wrong size, overlapping, not starting at 0)
166
+
167
+ Example:
168
+ x = ox.State("x", shape=(3,))
169
+ u = ox.Control("u", shape=(2,))
170
+ states, controls = collect_and_assign_slices([x], [u])
171
+ print(x._slice) # slice(0, 3)
172
+ print(u._slice) # slice(0, 2)
173
+ """
174
+
175
+ def assign(vars_list, start_index):
176
+ # split into manual vs auto
177
+ manual = [v for v in vars_list if v._slice is not None]
178
+ auto = [v for v in vars_list if v._slice is None]
179
+
180
+ if manual:
181
+ # 1) shape‐match check
182
+ for v in manual:
183
+ dim = int(np.prod(v.shape))
184
+ sl = v._slice
185
+ if (sl.stop - sl.start) != dim:
186
+ raise ValueError(
187
+ f"Manual slice for {v.name!r} is length {sl.stop - sl.start}, "
188
+ f"but variable has shape {v.shape} (dim {dim})"
189
+ )
190
+ # sort by the start of their slices
191
+ manual.sort(key=lambda v: v._slice.start)
192
+ # 2a) must start at 0
193
+ if manual[0]._slice.start != start_index:
194
+ raise ValueError("User-defined slices must start at index 0")
195
+ # 2b) check contiguity & no overlaps
196
+ cursor = start_index
197
+ for v in manual:
198
+ sl = v._slice
199
+ dim = int(np.prod(v.shape))
200
+ if sl.start != cursor or sl.stop != cursor + dim:
201
+ raise ValueError(
202
+ f"Manual slice for {v.name!r} must be contiguous and non-overlapping"
203
+ )
204
+ cursor += dim
205
+ offset = cursor
206
+ else:
207
+ offset = start_index
208
+
209
+ # 3) auto-assign the rest
210
+ for v in auto:
211
+ dim = int(np.prod(v.shape))
212
+ v._slice = slice(offset, offset + dim)
213
+ offset += dim
214
+
215
+ # run separately on states (x) and controls (u)
216
+ assign(states, start_index)
217
+ assign(controls, start_index)
218
+
219
+ # Return the collected variables
220
+ return states, controls
221
+
222
+
223
+ def _traverse_with_depth(expr: Expr, visit: Callable[[Expr, int], None], depth: int = 0):
224
+ """Depth-first traversal of an expression tree with depth tracking.
225
+
226
+ Internal helper function that extends the standard traverse function to track
227
+ the depth of each node in the tree. Used for constraint validation.
228
+
229
+ Args:
230
+ expr: Root expression node to start traversal from
231
+ visit: Callback function applied to each (node, depth) pair during traversal
232
+ depth: Current depth level (default: 0)
233
+ """
234
+ visit(expr, depth)
235
+ for child in expr.children():
236
+ _traverse_with_depth(child, visit, depth + 1)
237
+
238
+
239
+ def validate_constraints_at_root(exprs: Union[Expr, list[Expr]]):
240
+ """Validate that constraints only appear at the root level of expression trees.
241
+
242
+ Constraints and constraint wrappers (CTCS, NodalConstraint, CrossNodeConstraint)
243
+ must only appear as top-level expressions, not nested within other expressions.
244
+ However, constraints inside constraint wrappers are allowed (e.g., the constraint
245
+ inside CTCS(x <= 5)).
246
+
247
+ This ensures constraints are properly processed during problem compilation and
248
+ prevents ambiguous constraint specifications.
249
+
250
+ Args:
251
+ exprs: Single expression or list of expressions to validate
252
+
253
+ Raises:
254
+ ValueError: If any constraint or constraint wrapper is found at depth > 0
255
+
256
+ Example:
257
+ x = ox.State("x", shape=(3,))
258
+ constraint = x <= 5
259
+ validate_constraints_at_root([constraint]) # OK - constraint at root
260
+
261
+ bad_expr = ox.Sum(x <= 5) # Constraint nested inside Sum
262
+ validate_constraints_at_root([bad_expr]) # Raises ValueError
263
+ """
264
+
265
+ # Define constraint wrappers that must also be at root level
266
+ CONSTRAINT_WRAPPERS = (CTCS, NodalConstraint, CrossNodeConstraint)
267
+
268
+ # normalize to list
269
+ expr_list = exprs if isinstance(exprs, (list, tuple)) else [exprs]
270
+
271
+ for expr in expr_list:
272
+
273
+ def visit(node: Expr, depth: int):
274
+ if depth > 0:
275
+ if isinstance(node, CONSTRAINT_WRAPPERS):
276
+ raise ValueError(
277
+ f"Nested constraint wrapper found at depth {depth!r}: {node!r}; "
278
+ "constraint wrappers must only appear as top-level roots"
279
+ )
280
+ elif isinstance(node, Constraint):
281
+ raise ValueError(
282
+ f"Nested Constraint found at depth {depth!r}: {node!r}; "
283
+ "constraints must only appear as top-level roots"
284
+ )
285
+
286
+ # If this is a constraint wrapper, don't validate its children
287
+ # (we allow constraints inside constraint wrappers)
288
+ if isinstance(node, CONSTRAINT_WRAPPERS):
289
+ return # Skip traversing children
290
+
291
+ # Otherwise, continue traversing children
292
+ for child in node.children():
293
+ visit(child, depth + 1)
294
+
295
+ # Start traversal
296
+ visit(expr, 0)
297
+
298
+
299
+ def validate_and_normalize_constraint_nodes(exprs: Union[Expr, list[Expr]], n_nodes: int):
300
+ """Validate and normalize constraint node specifications.
301
+
302
+ This function validates and normalizes node specifications for constraint wrappers:
303
+
304
+ For NodalConstraint:
305
+ - nodes should be a list of specific node indices: [2, 4, 6, 8]
306
+ - Validates all nodes are within the valid range [0, n_nodes)
307
+
308
+ For CTCS (Continuous-Time Constraint Satisfaction) constraints:
309
+ - nodes should be a tuple of (start, end): (0, 10)
310
+ - None is replaced with (0, n_nodes) to apply over entire trajectory
311
+ - Validation ensures tuple has exactly 2 elements and start < end
312
+ - Validates indices are within trajectory bounds
313
+
314
+ Args:
315
+ exprs: Single expression or list of expressions to validate
316
+ n_nodes: Total number of nodes in the trajectory
317
+
318
+ Raises:
319
+ ValueError: If node specifications are invalid (out of range, malformed, etc.)
320
+
321
+ Example:
322
+ x = ox.State("x", shape=(3,))
323
+ constraint = (x <= 5).at([0, 10, 20]) # NodalConstraint
324
+ validate_and_normalize_constraint_nodes([constraint], n_nodes=50) # OK
325
+
326
+ ctcs_constraint = (x <= 5).over((0, 100)) # CTCS
327
+ validate_and_normalize_constraint_nodes([ctcs_constraint], n_nodes=50)
328
+ # Raises ValueError: Range exceeds trajectory length
329
+ """
330
+
331
+ # Normalize to list
332
+ expr_list = exprs if isinstance(exprs, (list, tuple)) else [exprs]
333
+
334
+ for expr in expr_list:
335
+ if isinstance(expr, CTCS):
336
+ # CTCS constraint validation (already done in __init__, but normalize None)
337
+ if expr.nodes is None:
338
+ expr.nodes = (0, n_nodes)
339
+ elif expr.nodes[0] >= n_nodes or expr.nodes[1] > n_nodes:
340
+ raise ValueError(
341
+ f"CTCS node range {expr.nodes} exceeds trajectory length {n_nodes}"
342
+ )
343
+
344
+ elif isinstance(expr, NodalConstraint):
345
+ # NodalConstraint validation - nodes are already validated in __init__
346
+ # Just need to check they're within trajectory range
347
+ for node in expr.nodes:
348
+ if node < 0 or node >= n_nodes:
349
+ raise ValueError(f"NodalConstraint node {node} is out of range [0, {n_nodes})")
350
+
351
+
352
+ def validate_cross_node_constraint(cross_node_constraint, n_nodes: int) -> None:
353
+ """Validate cross-node constraint bounds and variable consistency.
354
+
355
+ This function performs two validations in a single tree traversal:
356
+
357
+ 1. **Bounds checking**: Ensures all NodeReference indices are within [0, n_nodes).
358
+ Cross-node constraints reference fixed trajectory nodes (e.g., position.at(5)),
359
+ and this validates those indices are valid. Negative indices are normalized
360
+ (e.g., -1 becomes n_nodes-1) before checking.
361
+
362
+ 2. **Variable consistency**: Ensures that if ANY variable uses .at(), then ALL
363
+ state/control variables must use .at(). Mixing causes shape mismatches during
364
+ lowering because:
365
+ - Variables with .at(k) extract single-node values: X[k, :] → shape (n_x,)
366
+ - Variables without .at() expect full trajectory: X[:, :] → shape (N, n_x)
367
+
368
+ Args:
369
+ cross_node_constraint: The CrossNodeConstraint to validate
370
+ n_nodes: Total number of trajectory nodes
371
+
372
+ Raises:
373
+ ValueError: If any NodeReference accesses nodes outside [0, n_nodes)
374
+ ValueError: If constraint mixes .at() and non-.at() variables
375
+
376
+ Example:
377
+ Valid cross-node constraint:
378
+
379
+ from openscvx.symbolic.expr import CrossNodeConstraint
380
+
381
+ position = State("pos", shape=(3,))
382
+
383
+ # Valid: all variables use .at(), indices in bounds
384
+ constraint = CrossNodeConstraint(position.at(5) - position.at(4) <= 0.1)
385
+ validate_cross_node_constraint(constraint, n_nodes=10) # OK
386
+
387
+ Invalid - out of bounds:
388
+
389
+ # Invalid: node 10 is out of bounds for n_nodes=10
390
+ bad_bounds = CrossNodeConstraint(position.at(0) == position.at(10))
391
+ validate_cross_node_constraint(bad_bounds, n_nodes=10) # Raises ValueError
392
+
393
+ Invalid - mixed .at() usage:
394
+
395
+ velocity = State("vel", shape=(3,))
396
+ # Invalid: position uses .at(), velocity doesn't
397
+ bad_mixed = CrossNodeConstraint(position.at(5) - velocity <= 0.1)
398
+ validate_cross_node_constraint(bad_mixed, n_nodes=10) # Raises ValueError
399
+ """
400
+ from openscvx.symbolic.expr import Control, CrossNodeConstraint, NodeReference, State
401
+
402
+ if not isinstance(cross_node_constraint, CrossNodeConstraint):
403
+ raise TypeError(
404
+ f"Expected CrossNodeConstraint, got {type(cross_node_constraint).__name__}. "
405
+ f"Bare constraints with NodeReferences should be wrapped in CrossNodeConstraint "
406
+ f"by separate_constraints() before validation."
407
+ )
408
+
409
+ constraint = cross_node_constraint.constraint
410
+
411
+ # Collect information in a single traversal
412
+ node_refs = [] # List of (node_idx, normalized_idx) tuples
413
+ unwrapped_vars = [] # List of variable names without .at()
414
+
415
+ def traverse(expr):
416
+ if isinstance(expr, NodeReference):
417
+ # Normalize negative indices
418
+ idx = expr.node_idx
419
+ normalized_idx = idx if idx >= 0 else n_nodes + idx
420
+ node_refs.append((idx, normalized_idx))
421
+ # Don't traverse into children - NodeReference wraps the variable
422
+ return
423
+
424
+ if isinstance(expr, (State, Control)):
425
+ # Found a bare State/Control not wrapped in NodeReference
426
+ unwrapped_vars.append(expr.name)
427
+ return
428
+
429
+ # Recurse on children
430
+ for child in expr.children():
431
+ traverse(child)
432
+
433
+ # Traverse the constraint expression (both sides)
434
+ traverse(constraint.lhs)
435
+ traverse(constraint.rhs)
436
+
437
+ # Check 1: Bounds validation
438
+ for orig_idx, normalized_idx in node_refs:
439
+ if normalized_idx < 0 or normalized_idx >= n_nodes:
440
+ raise ValueError(
441
+ f"Cross-node constraint references invalid node index {orig_idx}. "
442
+ f"Node indices must be in range [0, {n_nodes}) "
443
+ f"(or negative indices in range [-{n_nodes}, -1]). "
444
+ f"Constraint: {constraint}"
445
+ )
446
+
447
+ # Check 2: Variable consistency - if we have NodeReferences, all vars must use .at()
448
+ if node_refs and unwrapped_vars:
449
+ raise ValueError(
450
+ f"Cross-node constraint contains NodeReferences (variables with .at(k)) "
451
+ f"but also has variables without .at(): {unwrapped_vars}. "
452
+ f"All state/control variables in cross-node constraints must use .at(k). "
453
+ f"For example, if you use 'position.at(5)', you must also use 'velocity.at(4)' "
454
+ f"instead of just 'velocity'. "
455
+ f"Constraint: {constraint}"
456
+ )
457
+
458
+
459
+ def validate_dynamics_dimension(
460
+ dynamics_expr: Union[Expr, list[Expr]], states: Union[State, list[State]]
461
+ ) -> None:
462
+ """Validate that dynamics expression dimensions match state dimensions.
463
+
464
+ Ensures that the total dimension of all dynamics expressions matches the total
465
+ dimension of all states. Each dynamics expression must be a 1D vector, and their
466
+ combined dimension must equal the sum of all state dimensions.
467
+
468
+ This is essential for ensuring the ODE system x_dot = f(x, u, t) is well-formed.
469
+
470
+ Args:
471
+ dynamics_expr: Single dynamics expression or list of dynamics expressions.
472
+ Combined, they represent x_dot = f(x, u, t) for all states.
473
+ states: Single state variable or list of state variables that the dynamics describe.
474
+
475
+ Raises:
476
+ ValueError: If dimensions don't match or if any dynamics is not a 1D vector
477
+
478
+ Example:
479
+ x = ox.State("x", shape=(3,))
480
+ y = ox.State("y", shape=(2,))
481
+ dynamics = ox.Concat(x * 2, y + 1) # Shape (5,) - matches total state dim
482
+ validate_dynamics_dimension(dynamics, [x, y]) # OK
483
+
484
+ bad_dynamics = x # Shape (3,) - doesn't match total dim of 5
485
+ validate_dynamics_dimension(bad_dynamics, [x, y]) # Raises ValueError
486
+ """
487
+ # Normalize inputs to lists
488
+ dynamics_list = dynamics_expr if isinstance(dynamics_expr, (list, tuple)) else [dynamics_expr]
489
+ states_list = states if isinstance(states, (list, tuple)) else [states]
490
+
491
+ # Calculate total state dimension
492
+ total_state_dim = sum(int(np.prod(state.shape)) for state in states_list)
493
+
494
+ # Validate each dynamics expression and calculate total dynamics dimension
495
+ total_dynamics_dim = 0
496
+
497
+ for i, dyn_expr in enumerate(dynamics_list):
498
+ # Get the shape of this dynamics expression
499
+ dynamics_shape = dyn_expr.check_shape()
500
+
501
+ # Dynamics should be a 1D vector
502
+ if len(dynamics_shape) != 1:
503
+ prefix = f"Dynamics expression {i}" if len(dynamics_list) > 1 else "Dynamics expression"
504
+ raise ValueError(
505
+ f"{prefix} must be 1-dimensional (vector), but got shape {dynamics_shape}"
506
+ )
507
+
508
+ total_dynamics_dim += dynamics_shape[0]
509
+
510
+ # Check that total dynamics dimension matches total state dimension
511
+ if total_dynamics_dim != total_state_dim:
512
+ if len(dynamics_list) == 1:
513
+ raise ValueError(
514
+ f"Dynamics dimension mismatch: dynamics has dimension {total_dynamics_dim}, "
515
+ f"but total state dimension is {total_state_dim}. "
516
+ f"States: {[(s.name, s.shape) for s in states_list]}"
517
+ )
518
+ else:
519
+ dynamics_dims = [dyn.check_shape()[0] for dyn in dynamics_list]
520
+ raise ValueError(
521
+ f"Dynamics dimension mismatch: {len(dynamics_list)} dynamics expressions "
522
+ f"have combined dimension {total_dynamics_dim} {dynamics_dims}, "
523
+ f"but total state dimension is {total_state_dim}. "
524
+ f"States: {[(s.name, s.shape) for s in states_list]}"
525
+ )
526
+
527
+
528
+ def validate_dynamics_dict(
529
+ dynamics: Dict[str, Expr],
530
+ states: List[State],
531
+ byof_dynamics: Optional[Dict[str, callable]] = None,
532
+ ) -> None:
533
+ """Validate that dynamics dictionary keys match state names exactly.
534
+
535
+ Ensures that the dynamics dictionary (combined with optional byof dynamics) has
536
+ exactly the same keys as the state names, with no missing states, no extra keys,
537
+ and no overlap between symbolic and byof dynamics.
538
+
539
+ Args:
540
+ dynamics: Dictionary mapping state names to their dynamics expressions
541
+ states: List of State objects
542
+ byof_dynamics: Optional dictionary mapping state names to raw JAX functions.
543
+ States in byof_dynamics should NOT appear in dynamics dict.
544
+
545
+ Raises:
546
+ ValueError: If there's a mismatch between state names and dynamics keys,
547
+ or if a state appears in both dynamics and byof_dynamics.
548
+
549
+ Example:
550
+ x = ox.State("x", shape=(3,))
551
+ y = ox.State("y", shape=(2,))
552
+ dynamics = {"x": x * 2, "y": y + 1}
553
+ validate_dynamics_dict(dynamics, [x, y]) # OK
554
+
555
+ bad_dynamics = {"x": x * 2} # Missing "y"
556
+ validate_dynamics_dict(bad_dynamics, [x, y]) # Raises ValueError
557
+
558
+ # With byof_dynamics (expert user mode)
559
+ dynamics = {"x": x * 2} # Only symbolic for x
560
+ byof_dynamics = {"y": some_jax_fn} # Raw JAX for y
561
+ validate_dynamics_dict(dynamics, [x, y], byof_dynamics) # OK
562
+ """
563
+ state_names_set = set(state.name for state in states)
564
+ symbolic_keys = set(dynamics.keys())
565
+ byof_keys = set(byof_dynamics.keys()) if byof_dynamics else set()
566
+
567
+ # Check for overlap - a state can't be defined in both
568
+ overlap = symbolic_keys & byof_keys
569
+ if overlap:
570
+ raise ValueError(
571
+ f"States defined in both symbolic and byof dynamics: {overlap}\n"
572
+ "Each state must have dynamics in exactly one place."
573
+ )
574
+
575
+ # Check coverage - all states must be covered
576
+ covered = symbolic_keys | byof_keys
577
+ missing = state_names_set - covered
578
+ extra = covered - state_names_set
579
+
580
+ if missing or extra:
581
+ error_msg = "Mismatch between state names and dynamics keys.\n"
582
+ if missing:
583
+ error_msg += f" States missing from dynamics: {missing}\n"
584
+ if extra:
585
+ error_msg += f" Extra keys in dynamics: {extra}\n"
586
+ raise ValueError(error_msg)
587
+
588
+
589
+ def validate_dynamics_dict_dimensions(dynamics: Dict[str, Expr], states: List[State]) -> None:
590
+ """Validate that each dynamics expression matches its corresponding state shape.
591
+
592
+ For dictionary-based dynamics specification, ensures that each state's dynamics
593
+ expression has the same shape as the state itself. This validates that each
594
+ component of x_dot = f(x, u, t) has the correct dimension.
595
+
596
+ Scalars are normalized to shape (1,) for comparison, matching Concat behavior.
597
+
598
+ Args:
599
+ dynamics: Dictionary mapping state names to their dynamics expressions
600
+ states: List of State objects
601
+
602
+ Raises:
603
+ ValueError: If any dynamics expression dimension doesn't match its state shape
604
+
605
+ Example:
606
+ x = ox.State("x", shape=(3,))
607
+ y = ox.State("y", shape=(2,))
608
+ u = ox.Control("u", shape=(3,))
609
+ dynamics = {"x": u, "y": y + 1}
610
+ validate_dynamics_dict_dimensions(dynamics, [x, y]) # OK
611
+
612
+ bad_dynamics = {"x": u, "y": u} # y dynamics has wrong shape
613
+ validate_dynamics_dict_dimensions(bad_dynamics, [x, y]) # Raises ValueError
614
+ """
615
+
616
+ def normalize_scalars(shape: Tuple[int, ...]) -> Tuple[int, ...]:
617
+ """Normalize shape: scalar () becomes (1,)"""
618
+ return (1,) if len(shape) == 0 else shape
619
+
620
+ for state in states:
621
+ dyn_expr = dynamics[state.name]
622
+ expected_shape = state.shape
623
+
624
+ # Handle raw Python numbers (which will be converted to Constant later)
625
+ if isinstance(dyn_expr, (int, float)):
626
+ actual_shape = () # Scalars have shape ()
627
+ else:
628
+ # Compute the shape of the dynamics expression
629
+ actual_shape = dyn_expr.check_shape()
630
+
631
+ # Normalize both shapes for comparison (consistent with Concat behavior)
632
+ if normalize_scalars(actual_shape) != normalize_scalars(expected_shape):
633
+ raise ValueError(
634
+ f"Dynamics for state '{state.name}' has shape {actual_shape}, "
635
+ f"but state has shape {expected_shape}"
636
+ )
637
+
638
+
639
+ def validate_time_parameters(
640
+ states: List[State],
641
+ time: "Time",
642
+ ) -> Tuple[
643
+ bool,
644
+ Union[float, tuple, None],
645
+ Union[float, tuple, None],
646
+ float,
647
+ Union[float, None],
648
+ Union[float, None],
649
+ ]:
650
+ """Validate time parameter usage and configuration.
651
+
652
+ There are two valid approaches for handling time in trajectory optimization:
653
+
654
+ 1. Auto-create time (recommended): Don't include "time" in states, provide Time object.
655
+ The time state is automatically created and managed.
656
+
657
+ 2. User-provided time (advanced): Include a "time" State in states. The Time object
658
+ is ignored and the user has full control over time dynamics.
659
+
660
+ Args:
661
+ states: List of State objects
662
+ time: Time configuration object (required, but ignored if time state exists)
663
+
664
+ Returns:
665
+ Tuple of (has_time_state, time_initial, time_final, time_derivative, time_min, time_max):
666
+ - has_time_state: True if user provided a time state
667
+ - time_initial: Initial time value (None if user-provided time)
668
+ - time_final: Final time value (None if user-provided time)
669
+ - time_derivative: Always 1.0 for auto-created time (None if user-provided)
670
+ - time_min: Minimum time bound (None if user-provided)
671
+ - time_max: Maximum time bound (None if user-provided)
672
+
673
+ Raises:
674
+ ValueError: If Time object is not provided or has invalid type
675
+
676
+ Example:
677
+ # Approach 1: Auto-create time
678
+ x = ox.State("x", shape=(3,))
679
+ time_obj = ox.Time(initial=0.0, final=10.0)
680
+ validate_time_parameters([x], time_obj)
681
+ (False, 0.0, 10.0, 1.0, None, None)
682
+
683
+ # Approach 2: User-provided time
684
+ x = ox.State("x", shape=(3,))
685
+ time_state = ox.State("time", shape=())
686
+ validate_time_parameters([x, time_state], time_obj)
687
+ (True, None, None, None, None, None)
688
+ """
689
+ from openscvx.symbolic.time import Time
690
+
691
+ if not isinstance(time, Time):
692
+ raise ValueError(f"Expected Time object, but got {type(time).__name__}")
693
+
694
+ has_time_state = any(state.name == "time" for state in states)
695
+
696
+ if has_time_state:
697
+ # Approach 2: User-provided time state
698
+ # Time object is provided but ignored - user handles everything via State
699
+ # Return None for all time parameters since user handles everything
700
+ return True, None, None, None, None, None
701
+ else:
702
+ # Approach 1: Auto-create time state
703
+ # Extract values from Time object
704
+ time_initial = time.initial
705
+ time_final = time.final
706
+ time_derivative = 1.0 # Always 1.0 when using Time object
707
+ time_min = time.min
708
+ time_max = time.max
709
+
710
+ return False, time_initial, time_final, time_derivative, time_min, time_max
711
+
712
+
713
+ def convert_dynamics_dict_to_expr(
714
+ dynamics: Dict[str, Expr], states: List[State]
715
+ ) -> Tuple[Dict[str, Expr], Expr]:
716
+ """Convert dynamics dictionary to concatenated expression in canonical order.
717
+
718
+ Converts a dictionary-based dynamics specification to a single concatenated expression
719
+ that represents the full ODE system x_dot = f(x, u, t). The dynamics are ordered
720
+ according to the states list to ensure consistent variable ordering.
721
+
722
+ This function also normalizes scalar values (int, float) to Constant expressions.
723
+
724
+ Args:
725
+ dynamics: Dictionary mapping state names to their dynamics expressions
726
+ states: List of State objects defining the canonical order
727
+
728
+ Returns:
729
+ Tuple of:
730
+ - Updated dynamics dictionary (with scalars converted to Constant expressions)
731
+ - Concatenated dynamics expression ordered by states list
732
+
733
+ Example:
734
+ Convert dynamics dict to a single expression:
735
+
736
+ x = ox.State("x", shape=(3,))
737
+ y = ox.State("y", shape=(2,))
738
+ dynamics_dict = {"x": x * 2, "y": 1.0} # Scalar for y
739
+ converted_dict, concat_expr = convert_dynamics_dict_to_expr(
740
+ dynamics_dict, [x, y]
741
+ )
742
+ # converted_dict["y"] is now Constant(1.0)
743
+ # concat_expr is Concat(x * 2, Constant(1.0))
744
+ """
745
+ # Create a copy to avoid mutating the input
746
+ dynamics_converted = dict(dynamics)
747
+
748
+ # Convert scalar values to Constant expressions
749
+ for state_name, dyn_expr in dynamics_converted.items():
750
+ if isinstance(dyn_expr, (int, float)):
751
+ dynamics_converted[state_name] = Constant(dyn_expr)
752
+
753
+ # Create concatenated expression ordered by states list
754
+ dynamics_exprs = [dynamics_converted[state.name] for state in states]
755
+ dynamics_concat = Concat(*dynamics_exprs)
756
+
757
+ return dynamics_converted, dynamics_concat