openscvx 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

openscvx/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.1.2'
21
- __version_tuple__ = version_tuple = (0, 1, 2)
20
+ __version__ = version = '0.1.3'
21
+ __version_tuple__ = version_tuple = (0, 1, 3)
openscvx/config.py CHANGED
@@ -2,6 +2,8 @@ import numpy as np
2
2
  from dataclasses import dataclass, field
3
3
  from typing import Dict, List
4
4
 
5
+ from openscvx.constraints.boundary import BoundaryConstraint
6
+
5
7
 
6
8
  def get_affine_scaling_matrices(n, minimum, maximum):
7
9
  S = np.diag(np.maximum(np.ones(n), abs(minimum - maximum) / 2))
@@ -120,7 +122,8 @@ class PropagationConfig:
120
122
  class SimConfig:
121
123
  x_bar: np.ndarray
122
124
  u_bar: np.ndarray
123
- initial_state: np.ndarray
125
+ initial_state: BoundaryConstraint
126
+ initial_state_prop: BoundaryConstraint
124
127
  final_state: np.ndarray
125
128
  max_state: np.ndarray
126
129
  min_state: np.ndarray
@@ -128,9 +131,11 @@ class SimConfig:
128
131
  min_control: np.ndarray
129
132
  total_time: float
130
133
  idx_x_true: slice
134
+ idx_x_true_prop: slice
131
135
  idx_u_true: slice
132
136
  idx_t: slice
133
137
  idx_y: slice
138
+ idx_y_prop: slice
134
139
  idx_s: slice
135
140
  ctcs_node_intervals: list = None
136
141
  constraints_ctcs: List[callable] = field(
@@ -138,6 +143,7 @@ class SimConfig:
138
143
  ) # TODO (norrisg): clean this up, consider moving to dedicated `constraints` dataclass
139
144
  constraints_nodal: List[callable] = field(default_factory=list)
140
145
  n_states: int = None
146
+ n_states_prop: int = None
141
147
  n_controls: int = None
142
148
  S_x: np.ndarray = None
143
149
  inv_S_x: np.ndarray = None
@@ -1,4 +1,5 @@
1
1
  import numpy as np
2
+ import jax.numpy as jnp
2
3
 
3
4
  from openscvx.propagation import s_to_t, t_to_tau, simulate_nonlinear_time
4
5
  from openscvx.config import Config
@@ -14,9 +15,13 @@ def propagate_trajectory_results(params: Config, result: dict, propagation_solve
14
15
 
15
16
  tau_vals, u_full = t_to_tau(u, t_full, u, t, params)
16
17
 
17
- x_full = simulate_nonlinear_time(x[0], u, tau_vals, t, params, propagation_solver)
18
+ # Match free values from initial state to the initial value from the result
19
+ mask = jnp.array([t == "Free" for t in params.sim.initial_state_prop.types], dtype=bool)
20
+ params.sim.initial_state_prop.value = jnp.where(mask, x[0], params.sim.initial_state_prop.value)
18
21
 
19
- print("Total CTCS Constraint Violation:", x_full[-1, params.sim.idx_y])
22
+ x_full = simulate_nonlinear_time(params.sim.initial_state_prop.value, u, tau_vals, t, params, propagation_solver)
23
+
24
+ print("Total CTCS Constraint Violation:", x_full[-1, params.sim.idx_y_prop])
20
25
  i = 0
21
26
  cost = np.zeros_like(x[-1, i])
22
27
  for type in params.sim.initial_state.type:
@@ -28,6 +33,16 @@ def propagate_trajectory_results(params: Config, result: dict, propagation_solve
28
33
  if type == "Minimize":
29
34
  cost += x[-1, i]
30
35
  i += 1
36
+ i=0
37
+ for type in params.sim.initial_state.type:
38
+ if type == "Maximize":
39
+ cost -= x[0, i]
40
+ i += 1
41
+ i = 0
42
+ for type in params.sim.final_state.type:
43
+ if type == "Maximize":
44
+ cost -= x[-1, i]
45
+ i += 1
31
46
  print("Cost: ", cost)
32
47
 
33
48
  more_result = dict(t_full=t_full, x_full=x_full, u_full=u_full)
@@ -1,5 +1,5 @@
1
1
  import jax.numpy as jnp
2
- from typing import List, Union
2
+ from typing import List, Union, Optional
3
3
  import queue
4
4
  import threading
5
5
  import time
@@ -23,7 +23,7 @@ from openscvx.augmentation.ctcs import sort_ctcs_constraints
23
23
  from openscvx.constraints.violation import get_g_funcs, CTCSViolation
24
24
  from openscvx.discretization import get_discretization_solver
25
25
  from openscvx.propagation import get_propagation_solver
26
- from openscvx.constraints.boundary import BoundaryConstraint
26
+ from openscvx.constraints.boundary import BoundaryConstraint, boundary
27
27
  from openscvx.constraints.ctcs import CTCSConstraint
28
28
  from openscvx.constraints.nodal import NodalConstraint
29
29
  from openscvx.ptr import PTR_init, PTR_main
@@ -49,17 +49,24 @@ class TrajOptProblem:
49
49
  x_min: jnp.ndarray,
50
50
  u_max: jnp.ndarray,
51
51
  u_min: jnp.ndarray,
52
- scp: ScpConfig = None,
53
- dis: DiscretizationConfig = None,
54
- prp: PropagationConfig = None,
55
- sim: SimConfig = None,
56
- dev: DevConfig = None,
57
- cvx: ConvexSolverConfig = None,
52
+ dynamics_prop: callable = None,
53
+ initial_state_prop: BoundaryConstraint = None,
54
+ scp: Optional[ScpConfig] = None,
55
+ dis: Optional[DiscretizationConfig] = None,
56
+ prp: Optional[PropagationConfig] = None,
57
+ sim: Optional[SimConfig] = None,
58
+ dev: Optional[DevConfig] = None,
59
+ cvx: Optional[ConvexSolverConfig] = None,
58
60
  licq_min=0.0,
59
61
  licq_max=1e-4,
60
62
  time_dilation_factor_min=0.3,
61
63
  time_dilation_factor_max=3.0,
62
64
  ):
65
+ if dynamics_prop is None:
66
+ dynamics_prop = dynamics
67
+
68
+ if initial_state_prop is None:
69
+ initial_state_prop = initial_state
63
70
 
64
71
  # TODO (norrisg) move this into some augmentation function, if we want to make this be executed after the init (i.e. within problem.initialize) need to rethink how problem is defined
65
72
  constraints_ctcs = []
@@ -81,11 +88,15 @@ class TrajOptProblem:
81
88
  constraints_ctcs, node_intervals, num_augmented_states = sort_ctcs_constraints(constraints_ctcs, N)
82
89
 
83
90
  # Index tracking
84
- idx_x_true = slice(0, len(x_max))
91
+ idx_x_true = slice(0, len(initial_state.value))
92
+ idx_x_true_prop = slice(0, len(initial_state_prop.value))
85
93
  idx_u_true = slice(0, len(u_max))
86
94
  idx_constraint_violation = slice(
87
95
  idx_x_true.stop, idx_x_true.stop + num_augmented_states
88
96
  )
97
+ idx_constraint_violation_prop = slice(
98
+ idx_x_true_prop.stop, idx_x_true_prop.stop + num_augmented_states
99
+ )
89
100
 
90
101
  idx_time_dilation = slice(idx_u_true.stop, idx_u_true.stop + 1)
91
102
 
@@ -106,6 +117,11 @@ class TrajOptProblem:
106
117
  [u_guess, np.full((u_guess.shape[0], 1), time_init)]
107
118
  )
108
119
 
120
+ initial_state_prop_values = np.hstack([initial_state_prop.value, np.repeat(licq_min, num_augmented_states)])
121
+ initial_state_prop_types = np.hstack([initial_state_prop.type, ["Fix"] * num_augmented_states])
122
+ initial_state_prop = boundary(initial_state_prop_values)
123
+ initial_state_prop.types = initial_state_prop_types
124
+
109
125
  if dis is None:
110
126
  dis = DiscretizationConfig()
111
127
 
@@ -114,17 +130,21 @@ class TrajOptProblem:
114
130
  x_bar=x_bar_augmented,
115
131
  u_bar=u_bar_augmented,
116
132
  initial_state=initial_state,
133
+ initial_state_prop=initial_state_prop,
117
134
  final_state=final_state,
118
135
  max_state=x_max_augmented,
119
136
  min_state=x_min_augmented,
120
137
  max_control=u_max_augmented,
121
138
  min_control=u_min_augmented,
122
139
  total_time=time_init,
123
- n_states=len(x_max),
140
+ n_states=len(initial_state.value),
141
+ n_states_prop=len(initial_state_prop.value),
124
142
  idx_x_true=idx_x_true,
143
+ idx_x_true_prop=idx_x_true_prop,
125
144
  idx_u_true=idx_u_true,
126
145
  idx_t=idx_time,
127
146
  idx_y=idx_constraint_violation,
147
+ idx_y_prop=idx_constraint_violation_prop,
128
148
  idx_s=idx_time_dilation,
129
149
  ctcs_node_intervals=node_intervals,
130
150
  )
@@ -162,6 +182,7 @@ class TrajOptProblem:
162
182
 
163
183
  ctcs_violation_funcs = get_g_funcs(constraints_ctcs)
164
184
  self.dynamics_augmented = build_augmented_dynamics(dynamics, ctcs_violation_funcs, idx_x_true, idx_u_true)
185
+ self.dynamics_augmented_prop = build_augmented_dynamics(dynamics_prop, ctcs_violation_funcs, idx_x_true_prop, idx_u_true)
165
186
 
166
187
  self.params = Config(
167
188
  sim=sim,
@@ -215,6 +236,9 @@ class TrajOptProblem:
215
236
  self.dynamics_augmented.A = jax.jit(jax.vmap(self.dynamics_augmented.A, in_axes=(0, 0, 0)))
216
237
  self.dynamics_augmented.B = jax.jit(jax.vmap(self.dynamics_augmented.B, in_axes=(0, 0, 0)))
217
238
 
239
+
240
+ self.dynamics_augmented_prop.f = jax.vmap(self.dynamics_augmented_prop.f)
241
+
218
242
  for constraint in self.params.sim.constraints_nodal:
219
243
  if not constraint.convex:
220
244
  # TODO: (haynec) switch to AOT instead of JIT
@@ -224,7 +248,7 @@ class TrajOptProblem:
224
248
 
225
249
  # Generate solvers and optimal control problem
226
250
  self.discretization_solver = get_discretization_solver(self.dynamics_augmented, self.params)
227
- self.propagation_solver = get_propagation_solver(self.dynamics_augmented.f, self.params)
251
+ self.propagation_solver = get_propagation_solver(self.dynamics_augmented_prop.f, self.params)
228
252
  self.optimal_control_problem = OptimalControlProblem(self.params)
229
253
 
230
254
  # Initialize the PTR loop
@@ -248,7 +272,7 @@ class TrajOptProblem:
248
272
  self.propagation_solver = (
249
273
  jax.jit(self.propagation_solver)
250
274
  .lower(
251
- np.ones((self.params.sim.n_states)),
275
+ np.ones((self.params.sim.n_states_prop)),
252
276
  (0.0, 0.0),
253
277
  np.ones((1, self.params.sim.n_controls)),
254
278
  np.ones((1, self.params.sim.n_controls)),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: openscvx
3
- Version: 0.1.2
3
+ Version: 0.1.3
4
4
  Summary: A general Python-based successive convexification implementation which uses a JAX backend.
5
5
  Home-page: https://haynec.github.io/openscvx/
6
6
  Author: Chris Hayner and Griffin Norris
@@ -1,16 +1,16 @@
1
1
  openscvx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- openscvx/_version.py,sha256=bSmADqydH8nBu-J4lG8UVuR7hnU_zcwhnSav2oQ0W0A,511
3
- openscvx/config.py,sha256=8Cl5O0ekf9MGNDTEeMHsp1C4XvY9NfJQkxd80lvnafM,10296
2
+ openscvx/_version.py,sha256=NIzzV8ZM0W-CSLuEs1weG4zPrn_-8yr1AwwI1iuS6yo,511
3
+ openscvx/config.py,sha256=lTdvJtU-Xiwq55zGtW3QvmldZj9Ne0xyKsSDSIKwqiI,10488
4
4
  openscvx/discretization.py,sha256=YF3mEeyYHgyTWQVNQsqpi1Mv72zDLyNfaMJSWqxj34c,4745
5
5
  openscvx/dynamics.py,sha256=X9sPpxUGGbdsnvQzgyrb_939N9ctBSsWVyI1eXtOKpc,1118
6
6
  openscvx/integrators.py,sha256=msIS-1Ehj-9TJLHfoCMs3vdyZ8NXz-TM0RII6aqRf4E,3821
7
7
  openscvx/io.py,sha256=fOvNWQWAegcN1gejeToaNbXenP5H5bAifNU8edJvdk4,4127
8
8
  openscvx/ocp.py,sha256=L_509EQiMsI6s5gBYlYyxKaHEzzRdpo-XAMjliCU3Rc,7544
9
9
  openscvx/plotting.py,sha256=fCvWJV4qWMhVyJlh18s12S_5xhj6EviF-_FuP0tWjx4,31207
10
- openscvx/post_processing.py,sha256=TP1gi4TVlDS2HHpdqaIPCqfM5o4w7a7RCMU3Pu3czHw,1024
10
+ openscvx/post_processing.py,sha256=t3fUeDfA2PZq8S-WVnxDyN0xIgFuvpDP-5wbGQuIOVY,1618
11
11
  openscvx/propagation.py,sha256=XNezQnAM-NXb9L7aHUgKQOBn0CNUPeGGDL3_BbGoODU,3758
12
12
  openscvx/ptr.py,sha256=itDTR6RQUphnU226jaeRaAKuia-6v8U3MqAdw5-BYOk,5268
13
- openscvx/trajoptproblem.py,sha256=3yufy-egU7m0NV834TH8csY1HJqM90Is7VYw0gQe3pk,11996
13
+ openscvx/trajoptproblem.py,sha256=hcju3UB3iZ5peqhxh6NvDQYnsArnabR0mGUAKZ3sIaw,13360
14
14
  openscvx/utils.py,sha256=zmkKyto8Jowe_RAdOe8K0w6gzOu4JfxmX1RUL-3OFlY,2408
15
15
  openscvx/augmentation/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  openscvx/augmentation/ctcs.py,sha256=m1jdALXSqHq3WD6lCBAUI7FR0Sfs8aCYr66h0EwE4z4,1707
@@ -20,8 +20,8 @@ openscvx/constraints/boundary.py,sha256=yEhEnkKJ5f8NUeTksigEJjgBeE_YyuG_PJb_DWxg
20
20
  openscvx/constraints/ctcs.py,sha256=V763033aV82nAu7y4653KsAs11A7RpUysR_oUcnLfko,2572
21
21
  openscvx/constraints/nodal.py,sha256=YCS0cwUurA2OTQcHBb1EQqLxNt_w3MX8Nj8FH3GYClo,1726
22
22
  openscvx/constraints/violation.py,sha256=aIdDhHd-UndT0XB2QeuwLBKSNSAUWVkha_GeHOw9cQg,2362
23
- openscvx-0.1.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
24
- openscvx-0.1.2.dist-info/METADATA,sha256=MDHeKrpE_3FKRiQD5fVKVzNBWerOvcY0vfapGSRTlbk,6911
25
- openscvx-0.1.2.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
26
- openscvx-0.1.2.dist-info/top_level.txt,sha256=nUT4Ybefzh40H8tVXqc1RzKESy_MAowElb-CIvAbd4Q,9
27
- openscvx-0.1.2.dist-info/RECORD,,
23
+ openscvx-0.1.3.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
24
+ openscvx-0.1.3.dist-info/METADATA,sha256=CVr7tDDV75qDlnuEuxf-IL0Arxkb1RXwv0Xj9oVxf_s,6911
25
+ openscvx-0.1.3.dist-info/WHEEL,sha256=QZxptf4Y1BKFRCEDxD4h2V0mBFQOVFLFEpvxHmIs52A,91
26
+ openscvx-0.1.3.dist-info/top_level.txt,sha256=nUT4Ybefzh40H8tVXqc1RzKESy_MAowElb-CIvAbd4Q,9
27
+ openscvx-0.1.3.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.4.0)
2
+ Generator: setuptools (80.6.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5