openscvx 0.3.2.dev170__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

Files changed (79) hide show
  1. openscvx/__init__.py +123 -0
  2. openscvx/_version.py +34 -0
  3. openscvx/algorithms/__init__.py +92 -0
  4. openscvx/algorithms/autotuning.py +24 -0
  5. openscvx/algorithms/base.py +351 -0
  6. openscvx/algorithms/optimization_results.py +215 -0
  7. openscvx/algorithms/penalized_trust_region.py +384 -0
  8. openscvx/config.py +437 -0
  9. openscvx/discretization/__init__.py +47 -0
  10. openscvx/discretization/discretization.py +236 -0
  11. openscvx/expert/__init__.py +23 -0
  12. openscvx/expert/byof.py +326 -0
  13. openscvx/expert/lowering.py +419 -0
  14. openscvx/expert/validation.py +357 -0
  15. openscvx/integrators/__init__.py +48 -0
  16. openscvx/integrators/runge_kutta.py +281 -0
  17. openscvx/lowered/__init__.py +30 -0
  18. openscvx/lowered/cvxpy_constraints.py +23 -0
  19. openscvx/lowered/cvxpy_variables.py +124 -0
  20. openscvx/lowered/dynamics.py +34 -0
  21. openscvx/lowered/jax_constraints.py +133 -0
  22. openscvx/lowered/parameters.py +54 -0
  23. openscvx/lowered/problem.py +70 -0
  24. openscvx/lowered/unified.py +718 -0
  25. openscvx/plotting/__init__.py +63 -0
  26. openscvx/plotting/plotting.py +756 -0
  27. openscvx/plotting/scp_iteration.py +299 -0
  28. openscvx/plotting/viser/__init__.py +126 -0
  29. openscvx/plotting/viser/animated.py +605 -0
  30. openscvx/plotting/viser/plotly_integration.py +333 -0
  31. openscvx/plotting/viser/primitives.py +355 -0
  32. openscvx/plotting/viser/scp.py +459 -0
  33. openscvx/plotting/viser/server.py +112 -0
  34. openscvx/problem.py +734 -0
  35. openscvx/propagation/__init__.py +60 -0
  36. openscvx/propagation/post_processing.py +104 -0
  37. openscvx/propagation/propagation.py +248 -0
  38. openscvx/solvers/__init__.py +51 -0
  39. openscvx/solvers/cvxpy.py +226 -0
  40. openscvx/symbolic/__init__.py +9 -0
  41. openscvx/symbolic/augmentation.py +630 -0
  42. openscvx/symbolic/builder.py +492 -0
  43. openscvx/symbolic/constraint_set.py +92 -0
  44. openscvx/symbolic/expr/__init__.py +222 -0
  45. openscvx/symbolic/expr/arithmetic.py +517 -0
  46. openscvx/symbolic/expr/array.py +632 -0
  47. openscvx/symbolic/expr/constraint.py +796 -0
  48. openscvx/symbolic/expr/control.py +135 -0
  49. openscvx/symbolic/expr/expr.py +720 -0
  50. openscvx/symbolic/expr/lie/__init__.py +87 -0
  51. openscvx/symbolic/expr/lie/adjoint.py +357 -0
  52. openscvx/symbolic/expr/lie/se3.py +172 -0
  53. openscvx/symbolic/expr/lie/so3.py +138 -0
  54. openscvx/symbolic/expr/linalg.py +279 -0
  55. openscvx/symbolic/expr/math.py +699 -0
  56. openscvx/symbolic/expr/spatial.py +209 -0
  57. openscvx/symbolic/expr/state.py +607 -0
  58. openscvx/symbolic/expr/stl.py +136 -0
  59. openscvx/symbolic/expr/variable.py +321 -0
  60. openscvx/symbolic/hashing.py +112 -0
  61. openscvx/symbolic/lower.py +760 -0
  62. openscvx/symbolic/lowerers/__init__.py +106 -0
  63. openscvx/symbolic/lowerers/cvxpy.py +1302 -0
  64. openscvx/symbolic/lowerers/jax.py +1382 -0
  65. openscvx/symbolic/preprocessing.py +757 -0
  66. openscvx/symbolic/problem.py +110 -0
  67. openscvx/symbolic/time.py +116 -0
  68. openscvx/symbolic/unified.py +420 -0
  69. openscvx/utils/__init__.py +20 -0
  70. openscvx/utils/cache.py +131 -0
  71. openscvx/utils/caching.py +210 -0
  72. openscvx/utils/printing.py +301 -0
  73. openscvx/utils/profiling.py +37 -0
  74. openscvx/utils/utils.py +100 -0
  75. openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
  76. openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
  77. openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
  78. openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
  79. openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
@@ -0,0 +1,630 @@
1
+ """State and dynamics augmentation for continuous-time constraint satisfaction.
2
+
3
+ This module provides utilities for augmenting trajectory optimization problems with
4
+ additional states and dynamics to handle continuous-time constraint satisfaction (CTCS).
5
+ The CTCS method enforces path constraints continuously along the trajectory rather than
6
+ just at discretization nodes.
7
+
8
+ Key functionality:
9
+ - CTCS constraint grouping: Sort and group CTCS constraints by time intervals
10
+ - Constraint separation: Separate CTCS, nodal, and convex constraints
11
+ - Vector decomposition: Decompose vector constraints into scalar components
12
+ - Time augmentation: Add time state with appropriate dynamics and constraints
13
+ - CTCS dynamics augmentation: Add augmented states and time dilation control
14
+
15
+ The augmentation process transforms the original dynamics x_dot = f(x, u) into an
16
+ augmented system with additional states for constraint satisfaction and time dilation.
17
+
18
+ Architecture:
19
+ The CTCS method works by:
20
+
21
+ 1. Grouping constraints by time interval and assigning index (idx)
22
+ 2. Creating augmented states (one per constraint group)
23
+ 3. Adding penalty dynamics: aug_dot = penalty(constraint_violation)
24
+ 4. Adding time dilation control to slow down near constraint boundaries
25
+
26
+ Example:
27
+ Augmenting dynamics with CTCS constraints::
28
+
29
+ import openscvx as ox
30
+
31
+ # Define problem
32
+ x = ox.State("x", shape=(3,))
33
+ u = ox.Control("u", shape=(2,))
34
+
35
+ # Create dynamics
36
+ xdot = u @ A # Some dynamics expression
37
+
38
+ # Define path constraint
39
+ path_constraint = (ox.Norm(x) <= 1.0).over((0, 50)) # CTCS constraint
40
+
41
+ # Augment dynamics with CTCS
42
+ from openscvx.symbolic.augmentation import augment_dynamics_with_ctcs
43
+
44
+ xdot_aug, states_aug, controls_aug = augment_dynamics_with_ctcs(
45
+ xdot=xdot,
46
+ states=[x],
47
+ controls=[u],
48
+ constraints_ctcs=[path_constraint],
49
+ N=50
50
+ )
51
+ # xdot_aug now includes augmented state dynamics
52
+ # states_aug includes original states + augmented states
53
+ # controls_aug includes original controls + time dilation
54
+ """
55
+
56
+ from typing import Dict, List, Optional, Tuple
57
+
58
+ import numpy as np
59
+
60
+ from openscvx.symbolic.constraint_set import ConstraintSet
61
+ from openscvx.symbolic.expr import (
62
+ CTCS,
63
+ Add,
64
+ Concat,
65
+ Constraint,
66
+ CrossNodeConstraint,
67
+ Expr,
68
+ Index,
69
+ NodalConstraint,
70
+ )
71
+ from openscvx.symbolic.expr.control import Control
72
+ from openscvx.symbolic.expr.state import State
73
+
74
+
75
+ def sort_ctcs_constraints(
76
+ constraints_ctcs: List[CTCS],
77
+ ) -> Tuple[List[CTCS], List[Tuple[int, int]], int]:
78
+ """Sort and group CTCS constraints by time interval and assign indices.
79
+
80
+ Groups CTCS constraints by their time intervals (nodes) and assigns a unique
81
+ index (idx) to each group. Constraints with the same time interval can share
82
+ an augmented state (same idx), while constraints with different intervals must
83
+ have different augmented states.
84
+
85
+ Grouping rules:
86
+ - Constraints with the same node interval can share an idx
87
+ - Constraints with different node intervals must have different idx values
88
+ - idx values must form a contiguous block starting from 0
89
+ - Unspecified idx values are automatically assigned
90
+ - User-specified idx values are validated for consistency
91
+
92
+ Args:
93
+ constraints_ctcs: List of CTCS constraints to sort and group
94
+
95
+ Returns:
96
+ Tuple of:
97
+ - List of CTCS constraints with idx assigned to each
98
+ - List of node intervals (start, end) in ascending idx order
99
+ - Number of augmented states needed (number of unique idx values)
100
+
101
+ Raises:
102
+ ValueError: If user-specified idx values are inconsistent or non-contiguous
103
+
104
+ Example:
105
+ Sort CTCS constraints by interval and index:
106
+
107
+ constraint1 = (x <= 5).over((0, 50)) # Auto-assigned idx
108
+ constraint2 = (y <= 10).over((0, 50)) # Same interval, same idx
109
+ constraint3 = (z <= 15).over((20, 80)) # Different interval, different idx
110
+ sorted_ctcs, intervals, n_aug = sort_ctcs_constraints([c1, c2, c3])
111
+ # constraint1.idx = 0, constraint2.idx = 0, constraint3.idx = 1
112
+ # intervals = [(0, 50), (20, 80)]
113
+ # n_aug = 2
114
+ """
115
+ idx_to_nodes: Dict[int, Tuple[int, int]] = {}
116
+ next_idx = 0
117
+
118
+ for c in constraints_ctcs:
119
+ key = c.nodes
120
+
121
+ if c.idx is not None:
122
+ # User supplied an identifier: ensure it always points to the same interval
123
+ if c.idx in idx_to_nodes:
124
+ if idx_to_nodes[c.idx] != key:
125
+ raise ValueError(
126
+ f"idx={c.idx} was first used with interval={idx_to_nodes[c.idx]}, "
127
+ f"but now you gave it interval={key}"
128
+ )
129
+ else:
130
+ # When idx is explicitly provided, always create a separate group
131
+ # even if nodes are the same - this allows multiple constraint groups
132
+ # with the same node interval but different idx values
133
+ idx_to_nodes[c.idx] = key
134
+ else:
135
+ # No identifier: see if this interval already has one
136
+ for existing_id, nodes in idx_to_nodes.items():
137
+ if nodes == key:
138
+ c.idx = existing_id
139
+ break
140
+ else:
141
+ # Brand-new interval: pick the next free auto-id
142
+ while next_idx in idx_to_nodes:
143
+ next_idx += 1
144
+ c.idx = next_idx
145
+ idx_to_nodes[next_idx] = key
146
+ next_idx += 1
147
+
148
+ # Validate that idx values form a contiguous block starting from 0
149
+ ordered_ids = sorted(idx_to_nodes.keys())
150
+ expected_ids = list(range(len(ordered_ids)))
151
+ if ordered_ids != expected_ids:
152
+ raise ValueError(
153
+ f"CTCS constraint idx values must form a contiguous block starting from 0. "
154
+ f"Got {ordered_ids}, expected {expected_ids}"
155
+ )
156
+
157
+ # Extract intervals in ascending idx order
158
+ node_intervals = [idx_to_nodes[i] for i in ordered_ids]
159
+ num_augmented_states = len(ordered_ids)
160
+
161
+ return constraints_ctcs, node_intervals, num_augmented_states
162
+
163
+
164
+ def separate_constraints(constraint_set: ConstraintSet, n_nodes: int) -> ConstraintSet:
165
+ """Separate and categorize constraints by type and convexity.
166
+
167
+ Moves constraints from `constraint_set.unsorted` into their appropriate
168
+ category fields (ctcs, nodal, nodal_convex, cross_node, cross_node_convex).
169
+
170
+ Bare Constraint objects are automatically categorized:
171
+ - If they contain NodeReferences (from .at(k) calls), they become CrossNodeConstraint
172
+ - Otherwise, they become NodalConstraint applied at all nodes
173
+
174
+ Constraints within CTCS wrappers that have check_nodally=True are also extracted
175
+ and added to the nodal constraint lists.
176
+
177
+ Args:
178
+ constraint_set: ConstraintSet with raw constraints in `unsorted` field
179
+ n_nodes: Total number of nodes in the trajectory
180
+
181
+ Returns:
182
+ The same ConstraintSet with `unsorted` drained and categories populated
183
+
184
+ Raises:
185
+ ValueError: If a constraint is not one of the expected types
186
+ ValueError: If a NodalConstraint contains NodeReferences (use bare Constraint instead)
187
+ ValueError: If a CTCS constraint contains NodeReferences
188
+
189
+ Example:
190
+ Separate and categorize constraints::
191
+
192
+ x = ox.State("x", shape=(3,))
193
+ constraint_set = ConstraintSet(unsorted=[
194
+ (x <= 5).over((0, 50)), # CTCS
195
+ (x >= 0).at([0, 10, 20]), # NodalConstraint
196
+ ox.Norm(x) <= 1, # Bare -> all nodes
197
+ x.at(5) - x.at(4) <= 0.1, # Bare with NodeRef -> cross-node
198
+ ])
199
+ separate_constraints(constraint_set, n_nodes=50)
200
+ assert constraint_set.is_categorized
201
+ # Access via: constraint_set.ctcs, constraint_set.nodal, etc.
202
+ """
203
+ from openscvx.symbolic.lower import _contains_node_reference
204
+
205
+ # Process all constraints from unsorted
206
+ for c in constraint_set.unsorted:
207
+ if isinstance(c, CTCS):
208
+ # Validate that CTCS constraints don't contain NodeReferences
209
+ if _contains_node_reference(c.constraint):
210
+ raise ValueError(
211
+ "CTCS constraints cannot contain NodeReferences (.at(k)). "
212
+ "Cross-node constraints should be specified as bare Constraint objects. "
213
+ f"Constraint: {c.constraint}"
214
+ )
215
+ # Normalize None to full horizon
216
+ c.nodes = c.nodes or (0, n_nodes)
217
+ constraint_set.ctcs.append(c)
218
+
219
+ elif isinstance(c, NodalConstraint):
220
+ # NodalConstraint means user explicitly called .at([...])
221
+ # Cross-node constraints should NOT use .at([...]) wrapper
222
+ if _contains_node_reference(c.constraint):
223
+ raise ValueError(
224
+ f"Cross-node constraints should not use .at([...]) wrapper. "
225
+ f"The constraint already references specific nodes via .at(k) inside the "
226
+ f"expression. Remove the outer .at([...]) wrapper and use the bare "
227
+ f"constraint directly. "
228
+ f"Constraint: {c.constraint}"
229
+ )
230
+
231
+ # Regular nodal constraint - categorize by convexity
232
+ if c.constraint.is_convex:
233
+ constraint_set.nodal_convex.append(c)
234
+ else:
235
+ constraint_set.nodal.append(c)
236
+
237
+ elif isinstance(c, Constraint):
238
+ # Bare constraint - check if it's a cross-node constraint
239
+ if _contains_node_reference(c):
240
+ # Cross-node constraint: wrap in CrossNodeConstraint
241
+ cross_node = CrossNodeConstraint(c)
242
+ if c.is_convex:
243
+ constraint_set.cross_node_convex.append(cross_node)
244
+ else:
245
+ constraint_set.cross_node.append(cross_node)
246
+ else:
247
+ # Regular constraint: apply at all nodes
248
+ all_nodes = list(range(n_nodes))
249
+ nodal_constraint = NodalConstraint(c, all_nodes)
250
+ if c.is_convex:
251
+ constraint_set.nodal_convex.append(nodal_constraint)
252
+ else:
253
+ constraint_set.nodal.append(nodal_constraint)
254
+
255
+ else:
256
+ raise ValueError(
257
+ "Constraints must be `Constraint`, `NodalConstraint`, or `CTCS`, "
258
+ f"got {type(c).__name__}"
259
+ )
260
+
261
+ # Clear unsorted now that all have been categorized
262
+ constraint_set.unsorted = []
263
+
264
+ # Add nodal constraints from CTCS constraints that have check_nodally=True
265
+ ctcs_nodal_constraints = get_nodal_constraints_from_ctcs(constraint_set.ctcs)
266
+ for constraint, interval in ctcs_nodal_constraints:
267
+ # CTCS check_nodally constraints cannot have NodeReferences (validated above)
268
+ # Convert CTCS interval (start, end) to list of nodes [start, start+1, ..., end-1]
269
+ interval_nodes = list(range(interval[0], interval[1]))
270
+ nodal_constraint = NodalConstraint(constraint, interval_nodes)
271
+
272
+ if constraint.is_convex:
273
+ constraint_set.nodal_convex.append(nodal_constraint)
274
+ else:
275
+ constraint_set.nodal.append(nodal_constraint)
276
+
277
+ # Validate cross-node constraints (bounds and variable consistency)
278
+ from openscvx.symbolic.preprocessing import validate_cross_node_constraint
279
+
280
+ for cross_node_constraint in constraint_set.cross_node + constraint_set.cross_node_convex:
281
+ validate_cross_node_constraint(cross_node_constraint, n_nodes)
282
+
283
+ return constraint_set
284
+
285
+
286
+ def decompose_vector_nodal_constraints(
287
+ constraints_nodal: List[NodalConstraint],
288
+ ) -> List[NodalConstraint]:
289
+ """Decompose vector-valued nodal constraints into scalar constraints.
290
+
291
+ Decomposes vector constraints into individual scalar constraints, which is necessary
292
+ for nonconvex nodal constraints that are lowered to JAX functions. The JAX-to-CVXPY
293
+ interface expects scalar constraint values at each node.
294
+
295
+ For example, a constraint with shape (3,) is decomposed into 3 separate scalar
296
+ constraints using indexing. CTCS constraints don't need decomposition since they
297
+ handle vector values internally.
298
+
299
+ Args:
300
+ constraints_nodal (List[NodalConstraint]): List of NodalConstraint objects
301
+ (must be canonicalized)
302
+
303
+ Returns:
304
+ List of NodalConstraint objects with vector constraints decomposed into scalars.
305
+ Scalar constraints are passed through unchanged.
306
+
307
+ Note:
308
+ Constraints are assumed to be in canonical form: residual <= 0 or residual == 0,
309
+ where residual is the lhs of the constraint.
310
+
311
+ Example:
312
+ Decompose vector constraint into 3 constraints:
313
+
314
+ x = ox.State("x", shape=(3,))
315
+ constraint = (x <= 5).at([0, 10, 20]) # Vector constraint, shape (3,)
316
+ decomposed = decompose_vector_nodal_constraints([constraint])
317
+ # Returns 3 constraints: x[0] <= 5, x[1] <= 5, x[2] <= 5
318
+ """
319
+ decomposed_constraints = []
320
+
321
+ for nodal_constraint in constraints_nodal:
322
+ constraint = nodal_constraint.constraint
323
+ nodes = nodal_constraint.nodes
324
+
325
+ try:
326
+ # Get the shape of the constraint residual
327
+ # Canonicalized constraints are in form: residual <= 0 or residual == 0
328
+ residual_shape = constraint.lhs.check_shape()
329
+
330
+ # Check if this is a vector constraint
331
+ # Decompose ALL vector-shaped constraints (including shape=(1,)) to avoid
332
+ # vmap adding an extra dimension when stacking results
333
+ if len(residual_shape) > 0:
334
+ # Vector constraint - decompose into scalar constraints
335
+ total_elements = int(np.prod(residual_shape))
336
+
337
+ for i in range(total_elements):
338
+ # Create indexed version: residual[i] <= 0 or residual[i] == 0
339
+ indexed_lhs = Index(constraint.lhs, i)
340
+ indexed_rhs = constraint.rhs # Should be Constant(0)
341
+ indexed_constraint = constraint.__class__(indexed_lhs, indexed_rhs)
342
+ decomposed_constraints.append(NodalConstraint(indexed_constraint, nodes))
343
+ else:
344
+ # Scalar constraint - keep as is
345
+ decomposed_constraints.append(nodal_constraint)
346
+
347
+ except Exception:
348
+ # If shape analysis fails, keep original constraint for backward compatibility
349
+ decomposed_constraints.append(nodal_constraint)
350
+
351
+ return decomposed_constraints
352
+
353
+
354
+ def get_nodal_constraints_from_ctcs(
355
+ constraints_ctcs: List[CTCS],
356
+ ) -> List[tuple[Constraint, tuple[int, int]]]:
357
+ """Extract constraints from CTCS wrappers that should be checked nodally.
358
+
359
+ Some CTCS constraints have the check_nodally flag set, indicating that the
360
+ underlying constraint should be enforced both continuously (via CTCS) and
361
+ discretely at the nodes. This function extracts those underlying constraints
362
+ along with their node intervals.
363
+
364
+ Args:
365
+ constraints_ctcs: List of CTCS constraint wrappers
366
+
367
+ Returns:
368
+ List of tuples (constraint, nodes) where:
369
+ - constraint: The underlying Constraint object from CTCS with check_nodally=True
370
+ - nodes: The (start, end) interval from the CTCS wrapper
371
+
372
+ Example:
373
+ Extract CTCS constraint that should also be checked at nodes:
374
+
375
+ x = ox.State("x", shape=(3,))
376
+ constraint = (x <= 5).over((10, 50), check_nodally=True)
377
+ nodal = get_nodal_constraints_from_ctcs([constraint])
378
+
379
+ Returns [(x <= 5, (10, 50))] to be enforced at nodes 10 through 49
380
+ """
381
+ nodal_ctcs = []
382
+ for ctcs in constraints_ctcs:
383
+ if ctcs.check_nodally:
384
+ nodal_ctcs.append((ctcs.constraint, ctcs.nodes))
385
+ return nodal_ctcs
386
+
387
+
388
+ def augment_with_time_state(
389
+ states: List[State],
390
+ constraints: ConstraintSet,
391
+ time_initial: float | tuple,
392
+ time_final: float | tuple,
393
+ time_min: float,
394
+ time_max: float,
395
+ N: int,
396
+ time_scaling_min: Optional[float] = None,
397
+ time_scaling_max: Optional[float] = None,
398
+ ) -> Tuple[List[State], ConstraintSet]:
399
+ """Augment problem with a time state variable.
400
+
401
+ Creates a time state variable if one doesn't already exist and adds it to the
402
+ states list. Also adds CTCS constraints to enforce time bounds continuously
403
+ throughout the trajectory.
404
+
405
+ The time state tracks physical time along the trajectory and is used for
406
+ time-optimal control problems. Boundary conditions can be fixed values or
407
+ free variables with initial guesses.
408
+
409
+ Args:
410
+ states: List of State objects (will not be modified, copy is returned)
411
+ constraints: ConstraintSet with unsorted constraints (will be modified in place)
412
+ time_initial: Initial time boundary condition:
413
+ - float: Fixed initial time
414
+ - tuple: ("free", guess) for free initial time with initial guess
415
+ time_final: Final time boundary condition (same format as time_initial)
416
+ time_min: Minimum bound for time variable throughout trajectory
417
+ time_max: Maximum bound for time variable throughout trajectory
418
+ N: Number of discretization nodes (for initial guess generation)
419
+
420
+ Returns:
421
+ Tuple of:
422
+ - Updated states list (original + time state if created)
423
+ - The same ConstraintSet with time CTCS constraints added to unsorted
424
+
425
+ Note:
426
+ If a state named "time" already exists, it is not modified and no
427
+ constraints are added.
428
+
429
+ Example:
430
+ Get augmented states::
431
+
432
+ x = ox.State("x", shape=(3,))
433
+ constraints = ConstraintSet()
434
+ states_aug, constraints = augment_with_time_state(
435
+ states=[x],
436
+ constraints=constraints,
437
+ time_initial=0.0,
438
+ time_final=("free", 10.0),
439
+ time_min=0.0,
440
+ time_max=100.0,
441
+ N=50
442
+ )
443
+
444
+ states_aug now includes time state with initial=0, final=free
445
+ """
446
+ # Create copy of states to avoid mutating input
447
+ states_aug = list(states)
448
+
449
+ # Check if a time state already exists
450
+ time_state = None
451
+ for state in states_aug:
452
+ if state.name == "time":
453
+ time_state = state
454
+ break
455
+
456
+ if time_state is None:
457
+ # Create time State only if it doesn't exist
458
+ time_state = State("time", shape=(1,))
459
+ time_state.min = np.array([time_min])
460
+ time_state.max = np.array([time_max])
461
+
462
+ # Set time boundary conditions
463
+ time_state.initial = [time_initial]
464
+ time_state.final = [time_final]
465
+
466
+ # Create initial guess for time (linear interpolation)
467
+ time_guess_start = (
468
+ time_state.initial[0]
469
+ if isinstance(time_state.initial[0], (int, float))
470
+ else time_state.initial[0][1]
471
+ )
472
+ time_guess_end = (
473
+ time_state.final[0]
474
+ if isinstance(time_state.final[0], (int, float))
475
+ else time_state.final[0][1]
476
+ )
477
+ time_state.guess = np.linspace(time_guess_start, time_guess_end, N).reshape(-1, 1)
478
+
479
+ # Transfer scaling_min/max from Time object if provided
480
+ if time_scaling_min is not None:
481
+ time_state.scaling_min = np.array([time_scaling_min])
482
+ if time_scaling_max is not None:
483
+ time_state.scaling_max = np.array([time_scaling_max])
484
+
485
+ # Add time state to the list
486
+ states_aug.append(time_state)
487
+
488
+ # Add CTCS constraints for time bounds to unsorted
489
+ constraints.unsorted.append(CTCS(time_state <= time_state.max))
490
+ constraints.unsorted.append(CTCS(time_state.min <= time_state))
491
+
492
+ return states_aug, constraints
493
+
494
+
495
+ def augment_dynamics_with_ctcs(
496
+ xdot: Expr,
497
+ states: List[State],
498
+ controls: List[Control],
499
+ constraints_ctcs: List[CTCS],
500
+ N: int,
501
+ licq_min: float = 0.0,
502
+ licq_max: float = 1e-4,
503
+ time_dilation_factor_min: float = 0.3,
504
+ time_dilation_factor_max: float = 3.0,
505
+ ) -> Tuple[Expr, List[State], List[Control]]:
506
+ """Augment dynamics with continuous-time constraint satisfaction states.
507
+
508
+ Implements the CTCS method by adding augmented states and time dilation control
509
+ to the original dynamics. For each group of CTCS constraints, an augmented state
510
+ is created whose dynamics are the penalty function of constraint violations.
511
+
512
+ The CTCS method enforces path constraints continuously by:
513
+ 1. Creating augmented states with dynamics = penalty(constraint_violation)
514
+ 2. Constraining augmented states to stay near zero (LICQ condition)
515
+ 3. Adding time dilation control to slow down near constraint boundaries
516
+
517
+ The augmented dynamics become:
518
+ x_dot = f(x, u)
519
+ aug_dot = penalty(g(x, u)) # For each constraint group
520
+ time_dot = time_dilation
521
+
522
+ Args:
523
+ xdot: Original dynamics expression for states
524
+ states: List of state variables (must include a state named "time")
525
+ controls: List of control variables
526
+ constraints_ctcs: List of CTCS constraints (should be sorted and grouped)
527
+ N: Number of discretization nodes
528
+ licq_min: Minimum bound for augmented states (default: 0.0)
529
+ licq_max: Maximum bound for augmented states (default: 1e-4)
530
+ time_dilation_factor_min: Minimum time dilation factor (default: 0.3)
531
+ time_dilation_factor_max: Maximum time dilation factor (default: 3.0)
532
+
533
+ Returns:
534
+ Tuple of:
535
+ - Augmented dynamics expression (original + augmented state dynamics)
536
+ - Updated states list (original + augmented states)
537
+ - Updated controls list (original + time dilation control)
538
+
539
+ Raises:
540
+ ValueError: If no state named "time" is found in the states list
541
+
542
+ Example:
543
+ Augment dynamics with CTCS penalty states:
544
+
545
+ x = ox.State("x", shape=(3,))
546
+ u = ox.Control("u", shape=(2,))
547
+ time = ox.State("time", shape=(1,))
548
+ xdot = u @ A # Some dynamics
549
+ constraint = (ox.Norm(x) <= 1.0).over((0, 50))
550
+ xdot_aug, states_aug, controls_aug = augment_dynamics_with_ctcs(
551
+ xdot=xdot,
552
+ states=[x, time],
553
+ controls=[u],
554
+ constraints_ctcs=[constraint],
555
+ N=50
556
+ )
557
+
558
+ states_aug includes x, time, and _ctcs_aug_0,
559
+ controls_aug includes u and _time_dilation
560
+ """
561
+ # Copy the original states and controls lists
562
+ states_augmented = list(states)
563
+ controls_augmented = list(controls)
564
+
565
+ if constraints_ctcs:
566
+ # Group penalty expressions by idx (constraints should already be sorted)
567
+ penalty_groups: Dict[int, List[Expr]] = {}
568
+
569
+ for ctcs in constraints_ctcs:
570
+ # Keep the CTCS wrapper intact to preserve node interval information
571
+ # The JAX lowerer's visit_ctcs() method will handle the conditional logic
572
+
573
+ # TODO: In the future, apply scaling here if ctcs has a scaling attribute
574
+ # if hasattr(ctcs, 'scaling') and ctcs.scaling != 1.0:
575
+ # ctcs = scale_ctcs(ctcs, scaling_factor)
576
+
577
+ if ctcs.idx not in penalty_groups:
578
+ penalty_groups[ctcs.idx] = []
579
+ penalty_groups[ctcs.idx].append(ctcs)
580
+
581
+ # Create augmented state expressions for each group
582
+ augmented_state_exprs = []
583
+ for idx in sorted(penalty_groups.keys()):
584
+ penalty_terms = penalty_groups[idx]
585
+ if len(penalty_terms) == 1:
586
+ augmented_state_expr = penalty_terms[0]
587
+ else:
588
+ augmented_state_expr = Add(*penalty_terms)
589
+ augmented_state_exprs.append(augmented_state_expr)
590
+
591
+ # Calculate number of augmented states from the penalty groups
592
+ num_augmented_states = len(penalty_groups)
593
+
594
+ # Create augmented state variables
595
+ for idx in range(num_augmented_states):
596
+ aug_var = State(f"_ctcs_aug_{idx}", shape=(1,))
597
+ aug_var.initial = np.array([licq_min]) # Set initial to respect bounds
598
+ aug_var.final = [("free", 0)]
599
+ aug_var.min = np.array([licq_min])
600
+ aug_var.max = np.array([licq_max])
601
+ # Set guess to licq_min as well
602
+ aug_var.guess = np.full([N, 1], licq_min) # N x num augmented states
603
+ states_augmented.append(aug_var)
604
+
605
+ # Concatenate with original dynamics
606
+ xdot_aug = Concat(xdot, *augmented_state_exprs)
607
+ else:
608
+ xdot_aug = xdot
609
+
610
+ time_dilation = Control("_time_dilation", shape=(1,))
611
+
612
+ # Set up time dilation bounds and initial guess
613
+ # Find the time state by name
614
+ time_state = None
615
+ for state in states:
616
+ if state.name == "time":
617
+ time_state = state
618
+ break
619
+
620
+ if time_state is None:
621
+ raise ValueError("No state named 'time' found in states list")
622
+
623
+ time_final = time_state.final[0]
624
+ time_dilation.min = np.array([time_dilation_factor_min * time_final])
625
+ time_dilation.max = np.array([time_dilation_factor_max * time_final])
626
+ time_dilation.guess = np.ones([N, 1]) * time_final
627
+
628
+ controls_augmented.append(time_dilation)
629
+
630
+ return xdot_aug, states_augmented, controls_augmented