openscvx 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

@@ -0,0 +1,135 @@
1
+ import numpy as np
2
+
3
+ from openscvx.config import Config
4
+ from openscvx.integrators import solve_ivp_diffrax_prop
5
+
6
+
7
+ def prop_aug_dy(
8
+ tau: float,
9
+ x: np.ndarray,
10
+ u_current: np.ndarray,
11
+ u_next: np.ndarray,
12
+ tau_init: float,
13
+ node: int,
14
+ idx_s: int,
15
+ state_dot: callable,
16
+ dis_type: str,
17
+ N: int,
18
+ ) -> np.ndarray:
19
+ x = x[None, :]
20
+
21
+ if dis_type == "ZOH":
22
+ beta = 0.0
23
+ elif dis_type == "FOH":
24
+ beta = (tau - tau_init) * N
25
+ u = u_current + beta * (u_next - u_current)
26
+
27
+ return u[:, idx_s] * state_dot(x, u[:, :-1], node).squeeze()
28
+
29
+
30
+ def get_propagation_solver(state_dot, params):
31
+ def propagation_solver(V0, tau_grid, u_cur, u_next, tau_init, node, idx_s):
32
+ return solve_ivp_diffrax_prop(
33
+ f=prop_aug_dy,
34
+ tau_final=tau_grid[1],
35
+ y_0=V0,
36
+ args=(
37
+ u_cur,
38
+ u_next,
39
+ tau_init,
40
+ node,
41
+ idx_s,
42
+ state_dot,
43
+ params.dis.dis_type,
44
+ params.scp.n,
45
+ ),
46
+ tau_0=tau_grid[0],
47
+ )
48
+
49
+ return propagation_solver
50
+
51
+
52
+ def s_to_t(u, params: Config):
53
+ t = [0]
54
+ tau = np.linspace(0, 1, params.scp.n)
55
+ for k in range(1, params.scp.n):
56
+ s_kp = u[k - 1, -1]
57
+ s_k = u[k, -1]
58
+ if params.dis.dis_type == "ZOH":
59
+ t.append(t[k - 1] + (tau[k] - tau[k - 1]) * (s_kp))
60
+ else:
61
+ t.append(t[k - 1] + 0.5 * (s_k + s_kp) * (tau[k] - tau[k - 1]))
62
+ return t
63
+
64
+
65
+ def t_to_tau(u, t, u_nodal, t_nodal, params: Config):
66
+ u_lam = lambda new_t: np.array(
67
+ [np.interp(new_t, t_nodal, u[:, i]) for i in range(u.shape[1])]
68
+ ).T
69
+ u = np.array([u_lam(t_i) for t_i in t])
70
+
71
+ tau = np.zeros(len(t))
72
+ tau_nodal = np.linspace(0, 1, params.scp.n)
73
+ for k in range(1, len(t)):
74
+ k_nodal = np.where(t_nodal < t[k])[0][-1]
75
+ s_kp = u_nodal[k_nodal, -1]
76
+ tp = t_nodal[k_nodal]
77
+ tau_p = tau_nodal[k_nodal]
78
+
79
+ s_k = u[k, -1]
80
+ if params.dis.dis_type == "ZOH":
81
+ tau[k] = tau_p + (t[k] - tp) / s_kp
82
+ else:
83
+ tau[k] = tau_p + 2 * (t[k] - tp) / (s_k + s_kp)
84
+ return tau, u
85
+
86
+
87
+ def simulate_nonlinear_time(x_0, u, tau_vals, t, params, propagation_solver):
88
+ states = np.empty(
89
+ (x_0.shape[0], 0)
90
+ ) # Initialize states as a 2D array with shape (n, 0)
91
+
92
+ tau = np.linspace(0, 1, params.scp.n)
93
+
94
+ u_lam = lambda new_t: np.array(
95
+ [np.interp(new_t, t, u[:, i]) for i in range(u.shape[1])]
96
+ ).T
97
+
98
+ # Bin the tau_vals into with respect to the uniform tau grid, tau
99
+ tau_inds = np.digitize(tau_vals, tau) - 1
100
+ # Force the last indice to be in the same bin as the previous ones
101
+ tau_inds = np.where(tau_inds == params.scp.n - 1, params.scp.n - 2, tau_inds)
102
+
103
+ prev_count = 0
104
+
105
+ for k in range(params.scp.n - 1):
106
+ controls_current = np.squeeze(u_lam(t[k]))[None, :]
107
+ controls_next = np.squeeze(u_lam(t[k + 1]))[None, :]
108
+
109
+ # Create a mask
110
+ mask = (tau_inds >= k) & (tau_inds < k + 1)
111
+
112
+ count = np.sum(mask)
113
+
114
+ # Use count to grab the first count number of elements
115
+ tau_cur = tau_vals[prev_count : prev_count + count]
116
+
117
+ sol = propagation_solver(
118
+ x_0,
119
+ (tau[k], tau[k + 1]),
120
+ controls_current,
121
+ controls_next,
122
+ np.array([[tau[k]]]),
123
+ np.array([[k]]),
124
+ params.sim.idx_s.stop,
125
+ )
126
+
127
+ x = sol.ys
128
+ for tau_i in tau_cur:
129
+ new_state = sol.evaluate(tau_i).reshape(-1, 1) # Ensure new_state is 2D
130
+ states = np.concatenate([states, new_state], axis=1)
131
+
132
+ x_0 = x[-1]
133
+ prev_count += count
134
+
135
+ return states.T
openscvx/ptr.py ADDED
@@ -0,0 +1,149 @@
1
+ import numpy as np
2
+ import numpy.linalg as la
3
+ import cvxpy as cp
4
+ import pickle
5
+ import time
6
+
7
+ from openscvx.config import Config
8
+
9
+ import warnings
10
+ warnings.filterwarnings("ignore")
11
+
12
+ def PTR_init(ocp: cp.Problem, discretization_solver: callable, params: Config):
13
+ if params.cvx.cvxpygen:
14
+ from solver.cpg_solver import cpg_solve
15
+ with open('solver/problem.pickle', 'rb') as f:
16
+ prob = pickle.load(f)
17
+ else:
18
+ cpg_solve = None
19
+
20
+ # Solve a dumb problem to intilize DPP and JAX jacobians
21
+ _ = PTR_subproblem(cpg_solve, params.sim.x_bar, params.sim.u_bar, discretization_solver, ocp, params)
22
+
23
+ return cpg_solve
24
+
25
+ def PTR_main(params: Config, prob: cp.Problem, aug_dy: callable, cpg_solve, emitter_function) -> dict:
26
+ J_vb = 1E2
27
+ J_vc = 1E2
28
+ J_tr = 1E2
29
+
30
+ x_bar = params.sim.x_bar
31
+ u_bar = params.sim.u_bar
32
+
33
+ scp_trajs = [x_bar]
34
+ scp_controls = [u_bar]
35
+ V_multi_shoot_traj = []
36
+
37
+ k = 1
38
+
39
+ while k <= params.scp.k_max and ((J_tr >= params.scp.ep_tr) or (J_vb >= params.scp.ep_vb) or (J_vc >= params.scp.ep_vc)):
40
+ x, u, t, J_total, J_vb_vec, J_vc_vec, J_tr_vec, prob_stat, V_multi_shoot, subprop_time, dis_time = PTR_subproblem(cpg_solve, x_bar, u_bar, aug_dy, prob, params)
41
+
42
+ V_multi_shoot_traj.append(V_multi_shoot)
43
+
44
+ x_bar = x
45
+ u_bar = u
46
+
47
+ J_tr = np.sum(np.array(J_tr_vec))
48
+ J_vb = np.sum(np.array(J_vb_vec))
49
+ J_vc = np.sum(np.array(J_vc_vec))
50
+ scp_trajs.append(x)
51
+ scp_controls.append(u)
52
+
53
+ params.scp.w_tr = min(params.scp.w_tr * params.scp.w_tr_adapt, params.scp.w_tr_max)
54
+ if k > params.scp.cost_drop:
55
+ params.scp.lam_cost = params.scp.lam_cost * params.scp.cost_relax
56
+
57
+ emitter_function(
58
+ {
59
+ "iter": k,
60
+ "dis_time": dis_time * 1000.0,
61
+ "subprop_time": subprop_time * 1000.0,
62
+ "J_total": J_total,
63
+ "J_tr": J_tr,
64
+ "J_vb": J_vb,
65
+ "J_vc": J_vc,
66
+ "cost": t[-1],
67
+ "prob_stat": prob_stat,
68
+ }
69
+ )
70
+
71
+ k += 1
72
+
73
+ result = dict(
74
+ converged = k <= params.scp.k_max,
75
+ t_final = x[:,params.sim.idx_t][-1],
76
+ u = u,
77
+ x = x,
78
+ x_history = scp_trajs,
79
+ u_history = scp_controls,
80
+ discretization_history = V_multi_shoot_traj,
81
+ J_tr_history = J_tr_vec,
82
+ J_vb_history = J_vb_vec,
83
+ J_vc_history = J_vc_vec,
84
+ )
85
+ return result
86
+
87
+
88
+ def PTR_subproblem(cpg_solve, x_bar, u_bar, aug_dy, prob, params: Config):
89
+ prob.param_dict['x_bar'].value = x_bar
90
+ prob.param_dict['u_bar'].value = u_bar
91
+
92
+ t0 = time.time()
93
+ A_bar, B_bar, C_bar, z_bar, V_multi_shoot = aug_dy(x_bar, u_bar.astype(float))
94
+
95
+ prob.param_dict['A_d'].value = A_bar.__array__()
96
+ prob.param_dict['B_d'].value = B_bar.__array__()
97
+ prob.param_dict['C_d'].value = C_bar.__array__()
98
+ prob.param_dict['z_d'].value = z_bar.__array__()
99
+ dis_time = time.time() - t0
100
+
101
+ if params.sim.constraints_nodal:
102
+ for g_id, constraint in enumerate(params.sim.constraints_nodal):
103
+ if not constraint.convex:
104
+ prob.param_dict['g_' + str(g_id)].value = np.asarray(constraint.g(x_bar, u_bar))
105
+ prob.param_dict['grad_g_x_' + str(g_id)].value = np.asarray(constraint.grad_g_x(x_bar, u_bar))
106
+ prob.param_dict['grad_g_u_' + str(g_id)].value = np.asarray(constraint.grad_g_u(x_bar, u_bar))
107
+
108
+ prob.param_dict['w_tr'].value = params.scp.w_tr
109
+ prob.param_dict['lam_cost'].value = params.scp.lam_cost
110
+
111
+ if params.cvx.cvxpygen:
112
+ t0 = time.time()
113
+ prob.register_solve('CPG', cpg_solve)
114
+ prob.solve(method = 'CPG', **params.cvx.solver_args)
115
+ subprop_time = time.time() - t0
116
+ else:
117
+ t0 = time.time()
118
+ prob.solve(solver = params.cvx.solver, enforce_dpp = True, **params.cvx.solver_args)
119
+ subprop_time = time.time() - t0
120
+
121
+ x = (params.sim.S_x @ prob.var_dict['x'].value.T + np.expand_dims(params.sim.c_x, axis = 1)).T
122
+ u = (params.sim.S_u @ prob.var_dict['u'].value.T + np.expand_dims(params.sim.c_u, axis = 1)).T
123
+
124
+ i = 0
125
+ costs = [0]
126
+ for type in params.sim.final_state.type:
127
+ if type == 'Minimize':
128
+ costs += x[:,i]
129
+ if type == 'Maximize':
130
+ costs -= x[:,i]
131
+ i += 1
132
+
133
+ # Create the block diagonal matrix using jax.numpy.block
134
+ inv_block_diag = np.block([
135
+ [params.sim.inv_S_x, np.zeros((params.sim.inv_S_x.shape[0], params.sim.inv_S_u.shape[1]))],
136
+ [np.zeros((params.sim.inv_S_u.shape[0], params.sim.inv_S_x.shape[1])), params.sim.inv_S_u]
137
+ ])
138
+
139
+ # Calculate J_tr_vec using the JAX-compatible block diagonal matrix
140
+ J_tr_vec = la.norm(inv_block_diag @ np.hstack((x - x_bar, u - u_bar)).T, axis=0)**2
141
+ J_vc_vec = np.sum(np.abs(prob.var_dict['nu'].value), axis = 1)
142
+
143
+ id_ncvx = 0
144
+ J_vb_vec = 0
145
+ for constraint in params.sim.constraints_nodal:
146
+ if constraint.convex == False:
147
+ J_vb_vec += np.maximum(0, prob.var_dict['nu_vb_' + str(id_ncvx)].value)
148
+ id_ncvx += 1
149
+ return x, u, costs, prob.value, J_vb_vec, J_vc_vec, J_tr_vec, prob.status, V_multi_shoot, subprop_time, dis_time
@@ -0,0 +1,336 @@
1
+ import jax.numpy as jnp
2
+ from typing import List
3
+ import queue
4
+ import threading
5
+ import time
6
+
7
+ import cvxpy as cp
8
+ import jax
9
+ import numpy as np
10
+
11
+ from openscvx.config import (
12
+ ScpConfig,
13
+ SimConfig,
14
+ ConvexSolverConfig,
15
+ DiscretizationConfig,
16
+ PropagationConfig,
17
+ DevConfig,
18
+ Config,
19
+ )
20
+ from openscvx.dynamics import get_augmented_dynamics, get_jacobians
21
+ from openscvx.constraints.violation import get_g_funcs
22
+ from openscvx.augmentation import sort_ctcs_constraints
23
+ from openscvx.discretization import get_discretization_solver
24
+ from openscvx.propagation import get_propagation_solver
25
+ from openscvx.constraints.boundary import BoundaryConstraint
26
+ from openscvx.constraints.ctcs import CTCSConstraint
27
+ from openscvx.constraints.nodal import NodalConstraint
28
+ from openscvx.ptr import PTR_init, PTR_main
29
+ from openscvx.post_processing import propagate_trajectory_results
30
+ from openscvx.ocp import OptimalControlProblem
31
+ from openscvx import io
32
+
33
+
34
+ # TODO: (norrisg) Decide whether to have constraints`, `cost`, alongside `dynamics`, ` etc.
35
+ class TrajOptProblem:
36
+ def __init__(
37
+ self,
38
+ dynamics: callable,
39
+ constraints: List[callable],
40
+ idx_time: int,
41
+ N: int,
42
+ time_init: float,
43
+ x_guess: jnp.ndarray,
44
+ u_guess: jnp.ndarray,
45
+ initial_state: BoundaryConstraint,
46
+ final_state: BoundaryConstraint,
47
+ x_max: jnp.ndarray,
48
+ x_min: jnp.ndarray,
49
+ u_max: jnp.ndarray,
50
+ u_min: jnp.ndarray,
51
+ scp: ScpConfig = None,
52
+ dis: DiscretizationConfig = None,
53
+ prp: PropagationConfig = None,
54
+ sim: SimConfig = None,
55
+ dev: DevConfig = None,
56
+ cvx: ConvexSolverConfig = None,
57
+ licq_min=0.0,
58
+ licq_max=1e-4,
59
+ time_dilation_factor_min=0.3,
60
+ time_dilation_factor_max=3.0,
61
+ ):
62
+
63
+ # TODO (norrisg) move this into some augmentation function, if we want to make this be executed after the init (i.e. within problem.initialize) need to rethink how problem is defined
64
+ constraints_ctcs = []
65
+ constraints_nodal = []
66
+ # TODO: (norrisg) change back to using isinstance once on PyPi
67
+ for constraint in constraints:
68
+ if type(constraint).__name__ == CTCSConstraint.__name__:
69
+ constraints_ctcs.append(
70
+ constraint
71
+ )
72
+ elif type(constraint).__name__ == NodalConstraint.__name__:
73
+ constraints_nodal.append(
74
+ constraint
75
+ )
76
+ else:
77
+ raise ValueError(
78
+ f"Unknown constraint type: {type(constraint)}, All constraints must be decorated with @ctcs or @nodal"
79
+ )
80
+
81
+ constraints_ctcs, node_intervals, num_augmented_states = sort_ctcs_constraints(constraints_ctcs, N)
82
+
83
+ # Index tracking
84
+ idx_x_true = slice(0, len(x_max))
85
+ idx_u_true = slice(0, len(u_max))
86
+ idx_constraint_violation = slice(
87
+ idx_x_true.stop, idx_x_true.stop + num_augmented_states
88
+ )
89
+
90
+ idx_time_dilation = slice(idx_u_true.stop, idx_u_true.stop + 1)
91
+
92
+ # check that idx_time is in the correct range
93
+ assert idx_time >= 0 and idx_time < len(
94
+ x_max
95
+ ), "idx_time must be in the range of the state vector and non-negative"
96
+ idx_time = slice(idx_time, idx_time + 1)
97
+
98
+ x_min_augmented = np.hstack([x_min, np.repeat(licq_min, num_augmented_states)])
99
+ x_max_augmented = np.hstack([x_max, np.repeat(licq_max, num_augmented_states)])
100
+
101
+ u_min_augmented = np.hstack([u_min, time_dilation_factor_min * time_init])
102
+ u_max_augmented = np.hstack([u_max, time_dilation_factor_max * time_init])
103
+
104
+ x_bar_augmented = np.hstack([x_guess, np.full((x_guess.shape[0], num_augmented_states), 0)])
105
+ u_bar_augmented = np.hstack(
106
+ [u_guess, np.full((u_guess.shape[0], 1), time_init)]
107
+ )
108
+
109
+ if dis is None:
110
+ dis = DiscretizationConfig()
111
+
112
+ if sim is None:
113
+ sim = SimConfig(
114
+ x_bar=x_bar_augmented,
115
+ u_bar=u_bar_augmented,
116
+ initial_state=initial_state,
117
+ final_state=final_state,
118
+ max_state=x_max_augmented,
119
+ min_state=x_min_augmented,
120
+ max_control=u_max_augmented,
121
+ min_control=u_min_augmented,
122
+ total_time=time_init,
123
+ n_states=len(x_max),
124
+ idx_x_true=idx_x_true,
125
+ idx_u_true=idx_u_true,
126
+ idx_t=idx_time,
127
+ idx_y=idx_constraint_violation,
128
+ idx_s=idx_time_dilation,
129
+ ctcs_node_intervals=node_intervals,
130
+ )
131
+
132
+ if scp is None:
133
+ scp = ScpConfig(
134
+ n=N,
135
+ k_max=200,
136
+ w_tr=1e1, # Weight on the Trust Reigon
137
+ lam_cost=1e1, # Weight on the Nonlinear Cost
138
+ lam_vc=1e2, # Weight on the Virtual Control Objective
139
+ lam_vb=0e0, # Weight on the Virtual Buffer Objective (only for penalized nodal constraints)
140
+ ep_tr=1e-4, # Trust Region Tolerance
141
+ ep_vb=1e-4, # Virtual Control Tolerance
142
+ ep_vc=1e-8, # Virtual Control Tolerance for CTCS
143
+ cost_drop=4, # SCP iteration to relax minimal final time objective
144
+ cost_relax=0.5, # Minimal Time Relaxation Factor
145
+ w_tr_adapt=1.2, # Trust Region Adaptation Factor
146
+ w_tr_max_scaling_factor=1e2, # Maximum Trust Region Weight
147
+ )
148
+ else:
149
+ assert (
150
+ self.scp.n == N
151
+ ), "Number of segments must be the same as in the config"
152
+
153
+ if dev is None:
154
+ dev = DevConfig()
155
+ if cvx is None:
156
+ cvx = ConvexSolverConfig()
157
+ if prp is None:
158
+ prp = PropagationConfig()
159
+
160
+ sim.constraints_ctcs = constraints_ctcs
161
+ sim.constraints_nodal = constraints_nodal
162
+
163
+ g_funcs = get_g_funcs(constraints_ctcs)
164
+ self.dynamics_augmented = get_augmented_dynamics(dynamics, g_funcs, idx_x_true, idx_u_true)
165
+ self.A_uncompiled, self.B_uncompiled = get_jacobians(self.dynamics_augmented)
166
+
167
+ self.params = Config(
168
+ sim=sim,
169
+ scp=scp,
170
+ dis=dis,
171
+ dev=dev,
172
+ cvx=cvx,
173
+ prp=prp,
174
+ )
175
+
176
+ self.optimal_control_problem: cp.Problem = None
177
+ self.discretization_solver: callable = None
178
+ self.cpg_solve = None
179
+
180
+ # set up emitter & thread only if printing is enabled
181
+ if self.params.dev.printing:
182
+ self.print_queue = queue.Queue()
183
+ self.emitter_function = lambda data: self.print_queue.put(data)
184
+ self.print_thread = threading.Thread(
185
+ target=io.intermediate,
186
+ args=(self.print_queue, self.params),
187
+ daemon=True,
188
+ )
189
+ self.print_thread.start()
190
+ else:
191
+ # no-op emitter; nothing ever gets queued or printed
192
+ self.emitter_function = lambda data: None
193
+
194
+
195
+ self.timing_init = None
196
+ self.timing_solve = None
197
+ self.timing_post = None
198
+
199
+ def initialize(self):
200
+ io.intro()
201
+
202
+ # Enable the profiler
203
+ if self.params.dev.profiling:
204
+ import cProfile
205
+
206
+ pr = cProfile.Profile()
207
+ pr.enable()
208
+
209
+ t_0_while = time.time()
210
+ # Ensure parameter sizes and normalization are correct
211
+ self.params.scp.__post_init__()
212
+ self.params.sim.__post_init__()
213
+
214
+ # Compile dynamics and jacobians
215
+ self.state_dot = jax.vmap(self.dynamics_augmented)
216
+ self.A = jax.jit(jax.vmap(self.A_uncompiled, in_axes=(0, 0, 0)))
217
+ self.B = jax.jit(jax.vmap(self.B_uncompiled, in_axes=(0, 0, 0)))
218
+ # TODO: (norrisg) Could consider using dataclass just to hold dynamics and jacobians
219
+ # TODO: (norrisg) Consider writing the compiled versions into the same variables?
220
+ # Otherwise if have a dataclass could have 2 instances, one for compied and one for uncompiled
221
+
222
+ # Generate solvers and optimal control problem
223
+ self.discretization_solver = get_discretization_solver(
224
+ self.state_dot, self.A, self.B, self.params
225
+ )
226
+ self.propagation_solver = get_propagation_solver(self.state_dot, self.params)
227
+ self.optimal_control_problem = OptimalControlProblem(self.params)
228
+
229
+ # Initialize the PTR loop
230
+ self.cpg_solve = PTR_init(
231
+ self.optimal_control_problem,
232
+ self.discretization_solver,
233
+ self.params,
234
+ )
235
+
236
+ # Compile the solvers
237
+ if not self.params.dev.debug:
238
+ self.discretization_solver = (
239
+ jax.jit(self.discretization_solver)
240
+ .lower(
241
+ np.ones((self.params.scp.n, self.params.sim.n_states)),
242
+ np.ones((self.params.scp.n, self.params.sim.n_controls)),
243
+ )
244
+ .compile()
245
+ )
246
+
247
+ self.propagation_solver = (
248
+ jax.jit(self.propagation_solver)
249
+ .lower(
250
+ np.ones((self.params.sim.n_states)),
251
+ (0.0, 0.0),
252
+ np.ones((1, self.params.sim.n_controls)),
253
+ np.ones((1, self.params.sim.n_controls)),
254
+ np.ones((1, 1)),
255
+ np.ones((1, 1)).astype("int"),
256
+ 0,
257
+ )
258
+ .compile()
259
+ )
260
+
261
+ t_f_while = time.time()
262
+ self.timing_init = t_f_while - t_0_while
263
+ print("Total Initialization Time: ", self.timing_init)
264
+
265
+ if self.params.dev.profiling:
266
+ pr.disable()
267
+ # Save results so it can be viusualized with snakeviz
268
+ pr.dump_stats("profiling_initialize.prof")
269
+
270
+ def solve(self):
271
+ # Ensure parameter sizes and normalization are correct
272
+ self.params.scp.__post_init__()
273
+ self.params.sim.__post_init__()
274
+
275
+ if self.optimal_control_problem is None or self.discretization_solver is None:
276
+ raise ValueError(
277
+ "Problem has not been initialized. Call initialize() before solve()"
278
+ )
279
+
280
+ # Enable the profiler
281
+ if self.params.dev.profiling:
282
+ import cProfile
283
+
284
+ pr = cProfile.Profile()
285
+ pr.enable()
286
+
287
+ t_0_while = time.time()
288
+ # Print top header for solver results
289
+ io.header()
290
+
291
+ result = PTR_main(
292
+ self.params,
293
+ self.optimal_control_problem,
294
+ self.discretization_solver,
295
+ self.cpg_solve,
296
+ self.emitter_function,
297
+ )
298
+
299
+ t_f_while = time.time()
300
+ self.timing_solve = t_f_while - t_0_while
301
+
302
+ while self.print_queue.qsize() > 0:
303
+ time.sleep(0.1)
304
+
305
+ # Print bottom footer for solver results as well as total computation time
306
+ io.footer(self.timing_solve)
307
+
308
+ # Disable the profiler
309
+ if self.params.dev.profiling:
310
+ pr.disable()
311
+ # Save results so it can be viusualized with snakeviz
312
+ pr.dump_stats("profiling_solve.prof")
313
+
314
+ return result
315
+
316
+ def post_process(self, result):
317
+ # Enable the profiler
318
+ if self.params.dev.profiling:
319
+ import cProfile
320
+
321
+ pr = cProfile.Profile()
322
+ pr.enable()
323
+
324
+ t_0_post = time.time()
325
+ result = propagate_trajectory_results(self.params, result, self.propagation_solver)
326
+ t_f_post = time.time()
327
+
328
+ self.timing_post = t_f_post - t_0_post
329
+ print("Total Post Processing Time: ", self.timing_post)
330
+
331
+ # Disable the profiler
332
+ if self.params.dev.profiling:
333
+ pr.disable()
334
+ # Save results so it can be viusualized with snakeviz
335
+ pr.dump_stats("profiling_postprocess.prof")
336
+ return result
openscvx/utils.py ADDED
@@ -0,0 +1,80 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import numpy as np
4
+
5
+
6
+ def qdcm(q: jnp.ndarray) -> jnp.ndarray:
7
+ # Convert a quaternion to a direction cosine matrix
8
+ q_norm = (q[0] ** 2 + q[1] ** 2 + q[2] ** 2 + q[3] ** 2) ** 0.5
9
+ w, x, y, z = q / q_norm
10
+ return jnp.array(
11
+ [
12
+ [1 - 2 * (y**2 + z**2), 2 * (x * y - z * w), 2 * (x * z + y * w)],
13
+ [2 * (x * y + z * w), 1 - 2 * (x**2 + z**2), 2 * (y * z - x * w)],
14
+ [2 * (x * z - y * w), 2 * (y * z + x * w), 1 - 2 * (x**2 + y**2)],
15
+ ]
16
+ )
17
+
18
+
19
+ def SSMP(w: jnp.ndarray):
20
+ # Convert an angular rate to a 4 x 4 skew symetric matrix
21
+ x, y, z = w
22
+ return jnp.array([[0, -x, -y, -z], [x, 0, z, -y], [y, -z, 0, x], [z, y, -x, 0]])
23
+
24
+
25
+ def SSM(w: jnp.ndarray):
26
+ # Convert an angular rate to a 3 x 3 skew symetric matrix
27
+ x, y, z = w
28
+ return jnp.array([[0, -z, y], [z, 0, -x], [-y, x, 0]])
29
+
30
+
31
+ def generate_orthogonal_unit_vectors(vectors=None):
32
+ """
33
+ Generates 3 orthogonal unit vectors to model the axis of the ellipsoid via QR decomposition
34
+
35
+ Parameters:
36
+ vectors (np.ndarray): Optional, axes of the ellipsoid to be orthonormalized.
37
+ If none specified generates randomly.
38
+
39
+ Returns:
40
+ np.ndarray: A 3x3 matrix where each column is a unit vector.
41
+ """
42
+ if vectors is None:
43
+ # Create a random key
44
+ key = jax.random.PRNGKey(0)
45
+
46
+ # Generate a 3x3 array of random numbers uniformly distributed between 0 and 1
47
+ vectors = jax.random.uniform(key, (3, 3))
48
+ Q, _ = jnp.linalg.qr(vectors)
49
+ return Q
50
+
51
+
52
+ rot = np.array(
53
+ [
54
+ [np.cos(np.pi / 2), np.sin(np.pi / 2), 0],
55
+ [-np.sin(np.pi / 2), np.cos(np.pi / 2), 0],
56
+ [0, 0, 1],
57
+ ]
58
+ )
59
+ def gen_vertices(center, radii):
60
+ """
61
+ Obtains the vertices of the gate.
62
+ """
63
+ vertices = []
64
+ vertices.append(center + rot @ [radii[0], 0, radii[2]])
65
+ vertices.append(center + rot @ [-radii[0], 0, radii[2]])
66
+ vertices.append(center + rot @ [-radii[0], 0, -radii[2]])
67
+ vertices.append(center + rot @ [radii[0], 0, -radii[2]])
68
+ return vertices
69
+
70
+
71
+ # TODO (haynec): make this less hardcoded
72
+ def get_kp_pose(t, init_pose):
73
+ loop_time = 40.0
74
+ loop_radius = 20.0
75
+
76
+ t_angle = t / loop_time * (2 * jnp.pi)
77
+ x = loop_radius * jnp.sin(t_angle)
78
+ y = x * jnp.cos(t_angle)
79
+ z = 0.5 * x * jnp.sin(t_angle)
80
+ return jnp.array([x, y, z]).T + init_pose