openscvx 0.3.2.dev170__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

Files changed (79) hide show
  1. openscvx/__init__.py +123 -0
  2. openscvx/_version.py +34 -0
  3. openscvx/algorithms/__init__.py +92 -0
  4. openscvx/algorithms/autotuning.py +24 -0
  5. openscvx/algorithms/base.py +351 -0
  6. openscvx/algorithms/optimization_results.py +215 -0
  7. openscvx/algorithms/penalized_trust_region.py +384 -0
  8. openscvx/config.py +437 -0
  9. openscvx/discretization/__init__.py +47 -0
  10. openscvx/discretization/discretization.py +236 -0
  11. openscvx/expert/__init__.py +23 -0
  12. openscvx/expert/byof.py +326 -0
  13. openscvx/expert/lowering.py +419 -0
  14. openscvx/expert/validation.py +357 -0
  15. openscvx/integrators/__init__.py +48 -0
  16. openscvx/integrators/runge_kutta.py +281 -0
  17. openscvx/lowered/__init__.py +30 -0
  18. openscvx/lowered/cvxpy_constraints.py +23 -0
  19. openscvx/lowered/cvxpy_variables.py +124 -0
  20. openscvx/lowered/dynamics.py +34 -0
  21. openscvx/lowered/jax_constraints.py +133 -0
  22. openscvx/lowered/parameters.py +54 -0
  23. openscvx/lowered/problem.py +70 -0
  24. openscvx/lowered/unified.py +718 -0
  25. openscvx/plotting/__init__.py +63 -0
  26. openscvx/plotting/plotting.py +756 -0
  27. openscvx/plotting/scp_iteration.py +299 -0
  28. openscvx/plotting/viser/__init__.py +126 -0
  29. openscvx/plotting/viser/animated.py +605 -0
  30. openscvx/plotting/viser/plotly_integration.py +333 -0
  31. openscvx/plotting/viser/primitives.py +355 -0
  32. openscvx/plotting/viser/scp.py +459 -0
  33. openscvx/plotting/viser/server.py +112 -0
  34. openscvx/problem.py +734 -0
  35. openscvx/propagation/__init__.py +60 -0
  36. openscvx/propagation/post_processing.py +104 -0
  37. openscvx/propagation/propagation.py +248 -0
  38. openscvx/solvers/__init__.py +51 -0
  39. openscvx/solvers/cvxpy.py +226 -0
  40. openscvx/symbolic/__init__.py +9 -0
  41. openscvx/symbolic/augmentation.py +630 -0
  42. openscvx/symbolic/builder.py +492 -0
  43. openscvx/symbolic/constraint_set.py +92 -0
  44. openscvx/symbolic/expr/__init__.py +222 -0
  45. openscvx/symbolic/expr/arithmetic.py +517 -0
  46. openscvx/symbolic/expr/array.py +632 -0
  47. openscvx/symbolic/expr/constraint.py +796 -0
  48. openscvx/symbolic/expr/control.py +135 -0
  49. openscvx/symbolic/expr/expr.py +720 -0
  50. openscvx/symbolic/expr/lie/__init__.py +87 -0
  51. openscvx/symbolic/expr/lie/adjoint.py +357 -0
  52. openscvx/symbolic/expr/lie/se3.py +172 -0
  53. openscvx/symbolic/expr/lie/so3.py +138 -0
  54. openscvx/symbolic/expr/linalg.py +279 -0
  55. openscvx/symbolic/expr/math.py +699 -0
  56. openscvx/symbolic/expr/spatial.py +209 -0
  57. openscvx/symbolic/expr/state.py +607 -0
  58. openscvx/symbolic/expr/stl.py +136 -0
  59. openscvx/symbolic/expr/variable.py +321 -0
  60. openscvx/symbolic/hashing.py +112 -0
  61. openscvx/symbolic/lower.py +760 -0
  62. openscvx/symbolic/lowerers/__init__.py +106 -0
  63. openscvx/symbolic/lowerers/cvxpy.py +1302 -0
  64. openscvx/symbolic/lowerers/jax.py +1382 -0
  65. openscvx/symbolic/preprocessing.py +757 -0
  66. openscvx/symbolic/problem.py +110 -0
  67. openscvx/symbolic/time.py +116 -0
  68. openscvx/symbolic/unified.py +420 -0
  69. openscvx/utils/__init__.py +20 -0
  70. openscvx/utils/cache.py +131 -0
  71. openscvx/utils/caching.py +210 -0
  72. openscvx/utils/printing.py +301 -0
  73. openscvx/utils/profiling.py +37 -0
  74. openscvx/utils/utils.py +100 -0
  75. openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
  76. openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
  77. openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
  78. openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
  79. openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
@@ -0,0 +1,492 @@
1
+ """Symbolic problem preprocessing and augmentation pipeline.
2
+
3
+ This module provides the main preprocessing pipeline for trajectory optimization problems,
4
+ transforming user-specified symbolic dynamics and constraints into an augmented form
5
+ ready for compilation to executable code.
6
+
7
+ The preprocessing pipeline is purely symbolic - no code generation occurs here. Instead,
8
+ it performs validation, canonicalization, and augmentation to prepare the problem for
9
+ efficient numerical solution.
10
+
11
+ Key functionality:
12
+ - Problem validation: Check shapes, variable names, constraint placement
13
+ - Time handling: Auto-create time state or validate user-provided time
14
+ - Canonicalization: Simplify expressions algebraically
15
+ - Parameter collection: Extract parameter values from expressions
16
+ - Constraint separation: Categorize constraints by type (CTCS, nodal, convex)
17
+ - CTCS augmentation: Add augmented states and time dilation for path constraints
18
+ - Propagation dynamics: Optionally extend dynamics for post-solution propagation
19
+
20
+ The preprocessing pipeline is purely symbolic - no code generation occurs here.
21
+
22
+ Pipeline stages:
23
+ 1. Time handling & validation
24
+ 2. Expression validation (shapes, names, constraint structure)
25
+ 3. Canonicalization & parameter collection
26
+ 4. Constraint separation & CTCS augmentation
27
+ 5. Propagation dynamics creation
28
+
29
+ See `preprocess_symbolic_problem()` for the main entry point.
30
+ """
31
+
32
+ from typing import Dict, List, Optional, Tuple
33
+
34
+ import numpy as np
35
+
36
+ from openscvx.symbolic.augmentation import (
37
+ augment_dynamics_with_ctcs,
38
+ augment_with_time_state,
39
+ decompose_vector_nodal_constraints,
40
+ separate_constraints,
41
+ sort_ctcs_constraints,
42
+ )
43
+ from openscvx.symbolic.constraint_set import ConstraintSet
44
+ from openscvx.symbolic.expr import Constant, Parameter, traverse
45
+ from openscvx.symbolic.expr.control import Control
46
+ from openscvx.symbolic.expr.state import State
47
+ from openscvx.symbolic.preprocessing import (
48
+ collect_and_assign_slices,
49
+ convert_dynamics_dict_to_expr,
50
+ validate_and_normalize_constraint_nodes,
51
+ validate_constraints_at_root,
52
+ validate_dynamics_dict,
53
+ validate_dynamics_dict_dimensions,
54
+ validate_dynamics_dimension,
55
+ validate_shapes,
56
+ validate_time_parameters,
57
+ validate_variable_names,
58
+ )
59
+ from openscvx.symbolic.problem import SymbolicProblem
60
+ from openscvx.symbolic.time import Time
61
+
62
+
63
+ def preprocess_symbolic_problem(
64
+ dynamics: dict,
65
+ constraints: ConstraintSet,
66
+ states: List[State],
67
+ controls: List[Control],
68
+ N: int,
69
+ time: Time,
70
+ licq_min: float = 0.0,
71
+ licq_max: float = 1e-4,
72
+ time_dilation_factor_min: float = 0.3,
73
+ time_dilation_factor_max: float = 3.0,
74
+ dynamics_prop_extra: dict = None,
75
+ states_prop_extra: List[State] = None,
76
+ byof: Optional[dict] = None,
77
+ ) -> SymbolicProblem:
78
+ """Preprocess and augment symbolic trajectory optimization problem.
79
+
80
+ This is the main preprocessing pipeline that transforms a user-specified symbolic
81
+ problem into an augmented form ready for compilation. It performs validation,
82
+ canonicalization, constraint separation, and CTCS augmentation in a series of
83
+ well-defined phases.
84
+
85
+ The function is purely symbolic - no code generation or compilation occurs. The
86
+ output is a SymbolicProblem dataclass that can be lowered to JAX or CVXPy by
87
+ downstream compilation functions.
88
+
89
+ Pipeline phases:
90
+ 1. Time handling & validation: Auto-create or validate time state
91
+ 2. Expression validation: Validate shapes, names, constraints
92
+ 3. Canonicalization & parameter collection: Simplify and extract parameters
93
+ 4. Constraint separation & augmentation: Sort constraints and add CTCS states
94
+ 5. Propagation dynamics creation: Optionally add extra states for simulation
95
+
96
+ Args:
97
+ dynamics: Dictionary mapping state names to dynamics expressions.
98
+ Example: {"x": v, "v": u}
99
+ constraints: ConstraintSet with raw constraints in `unsorted` field.
100
+ Create with: ConstraintSet(unsorted=[c1, c2, c3])
101
+ states: List of user-defined State objects (should NOT include time or CTCS states)
102
+ controls: List of user-defined Control objects (should NOT include time dilation)
103
+ N: Number of discretization nodes in the trajectory
104
+ time: Time configuration object specifying time bounds and constraints
105
+ licq_min: Minimum bound for CTCS augmented states (default: 0.0)
106
+ licq_max: Maximum bound for CTCS augmented states (default: 1e-4)
107
+ time_dilation_factor_min: Minimum factor for time dilation control (default: 0.3)
108
+ time_dilation_factor_max: Maximum factor for time dilation control (default: 3.0)
109
+ dynamics_prop_extra: Optional dictionary of additional dynamics for propagation-only
110
+ states (default: None)
111
+ states_prop_extra: Optional list of additional State objects for propagation only
112
+ (default: None)
113
+ byof: Optional dict of raw JAX functions for expert users. If byof contains
114
+ a "dynamics" key, it should map state names to raw JAX functions with
115
+ signature f(x, u, node, params) -> xdot_component. States in byof["dynamics"]
116
+ should NOT appear in the symbolic dynamics dict.
117
+
118
+ Returns:
119
+ SymbolicProblem dataclass with:
120
+ - dynamics: Augmented dynamics (user + time + CTCS penalties)
121
+ - states: Augmented states (user + time + CTCS augmented)
122
+ - controls: Augmented controls (user + time dilation)
123
+ - constraints: ConstraintSet with is_categorized=True
124
+ - parameters: Dict of extracted parameter values
125
+ - node_intervals: List of (start, end) tuples for CTCS intervals
126
+ - dynamics_prop: Propagation dynamics
127
+ - states_prop: Propagation states
128
+ - controls_prop: Propagation controls
129
+
130
+ Raises:
131
+ ValueError: If validation fails at any stage
132
+
133
+ Example:
134
+ Basic usage with CTCS constraint::
135
+
136
+ import openscvx as ox
137
+ from openscvx.symbolic.constraint_set import ConstraintSet
138
+
139
+ x = ox.State("x", shape=(2,))
140
+ v = ox.State("v", shape=(2,))
141
+ u = ox.Control("u", shape=(2,))
142
+
143
+ dynamics = {"x": v, "v": u}
144
+ constraints = ConstraintSet(unsorted=[
145
+ (ox.Norm(x) <= 5.0).over((0, 50))
146
+ ])
147
+
148
+ problem = preprocess_symbolic_problem(
149
+ dynamics=dynamics,
150
+ constraints=constraints,
151
+ states=[x, v],
152
+ controls=[u],
153
+ N=50,
154
+ time=ox.Time(initial=0.0, final=10.0)
155
+ )
156
+
157
+ assert problem.is_preprocessed
158
+ # problem.dynamics: augmented dynamics expression
159
+ # problem.states: [x, v, time, _ctcs_aug_0]
160
+ # problem.controls: [u, _time_dilation]
161
+ print([s.name for s in problem.states])
162
+ # ['x', 'v', 'time', '_ctcs_aug_0']
163
+
164
+ With propagation-only states::
165
+
166
+ distance = ox.State("distance", shape=(1,))
167
+ dynamics_extra = {"distance": ox.Norm(v)}
168
+
169
+ problem = preprocess_symbolic_problem(
170
+ dynamics=dynamics,
171
+ constraints=constraints,
172
+ states=[x, v],
173
+ controls=[u],
174
+ N=50,
175
+ time=ox.Time(initial=0.0, final=10.0),
176
+ dynamics_prop_extra=dynamics_extra,
177
+ states_prop_extra=[distance]
178
+ )
179
+
180
+ # Propagation states include distance for post-solve simulation
181
+ print([s.name for s in problem.states_prop])
182
+ """
183
+
184
+ # ==================== PHASE 1: Time Handling & Validation ====================
185
+
186
+ # Validate time handling approach and get processed parameters
187
+ (
188
+ has_time_state,
189
+ time_initial,
190
+ time_final,
191
+ time_derivative,
192
+ time_min,
193
+ time_max,
194
+ ) = validate_time_parameters(states, time)
195
+
196
+ # Augment states with time state if needed (auto-create approach)
197
+ if not has_time_state:
198
+ states, constraints = augment_with_time_state(
199
+ states,
200
+ constraints,
201
+ time_initial,
202
+ time_final,
203
+ time_min,
204
+ time_max,
205
+ N,
206
+ time_scaling_min=getattr(time, "scaling_min", None),
207
+ time_scaling_max=getattr(time, "scaling_max", None),
208
+ )
209
+
210
+ # Add time derivative to dynamics dict (if not already present)
211
+ # Time derivative is always 1.0 when using Time object
212
+ dynamics = dict(dynamics) # Make a copy to avoid mutating the input
213
+ if "time" not in dynamics:
214
+ dynamics["time"] = 1.0
215
+
216
+ # Extract byof dynamics for validation
217
+ byof_dynamics = byof.get("dynamics", {}) if byof else {}
218
+
219
+ # Validate dynamics dict matches state names and dimensions
220
+ # byof_dynamics states should not be in symbolic dynamics dict
221
+ validate_dynamics_dict(dynamics, states, byof_dynamics=byof_dynamics)
222
+
223
+ # Inject zero placeholders for byof dynamics states
224
+ # These will be replaced with the actual byof functions at lowering time
225
+ for state in states:
226
+ if state.name in byof_dynamics:
227
+ dynamics[state.name] = Constant(np.zeros(state.shape))
228
+
229
+ # Validate dynamics dimensions AFTER injecting placeholders
230
+ validate_dynamics_dict_dimensions(dynamics, states)
231
+
232
+ # Convert dynamics dict to concatenated expression
233
+ dynamics, dynamics_concat = convert_dynamics_dict_to_expr(dynamics, states)
234
+
235
+ # ==================== PHASE 2: Expression Validation ====================
236
+
237
+ # Validate all expressions (use unsorted constraints)
238
+ all_exprs = [dynamics_concat] + constraints.unsorted
239
+ validate_variable_names(all_exprs)
240
+ collect_and_assign_slices(states, controls)
241
+ validate_shapes(all_exprs)
242
+ validate_constraints_at_root(constraints.unsorted)
243
+ validate_and_normalize_constraint_nodes(constraints.unsorted, N)
244
+ validate_dynamics_dimension(dynamics_concat, states)
245
+
246
+ # ==================== PHASE 3: Canonicalization & Parameter Collection ====================
247
+
248
+ # Canonicalize all expressions after validation
249
+ dynamics_concat = dynamics_concat.canonicalize()
250
+ constraints.unsorted = [expr.canonicalize() for expr in constraints.unsorted]
251
+
252
+ # Collect parameter values from all constraints and dynamics
253
+ parameters = {}
254
+
255
+ def collect_param_values(expr):
256
+ if isinstance(expr, Parameter):
257
+ if expr.name not in parameters:
258
+ parameters[expr.name] = expr.value
259
+
260
+ # Collect from dynamics
261
+ traverse(dynamics_concat, collect_param_values)
262
+
263
+ # Collect from constraints
264
+ for constraint in constraints.unsorted:
265
+ traverse(constraint, collect_param_values)
266
+
267
+ # ==================== PHASE 4: Constraint Separation & Augmentation ====================
268
+
269
+ # Sort and separate constraints by type (drains unsorted -> fills categories)
270
+ separate_constraints(constraints, N)
271
+
272
+ # Decompose vector-valued nodal constraints into scalar constraints
273
+ # This is necessary for non-convex nodal constraints that get lowered to JAX
274
+ constraints.nodal = decompose_vector_nodal_constraints(constraints.nodal)
275
+
276
+ # Sort CTCS constraints by their idx to get node_intervals
277
+ constraints.ctcs, node_intervals, _ = sort_ctcs_constraints(constraints.ctcs)
278
+
279
+ # Augment dynamics, states, and controls with CTCS constraints, time dilation
280
+ dynamics_aug, states_aug, controls_aug = augment_dynamics_with_ctcs(
281
+ dynamics_concat,
282
+ states,
283
+ controls,
284
+ constraints.ctcs,
285
+ N,
286
+ licq_min=licq_min,
287
+ licq_max=licq_max,
288
+ time_dilation_factor_min=time_dilation_factor_min,
289
+ time_dilation_factor_max=time_dilation_factor_max,
290
+ )
291
+
292
+ # Assign slices to augmented states and controls in canonical order
293
+ collect_and_assign_slices(states_aug, controls_aug)
294
+
295
+ # ==================== PHASE 5: Create Propagation Dynamics ====================
296
+
297
+ # By default, propagation dynamics are the same as optimization dynamics
298
+ # Use deepcopy to avoid reference issues when lowering
299
+ from copy import deepcopy
300
+
301
+ dynamics_prop = deepcopy(dynamics_aug)
302
+ states_prop = list(states_aug) # Shallow copy of list is fine for states
303
+ controls_prop = list(controls_aug)
304
+
305
+ # If user provided extra propagation states, extend propagation dynamics
306
+ if dynamics_prop_extra is not None and states_prop_extra is not None:
307
+ (
308
+ dynamics_prop,
309
+ states_prop,
310
+ controls_prop,
311
+ parameters,
312
+ ) = add_propagation_states(
313
+ dynamics_extra=dynamics_prop_extra,
314
+ states_extra=states_prop_extra,
315
+ dynamics_opt=dynamics_prop,
316
+ states_opt=states_prop,
317
+ controls_opt=controls_prop,
318
+ parameters=parameters,
319
+ )
320
+
321
+ # ==================== Return SymbolicProblem ====================
322
+
323
+ return SymbolicProblem(
324
+ dynamics=dynamics_aug,
325
+ states=states_aug,
326
+ controls=controls_aug,
327
+ constraints=constraints,
328
+ parameters=parameters,
329
+ N=N,
330
+ node_intervals=node_intervals,
331
+ dynamics_prop=dynamics_prop,
332
+ states_prop=states_prop,
333
+ controls_prop=controls_prop,
334
+ )
335
+
336
+
337
+ def add_propagation_states(
338
+ dynamics_extra: dict,
339
+ states_extra: List[State],
340
+ dynamics_opt: any,
341
+ states_opt: List[State],
342
+ controls_opt: List[Control],
343
+ parameters: Dict[str, any],
344
+ ) -> Tuple:
345
+ """Extend optimization dynamics with additional propagation-only states.
346
+
347
+ This function augments the optimization dynamics with extra states that are only
348
+ needed for post-solution trajectory propagation and simulation. These states
349
+ don't affect the optimization but are useful for computing derived quantities
350
+ like distance traveled, energy consumed, or accumulated cost.
351
+
352
+ Propagation-only states are NOT part of the optimization problem - they are
353
+ integrated forward after solving using the optimized state and control trajectories.
354
+ This is more efficient than including them as optimization variables.
355
+
356
+ The user specifies only the ADDITIONAL states and their dynamics. These are
357
+ appended after all optimization states (user states + time + CTCS augmented states).
358
+
359
+ State ordering in propagation dynamics:
360
+ [user_states, time, ctcs_aug_states, extra_prop_states]
361
+
362
+ Args:
363
+ dynamics_extra: Dictionary mapping extra state names to dynamics expressions.
364
+ Only specify NEW states, not optimization states. Example: {"distance": speed}
365
+ states_extra: List of extra State objects for propagation only
366
+ dynamics_opt: Augmented optimization dynamics expression (from preprocessing)
367
+ states_opt: Augmented optimization states (user + time + CTCS augmented)
368
+ controls_opt: Augmented optimization controls (user + time dilation)
369
+ parameters: Dictionary of parameter values from optimization preprocessing
370
+
371
+ Returns:
372
+ Tuple containing:
373
+ - dynamics_prop (Expr): Extended dynamics (optimization + extra)
374
+ - states_prop (List[State]): Extended states (optimization + extra)
375
+ - controls_prop (List[Control]): Same as controls_opt
376
+ - parameters_updated (Dict): Updated parameters including any from extra dynamics
377
+
378
+ Raises:
379
+ ValueError: If extra states conflict with optimization state names or if
380
+ validation fails
381
+
382
+ Example:
383
+ Adding distance and energy tracking for propagation::
384
+
385
+ # After preprocessing, add propagation states
386
+ import openscvx as ox
387
+ import numpy as np
388
+
389
+ # Define extra states for tracking
390
+ distance = ox.State("distance", shape=(1,))
391
+ distance.initial = np.array([0.0])
392
+
393
+ energy = ox.State("energy", shape=(1,))
394
+ energy.initial = np.array([0.0])
395
+
396
+ # Define their dynamics (using optimization states/controls)
397
+ # Assume v and u are optimization states/controls
398
+ dynamics_extra = {
399
+ "distance": ox.Norm(v), # Integrate velocity magnitude
400
+ "energy": ox.Norm(u)**2 # Integrate squared control
401
+ }
402
+
403
+ dyn_prop, states_prop, controls_prop, params = add_propagation_states(
404
+ dynamics_extra=dynamics_extra,
405
+ states_extra=[distance, energy],
406
+ dynamics_opt=dynamics_aug,
407
+ states_opt=states_aug,
408
+ controls_opt=controls_aug,
409
+ parameters=parameters
410
+ )
411
+
412
+ # Now states_prop includes all states for forward simulation
413
+ # distance and energy will be integrated during propagation
414
+
415
+ Note:
416
+ The extra states should have initial conditions set, as they will be
417
+ integrated from these initial values during propagation.
418
+ """
419
+
420
+ # Make copies to avoid mutating inputs
421
+ states_extra = list(states_extra)
422
+ dynamics_extra = dict(dynamics_extra)
423
+ parameters = dict(parameters)
424
+
425
+ # ==================== PHASE 1: Validate Extra States ====================
426
+
427
+ # Validate that extra states don't conflict with optimization state names
428
+ opt_state_names = {s.name for s in states_opt}
429
+ extra_state_names = {s.name for s in states_extra}
430
+ conflicts = opt_state_names & extra_state_names
431
+ if conflicts:
432
+ raise ValueError(
433
+ f"Extra propagation states conflict with optimization states: {conflicts}. "
434
+ f"Only specify additional states, not optimization states."
435
+ )
436
+
437
+ # Validate dynamics dict for extra states
438
+ validate_dynamics_dict(dynamics_extra, states_extra)
439
+ validate_dynamics_dict_dimensions(dynamics_extra, states_extra)
440
+
441
+ # ==================== PHASE 2: Process Extra Dynamics ====================
442
+
443
+ # Convert extra dynamics to expression
444
+ _, dynamics_extra_concat = convert_dynamics_dict_to_expr(dynamics_extra, states_extra)
445
+
446
+ # Validate and canonicalize
447
+ validate_variable_names([dynamics_extra_concat])
448
+
449
+ # Temporarily assign slices for validation (will be recalculated below)
450
+ collect_and_assign_slices(states_extra, controls_opt)
451
+ validate_shapes([dynamics_extra_concat])
452
+ validate_dynamics_dimension(dynamics_extra_concat, states_extra)
453
+ dynamics_extra_concat = dynamics_extra_concat.canonicalize()
454
+
455
+ # Collect any new parameter values from extra dynamics
456
+ def collect_param_values(expr):
457
+ if isinstance(expr, Parameter):
458
+ if expr.name not in parameters:
459
+ parameters[expr.name] = expr.value
460
+
461
+ traverse(dynamics_extra_concat, collect_param_values)
462
+
463
+ # ==================== PHASE 3: Concatenate with Optimization Dynamics ====================
464
+
465
+ # Concatenate: {opt dynamics, extra dynamics}
466
+ from openscvx.symbolic.expr import Concat
467
+
468
+ dynamics_prop = Concat(dynamics_opt, dynamics_extra_concat)
469
+
470
+ # Manually assign slices to extra states ONLY (don't modify optimization state slices)
471
+ # Extra states are appended after all optimization states
472
+ n_opt_states = states_opt[-1]._slice.stop if states_opt else 0
473
+ start_idx = n_opt_states
474
+ for state in states_extra:
475
+ end_idx = start_idx + state.shape[0]
476
+ state._slice = slice(start_idx, end_idx)
477
+ start_idx = end_idx
478
+
479
+ # Append extra states to optimization states
480
+ states_prop = states_opt + states_extra
481
+
482
+ # Propagation uses same controls as optimization
483
+ controls_prop = controls_opt
484
+
485
+ # ==================== Return Symbolic Outputs ====================
486
+
487
+ return (
488
+ dynamics_prop,
489
+ states_prop,
490
+ controls_prop,
491
+ parameters,
492
+ )
@@ -0,0 +1,92 @@
1
+ """Container for categorized symbolic constraints.
2
+
3
+ This module provides a dataclass to hold all symbolic constraint types in a
4
+ structured way before they are lowered to JAX/CVXPy.
5
+ """
6
+
7
+ from dataclasses import dataclass, field
8
+ from typing import TYPE_CHECKING, List, Union
9
+
10
+ if TYPE_CHECKING:
11
+ from openscvx.symbolic.expr import CTCS, Constraint, CrossNodeConstraint, NodalConstraint
12
+
13
+
14
+ @dataclass
15
+ class ConstraintSet:
16
+ """Container for categorized symbolic constraints.
17
+
18
+ This dataclass holds all symbolic constraint types in a structured way,
19
+ providing type safety and a clear API for accessing constraint categories.
20
+ This is a pre-lowering container - after lowering, constraints live in
21
+ LoweredJaxConstraints and LoweredCvxpyConstraints.
22
+
23
+ The constraint set supports two lifecycle stages:
24
+
25
+ 1. **Before preprocessing**: Raw constraints live in `unsorted`
26
+ 2. **After preprocessing**: `unsorted` is empty, constraints are categorized
27
+
28
+ Use `is_categorized` to check which stage the constraint set is in.
29
+
30
+ Attributes:
31
+ unsorted: Raw constraints before categorization. Empty after preprocessing.
32
+ ctcs: CTCS (continuous-time) constraints.
33
+ nodal: Non-convex nodal constraints (will be lowered to JAX).
34
+ nodal_convex: Convex nodal constraints (will be lowered to CVXPy).
35
+ cross_node: Non-convex cross-node constraints (will be lowered to JAX).
36
+ cross_node_convex: Convex cross-node constraints (will be lowered to CVXPy).
37
+
38
+ Example:
39
+ Before preprocessing (raw constraints)::
40
+
41
+ constraints = ConstraintSet(unsorted=[c1, c2, c3])
42
+ assert not constraints.is_categorized
43
+
44
+ After preprocessing (categorized)::
45
+
46
+ # preprocess_symbolic_problem drains unsorted -> fills categories
47
+ assert constraints.is_categorized
48
+ for c in constraints.nodal:
49
+ # Process non-convex nodal constraints
50
+ pass
51
+ """
52
+
53
+ # Raw constraints before categorization (empty after preprocessing)
54
+ unsorted: List[Union["Constraint", "CTCS"]] = field(default_factory=list)
55
+
56
+ # Categorized symbolic constraints (populated by preprocessing)
57
+ ctcs: List["CTCS"] = field(default_factory=list)
58
+ nodal: List["NodalConstraint"] = field(default_factory=list)
59
+ nodal_convex: List["NodalConstraint"] = field(default_factory=list)
60
+ cross_node: List["CrossNodeConstraint"] = field(default_factory=list)
61
+ cross_node_convex: List["CrossNodeConstraint"] = field(default_factory=list)
62
+
63
+ @property
64
+ def is_categorized(self) -> bool:
65
+ """True if all constraints have been sorted into categories.
66
+
67
+ After preprocessing, `unsorted` should be empty and all constraints
68
+ should be in their appropriate category lists.
69
+ """
70
+ return len(self.unsorted) == 0
71
+
72
+ def __bool__(self) -> bool:
73
+ """Return True if any constraint list is non-empty."""
74
+ return bool(
75
+ self.unsorted
76
+ or self.ctcs
77
+ or self.nodal
78
+ or self.nodal_convex
79
+ or self.cross_node
80
+ or self.cross_node_convex
81
+ )
82
+
83
+ def __len__(self) -> int:
84
+ """Return total number of constraints across all lists."""
85
+ return (
86
+ len(self.unsorted)
87
+ + len(self.ctcs)
88
+ + len(self.nodal)
89
+ + len(self.nodal_convex)
90
+ + len(self.cross_node)
91
+ + len(self.cross_node_convex)
92
+ )