cbfpy 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2 @@
1
+ cbfpy
2
+ test
test/__init__.py ADDED
File without changes
test/test_speed.py ADDED
@@ -0,0 +1,191 @@
1
+ """Speed tests for the CBF solver
2
+
3
+ We evaluate the speed of the solver NOT just via the QP solve but via the whole process
4
+ (solving for the nominal control input, constructing the QP matrices, and then solving).
5
+ This provides a more accurate view of what the controller frequency would actually be if
6
+ deployed on the robot.
7
+
8
+ These test cases can also be used to check that modifications to the CBF implementation
9
+ do not significantly degrade performance
10
+ """
11
+
12
+ import unittest
13
+ from typing import Callable
14
+ import time
15
+ import jax
16
+ import numpy as np
17
+ import matplotlib.pyplot as plt
18
+
19
+ from cbfpy import CBF, CLFCBF
20
+ import cbfpy.examples.point_robot_demo as prdemo
21
+ import cbfpy.examples.adaptive_cruise_control_demo as accdemo
22
+
23
+ # Seed RNG for repeatability
24
+ np.random.seed(0)
25
+
26
+
27
+ # TODO Make this work if there are additional arguments for the barrier function
28
+ def eval_speed(
29
+ controller_func: Callable,
30
+ states: np.ndarray,
31
+ des_states: np.ndarray,
32
+ verbose: bool = True,
33
+ plot: bool = True,
34
+ ) -> float:
35
+ """Tests the speed of a controller function via evaluation on a set of inputs
36
+
37
+ Timing details (average solve time / Hz, distributions of times, etc.) can be printed to the terminal
38
+ or visualized in plots, via the `verbose` and `plot` inputs
39
+
40
+ Args:
41
+ controller_func (Callable): Function to time. This should be the highest-level CBF-based controller
42
+ function which includes the nominal controller, QP construction, and QP solve
43
+ states (np.ndarray): Set of states to evaluate on, shape (num_evals, state_dim)
44
+ des_states (np.ndarray): Set of desired states to evaluate on, shape (num_evals, des_state_dim)
45
+ verbose (bool, optional): Whether to print timing details to the terminal. Defaults to True.
46
+ plot (bool, optional): Whether to visualize the distribution of solve times. Defaults to True.
47
+
48
+ Returns:
49
+ float: Average solver Hz
50
+ """
51
+ assert isinstance(controller_func, Callable)
52
+ assert isinstance(states, np.ndarray)
53
+ assert isinstance(des_states, np.ndarray)
54
+ assert isinstance(verbose, bool)
55
+ assert isinstance(plot, bool)
56
+ assert states.shape[0] > 1
57
+ assert states.shape[0] == des_states.shape[0]
58
+ controller_func: Callable = jax.jit(controller_func)
59
+
60
+ # Do an initial solve to jit-compile the function
61
+ start_time = time.perf_counter()
62
+ u = controller_func(states[0], des_states[0])
63
+ first_solve_time = time.perf_counter() - start_time
64
+
65
+ # Solve for the remainder of the controls using the jit-compiled controller
66
+ times = []
67
+ for i in range(1, states.shape[0]):
68
+ start_time = time.perf_counter()
69
+ u = controller_func(states[i], des_states[i]).block_until_ready()
70
+ times.append(time.perf_counter() - start_time)
71
+ times = np.asarray(times)
72
+ avg_solve_time = np.mean(times)
73
+ max_solve_time = np.max(times)
74
+ avg_hz = 1 / avg_solve_time
75
+ worst_case_hz = 1 / max_solve_time
76
+
77
+ if verbose:
78
+ # Print info about solver stats
79
+ print(f"Solved for the first control input in {first_solve_time} seconds")
80
+ print(f"Average solve time: {avg_solve_time} seconds")
81
+ print(f"Average Hz: {avg_hz}")
82
+ print(f"Worst-case solve time: {max_solve_time}")
83
+ print(f"Worst-case Hz: {worst_case_hz}")
84
+ print(
85
+ "NOTE: Worst-case behavior might be inaccurate due to how the OS manages background processes"
86
+ )
87
+
88
+ if plot:
89
+ # Create a plot to visualize the distribution of times
90
+ fig, axs = plt.subplots(2, 2)
91
+ axs[0, 0].hist(times, 20)
92
+ axs[0, 0].set_title("Solve Times")
93
+ axs[0, 0].set_ylabel("Frequency")
94
+ axs[0, 0].set_xscale("log")
95
+ axs[0, 1].boxplot(times, vert=False)
96
+ axs[0, 1].set_title("Solve Times")
97
+ axs[0, 1].set_xscale("log")
98
+ axs[1, 0].hist(1 / times, 20)
99
+ axs[1, 0].set_title("Hz")
100
+ axs[1, 0].set_ylabel("Frequency")
101
+ axs[1, 1].boxplot(1 / times, vert=False)
102
+ axs[1, 1].set_title("Hz")
103
+ plt.show()
104
+
105
+ return avg_hz
106
+
107
+
108
+ @jax.tree_util.register_static
109
+ class PointRobotTest:
110
+ """Test the speed of the point-robot-in-a-box demo, using randomly sampled states"""
111
+
112
+ def __init__(self):
113
+ self.config = prdemo.PointRobotConfig()
114
+ self.cbf = CBF.from_config(self.config)
115
+ self.nominal_controller = prdemo.nominal_controller
116
+ self.pos_min = self.config.pos_min
117
+ self.pos_max = self.config.pos_max
118
+
119
+ def sample_states(self, num_samples: int) -> np.ndarray:
120
+ """Sample a set of random states for the point robot (3D positions and velocities)"""
121
+ # Sample positions uniformly inside the keep-in region
122
+ positions = np.asarray(self.pos_min) + np.random.rand(
123
+ num_samples, 3
124
+ ) * np.subtract(self.pos_max, self.pos_min)
125
+ # Assume x/y/z velocities are sampled uniformly between [-3, 3]
126
+ velocities = -3.0 + 6 * np.random.rand(num_samples, 3)
127
+ return np.column_stack([positions, velocities])
128
+
129
+ @jax.jit
130
+ def controller(self, z, z_des):
131
+ u = self.nominal_controller(z, z_des)
132
+ return self.cbf.safety_filter(z, u)
133
+
134
+ def test_speed(self, verbose: bool = True, plot: bool = True):
135
+ n_samples = 10000
136
+ states = self.sample_states(n_samples)
137
+ desired_states = self.sample_states(n_samples)
138
+ avg_hz = eval_speed(self.controller, states, desired_states, verbose, plot)
139
+ return avg_hz
140
+
141
+
142
+ @jax.tree_util.register_static
143
+ class ACCTest:
144
+ """Test the speed of the adaptive cruise control CLF-CBF demo, using randomly sampled states"""
145
+
146
+ def __init__(self):
147
+ self.config = accdemo.ACCConfig()
148
+ self.clf_cbf = CLFCBF.from_config(self.config)
149
+
150
+ def sample_states(self, num_samples: int) -> np.ndarray:
151
+ """Sample a set of random states for the adaptive cruise control demo"""
152
+ follower_vels = np.random.rand(num_samples) * 20
153
+ leader_vels = np.random.rand(num_samples) * 40
154
+ distances = 10 + np.random.rand(num_samples) * 100
155
+ return np.column_stack([follower_vels, leader_vels, distances])
156
+
157
+ @jax.jit
158
+ def controller(self, z, z_des):
159
+ return self.clf_cbf.controller(z, z_des)
160
+
161
+ def test_speed(self, verbose: bool = True, plot: bool = True):
162
+ n_samples = 10000
163
+ states = self.sample_states(n_samples)
164
+ # Note that the desired states aren't really relevant for this specific demo
165
+ # just due to how the ACC problem was constructed
166
+ desired_states = self.sample_states(n_samples)
167
+ avg_hz = eval_speed(self.controller, states, desired_states, verbose, plot)
168
+ return avg_hz
169
+
170
+
171
+ class SpeedTest(unittest.TestCase):
172
+ """Test case to guarantee that the CBFs run at least at kilohertz rates"""
173
+
174
+ @classmethod
175
+ def setUpClass(cls) -> None:
176
+ cls.point_robot_test = PointRobotTest()
177
+ cls.acc_test = ACCTest()
178
+
179
+ def test_point_robot(self):
180
+ avg_hz = self.point_robot_test.test_speed(verbose=False, plot=False)
181
+ print("Point robot average Hz: ", avg_hz)
182
+ self.assertTrue(avg_hz >= 1000)
183
+
184
+ def test_acc(self):
185
+ avg_hz = self.acc_test.test_speed(verbose=False, plot=False)
186
+ print("ACC average Hz: ", avg_hz)
187
+ self.assertTrue(avg_hz >= 1000)
188
+
189
+
190
+ if __name__ == "__main__":
191
+ unittest.main()
test/test_utils.py ADDED
@@ -0,0 +1,34 @@
1
+ """Unit tests for utility functions"""
2
+
3
+ import unittest
4
+ import numpy as np
5
+ import jax
6
+
7
+ import cbfpy.utils.math_utils as math_utils
8
+
9
+
10
+ TEST_JIT = False
11
+
12
+
13
+ class TestUtils(unittest.TestCase):
14
+ """Unit tests for utility functions"""
15
+
16
+ def test_normalize(self):
17
+ if TEST_JIT:
18
+ normalize = jax.jit(math_utils.normalize)
19
+ else:
20
+ normalize = math_utils.normalize
21
+
22
+ # Test single vector
23
+ vec = np.array([1, 2, 3])
24
+ self.assertTrue(np.allclose(normalize(vec), vec / np.linalg.norm(vec)))
25
+
26
+ # Test multiple vectors
27
+ vecs = np.array([[1, 2, 3], [4, 5, 6]])
28
+ normalized = normalize(vecs)
29
+ for i, vec in enumerate(vecs):
30
+ self.assertTrue(np.allclose(normalized[i], vec / np.linalg.norm(vec)))
31
+
32
+
33
+ if __name__ == "__main__":
34
+ unittest.main()