openscvx 0.3.2.dev170__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of openscvx might be problematic. Click here for more details.

Files changed (79) hide show
  1. openscvx/__init__.py +123 -0
  2. openscvx/_version.py +34 -0
  3. openscvx/algorithms/__init__.py +92 -0
  4. openscvx/algorithms/autotuning.py +24 -0
  5. openscvx/algorithms/base.py +351 -0
  6. openscvx/algorithms/optimization_results.py +215 -0
  7. openscvx/algorithms/penalized_trust_region.py +384 -0
  8. openscvx/config.py +437 -0
  9. openscvx/discretization/__init__.py +47 -0
  10. openscvx/discretization/discretization.py +236 -0
  11. openscvx/expert/__init__.py +23 -0
  12. openscvx/expert/byof.py +326 -0
  13. openscvx/expert/lowering.py +419 -0
  14. openscvx/expert/validation.py +357 -0
  15. openscvx/integrators/__init__.py +48 -0
  16. openscvx/integrators/runge_kutta.py +281 -0
  17. openscvx/lowered/__init__.py +30 -0
  18. openscvx/lowered/cvxpy_constraints.py +23 -0
  19. openscvx/lowered/cvxpy_variables.py +124 -0
  20. openscvx/lowered/dynamics.py +34 -0
  21. openscvx/lowered/jax_constraints.py +133 -0
  22. openscvx/lowered/parameters.py +54 -0
  23. openscvx/lowered/problem.py +70 -0
  24. openscvx/lowered/unified.py +718 -0
  25. openscvx/plotting/__init__.py +63 -0
  26. openscvx/plotting/plotting.py +756 -0
  27. openscvx/plotting/scp_iteration.py +299 -0
  28. openscvx/plotting/viser/__init__.py +126 -0
  29. openscvx/plotting/viser/animated.py +605 -0
  30. openscvx/plotting/viser/plotly_integration.py +333 -0
  31. openscvx/plotting/viser/primitives.py +355 -0
  32. openscvx/plotting/viser/scp.py +459 -0
  33. openscvx/plotting/viser/server.py +112 -0
  34. openscvx/problem.py +734 -0
  35. openscvx/propagation/__init__.py +60 -0
  36. openscvx/propagation/post_processing.py +104 -0
  37. openscvx/propagation/propagation.py +248 -0
  38. openscvx/solvers/__init__.py +51 -0
  39. openscvx/solvers/cvxpy.py +226 -0
  40. openscvx/symbolic/__init__.py +9 -0
  41. openscvx/symbolic/augmentation.py +630 -0
  42. openscvx/symbolic/builder.py +492 -0
  43. openscvx/symbolic/constraint_set.py +92 -0
  44. openscvx/symbolic/expr/__init__.py +222 -0
  45. openscvx/symbolic/expr/arithmetic.py +517 -0
  46. openscvx/symbolic/expr/array.py +632 -0
  47. openscvx/symbolic/expr/constraint.py +796 -0
  48. openscvx/symbolic/expr/control.py +135 -0
  49. openscvx/symbolic/expr/expr.py +720 -0
  50. openscvx/symbolic/expr/lie/__init__.py +87 -0
  51. openscvx/symbolic/expr/lie/adjoint.py +357 -0
  52. openscvx/symbolic/expr/lie/se3.py +172 -0
  53. openscvx/symbolic/expr/lie/so3.py +138 -0
  54. openscvx/symbolic/expr/linalg.py +279 -0
  55. openscvx/symbolic/expr/math.py +699 -0
  56. openscvx/symbolic/expr/spatial.py +209 -0
  57. openscvx/symbolic/expr/state.py +607 -0
  58. openscvx/symbolic/expr/stl.py +136 -0
  59. openscvx/symbolic/expr/variable.py +321 -0
  60. openscvx/symbolic/hashing.py +112 -0
  61. openscvx/symbolic/lower.py +760 -0
  62. openscvx/symbolic/lowerers/__init__.py +106 -0
  63. openscvx/symbolic/lowerers/cvxpy.py +1302 -0
  64. openscvx/symbolic/lowerers/jax.py +1382 -0
  65. openscvx/symbolic/preprocessing.py +757 -0
  66. openscvx/symbolic/problem.py +110 -0
  67. openscvx/symbolic/time.py +116 -0
  68. openscvx/symbolic/unified.py +420 -0
  69. openscvx/utils/__init__.py +20 -0
  70. openscvx/utils/cache.py +131 -0
  71. openscvx/utils/caching.py +210 -0
  72. openscvx/utils/printing.py +301 -0
  73. openscvx/utils/profiling.py +37 -0
  74. openscvx/utils/utils.py +100 -0
  75. openscvx-0.3.2.dev170.dist-info/METADATA +350 -0
  76. openscvx-0.3.2.dev170.dist-info/RECORD +79 -0
  77. openscvx-0.3.2.dev170.dist-info/WHEEL +5 -0
  78. openscvx-0.3.2.dev170.dist-info/licenses/LICENSE +201 -0
  79. openscvx-0.3.2.dev170.dist-info/top_level.txt +1 -0
@@ -0,0 +1,131 @@
1
+ """Cache management for compiled solvers.
2
+
3
+ This module provides utilities for managing the cache directory where compiled
4
+ JAX solvers are stored. The cache location follows platform conventions:
5
+
6
+ - **Linux**: ``~/.cache/openscvx/``
7
+ - **macOS**: ``~/Library/Caches/openscvx/``
8
+ - **Windows**: ``%LOCALAPPDATA%/openscvx/Cache/``
9
+
10
+ The cache location can be overridden by setting the ``OPENSCVX_CACHE_DIR``
11
+ environment variable.
12
+
13
+ Example:
14
+ Get the cache directory::
15
+
16
+ import openscvx as ox
17
+ print(ox.get_cache_dir()) # /home/user/.cache/openscvx
18
+
19
+ Clear all cached solvers::
20
+
21
+ import openscvx as ox
22
+ ox.clear_cache()
23
+
24
+ Check cache size::
25
+
26
+ import openscvx as ox
27
+ size_mb = ox.get_cache_size() / (1024 * 1024)
28
+ print(f"Cache size: {size_mb:.1f} MB")
29
+ """
30
+
31
+ import os
32
+ import shutil
33
+ import sys
34
+ from pathlib import Path
35
+
36
+
37
+ def get_cache_dir() -> Path:
38
+ """Get the cache directory for compiled solvers.
39
+
40
+ The cache location is determined in the following order:
41
+ 1. ``OPENSCVX_CACHE_DIR`` environment variable (if set)
42
+ 2. Platform-specific default:
43
+ - Linux: ``~/.cache/openscvx/``
44
+ - macOS: ``~/Library/Caches/openscvx/``
45
+ - Windows: ``%LOCALAPPDATA%/openscvx/Cache/``
46
+
47
+ Returns:
48
+ Path to the cache directory (may not exist yet)
49
+ """
50
+ # Check environment variable override
51
+ env_dir = os.environ.get("OPENSCVX_CACHE_DIR")
52
+ if env_dir:
53
+ return Path(env_dir)
54
+
55
+ # Platform-specific defaults
56
+ if sys.platform == "darwin":
57
+ # macOS: ~/Library/Caches/openscvx/
58
+ return Path.home() / "Library" / "Caches" / "openscvx"
59
+ elif sys.platform == "win32":
60
+ # Windows: %LOCALAPPDATA%/openscvx/Cache/
61
+ local_app_data = os.environ.get("LOCALAPPDATA")
62
+ if local_app_data:
63
+ return Path(local_app_data) / "openscvx" / "Cache"
64
+ else:
65
+ # Fallback if LOCALAPPDATA not set
66
+ return Path.home() / "AppData" / "Local" / "openscvx" / "Cache"
67
+ else:
68
+ # Linux and others: follow XDG Base Directory Specification
69
+ xdg_cache = os.environ.get("XDG_CACHE_HOME")
70
+ if xdg_cache:
71
+ return Path(xdg_cache) / "openscvx"
72
+ else:
73
+ return Path.home() / ".cache" / "openscvx"
74
+
75
+
76
+ def clear_cache() -> int:
77
+ """Clear all cached compiled solvers.
78
+
79
+ Removes all files in the cache directory. The directory itself is
80
+ preserved but emptied.
81
+
82
+ Returns:
83
+ Number of files deleted
84
+
85
+ Example:
86
+ Clear the cache::
87
+
88
+ import openscvx as ox
89
+ deleted = ox.clear_cache()
90
+ print(f"Deleted {deleted} cached files")
91
+ """
92
+ cache_dir = get_cache_dir()
93
+ if not cache_dir.exists():
94
+ return 0
95
+
96
+ count = 0
97
+ for item in cache_dir.iterdir():
98
+ if item.is_file():
99
+ item.unlink()
100
+ count += 1
101
+ elif item.is_dir():
102
+ shutil.rmtree(item)
103
+ count += 1
104
+
105
+ return count
106
+
107
+
108
+ def get_cache_size() -> int:
109
+ """Get the total size of the cache in bytes.
110
+
111
+ Returns:
112
+ Total size of all files in the cache directory in bytes.
113
+ Returns 0 if the cache directory doesn't exist.
114
+
115
+ Example:
116
+ Check cache size in megabytes::
117
+
118
+ import openscvx as ox
119
+ size_mb = ox.get_cache_size() / (1024 * 1024)
120
+ print(f"Cache size: {size_mb:.1f} MB")
121
+ """
122
+ cache_dir = get_cache_dir()
123
+ if not cache_dir.exists():
124
+ return 0
125
+
126
+ total = 0
127
+ for item in cache_dir.rglob("*"):
128
+ if item.is_file():
129
+ total += item.stat().st_size
130
+
131
+ return total
@@ -0,0 +1,210 @@
1
+ import hashlib
2
+ from pathlib import Path
3
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
4
+
5
+ import jax
6
+ import numpy as np
7
+ from jax import export
8
+
9
+ from openscvx.utils.cache import get_cache_dir
10
+
11
+ if TYPE_CHECKING:
12
+ from openscvx.symbolic.problem import SymbolicProblem
13
+
14
+
15
+ def get_solver_cache_paths(
16
+ symbolic_problem: "SymbolicProblem",
17
+ dt: float,
18
+ total_time: float,
19
+ cache_dir: Optional[Path] = None,
20
+ ) -> Tuple[Path, Path]:
21
+ """Generate cache file paths using symbolic AST hashing.
22
+
23
+ This function computes a hash based on the symbolic structure of the problem,
24
+ which is more stable than hashing lowered JAX code. Two problems with the same
25
+ mathematical structure will produce the same hash, regardless of variable names.
26
+
27
+ Args:
28
+ symbolic_problem: The preprocessed SymbolicProblem
29
+ dt: Time step for propagation
30
+ total_time: Total simulation time
31
+ cache_dir: Directory to store cached solvers. If None, uses the default
32
+ cache directory (see :func:`openscvx.get_cache_dir`).
33
+
34
+ Returns:
35
+ Tuple of (discretization_solver_path, propagation_solver_path)
36
+ """
37
+ from openscvx.symbolic.hashing import hash_symbolic_problem
38
+
39
+ # Get the structural hash of the symbolic problem
40
+ problem_hash = hash_symbolic_problem(symbolic_problem)
41
+
42
+ # Include runtime config in the hash
43
+ final_hasher = hashlib.sha256()
44
+ final_hasher.update(problem_hash.encode())
45
+ final_hasher.update(f"dt:{dt}".encode())
46
+ final_hasher.update(f"total_time:{total_time}".encode())
47
+ final_hash = final_hasher.hexdigest()[:32]
48
+
49
+ solver_dir = cache_dir if cache_dir is not None else get_cache_dir()
50
+ solver_dir.mkdir(parents=True, exist_ok=True)
51
+
52
+ dis_solver_file = solver_dir / f"compiled_discretization_solver_{final_hash}.jax"
53
+ prop_solver_file = solver_dir / f"compiled_propagation_solver_{final_hash}.jax"
54
+
55
+ return dis_solver_file, prop_solver_file
56
+
57
+
58
+ def load_or_compile_discretization_solver(
59
+ discretization_solver: callable,
60
+ cache_file: Path,
61
+ params: Dict[str, Any],
62
+ n_discretization_nodes: int,
63
+ n_states: int,
64
+ n_controls: int,
65
+ save_compiled: bool = False,
66
+ debug: bool = False,
67
+ ) -> callable:
68
+ """Load discretization solver from cache or compile and cache it.
69
+
70
+ Args:
71
+ discretization_solver: The solver function to compile
72
+ cache_file: Path to cache file
73
+ params: Parameters dictionary
74
+ n_discretization_nodes: Number of discretization nodes
75
+ n_states: Number of state variables
76
+ n_controls: Number of control variables
77
+ save_compiled: Whether to save/load compiled solvers
78
+ debug: Whether in debug mode (skip compilation)
79
+
80
+ Returns:
81
+ Compiled discretization solver
82
+ """
83
+ if debug:
84
+ return discretization_solver
85
+
86
+ if save_compiled:
87
+ try:
88
+ with open(cache_file, "rb") as f:
89
+ serial_dis = f.read()
90
+ compiled_solver = export.deserialize(serial_dis)
91
+ print("✓ Loaded existing discretization solver")
92
+ return compiled_solver
93
+ except FileNotFoundError:
94
+ print("Compiling discretization solver...")
95
+
96
+ else:
97
+ print("Compiling discretization solver (not saving/loading from disk)...")
98
+
99
+ # Pass parameters as a single dictionary
100
+ compiled_solver = export.export(jax.jit(discretization_solver))(
101
+ np.ones((n_discretization_nodes, n_states)),
102
+ np.ones((n_discretization_nodes, n_controls)),
103
+ params,
104
+ )
105
+
106
+ if save_compiled:
107
+ with open(cache_file, "wb") as f:
108
+ f.write(compiled_solver.serialize())
109
+ print("✓ Discretization solver compiled and saved")
110
+
111
+ return compiled_solver
112
+
113
+
114
+ def load_or_compile_propagation_solver(
115
+ propagation_solver: callable,
116
+ cache_file: Path,
117
+ params: Dict[str, Any],
118
+ n_states_prop: int,
119
+ n_controls: int,
120
+ max_tau_len: int,
121
+ save_compiled: bool = False,
122
+ ) -> callable:
123
+ """Load propagation solver from cache or compile and cache it.
124
+
125
+ Args:
126
+ propagation_solver: The solver function to compile
127
+ cache_file: Path to cache file
128
+ params: Parameters dictionary
129
+ n_states_prop: Number of propagation state variables
130
+ n_controls: Number of control variables
131
+ max_tau_len: Maximum tau length for propagation
132
+ save_compiled: Whether to save/load compiled solvers
133
+
134
+ Returns:
135
+ Compiled propagation solver
136
+ """
137
+ if save_compiled:
138
+ try:
139
+ with open(cache_file, "rb") as f:
140
+ serial_prop = f.read()
141
+ compiled_solver = export.deserialize(serial_prop)
142
+ print("✓ Loaded existing propagation solver")
143
+ return compiled_solver
144
+ except FileNotFoundError:
145
+ print("Compiling propagation solver...")
146
+
147
+ else:
148
+ print("Compiling propagation solver (not saving/loading from disk)...")
149
+
150
+ # Pass parameters as a single dictionary
151
+ compiled_solver = export.export(jax.jit(propagation_solver))(
152
+ np.ones(n_states_prop), # x_0
153
+ (0.0, 0.0), # time span
154
+ np.ones((1, n_controls)), # controls_current
155
+ np.ones((1, n_controls)), # controls_next
156
+ np.ones((1, 1)), # tau_0
157
+ np.ones((1, 1)).astype("int"), # segment index
158
+ 0, # idx_s_stop
159
+ np.ones((max_tau_len,)), # save_time (tau_cur_padded)
160
+ np.ones((max_tau_len,), dtype=bool), # mask_padded (boolean mask)
161
+ params, # additional parameters as dict
162
+ )
163
+
164
+ if save_compiled:
165
+ with open(cache_file, "wb") as f:
166
+ f.write(compiled_solver.serialize())
167
+ print("✓ Propagation solver compiled and saved")
168
+
169
+ return compiled_solver
170
+
171
+
172
+ def prime_propagation_solver(
173
+ propagation_solver: callable, params: Dict[str, Any], settings
174
+ ) -> None:
175
+ """Prime the propagation solver with a test call to ensure it works.
176
+
177
+ Args:
178
+ propagation_solver: Compiled propagation solver
179
+ params: Parameters dictionary
180
+ settings: Settings configuration object
181
+ """
182
+ try:
183
+ x_0 = np.ones(settings.sim.x_prop.initial.shape, dtype=settings.sim.x_prop.initial.dtype)
184
+ tau_grid = (0.0, 1.0)
185
+ controls_current = np.ones((1, settings.sim.u.shape[0]), dtype=settings.sim.u.guess.dtype)
186
+ controls_next = np.ones((1, settings.sim.u.shape[0]), dtype=settings.sim.u.guess.dtype)
187
+ tau_init = np.array([[0.0]], dtype=np.float64)
188
+ node = np.array([[0]], dtype=np.int64)
189
+ idx_s_stop = settings.sim.time_dilation_slice.stop
190
+ save_time = np.ones((settings.prp.max_tau_len,), dtype=np.float64)
191
+ mask_padded = np.ones((settings.prp.max_tau_len,), dtype=bool)
192
+ # Create dummy params dict with same structure
193
+ dummy_params = {
194
+ name: np.ones_like(value) if hasattr(value, "shape") else float(value)
195
+ for name, value in params.items()
196
+ }
197
+ propagation_solver.call(
198
+ x_0,
199
+ tau_grid,
200
+ controls_current,
201
+ controls_next,
202
+ tau_init,
203
+ node,
204
+ idx_s_stop,
205
+ save_time,
206
+ mask_padded,
207
+ dummy_params,
208
+ )
209
+ except Exception as e:
210
+ print(f"[Initialization] Priming propagation_solver.call failed: {e}")
@@ -0,0 +1,301 @@
1
+ import queue
2
+ import sys
3
+ import time
4
+ import warnings
5
+ from importlib.metadata import PackageNotFoundError, version
6
+
7
+ import jax
8
+ import numpy as np
9
+ from termcolor import colored
10
+
11
+ from openscvx.algorithms import OptimizationResults
12
+
13
+ warnings.filterwarnings("ignore")
14
+
15
+
16
+ # Define colors for printing
17
+ col_main = "blue"
18
+ col_pos = "green"
19
+ col_neg = "red"
20
+
21
+
22
+ def get_version() -> str:
23
+ try:
24
+ return version("openscvx")
25
+ except PackageNotFoundError:
26
+ return "0.0.0"
27
+
28
+
29
+ def print_summary_box(lines, title="Summary"):
30
+ """
31
+ Print a centered summary box with the given lines.
32
+
33
+ Args:
34
+ lines (list): List of strings to display in the box
35
+ title (str): Title for the box (default: "Summary")
36
+ """
37
+ # Find the longest line (excluding the title which will be handled separately)
38
+ content_lines = lines[1:] if len(lines) > 1 else []
39
+ max_content_width = max(len(line) for line in content_lines) if content_lines else 0
40
+ title_width = len(title)
41
+
42
+ # Box width should accommodate both title and content
43
+ box_width = max(max_content_width, title_width) + 4 # Add padding for the box borders
44
+
45
+ # Center with respect to the 89-character horizontal lines in io.py
46
+ total_width = 89
47
+ if box_width <= total_width:
48
+ indent = (total_width - box_width) // 2
49
+ else:
50
+ # If box is wider than 89 chars, use a smaller fixed indentation
51
+ indent = 2
52
+
53
+ # Print the box with dynamic width and centering
54
+ print(f"\n{' ' * indent}╭{'─' * box_width}╮")
55
+ print(f"{' ' * indent}│ {title:^{box_width - 2}} │")
56
+ print(f"{' ' * indent}├{'─' * box_width}┤")
57
+ for line in content_lines:
58
+ print(f"{' ' * indent}│ {line:<{box_width - 2}} │")
59
+ print(f"{' ' * indent}╰{'─' * box_width}╯\n")
60
+
61
+
62
+ def print_problem_summary(settings, lowered):
63
+ """
64
+ Print the problem summary box.
65
+
66
+ Args:
67
+ settings: Configuration settings containing problem information
68
+ lowered: LoweredProblem from lower_symbolic_problem()
69
+ """
70
+ n_nodal_convex = len(lowered.cvxpy_constraints.constraints)
71
+ n_nodal_nonconvex = len(lowered.jax_constraints.nodal)
72
+ n_ctcs = len(lowered.jax_constraints.ctcs)
73
+ n_augmented = settings.sim.n_states - settings.sim.true_state_slice.stop
74
+
75
+ # Count CVXPy variables, parameters, and constraints
76
+ from openscvx.solvers import optimal_control_problem
77
+
78
+ try:
79
+ # Build OCP using LoweredProblem
80
+ prob = optimal_control_problem(settings, lowered)
81
+
82
+ # Get the actual problem size information like CVXPy verbose output
83
+ n_cvx_variables = sum(var.size for var in prob.variables())
84
+ n_cvx_parameters = sum(param.size for param in prob.parameters())
85
+ n_cvx_constraints = sum(constraint.size for constraint in prob.constraints)
86
+ except Exception:
87
+ # Fallback if problem construction fails
88
+ n_cvx_variables = 0
89
+ n_cvx_parameters = 0
90
+ n_cvx_constraints = 0
91
+
92
+ # Get JAX backend information
93
+ jax_backend = jax.devices()[0].platform.upper()
94
+ jax_version = jax.__version__
95
+
96
+ # Build weights string conditionally
97
+ if isinstance(settings.scp.lam_vc, np.ndarray):
98
+ lam_vc_str = f"λ_vc=matrix({settings.scp.lam_vc.shape})"
99
+ else:
100
+ lam_vc_str = f"λ_vc={settings.scp.lam_vc:4.1f}"
101
+ weights_parts = [
102
+ f"λ_cost={settings.scp.lam_cost:4.1f}",
103
+ f"λ_tr={settings.scp.w_tr:4.1f}",
104
+ lam_vc_str,
105
+ ]
106
+
107
+ # Add λ_vb only if there are nodal nonconvex constraints
108
+ if n_nodal_nonconvex > 0:
109
+ weights_parts.append(f"λ_vb={settings.scp.lam_vb:4.1f}")
110
+
111
+ weights_str = ", ".join(weights_parts)
112
+
113
+ lines = [
114
+ "Problem Summary",
115
+ (
116
+ f"Dimensions: {settings.sim.n_states} states ({n_augmented} aug),"
117
+ f" {settings.sim.n_controls} controls, {settings.scp.n} nodes"
118
+ ),
119
+ f"Constraints: {n_nodal_convex} conv, {n_nodal_nonconvex} nonconv, {n_ctcs} ctcs",
120
+ (
121
+ f"Subproblem: {n_cvx_variables} vars, {n_cvx_parameters} params,"
122
+ f" {n_cvx_constraints} constraints"
123
+ ),
124
+ f"Weights: {weights_str}",
125
+ f"CVX Solver: {settings.cvx.solver}, Discretization Solver: {settings.dis.solver}",
126
+ f"JAX Backend: {jax_backend} (v{jax_version})",
127
+ ]
128
+
129
+ print_summary_box(lines, "Problem Summary")
130
+
131
+
132
+ def print_results_summary(result: OptimizationResults, timing_post, timing_init, timing_solve):
133
+ """
134
+ Print the results summary box.
135
+
136
+ Args:
137
+ result (OptimizationResults): Optimization results object
138
+ timing_post (float): Post-processing time
139
+ timing_init (float): Initialization time
140
+ timing_solve (float): Solve time
141
+ """
142
+ cost = result.get("cost", 0.0)
143
+ ctcs_violation = result.get("ctcs_violation", 0.0)
144
+
145
+ # Convert numpy arrays to scalars for formatting
146
+ if hasattr(cost, "item"):
147
+ cost = cost.item()
148
+
149
+ # Handle CTCS violation - display as 1D array
150
+ if hasattr(ctcs_violation, "size"):
151
+ if ctcs_violation.size == 1:
152
+ ctcs_violation_str = f"[{ctcs_violation.item():.2e}]"
153
+ else:
154
+ # Display as 1D array
155
+ ctcs_violation_str = f"[{', '.join([f'{v:.2e}' for v in ctcs_violation])}]"
156
+ else:
157
+ ctcs_violation_str = f"[{ctcs_violation:.2e}]"
158
+
159
+ # Calculate total computation time
160
+ total_time = (timing_init or 0.0) + (timing_solve or 0.0) + timing_post
161
+
162
+ lines = [
163
+ "Results Summary",
164
+ f"Cost: {cost:.6f}",
165
+ f"CTCS Constraint Violation: {ctcs_violation_str}",
166
+ f"Preprocessing Time: {timing_init or 0.0:.3f}s",
167
+ f"Main Solve Time: {timing_solve or 0.0:.3f}s",
168
+ f"Post-processing Time: {timing_post:.3f}s",
169
+ f"Total Computation Time: {total_time:.3f}s",
170
+ ]
171
+
172
+ print_summary_box(lines, "Results Summary")
173
+
174
+
175
+ def intro():
176
+ # Silence syntax warnings
177
+ warnings.filterwarnings("ignore")
178
+ # fmt: off
179
+ ascii_art = rf"""
180
+
181
+ ____ _____ _____
182
+ / __ \ / ____|/ ____|
183
+ | | | |_ __ ___ _ __ | (___ | | __ ____ __
184
+ | | | | '_ \ / _ \ '_ \ \___ \| | \ \ / /\ \/ /
185
+ | |__| | |_) | __/ | | |____) | |___\ V / > <
186
+ \____/| .__/ \___|_| |_|_____/ \_____\_/ /_/\_\
187
+ | |
188
+ |_|
189
+ ─────────────────────────────────────────────────────────────────────────────────────────────────────────
190
+ Author: Chris Hayner and Griffin Norris
191
+ Autonomous Controls Laboratory
192
+ University of Washington
193
+ Version: {get_version()}
194
+ ─────────────────────────────────────────────────────────────────────────────────────────────────────────
195
+ """
196
+ # fmt: on
197
+ print(ascii_art)
198
+
199
+
200
+ def header():
201
+ print(
202
+ colored(
203
+ "─────────────────────────────────────────────────────────────────────────────────────────────────────────"
204
+ )
205
+ )
206
+ print(
207
+ "{:^4} │ {:^7} │ {:^7} │ {:^7} │ {:^7} │ {:^7} │ {:^7} │ {:^7} │ {:^14}".format(
208
+ "Iter",
209
+ "Dis Time (ms)",
210
+ "Solve Time (ms)",
211
+ "J_total",
212
+ "J_tr",
213
+ "J_vb",
214
+ "J_vc",
215
+ "Cost",
216
+ "Solver Status",
217
+ )
218
+ )
219
+ print(
220
+ colored(
221
+ "─────────────────────────────────────────────────────────────────────────────────────────────────────────"
222
+ )
223
+ )
224
+
225
+
226
+ def intermediate(print_queue, params):
227
+ hz = 30.0
228
+ while True:
229
+ t_start = time.time()
230
+ try:
231
+ data = print_queue.get(timeout=1.0 / hz)
232
+ # remove bottom labels and line
233
+ if data["iter"] != 1:
234
+ sys.stdout.write("\x1b[1A\x1b[2K\x1b[1A\x1b[2K")
235
+ if data["prob_stat"][3] == "f":
236
+ # Only show the first element of the string
237
+ data["prob_stat"] = data["prob_stat"][0]
238
+
239
+ iter_colored = colored("{:4d}".format(data["iter"]))
240
+ J_tot_colored = colored("{:.1e}".format(data["J_total"]))
241
+ J_tr_colored = colored(
242
+ "{:.1e}".format(data["J_tr"]),
243
+ col_pos if data["J_tr"] <= params.scp.ep_tr else col_neg,
244
+ )
245
+ J_vb_colored = colored(
246
+ "{:.1e}".format(data["J_vb"]),
247
+ col_pos if data["J_vb"] <= params.scp.ep_vb else col_neg,
248
+ )
249
+ J_vc_colored = colored(
250
+ "{:.1e}".format(data["J_vc"]),
251
+ col_pos if data["J_vc"] <= params.scp.ep_vc else col_neg,
252
+ )
253
+ cost_colored = colored("{:.1e}".format(data["cost"]))
254
+ prob_stat_colored = colored(
255
+ data["prob_stat"], col_pos if data["prob_stat"] == "optimal" else col_neg
256
+ )
257
+
258
+ print(
259
+ "{:^4} │ {:^6.2f} │ {:^6.2F} │ {:^7} │ {:^7} │ {:^7} │ {:^7} │ "
260
+ " {:^7} │ {:^14}".format(
261
+ iter_colored,
262
+ data["dis_time"],
263
+ data["subprop_time"],
264
+ J_tot_colored,
265
+ J_tr_colored,
266
+ J_vb_colored,
267
+ J_vc_colored,
268
+ cost_colored,
269
+ prob_stat_colored,
270
+ )
271
+ )
272
+
273
+ print(
274
+ colored(
275
+ "─────────────────────────────────────────────────────────────────────────────────────────────────────────"
276
+ )
277
+ )
278
+ print(
279
+ "{:^4} │ {:^7} │ {:^7} │ {:^7} │ {:^7} │ {:^7} │ {:^7} │ {:^7} │ {:^14}".format(
280
+ "Iter",
281
+ "Dis Time (ms)",
282
+ "Solve Time (ms)",
283
+ "J_total",
284
+ "J_tr",
285
+ "J_vb",
286
+ "J_vc",
287
+ "Cost",
288
+ "Solver Status",
289
+ )
290
+ )
291
+ except queue.Empty:
292
+ pass
293
+ time.sleep(max(0.0, 1.0 / hz - (time.time() - t_start)))
294
+
295
+
296
+ def footer():
297
+ print(
298
+ colored(
299
+ "─────────────────────────────────────────────────────────────────────────────────────────────────────────"
300
+ )
301
+ )
@@ -0,0 +1,37 @@
1
+ from datetime import datetime
2
+ from typing import TYPE_CHECKING, Optional
3
+
4
+ if TYPE_CHECKING:
5
+ import cProfile
6
+
7
+
8
+ def profiling_start(profiling_enabled: bool) -> "Optional[cProfile.Profile]":
9
+ """Start profiling if enabled.
10
+
11
+ Args:
12
+ profiling_enabled: Whether to enable profiling.
13
+
14
+ Returns:
15
+ Profile object if enabled, None otherwise.
16
+ """
17
+ if profiling_enabled:
18
+ import cProfile
19
+
20
+ pr = cProfile.Profile()
21
+ pr.enable()
22
+ return pr
23
+ return None
24
+
25
+
26
+ def profiling_end(pr: "Optional[cProfile.Profile]", identifier: str):
27
+ """Stop profiling and save results with timestamp.
28
+
29
+ Args:
30
+ pr: Profile object from profiling_start, or None.
31
+ identifier: Identifier for the profiling session (e.g., "solve", "initialize").
32
+ """
33
+ if pr is not None:
34
+ pr.disable()
35
+ # Save results so it can be visualized with snakeviz
36
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
37
+ pr.dump_stats(f"profiling/{timestamp}_{identifier}.prof")