cbfpy 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cbfpy/__init__.py +11 -0
- cbfpy/cbfs/__init__.py +0 -0
- cbfpy/cbfs/cbf.py +384 -0
- cbfpy/cbfs/clf_cbf.py +490 -0
- cbfpy/config/__init__.py +0 -0
- cbfpy/config/cbf_config.py +401 -0
- cbfpy/config/clf_cbf_config.py +251 -0
- cbfpy/envs/__init__.py +0 -0
- cbfpy/envs/arm_envs.py +84 -0
- cbfpy/envs/base_env.py +69 -0
- cbfpy/envs/car_env.py +332 -0
- cbfpy/envs/drone_env.py +153 -0
- cbfpy/envs/point_robot_envs.py +209 -0
- cbfpy/examples/__init__.py +0 -0
- cbfpy/examples/adaptive_cruise_control_demo.py +117 -0
- cbfpy/examples/drone_demo.py +109 -0
- cbfpy/examples/joint_limits_demo.py +150 -0
- cbfpy/examples/point_robot_demo.py +91 -0
- cbfpy/examples/point_robot_obstacle_demo.py +118 -0
- cbfpy/temp/test_import.py +3 -0
- cbfpy/utils/__init__.py +0 -0
- cbfpy/utils/general_utils.py +131 -0
- cbfpy/utils/jax_utils.py +26 -0
- cbfpy/utils/math_utils.py +21 -0
- cbfpy/utils/visualization.py +93 -0
- cbfpy-0.0.1.dist-info/LICENSE +21 -0
- cbfpy-0.0.1.dist-info/METADATA +226 -0
- cbfpy-0.0.1.dist-info/RECORD +33 -0
- cbfpy-0.0.1.dist-info/WHEEL +5 -0
- cbfpy-0.0.1.dist-info/top_level.txt +2 -0
- test/__init__.py +0 -0
- test/test_speed.py +191 -0
- test/test_utils.py +34 -0
test/__init__.py
ADDED
|
File without changes
|
test/test_speed.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
"""Speed tests for the CBF solver
|
|
2
|
+
|
|
3
|
+
We evaluate the speed of the solver NOT just via the QP solve but via the whole process
|
|
4
|
+
(solving for the nominal control input, constructing the QP matrices, and then solving).
|
|
5
|
+
This provides a more accurate view of what the controller frequency would actually be if
|
|
6
|
+
deployed on the robot.
|
|
7
|
+
|
|
8
|
+
These test cases can also be used to check that modifications to the CBF implementation
|
|
9
|
+
do not significantly degrade performance
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import unittest
|
|
13
|
+
from typing import Callable
|
|
14
|
+
import time
|
|
15
|
+
import jax
|
|
16
|
+
import numpy as np
|
|
17
|
+
import matplotlib.pyplot as plt
|
|
18
|
+
|
|
19
|
+
from cbfpy import CBF, CLFCBF
|
|
20
|
+
import cbfpy.examples.point_robot_demo as prdemo
|
|
21
|
+
import cbfpy.examples.adaptive_cruise_control_demo as accdemo
|
|
22
|
+
|
|
23
|
+
# Seed RNG for repeatability
|
|
24
|
+
np.random.seed(0)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# TODO Make this work if there are additional arguments for the barrier function
|
|
28
|
+
def eval_speed(
|
|
29
|
+
controller_func: Callable,
|
|
30
|
+
states: np.ndarray,
|
|
31
|
+
des_states: np.ndarray,
|
|
32
|
+
verbose: bool = True,
|
|
33
|
+
plot: bool = True,
|
|
34
|
+
) -> float:
|
|
35
|
+
"""Tests the speed of a controller function via evaluation on a set of inputs
|
|
36
|
+
|
|
37
|
+
Timing details (average solve time / Hz, distributions of times, etc.) can be printed to the terminal
|
|
38
|
+
or visualized in plots, via the `verbose` and `plot` inputs
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
controller_func (Callable): Function to time. This should be the highest-level CBF-based controller
|
|
42
|
+
function which includes the nominal controller, QP construction, and QP solve
|
|
43
|
+
states (np.ndarray): Set of states to evaluate on, shape (num_evals, state_dim)
|
|
44
|
+
des_states (np.ndarray): Set of desired states to evaluate on, shape (num_evals, des_state_dim)
|
|
45
|
+
verbose (bool, optional): Whether to print timing details to the terminal. Defaults to True.
|
|
46
|
+
plot (bool, optional): Whether to visualize the distribution of solve times. Defaults to True.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
float: Average solver Hz
|
|
50
|
+
"""
|
|
51
|
+
assert isinstance(controller_func, Callable)
|
|
52
|
+
assert isinstance(states, np.ndarray)
|
|
53
|
+
assert isinstance(des_states, np.ndarray)
|
|
54
|
+
assert isinstance(verbose, bool)
|
|
55
|
+
assert isinstance(plot, bool)
|
|
56
|
+
assert states.shape[0] > 1
|
|
57
|
+
assert states.shape[0] == des_states.shape[0]
|
|
58
|
+
controller_func: Callable = jax.jit(controller_func)
|
|
59
|
+
|
|
60
|
+
# Do an initial solve to jit-compile the function
|
|
61
|
+
start_time = time.perf_counter()
|
|
62
|
+
u = controller_func(states[0], des_states[0])
|
|
63
|
+
first_solve_time = time.perf_counter() - start_time
|
|
64
|
+
|
|
65
|
+
# Solve for the remainder of the controls using the jit-compiled controller
|
|
66
|
+
times = []
|
|
67
|
+
for i in range(1, states.shape[0]):
|
|
68
|
+
start_time = time.perf_counter()
|
|
69
|
+
u = controller_func(states[i], des_states[i]).block_until_ready()
|
|
70
|
+
times.append(time.perf_counter() - start_time)
|
|
71
|
+
times = np.asarray(times)
|
|
72
|
+
avg_solve_time = np.mean(times)
|
|
73
|
+
max_solve_time = np.max(times)
|
|
74
|
+
avg_hz = 1 / avg_solve_time
|
|
75
|
+
worst_case_hz = 1 / max_solve_time
|
|
76
|
+
|
|
77
|
+
if verbose:
|
|
78
|
+
# Print info about solver stats
|
|
79
|
+
print(f"Solved for the first control input in {first_solve_time} seconds")
|
|
80
|
+
print(f"Average solve time: {avg_solve_time} seconds")
|
|
81
|
+
print(f"Average Hz: {avg_hz}")
|
|
82
|
+
print(f"Worst-case solve time: {max_solve_time}")
|
|
83
|
+
print(f"Worst-case Hz: {worst_case_hz}")
|
|
84
|
+
print(
|
|
85
|
+
"NOTE: Worst-case behavior might be inaccurate due to how the OS manages background processes"
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
if plot:
|
|
89
|
+
# Create a plot to visualize the distribution of times
|
|
90
|
+
fig, axs = plt.subplots(2, 2)
|
|
91
|
+
axs[0, 0].hist(times, 20)
|
|
92
|
+
axs[0, 0].set_title("Solve Times")
|
|
93
|
+
axs[0, 0].set_ylabel("Frequency")
|
|
94
|
+
axs[0, 0].set_xscale("log")
|
|
95
|
+
axs[0, 1].boxplot(times, vert=False)
|
|
96
|
+
axs[0, 1].set_title("Solve Times")
|
|
97
|
+
axs[0, 1].set_xscale("log")
|
|
98
|
+
axs[1, 0].hist(1 / times, 20)
|
|
99
|
+
axs[1, 0].set_title("Hz")
|
|
100
|
+
axs[1, 0].set_ylabel("Frequency")
|
|
101
|
+
axs[1, 1].boxplot(1 / times, vert=False)
|
|
102
|
+
axs[1, 1].set_title("Hz")
|
|
103
|
+
plt.show()
|
|
104
|
+
|
|
105
|
+
return avg_hz
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@jax.tree_util.register_static
|
|
109
|
+
class PointRobotTest:
|
|
110
|
+
"""Test the speed of the point-robot-in-a-box demo, using randomly sampled states"""
|
|
111
|
+
|
|
112
|
+
def __init__(self):
|
|
113
|
+
self.config = prdemo.PointRobotConfig()
|
|
114
|
+
self.cbf = CBF.from_config(self.config)
|
|
115
|
+
self.nominal_controller = prdemo.nominal_controller
|
|
116
|
+
self.pos_min = self.config.pos_min
|
|
117
|
+
self.pos_max = self.config.pos_max
|
|
118
|
+
|
|
119
|
+
def sample_states(self, num_samples: int) -> np.ndarray:
|
|
120
|
+
"""Sample a set of random states for the point robot (3D positions and velocities)"""
|
|
121
|
+
# Sample positions uniformly inside the keep-in region
|
|
122
|
+
positions = np.asarray(self.pos_min) + np.random.rand(
|
|
123
|
+
num_samples, 3
|
|
124
|
+
) * np.subtract(self.pos_max, self.pos_min)
|
|
125
|
+
# Assume x/y/z velocities are sampled uniformly between [-3, 3]
|
|
126
|
+
velocities = -3.0 + 6 * np.random.rand(num_samples, 3)
|
|
127
|
+
return np.column_stack([positions, velocities])
|
|
128
|
+
|
|
129
|
+
@jax.jit
|
|
130
|
+
def controller(self, z, z_des):
|
|
131
|
+
u = self.nominal_controller(z, z_des)
|
|
132
|
+
return self.cbf.safety_filter(z, u)
|
|
133
|
+
|
|
134
|
+
def test_speed(self, verbose: bool = True, plot: bool = True):
|
|
135
|
+
n_samples = 10000
|
|
136
|
+
states = self.sample_states(n_samples)
|
|
137
|
+
desired_states = self.sample_states(n_samples)
|
|
138
|
+
avg_hz = eval_speed(self.controller, states, desired_states, verbose, plot)
|
|
139
|
+
return avg_hz
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@jax.tree_util.register_static
|
|
143
|
+
class ACCTest:
|
|
144
|
+
"""Test the speed of the adaptive cruise control CLF-CBF demo, using randomly sampled states"""
|
|
145
|
+
|
|
146
|
+
def __init__(self):
|
|
147
|
+
self.config = accdemo.ACCConfig()
|
|
148
|
+
self.clf_cbf = CLFCBF.from_config(self.config)
|
|
149
|
+
|
|
150
|
+
def sample_states(self, num_samples: int) -> np.ndarray:
|
|
151
|
+
"""Sample a set of random states for the adaptive cruise control demo"""
|
|
152
|
+
follower_vels = np.random.rand(num_samples) * 20
|
|
153
|
+
leader_vels = np.random.rand(num_samples) * 40
|
|
154
|
+
distances = 10 + np.random.rand(num_samples) * 100
|
|
155
|
+
return np.column_stack([follower_vels, leader_vels, distances])
|
|
156
|
+
|
|
157
|
+
@jax.jit
|
|
158
|
+
def controller(self, z, z_des):
|
|
159
|
+
return self.clf_cbf.controller(z, z_des)
|
|
160
|
+
|
|
161
|
+
def test_speed(self, verbose: bool = True, plot: bool = True):
|
|
162
|
+
n_samples = 10000
|
|
163
|
+
states = self.sample_states(n_samples)
|
|
164
|
+
# Note that the desired states aren't really relevant for this specific demo
|
|
165
|
+
# just due to how the ACC problem was constructed
|
|
166
|
+
desired_states = self.sample_states(n_samples)
|
|
167
|
+
avg_hz = eval_speed(self.controller, states, desired_states, verbose, plot)
|
|
168
|
+
return avg_hz
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class SpeedTest(unittest.TestCase):
|
|
172
|
+
"""Test case to guarantee that the CBFs run at least at kilohertz rates"""
|
|
173
|
+
|
|
174
|
+
@classmethod
|
|
175
|
+
def setUpClass(cls) -> None:
|
|
176
|
+
cls.point_robot_test = PointRobotTest()
|
|
177
|
+
cls.acc_test = ACCTest()
|
|
178
|
+
|
|
179
|
+
def test_point_robot(self):
|
|
180
|
+
avg_hz = self.point_robot_test.test_speed(verbose=False, plot=False)
|
|
181
|
+
print("Point robot average Hz: ", avg_hz)
|
|
182
|
+
self.assertTrue(avg_hz >= 1000)
|
|
183
|
+
|
|
184
|
+
def test_acc(self):
|
|
185
|
+
avg_hz = self.acc_test.test_speed(verbose=False, plot=False)
|
|
186
|
+
print("ACC average Hz: ", avg_hz)
|
|
187
|
+
self.assertTrue(avg_hz >= 1000)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
if __name__ == "__main__":
|
|
191
|
+
unittest.main()
|
test/test_utils.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Unit tests for utility functions"""
|
|
2
|
+
|
|
3
|
+
import unittest
|
|
4
|
+
import numpy as np
|
|
5
|
+
import jax
|
|
6
|
+
|
|
7
|
+
import cbfpy.utils.math_utils as math_utils
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
TEST_JIT = False
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TestUtils(unittest.TestCase):
|
|
14
|
+
"""Unit tests for utility functions"""
|
|
15
|
+
|
|
16
|
+
def test_normalize(self):
|
|
17
|
+
if TEST_JIT:
|
|
18
|
+
normalize = jax.jit(math_utils.normalize)
|
|
19
|
+
else:
|
|
20
|
+
normalize = math_utils.normalize
|
|
21
|
+
|
|
22
|
+
# Test single vector
|
|
23
|
+
vec = np.array([1, 2, 3])
|
|
24
|
+
self.assertTrue(np.allclose(normalize(vec), vec / np.linalg.norm(vec)))
|
|
25
|
+
|
|
26
|
+
# Test multiple vectors
|
|
27
|
+
vecs = np.array([[1, 2, 3], [4, 5, 6]])
|
|
28
|
+
normalized = normalize(vecs)
|
|
29
|
+
for i, vec in enumerate(vecs):
|
|
30
|
+
self.assertTrue(np.allclose(normalized[i], vec / np.linalg.norm(vec)))
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
if __name__ == "__main__":
|
|
34
|
+
unittest.main()
|