openscvx 0.1.2__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/_version.py +2 -2
- openscvx/augmentation/dynamics_augmentation.py +22 -7
- openscvx/config.py +310 -192
- openscvx/constraints/__init__.py +0 -3
- openscvx/constraints/ctcs.py +188 -33
- openscvx/constraints/nodal.py +150 -11
- openscvx/constraints/violation.py +12 -2
- openscvx/discretization.py +115 -37
- openscvx/dynamics.py +150 -11
- openscvx/integrators.py +135 -16
- openscvx/io.py +129 -17
- openscvx/ocp.py +86 -67
- openscvx/plotting.py +72 -215
- openscvx/post_processing.py +57 -16
- openscvx/propagation.py +155 -55
- openscvx/ptr.py +96 -57
- openscvx/results.py +153 -0
- openscvx/trajoptproblem.py +359 -114
- openscvx/utils.py +50 -0
- {openscvx-0.1.2.dist-info → openscvx-0.2.1.dev0.dist-info}/METADATA +129 -41
- openscvx-0.2.1.dev0.dist-info/RECORD +27 -0
- {openscvx-0.1.2.dist-info → openscvx-0.2.1.dev0.dist-info}/WHEEL +1 -1
- openscvx/constraints/boundary.py +0 -49
- openscvx-0.1.2.dist-info/RECORD +0 -27
- {openscvx-0.1.2.dist-info → openscvx-0.2.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {openscvx-0.1.2.dist-info → openscvx-0.2.1.dev0.dist-info}/top_level.txt +0 -0
openscvx/trajoptproblem.py
CHANGED
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
import jax.numpy as jnp
|
|
2
|
-
from typing import List, Union
|
|
2
|
+
from typing import List, Union, Optional
|
|
3
3
|
import queue
|
|
4
4
|
import threading
|
|
5
5
|
import time
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from copy import deepcopy
|
|
6
8
|
|
|
7
9
|
import cvxpy as cp
|
|
8
10
|
import jax
|
|
11
|
+
from jax import export, ShapeDtypeStruct
|
|
12
|
+
from functools import partial
|
|
9
13
|
import numpy as np
|
|
10
14
|
|
|
11
15
|
from openscvx.config import (
|
|
@@ -23,13 +27,18 @@ from openscvx.augmentation.ctcs import sort_ctcs_constraints
|
|
|
23
27
|
from openscvx.constraints.violation import get_g_funcs, CTCSViolation
|
|
24
28
|
from openscvx.discretization import get_discretization_solver
|
|
25
29
|
from openscvx.propagation import get_propagation_solver
|
|
26
|
-
from openscvx.constraints.boundary import BoundaryConstraint
|
|
27
30
|
from openscvx.constraints.ctcs import CTCSConstraint
|
|
28
31
|
from openscvx.constraints.nodal import NodalConstraint
|
|
29
|
-
from openscvx.ptr import PTR_init,
|
|
32
|
+
from openscvx.ptr import PTR_init, PTR_subproblem, format_result
|
|
30
33
|
from openscvx.post_processing import propagate_trajectory_results
|
|
31
34
|
from openscvx.ocp import OptimalControlProblem
|
|
32
35
|
from openscvx import io
|
|
36
|
+
from openscvx.utils import stable_function_hash
|
|
37
|
+
from openscvx.backend.state import State, Free
|
|
38
|
+
from openscvx.backend.control import Control
|
|
39
|
+
from openscvx.backend.parameter import Parameter
|
|
40
|
+
from openscvx.results import OptimizationResults
|
|
41
|
+
|
|
33
42
|
|
|
34
43
|
|
|
35
44
|
# TODO: (norrisg) Decide whether to have constraints`, `cost`, alongside `dynamics`, ` etc.
|
|
@@ -38,28 +47,62 @@ class TrajOptProblem:
|
|
|
38
47
|
self,
|
|
39
48
|
dynamics: Dynamics,
|
|
40
49
|
constraints: List[Union[CTCSConstraint, NodalConstraint]],
|
|
41
|
-
|
|
50
|
+
x: State,
|
|
51
|
+
u: Control,
|
|
42
52
|
N: int,
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
dis: DiscretizationConfig = None,
|
|
54
|
-
prp: PropagationConfig = None,
|
|
55
|
-
sim: SimConfig = None,
|
|
56
|
-
dev: DevConfig = None,
|
|
57
|
-
cvx: ConvexSolverConfig = None,
|
|
53
|
+
idx_time: int,
|
|
54
|
+
params: dict = {},
|
|
55
|
+
dynamics_prop: callable = None,
|
|
56
|
+
x_prop: State = None,
|
|
57
|
+
scp: Optional[ScpConfig] = None,
|
|
58
|
+
dis: Optional[DiscretizationConfig] = None,
|
|
59
|
+
prp: Optional[PropagationConfig] = None,
|
|
60
|
+
sim: Optional[SimConfig] = None,
|
|
61
|
+
dev: Optional[DevConfig] = None,
|
|
62
|
+
cvx: Optional[ConvexSolverConfig] = None,
|
|
58
63
|
licq_min=0.0,
|
|
59
64
|
licq_max=1e-4,
|
|
60
65
|
time_dilation_factor_min=0.3,
|
|
61
66
|
time_dilation_factor_max=3.0,
|
|
62
67
|
):
|
|
68
|
+
"""
|
|
69
|
+
The primary class in charge of compiling and exporting the solvers
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
dynamics (Dynamics): Dynamics function decorated with @dynamics
|
|
74
|
+
constraints (List[Union[CTCSConstraint, NodalConstraint]]): List of constraints decorated with @ctcs or @nodal
|
|
75
|
+
idx_time (int): Index of the time variable in the state vector
|
|
76
|
+
N (int): Number of segments in the trajectory
|
|
77
|
+
time_init (float): Initial time for the trajectory
|
|
78
|
+
x_guess (jnp.ndarray): Initial guess for the state trajectory
|
|
79
|
+
u_guess (jnp.ndarray): Initial guess for the control trajectory
|
|
80
|
+
initial_state (BoundaryConstraint): Initial state constraint
|
|
81
|
+
final_state (BoundaryConstraint): Final state constraint
|
|
82
|
+
x_max (jnp.ndarray): Upper bound on the state variables
|
|
83
|
+
x_min (jnp.ndarray): Lower bound on the state variables
|
|
84
|
+
u_max (jnp.ndarray): Upper bound on the control variables
|
|
85
|
+
u_min (jnp.ndarray): Lower bound on the control variables
|
|
86
|
+
dynamics_prop: Propagation dynamics function decorated with @dynamics
|
|
87
|
+
initial_state_prop: Propagation initial state constraint
|
|
88
|
+
scp: SCP configuration object
|
|
89
|
+
dis: Discretization configuration object
|
|
90
|
+
prp: Propagation configuration object
|
|
91
|
+
sim: Simulation configuration object
|
|
92
|
+
dev: Development configuration object
|
|
93
|
+
cvx: Convex solver configuration object
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
None
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
self.params = params
|
|
100
|
+
|
|
101
|
+
if dynamics_prop is None:
|
|
102
|
+
dynamics_prop = dynamics
|
|
103
|
+
|
|
104
|
+
if x_prop is None:
|
|
105
|
+
x_prop = deepcopy(x)
|
|
63
106
|
|
|
64
107
|
# TODO (norrisg) move this into some augmentation function, if we want to make this be executed after the init (i.e. within problem.initialize) need to rethink how problem is defined
|
|
65
108
|
constraints_ctcs = []
|
|
@@ -81,50 +124,61 @@ class TrajOptProblem:
|
|
|
81
124
|
constraints_ctcs, node_intervals, num_augmented_states = sort_ctcs_constraints(constraints_ctcs, N)
|
|
82
125
|
|
|
83
126
|
# Index tracking
|
|
84
|
-
idx_x_true = slice(0,
|
|
85
|
-
|
|
127
|
+
idx_x_true = slice(0, x.shape[0])
|
|
128
|
+
idx_x_true_prop = slice(0, x_prop.shape[0])
|
|
129
|
+
idx_u_true = slice(0, u.shape[0])
|
|
86
130
|
idx_constraint_violation = slice(
|
|
87
131
|
idx_x_true.stop, idx_x_true.stop + num_augmented_states
|
|
88
132
|
)
|
|
133
|
+
idx_constraint_violation_prop = slice(
|
|
134
|
+
idx_x_true_prop.stop, idx_x_true_prop.stop + num_augmented_states
|
|
135
|
+
)
|
|
89
136
|
|
|
90
137
|
idx_time_dilation = slice(idx_u_true.stop, idx_u_true.stop + 1)
|
|
91
138
|
|
|
92
139
|
# check that idx_time is in the correct range
|
|
93
140
|
assert idx_time >= 0 and idx_time < len(
|
|
94
|
-
|
|
141
|
+
x.max
|
|
95
142
|
), "idx_time must be in the range of the state vector and non-negative"
|
|
96
143
|
idx_time = slice(idx_time, idx_time + 1)
|
|
97
144
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
145
|
+
# Create a new state object for the augmented states
|
|
146
|
+
if num_augmented_states != 0:
|
|
147
|
+
y = State(name="y", shape=(num_augmented_states,))
|
|
148
|
+
y.initial = np.zeros((num_augmented_states,))
|
|
149
|
+
y.final = np.array([Free(0)] * num_augmented_states)
|
|
150
|
+
y.guess = np.zeros((N, num_augmented_states,))
|
|
151
|
+
y.min = np.zeros((num_augmented_states,))
|
|
152
|
+
y.max = licq_max * np.ones((num_augmented_states,))
|
|
153
|
+
|
|
154
|
+
x.append(y, augmented=True)
|
|
155
|
+
x_prop.append(y, augmented=True)
|
|
156
|
+
|
|
157
|
+
s = Control(name="s", shape=(1,))
|
|
158
|
+
s.min = np.array([time_dilation_factor_min * x.final[idx_time][0]])
|
|
159
|
+
s.max = np.array([time_dilation_factor_max * x.final[idx_time][0]])
|
|
160
|
+
s.guess = np.ones((N, 1)) * x.final[idx_time][0]
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
u.append(s, augmented=True)
|
|
108
164
|
|
|
109
165
|
if dis is None:
|
|
110
166
|
dis = DiscretizationConfig()
|
|
111
167
|
|
|
112
168
|
if sim is None:
|
|
113
169
|
sim = SimConfig(
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
max_control=u_max_augmented,
|
|
121
|
-
min_control=u_min_augmented,
|
|
122
|
-
total_time=time_init,
|
|
123
|
-
n_states=len(x_max),
|
|
170
|
+
x=x,
|
|
171
|
+
x_prop=x_prop,
|
|
172
|
+
u=u,
|
|
173
|
+
total_time=x.initial[idx_time][0],
|
|
174
|
+
n_states=x.initial.shape[0],
|
|
175
|
+
n_states_prop=x_prop.initial.shape[0],
|
|
124
176
|
idx_x_true=idx_x_true,
|
|
177
|
+
idx_x_true_prop=idx_x_true_prop,
|
|
125
178
|
idx_u_true=idx_u_true,
|
|
126
179
|
idx_t=idx_time,
|
|
127
180
|
idx_y=idx_constraint_violation,
|
|
181
|
+
idx_y_prop=idx_constraint_violation_prop,
|
|
128
182
|
idx_s=idx_time_dilation,
|
|
129
183
|
ctcs_node_intervals=node_intervals,
|
|
130
184
|
)
|
|
@@ -132,22 +186,11 @@ class TrajOptProblem:
|
|
|
132
186
|
if scp is None:
|
|
133
187
|
scp = ScpConfig(
|
|
134
188
|
n=N,
|
|
135
|
-
k_max=200,
|
|
136
|
-
w_tr=1e1, # Weight on the Trust Reigon
|
|
137
|
-
lam_cost=1e1, # Weight on the Nonlinear Cost
|
|
138
|
-
lam_vc=1e2, # Weight on the Virtual Control Objective
|
|
139
|
-
lam_vb=0e0, # Weight on the Virtual Buffer Objective (only for penalized nodal constraints)
|
|
140
|
-
ep_tr=1e-4, # Trust Region Tolerance
|
|
141
|
-
ep_vb=1e-4, # Virtual Control Tolerance
|
|
142
|
-
ep_vc=1e-8, # Virtual Control Tolerance for CTCS
|
|
143
|
-
cost_drop=4, # SCP iteration to relax minimal final time objective
|
|
144
|
-
cost_relax=0.5, # Minimal Time Relaxation Factor
|
|
145
|
-
w_tr_adapt=1.2, # Trust Region Adaptation Factor
|
|
146
189
|
w_tr_max_scaling_factor=1e2, # Maximum Trust Region Weight
|
|
147
190
|
)
|
|
148
191
|
else:
|
|
149
192
|
assert (
|
|
150
|
-
self.scp.n == N
|
|
193
|
+
self.settings.scp.n == N
|
|
151
194
|
), "Number of segments must be the same as in the config"
|
|
152
195
|
|
|
153
196
|
if dev is None:
|
|
@@ -162,8 +205,9 @@ class TrajOptProblem:
|
|
|
162
205
|
|
|
163
206
|
ctcs_violation_funcs = get_g_funcs(constraints_ctcs)
|
|
164
207
|
self.dynamics_augmented = build_augmented_dynamics(dynamics, ctcs_violation_funcs, idx_x_true, idx_u_true)
|
|
208
|
+
self.dynamics_augmented_prop = build_augmented_dynamics(dynamics_prop, ctcs_violation_funcs, idx_x_true_prop, idx_u_true)
|
|
165
209
|
|
|
166
|
-
self.
|
|
210
|
+
self.settings = Config(
|
|
167
211
|
sim=sim,
|
|
168
212
|
scp=scp,
|
|
169
213
|
dis=dis,
|
|
@@ -171,18 +215,18 @@ class TrajOptProblem:
|
|
|
171
215
|
cvx=cvx,
|
|
172
216
|
prp=prp,
|
|
173
217
|
)
|
|
174
|
-
|
|
218
|
+
|
|
175
219
|
self.optimal_control_problem: cp.Problem = None
|
|
176
220
|
self.discretization_solver: callable = None
|
|
177
221
|
self.cpg_solve = None
|
|
178
222
|
|
|
179
223
|
# set up emitter & thread only if printing is enabled
|
|
180
|
-
if self.
|
|
224
|
+
if self.settings.dev.printing:
|
|
181
225
|
self.print_queue = queue.Queue()
|
|
182
226
|
self.emitter_function = lambda data: self.print_queue.put(data)
|
|
183
227
|
self.print_thread = threading.Thread(
|
|
184
228
|
target=io.intermediate,
|
|
185
|
-
args=(self.print_queue, self.
|
|
229
|
+
args=(self.print_queue, self.settings),
|
|
186
230
|
daemon=True,
|
|
187
231
|
)
|
|
188
232
|
self.print_thread.start()
|
|
@@ -195,11 +239,23 @@ class TrajOptProblem:
|
|
|
195
239
|
self.timing_solve = None
|
|
196
240
|
self.timing_post = None
|
|
197
241
|
|
|
242
|
+
# SCP state variables
|
|
243
|
+
self.scp_k = 0
|
|
244
|
+
self.scp_J_tr = 1e2
|
|
245
|
+
self.scp_J_vb = 1e2
|
|
246
|
+
self.scp_J_vc = 1e2
|
|
247
|
+
self.scp_trajs = []
|
|
248
|
+
self.scp_controls = []
|
|
249
|
+
self.scp_V_multi_shoot_traj = []
|
|
250
|
+
|
|
198
251
|
def initialize(self):
|
|
199
252
|
io.intro()
|
|
200
253
|
|
|
254
|
+
# Print problem summary
|
|
255
|
+
io.print_problem_summary(self.settings)
|
|
256
|
+
|
|
201
257
|
# Enable the profiler
|
|
202
|
-
if self.
|
|
258
|
+
if self.settings.dev.profiling:
|
|
203
259
|
import cProfile
|
|
204
260
|
|
|
205
261
|
pr = cProfile.Profile()
|
|
@@ -207,15 +263,17 @@ class TrajOptProblem:
|
|
|
207
263
|
|
|
208
264
|
t_0_while = time.time()
|
|
209
265
|
# Ensure parameter sizes and normalization are correct
|
|
210
|
-
self.
|
|
211
|
-
self.
|
|
266
|
+
self.settings.scp.__post_init__()
|
|
267
|
+
self.settings.sim.__post_init__()
|
|
212
268
|
|
|
213
269
|
# Compile dynamics and jacobians
|
|
214
|
-
self.dynamics_augmented.f = jax.vmap(self.dynamics_augmented.f)
|
|
215
|
-
self.dynamics_augmented.A = jax.
|
|
216
|
-
self.dynamics_augmented.B = jax.
|
|
270
|
+
self.dynamics_augmented.f = jax.vmap(self.dynamics_augmented.f, in_axes=(0, 0, 0, *(None,) * len(self.params)))
|
|
271
|
+
self.dynamics_augmented.A = jax.vmap(self.dynamics_augmented.A, in_axes=(0, 0, 0, *(None,) * len(self.params)))
|
|
272
|
+
self.dynamics_augmented.B = jax.vmap(self.dynamics_augmented.B, in_axes=(0, 0, 0, *(None,) * len(self.params)))
|
|
273
|
+
|
|
274
|
+
self.dynamics_augmented_prop.f = jax.vmap(self.dynamics_augmented_prop.f, in_axes=(0, 0, 0, *(None,) * len(self.params)))
|
|
217
275
|
|
|
218
|
-
for constraint in self.
|
|
276
|
+
for constraint in self.settings.sim.constraints_nodal:
|
|
219
277
|
if not constraint.convex:
|
|
220
278
|
# TODO: (haynec) switch to AOT instead of JIT
|
|
221
279
|
constraint.g = jax.jit(constraint.g)
|
|
@@ -223,55 +281,241 @@ class TrajOptProblem:
|
|
|
223
281
|
constraint.grad_g_u = jax.jit(constraint.grad_g_u)
|
|
224
282
|
|
|
225
283
|
# Generate solvers and optimal control problem
|
|
226
|
-
self.discretization_solver = get_discretization_solver(self.dynamics_augmented, self.params)
|
|
227
|
-
self.propagation_solver = get_propagation_solver(self.
|
|
228
|
-
self.optimal_control_problem = OptimalControlProblem(self.
|
|
229
|
-
|
|
230
|
-
#
|
|
231
|
-
self.
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
284
|
+
self.discretization_solver = get_discretization_solver(self.dynamics_augmented, self.settings, self.params)
|
|
285
|
+
self.propagation_solver = get_propagation_solver(self.dynamics_augmented_prop.f, self.settings, self.params)
|
|
286
|
+
self.optimal_control_problem = OptimalControlProblem(self.settings)
|
|
287
|
+
|
|
288
|
+
# Collect all relevant functions
|
|
289
|
+
functions_to_hash = [self.dynamics_augmented.f, self.dynamics_augmented_prop.f]
|
|
290
|
+
for constraint in self.settings.sim.constraints_nodal:
|
|
291
|
+
functions_to_hash.append(constraint.func)
|
|
292
|
+
for constraint in self.settings.sim.constraints_ctcs:
|
|
293
|
+
functions_to_hash.append(constraint.func)
|
|
294
|
+
|
|
295
|
+
# Get unique source-based hash
|
|
296
|
+
function_hash = stable_function_hash(
|
|
297
|
+
functions_to_hash,
|
|
298
|
+
n_discretization_nodes=self.settings.scp.n,
|
|
299
|
+
dt=self.settings.prp.dt,
|
|
300
|
+
total_time=self.settings.sim.total_time,
|
|
301
|
+
state_max=self.settings.sim.x.max,
|
|
302
|
+
state_min=self.settings.sim.x.min,
|
|
303
|
+
control_max=self.settings.sim.u.max,
|
|
304
|
+
control_min=self.settings.sim.u.min
|
|
235
305
|
)
|
|
236
306
|
|
|
307
|
+
solver_dir = Path(".tmp")
|
|
308
|
+
solver_dir.mkdir(parents=True, exist_ok=True)
|
|
309
|
+
dis_solver_file = solver_dir / f"compiled_discretization_solver_{function_hash}.jax"
|
|
310
|
+
prop_solver_file = solver_dir / f"compiled_propagation_solver_{function_hash}.jax"
|
|
311
|
+
|
|
312
|
+
|
|
237
313
|
# Compile the solvers
|
|
238
|
-
if not self.
|
|
239
|
-
self.
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
314
|
+
if not self.settings.dev.debug:
|
|
315
|
+
if self.settings.sim.save_compiled:
|
|
316
|
+
# Check if the compiled file already exists
|
|
317
|
+
try:
|
|
318
|
+
with open(dis_solver_file, "rb") as f:
|
|
319
|
+
serial_dis = f.read()
|
|
320
|
+
# Load the compiled code
|
|
321
|
+
self.discretization_solver = export.deserialize(serial_dis)
|
|
322
|
+
print("✓ Loaded existing discretization solver")
|
|
323
|
+
except FileNotFoundError:
|
|
324
|
+
print("Compiling discretization solver...")
|
|
325
|
+
# Extract parameter values and names in order
|
|
326
|
+
param_values = [param.value for _, param in self.params.items()]
|
|
327
|
+
|
|
328
|
+
self.discretization_solver = export.export(jax.jit(self.discretization_solver))(
|
|
329
|
+
np.ones((self.settings.scp.n, self.settings.sim.n_states)),
|
|
330
|
+
np.ones((self.settings.scp.n, self.settings.sim.n_controls)),
|
|
331
|
+
*param_values
|
|
332
|
+
)
|
|
333
|
+
# Serialize and Save the compiled code in a temp directory
|
|
334
|
+
with open(dis_solver_file, "wb") as f:
|
|
335
|
+
f.write(self.discretization_solver.serialize())
|
|
336
|
+
print("✓ Discretization solver compiled and saved")
|
|
337
|
+
else:
|
|
338
|
+
print("Compiling discretization solver (not saving/loading from disk)...")
|
|
339
|
+
param_values = [param.value for _, param in self.params.items()]
|
|
340
|
+
self.discretization_solver = export.export(jax.jit(self.discretization_solver))(
|
|
341
|
+
np.ones((self.settings.scp.n, self.settings.sim.n_states)),
|
|
342
|
+
np.ones((self.settings.scp.n, self.settings.sim.n_controls)),
|
|
343
|
+
*param_values
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# Compile the discretization solver and save it
|
|
347
|
+
dtau = 1.0 / (self.settings.scp.n - 1)
|
|
348
|
+
dt_max = self.settings.sim.u.max[self.settings.sim.idx_s][0] * dtau
|
|
349
|
+
|
|
350
|
+
self.settings.prp.max_tau_len = int(dt_max / self.settings.prp.dt) + 2
|
|
351
|
+
|
|
352
|
+
# Check if the compiled file already exists
|
|
353
|
+
if self.settings.sim.save_compiled:
|
|
354
|
+
try:
|
|
355
|
+
with open(prop_solver_file, "rb") as f:
|
|
356
|
+
serial_prop = f.read()
|
|
357
|
+
# Load the compiled code
|
|
358
|
+
self.propagation_solver = export.deserialize(serial_prop)
|
|
359
|
+
print("✓ Loaded existing propagation solver")
|
|
360
|
+
except FileNotFoundError:
|
|
361
|
+
print("Compiling propagation solver...")
|
|
362
|
+
# Extract parameter values and names in order
|
|
363
|
+
param_values = [param.value for _, param in self.params.items()]
|
|
364
|
+
|
|
365
|
+
propagation_solver = export.export(jax.jit(self.propagation_solver))(
|
|
366
|
+
np.ones((self.settings.sim.n_states_prop)), # x_0
|
|
367
|
+
(0.0, 0.0), # time span
|
|
368
|
+
np.ones((1, self.settings.sim.n_controls)), # controls_current
|
|
369
|
+
np.ones((1, self.settings.sim.n_controls)), # controls_next
|
|
370
|
+
np.ones((1, 1)), # tau_0
|
|
371
|
+
np.ones((1, 1)).astype("int"), # segment index
|
|
372
|
+
0, # idx_s_stop
|
|
373
|
+
np.ones((self.settings.prp.max_tau_len,)), # save_time (tau_cur_padded)
|
|
374
|
+
np.ones((self.settings.prp.max_tau_len,), dtype=bool), # mask_padded (boolean mask)
|
|
375
|
+
*param_values, # additional parameters
|
|
244
376
|
)
|
|
245
|
-
.compile()
|
|
246
|
-
)
|
|
247
377
|
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
378
|
+
# Serialize and Save the compiled code in a temp directory
|
|
379
|
+
self.propagation_solver = propagation_solver
|
|
380
|
+
|
|
381
|
+
with open(prop_solver_file, "wb") as f:
|
|
382
|
+
f.write(self.propagation_solver.serialize())
|
|
383
|
+
print("✓ Propagation solver compiled and saved")
|
|
384
|
+
else:
|
|
385
|
+
print("Compiling propagation solver (not saving/loading from disk)...")
|
|
386
|
+
param_values = [param.value for _, param in self.params.items()]
|
|
387
|
+
propagation_solver = export.export(jax.jit(self.propagation_solver))(
|
|
388
|
+
np.ones((self.settings.sim.n_states_prop)), # x_0
|
|
389
|
+
(0.0, 0.0), # time span
|
|
390
|
+
np.ones((1, self.settings.sim.n_controls)), # controls_current
|
|
391
|
+
np.ones((1, self.settings.sim.n_controls)), # controls_next
|
|
392
|
+
np.ones((1, 1)), # tau_0
|
|
393
|
+
np.ones((1, 1)).astype("int"), # segment index
|
|
394
|
+
0, # idx_s_stop
|
|
395
|
+
np.ones((self.settings.prp.max_tau_len,)), # save_time (tau_cur_padded)
|
|
396
|
+
np.ones((self.settings.prp.max_tau_len,), dtype=bool), # mask_padded (boolean mask)
|
|
397
|
+
*param_values, # additional parameters
|
|
258
398
|
)
|
|
259
|
-
.
|
|
399
|
+
self.propagation_solver = propagation_solver
|
|
400
|
+
|
|
401
|
+
# Initialize the PTR loop
|
|
402
|
+
print("Initializing the SCvx Subproblem Solver...")
|
|
403
|
+
self.cpg_solve = PTR_init(
|
|
404
|
+
self.params,
|
|
405
|
+
self.optimal_control_problem,
|
|
406
|
+
self.discretization_solver,
|
|
407
|
+
self.settings,
|
|
260
408
|
)
|
|
409
|
+
print("✓ SCvx Subproblem Solver initialized")
|
|
410
|
+
|
|
411
|
+
# Reset SCP state
|
|
412
|
+
self.scp_k = 1
|
|
413
|
+
self.scp_J_tr = 1e2
|
|
414
|
+
self.scp_J_vb = 1e2
|
|
415
|
+
self.scp_J_vc = 1e2
|
|
416
|
+
self.scp_trajs = [self.settings.sim.x.guess]
|
|
417
|
+
self.scp_controls = [self.settings.sim.u.guess]
|
|
418
|
+
self.scp_V_multi_shoot_traj = []
|
|
261
419
|
|
|
262
420
|
t_f_while = time.time()
|
|
263
421
|
self.timing_init = t_f_while - t_0_while
|
|
264
422
|
print("Total Initialization Time: ", self.timing_init)
|
|
265
423
|
|
|
266
|
-
|
|
424
|
+
# Robust priming call for propagation_solver.call (no debug prints)
|
|
425
|
+
try:
|
|
426
|
+
x_0 = np.ones(self.settings.sim.x_prop.initial.shape, dtype=self.settings.sim.x_prop.initial.dtype)
|
|
427
|
+
tau_grid = (0.0, 1.0)
|
|
428
|
+
controls_current = np.ones((1, self.settings.sim.u.shape[0]), dtype=self.settings.sim.u.guess.dtype)
|
|
429
|
+
controls_next = np.ones((1, self.settings.sim.u.shape[0]), dtype=self.settings.sim.u.guess.dtype)
|
|
430
|
+
tau_init = np.array([[0.0]], dtype=np.float64)
|
|
431
|
+
node = np.array([[0]], dtype=np.int64)
|
|
432
|
+
idx_s_stop = self.settings.sim.idx_s.stop
|
|
433
|
+
save_time = np.ones((self.settings.prp.max_tau_len,), dtype=np.float64)
|
|
434
|
+
mask_padded = np.ones((self.settings.prp.max_tau_len,), dtype=bool)
|
|
435
|
+
param_values = [np.ones_like(param.value) if hasattr(param.value, 'shape') else float(param.value) for _, param in self.params.items()]
|
|
436
|
+
self.propagation_solver.call(
|
|
437
|
+
x_0, tau_grid, controls_current, controls_next, tau_init, node, idx_s_stop, save_time, mask_padded, *param_values
|
|
438
|
+
)
|
|
439
|
+
except Exception as e:
|
|
440
|
+
print(f"[Initialization] Priming propagation_solver.call failed: {e}")
|
|
441
|
+
|
|
442
|
+
if self.settings.dev.profiling:
|
|
267
443
|
pr.disable()
|
|
268
444
|
# Save results so it can be viusualized with snakeviz
|
|
269
445
|
pr.dump_stats("profiling_initialize.prof")
|
|
270
446
|
|
|
271
|
-
def
|
|
447
|
+
def step(self):
|
|
448
|
+
"""Performs a single SCP iteration.
|
|
449
|
+
|
|
450
|
+
This method is designed for real-time plotting and interactive optimization.
|
|
451
|
+
It performs one complete SCP iteration including subproblem solving,
|
|
452
|
+
state updates, and progress emission for real-time visualization.
|
|
453
|
+
|
|
454
|
+
Returns:
|
|
455
|
+
dict: Dictionary containing convergence status and current state
|
|
456
|
+
"""
|
|
457
|
+
x = self.settings.sim.x
|
|
458
|
+
u = self.settings.sim.u
|
|
459
|
+
|
|
460
|
+
# Run the subproblem
|
|
461
|
+
x_sol, u_sol, cost, J_total, J_vb_vec, J_vc_vec, J_tr_vec, prob_stat, V_multi_shoot, subprop_time, dis_time = PTR_subproblem(
|
|
462
|
+
self.params.items(),
|
|
463
|
+
self.cpg_solve,
|
|
464
|
+
x,
|
|
465
|
+
u,
|
|
466
|
+
self.discretization_solver,
|
|
467
|
+
self.optimal_control_problem,
|
|
468
|
+
self.settings,
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
# Update state
|
|
472
|
+
self.scp_V_multi_shoot_traj.append(V_multi_shoot)
|
|
473
|
+
x.guess = x_sol
|
|
474
|
+
u.guess = u_sol
|
|
475
|
+
self.scp_trajs.append(x.guess)
|
|
476
|
+
self.scp_controls.append(u.guess)
|
|
477
|
+
|
|
478
|
+
self.scp_J_tr = np.sum(np.array(J_tr_vec))
|
|
479
|
+
self.scp_J_vb = np.sum(np.array(J_vb_vec))
|
|
480
|
+
self.scp_J_vc = np.sum(np.array(J_vc_vec))
|
|
481
|
+
|
|
482
|
+
# Update weights
|
|
483
|
+
self.settings.scp.w_tr = min(self.settings.scp.w_tr * self.settings.scp.w_tr_adapt, self.settings.scp.w_tr_max)
|
|
484
|
+
if self.scp_k > self.settings.scp.cost_drop:
|
|
485
|
+
self.settings.scp.lam_cost = self.settings.scp.lam_cost * self.settings.scp.cost_relax
|
|
486
|
+
|
|
487
|
+
# Emit data
|
|
488
|
+
self.emitter_function(
|
|
489
|
+
{
|
|
490
|
+
"iter": self.scp_k,
|
|
491
|
+
"dis_time": dis_time * 1000.0,
|
|
492
|
+
"subprop_time": subprop_time * 1000.0,
|
|
493
|
+
"J_total": J_total,
|
|
494
|
+
"J_tr": self.scp_J_tr,
|
|
495
|
+
"J_vb": self.scp_J_vb,
|
|
496
|
+
"J_vc": self.scp_J_vc,
|
|
497
|
+
"cost": cost[-1],
|
|
498
|
+
"prob_stat": prob_stat,
|
|
499
|
+
}
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
# Increment counter
|
|
503
|
+
self.scp_k += 1
|
|
504
|
+
|
|
505
|
+
# Create a result dictionary for this step
|
|
506
|
+
return {
|
|
507
|
+
"converged": (self.scp_J_tr < self.settings.scp.ep_tr) and \
|
|
508
|
+
(self.scp_J_vb < self.settings.scp.ep_vb) and \
|
|
509
|
+
(self.scp_J_vc < self.settings.scp.ep_vc),
|
|
510
|
+
"u": u,
|
|
511
|
+
"x": x,
|
|
512
|
+
"V_multi_shoot": V_multi_shoot
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
def solve(self, max_iters: Optional[int] = None, continuous: bool = False) -> OptimizationResults:
|
|
272
516
|
# Ensure parameter sizes and normalization are correct
|
|
273
|
-
self.
|
|
274
|
-
self.
|
|
517
|
+
self.settings.scp.__post_init__()
|
|
518
|
+
self.settings.sim.__post_init__()
|
|
275
519
|
|
|
276
520
|
if self.optimal_control_problem is None or self.discretization_solver is None:
|
|
277
521
|
raise ValueError(
|
|
@@ -279,7 +523,7 @@ class TrajOptProblem:
|
|
|
279
523
|
)
|
|
280
524
|
|
|
281
525
|
# Enable the profiler
|
|
282
|
-
if self.
|
|
526
|
+
if self.settings.dev.profiling:
|
|
283
527
|
import cProfile
|
|
284
528
|
|
|
285
529
|
pr = cProfile.Profile()
|
|
@@ -288,14 +532,13 @@ class TrajOptProblem:
|
|
|
288
532
|
t_0_while = time.time()
|
|
289
533
|
# Print top header for solver results
|
|
290
534
|
io.header()
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
self.
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
)
|
|
535
|
+
|
|
536
|
+
k_max = max_iters if max_iters is not None else self.settings.scp.k_max
|
|
537
|
+
|
|
538
|
+
while self.scp_k <= k_max:
|
|
539
|
+
result = self.step()
|
|
540
|
+
if result["converged"] and not continuous:
|
|
541
|
+
break
|
|
299
542
|
|
|
300
543
|
t_f_while = time.time()
|
|
301
544
|
self.timing_solve = t_f_while - t_0_while
|
|
@@ -304,33 +547,35 @@ class TrajOptProblem:
|
|
|
304
547
|
time.sleep(0.1)
|
|
305
548
|
|
|
306
549
|
# Print bottom footer for solver results as well as total computation time
|
|
307
|
-
io.footer(
|
|
550
|
+
io.footer()
|
|
308
551
|
|
|
309
552
|
# Disable the profiler
|
|
310
|
-
if self.
|
|
553
|
+
if self.settings.dev.profiling:
|
|
311
554
|
pr.disable()
|
|
312
555
|
# Save results so it can be viusualized with snakeviz
|
|
313
556
|
pr.dump_stats("profiling_solve.prof")
|
|
314
557
|
|
|
315
|
-
return
|
|
558
|
+
return format_result(self, self.scp_k <= k_max)
|
|
316
559
|
|
|
317
|
-
def post_process(self, result):
|
|
560
|
+
def post_process(self, result: OptimizationResults) -> OptimizationResults:
|
|
318
561
|
# Enable the profiler
|
|
319
|
-
if self.
|
|
562
|
+
if self.settings.dev.profiling:
|
|
320
563
|
import cProfile
|
|
321
564
|
|
|
322
565
|
pr = cProfile.Profile()
|
|
323
566
|
pr.enable()
|
|
324
567
|
|
|
325
568
|
t_0_post = time.time()
|
|
326
|
-
result = propagate_trajectory_results(self.params, result, self.propagation_solver)
|
|
569
|
+
result = propagate_trajectory_results(self.params, self.settings, result, self.propagation_solver)
|
|
327
570
|
t_f_post = time.time()
|
|
328
571
|
|
|
329
572
|
self.timing_post = t_f_post - t_0_post
|
|
330
|
-
|
|
573
|
+
|
|
574
|
+
# Print results summary
|
|
575
|
+
io.print_results_summary(result, self.timing_post, self.timing_init, self.timing_solve)
|
|
331
576
|
|
|
332
577
|
# Disable the profiler
|
|
333
|
-
if self.
|
|
578
|
+
if self.settings.dev.profiling:
|
|
334
579
|
pr.disable()
|
|
335
580
|
# Save results so it can be viusualized with snakeviz
|
|
336
581
|
pr.dump_stats("profiling_postprocess.prof")
|