openscvx 0.3.2.dev170__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

Files changed (79) hide show
  1. openscvx/__init__.py +123 -0
  2. openscvx/_version.py +34 -0
  3. openscvx/algorithms/__init__.py +92 -0
  4. openscvx/algorithms/autotuning.py +24 -0
  5. openscvx/algorithms/base.py +351 -0
  6. openscvx/algorithms/optimization_results.py +215 -0
  7. openscvx/algorithms/penalized_trust_region.py +384 -0
  8. openscvx/config.py +437 -0
  9. openscvx/discretization/__init__.py +47 -0
  10. openscvx/discretization/discretization.py +236 -0
  11. openscvx/expert/__init__.py +23 -0
  12. openscvx/expert/byof.py +326 -0
  13. openscvx/expert/lowering.py +419 -0
  14. openscvx/expert/validation.py +357 -0
  15. openscvx/integrators/__init__.py +48 -0
  16. openscvx/integrators/runge_kutta.py +281 -0
  17. openscvx/lowered/__init__.py +30 -0
  18. openscvx/lowered/cvxpy_constraints.py +23 -0
  19. openscvx/lowered/cvxpy_variables.py +124 -0
  20. openscvx/lowered/dynamics.py +34 -0
  21. openscvx/lowered/jax_constraints.py +133 -0
  22. openscvx/lowered/parameters.py +54 -0
  23. openscvx/lowered/problem.py +70 -0
  24. openscvx/lowered/unified.py +718 -0
  25. openscvx/plotting/__init__.py +63 -0
  26. openscvx/plotting/plotting.py +756 -0
  27. openscvx/plotting/scp_iteration.py +299 -0
  28. openscvx/plotting/viser/__init__.py +126 -0
  29. openscvx/plotting/viser/animated.py +605 -0
  30. openscvx/plotting/viser/plotly_integration.py +333 -0
  31. openscvx/plotting/viser/primitives.py +355 -0
  32. openscvx/plotting/viser/scp.py +459 -0
  33. openscvx/plotting/viser/server.py +112 -0
  34. openscvx/problem.py +734 -0
  35. openscvx/propagation/__init__.py +60 -0
  36. openscvx/propagation/post_processing.py +104 -0
  37. openscvx/propagation/propagation.py +248 -0
  38. openscvx/solvers/__init__.py +51 -0
  39. openscvx/solvers/cvxpy.py +226 -0
  40. openscvx/symbolic/__init__.py +9 -0
  41. openscvx/symbolic/augmentation.py +630 -0
  42. openscvx/symbolic/builder.py +492 -0
  43. openscvx/symbolic/constraint_set.py +92 -0
  44. openscvx/symbolic/expr/__init__.py +222 -0
  45. openscvx/symbolic/expr/arithmetic.py +517 -0
  46. openscvx/symbolic/expr/array.py +632 -0
  47. openscvx/symbolic/expr/constraint.py +796 -0
  48. openscvx/symbolic/expr/control.py +135 -0
  49. openscvx/symbolic/expr/expr.py +720 -0
  50. openscvx/symbolic/expr/lie/__init__.py +87 -0
  51. openscvx/symbolic/expr/lie/adjoint.py +357 -0
  52. openscvx/symbolic/expr/lie/se3.py +172 -0
  53. openscvx/symbolic/expr/lie/so3.py +138 -0
  54. openscvx/symbolic/expr/linalg.py +279 -0
  55. openscvx/symbolic/expr/math.py +699 -0
  56. openscvx/symbolic/expr/spatial.py +209 -0
  57. openscvx/symbolic/expr/state.py +607 -0
  58. openscvx/symbolic/expr/stl.py +136 -0
  59. openscvx/symbolic/expr/variable.py +321 -0
  60. openscvx/symbolic/hashing.py +112 -0
  61. openscvx/symbolic/lower.py +760 -0
  62. openscvx/symbolic/lowerers/__init__.py +106 -0
  63. openscvx/symbolic/lowerers/cvxpy.py +1302 -0
  64. openscvx/symbolic/lowerers/jax.py +1382 -0
  65. openscvx/symbolic/preprocessing.py +757 -0
  66. openscvx/symbolic/problem.py +110 -0
  67. openscvx/symbolic/time.py +116 -0
  68. openscvx/symbolic/unified.py +420 -0
  69. openscvx/utils/__init__.py +20 -0
  70. openscvx/utils/cache.py +131 -0
  71. openscvx/utils/caching.py +210 -0
  72. openscvx/utils/printing.py +301 -0
  73. openscvx/utils/profiling.py +37 -0
  74. openscvx/utils/utils.py +100 -0
  75. openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
  76. openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
  77. openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
  78. openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
  79. openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
@@ -0,0 +1,796 @@
1
+ """Specialized constraint types for trajectory optimization.
2
+
3
+ This module provides advanced constraint specification mechanisms that extend the
4
+ basic Equality and Inequality constraints. These specialized constraint types enable
5
+ precise control over when and how constraints are enforced in discretized trajectory
6
+ optimization problems.
7
+
8
+ Key constraint types:
9
+ - **NodalConstraint:** Enforces constraints only at specific discrete time points (nodes) along
10
+ the trajectory. Useful for waypoint constraints, boundary conditions, and reducing computational
11
+ cost by selective enforcement.
12
+ - **CTCS (Continuous-Time Constraint Satisfaction):** Guarantees strict constraint satisfaction
13
+ throughout the entire continuous trajectory, not just at discrete nodes. Works by augmenting the
14
+ state vector with additional states whose dynamics integrate constraint violation penalties.
15
+ Essential for safety-critical applications where inter-node violations could be catastrophic.
16
+
17
+ Example:
18
+ Nodal constraints for waypoints::
19
+
20
+ import openscvx as ox
21
+
22
+ x = ox.State("x", shape=(3,))
23
+ target = [10, 5, 0]
24
+
25
+ # Enforce position constraint only at specific nodes
26
+ waypoint_constraint = (x == target).at([0, 10, 20])
27
+
28
+ Continuous-time constraint for obstacle avoidance::
29
+
30
+ obstacle_center = ox.Parameter("obs", shape=(2,), value=[5, 5])
31
+ obstacle_radius = 2.0
32
+
33
+ # Distance from obstacle must be > radius for ALL time
34
+ distance = ox.Norm(x[:2] - obstacle_center)
35
+ safety_constraint = (distance >= obstacle_radius).over((0, 100))
36
+ """
37
+
38
+ import hashlib
39
+ import struct
40
+ from typing import Optional, Tuple, Union
41
+
42
+ import numpy as np
43
+
44
+ from .arithmetic import Sub
45
+ from .expr import Constant, Expr
46
+ from .linalg import Sum
47
+
48
+
49
+ class Constraint(Expr):
50
+ """Abstract base class for optimization constraints.
51
+
52
+ Constraints represent relationships between expressions that must be satisfied
53
+ in the optimization problem. This base class provides common functionality for
54
+ both equality and inequality constraints.
55
+
56
+ Attributes:
57
+ lhs: Left-hand side expression
58
+ rhs: Right-hand side expression
59
+ is_convex: Flag indicating if the constraint is known to be convex
60
+
61
+ Note:
62
+ Constraints are canonicalized to standard form: (lhs - rhs) {op} 0
63
+ """
64
+
65
+ def __init__(self, lhs: Expr, rhs: Expr):
66
+ """Initialize a constraint.
67
+
68
+ Args:
69
+ lhs: Left-hand side expression
70
+ rhs: Right-hand side expression
71
+ """
72
+ self.lhs = lhs
73
+ self.rhs = rhs
74
+ self.is_convex = False
75
+
76
+ def children(self):
77
+ return [self.lhs, self.rhs]
78
+
79
+ def canonicalize(self) -> "Expr":
80
+ """Canonicalize constraint to standard form: (lhs - rhs) {op} 0.
81
+
82
+ This works for both Equality and Inequality by using type(self) to
83
+ construct the appropriate subclass type.
84
+ """
85
+ diff = Sub(self.lhs, self.rhs)
86
+ canon_diff = diff.canonicalize()
87
+ new_constraint = type(self)(canon_diff, Constant(np.array(0)))
88
+ new_constraint.is_convex = self.is_convex # Preserve convex flag
89
+ return new_constraint
90
+
91
+ def check_shape(self) -> Tuple[int, ...]:
92
+ """Check that constraint operands are broadcastable. Returns scalar shape."""
93
+ L_shape = self.lhs.check_shape()
94
+ R_shape = self.rhs.check_shape()
95
+
96
+ # Figure out their broadcasted shape (or error if incompatible)
97
+ try:
98
+ np.broadcast_shapes(L_shape, R_shape)
99
+ except ValueError as e:
100
+ constraint_type = type(self).__name__
101
+ raise ValueError(f"{constraint_type} not broadcastable: {L_shape} vs {R_shape}") from e
102
+
103
+ # Allow vector constraints - they're interpreted element-wise
104
+ # Return () as constraints always produce a scalar
105
+ return ()
106
+
107
+ def at(self, nodes: Union[list, tuple]):
108
+ """Apply this constraint only at specific discrete nodes.
109
+
110
+ Args:
111
+ nodes: List of node indices where the constraint should be enforced
112
+
113
+ Returns:
114
+ NodalConstraint wrapping this constraint with node specification
115
+ """
116
+ if isinstance(nodes, int):
117
+ nodes = [nodes]
118
+ return NodalConstraint(self, list(nodes))
119
+
120
+ def over(
121
+ self,
122
+ interval: tuple[int, int],
123
+ penalty: str = "squared_relu",
124
+ idx: Optional[int] = None,
125
+ check_nodally: bool = False,
126
+ ):
127
+ """Apply this constraint over a continuous interval using CTCS.
128
+
129
+ Args:
130
+ interval: Tuple of (start, end) node indices for the continuous interval
131
+ penalty: Penalty function type ("squared_relu", "huber", "smooth_relu")
132
+ idx: Optional grouping index for multiple augmented states
133
+ check_nodally: Whether to also enforce this constraint nodally
134
+
135
+ Returns:
136
+ CTCS constraint wrapping this constraint with interval specification
137
+ """
138
+ return CTCS(self, penalty=penalty, nodes=interval, idx=idx, check_nodally=check_nodally)
139
+
140
+ def convex(self) -> "Constraint":
141
+ """Mark this constraint as convex for CVXPy lowering.
142
+
143
+ Returns:
144
+ Self with convex flag set to True (enables method chaining)
145
+ """
146
+ self.is_convex = True
147
+ return self
148
+
149
+
150
+ class Equality(Constraint):
151
+ """Equality constraint for optimization problems.
152
+
153
+ Represents an equality constraint: lhs == rhs. Can be created using the ==
154
+ operator on Expr objects.
155
+
156
+ Example:
157
+ Define an Equality constraint:
158
+
159
+ x = ox.State("x", shape=(3,))
160
+ constraint = x == 0 # Creates Equality(x, Constant(0))
161
+ """
162
+
163
+ def __repr__(self):
164
+ return f"{self.lhs!r} == {self.rhs!r}"
165
+
166
+
167
+ class Inequality(Constraint):
168
+ """Inequality constraint for optimization problems.
169
+
170
+ Represents an inequality constraint: lhs <= rhs. Can be created using the <=
171
+ operator on Expr objects.
172
+
173
+ Example:
174
+ Define an Inequality constraint:
175
+
176
+ x = ox.State("x", shape=(3,))
177
+ constraint = x <= 10 # Creates Inequality(x, Constant(10))
178
+ """
179
+
180
+ def __repr__(self):
181
+ return f"{self.lhs!r} <= {self.rhs!r}"
182
+
183
+
184
+ class NodalConstraint(Expr):
185
+ """Wrapper for constraints enforced only at specific discrete trajectory nodes.
186
+
187
+ NodalConstraint allows selective enforcement of constraints at specific time points
188
+ (nodes) in a discretized trajectory, rather than enforcing them at every node.
189
+ This is useful for:
190
+
191
+ - Specifying waypoint constraints (e.g., pass through point X at node 10)
192
+ - Boundary conditions at non-standard locations
193
+ - Reducing computational cost by checking constraints less frequently
194
+ - Enforcing periodic constraints (e.g., every 5th node)
195
+
196
+ The wrapper maintains clean separation between the constraint's mathematical
197
+ definition and the specification of where it should be applied during optimization.
198
+
199
+ Note:
200
+ Bare Constraint objects (without .at() or .over()) are automatically converted
201
+ to NodalConstraints applied at all nodes during preprocessing.
202
+
203
+ Attributes:
204
+ constraint: The wrapped Constraint (Equality or Inequality) to enforce
205
+ nodes: List of integer node indices where the constraint is enforced
206
+
207
+ Example:
208
+ Enforce position constraint only at nodes 0, 10, and 20:
209
+
210
+ x = State("x", shape=(3,))
211
+ target = [10, 5, 0]
212
+ constraint = (x == target).at([0, 10, 20])
213
+
214
+ Equivalent using NodalConstraint directly:
215
+
216
+ constraint = NodalConstraint(x == target, nodes=[0, 10, 20])
217
+
218
+ Periodic constraint enforcement (every 10th node):
219
+
220
+ velocity_limit = (vel <= 100).at(list(range(0, 100, 10)))
221
+
222
+ Bare constraints are automatically applied at all nodes.
223
+ These are equivalent:
224
+
225
+ constraint1 = vel <= 100 # Auto-converted to all nodes
226
+ constraint2 = (vel <= 100).at(list(range(n_nodes)))
227
+ """
228
+
229
+ def __init__(self, constraint: Constraint, nodes: list[int]):
230
+ """Initialize a NodalConstraint.
231
+
232
+ Args:
233
+ constraint: The Constraint (Equality or Inequality) to enforce at specified nodes
234
+ nodes: List of integer node indices where the constraint should be enforced.
235
+ Automatically converts numpy integers to Python integers.
236
+
237
+ Raises:
238
+ TypeError: If constraint is not a Constraint instance
239
+ TypeError: If nodes is not a list
240
+ TypeError: If any node index is not an integer
241
+
242
+ Note:
243
+ Bounds checking for cross-node constraints (those containing NodeReference)
244
+ is performed later in the pipeline when N is known, via
245
+ validate_cross_node_constraint_bounds() in preprocessing.py.
246
+ """
247
+ if not isinstance(constraint, Constraint):
248
+ raise TypeError("NodalConstraint must wrap a Constraint")
249
+ if not isinstance(nodes, list):
250
+ raise TypeError("nodes must be a list of integers")
251
+
252
+ # Convert numpy integers to Python integers
253
+ converted_nodes = []
254
+ for n in nodes:
255
+ if isinstance(n, np.integer):
256
+ converted_nodes.append(int(n))
257
+ elif isinstance(n, int):
258
+ converted_nodes.append(n)
259
+ else:
260
+ raise TypeError("all node indices must be integers")
261
+
262
+ self.constraint = constraint
263
+ self.nodes = converted_nodes
264
+
265
+ def children(self):
266
+ """Return the wrapped constraint as the only child.
267
+
268
+ Returns:
269
+ list: Single-element list containing the wrapped constraint
270
+ """
271
+ return [self.constraint]
272
+
273
+ def canonicalize(self) -> "Expr":
274
+ """Canonicalize the wrapped constraint while preserving node specification.
275
+
276
+ Returns:
277
+ NodalConstraint: A new NodalConstraint with canonicalized inner constraint
278
+ """
279
+ canon_constraint = self.constraint.canonicalize()
280
+ return NodalConstraint(canon_constraint, self.nodes)
281
+
282
+ def check_shape(self) -> Tuple[int, ...]:
283
+ """Validate the wrapped constraint's shape.
284
+
285
+ NodalConstraint wraps a constraint without changing its computational meaning,
286
+ only specifying where it should be applied. Like all constraints, it produces
287
+ a scalar result.
288
+
289
+ Returns:
290
+ tuple: Empty tuple () representing scalar shape
291
+ """
292
+ # Validate the wrapped constraint's shape
293
+ self.constraint.check_shape()
294
+
295
+ # NodalConstraint produces a scalar like any constraint
296
+ return ()
297
+
298
+ def convex(self) -> "NodalConstraint":
299
+ """Mark the underlying constraint as convex for CVXPy lowering.
300
+
301
+ Returns:
302
+ Self with underlying constraint's convex flag set to True (enables method chaining)
303
+
304
+ Example:
305
+ Mark a constraint as convex:
306
+ constraint = (x <= 10).at([0, 5, 10]).convex()
307
+ """
308
+ self.constraint.convex()
309
+ return self
310
+
311
+ def _hash_into(self, hasher: "hashlib._Hash") -> None:
312
+ """Hash NodalConstraint including its node list.
313
+
314
+ Args:
315
+ hasher: A hashlib hash object to update
316
+ """
317
+ hasher.update(b"NodalConstraint")
318
+ # Hash the nodes list
319
+ for node in self.nodes:
320
+ hasher.update(struct.pack(">i", node))
321
+ hasher.update(b"|") # Separator to distinguish node counts
322
+ # Hash the wrapped constraint
323
+ self.constraint._hash_into(hasher)
324
+
325
+ def __repr__(self):
326
+ """String representation of the NodalConstraint.
327
+
328
+ Returns:
329
+ str: String showing the wrapped constraint and node indices
330
+ """
331
+ return f"NodalConstraint({self.constraint!r}, nodes={self.nodes})"
332
+
333
+
334
+ class CrossNodeConstraint(Expr):
335
+ """A constraint that couples specific trajectory nodes via .at(k) references.
336
+
337
+ Unlike NodalConstraint which applies a constraint pattern at multiple nodes
338
+ (via vmapping), CrossNodeConstraint is a single constraint with fixed node
339
+ indices embedded in the expression via NodeReference nodes.
340
+
341
+ CrossNodeConstraint is created automatically when a bare Constraint contains
342
+ NodeReference nodes (from .at(k) calls). Users should NOT manually wrap
343
+ cross-node constraints - they are auto-detected during constraint separation.
344
+
345
+ **Key differences from NodalConstraint:**
346
+
347
+ - **NodalConstraint**: Same constraint evaluated at multiple nodes via vmapping.
348
+ Signature: (x, u, node, params) → scalar, vmapped to (N, n_x) inputs.
349
+ - **CrossNodeConstraint**: Single constraint coupling specific fixed nodes.
350
+ Signature: (X, U, params) → scalar, operates on full trajectory arrays.
351
+
352
+ **Lowering:**
353
+
354
+ - **Non-convex**: Lowered to JAX with automatic differentiation for SCP linearization
355
+ - **Convex**: Lowered to CVXPy and solved directly by the convex solver
356
+
357
+ Attributes:
358
+ constraint: The wrapped Constraint containing NodeReference nodes
359
+
360
+ Example:
361
+ Rate limit constraint (auto-detected as CrossNodeConstraint):
362
+
363
+ position = State("pos", shape=(3,))
364
+
365
+ # This creates a CrossNodeConstraint automatically:
366
+ rate_limit = position.at(5) - position.at(4) <= 0.1
367
+
368
+ # Mark as convex if the constraint is convex:
369
+ rate_limit_convex = (position.at(5) - position.at(4) <= 0.1).convex()
370
+
371
+ Creating multiple cross-node constraints with a loop:
372
+
373
+ constraints = []
374
+ for k in range(1, N):
375
+ # Each iteration creates one CrossNodeConstraint
376
+ rate_limit = position.at(k) - position.at(k-1) <= max_step
377
+ constraints.append(rate_limit)
378
+
379
+ Note:
380
+ Do NOT use .at([...]) on cross-node constraints. The nodes are already
381
+ specified via .at(k) inside the expression. Using .at([...]) will raise
382
+ an error during constraint separation.
383
+ """
384
+
385
+ def __init__(self, constraint: Constraint):
386
+ """Initialize a CrossNodeConstraint.
387
+
388
+ Args:
389
+ constraint: The Constraint containing NodeReference nodes.
390
+ Must contain at least one NodeReference (from .at(k) calls).
391
+
392
+ Raises:
393
+ TypeError: If constraint is not a Constraint instance
394
+ """
395
+ if not isinstance(constraint, Constraint):
396
+ raise TypeError("CrossNodeConstraint must wrap a Constraint")
397
+
398
+ self.constraint = constraint
399
+
400
+ @property
401
+ def is_convex(self) -> bool:
402
+ """Whether the underlying constraint is marked as convex.
403
+
404
+ Returns:
405
+ bool: True if the constraint is convex, False otherwise
406
+ """
407
+ return self.constraint.is_convex
408
+
409
+ def children(self):
410
+ """Return the wrapped constraint as the only child.
411
+
412
+ Returns:
413
+ list: Single-element list containing the wrapped constraint
414
+ """
415
+ return [self.constraint]
416
+
417
+ def canonicalize(self) -> "Expr":
418
+ """Canonicalize the wrapped constraint.
419
+
420
+ Returns:
421
+ CrossNodeConstraint: A new CrossNodeConstraint with canonicalized inner constraint
422
+ """
423
+ canon_constraint = self.constraint.canonicalize()
424
+ return CrossNodeConstraint(canon_constraint)
425
+
426
+ def check_shape(self) -> Tuple[int, ...]:
427
+ """Validate the wrapped constraint's shape.
428
+
429
+ Returns:
430
+ tuple: Empty tuple () representing scalar shape
431
+ """
432
+ self.constraint.check_shape()
433
+ return ()
434
+
435
+ def convex(self) -> "CrossNodeConstraint":
436
+ """Mark the underlying constraint as convex for CVXPy lowering.
437
+
438
+ Returns:
439
+ Self with underlying constraint's convex flag set to True
440
+ """
441
+ self.constraint.convex()
442
+ return self
443
+
444
+ def __repr__(self):
445
+ """String representation of the CrossNodeConstraint.
446
+
447
+ Returns:
448
+ str: String showing the wrapped constraint
449
+ """
450
+ return f"CrossNodeConstraint({self.constraint!r})"
451
+
452
+
453
+ # CTCS STUFF
454
+
455
+
456
+ class CTCS(Expr):
457
+ """Continuous-Time Constraint Satisfaction using augmented state dynamics.
458
+
459
+ CTCS enables strict continuous-time constraint enforcement in discretized trajectory
460
+ optimization by augmenting the state vector with additional states whose dynamics
461
+ are the constraint violation penalties. By constraining these augmented states to remain
462
+ at zero throughout the trajectory, the original constraints are guaranteed to be satisfied
463
+ continuously, not just at discrete nodes.
464
+
465
+ **How it works:**
466
+
467
+ 1. Each constraint (in canonical form: lhs <= 0) is wrapped in a penalty function
468
+ 2. Augmented states s_aug_i are added with dynamics: ds_aug_i/dt = sum(penalty_j(lhs_j))
469
+ for all CTCS constraints j in group i
470
+ 3. Each augmented state is constrained: s_aug_i(t) = 0 for all t (strictly enforced)
471
+ 4. Since s_aug_i integrates the penalties, s_aug_i = 0 implies all penalties in the
472
+ group are zero, which means all constraints in the group are satisfied continuously
473
+
474
+ **Grouping and augmented states:**
475
+
476
+ - CTCS constraints with the **same node interval** are grouped into a single augmented
477
+ state by default (their penalties are summed)
478
+ - CTCS constraints with **different node intervals** create separate augmented states
479
+ - Using the `idx` parameter explicitly assigns constraints to specific augmented states,
480
+ allowing manual control over grouping
481
+ - Each unique group creates one augmented state named `_ctcs_aug_0`, `_ctcs_aug_1`, etc.
482
+
483
+ This is particularly useful for:
484
+
485
+ - Path constraints that must hold throughout the entire trajectory (not just at nodes)
486
+ - Obstacle avoidance where constraint violation between nodes could be catastrophic
487
+ - State limits that should be respected continuously (e.g., altitude > 0 for aircraft)
488
+ - Ensuring smooth, feasible trajectories between discretization points
489
+
490
+ **Penalty functions** (applied to constraint violations):
491
+
492
+ - **squared_relu**: Square(PositivePart(lhs)) - smooth, differentiable (default)
493
+ - **huber**: Huber(PositivePart(lhs)) - less sensitive to outliers than squared
494
+ - **smooth_relu**: SmoothReLU(lhs) - smooth approximation of ReLU
495
+
496
+ Attributes:
497
+ constraint: The wrapped Constraint (typically Inequality) to enforce continuously
498
+ penalty: Penalty function type ('squared_relu', 'huber', or 'smooth_relu')
499
+ nodes: Optional (start, end) tuple specifying the interval for enforcement,
500
+ or None to enforce over the entire trajectory
501
+ idx: Optional grouping index for managing multiple augmented states.
502
+ CTCS constraints with the same idx and nodes are grouped together, sharing
503
+ an augmented state. If None, auto-assigned based on node intervals.
504
+ check_nodally: Whether to also enforce the constraint at discrete nodes for
505
+ additional numerical robustness (creates both continuous and nodal constraints)
506
+
507
+ Example:
508
+ Single augmented state (default behavior - same node interval):
509
+
510
+ altitude = State("alt", shape=(1,))
511
+ constraints = [
512
+ (altitude >= 10).over((0, 10)), # Both constraints share
513
+ (altitude <= 1000).over((0, 10)) # one augmented state
514
+ ]
515
+
516
+ Multiple augmented states (different node intervals):
517
+
518
+ constraints = [
519
+ (altitude >= 10).over((0, 5)), # Creates _ctcs_aug_0
520
+ (altitude >= 20).over((5, 10)) # Creates _ctcs_aug_1
521
+ ]
522
+
523
+ Manual grouping with idx parameter:
524
+
525
+ constraints = [
526
+ (altitude >= 10).over((0, 10), idx=0), # Group 0
527
+ (velocity <= 100).over((0, 10), idx=1), # Group 1 (separate state)
528
+ (altitude <= 1000).over((0, 10), idx=0) # Also group 0
529
+ ]
530
+ """
531
+
532
+ def __init__(
533
+ self,
534
+ constraint: Constraint,
535
+ penalty: str = "squared_relu",
536
+ nodes: Optional[Tuple[int, int]] = None,
537
+ idx: Optional[int] = None,
538
+ check_nodally: bool = False,
539
+ ):
540
+ """Initialize a CTCS constraint.
541
+
542
+ Args:
543
+ constraint: The Constraint to enforce continuously (typically an Inequality)
544
+ penalty: Penalty function type. Options:
545
+ - 'squared_relu': Square(PositivePart(lhs)) - default, smooth, differentiable
546
+ - 'huber': Huber(PositivePart(lhs)) - robust to outliers
547
+ - 'smooth_relu': SmoothReLU(lhs) - smooth ReLU approximation
548
+ nodes: Optional (start, end) tuple of node indices defining the enforcement interval.
549
+ None means enforce over the entire trajectory. Must satisfy start < end.
550
+ CTCS constraints with the same nodes are automatically grouped together.
551
+ idx: Optional grouping index for multiple augmented states. Allows organizing
552
+ multiple CTCS constraints with separate augmented state variables.
553
+ If None, constraints are auto-grouped by their node intervals.
554
+ Explicitly setting idx allows manual control over which constraints
555
+ share an augmented state.
556
+ check_nodally: If True, also enforce the constraint at discrete nodes for
557
+ numerical stability (creates both continuous and nodal constraints).
558
+ Defaults to False.
559
+
560
+ Raises:
561
+ TypeError: If constraint is not a Constraint instance
562
+ ValueError: If nodes is not None or a 2-tuple of integers
563
+ ValueError: If nodes[0] >= nodes[1] (invalid interval)
564
+ """
565
+ if not isinstance(constraint, Constraint):
566
+ raise TypeError("CTCS must wrap a Constraint")
567
+
568
+ # Validate nodes parameter for CTCS
569
+ if nodes is not None:
570
+ if not isinstance(nodes, tuple) or len(nodes) != 2:
571
+ raise ValueError(
572
+ "CTCS constraints must specify nodes as a tuple of (start, end) or None "
573
+ "for all nodes"
574
+ )
575
+ if not all(isinstance(n, int) for n in nodes):
576
+ raise ValueError("CTCS node indices must be integers")
577
+ if nodes[0] >= nodes[1]:
578
+ raise ValueError("CTCS node range must have start < end")
579
+
580
+ self.constraint = constraint
581
+ self.penalty = penalty
582
+ self.nodes = nodes # (start, end) node range or None for all nodes
583
+ self.idx = idx # Optional grouping index for multiple augmented states
584
+ # Whether to also enforce this constraint nodally for numerical stability
585
+ self.check_nodally = check_nodally
586
+
587
+ def children(self):
588
+ """Return the wrapped constraint as the only child.
589
+
590
+ Returns:
591
+ list: Single-element list containing the wrapped constraint
592
+ """
593
+ return [self.constraint]
594
+
595
+ def canonicalize(self) -> "Expr":
596
+ """Canonicalize the inner constraint while preserving CTCS parameters.
597
+
598
+ Returns:
599
+ CTCS: A new CTCS with canonicalized inner constraint and same parameters
600
+ """
601
+ canon_constraint = self.constraint.canonicalize()
602
+ return CTCS(
603
+ canon_constraint,
604
+ penalty=self.penalty,
605
+ nodes=self.nodes,
606
+ idx=self.idx,
607
+ check_nodally=self.check_nodally,
608
+ )
609
+
610
+ def check_shape(self) -> Tuple[int, ...]:
611
+ """Validate the constraint and penalty expression shapes.
612
+
613
+ CTCS transforms the wrapped constraint into a penalty expression that is
614
+ summed (integrated) over the trajectory, always producing a scalar result.
615
+
616
+ Returns:
617
+ tuple: Empty tuple () representing scalar shape
618
+
619
+ Raises:
620
+ ValueError: If the wrapped constraint has invalid shape
621
+ ValueError: If the generated penalty expression is not scalar
622
+ """
623
+ # First validate the wrapped constraint's shape
624
+ self.constraint.check_shape()
625
+
626
+ # Also validate the penalty expression that would be generated
627
+ try:
628
+ penalty_expr = self.penalty_expr()
629
+ penalty_shape = penalty_expr.check_shape()
630
+
631
+ # The penalty expression should always be scalar due to Sum wrapper
632
+ if penalty_shape != ():
633
+ raise ValueError(
634
+ f"CTCS penalty expression should be scalar, but got shape {penalty_shape}"
635
+ )
636
+ except Exception as e:
637
+ # Re-raise with more context about which CTCS node failed
638
+ raise ValueError(f"CTCS penalty expression validation failed: {e}") from e
639
+
640
+ # CTCS always produces a scalar due to the Sum in penalty_expr
641
+ return ()
642
+
643
+ def _hash_into(self, hasher: "hashlib._Hash") -> None:
644
+ """Hash CTCS including all its parameters.
645
+
646
+ Args:
647
+ hasher: A hashlib hash object to update
648
+ """
649
+ hasher.update(b"CTCS")
650
+ # Hash penalty type
651
+ hasher.update(self.penalty.encode())
652
+ # Hash nodes interval
653
+ if self.nodes is not None:
654
+ hasher.update(struct.pack(">ii", self.nodes[0], self.nodes[1]))
655
+ else:
656
+ hasher.update(b"None")
657
+ # Hash idx
658
+ if self.idx is not None:
659
+ hasher.update(struct.pack(">i", self.idx))
660
+ else:
661
+ hasher.update(b"None")
662
+ # Hash check_nodally
663
+ hasher.update(b"1" if self.check_nodally else b"0")
664
+ # Hash the wrapped constraint
665
+ self.constraint._hash_into(hasher)
666
+
667
+ def over(self, interval: tuple[int, int]) -> "CTCS":
668
+ """Set or update the continuous interval for this CTCS constraint.
669
+
670
+ Args:
671
+ interval: Tuple of (start, end) node indices defining the enforcement interval
672
+
673
+ Returns:
674
+ CTCS: New CTCS constraint with the specified interval
675
+
676
+ Example:
677
+ Define constraint over range:
678
+
679
+ constraint = (altitude >= 10).over((0, 50))
680
+
681
+ Update interval to cover different range:
682
+
683
+ constraint_updated = constraint.over((50, 100))
684
+ """
685
+ return CTCS(
686
+ self.constraint,
687
+ penalty=self.penalty,
688
+ nodes=interval,
689
+ idx=self.idx,
690
+ check_nodally=self.check_nodally,
691
+ )
692
+
693
+ def __repr__(self):
694
+ """String representation of the CTCS constraint.
695
+
696
+ Returns:
697
+ str: String showing constraint, penalty type, and optional parameters
698
+ """
699
+ parts = [f"{self.constraint!r}", f"penalty={self.penalty!r}"]
700
+ if self.nodes is not None:
701
+ parts.append(f"nodes={self.nodes}")
702
+ if self.idx is not None:
703
+ parts.append(f"idx={self.idx}")
704
+ if self.check_nodally:
705
+ parts.append(f"check_nodally={self.check_nodally}")
706
+ return f"CTCS({', '.join(parts)})"
707
+
708
+ def penalty_expr(self) -> Expr:
709
+ """Build the penalty expression for this CTCS constraint.
710
+
711
+ Transforms the constraint's left-hand side (in canonical form: lhs <= 0)
712
+ into a penalty expression using the specified penalty function. The penalty
713
+ is zero when the constraint is satisfied and positive when violated.
714
+
715
+ This penalty expression becomes part of the dynamics of an augmented state.
716
+ Multiple CTCS constraints in the same group (same idx) have their penalties
717
+ summed: ds_aug_i/dt = sum(penalty_j) for all j in group i. By constraining
718
+ s_aug_i(t) = 0 for all t, we ensure all penalties in the group are zero,
719
+ which strictly enforces all constraints in the group continuously.
720
+
721
+ Returns:
722
+ Expr: Sum of the penalty function applied to the constraint violation
723
+
724
+ Raises:
725
+ ValueError: If an unknown penalty type is specified
726
+
727
+ Note:
728
+ This method is used internally during problem compilation to create
729
+ augmented state dynamics. Multiple penalty expressions with the same
730
+ idx are summed together before being added to the dynamics vector via Concat.
731
+ """
732
+ lhs = self.constraint.lhs
733
+
734
+ if self.penalty == "squared_relu":
735
+ from openscvx.symbolic.expr.math import PositivePart, Square
736
+
737
+ penalty = Square(PositivePart(lhs))
738
+ elif self.penalty == "huber":
739
+ from openscvx.symbolic.expr.math import Huber, PositivePart
740
+
741
+ penalty = Huber(PositivePart(lhs))
742
+ elif self.penalty == "smooth_relu":
743
+ from openscvx.symbolic.expr.math import SmoothReLU
744
+
745
+ penalty = SmoothReLU(lhs)
746
+ else:
747
+ raise ValueError(f"Unknown penalty {self.penalty!r}")
748
+
749
+ return Sum(penalty)
750
+
751
+
752
+ def ctcs(
753
+ constraint: Constraint,
754
+ penalty: str = "squared_relu",
755
+ nodes: Optional[Tuple[int, int]] = None,
756
+ idx: Optional[int] = None,
757
+ check_nodally: bool = False,
758
+ ) -> CTCS:
759
+ """Helper function to create CTCS (Continuous-Time Constraint Satisfaction) constraints.
760
+
761
+ This is a convenience function that creates a CTCS constraint with the same
762
+ parameters as the CTCS constructor. Useful for functional-style constraint building.
763
+
764
+ Args:
765
+ constraint: The Constraint to enforce continuously
766
+ penalty: Penalty function type ('squared_relu', 'huber', or 'smooth_relu').
767
+ Defaults to 'squared_relu'.
768
+ nodes: Optional (start, end) tuple of node indices for enforcement interval.
769
+ None enforces over entire trajectory.
770
+ idx: Optional grouping index for multiple augmented states
771
+ check_nodally: Whether to also enforce constraint at discrete nodes.
772
+ Defaults to False.
773
+
774
+ Returns:
775
+ CTCS: A CTCS constraint wrapping the input constraint
776
+
777
+ Example:
778
+ Using the helper function:
779
+
780
+ from openscvx.symbolic.expr.constraint import ctcs
781
+ altitude_constraint = ctcs(
782
+ altitude >= 10,
783
+ penalty="huber",
784
+ nodes=(0, 100),
785
+ check_nodally=True
786
+ )
787
+
788
+ Equivalent to using CTCS constructor:
789
+
790
+ altitude_constraint = CTCS(altitude >= 10, penalty="huber", nodes=(0, 100))
791
+
792
+ Also equivalent to using .over() method on constraint:
793
+
794
+ altitude_constraint = (altitude >= 10).over((0, 100), penalty="huber")
795
+ """
796
+ return CTCS(constraint, penalty, nodes, idx, check_nodally)