openscvx 0.1.3__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/_version.py +2 -2
- openscvx/augmentation/dynamics_augmentation.py +22 -7
- openscvx/config.py +309 -197
- openscvx/constraints/__init__.py +0 -3
- openscvx/constraints/ctcs.py +188 -33
- openscvx/constraints/nodal.py +150 -11
- openscvx/constraints/violation.py +12 -2
- openscvx/discretization.py +115 -37
- openscvx/dynamics.py +150 -11
- openscvx/integrators.py +135 -16
- openscvx/io.py +129 -17
- openscvx/ocp.py +86 -67
- openscvx/plotting.py +72 -215
- openscvx/post_processing.py +48 -22
- openscvx/propagation.py +155 -55
- openscvx/ptr.py +96 -57
- openscvx/results.py +153 -0
- openscvx/trajoptproblem.py +341 -120
- openscvx/utils.py +50 -0
- {openscvx-0.1.3.dist-info → openscvx-0.2.1.dev0.dist-info}/METADATA +129 -41
- openscvx-0.2.1.dev0.dist-info/RECORD +27 -0
- {openscvx-0.1.3.dist-info → openscvx-0.2.1.dev0.dist-info}/WHEEL +1 -1
- openscvx/constraints/boundary.py +0 -49
- openscvx-0.1.3.dist-info/RECORD +0 -27
- {openscvx-0.1.3.dist-info → openscvx-0.2.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {openscvx-0.1.3.dist-info → openscvx-0.2.1.dev0.dist-info}/top_level.txt +0 -0
openscvx/trajoptproblem.py
CHANGED
|
@@ -3,9 +3,13 @@ from typing import List, Union, Optional
|
|
|
3
3
|
import queue
|
|
4
4
|
import threading
|
|
5
5
|
import time
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from copy import deepcopy
|
|
6
8
|
|
|
7
9
|
import cvxpy as cp
|
|
8
10
|
import jax
|
|
11
|
+
from jax import export, ShapeDtypeStruct
|
|
12
|
+
from functools import partial
|
|
9
13
|
import numpy as np
|
|
10
14
|
|
|
11
15
|
from openscvx.config import (
|
|
@@ -23,13 +27,18 @@ from openscvx.augmentation.ctcs import sort_ctcs_constraints
|
|
|
23
27
|
from openscvx.constraints.violation import get_g_funcs, CTCSViolation
|
|
24
28
|
from openscvx.discretization import get_discretization_solver
|
|
25
29
|
from openscvx.propagation import get_propagation_solver
|
|
26
|
-
from openscvx.constraints.boundary import BoundaryConstraint, boundary
|
|
27
30
|
from openscvx.constraints.ctcs import CTCSConstraint
|
|
28
31
|
from openscvx.constraints.nodal import NodalConstraint
|
|
29
|
-
from openscvx.ptr import PTR_init,
|
|
32
|
+
from openscvx.ptr import PTR_init, PTR_subproblem, format_result
|
|
30
33
|
from openscvx.post_processing import propagate_trajectory_results
|
|
31
34
|
from openscvx.ocp import OptimalControlProblem
|
|
32
35
|
from openscvx import io
|
|
36
|
+
from openscvx.utils import stable_function_hash
|
|
37
|
+
from openscvx.backend.state import State, Free
|
|
38
|
+
from openscvx.backend.control import Control
|
|
39
|
+
from openscvx.backend.parameter import Parameter
|
|
40
|
+
from openscvx.results import OptimizationResults
|
|
41
|
+
|
|
33
42
|
|
|
34
43
|
|
|
35
44
|
# TODO: (norrisg) Decide whether to have constraints`, `cost`, alongside `dynamics`, ` etc.
|
|
@@ -38,19 +47,13 @@ class TrajOptProblem:
|
|
|
38
47
|
self,
|
|
39
48
|
dynamics: Dynamics,
|
|
40
49
|
constraints: List[Union[CTCSConstraint, NodalConstraint]],
|
|
41
|
-
|
|
50
|
+
x: State,
|
|
51
|
+
u: Control,
|
|
42
52
|
N: int,
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
u_guess: jnp.ndarray,
|
|
46
|
-
initial_state: BoundaryConstraint,
|
|
47
|
-
final_state: BoundaryConstraint,
|
|
48
|
-
x_max: jnp.ndarray,
|
|
49
|
-
x_min: jnp.ndarray,
|
|
50
|
-
u_max: jnp.ndarray,
|
|
51
|
-
u_min: jnp.ndarray,
|
|
53
|
+
idx_time: int,
|
|
54
|
+
params: dict = {},
|
|
52
55
|
dynamics_prop: callable = None,
|
|
53
|
-
|
|
56
|
+
x_prop: State = None,
|
|
54
57
|
scp: Optional[ScpConfig] = None,
|
|
55
58
|
dis: Optional[DiscretizationConfig] = None,
|
|
56
59
|
prp: Optional[PropagationConfig] = None,
|
|
@@ -62,11 +65,44 @@ class TrajOptProblem:
|
|
|
62
65
|
time_dilation_factor_min=0.3,
|
|
63
66
|
time_dilation_factor_max=3.0,
|
|
64
67
|
):
|
|
68
|
+
"""
|
|
69
|
+
The primary class in charge of compiling and exporting the solvers
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
dynamics (Dynamics): Dynamics function decorated with @dynamics
|
|
74
|
+
constraints (List[Union[CTCSConstraint, NodalConstraint]]): List of constraints decorated with @ctcs or @nodal
|
|
75
|
+
idx_time (int): Index of the time variable in the state vector
|
|
76
|
+
N (int): Number of segments in the trajectory
|
|
77
|
+
time_init (float): Initial time for the trajectory
|
|
78
|
+
x_guess (jnp.ndarray): Initial guess for the state trajectory
|
|
79
|
+
u_guess (jnp.ndarray): Initial guess for the control trajectory
|
|
80
|
+
initial_state (BoundaryConstraint): Initial state constraint
|
|
81
|
+
final_state (BoundaryConstraint): Final state constraint
|
|
82
|
+
x_max (jnp.ndarray): Upper bound on the state variables
|
|
83
|
+
x_min (jnp.ndarray): Lower bound on the state variables
|
|
84
|
+
u_max (jnp.ndarray): Upper bound on the control variables
|
|
85
|
+
u_min (jnp.ndarray): Lower bound on the control variables
|
|
86
|
+
dynamics_prop: Propagation dynamics function decorated with @dynamics
|
|
87
|
+
initial_state_prop: Propagation initial state constraint
|
|
88
|
+
scp: SCP configuration object
|
|
89
|
+
dis: Discretization configuration object
|
|
90
|
+
prp: Propagation configuration object
|
|
91
|
+
sim: Simulation configuration object
|
|
92
|
+
dev: Development configuration object
|
|
93
|
+
cvx: Convex solver configuration object
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
None
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
self.params = params
|
|
100
|
+
|
|
65
101
|
if dynamics_prop is None:
|
|
66
102
|
dynamics_prop = dynamics
|
|
67
103
|
|
|
68
|
-
if
|
|
69
|
-
|
|
104
|
+
if x_prop is None:
|
|
105
|
+
x_prop = deepcopy(x)
|
|
70
106
|
|
|
71
107
|
# TODO (norrisg) move this into some augmentation function, if we want to make this be executed after the init (i.e. within problem.initialize) need to rethink how problem is defined
|
|
72
108
|
constraints_ctcs = []
|
|
@@ -88,9 +124,9 @@ class TrajOptProblem:
|
|
|
88
124
|
constraints_ctcs, node_intervals, num_augmented_states = sort_ctcs_constraints(constraints_ctcs, N)
|
|
89
125
|
|
|
90
126
|
# Index tracking
|
|
91
|
-
idx_x_true = slice(0,
|
|
92
|
-
idx_x_true_prop = slice(0,
|
|
93
|
-
idx_u_true = slice(0,
|
|
127
|
+
idx_x_true = slice(0, x.shape[0])
|
|
128
|
+
idx_x_true_prop = slice(0, x_prop.shape[0])
|
|
129
|
+
idx_u_true = slice(0, u.shape[0])
|
|
94
130
|
idx_constraint_violation = slice(
|
|
95
131
|
idx_x_true.stop, idx_x_true.stop + num_augmented_states
|
|
96
132
|
)
|
|
@@ -102,43 +138,41 @@ class TrajOptProblem:
|
|
|
102
138
|
|
|
103
139
|
# check that idx_time is in the correct range
|
|
104
140
|
assert idx_time >= 0 and idx_time < len(
|
|
105
|
-
|
|
141
|
+
x.max
|
|
106
142
|
), "idx_time must be in the range of the state vector and non-negative"
|
|
107
143
|
idx_time = slice(idx_time, idx_time + 1)
|
|
108
144
|
|
|
109
|
-
|
|
110
|
-
|
|
145
|
+
# Create a new state object for the augmented states
|
|
146
|
+
if num_augmented_states != 0:
|
|
147
|
+
y = State(name="y", shape=(num_augmented_states,))
|
|
148
|
+
y.initial = np.zeros((num_augmented_states,))
|
|
149
|
+
y.final = np.array([Free(0)] * num_augmented_states)
|
|
150
|
+
y.guess = np.zeros((N, num_augmented_states,))
|
|
151
|
+
y.min = np.zeros((num_augmented_states,))
|
|
152
|
+
y.max = licq_max * np.ones((num_augmented_states,))
|
|
153
|
+
|
|
154
|
+
x.append(y, augmented=True)
|
|
155
|
+
x_prop.append(y, augmented=True)
|
|
156
|
+
|
|
157
|
+
s = Control(name="s", shape=(1,))
|
|
158
|
+
s.min = np.array([time_dilation_factor_min * x.final[idx_time][0]])
|
|
159
|
+
s.max = np.array([time_dilation_factor_max * x.final[idx_time][0]])
|
|
160
|
+
s.guess = np.ones((N, 1)) * x.final[idx_time][0]
|
|
111
161
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
x_bar_augmented = np.hstack([x_guess, np.full((x_guess.shape[0], num_augmented_states), 0)])
|
|
116
|
-
u_bar_augmented = np.hstack(
|
|
117
|
-
[u_guess, np.full((u_guess.shape[0], 1), time_init)]
|
|
118
|
-
)
|
|
119
|
-
|
|
120
|
-
initial_state_prop_values = np.hstack([initial_state_prop.value, np.repeat(licq_min, num_augmented_states)])
|
|
121
|
-
initial_state_prop_types = np.hstack([initial_state_prop.type, ["Fix"] * num_augmented_states])
|
|
122
|
-
initial_state_prop = boundary(initial_state_prop_values)
|
|
123
|
-
initial_state_prop.types = initial_state_prop_types
|
|
162
|
+
|
|
163
|
+
u.append(s, augmented=True)
|
|
124
164
|
|
|
125
165
|
if dis is None:
|
|
126
166
|
dis = DiscretizationConfig()
|
|
127
167
|
|
|
128
168
|
if sim is None:
|
|
129
169
|
sim = SimConfig(
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
min_state=x_min_augmented,
|
|
137
|
-
max_control=u_max_augmented,
|
|
138
|
-
min_control=u_min_augmented,
|
|
139
|
-
total_time=time_init,
|
|
140
|
-
n_states=len(initial_state.value),
|
|
141
|
-
n_states_prop=len(initial_state_prop.value),
|
|
170
|
+
x=x,
|
|
171
|
+
x_prop=x_prop,
|
|
172
|
+
u=u,
|
|
173
|
+
total_time=x.initial[idx_time][0],
|
|
174
|
+
n_states=x.initial.shape[0],
|
|
175
|
+
n_states_prop=x_prop.initial.shape[0],
|
|
142
176
|
idx_x_true=idx_x_true,
|
|
143
177
|
idx_x_true_prop=idx_x_true_prop,
|
|
144
178
|
idx_u_true=idx_u_true,
|
|
@@ -152,22 +186,11 @@ class TrajOptProblem:
|
|
|
152
186
|
if scp is None:
|
|
153
187
|
scp = ScpConfig(
|
|
154
188
|
n=N,
|
|
155
|
-
k_max=200,
|
|
156
|
-
w_tr=1e1, # Weight on the Trust Reigon
|
|
157
|
-
lam_cost=1e1, # Weight on the Nonlinear Cost
|
|
158
|
-
lam_vc=1e2, # Weight on the Virtual Control Objective
|
|
159
|
-
lam_vb=0e0, # Weight on the Virtual Buffer Objective (only for penalized nodal constraints)
|
|
160
|
-
ep_tr=1e-4, # Trust Region Tolerance
|
|
161
|
-
ep_vb=1e-4, # Virtual Control Tolerance
|
|
162
|
-
ep_vc=1e-8, # Virtual Control Tolerance for CTCS
|
|
163
|
-
cost_drop=4, # SCP iteration to relax minimal final time objective
|
|
164
|
-
cost_relax=0.5, # Minimal Time Relaxation Factor
|
|
165
|
-
w_tr_adapt=1.2, # Trust Region Adaptation Factor
|
|
166
189
|
w_tr_max_scaling_factor=1e2, # Maximum Trust Region Weight
|
|
167
190
|
)
|
|
168
191
|
else:
|
|
169
192
|
assert (
|
|
170
|
-
self.scp.n == N
|
|
193
|
+
self.settings.scp.n == N
|
|
171
194
|
), "Number of segments must be the same as in the config"
|
|
172
195
|
|
|
173
196
|
if dev is None:
|
|
@@ -184,7 +207,7 @@ class TrajOptProblem:
|
|
|
184
207
|
self.dynamics_augmented = build_augmented_dynamics(dynamics, ctcs_violation_funcs, idx_x_true, idx_u_true)
|
|
185
208
|
self.dynamics_augmented_prop = build_augmented_dynamics(dynamics_prop, ctcs_violation_funcs, idx_x_true_prop, idx_u_true)
|
|
186
209
|
|
|
187
|
-
self.
|
|
210
|
+
self.settings = Config(
|
|
188
211
|
sim=sim,
|
|
189
212
|
scp=scp,
|
|
190
213
|
dis=dis,
|
|
@@ -192,18 +215,18 @@ class TrajOptProblem:
|
|
|
192
215
|
cvx=cvx,
|
|
193
216
|
prp=prp,
|
|
194
217
|
)
|
|
195
|
-
|
|
218
|
+
|
|
196
219
|
self.optimal_control_problem: cp.Problem = None
|
|
197
220
|
self.discretization_solver: callable = None
|
|
198
221
|
self.cpg_solve = None
|
|
199
222
|
|
|
200
223
|
# set up emitter & thread only if printing is enabled
|
|
201
|
-
if self.
|
|
224
|
+
if self.settings.dev.printing:
|
|
202
225
|
self.print_queue = queue.Queue()
|
|
203
226
|
self.emitter_function = lambda data: self.print_queue.put(data)
|
|
204
227
|
self.print_thread = threading.Thread(
|
|
205
228
|
target=io.intermediate,
|
|
206
|
-
args=(self.print_queue, self.
|
|
229
|
+
args=(self.print_queue, self.settings),
|
|
207
230
|
daemon=True,
|
|
208
231
|
)
|
|
209
232
|
self.print_thread.start()
|
|
@@ -216,11 +239,23 @@ class TrajOptProblem:
|
|
|
216
239
|
self.timing_solve = None
|
|
217
240
|
self.timing_post = None
|
|
218
241
|
|
|
242
|
+
# SCP state variables
|
|
243
|
+
self.scp_k = 0
|
|
244
|
+
self.scp_J_tr = 1e2
|
|
245
|
+
self.scp_J_vb = 1e2
|
|
246
|
+
self.scp_J_vc = 1e2
|
|
247
|
+
self.scp_trajs = []
|
|
248
|
+
self.scp_controls = []
|
|
249
|
+
self.scp_V_multi_shoot_traj = []
|
|
250
|
+
|
|
219
251
|
def initialize(self):
|
|
220
252
|
io.intro()
|
|
221
253
|
|
|
254
|
+
# Print problem summary
|
|
255
|
+
io.print_problem_summary(self.settings)
|
|
256
|
+
|
|
222
257
|
# Enable the profiler
|
|
223
|
-
if self.
|
|
258
|
+
if self.settings.dev.profiling:
|
|
224
259
|
import cProfile
|
|
225
260
|
|
|
226
261
|
pr = cProfile.Profile()
|
|
@@ -228,18 +263,17 @@ class TrajOptProblem:
|
|
|
228
263
|
|
|
229
264
|
t_0_while = time.time()
|
|
230
265
|
# Ensure parameter sizes and normalization are correct
|
|
231
|
-
self.
|
|
232
|
-
self.
|
|
266
|
+
self.settings.scp.__post_init__()
|
|
267
|
+
self.settings.sim.__post_init__()
|
|
233
268
|
|
|
234
269
|
# Compile dynamics and jacobians
|
|
235
|
-
self.dynamics_augmented.f = jax.vmap(self.dynamics_augmented.f)
|
|
236
|
-
self.dynamics_augmented.A = jax.
|
|
237
|
-
self.dynamics_augmented.B = jax.
|
|
270
|
+
self.dynamics_augmented.f = jax.vmap(self.dynamics_augmented.f, in_axes=(0, 0, 0, *(None,) * len(self.params)))
|
|
271
|
+
self.dynamics_augmented.A = jax.vmap(self.dynamics_augmented.A, in_axes=(0, 0, 0, *(None,) * len(self.params)))
|
|
272
|
+
self.dynamics_augmented.B = jax.vmap(self.dynamics_augmented.B, in_axes=(0, 0, 0, *(None,) * len(self.params)))
|
|
273
|
+
|
|
274
|
+
self.dynamics_augmented_prop.f = jax.vmap(self.dynamics_augmented_prop.f, in_axes=(0, 0, 0, *(None,) * len(self.params)))
|
|
238
275
|
|
|
239
|
-
|
|
240
|
-
self.dynamics_augmented_prop.f = jax.vmap(self.dynamics_augmented_prop.f)
|
|
241
|
-
|
|
242
|
-
for constraint in self.params.sim.constraints_nodal:
|
|
276
|
+
for constraint in self.settings.sim.constraints_nodal:
|
|
243
277
|
if not constraint.convex:
|
|
244
278
|
# TODO: (haynec) switch to AOT instead of JIT
|
|
245
279
|
constraint.g = jax.jit(constraint.g)
|
|
@@ -247,55 +281,241 @@ class TrajOptProblem:
|
|
|
247
281
|
constraint.grad_g_u = jax.jit(constraint.grad_g_u)
|
|
248
282
|
|
|
249
283
|
# Generate solvers and optimal control problem
|
|
250
|
-
self.discretization_solver = get_discretization_solver(self.dynamics_augmented, self.params)
|
|
251
|
-
self.propagation_solver = get_propagation_solver(self.dynamics_augmented_prop.f, self.params)
|
|
252
|
-
self.optimal_control_problem = OptimalControlProblem(self.
|
|
253
|
-
|
|
254
|
-
#
|
|
255
|
-
self.
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
284
|
+
self.discretization_solver = get_discretization_solver(self.dynamics_augmented, self.settings, self.params)
|
|
285
|
+
self.propagation_solver = get_propagation_solver(self.dynamics_augmented_prop.f, self.settings, self.params)
|
|
286
|
+
self.optimal_control_problem = OptimalControlProblem(self.settings)
|
|
287
|
+
|
|
288
|
+
# Collect all relevant functions
|
|
289
|
+
functions_to_hash = [self.dynamics_augmented.f, self.dynamics_augmented_prop.f]
|
|
290
|
+
for constraint in self.settings.sim.constraints_nodal:
|
|
291
|
+
functions_to_hash.append(constraint.func)
|
|
292
|
+
for constraint in self.settings.sim.constraints_ctcs:
|
|
293
|
+
functions_to_hash.append(constraint.func)
|
|
294
|
+
|
|
295
|
+
# Get unique source-based hash
|
|
296
|
+
function_hash = stable_function_hash(
|
|
297
|
+
functions_to_hash,
|
|
298
|
+
n_discretization_nodes=self.settings.scp.n,
|
|
299
|
+
dt=self.settings.prp.dt,
|
|
300
|
+
total_time=self.settings.sim.total_time,
|
|
301
|
+
state_max=self.settings.sim.x.max,
|
|
302
|
+
state_min=self.settings.sim.x.min,
|
|
303
|
+
control_max=self.settings.sim.u.max,
|
|
304
|
+
control_min=self.settings.sim.u.min
|
|
259
305
|
)
|
|
260
306
|
|
|
307
|
+
solver_dir = Path(".tmp")
|
|
308
|
+
solver_dir.mkdir(parents=True, exist_ok=True)
|
|
309
|
+
dis_solver_file = solver_dir / f"compiled_discretization_solver_{function_hash}.jax"
|
|
310
|
+
prop_solver_file = solver_dir / f"compiled_propagation_solver_{function_hash}.jax"
|
|
311
|
+
|
|
312
|
+
|
|
261
313
|
# Compile the solvers
|
|
262
|
-
if not self.
|
|
263
|
-
self.
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
314
|
+
if not self.settings.dev.debug:
|
|
315
|
+
if self.settings.sim.save_compiled:
|
|
316
|
+
# Check if the compiled file already exists
|
|
317
|
+
try:
|
|
318
|
+
with open(dis_solver_file, "rb") as f:
|
|
319
|
+
serial_dis = f.read()
|
|
320
|
+
# Load the compiled code
|
|
321
|
+
self.discretization_solver = export.deserialize(serial_dis)
|
|
322
|
+
print("✓ Loaded existing discretization solver")
|
|
323
|
+
except FileNotFoundError:
|
|
324
|
+
print("Compiling discretization solver...")
|
|
325
|
+
# Extract parameter values and names in order
|
|
326
|
+
param_values = [param.value for _, param in self.params.items()]
|
|
327
|
+
|
|
328
|
+
self.discretization_solver = export.export(jax.jit(self.discretization_solver))(
|
|
329
|
+
np.ones((self.settings.scp.n, self.settings.sim.n_states)),
|
|
330
|
+
np.ones((self.settings.scp.n, self.settings.sim.n_controls)),
|
|
331
|
+
*param_values
|
|
332
|
+
)
|
|
333
|
+
# Serialize and Save the compiled code in a temp directory
|
|
334
|
+
with open(dis_solver_file, "wb") as f:
|
|
335
|
+
f.write(self.discretization_solver.serialize())
|
|
336
|
+
print("✓ Discretization solver compiled and saved")
|
|
337
|
+
else:
|
|
338
|
+
print("Compiling discretization solver (not saving/loading from disk)...")
|
|
339
|
+
param_values = [param.value for _, param in self.params.items()]
|
|
340
|
+
self.discretization_solver = export.export(jax.jit(self.discretization_solver))(
|
|
341
|
+
np.ones((self.settings.scp.n, self.settings.sim.n_states)),
|
|
342
|
+
np.ones((self.settings.scp.n, self.settings.sim.n_controls)),
|
|
343
|
+
*param_values
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# Compile the discretization solver and save it
|
|
347
|
+
dtau = 1.0 / (self.settings.scp.n - 1)
|
|
348
|
+
dt_max = self.settings.sim.u.max[self.settings.sim.idx_s][0] * dtau
|
|
349
|
+
|
|
350
|
+
self.settings.prp.max_tau_len = int(dt_max / self.settings.prp.dt) + 2
|
|
351
|
+
|
|
352
|
+
# Check if the compiled file already exists
|
|
353
|
+
if self.settings.sim.save_compiled:
|
|
354
|
+
try:
|
|
355
|
+
with open(prop_solver_file, "rb") as f:
|
|
356
|
+
serial_prop = f.read()
|
|
357
|
+
# Load the compiled code
|
|
358
|
+
self.propagation_solver = export.deserialize(serial_prop)
|
|
359
|
+
print("✓ Loaded existing propagation solver")
|
|
360
|
+
except FileNotFoundError:
|
|
361
|
+
print("Compiling propagation solver...")
|
|
362
|
+
# Extract parameter values and names in order
|
|
363
|
+
param_values = [param.value for _, param in self.params.items()]
|
|
364
|
+
|
|
365
|
+
propagation_solver = export.export(jax.jit(self.propagation_solver))(
|
|
366
|
+
np.ones((self.settings.sim.n_states_prop)), # x_0
|
|
367
|
+
(0.0, 0.0), # time span
|
|
368
|
+
np.ones((1, self.settings.sim.n_controls)), # controls_current
|
|
369
|
+
np.ones((1, self.settings.sim.n_controls)), # controls_next
|
|
370
|
+
np.ones((1, 1)), # tau_0
|
|
371
|
+
np.ones((1, 1)).astype("int"), # segment index
|
|
372
|
+
0, # idx_s_stop
|
|
373
|
+
np.ones((self.settings.prp.max_tau_len,)), # save_time (tau_cur_padded)
|
|
374
|
+
np.ones((self.settings.prp.max_tau_len,), dtype=bool), # mask_padded (boolean mask)
|
|
375
|
+
*param_values, # additional parameters
|
|
268
376
|
)
|
|
269
|
-
.compile()
|
|
270
|
-
)
|
|
271
377
|
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
378
|
+
# Serialize and Save the compiled code in a temp directory
|
|
379
|
+
self.propagation_solver = propagation_solver
|
|
380
|
+
|
|
381
|
+
with open(prop_solver_file, "wb") as f:
|
|
382
|
+
f.write(self.propagation_solver.serialize())
|
|
383
|
+
print("✓ Propagation solver compiled and saved")
|
|
384
|
+
else:
|
|
385
|
+
print("Compiling propagation solver (not saving/loading from disk)...")
|
|
386
|
+
param_values = [param.value for _, param in self.params.items()]
|
|
387
|
+
propagation_solver = export.export(jax.jit(self.propagation_solver))(
|
|
388
|
+
np.ones((self.settings.sim.n_states_prop)), # x_0
|
|
389
|
+
(0.0, 0.0), # time span
|
|
390
|
+
np.ones((1, self.settings.sim.n_controls)), # controls_current
|
|
391
|
+
np.ones((1, self.settings.sim.n_controls)), # controls_next
|
|
392
|
+
np.ones((1, 1)), # tau_0
|
|
393
|
+
np.ones((1, 1)).astype("int"), # segment index
|
|
394
|
+
0, # idx_s_stop
|
|
395
|
+
np.ones((self.settings.prp.max_tau_len,)), # save_time (tau_cur_padded)
|
|
396
|
+
np.ones((self.settings.prp.max_tau_len,), dtype=bool), # mask_padded (boolean mask)
|
|
397
|
+
*param_values, # additional parameters
|
|
282
398
|
)
|
|
283
|
-
.
|
|
399
|
+
self.propagation_solver = propagation_solver
|
|
400
|
+
|
|
401
|
+
# Initialize the PTR loop
|
|
402
|
+
print("Initializing the SCvx Subproblem Solver...")
|
|
403
|
+
self.cpg_solve = PTR_init(
|
|
404
|
+
self.params,
|
|
405
|
+
self.optimal_control_problem,
|
|
406
|
+
self.discretization_solver,
|
|
407
|
+
self.settings,
|
|
284
408
|
)
|
|
409
|
+
print("✓ SCvx Subproblem Solver initialized")
|
|
410
|
+
|
|
411
|
+
# Reset SCP state
|
|
412
|
+
self.scp_k = 1
|
|
413
|
+
self.scp_J_tr = 1e2
|
|
414
|
+
self.scp_J_vb = 1e2
|
|
415
|
+
self.scp_J_vc = 1e2
|
|
416
|
+
self.scp_trajs = [self.settings.sim.x.guess]
|
|
417
|
+
self.scp_controls = [self.settings.sim.u.guess]
|
|
418
|
+
self.scp_V_multi_shoot_traj = []
|
|
285
419
|
|
|
286
420
|
t_f_while = time.time()
|
|
287
421
|
self.timing_init = t_f_while - t_0_while
|
|
288
422
|
print("Total Initialization Time: ", self.timing_init)
|
|
289
423
|
|
|
290
|
-
|
|
424
|
+
# Robust priming call for propagation_solver.call (no debug prints)
|
|
425
|
+
try:
|
|
426
|
+
x_0 = np.ones(self.settings.sim.x_prop.initial.shape, dtype=self.settings.sim.x_prop.initial.dtype)
|
|
427
|
+
tau_grid = (0.0, 1.0)
|
|
428
|
+
controls_current = np.ones((1, self.settings.sim.u.shape[0]), dtype=self.settings.sim.u.guess.dtype)
|
|
429
|
+
controls_next = np.ones((1, self.settings.sim.u.shape[0]), dtype=self.settings.sim.u.guess.dtype)
|
|
430
|
+
tau_init = np.array([[0.0]], dtype=np.float64)
|
|
431
|
+
node = np.array([[0]], dtype=np.int64)
|
|
432
|
+
idx_s_stop = self.settings.sim.idx_s.stop
|
|
433
|
+
save_time = np.ones((self.settings.prp.max_tau_len,), dtype=np.float64)
|
|
434
|
+
mask_padded = np.ones((self.settings.prp.max_tau_len,), dtype=bool)
|
|
435
|
+
param_values = [np.ones_like(param.value) if hasattr(param.value, 'shape') else float(param.value) for _, param in self.params.items()]
|
|
436
|
+
self.propagation_solver.call(
|
|
437
|
+
x_0, tau_grid, controls_current, controls_next, tau_init, node, idx_s_stop, save_time, mask_padded, *param_values
|
|
438
|
+
)
|
|
439
|
+
except Exception as e:
|
|
440
|
+
print(f"[Initialization] Priming propagation_solver.call failed: {e}")
|
|
441
|
+
|
|
442
|
+
if self.settings.dev.profiling:
|
|
291
443
|
pr.disable()
|
|
292
444
|
# Save results so it can be viusualized with snakeviz
|
|
293
445
|
pr.dump_stats("profiling_initialize.prof")
|
|
294
446
|
|
|
295
|
-
def
|
|
447
|
+
def step(self):
|
|
448
|
+
"""Performs a single SCP iteration.
|
|
449
|
+
|
|
450
|
+
This method is designed for real-time plotting and interactive optimization.
|
|
451
|
+
It performs one complete SCP iteration including subproblem solving,
|
|
452
|
+
state updates, and progress emission for real-time visualization.
|
|
453
|
+
|
|
454
|
+
Returns:
|
|
455
|
+
dict: Dictionary containing convergence status and current state
|
|
456
|
+
"""
|
|
457
|
+
x = self.settings.sim.x
|
|
458
|
+
u = self.settings.sim.u
|
|
459
|
+
|
|
460
|
+
# Run the subproblem
|
|
461
|
+
x_sol, u_sol, cost, J_total, J_vb_vec, J_vc_vec, J_tr_vec, prob_stat, V_multi_shoot, subprop_time, dis_time = PTR_subproblem(
|
|
462
|
+
self.params.items(),
|
|
463
|
+
self.cpg_solve,
|
|
464
|
+
x,
|
|
465
|
+
u,
|
|
466
|
+
self.discretization_solver,
|
|
467
|
+
self.optimal_control_problem,
|
|
468
|
+
self.settings,
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
# Update state
|
|
472
|
+
self.scp_V_multi_shoot_traj.append(V_multi_shoot)
|
|
473
|
+
x.guess = x_sol
|
|
474
|
+
u.guess = u_sol
|
|
475
|
+
self.scp_trajs.append(x.guess)
|
|
476
|
+
self.scp_controls.append(u.guess)
|
|
477
|
+
|
|
478
|
+
self.scp_J_tr = np.sum(np.array(J_tr_vec))
|
|
479
|
+
self.scp_J_vb = np.sum(np.array(J_vb_vec))
|
|
480
|
+
self.scp_J_vc = np.sum(np.array(J_vc_vec))
|
|
481
|
+
|
|
482
|
+
# Update weights
|
|
483
|
+
self.settings.scp.w_tr = min(self.settings.scp.w_tr * self.settings.scp.w_tr_adapt, self.settings.scp.w_tr_max)
|
|
484
|
+
if self.scp_k > self.settings.scp.cost_drop:
|
|
485
|
+
self.settings.scp.lam_cost = self.settings.scp.lam_cost * self.settings.scp.cost_relax
|
|
486
|
+
|
|
487
|
+
# Emit data
|
|
488
|
+
self.emitter_function(
|
|
489
|
+
{
|
|
490
|
+
"iter": self.scp_k,
|
|
491
|
+
"dis_time": dis_time * 1000.0,
|
|
492
|
+
"subprop_time": subprop_time * 1000.0,
|
|
493
|
+
"J_total": J_total,
|
|
494
|
+
"J_tr": self.scp_J_tr,
|
|
495
|
+
"J_vb": self.scp_J_vb,
|
|
496
|
+
"J_vc": self.scp_J_vc,
|
|
497
|
+
"cost": cost[-1],
|
|
498
|
+
"prob_stat": prob_stat,
|
|
499
|
+
}
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
# Increment counter
|
|
503
|
+
self.scp_k += 1
|
|
504
|
+
|
|
505
|
+
# Create a result dictionary for this step
|
|
506
|
+
return {
|
|
507
|
+
"converged": (self.scp_J_tr < self.settings.scp.ep_tr) and \
|
|
508
|
+
(self.scp_J_vb < self.settings.scp.ep_vb) and \
|
|
509
|
+
(self.scp_J_vc < self.settings.scp.ep_vc),
|
|
510
|
+
"u": u,
|
|
511
|
+
"x": x,
|
|
512
|
+
"V_multi_shoot": V_multi_shoot
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
def solve(self, max_iters: Optional[int] = None, continuous: bool = False) -> OptimizationResults:
|
|
296
516
|
# Ensure parameter sizes and normalization are correct
|
|
297
|
-
self.
|
|
298
|
-
self.
|
|
517
|
+
self.settings.scp.__post_init__()
|
|
518
|
+
self.settings.sim.__post_init__()
|
|
299
519
|
|
|
300
520
|
if self.optimal_control_problem is None or self.discretization_solver is None:
|
|
301
521
|
raise ValueError(
|
|
@@ -303,7 +523,7 @@ class TrajOptProblem:
|
|
|
303
523
|
)
|
|
304
524
|
|
|
305
525
|
# Enable the profiler
|
|
306
|
-
if self.
|
|
526
|
+
if self.settings.dev.profiling:
|
|
307
527
|
import cProfile
|
|
308
528
|
|
|
309
529
|
pr = cProfile.Profile()
|
|
@@ -312,14 +532,13 @@ class TrajOptProblem:
|
|
|
312
532
|
t_0_while = time.time()
|
|
313
533
|
# Print top header for solver results
|
|
314
534
|
io.header()
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
self.
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
)
|
|
535
|
+
|
|
536
|
+
k_max = max_iters if max_iters is not None else self.settings.scp.k_max
|
|
537
|
+
|
|
538
|
+
while self.scp_k <= k_max:
|
|
539
|
+
result = self.step()
|
|
540
|
+
if result["converged"] and not continuous:
|
|
541
|
+
break
|
|
323
542
|
|
|
324
543
|
t_f_while = time.time()
|
|
325
544
|
self.timing_solve = t_f_while - t_0_while
|
|
@@ -328,33 +547,35 @@ class TrajOptProblem:
|
|
|
328
547
|
time.sleep(0.1)
|
|
329
548
|
|
|
330
549
|
# Print bottom footer for solver results as well as total computation time
|
|
331
|
-
io.footer(
|
|
550
|
+
io.footer()
|
|
332
551
|
|
|
333
552
|
# Disable the profiler
|
|
334
|
-
if self.
|
|
553
|
+
if self.settings.dev.profiling:
|
|
335
554
|
pr.disable()
|
|
336
555
|
# Save results so it can be viusualized with snakeviz
|
|
337
556
|
pr.dump_stats("profiling_solve.prof")
|
|
338
557
|
|
|
339
|
-
return
|
|
558
|
+
return format_result(self, self.scp_k <= k_max)
|
|
340
559
|
|
|
341
|
-
def post_process(self, result):
|
|
560
|
+
def post_process(self, result: OptimizationResults) -> OptimizationResults:
|
|
342
561
|
# Enable the profiler
|
|
343
|
-
if self.
|
|
562
|
+
if self.settings.dev.profiling:
|
|
344
563
|
import cProfile
|
|
345
564
|
|
|
346
565
|
pr = cProfile.Profile()
|
|
347
566
|
pr.enable()
|
|
348
567
|
|
|
349
568
|
t_0_post = time.time()
|
|
350
|
-
result = propagate_trajectory_results(self.params, result, self.propagation_solver)
|
|
569
|
+
result = propagate_trajectory_results(self.params, self.settings, result, self.propagation_solver)
|
|
351
570
|
t_f_post = time.time()
|
|
352
571
|
|
|
353
572
|
self.timing_post = t_f_post - t_0_post
|
|
354
|
-
|
|
573
|
+
|
|
574
|
+
# Print results summary
|
|
575
|
+
io.print_results_summary(result, self.timing_post, self.timing_init, self.timing_solve)
|
|
355
576
|
|
|
356
577
|
# Disable the profiler
|
|
357
|
-
if self.
|
|
578
|
+
if self.settings.dev.profiling:
|
|
358
579
|
pr.disable()
|
|
359
580
|
# Save results so it can be viusualized with snakeviz
|
|
360
581
|
pr.dump_stats("profiling_postprocess.prof")
|