openscvx 0.3.2.dev170__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openscvx might be problematic. Click here for more details.
- openscvx/__init__.py +123 -0
- openscvx/_version.py +34 -0
- openscvx/algorithms/__init__.py +92 -0
- openscvx/algorithms/autotuning.py +24 -0
- openscvx/algorithms/base.py +351 -0
- openscvx/algorithms/optimization_results.py +215 -0
- openscvx/algorithms/penalized_trust_region.py +384 -0
- openscvx/config.py +437 -0
- openscvx/discretization/__init__.py +47 -0
- openscvx/discretization/discretization.py +236 -0
- openscvx/expert/__init__.py +23 -0
- openscvx/expert/byof.py +326 -0
- openscvx/expert/lowering.py +419 -0
- openscvx/expert/validation.py +357 -0
- openscvx/integrators/__init__.py +48 -0
- openscvx/integrators/runge_kutta.py +281 -0
- openscvx/lowered/__init__.py +30 -0
- openscvx/lowered/cvxpy_constraints.py +23 -0
- openscvx/lowered/cvxpy_variables.py +124 -0
- openscvx/lowered/dynamics.py +34 -0
- openscvx/lowered/jax_constraints.py +133 -0
- openscvx/lowered/parameters.py +54 -0
- openscvx/lowered/problem.py +70 -0
- openscvx/lowered/unified.py +718 -0
- openscvx/plotting/__init__.py +63 -0
- openscvx/plotting/plotting.py +756 -0
- openscvx/plotting/scp_iteration.py +299 -0
- openscvx/plotting/viser/__init__.py +126 -0
- openscvx/plotting/viser/animated.py +605 -0
- openscvx/plotting/viser/plotly_integration.py +333 -0
- openscvx/plotting/viser/primitives.py +355 -0
- openscvx/plotting/viser/scp.py +459 -0
- openscvx/plotting/viser/server.py +112 -0
- openscvx/problem.py +734 -0
- openscvx/propagation/__init__.py +60 -0
- openscvx/propagation/post_processing.py +104 -0
- openscvx/propagation/propagation.py +248 -0
- openscvx/solvers/__init__.py +51 -0
- openscvx/solvers/cvxpy.py +226 -0
- openscvx/symbolic/__init__.py +9 -0
- openscvx/symbolic/augmentation.py +630 -0
- openscvx/symbolic/builder.py +492 -0
- openscvx/symbolic/constraint_set.py +92 -0
- openscvx/symbolic/expr/__init__.py +222 -0
- openscvx/symbolic/expr/arithmetic.py +517 -0
- openscvx/symbolic/expr/array.py +632 -0
- openscvx/symbolic/expr/constraint.py +796 -0
- openscvx/symbolic/expr/control.py +135 -0
- openscvx/symbolic/expr/expr.py +720 -0
- openscvx/symbolic/expr/lie/__init__.py +87 -0
- openscvx/symbolic/expr/lie/adjoint.py +357 -0
- openscvx/symbolic/expr/lie/se3.py +172 -0
- openscvx/symbolic/expr/lie/so3.py +138 -0
- openscvx/symbolic/expr/linalg.py +279 -0
- openscvx/symbolic/expr/math.py +699 -0
- openscvx/symbolic/expr/spatial.py +209 -0
- openscvx/symbolic/expr/state.py +607 -0
- openscvx/symbolic/expr/stl.py +136 -0
- openscvx/symbolic/expr/variable.py +321 -0
- openscvx/symbolic/hashing.py +112 -0
- openscvx/symbolic/lower.py +760 -0
- openscvx/symbolic/lowerers/__init__.py +106 -0
- openscvx/symbolic/lowerers/cvxpy.py +1302 -0
- openscvx/symbolic/lowerers/jax.py +1382 -0
- openscvx/symbolic/preprocessing.py +757 -0
- openscvx/symbolic/problem.py +110 -0
- openscvx/symbolic/time.py +116 -0
- openscvx/symbolic/unified.py +420 -0
- openscvx/utils/__init__.py +20 -0
- openscvx/utils/cache.py +131 -0
- openscvx/utils/caching.py +210 -0
- openscvx/utils/printing.py +301 -0
- openscvx/utils/profiling.py +37 -0
- openscvx/utils/utils.py +100 -0
- openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
- openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
- openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
- openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
- openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
openscvx/utils/cache.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
"""Cache management for compiled solvers.
|
|
2
|
+
|
|
3
|
+
This module provides utilities for managing the cache directory where compiled
|
|
4
|
+
JAX solvers are stored. The cache location follows platform conventions:
|
|
5
|
+
|
|
6
|
+
- **Linux**: ``~/.cache/openscvx/``
|
|
7
|
+
- **macOS**: ``~/Library/Caches/openscvx/``
|
|
8
|
+
- **Windows**: ``%LOCALAPPDATA%/openscvx/Cache/``
|
|
9
|
+
|
|
10
|
+
The cache location can be overridden by setting the ``OPENSCVX_CACHE_DIR``
|
|
11
|
+
environment variable.
|
|
12
|
+
|
|
13
|
+
Example:
|
|
14
|
+
Get the cache directory::
|
|
15
|
+
|
|
16
|
+
import openscvx as ox
|
|
17
|
+
print(ox.get_cache_dir()) # /home/user/.cache/openscvx
|
|
18
|
+
|
|
19
|
+
Clear all cached solvers::
|
|
20
|
+
|
|
21
|
+
import openscvx as ox
|
|
22
|
+
ox.clear_cache()
|
|
23
|
+
|
|
24
|
+
Check cache size::
|
|
25
|
+
|
|
26
|
+
import openscvx as ox
|
|
27
|
+
size_mb = ox.get_cache_size() / (1024 * 1024)
|
|
28
|
+
print(f"Cache size: {size_mb:.1f} MB")
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
import os
|
|
32
|
+
import shutil
|
|
33
|
+
import sys
|
|
34
|
+
from pathlib import Path
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_cache_dir() -> Path:
|
|
38
|
+
"""Get the cache directory for compiled solvers.
|
|
39
|
+
|
|
40
|
+
The cache location is determined in the following order:
|
|
41
|
+
1. ``OPENSCVX_CACHE_DIR`` environment variable (if set)
|
|
42
|
+
2. Platform-specific default:
|
|
43
|
+
- Linux: ``~/.cache/openscvx/``
|
|
44
|
+
- macOS: ``~/Library/Caches/openscvx/``
|
|
45
|
+
- Windows: ``%LOCALAPPDATA%/openscvx/Cache/``
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Path to the cache directory (may not exist yet)
|
|
49
|
+
"""
|
|
50
|
+
# Check environment variable override
|
|
51
|
+
env_dir = os.environ.get("OPENSCVX_CACHE_DIR")
|
|
52
|
+
if env_dir:
|
|
53
|
+
return Path(env_dir)
|
|
54
|
+
|
|
55
|
+
# Platform-specific defaults
|
|
56
|
+
if sys.platform == "darwin":
|
|
57
|
+
# macOS: ~/Library/Caches/openscvx/
|
|
58
|
+
return Path.home() / "Library" / "Caches" / "openscvx"
|
|
59
|
+
elif sys.platform == "win32":
|
|
60
|
+
# Windows: %LOCALAPPDATA%/openscvx/Cache/
|
|
61
|
+
local_app_data = os.environ.get("LOCALAPPDATA")
|
|
62
|
+
if local_app_data:
|
|
63
|
+
return Path(local_app_data) / "openscvx" / "Cache"
|
|
64
|
+
else:
|
|
65
|
+
# Fallback if LOCALAPPDATA not set
|
|
66
|
+
return Path.home() / "AppData" / "Local" / "openscvx" / "Cache"
|
|
67
|
+
else:
|
|
68
|
+
# Linux and others: follow XDG Base Directory Specification
|
|
69
|
+
xdg_cache = os.environ.get("XDG_CACHE_HOME")
|
|
70
|
+
if xdg_cache:
|
|
71
|
+
return Path(xdg_cache) / "openscvx"
|
|
72
|
+
else:
|
|
73
|
+
return Path.home() / ".cache" / "openscvx"
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def clear_cache() -> int:
|
|
77
|
+
"""Clear all cached compiled solvers.
|
|
78
|
+
|
|
79
|
+
Removes all files in the cache directory. The directory itself is
|
|
80
|
+
preserved but emptied.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
Number of files deleted
|
|
84
|
+
|
|
85
|
+
Example:
|
|
86
|
+
Clear the cache::
|
|
87
|
+
|
|
88
|
+
import openscvx as ox
|
|
89
|
+
deleted = ox.clear_cache()
|
|
90
|
+
print(f"Deleted {deleted} cached files")
|
|
91
|
+
"""
|
|
92
|
+
cache_dir = get_cache_dir()
|
|
93
|
+
if not cache_dir.exists():
|
|
94
|
+
return 0
|
|
95
|
+
|
|
96
|
+
count = 0
|
|
97
|
+
for item in cache_dir.iterdir():
|
|
98
|
+
if item.is_file():
|
|
99
|
+
item.unlink()
|
|
100
|
+
count += 1
|
|
101
|
+
elif item.is_dir():
|
|
102
|
+
shutil.rmtree(item)
|
|
103
|
+
count += 1
|
|
104
|
+
|
|
105
|
+
return count
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def get_cache_size() -> int:
|
|
109
|
+
"""Get the total size of the cache in bytes.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
Total size of all files in the cache directory in bytes.
|
|
113
|
+
Returns 0 if the cache directory doesn't exist.
|
|
114
|
+
|
|
115
|
+
Example:
|
|
116
|
+
Check cache size in megabytes::
|
|
117
|
+
|
|
118
|
+
import openscvx as ox
|
|
119
|
+
size_mb = ox.get_cache_size() / (1024 * 1024)
|
|
120
|
+
print(f"Cache size: {size_mb:.1f} MB")
|
|
121
|
+
"""
|
|
122
|
+
cache_dir = get_cache_dir()
|
|
123
|
+
if not cache_dir.exists():
|
|
124
|
+
return 0
|
|
125
|
+
|
|
126
|
+
total = 0
|
|
127
|
+
for item in cache_dir.rglob("*"):
|
|
128
|
+
if item.is_file():
|
|
129
|
+
total += item.stat().st_size
|
|
130
|
+
|
|
131
|
+
return total
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
import numpy as np
|
|
7
|
+
from jax import export
|
|
8
|
+
|
|
9
|
+
from openscvx.utils.cache import get_cache_dir
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from openscvx.symbolic.problem import SymbolicProblem
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_solver_cache_paths(
|
|
16
|
+
symbolic_problem: "SymbolicProblem",
|
|
17
|
+
dt: float,
|
|
18
|
+
total_time: float,
|
|
19
|
+
cache_dir: Optional[Path] = None,
|
|
20
|
+
) -> Tuple[Path, Path]:
|
|
21
|
+
"""Generate cache file paths using symbolic AST hashing.
|
|
22
|
+
|
|
23
|
+
This function computes a hash based on the symbolic structure of the problem,
|
|
24
|
+
which is more stable than hashing lowered JAX code. Two problems with the same
|
|
25
|
+
mathematical structure will produce the same hash, regardless of variable names.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
symbolic_problem: The preprocessed SymbolicProblem
|
|
29
|
+
dt: Time step for propagation
|
|
30
|
+
total_time: Total simulation time
|
|
31
|
+
cache_dir: Directory to store cached solvers. If None, uses the default
|
|
32
|
+
cache directory (see :func:`openscvx.get_cache_dir`).
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Tuple of (discretization_solver_path, propagation_solver_path)
|
|
36
|
+
"""
|
|
37
|
+
from openscvx.symbolic.hashing import hash_symbolic_problem
|
|
38
|
+
|
|
39
|
+
# Get the structural hash of the symbolic problem
|
|
40
|
+
problem_hash = hash_symbolic_problem(symbolic_problem)
|
|
41
|
+
|
|
42
|
+
# Include runtime config in the hash
|
|
43
|
+
final_hasher = hashlib.sha256()
|
|
44
|
+
final_hasher.update(problem_hash.encode())
|
|
45
|
+
final_hasher.update(f"dt:{dt}".encode())
|
|
46
|
+
final_hasher.update(f"total_time:{total_time}".encode())
|
|
47
|
+
final_hash = final_hasher.hexdigest()[:32]
|
|
48
|
+
|
|
49
|
+
solver_dir = cache_dir if cache_dir is not None else get_cache_dir()
|
|
50
|
+
solver_dir.mkdir(parents=True, exist_ok=True)
|
|
51
|
+
|
|
52
|
+
dis_solver_file = solver_dir / f"compiled_discretization_solver_{final_hash}.jax"
|
|
53
|
+
prop_solver_file = solver_dir / f"compiled_propagation_solver_{final_hash}.jax"
|
|
54
|
+
|
|
55
|
+
return dis_solver_file, prop_solver_file
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def load_or_compile_discretization_solver(
|
|
59
|
+
discretization_solver: callable,
|
|
60
|
+
cache_file: Path,
|
|
61
|
+
params: Dict[str, Any],
|
|
62
|
+
n_discretization_nodes: int,
|
|
63
|
+
n_states: int,
|
|
64
|
+
n_controls: int,
|
|
65
|
+
save_compiled: bool = False,
|
|
66
|
+
debug: bool = False,
|
|
67
|
+
) -> callable:
|
|
68
|
+
"""Load discretization solver from cache or compile and cache it.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
discretization_solver: The solver function to compile
|
|
72
|
+
cache_file: Path to cache file
|
|
73
|
+
params: Parameters dictionary
|
|
74
|
+
n_discretization_nodes: Number of discretization nodes
|
|
75
|
+
n_states: Number of state variables
|
|
76
|
+
n_controls: Number of control variables
|
|
77
|
+
save_compiled: Whether to save/load compiled solvers
|
|
78
|
+
debug: Whether in debug mode (skip compilation)
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Compiled discretization solver
|
|
82
|
+
"""
|
|
83
|
+
if debug:
|
|
84
|
+
return discretization_solver
|
|
85
|
+
|
|
86
|
+
if save_compiled:
|
|
87
|
+
try:
|
|
88
|
+
with open(cache_file, "rb") as f:
|
|
89
|
+
serial_dis = f.read()
|
|
90
|
+
compiled_solver = export.deserialize(serial_dis)
|
|
91
|
+
print("✓ Loaded existing discretization solver")
|
|
92
|
+
return compiled_solver
|
|
93
|
+
except FileNotFoundError:
|
|
94
|
+
print("Compiling discretization solver...")
|
|
95
|
+
|
|
96
|
+
else:
|
|
97
|
+
print("Compiling discretization solver (not saving/loading from disk)...")
|
|
98
|
+
|
|
99
|
+
# Pass parameters as a single dictionary
|
|
100
|
+
compiled_solver = export.export(jax.jit(discretization_solver))(
|
|
101
|
+
np.ones((n_discretization_nodes, n_states)),
|
|
102
|
+
np.ones((n_discretization_nodes, n_controls)),
|
|
103
|
+
params,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if save_compiled:
|
|
107
|
+
with open(cache_file, "wb") as f:
|
|
108
|
+
f.write(compiled_solver.serialize())
|
|
109
|
+
print("✓ Discretization solver compiled and saved")
|
|
110
|
+
|
|
111
|
+
return compiled_solver
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def load_or_compile_propagation_solver(
|
|
115
|
+
propagation_solver: callable,
|
|
116
|
+
cache_file: Path,
|
|
117
|
+
params: Dict[str, Any],
|
|
118
|
+
n_states_prop: int,
|
|
119
|
+
n_controls: int,
|
|
120
|
+
max_tau_len: int,
|
|
121
|
+
save_compiled: bool = False,
|
|
122
|
+
) -> callable:
|
|
123
|
+
"""Load propagation solver from cache or compile and cache it.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
propagation_solver: The solver function to compile
|
|
127
|
+
cache_file: Path to cache file
|
|
128
|
+
params: Parameters dictionary
|
|
129
|
+
n_states_prop: Number of propagation state variables
|
|
130
|
+
n_controls: Number of control variables
|
|
131
|
+
max_tau_len: Maximum tau length for propagation
|
|
132
|
+
save_compiled: Whether to save/load compiled solvers
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
Compiled propagation solver
|
|
136
|
+
"""
|
|
137
|
+
if save_compiled:
|
|
138
|
+
try:
|
|
139
|
+
with open(cache_file, "rb") as f:
|
|
140
|
+
serial_prop = f.read()
|
|
141
|
+
compiled_solver = export.deserialize(serial_prop)
|
|
142
|
+
print("✓ Loaded existing propagation solver")
|
|
143
|
+
return compiled_solver
|
|
144
|
+
except FileNotFoundError:
|
|
145
|
+
print("Compiling propagation solver...")
|
|
146
|
+
|
|
147
|
+
else:
|
|
148
|
+
print("Compiling propagation solver (not saving/loading from disk)...")
|
|
149
|
+
|
|
150
|
+
# Pass parameters as a single dictionary
|
|
151
|
+
compiled_solver = export.export(jax.jit(propagation_solver))(
|
|
152
|
+
np.ones(n_states_prop), # x_0
|
|
153
|
+
(0.0, 0.0), # time span
|
|
154
|
+
np.ones((1, n_controls)), # controls_current
|
|
155
|
+
np.ones((1, n_controls)), # controls_next
|
|
156
|
+
np.ones((1, 1)), # tau_0
|
|
157
|
+
np.ones((1, 1)).astype("int"), # segment index
|
|
158
|
+
0, # idx_s_stop
|
|
159
|
+
np.ones((max_tau_len,)), # save_time (tau_cur_padded)
|
|
160
|
+
np.ones((max_tau_len,), dtype=bool), # mask_padded (boolean mask)
|
|
161
|
+
params, # additional parameters as dict
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
if save_compiled:
|
|
165
|
+
with open(cache_file, "wb") as f:
|
|
166
|
+
f.write(compiled_solver.serialize())
|
|
167
|
+
print("✓ Propagation solver compiled and saved")
|
|
168
|
+
|
|
169
|
+
return compiled_solver
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def prime_propagation_solver(
|
|
173
|
+
propagation_solver: callable, params: Dict[str, Any], settings
|
|
174
|
+
) -> None:
|
|
175
|
+
"""Prime the propagation solver with a test call to ensure it works.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
propagation_solver: Compiled propagation solver
|
|
179
|
+
params: Parameters dictionary
|
|
180
|
+
settings: Settings configuration object
|
|
181
|
+
"""
|
|
182
|
+
try:
|
|
183
|
+
x_0 = np.ones(settings.sim.x_prop.initial.shape, dtype=settings.sim.x_prop.initial.dtype)
|
|
184
|
+
tau_grid = (0.0, 1.0)
|
|
185
|
+
controls_current = np.ones((1, settings.sim.u.shape[0]), dtype=settings.sim.u.guess.dtype)
|
|
186
|
+
controls_next = np.ones((1, settings.sim.u.shape[0]), dtype=settings.sim.u.guess.dtype)
|
|
187
|
+
tau_init = np.array([[0.0]], dtype=np.float64)
|
|
188
|
+
node = np.array([[0]], dtype=np.int64)
|
|
189
|
+
idx_s_stop = settings.sim.time_dilation_slice.stop
|
|
190
|
+
save_time = np.ones((settings.prp.max_tau_len,), dtype=np.float64)
|
|
191
|
+
mask_padded = np.ones((settings.prp.max_tau_len,), dtype=bool)
|
|
192
|
+
# Create dummy params dict with same structure
|
|
193
|
+
dummy_params = {
|
|
194
|
+
name: np.ones_like(value) if hasattr(value, "shape") else float(value)
|
|
195
|
+
for name, value in params.items()
|
|
196
|
+
}
|
|
197
|
+
propagation_solver.call(
|
|
198
|
+
x_0,
|
|
199
|
+
tau_grid,
|
|
200
|
+
controls_current,
|
|
201
|
+
controls_next,
|
|
202
|
+
tau_init,
|
|
203
|
+
node,
|
|
204
|
+
idx_s_stop,
|
|
205
|
+
save_time,
|
|
206
|
+
mask_padded,
|
|
207
|
+
dummy_params,
|
|
208
|
+
)
|
|
209
|
+
except Exception as e:
|
|
210
|
+
print(f"[Initialization] Priming propagation_solver.call failed: {e}")
|
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
import queue
|
|
2
|
+
import sys
|
|
3
|
+
import time
|
|
4
|
+
import warnings
|
|
5
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
6
|
+
|
|
7
|
+
import jax
|
|
8
|
+
import numpy as np
|
|
9
|
+
from termcolor import colored
|
|
10
|
+
|
|
11
|
+
from openscvx.algorithms import OptimizationResults
|
|
12
|
+
|
|
13
|
+
warnings.filterwarnings("ignore")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# Define colors for printing
|
|
17
|
+
col_main = "blue"
|
|
18
|
+
col_pos = "green"
|
|
19
|
+
col_neg = "red"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_version() -> str:
|
|
23
|
+
try:
|
|
24
|
+
return version("openscvx")
|
|
25
|
+
except PackageNotFoundError:
|
|
26
|
+
return "0.0.0"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def print_summary_box(lines, title="Summary"):
|
|
30
|
+
"""
|
|
31
|
+
Print a centered summary box with the given lines.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
lines (list): List of strings to display in the box
|
|
35
|
+
title (str): Title for the box (default: "Summary")
|
|
36
|
+
"""
|
|
37
|
+
# Find the longest line (excluding the title which will be handled separately)
|
|
38
|
+
content_lines = lines[1:] if len(lines) > 1 else []
|
|
39
|
+
max_content_width = max(len(line) for line in content_lines) if content_lines else 0
|
|
40
|
+
title_width = len(title)
|
|
41
|
+
|
|
42
|
+
# Box width should accommodate both title and content
|
|
43
|
+
box_width = max(max_content_width, title_width) + 4 # Add padding for the box borders
|
|
44
|
+
|
|
45
|
+
# Center with respect to the 89-character horizontal lines in io.py
|
|
46
|
+
total_width = 89
|
|
47
|
+
if box_width <= total_width:
|
|
48
|
+
indent = (total_width - box_width) // 2
|
|
49
|
+
else:
|
|
50
|
+
# If box is wider than 89 chars, use a smaller fixed indentation
|
|
51
|
+
indent = 2
|
|
52
|
+
|
|
53
|
+
# Print the box with dynamic width and centering
|
|
54
|
+
print(f"\n{' ' * indent}╭{'─' * box_width}╮")
|
|
55
|
+
print(f"{' ' * indent}│ {title:^{box_width - 2}} │")
|
|
56
|
+
print(f"{' ' * indent}├{'─' * box_width}┤")
|
|
57
|
+
for line in content_lines:
|
|
58
|
+
print(f"{' ' * indent}│ {line:<{box_width - 2}} │")
|
|
59
|
+
print(f"{' ' * indent}╰{'─' * box_width}╯\n")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def print_problem_summary(settings, lowered):
|
|
63
|
+
"""
|
|
64
|
+
Print the problem summary box.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
settings: Configuration settings containing problem information
|
|
68
|
+
lowered: LoweredProblem from lower_symbolic_problem()
|
|
69
|
+
"""
|
|
70
|
+
n_nodal_convex = len(lowered.cvxpy_constraints.constraints)
|
|
71
|
+
n_nodal_nonconvex = len(lowered.jax_constraints.nodal)
|
|
72
|
+
n_ctcs = len(lowered.jax_constraints.ctcs)
|
|
73
|
+
n_augmented = settings.sim.n_states - settings.sim.true_state_slice.stop
|
|
74
|
+
|
|
75
|
+
# Count CVXPy variables, parameters, and constraints
|
|
76
|
+
from openscvx.solvers import optimal_control_problem
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
# Build OCP using LoweredProblem
|
|
80
|
+
prob = optimal_control_problem(settings, lowered)
|
|
81
|
+
|
|
82
|
+
# Get the actual problem size information like CVXPy verbose output
|
|
83
|
+
n_cvx_variables = sum(var.size for var in prob.variables())
|
|
84
|
+
n_cvx_parameters = sum(param.size for param in prob.parameters())
|
|
85
|
+
n_cvx_constraints = sum(constraint.size for constraint in prob.constraints)
|
|
86
|
+
except Exception:
|
|
87
|
+
# Fallback if problem construction fails
|
|
88
|
+
n_cvx_variables = 0
|
|
89
|
+
n_cvx_parameters = 0
|
|
90
|
+
n_cvx_constraints = 0
|
|
91
|
+
|
|
92
|
+
# Get JAX backend information
|
|
93
|
+
jax_backend = jax.devices()[0].platform.upper()
|
|
94
|
+
jax_version = jax.__version__
|
|
95
|
+
|
|
96
|
+
# Build weights string conditionally
|
|
97
|
+
if isinstance(settings.scp.lam_vc, np.ndarray):
|
|
98
|
+
lam_vc_str = f"λ_vc=matrix({settings.scp.lam_vc.shape})"
|
|
99
|
+
else:
|
|
100
|
+
lam_vc_str = f"λ_vc={settings.scp.lam_vc:4.1f}"
|
|
101
|
+
weights_parts = [
|
|
102
|
+
f"λ_cost={settings.scp.lam_cost:4.1f}",
|
|
103
|
+
f"λ_tr={settings.scp.w_tr:4.1f}",
|
|
104
|
+
lam_vc_str,
|
|
105
|
+
]
|
|
106
|
+
|
|
107
|
+
# Add λ_vb only if there are nodal nonconvex constraints
|
|
108
|
+
if n_nodal_nonconvex > 0:
|
|
109
|
+
weights_parts.append(f"λ_vb={settings.scp.lam_vb:4.1f}")
|
|
110
|
+
|
|
111
|
+
weights_str = ", ".join(weights_parts)
|
|
112
|
+
|
|
113
|
+
lines = [
|
|
114
|
+
"Problem Summary",
|
|
115
|
+
(
|
|
116
|
+
f"Dimensions: {settings.sim.n_states} states ({n_augmented} aug),"
|
|
117
|
+
f" {settings.sim.n_controls} controls, {settings.scp.n} nodes"
|
|
118
|
+
),
|
|
119
|
+
f"Constraints: {n_nodal_convex} conv, {n_nodal_nonconvex} nonconv, {n_ctcs} ctcs",
|
|
120
|
+
(
|
|
121
|
+
f"Subproblem: {n_cvx_variables} vars, {n_cvx_parameters} params,"
|
|
122
|
+
f" {n_cvx_constraints} constraints"
|
|
123
|
+
),
|
|
124
|
+
f"Weights: {weights_str}",
|
|
125
|
+
f"CVX Solver: {settings.cvx.solver}, Discretization Solver: {settings.dis.solver}",
|
|
126
|
+
f"JAX Backend: {jax_backend} (v{jax_version})",
|
|
127
|
+
]
|
|
128
|
+
|
|
129
|
+
print_summary_box(lines, "Problem Summary")
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def print_results_summary(result: OptimizationResults, timing_post, timing_init, timing_solve):
|
|
133
|
+
"""
|
|
134
|
+
Print the results summary box.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
result (OptimizationResults): Optimization results object
|
|
138
|
+
timing_post (float): Post-processing time
|
|
139
|
+
timing_init (float): Initialization time
|
|
140
|
+
timing_solve (float): Solve time
|
|
141
|
+
"""
|
|
142
|
+
cost = result.get("cost", 0.0)
|
|
143
|
+
ctcs_violation = result.get("ctcs_violation", 0.0)
|
|
144
|
+
|
|
145
|
+
# Convert numpy arrays to scalars for formatting
|
|
146
|
+
if hasattr(cost, "item"):
|
|
147
|
+
cost = cost.item()
|
|
148
|
+
|
|
149
|
+
# Handle CTCS violation - display as 1D array
|
|
150
|
+
if hasattr(ctcs_violation, "size"):
|
|
151
|
+
if ctcs_violation.size == 1:
|
|
152
|
+
ctcs_violation_str = f"[{ctcs_violation.item():.2e}]"
|
|
153
|
+
else:
|
|
154
|
+
# Display as 1D array
|
|
155
|
+
ctcs_violation_str = f"[{', '.join([f'{v:.2e}' for v in ctcs_violation])}]"
|
|
156
|
+
else:
|
|
157
|
+
ctcs_violation_str = f"[{ctcs_violation:.2e}]"
|
|
158
|
+
|
|
159
|
+
# Calculate total computation time
|
|
160
|
+
total_time = (timing_init or 0.0) + (timing_solve or 0.0) + timing_post
|
|
161
|
+
|
|
162
|
+
lines = [
|
|
163
|
+
"Results Summary",
|
|
164
|
+
f"Cost: {cost:.6f}",
|
|
165
|
+
f"CTCS Constraint Violation: {ctcs_violation_str}",
|
|
166
|
+
f"Preprocessing Time: {timing_init or 0.0:.3f}s",
|
|
167
|
+
f"Main Solve Time: {timing_solve or 0.0:.3f}s",
|
|
168
|
+
f"Post-processing Time: {timing_post:.3f}s",
|
|
169
|
+
f"Total Computation Time: {total_time:.3f}s",
|
|
170
|
+
]
|
|
171
|
+
|
|
172
|
+
print_summary_box(lines, "Results Summary")
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def intro():
|
|
176
|
+
# Silence syntax warnings
|
|
177
|
+
warnings.filterwarnings("ignore")
|
|
178
|
+
# fmt: off
|
|
179
|
+
ascii_art = rf"""
|
|
180
|
+
|
|
181
|
+
____ _____ _____
|
|
182
|
+
/ __ \ / ____|/ ____|
|
|
183
|
+
| | | |_ __ ___ _ __ | (___ | | __ ____ __
|
|
184
|
+
| | | | '_ \ / _ \ '_ \ \___ \| | \ \ / /\ \/ /
|
|
185
|
+
| |__| | |_) | __/ | | |____) | |___\ V / > <
|
|
186
|
+
\____/| .__/ \___|_| |_|_____/ \_____\_/ /_/\_\
|
|
187
|
+
| |
|
|
188
|
+
|_|
|
|
189
|
+
─────────────────────────────────────────────────────────────────────────────────────────────────────────
|
|
190
|
+
Author: Chris Hayner and Griffin Norris
|
|
191
|
+
Autonomous Controls Laboratory
|
|
192
|
+
University of Washington
|
|
193
|
+
Version: {get_version()}
|
|
194
|
+
─────────────────────────────────────────────────────────────────────────────────────────────────────────
|
|
195
|
+
"""
|
|
196
|
+
# fmt: on
|
|
197
|
+
print(ascii_art)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def header():
|
|
201
|
+
print(
|
|
202
|
+
colored(
|
|
203
|
+
"─────────────────────────────────────────────────────────────────────────────────────────────────────────"
|
|
204
|
+
)
|
|
205
|
+
)
|
|
206
|
+
print(
|
|
207
|
+
"{:^4} │ {:^7} │ {:^7} │ {:^7} │ {:^7} │ {:^7} │ {:^7} │ {:^7} │ {:^14}".format(
|
|
208
|
+
"Iter",
|
|
209
|
+
"Dis Time (ms)",
|
|
210
|
+
"Solve Time (ms)",
|
|
211
|
+
"J_total",
|
|
212
|
+
"J_tr",
|
|
213
|
+
"J_vb",
|
|
214
|
+
"J_vc",
|
|
215
|
+
"Cost",
|
|
216
|
+
"Solver Status",
|
|
217
|
+
)
|
|
218
|
+
)
|
|
219
|
+
print(
|
|
220
|
+
colored(
|
|
221
|
+
"─────────────────────────────────────────────────────────────────────────────────────────────────────────"
|
|
222
|
+
)
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def intermediate(print_queue, params):
|
|
227
|
+
hz = 30.0
|
|
228
|
+
while True:
|
|
229
|
+
t_start = time.time()
|
|
230
|
+
try:
|
|
231
|
+
data = print_queue.get(timeout=1.0 / hz)
|
|
232
|
+
# remove bottom labels and line
|
|
233
|
+
if data["iter"] != 1:
|
|
234
|
+
sys.stdout.write("\x1b[1A\x1b[2K\x1b[1A\x1b[2K")
|
|
235
|
+
if data["prob_stat"][3] == "f":
|
|
236
|
+
# Only show the first element of the string
|
|
237
|
+
data["prob_stat"] = data["prob_stat"][0]
|
|
238
|
+
|
|
239
|
+
iter_colored = colored("{:4d}".format(data["iter"]))
|
|
240
|
+
J_tot_colored = colored("{:.1e}".format(data["J_total"]))
|
|
241
|
+
J_tr_colored = colored(
|
|
242
|
+
"{:.1e}".format(data["J_tr"]),
|
|
243
|
+
col_pos if data["J_tr"] <= params.scp.ep_tr else col_neg,
|
|
244
|
+
)
|
|
245
|
+
J_vb_colored = colored(
|
|
246
|
+
"{:.1e}".format(data["J_vb"]),
|
|
247
|
+
col_pos if data["J_vb"] <= params.scp.ep_vb else col_neg,
|
|
248
|
+
)
|
|
249
|
+
J_vc_colored = colored(
|
|
250
|
+
"{:.1e}".format(data["J_vc"]),
|
|
251
|
+
col_pos if data["J_vc"] <= params.scp.ep_vc else col_neg,
|
|
252
|
+
)
|
|
253
|
+
cost_colored = colored("{:.1e}".format(data["cost"]))
|
|
254
|
+
prob_stat_colored = colored(
|
|
255
|
+
data["prob_stat"], col_pos if data["prob_stat"] == "optimal" else col_neg
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
print(
|
|
259
|
+
"{:^4} │ {:^6.2f} │ {:^6.2F} │ {:^7} │ {:^7} │ {:^7} │ {:^7} │ "
|
|
260
|
+
" {:^7} │ {:^14}".format(
|
|
261
|
+
iter_colored,
|
|
262
|
+
data["dis_time"],
|
|
263
|
+
data["subprop_time"],
|
|
264
|
+
J_tot_colored,
|
|
265
|
+
J_tr_colored,
|
|
266
|
+
J_vb_colored,
|
|
267
|
+
J_vc_colored,
|
|
268
|
+
cost_colored,
|
|
269
|
+
prob_stat_colored,
|
|
270
|
+
)
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
print(
|
|
274
|
+
colored(
|
|
275
|
+
"─────────────────────────────────────────────────────────────────────────────────────────────────────────"
|
|
276
|
+
)
|
|
277
|
+
)
|
|
278
|
+
print(
|
|
279
|
+
"{:^4} │ {:^7} │ {:^7} │ {:^7} │ {:^7} │ {:^7} │ {:^7} │ {:^7} │ {:^14}".format(
|
|
280
|
+
"Iter",
|
|
281
|
+
"Dis Time (ms)",
|
|
282
|
+
"Solve Time (ms)",
|
|
283
|
+
"J_total",
|
|
284
|
+
"J_tr",
|
|
285
|
+
"J_vb",
|
|
286
|
+
"J_vc",
|
|
287
|
+
"Cost",
|
|
288
|
+
"Solver Status",
|
|
289
|
+
)
|
|
290
|
+
)
|
|
291
|
+
except queue.Empty:
|
|
292
|
+
pass
|
|
293
|
+
time.sleep(max(0.0, 1.0 / hz - (time.time() - t_start)))
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def footer():
|
|
297
|
+
print(
|
|
298
|
+
colored(
|
|
299
|
+
"─────────────────────────────────────────────────────────────────────────────────────────────────────────"
|
|
300
|
+
)
|
|
301
|
+
)
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import TYPE_CHECKING, Optional
|
|
3
|
+
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
import cProfile
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def profiling_start(profiling_enabled: bool) -> "Optional[cProfile.Profile]":
|
|
9
|
+
"""Start profiling if enabled.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
profiling_enabled: Whether to enable profiling.
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
Profile object if enabled, None otherwise.
|
|
16
|
+
"""
|
|
17
|
+
if profiling_enabled:
|
|
18
|
+
import cProfile
|
|
19
|
+
|
|
20
|
+
pr = cProfile.Profile()
|
|
21
|
+
pr.enable()
|
|
22
|
+
return pr
|
|
23
|
+
return None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def profiling_end(pr: "Optional[cProfile.Profile]", identifier: str):
|
|
27
|
+
"""Stop profiling and save results with timestamp.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
pr: Profile object from profiling_start, or None.
|
|
31
|
+
identifier: Identifier for the profiling session (e.g., "solve", "initialize").
|
|
32
|
+
"""
|
|
33
|
+
if pr is not None:
|
|
34
|
+
pr.disable()
|
|
35
|
+
# Save results so it can be visualized with snakeviz
|
|
36
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
37
|
+
pr.dump_stats(f"profiling/{timestamp}_{identifier}.prof")
|