openscvx 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/_version.py +2 -2
- openscvx/config.py +7 -1
- openscvx/post_processing.py +17 -2
- openscvx/trajoptproblem.py +36 -12
- {openscvx-0.1.2.dist-info → openscvx-0.1.3.dist-info}/METADATA +1 -1
- {openscvx-0.1.2.dist-info → openscvx-0.1.3.dist-info}/RECORD +9 -9
- {openscvx-0.1.2.dist-info → openscvx-0.1.3.dist-info}/WHEEL +1 -1
- {openscvx-0.1.2.dist-info → openscvx-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {openscvx-0.1.2.dist-info → openscvx-0.1.3.dist-info}/top_level.txt +0 -0
openscvx/_version.py
CHANGED
openscvx/config.py
CHANGED
|
@@ -2,6 +2,8 @@ import numpy as np
|
|
|
2
2
|
from dataclasses import dataclass, field
|
|
3
3
|
from typing import Dict, List
|
|
4
4
|
|
|
5
|
+
from openscvx.constraints.boundary import BoundaryConstraint
|
|
6
|
+
|
|
5
7
|
|
|
6
8
|
def get_affine_scaling_matrices(n, minimum, maximum):
|
|
7
9
|
S = np.diag(np.maximum(np.ones(n), abs(minimum - maximum) / 2))
|
|
@@ -120,7 +122,8 @@ class PropagationConfig:
|
|
|
120
122
|
class SimConfig:
|
|
121
123
|
x_bar: np.ndarray
|
|
122
124
|
u_bar: np.ndarray
|
|
123
|
-
initial_state:
|
|
125
|
+
initial_state: BoundaryConstraint
|
|
126
|
+
initial_state_prop: BoundaryConstraint
|
|
124
127
|
final_state: np.ndarray
|
|
125
128
|
max_state: np.ndarray
|
|
126
129
|
min_state: np.ndarray
|
|
@@ -128,9 +131,11 @@ class SimConfig:
|
|
|
128
131
|
min_control: np.ndarray
|
|
129
132
|
total_time: float
|
|
130
133
|
idx_x_true: slice
|
|
134
|
+
idx_x_true_prop: slice
|
|
131
135
|
idx_u_true: slice
|
|
132
136
|
idx_t: slice
|
|
133
137
|
idx_y: slice
|
|
138
|
+
idx_y_prop: slice
|
|
134
139
|
idx_s: slice
|
|
135
140
|
ctcs_node_intervals: list = None
|
|
136
141
|
constraints_ctcs: List[callable] = field(
|
|
@@ -138,6 +143,7 @@ class SimConfig:
|
|
|
138
143
|
) # TODO (norrisg): clean this up, consider moving to dedicated `constraints` dataclass
|
|
139
144
|
constraints_nodal: List[callable] = field(default_factory=list)
|
|
140
145
|
n_states: int = None
|
|
146
|
+
n_states_prop: int = None
|
|
141
147
|
n_controls: int = None
|
|
142
148
|
S_x: np.ndarray = None
|
|
143
149
|
inv_S_x: np.ndarray = None
|
openscvx/post_processing.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import numpy as np
|
|
2
|
+
import jax.numpy as jnp
|
|
2
3
|
|
|
3
4
|
from openscvx.propagation import s_to_t, t_to_tau, simulate_nonlinear_time
|
|
4
5
|
from openscvx.config import Config
|
|
@@ -14,9 +15,13 @@ def propagate_trajectory_results(params: Config, result: dict, propagation_solve
|
|
|
14
15
|
|
|
15
16
|
tau_vals, u_full = t_to_tau(u, t_full, u, t, params)
|
|
16
17
|
|
|
17
|
-
|
|
18
|
+
# Match free values from initial state to the initial value from the result
|
|
19
|
+
mask = jnp.array([t == "Free" for t in params.sim.initial_state_prop.types], dtype=bool)
|
|
20
|
+
params.sim.initial_state_prop.value = jnp.where(mask, x[0], params.sim.initial_state_prop.value)
|
|
18
21
|
|
|
19
|
-
|
|
22
|
+
x_full = simulate_nonlinear_time(params.sim.initial_state_prop.value, u, tau_vals, t, params, propagation_solver)
|
|
23
|
+
|
|
24
|
+
print("Total CTCS Constraint Violation:", x_full[-1, params.sim.idx_y_prop])
|
|
20
25
|
i = 0
|
|
21
26
|
cost = np.zeros_like(x[-1, i])
|
|
22
27
|
for type in params.sim.initial_state.type:
|
|
@@ -28,6 +33,16 @@ def propagate_trajectory_results(params: Config, result: dict, propagation_solve
|
|
|
28
33
|
if type == "Minimize":
|
|
29
34
|
cost += x[-1, i]
|
|
30
35
|
i += 1
|
|
36
|
+
i=0
|
|
37
|
+
for type in params.sim.initial_state.type:
|
|
38
|
+
if type == "Maximize":
|
|
39
|
+
cost -= x[0, i]
|
|
40
|
+
i += 1
|
|
41
|
+
i = 0
|
|
42
|
+
for type in params.sim.final_state.type:
|
|
43
|
+
if type == "Maximize":
|
|
44
|
+
cost -= x[-1, i]
|
|
45
|
+
i += 1
|
|
31
46
|
print("Cost: ", cost)
|
|
32
47
|
|
|
33
48
|
more_result = dict(t_full=t_full, x_full=x_full, u_full=u_full)
|
openscvx/trajoptproblem.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import jax.numpy as jnp
|
|
2
|
-
from typing import List, Union
|
|
2
|
+
from typing import List, Union, Optional
|
|
3
3
|
import queue
|
|
4
4
|
import threading
|
|
5
5
|
import time
|
|
@@ -23,7 +23,7 @@ from openscvx.augmentation.ctcs import sort_ctcs_constraints
|
|
|
23
23
|
from openscvx.constraints.violation import get_g_funcs, CTCSViolation
|
|
24
24
|
from openscvx.discretization import get_discretization_solver
|
|
25
25
|
from openscvx.propagation import get_propagation_solver
|
|
26
|
-
from openscvx.constraints.boundary import BoundaryConstraint
|
|
26
|
+
from openscvx.constraints.boundary import BoundaryConstraint, boundary
|
|
27
27
|
from openscvx.constraints.ctcs import CTCSConstraint
|
|
28
28
|
from openscvx.constraints.nodal import NodalConstraint
|
|
29
29
|
from openscvx.ptr import PTR_init, PTR_main
|
|
@@ -49,17 +49,24 @@ class TrajOptProblem:
|
|
|
49
49
|
x_min: jnp.ndarray,
|
|
50
50
|
u_max: jnp.ndarray,
|
|
51
51
|
u_min: jnp.ndarray,
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
52
|
+
dynamics_prop: callable = None,
|
|
53
|
+
initial_state_prop: BoundaryConstraint = None,
|
|
54
|
+
scp: Optional[ScpConfig] = None,
|
|
55
|
+
dis: Optional[DiscretizationConfig] = None,
|
|
56
|
+
prp: Optional[PropagationConfig] = None,
|
|
57
|
+
sim: Optional[SimConfig] = None,
|
|
58
|
+
dev: Optional[DevConfig] = None,
|
|
59
|
+
cvx: Optional[ConvexSolverConfig] = None,
|
|
58
60
|
licq_min=0.0,
|
|
59
61
|
licq_max=1e-4,
|
|
60
62
|
time_dilation_factor_min=0.3,
|
|
61
63
|
time_dilation_factor_max=3.0,
|
|
62
64
|
):
|
|
65
|
+
if dynamics_prop is None:
|
|
66
|
+
dynamics_prop = dynamics
|
|
67
|
+
|
|
68
|
+
if initial_state_prop is None:
|
|
69
|
+
initial_state_prop = initial_state
|
|
63
70
|
|
|
64
71
|
# TODO (norrisg) move this into some augmentation function, if we want to make this be executed after the init (i.e. within problem.initialize) need to rethink how problem is defined
|
|
65
72
|
constraints_ctcs = []
|
|
@@ -81,11 +88,15 @@ class TrajOptProblem:
|
|
|
81
88
|
constraints_ctcs, node_intervals, num_augmented_states = sort_ctcs_constraints(constraints_ctcs, N)
|
|
82
89
|
|
|
83
90
|
# Index tracking
|
|
84
|
-
idx_x_true = slice(0, len(
|
|
91
|
+
idx_x_true = slice(0, len(initial_state.value))
|
|
92
|
+
idx_x_true_prop = slice(0, len(initial_state_prop.value))
|
|
85
93
|
idx_u_true = slice(0, len(u_max))
|
|
86
94
|
idx_constraint_violation = slice(
|
|
87
95
|
idx_x_true.stop, idx_x_true.stop + num_augmented_states
|
|
88
96
|
)
|
|
97
|
+
idx_constraint_violation_prop = slice(
|
|
98
|
+
idx_x_true_prop.stop, idx_x_true_prop.stop + num_augmented_states
|
|
99
|
+
)
|
|
89
100
|
|
|
90
101
|
idx_time_dilation = slice(idx_u_true.stop, idx_u_true.stop + 1)
|
|
91
102
|
|
|
@@ -106,6 +117,11 @@ class TrajOptProblem:
|
|
|
106
117
|
[u_guess, np.full((u_guess.shape[0], 1), time_init)]
|
|
107
118
|
)
|
|
108
119
|
|
|
120
|
+
initial_state_prop_values = np.hstack([initial_state_prop.value, np.repeat(licq_min, num_augmented_states)])
|
|
121
|
+
initial_state_prop_types = np.hstack([initial_state_prop.type, ["Fix"] * num_augmented_states])
|
|
122
|
+
initial_state_prop = boundary(initial_state_prop_values)
|
|
123
|
+
initial_state_prop.types = initial_state_prop_types
|
|
124
|
+
|
|
109
125
|
if dis is None:
|
|
110
126
|
dis = DiscretizationConfig()
|
|
111
127
|
|
|
@@ -114,17 +130,21 @@ class TrajOptProblem:
|
|
|
114
130
|
x_bar=x_bar_augmented,
|
|
115
131
|
u_bar=u_bar_augmented,
|
|
116
132
|
initial_state=initial_state,
|
|
133
|
+
initial_state_prop=initial_state_prop,
|
|
117
134
|
final_state=final_state,
|
|
118
135
|
max_state=x_max_augmented,
|
|
119
136
|
min_state=x_min_augmented,
|
|
120
137
|
max_control=u_max_augmented,
|
|
121
138
|
min_control=u_min_augmented,
|
|
122
139
|
total_time=time_init,
|
|
123
|
-
n_states=len(
|
|
140
|
+
n_states=len(initial_state.value),
|
|
141
|
+
n_states_prop=len(initial_state_prop.value),
|
|
124
142
|
idx_x_true=idx_x_true,
|
|
143
|
+
idx_x_true_prop=idx_x_true_prop,
|
|
125
144
|
idx_u_true=idx_u_true,
|
|
126
145
|
idx_t=idx_time,
|
|
127
146
|
idx_y=idx_constraint_violation,
|
|
147
|
+
idx_y_prop=idx_constraint_violation_prop,
|
|
128
148
|
idx_s=idx_time_dilation,
|
|
129
149
|
ctcs_node_intervals=node_intervals,
|
|
130
150
|
)
|
|
@@ -162,6 +182,7 @@ class TrajOptProblem:
|
|
|
162
182
|
|
|
163
183
|
ctcs_violation_funcs = get_g_funcs(constraints_ctcs)
|
|
164
184
|
self.dynamics_augmented = build_augmented_dynamics(dynamics, ctcs_violation_funcs, idx_x_true, idx_u_true)
|
|
185
|
+
self.dynamics_augmented_prop = build_augmented_dynamics(dynamics_prop, ctcs_violation_funcs, idx_x_true_prop, idx_u_true)
|
|
165
186
|
|
|
166
187
|
self.params = Config(
|
|
167
188
|
sim=sim,
|
|
@@ -215,6 +236,9 @@ class TrajOptProblem:
|
|
|
215
236
|
self.dynamics_augmented.A = jax.jit(jax.vmap(self.dynamics_augmented.A, in_axes=(0, 0, 0)))
|
|
216
237
|
self.dynamics_augmented.B = jax.jit(jax.vmap(self.dynamics_augmented.B, in_axes=(0, 0, 0)))
|
|
217
238
|
|
|
239
|
+
|
|
240
|
+
self.dynamics_augmented_prop.f = jax.vmap(self.dynamics_augmented_prop.f)
|
|
241
|
+
|
|
218
242
|
for constraint in self.params.sim.constraints_nodal:
|
|
219
243
|
if not constraint.convex:
|
|
220
244
|
# TODO: (haynec) switch to AOT instead of JIT
|
|
@@ -224,7 +248,7 @@ class TrajOptProblem:
|
|
|
224
248
|
|
|
225
249
|
# Generate solvers and optimal control problem
|
|
226
250
|
self.discretization_solver = get_discretization_solver(self.dynamics_augmented, self.params)
|
|
227
|
-
self.propagation_solver = get_propagation_solver(self.
|
|
251
|
+
self.propagation_solver = get_propagation_solver(self.dynamics_augmented_prop.f, self.params)
|
|
228
252
|
self.optimal_control_problem = OptimalControlProblem(self.params)
|
|
229
253
|
|
|
230
254
|
# Initialize the PTR loop
|
|
@@ -248,7 +272,7 @@ class TrajOptProblem:
|
|
|
248
272
|
self.propagation_solver = (
|
|
249
273
|
jax.jit(self.propagation_solver)
|
|
250
274
|
.lower(
|
|
251
|
-
np.ones((self.params.sim.
|
|
275
|
+
np.ones((self.params.sim.n_states_prop)),
|
|
252
276
|
(0.0, 0.0),
|
|
253
277
|
np.ones((1, self.params.sim.n_controls)),
|
|
254
278
|
np.ones((1, self.params.sim.n_controls)),
|
|
@@ -1,16 +1,16 @@
|
|
|
1
1
|
openscvx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
openscvx/_version.py,sha256=
|
|
3
|
-
openscvx/config.py,sha256=
|
|
2
|
+
openscvx/_version.py,sha256=NIzzV8ZM0W-CSLuEs1weG4zPrn_-8yr1AwwI1iuS6yo,511
|
|
3
|
+
openscvx/config.py,sha256=lTdvJtU-Xiwq55zGtW3QvmldZj9Ne0xyKsSDSIKwqiI,10488
|
|
4
4
|
openscvx/discretization.py,sha256=YF3mEeyYHgyTWQVNQsqpi1Mv72zDLyNfaMJSWqxj34c,4745
|
|
5
5
|
openscvx/dynamics.py,sha256=X9sPpxUGGbdsnvQzgyrb_939N9ctBSsWVyI1eXtOKpc,1118
|
|
6
6
|
openscvx/integrators.py,sha256=msIS-1Ehj-9TJLHfoCMs3vdyZ8NXz-TM0RII6aqRf4E,3821
|
|
7
7
|
openscvx/io.py,sha256=fOvNWQWAegcN1gejeToaNbXenP5H5bAifNU8edJvdk4,4127
|
|
8
8
|
openscvx/ocp.py,sha256=L_509EQiMsI6s5gBYlYyxKaHEzzRdpo-XAMjliCU3Rc,7544
|
|
9
9
|
openscvx/plotting.py,sha256=fCvWJV4qWMhVyJlh18s12S_5xhj6EviF-_FuP0tWjx4,31207
|
|
10
|
-
openscvx/post_processing.py,sha256=
|
|
10
|
+
openscvx/post_processing.py,sha256=t3fUeDfA2PZq8S-WVnxDyN0xIgFuvpDP-5wbGQuIOVY,1618
|
|
11
11
|
openscvx/propagation.py,sha256=XNezQnAM-NXb9L7aHUgKQOBn0CNUPeGGDL3_BbGoODU,3758
|
|
12
12
|
openscvx/ptr.py,sha256=itDTR6RQUphnU226jaeRaAKuia-6v8U3MqAdw5-BYOk,5268
|
|
13
|
-
openscvx/trajoptproblem.py,sha256=
|
|
13
|
+
openscvx/trajoptproblem.py,sha256=hcju3UB3iZ5peqhxh6NvDQYnsArnabR0mGUAKZ3sIaw,13360
|
|
14
14
|
openscvx/utils.py,sha256=zmkKyto8Jowe_RAdOe8K0w6gzOu4JfxmX1RUL-3OFlY,2408
|
|
15
15
|
openscvx/augmentation/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
16
|
openscvx/augmentation/ctcs.py,sha256=m1jdALXSqHq3WD6lCBAUI7FR0Sfs8aCYr66h0EwE4z4,1707
|
|
@@ -20,8 +20,8 @@ openscvx/constraints/boundary.py,sha256=yEhEnkKJ5f8NUeTksigEJjgBeE_YyuG_PJb_DWxg
|
|
|
20
20
|
openscvx/constraints/ctcs.py,sha256=V763033aV82nAu7y4653KsAs11A7RpUysR_oUcnLfko,2572
|
|
21
21
|
openscvx/constraints/nodal.py,sha256=YCS0cwUurA2OTQcHBb1EQqLxNt_w3MX8Nj8FH3GYClo,1726
|
|
22
22
|
openscvx/constraints/violation.py,sha256=aIdDhHd-UndT0XB2QeuwLBKSNSAUWVkha_GeHOw9cQg,2362
|
|
23
|
-
openscvx-0.1.
|
|
24
|
-
openscvx-0.1.
|
|
25
|
-
openscvx-0.1.
|
|
26
|
-
openscvx-0.1.
|
|
27
|
-
openscvx-0.1.
|
|
23
|
+
openscvx-0.1.3.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
24
|
+
openscvx-0.1.3.dist-info/METADATA,sha256=CVr7tDDV75qDlnuEuxf-IL0Arxkb1RXwv0Xj9oVxf_s,6911
|
|
25
|
+
openscvx-0.1.3.dist-info/WHEEL,sha256=QZxptf4Y1BKFRCEDxD4h2V0mBFQOVFLFEpvxHmIs52A,91
|
|
26
|
+
openscvx-0.1.3.dist-info/top_level.txt,sha256=nUT4Ybefzh40H8tVXqc1RzKESy_MAowElb-CIvAbd4Q,9
|
|
27
|
+
openscvx-0.1.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|