openscvx 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

openscvx/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.1.1'
21
- __version_tuple__ = version_tuple = (0, 1, 1)
20
+ __version__ = version = '0.1.2'
21
+ __version_tuple__ = version_tuple = (0, 1, 2)
File without changes
@@ -0,0 +1,122 @@
1
+ from typing import Callable, List, Tuple
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+
6
+ from openscvx.constraints.violation import CTCSViolation
7
+ from openscvx.dynamics import Dynamics
8
+
9
+ def build_augmented_dynamics(
10
+ dynamics_non_augmented: Dynamics,
11
+ violations: List[CTCSViolation],
12
+ idx_x_true: slice,
13
+ idx_u_true: slice,
14
+ ) -> Dynamics:
15
+ dynamics_augmented = Dynamics(
16
+ f=get_augmented_dynamics(
17
+ dynamics_non_augmented.f, violations, idx_x_true, idx_u_true
18
+ ),
19
+ )
20
+ A, B = get_jacobians(
21
+ dynamics_augmented.f, dynamics_non_augmented, violations, idx_x_true, idx_u_true
22
+ )
23
+ dynamics_augmented.A = A
24
+ dynamics_augmented.B = B
25
+ return dynamics_augmented
26
+
27
+
28
+ def get_augmented_dynamics(
29
+ dynamics: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray],
30
+ violations: List[CTCSViolation],
31
+ idx_x_true: slice,
32
+ idx_u_true: slice,
33
+ ) -> Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]:
34
+ def dynamics_augmented(x: jnp.array, u: jnp.array, node: int) -> jnp.array:
35
+ x_dot = dynamics(x[idx_x_true], u[idx_u_true])
36
+
37
+ # Iterate through the g_func dictionary and stack the output each function
38
+ # to x_dot
39
+ for v in violations:
40
+ x_dot = jnp.hstack([x_dot, v.g(x[idx_x_true], u[idx_u_true], node)])
41
+
42
+ return x_dot
43
+
44
+ return dynamics_augmented
45
+
46
+
47
+ def get_jacobians(
48
+ dyn_augmented: Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray],
49
+ dynamics_non_augmented: Dynamics,
50
+ violations: List[CTCSViolation],
51
+ idx_x_true: slice,
52
+ idx_u_true: slice,
53
+ ) -> Tuple[
54
+ Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray],
55
+ Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray],
56
+ ]:
57
+ # 1) Early return if absolutely no custom grads anywhere
58
+ no_dyn_grads = dynamics_non_augmented.A is None and dynamics_non_augmented.B is None
59
+ no_vio_grads = all(v.g_grad_x is None and v.g_grad_u is None for v in violations)
60
+
61
+ if no_dyn_grads and no_vio_grads:
62
+ return (
63
+ jax.jacfwd(dyn_augmented, argnums=0),
64
+ jax.jacfwd(dyn_augmented, argnums=1),
65
+ )
66
+
67
+ # 2) Build the *true‐state* Jacobians of f(x_true,u_true)
68
+ f_fn = dynamics_non_augmented.f
69
+ if dynamics_non_augmented.A is None:
70
+ A_f = lambda x_true, u_true: jax.jacfwd(f_fn, argnums=0)(x_true, u_true)
71
+ else:
72
+ A_f = dynamics_non_augmented.A
73
+
74
+ if dynamics_non_augmented.B is None:
75
+ B_f = lambda x_true, u_true: jax.jacfwd(f_fn, argnums=1)(x_true, u_true)
76
+ else:
77
+ B_f = dynamics_non_augmented.B
78
+
79
+ # 3) Build per-violation gradients
80
+ def make_violation_grad_x(i: int) -> Callable:
81
+ viol = violations[i]
82
+ # use user‐provided if present, otherwise autodiff viol.g in argnum=0
83
+ return viol.g_grad_x or jax.jacfwd(viol.g, argnums=0)
84
+
85
+ def make_violation_grad_u(i: int) -> Callable:
86
+ viol = violations[i]
87
+ # use user‐provided if present, otherwise autodiff viol.g in argnum=0
88
+ return viol.g_grad_u or jax.jacfwd(viol.g, argnums=1)
89
+
90
+ # 4) Assemble A_aug, B_aug
91
+ def A(x_aug, u_aug, node):
92
+ # dynamics block + zero‐pad
93
+ Af = A_f(x_aug[idx_x_true], u_aug[idx_u_true]) # (n_f, n_x_true)
94
+ nv = sum(
95
+ v.g(x_aug[idx_x_true], u_aug[idx_u_true], node).shape[0] for v in violations
96
+ ) # total # rows of violations
97
+ zero_pad = jnp.zeros((Af.shape[0], nv)) # (n_f, n_v)
98
+ top = jnp.hstack([Af, zero_pad]) # (n_f, n_x_true + n_v)
99
+
100
+ # violation blocks
101
+ rows = [top]
102
+ for i in range(len(violations)):
103
+ dx_i = make_violation_grad_x(i)(
104
+ x_aug[idx_x_true], u_aug[idx_u_true], node
105
+ ) # (n_gi, n_x_true)
106
+ pad_i = jnp.zeros((dx_i.shape[0], nv)) # (n_gi, n_v)
107
+ rows.append(jnp.hstack([dx_i, pad_i])) # (n_gi, n_x_true + n_v)
108
+
109
+ return jnp.vstack(rows)
110
+
111
+ def B(x_aug, u_aug, node):
112
+ Bf = B_f(x_aug[idx_x_true], u_aug[idx_u_true]) # (n_f, n_u_true)
113
+ rows = [Bf]
114
+ for i in range(len(violations)):
115
+ du_i = make_violation_grad_u(i)(
116
+ x_aug[idx_x_true], u_aug[idx_u_true], node
117
+ ) # (n_gi, n_u_true)
118
+ rows.append(du_i)
119
+
120
+ return jnp.vstack(rows)
121
+
122
+ return A, B
@@ -1,6 +1,7 @@
1
1
  from dataclasses import dataclass
2
2
  from typing import Callable, Sequence, Tuple, Optional
3
3
  import functools
4
+ import types
4
5
 
5
6
  from jax.lax import cond
6
7
  import jax.numpy as jnp
@@ -10,8 +11,18 @@ import jax.numpy as jnp
10
11
  class CTCSConstraint:
11
12
  func: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
12
13
  penalty: Callable[[jnp.ndarray], jnp.ndarray]
13
- nodes: Optional[Sequence[Tuple[int, int]]] = None
14
+ nodes: Optional[Tuple[int, int]] = None
14
15
  idx: Optional[int] = None
16
+ grad_f_x: Optional[Callable] = None
17
+ grad_f_u: Optional[Callable] = None
18
+
19
+ def __post_init__(self):
20
+ if self.grad_f_x is not None:
21
+ _grad_f_x = self.grad_f_x
22
+ self.grad_f_x = lambda x, u, nodes: _grad_f_x(x, u)
23
+ if self.grad_f_u is not None:
24
+ _grad_f_u = self.grad_f_u
25
+ self.grad_f_u = lambda x, u, nodes: _grad_f_u(x, u)
15
26
 
16
27
  def __call__(self, x, u, node):
17
28
  return cond(
@@ -28,6 +39,8 @@ def ctcs(
28
39
  penalty: str = "squared_relu",
29
40
  nodes: Optional[Sequence[Tuple[int, int]]] = None,
30
41
  idx: Optional[int] = None,
42
+ grad_f_x: Optional[Callable] = None,
43
+ grad_f_u: Optional[Callable] = None,
31
44
  ):
32
45
  """Decorator to mark a function as a 'ctcs' constraint.
33
46
 
@@ -44,14 +57,25 @@ def ctcs(
44
57
  def pen(x):
45
58
  r = jnp.maximum(0, x)
46
59
  return jnp.where(r < delta, 0.5 * r**2, r - 0.5 * delta)
47
-
60
+ elif penalty == "smooth_relu":
61
+ c = 1e-8
62
+ pen = lambda x: (jnp.maximum(0, x) ** 2 + c**2) ** 0.5 - c
63
+ elif isinstance(penalty, types.LambdaType):
64
+ pen = penalty
48
65
  else:
49
66
  raise ValueError(f"Unknown penalty {penalty}")
50
67
 
51
68
  def decorator(f: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]):
52
69
  # wrap so name, doc, signature stay on f
53
70
  wrapped = functools.wraps(f)(f)
54
- return CTCSConstraint(func=wrapped, penalty=pen, nodes=nodes, idx=idx)
71
+ return CTCSConstraint(
72
+ func=wrapped,
73
+ penalty=pen,
74
+ nodes=nodes,
75
+ idx=idx,
76
+ grad_f_x=grad_f_x,
77
+ grad_f_u=grad_f_u,
78
+ )
55
79
 
56
80
  # if called as @ctcs or @ctcs(...), _func will be None and we return decorator
57
81
  if _func is None:
@@ -11,20 +11,22 @@ class NodalConstraint:
11
11
  nodes: Optional[List[int]] = None
12
12
  convex: bool = False
13
13
  vectorized: bool = False
14
+ grad_g_x: Optional[Callable] = None
15
+ grad_g_u: Optional[Callable] = None
14
16
 
15
17
  def __post_init__(self):
16
18
  if not self.convex:
17
- # TODO: (haynec) switch to AOT instead of JIT
18
- if self.vectorized:
19
- # single-node but still using JAX
20
- self.g = jit(self.func)
21
- self.grad_g_x = jit(jacfwd(self.func, argnums=0))
22
- self.grad_g_u = jit(jacfwd(self.func, argnums=1))
23
- else:
24
- self.g = vmap(jit(self.func), in_axes=(0, 0))
25
- self.grad_g_x = jit(vmap(jacfwd(self.func, argnums=0), in_axes=(0, 0)))
26
- self.grad_g_u = jit(vmap(jacfwd(self.func, argnums=1), in_axes=(0, 0)))
27
- # if convex=True and inter_nodal=False, assume an external solver (e.g. CVX) will handle it
19
+ # single-node but still using JAX
20
+ self.g = self.func
21
+ if self.grad_g_x is None:
22
+ self.grad_g_x = jacfwd(self.func, argnums=0)
23
+ if self.grad_g_u is None:
24
+ self.grad_g_u = jacfwd(self.func, argnums=1)
25
+ if not self.vectorized:
26
+ self.g = vmap(self.g, in_axes=(0, 0))
27
+ self.grad_g_x = vmap(self.grad_g_x, in_axes=(0, 0))
28
+ self.grad_g_u = vmap(self.grad_g_u, in_axes=(0, 0))
29
+ # if convex=True assume an external solver (e.g. CVX) will handle it
28
30
 
29
31
  def __call__(self, x: jnp.ndarray, u: jnp.ndarray):
30
32
  return self.func(x, u)
@@ -36,6 +38,8 @@ def nodal(
36
38
  nodes: Optional[List[int]] = None,
37
39
  convex: bool = False,
38
40
  vectorized: bool = False,
41
+ grad_g_x: Optional[Callable] = None,
42
+ grad_g_u: Optional[Callable] = None,
39
43
  ):
40
44
  def decorator(f: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]):
41
45
  return NodalConstraint(
@@ -43,6 +47,8 @@ def nodal(
43
47
  nodes=nodes,
44
48
  convex=convex,
45
49
  vectorized=vectorized,
50
+ grad_g_x=grad_g_x,
51
+ grad_g_u=grad_g_u,
46
52
  )
47
53
 
48
54
  return decorator if _func is None else decorator(_func)
@@ -1,26 +1,67 @@
1
1
  from collections import defaultdict
2
+ from typing import List, Optional, Tuple, Callable
3
+ from dataclasses import dataclass
2
4
 
3
5
  import jax.numpy as jnp
4
6
 
5
- def get_g_func(constraints_ctcs: list[callable, callable]):
7
+ from openscvx.constraints.ctcs import CTCSConstraint
8
+
9
+
10
+ @dataclass
11
+ class CTCSViolation:
12
+ g: Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]
13
+ g_grad_x: Optional[Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]] = None
14
+ g_grad_u: Optional[Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]] = None
15
+
16
+
17
+ def get_g_grad_x(constraints_ctcs: List[CTCSConstraint]) -> Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]:
18
+ def g_grad_x(x: jnp.ndarray, u: jnp.ndarray, node: int) -> jnp.ndarray:
19
+ grads = [
20
+ c.grad_f_x(x, u, node) for c in constraints_ctcs if c.grad_f_x is not None
21
+ ]
22
+ return sum(grads) if grads else None
23
+
24
+ return g_grad_x
25
+
26
+
27
+ def get_g_grad_u(constraints_ctcs: List[CTCSConstraint]) -> Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]:
28
+ def g_grad_u(x: jnp.ndarray, u: jnp.ndarray, node: int) -> jnp.ndarray:
29
+ grads = [
30
+ c.grad_f_u(x, u, node) for c in constraints_ctcs if c.grad_f_u is not None
31
+ ]
32
+ return sum(grads) if grads else None
33
+
34
+ return g_grad_u
35
+
36
+
37
+ def get_g_func(constraints_ctcs: List[CTCSConstraint]) -> Callable[[jnp.ndarray, jnp.ndarray, int], jnp.ndarray]:
6
38
  def g_func(x: jnp.array, u: jnp.array, node: int) -> jnp.array:
7
- g_sum = 0
8
- for g in constraints_ctcs:
9
- g_sum += g(x,u, node)
10
- return g_sum
39
+ return sum(c(x, u, node) for c in constraints_ctcs)
40
+
11
41
  return g_func
12
42
 
13
43
 
14
- def get_g_funcs(constraints_ctcs: list[callable]) -> list[callable]:
44
+ def get_g_funcs(constraints_ctcs: List[CTCSConstraint]) -> List[CTCSViolation]:
15
45
  # Bucket by idx
16
- groups: dict[int, list[callable]] = defaultdict(list)
46
+ groups: dict[int, List[CTCSConstraint]] = defaultdict(list)
17
47
  for c in constraints_ctcs:
18
48
  if c.idx is None:
19
49
  raise ValueError(f"CTCS constraint {c} has no .idx assigned")
20
50
  groups[c.idx].append(c)
21
51
 
22
- # Build and return a list of get_g_func(funcs) in idx order
23
- return [
24
- get_g_func(funcs)
25
- for idx, funcs in sorted(groups.items(), key=lambda kv: kv[0])
26
- ]
52
+ # For each bucket, build one CTCSViolation
53
+ violations: List[CTCSViolation] = []
54
+ for idx, bucket in sorted(groups.items(), key=lambda kv: kv[0]):
55
+ g = get_g_func(bucket)
56
+ g_grad_x = get_g_grad_u(bucket) if all(c.grad_f_x for c in bucket) else None
57
+ g_grad_u = get_g_grad_x(bucket) if all(c.grad_f_u for c in bucket) else None
58
+
59
+ violations.append(
60
+ CTCSViolation(
61
+ g=g,
62
+ g_grad_x=g_grad_x,
63
+ g_grad_u=g_grad_u,
64
+ )
65
+ )
66
+
67
+ return violations
@@ -1,6 +1,7 @@
1
1
  import jax.numpy as jnp
2
2
  import numpy as np
3
3
 
4
+ from openscvx.dynamics import Dynamics
4
5
  from openscvx.integrators import solve_ivp_rk45, solve_ivp_diffrax
5
6
 
6
7
 
@@ -150,13 +151,13 @@ def calculate_discretization(
150
151
  return A_bar, B_bar, C_bar, z_bar, Vmulti
151
152
 
152
153
 
153
- def get_discretization_solver(state_dot, A, B, params):
154
+ def get_discretization_solver(dyn: Dynamics, params):
154
155
  return lambda x, u: calculate_discretization(
155
156
  x=x,
156
157
  u=u,
157
- state_dot=state_dot,
158
- A=A,
159
- B=B,
158
+ state_dot=dyn.f,
159
+ A=dyn.A,
160
+ B=dyn.B,
160
161
  n_x=params.sim.n_states,
161
162
  n_u=params.sim.n_controls,
162
163
  N=params.scp.n,
openscvx/dynamics.py CHANGED
@@ -1,24 +1,41 @@
1
- import jax
1
+ from dataclasses import dataclass
2
+ from typing import Callable, Optional
3
+ import functools
4
+
2
5
  import jax.numpy as jnp
3
6
 
4
7
 
5
- def get_augmented_dynamics(
6
- dynamics: callable, g_funcs: list[callable], idx_x_true: slice, idx_u_true: slice
7
- ) -> callable:
8
- def dynamics_augmented(x: jnp.array, u: jnp.array, node: int) -> jnp.array:
9
- x_dot = dynamics(x[idx_x_true], u[idx_u_true])
8
+ @dataclass
9
+ class Dynamics:
10
+ f: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
11
+ A: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None
12
+ B: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None
10
13
 
11
- # Iterate through the g_func dictionary and stack the output each function
12
- # to x_dot
13
- for g in g_funcs:
14
- x_dot = jnp.hstack([x_dot, g(x[idx_x_true], u[idx_u_true], node)])
14
+ def dynamics(
15
+ _func=None,
16
+ *,
17
+ A: Optional[Callable] = None,
18
+ B: Optional[Callable] = None,):
19
+ """Decorator to mark a function as defining the system dynamics.
15
20
 
16
- return x_dot
21
+ Use as:
22
+ @dynamics(A=my_grad_f_x, B=my_grad_f_u)')
23
+ def my_dynamics(x,u): ...
24
+ """
17
25
 
18
- return dynamics_augmented
26
+ def decorator(f: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]):
27
+ # wrap so name, doc, signature stay on f
28
+ wrapped = functools.wraps(f)(f)
29
+ return Dynamics(
30
+ f=wrapped,
31
+ A=A,
32
+ B=B,
33
+ )
19
34
 
35
+ # if called as @dynamics or @dynamics(...), _func will be None and we return decorator
36
+ if _func is None:
37
+ return decorator
38
+ # if called as dynamics(func), we immediately decorate
39
+ else:
40
+ return decorator(_func)
20
41
 
21
- def get_jacobians(dyn: callable):
22
- A = jax.jacfwd(dyn, argnums=0)
23
- B = jax.jacfwd(dyn, argnums=1)
24
- return A, B
@@ -1,5 +1,5 @@
1
1
  import jax.numpy as jnp
2
- from typing import List
2
+ from typing import List, Union
3
3
  import queue
4
4
  import threading
5
5
  import time
@@ -17,9 +17,10 @@ from openscvx.config import (
17
17
  DevConfig,
18
18
  Config,
19
19
  )
20
- from openscvx.dynamics import get_augmented_dynamics, get_jacobians
21
- from openscvx.constraints.violation import get_g_funcs
22
- from openscvx.augmentation import sort_ctcs_constraints
20
+ from openscvx.dynamics import Dynamics
21
+ from openscvx.augmentation.dynamics_augmentation import build_augmented_dynamics
22
+ from openscvx.augmentation.ctcs import sort_ctcs_constraints
23
+ from openscvx.constraints.violation import get_g_funcs, CTCSViolation
23
24
  from openscvx.discretization import get_discretization_solver
24
25
  from openscvx.propagation import get_propagation_solver
25
26
  from openscvx.constraints.boundary import BoundaryConstraint
@@ -35,8 +36,8 @@ from openscvx import io
35
36
  class TrajOptProblem:
36
37
  def __init__(
37
38
  self,
38
- dynamics: callable,
39
- constraints: List[callable],
39
+ dynamics: Dynamics,
40
+ constraints: List[Union[CTCSConstraint, NodalConstraint]],
40
41
  idx_time: int,
41
42
  N: int,
42
43
  time_init: float,
@@ -63,13 +64,12 @@ class TrajOptProblem:
63
64
  # TODO (norrisg) move this into some augmentation function, if we want to make this be executed after the init (i.e. within problem.initialize) need to rethink how problem is defined
64
65
  constraints_ctcs = []
65
66
  constraints_nodal = []
66
- # TODO: (norrisg) change back to using isinstance once on PyPi
67
67
  for constraint in constraints:
68
- if type(constraint).__name__ == CTCSConstraint.__name__:
68
+ if isinstance(constraint, CTCSConstraint):
69
69
  constraints_ctcs.append(
70
70
  constraint
71
71
  )
72
- elif type(constraint).__name__ == NodalConstraint.__name__:
72
+ elif isinstance(constraint, NodalConstraint):
73
73
  constraints_nodal.append(
74
74
  constraint
75
75
  )
@@ -160,9 +160,8 @@ class TrajOptProblem:
160
160
  sim.constraints_ctcs = constraints_ctcs
161
161
  sim.constraints_nodal = constraints_nodal
162
162
 
163
- g_funcs = get_g_funcs(constraints_ctcs)
164
- self.dynamics_augmented = get_augmented_dynamics(dynamics, g_funcs, idx_x_true, idx_u_true)
165
- self.A_uncompiled, self.B_uncompiled = get_jacobians(self.dynamics_augmented)
163
+ ctcs_violation_funcs = get_g_funcs(constraints_ctcs)
164
+ self.dynamics_augmented = build_augmented_dynamics(dynamics, ctcs_violation_funcs, idx_x_true, idx_u_true)
166
165
 
167
166
  self.params = Config(
168
167
  sim=sim,
@@ -212,18 +211,20 @@ class TrajOptProblem:
212
211
  self.params.sim.__post_init__()
213
212
 
214
213
  # Compile dynamics and jacobians
215
- self.state_dot = jax.vmap(self.dynamics_augmented)
216
- self.A = jax.jit(jax.vmap(self.A_uncompiled, in_axes=(0, 0, 0)))
217
- self.B = jax.jit(jax.vmap(self.B_uncompiled, in_axes=(0, 0, 0)))
218
- # TODO: (norrisg) Could consider using dataclass just to hold dynamics and jacobians
219
- # TODO: (norrisg) Consider writing the compiled versions into the same variables?
220
- # Otherwise if have a dataclass could have 2 instances, one for compied and one for uncompiled
214
+ self.dynamics_augmented.f = jax.vmap(self.dynamics_augmented.f)
215
+ self.dynamics_augmented.A = jax.jit(jax.vmap(self.dynamics_augmented.A, in_axes=(0, 0, 0)))
216
+ self.dynamics_augmented.B = jax.jit(jax.vmap(self.dynamics_augmented.B, in_axes=(0, 0, 0)))
217
+
218
+ for constraint in self.params.sim.constraints_nodal:
219
+ if not constraint.convex:
220
+ # TODO: (haynec) switch to AOT instead of JIT
221
+ constraint.g = jax.jit(constraint.g)
222
+ constraint.grad_g_x = jax.jit(constraint.grad_g_x)
223
+ constraint.grad_g_u = jax.jit(constraint.grad_g_u)
221
224
 
222
225
  # Generate solvers and optimal control problem
223
- self.discretization_solver = get_discretization_solver(
224
- self.state_dot, self.A, self.B, self.params
225
- )
226
- self.propagation_solver = get_propagation_solver(self.state_dot, self.params)
226
+ self.discretization_solver = get_discretization_solver(self.dynamics_augmented, self.params)
227
+ self.propagation_solver = get_propagation_solver(self.dynamics_augmented.f, self.params)
227
228
  self.optimal_control_problem = OptimalControlProblem(self.params)
228
229
 
229
230
  # Initialize the PTR loop
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: openscvx
3
- Version: 0.1.1
3
+ Version: 0.1.2
4
4
  Summary: A general Python-based successive convexification implementation which uses a JAX backend.
5
5
  Home-page: https://haynec.github.io/openscvx/
6
6
  Author: Chris Hayner and Griffin Norris
@@ -26,10 +26,15 @@ Dynamic: license-file
26
26
 
27
27
  <img src="figures/openscvx_logo.svg" width="1200"/>
28
28
  <p align="center">
29
- <a href="https://github.com//haynec/OpenSCvx/actions/workflows/main.yml/badge.svg"><img src="https://github.com//haynec/OpenSCvx/actions/workflows/main.yml/badge.svg"/></a>
29
+ <a href="https://github.com//haynec/OpenSCvx/actions/workflows/website.yml/badge.svg"><img src="https://github.com//haynec/OpenSCvx/actions/workflows/website.yml/badge.svg"/></a>
30
30
  <a href="https://arxiv.org/abs/2410.22596"><img src="http://img.shields.io/badge/arXiv-2410.22596-B31B1B.svg"/></a>
31
31
  <a href="https://www.apache.org/licenses/LICENSE-2.0"><img src="https://img.shields.io/badge/License-Apache_2.0-blue.svg" alt="License: Apache 2.0"/></a>
32
32
  </p>
33
+ <p align="center">
34
+ <a href="https://github.com//haynec/OpenSCvx/actions/workflows/ci.yml/badge.svg"><img src="https://github.com//haynec/OpenSCvx/actions/workflows/ci.yml/badge.svg"/></a>
35
+ <a href="https://github.com//haynec/OpenSCvx/actions/workflows/nightly.yml/badge.svg"><img src="https://github.com//haynec/OpenSCvx/actions/workflows/nightly.yml/badge.svg"/></a>
36
+ <a href="https://github.com//haynec/OpenSCvx/actions/workflows/release.yml/badge.svg"><img src="https://github.com//haynec/OpenSCvx/actions/workflows/release.yml/badge.svg"/></a>
37
+ </p>
33
38
 
34
39
  <!-- PROJECT LOGO -->
35
40
  <br />
@@ -37,20 +42,50 @@ Dynamic: license-file
37
42
  <!-- GETTING STARTED -->
38
43
  ## Getting Started
39
44
 
40
-
41
45
  ### Installation
46
+
47
+ To grab the latest stable release simply run
48
+
49
+ ```sh
50
+ pip install openscvx
51
+ ```
52
+
53
+ to install OpenSCVx in your python environment.
54
+
55
+ <details>
56
+ <summary>Install Development / Nightly Version</summary>
57
+
58
+ If you want the pre-release version, you can install the latest `nightly` build with:
59
+
60
+ ```sh
61
+ python3 -m pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ --pre --upgrade openscvx
62
+ ```
63
+
64
+ This command will also upgrade an existing `nightly` install to the latest version
65
+ Or if you want a specific pre-release version this can be installed with
66
+
67
+ ```sh
68
+ python3 -m pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ openscvx==1.2.3.dev45
69
+ ```
70
+
71
+ where `1.2.3.dev45 => <major>.<minor>.<patch>.dev<XY>` corresponds to your exact version
72
+ </details>
73
+
74
+
75
+ #### Dependencies
76
+
42
77
  The main packages are:
43
- - ```cvxpy``` - is used to formulate and solve the convex subproblems
44
- - ```jax``` - is used for determining the Jacobians using automatic differentiation, vectorization, and ahead-of-time (AOT) compilation of the dynamics and their Jacobians
45
- - ```numpy``` - is used for numerical operations
46
- - ```diffrax``` - is used for the numerical integration of the dynamics
47
- - ```termcolor``` - is used for pretty command line output
48
- - ```plotly``` - is used for all visualizations
49
78
 
79
+ - `cvxpy` - is used to formulate and solve the convex subproblems
80
+ - `jax` - is used for determining the Jacobians using automatic differentiation, vectorization, and ahead-of-time (AOT) compilation of the dynamics and their Jacobians
81
+ - `numpy` - is used for numerical operations
82
+ - `diffrax` - is used for the numerical integration of the dynamics
83
+ - `termcolor` - is used for pretty command line output
84
+ - `plotly` - is used for all visualizations
50
85
 
51
- These can be installed via conda or pip.
86
+ These will be installed automatically, but can be installed via conda or pip if you are building from source.
52
87
  <details>
53
- <summary>Via Conda (Recommended) </summary>
88
+ <summary>Via Conda</summary>
54
89
 
55
90
  1. Clone the repo
56
91
  ```sh
@@ -87,11 +122,11 @@ These can be installed via conda or pip.
87
122
  See `examples/` folder for several example trajectory optimization problems.
88
123
  To run a problem simply run `examples/main.py` with:
89
124
 
90
- ```bash
91
- python3 -m examples.main
125
+ ```sh
126
+ python3 examples/main.py
92
127
  ```
93
128
 
94
- To change which example is run by `main` simply replace the `params` import line:
129
+ To change which example is run by `main` simply replace the `problem` import line:
95
130
 
96
131
  ```python
97
132
  # other imports
@@ -99,10 +134,11 @@ from examples.params.dr_vp import problem
99
134
  # rest of code
100
135
  ```
101
136
 
137
+ and adjust the plotting as needed.
102
138
  Check out the problem definitions inside `examples/params` to see how to define your own problems.
103
139
 
104
-
105
140
  ## ToDos
141
+
106
142
  - [X] Standardized Vehicle and Constraint classes
107
143
  - [X] Implement QOCOGen with CVPYGEN
108
144
  - [X] Non-Dilated Time Propagation
@@ -110,8 +146,12 @@ Check out the problem definitions inside `examples/params` to see how to define
110
146
  - [ ] Compiled at the subproblem level with JAX
111
147
  - [ ] Save and reload the compiled JAX code
112
148
  - [ ] Single Shot propagation
149
+ - [ ] Unified Mathematical Interface
150
+
113
151
  ## What is implemented
152
+
114
153
  This repo has the following features:
154
+
115
155
  1. Free Final Time
116
156
  2. Fully adaptive time dilation (```s``` is appended to the control vector)
117
157
  3. Continuous-Time Constraint Satisfaction
@@ -122,11 +162,14 @@ This repo has the following features:
122
162
  <p align="right">(<a href="#readme-top">back to top</a>)</p>
123
163
 
124
164
  ## Acknowledgements
165
+
125
166
  This work was supported by a NASA Space Technology Graduate Research Opportunity and the Office of Naval Research under grant N00014-17-1-2433. The authors would like to acknowledge Natalia Pavlasek, Samuel Buckner, Abhi Kamath, Govind Chari, and Purnanand Elango as well as the other Autonomous Controls Laboratory members, for their many helpful discussions and support throughout this work.
126
167
 
127
168
  ## Citation
169
+
128
170
  Please cite the following works if you use the repository,
129
- ```
171
+
172
+ ```tex
130
173
  @ARTICLE{hayner2025los,
131
174
  author={Hayner, Christopher R. and Carson III, John M. and Açıkmeşe, Behçet and Leung, Karen},
132
175
  journal={IEEE Robotics and Automation Letters},
@@ -139,7 +182,7 @@ Please cite the following works if you use the repository,
139
182
  doi={10.1109/LRA.2025.3545299}}
140
183
  ```
141
184
 
142
- ```
185
+ ```tex
143
186
  @misc{elango2024ctscvx,
144
187
  title={Successive Convexification for Trajectory Optimization with Continuous-Time Constraint Satisfaction},
145
188
  author={Purnanand Elango and Dayou Luo and Abhinav G. Kamath and Samet Uzun and Taewan Kim and Behçet Açıkmeşe},
@@ -0,0 +1,27 @@
1
+ openscvx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ openscvx/_version.py,sha256=bSmADqydH8nBu-J4lG8UVuR7hnU_zcwhnSav2oQ0W0A,511
3
+ openscvx/config.py,sha256=8Cl5O0ekf9MGNDTEeMHsp1C4XvY9NfJQkxd80lvnafM,10296
4
+ openscvx/discretization.py,sha256=YF3mEeyYHgyTWQVNQsqpi1Mv72zDLyNfaMJSWqxj34c,4745
5
+ openscvx/dynamics.py,sha256=X9sPpxUGGbdsnvQzgyrb_939N9ctBSsWVyI1eXtOKpc,1118
6
+ openscvx/integrators.py,sha256=msIS-1Ehj-9TJLHfoCMs3vdyZ8NXz-TM0RII6aqRf4E,3821
7
+ openscvx/io.py,sha256=fOvNWQWAegcN1gejeToaNbXenP5H5bAifNU8edJvdk4,4127
8
+ openscvx/ocp.py,sha256=L_509EQiMsI6s5gBYlYyxKaHEzzRdpo-XAMjliCU3Rc,7544
9
+ openscvx/plotting.py,sha256=fCvWJV4qWMhVyJlh18s12S_5xhj6EviF-_FuP0tWjx4,31207
10
+ openscvx/post_processing.py,sha256=TP1gi4TVlDS2HHpdqaIPCqfM5o4w7a7RCMU3Pu3czHw,1024
11
+ openscvx/propagation.py,sha256=XNezQnAM-NXb9L7aHUgKQOBn0CNUPeGGDL3_BbGoODU,3758
12
+ openscvx/ptr.py,sha256=itDTR6RQUphnU226jaeRaAKuia-6v8U3MqAdw5-BYOk,5268
13
+ openscvx/trajoptproblem.py,sha256=3yufy-egU7m0NV834TH8csY1HJqM90Is7VYw0gQe3pk,11996
14
+ openscvx/utils.py,sha256=zmkKyto8Jowe_RAdOe8K0w6gzOu4JfxmX1RUL-3OFlY,2408
15
+ openscvx/augmentation/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
+ openscvx/augmentation/ctcs.py,sha256=m1jdALXSqHq3WD6lCBAUI7FR0Sfs8aCYr66h0EwE4z4,1707
17
+ openscvx/augmentation/dynamics_augmentation.py,sha256=7PL-mMfSmIIfMiXjdbXTolxOgUFSolzKpLu8WAmD384,4271
18
+ openscvx/constraints/__init__.py,sha256=OOUcYEVoDWOSY50s2TbjpDjl3dRR3U04gRxmOyjbddY,258
19
+ openscvx/constraints/boundary.py,sha256=yEhEnkKJ5f8NUeTksigEJjgBeE_YyuG_PJb_DWxg0L4,1541
20
+ openscvx/constraints/ctcs.py,sha256=V763033aV82nAu7y4653KsAs11A7RpUysR_oUcnLfko,2572
21
+ openscvx/constraints/nodal.py,sha256=YCS0cwUurA2OTQcHBb1EQqLxNt_w3MX8Nj8FH3GYClo,1726
22
+ openscvx/constraints/violation.py,sha256=aIdDhHd-UndT0XB2QeuwLBKSNSAUWVkha_GeHOw9cQg,2362
23
+ openscvx-0.1.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
24
+ openscvx-0.1.2.dist-info/METADATA,sha256=MDHeKrpE_3FKRiQD5fVKVzNBWerOvcY0vfapGSRTlbk,6911
25
+ openscvx-0.1.2.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
26
+ openscvx-0.1.2.dist-info/top_level.txt,sha256=nUT4Ybefzh40H8tVXqc1RzKESy_MAowElb-CIvAbd4Q,9
27
+ openscvx-0.1.2.dist-info/RECORD,,
@@ -1,25 +0,0 @@
1
- openscvx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- openscvx/_version.py,sha256=Mmxse1R0ki5tjz9qzU8AQyqUsLt8nTyCAbYQp8R87PU,511
3
- openscvx/augmentation.py,sha256=m1jdALXSqHq3WD6lCBAUI7FR0Sfs8aCYr66h0EwE4z4,1707
4
- openscvx/config.py,sha256=8Cl5O0ekf9MGNDTEeMHsp1C4XvY9NfJQkxd80lvnafM,10296
5
- openscvx/discretization.py,sha256=xqR6QPfp--9Nh0sjMfOkTc4OxOxeyHdZq7ip7ATzMNo,4704
6
- openscvx/dynamics.py,sha256=1GrjV75qxHa4bx2UK_Idul6qZa-IFjMUrqsne_gdV3E,684
7
- openscvx/integrators.py,sha256=msIS-1Ehj-9TJLHfoCMs3vdyZ8NXz-TM0RII6aqRf4E,3821
8
- openscvx/io.py,sha256=fOvNWQWAegcN1gejeToaNbXenP5H5bAifNU8edJvdk4,4127
9
- openscvx/ocp.py,sha256=L_509EQiMsI6s5gBYlYyxKaHEzzRdpo-XAMjliCU3Rc,7544
10
- openscvx/plotting.py,sha256=fCvWJV4qWMhVyJlh18s12S_5xhj6EviF-_FuP0tWjx4,31207
11
- openscvx/post_processing.py,sha256=TP1gi4TVlDS2HHpdqaIPCqfM5o4w7a7RCMU3Pu3czHw,1024
12
- openscvx/propagation.py,sha256=XNezQnAM-NXb9L7aHUgKQOBn0CNUPeGGDL3_BbGoODU,3758
13
- openscvx/ptr.py,sha256=itDTR6RQUphnU226jaeRaAKuia-6v8U3MqAdw5-BYOk,5268
14
- openscvx/trajoptproblem.py,sha256=fq68viMiS1UOcwquyKmWxcYMBQT3NvG3xdH-LHollHQ,11932
15
- openscvx/utils.py,sha256=zmkKyto8Jowe_RAdOe8K0w6gzOu4JfxmX1RUL-3OFlY,2408
16
- openscvx/constraints/__init__.py,sha256=OOUcYEVoDWOSY50s2TbjpDjl3dRR3U04gRxmOyjbddY,258
17
- openscvx/constraints/boundary.py,sha256=yEhEnkKJ5f8NUeTksigEJjgBeE_YyuG_PJb_DWxg0L4,1541
18
- openscvx/constraints/ctcs.py,sha256=05epAuo_mNm1AieNB6FWatkv0wOT1ebD4FdngIROljY,1788
19
- openscvx/constraints/nodal.py,sha256=a3CRI7sYBNoOk2wZz9n7nyUuQUjzAGIjRmuHlgBSidk,1592
20
- openscvx/constraints/violation.py,sha256=wKLNhInHoXXannf2J_nLtvm3dWOMZrhJy3mJLG4CTX0,809
21
- openscvx-0.1.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
22
- openscvx-0.1.1.dist-info/METADATA,sha256=DiNKvRx7k9Z7mjnKgXvm5ClBgdIdBd66ub4kDPVy5q4,5384
23
- openscvx-0.1.1.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
24
- openscvx-0.1.1.dist-info/top_level.txt,sha256=nUT4Ybefzh40H8tVXqc1RzKESy_MAowElb-CIvAbd4Q,9
25
- openscvx-0.1.1.dist-info/RECORD,,
File without changes