cbfpy 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cbfpy/__init__.py ADDED
@@ -0,0 +1,11 @@
1
+ """CBFpy: Control Barrier Functions in Python and Jax"""
2
+
3
+ import jax as _jax
4
+
5
+ # 64 bit precision is generally necessary for these problems to be feasible
6
+ _jax.config.update("jax_enable_x64", True)
7
+
8
+ from cbfpy.cbfs.cbf import CBF
9
+ from cbfpy.cbfs.clf_cbf import CLFCBF
10
+ from cbfpy.config.cbf_config import CBFConfig
11
+ from cbfpy.config.clf_cbf_config import CLFCBFConfig
cbfpy/cbfs/__init__.py ADDED
File without changes
cbfpy/cbfs/cbf.py ADDED
@@ -0,0 +1,384 @@
1
+ """
2
+ # Control Barrier Functions (CBFs)
3
+
4
+ CBFs serve as safety filters on top of a nominal controller. Given a nominal control input, the CBF will compute a
5
+ safe control input to keep the system within a safe set.
6
+
7
+ For a relative-degree-1 system, this optimizes the standard min-norm objective with the constraint
8
+ `h_dot >= -alpha(h(z))`
9
+ ```
10
+ minimize ||u - u_des||_{2}^{2} # CBF Objective (Example)
11
+ subject to Lfh(z) + Lgh(z)u >= -alpha(h(z)) # RD1 CBF Constraint
12
+ ```
13
+
14
+ In the case of a relative-degree-2 system, this differs slightly to enforce the RD2 constraint
15
+ `h_2_dot >= -alpha_2(h_2(z))`
16
+ ```
17
+ minimize ||u - u_des||_{2}^{2} # CBF Objective (Example)
18
+ subject to Lfh_2(z) + Lgh_2(z)u >= -alpha_2(h_2(z)) # RD2 CBF Constraint
19
+ ```
20
+
21
+ If there are constraints on the control input, we also enforce another constraint:
22
+ ```
23
+ u_min <= u <= u_max # Control constraint
24
+ ```
25
+ """
26
+
27
+ from typing import Tuple, Callable, Optional
28
+
29
+ import jax
30
+ import jax.numpy as jnp
31
+ from jax import Array
32
+ from jax.typing import ArrayLike
33
+ import qpax
34
+
35
+ from cbfpy.config.cbf_config import CBFConfig
36
+ from cbfpy.utils.jax_utils import conditional_jit
37
+ from cbfpy.utils.general_utils import print_warning
38
+
39
+ # Debugging flags to disable jit in specific sections of the code.
40
+ # Note: If any higher-level jits exist, those must also be set to debug (disable jit)
41
+ DEBUG_SAFETY_FILTER = False
42
+ DEBUG_QP_DATA = False
43
+
44
+
45
+ @jax.tree_util.register_static
46
+ class CBF:
47
+ """Control Barrier Function (CBF) class.
48
+
49
+ The main constructor for this class is via the `from_config` method, which constructs a CBF instance
50
+ based on the provided CBFConfig configuration object.
51
+
52
+ You can then use the CBF's `safety_filter` method to compute the control input that satisfies the CBF
53
+
54
+ Examples:
55
+ ```
56
+ # Construct a CBFConfig for your problem
57
+ config = DroneConfig()
58
+ # Construct a CBF instance based on the config
59
+ cbf = CBF.from_config(config)
60
+ # Compute the safe control input
61
+ safe_control = cbf.safety_filter(current_state, nominal_control)
62
+ ```
63
+ """
64
+
65
+ # NOTE: The __init__ method is not used to construct a CBF instance. Instead, use the `from_config` method.
66
+ # This is because Jax prefers for the __init__ method to not contain any input validation, so we do this
67
+ # in the CBFConfig class instead.
68
+ def __init__(
69
+ self,
70
+ n: int,
71
+ m: int,
72
+ num_cbf: int,
73
+ u_min: Optional[tuple],
74
+ u_max: Optional[tuple],
75
+ control_constrained: bool,
76
+ relax_cbf: bool,
77
+ cbf_relaxation_penalty: float,
78
+ h_1: Callable[[ArrayLike], Array],
79
+ h_2: Callable[[ArrayLike], Array],
80
+ f: Callable[[ArrayLike], Array],
81
+ g: Callable[[ArrayLike], Array],
82
+ alpha: Callable[[ArrayLike], Array],
83
+ alpha_2: Callable[[ArrayLike], Array],
84
+ P: Callable[[ArrayLike, ArrayLike, Tuple[ArrayLike, ...]], Array],
85
+ q: Callable[[ArrayLike, ArrayLike, Tuple[ArrayLike, ...]], Array],
86
+ solver_tol: float,
87
+ ):
88
+ self.n = n
89
+ self.m = m
90
+ self.num_cbf = num_cbf
91
+ self.u_min = u_min
92
+ self.u_max = u_max
93
+ self.control_constrained = control_constrained
94
+ self.relax_cbf = relax_cbf
95
+ self.cbf_relaxation_penalty = cbf_relaxation_penalty
96
+ self.h_1 = h_1
97
+ self.h_2 = h_2
98
+ self.f = f
99
+ self.g = g
100
+ self.alpha = alpha
101
+ self.alpha_2 = alpha_2
102
+ self.P_config = P
103
+ self.q_config = q
104
+ self.solver_tol = solver_tol
105
+ if relax_cbf:
106
+ self.qp_solver: Callable = jax.jit(qpax.solve_qp_elastic)
107
+ else:
108
+ self.qp_solver: Callable = jax.jit(qpax.solve_qp)
109
+
110
+ @classmethod
111
+ def from_config(cls, config: CBFConfig) -> "CBF":
112
+ """Construct a CBF based on the provided configuration
113
+
114
+ Args:
115
+ config (CBFConfig): Config object for the CBF. Contains info on the system dynamics, barrier function, etc.
116
+
117
+ Returns:
118
+ CBF: Control Barrier Function instance
119
+ """
120
+ instance = cls(
121
+ config.n,
122
+ config.m,
123
+ config.num_cbf,
124
+ config.u_min,
125
+ config.u_max,
126
+ config.control_constrained,
127
+ config.relax_cbf,
128
+ config.cbf_relaxation_penalty,
129
+ config.h_1,
130
+ config.h_2,
131
+ config.f,
132
+ config.g,
133
+ config.alpha,
134
+ config.alpha_2,
135
+ config.P,
136
+ config.q,
137
+ config.solver_tol,
138
+ )
139
+ instance._validate_instance(*config.init_args)
140
+ return instance
141
+
142
+ def _validate_instance(self, *h_args) -> None:
143
+ """Checks that the CBF is valid; warns the user if not
144
+
145
+ Args:
146
+ *h_args: Optional additional arguments for the barrier function.
147
+ """
148
+ try:
149
+ # TODO: Decide if this should be checked on a row-by-row basis or via the full matrix
150
+ test_lgh = self.Lgh(jnp.ones(self.n), *h_args)
151
+ if jnp.allclose(test_lgh, 0):
152
+ print_warning(
153
+ "Lgh is zero. Consider increasing the relative degree or modifying the barrier function."
154
+ )
155
+ except TypeError:
156
+ print_warning(
157
+ "Cannot test Lgh; missing additional arguments.\n"
158
+ + "Please provide an initial seed for these args in the config's init_args input"
159
+ )
160
+
161
+ @conditional_jit(not DEBUG_SAFETY_FILTER)
162
+ def safety_filter(self, z: Array, u_des: Array, *h_args) -> Array:
163
+ """Apply the CBF safety filter to a nominal control
164
+
165
+ Args:
166
+ z (Array): State, shape (n,)
167
+ u_des (Array): Desired control input, shape (m,)
168
+ *h_args: Optional additional arguments for the barrier function.
169
+
170
+ Returns:
171
+ Array: Safe control input, shape (m,)
172
+ """
173
+ P, q, A, b, G, h = self.qp_data(z, u_des, *h_args)
174
+ if self.relax_cbf:
175
+ x_qp, t_qp, s1_qp, s2_qp, z1_qp, z2_qp, converged, iters = self.qp_solver(
176
+ P,
177
+ q,
178
+ G,
179
+ h,
180
+ self.cbf_relaxation_penalty,
181
+ solver_tol=self.solver_tol,
182
+ )
183
+ else:
184
+ x_qp, s_qp, z_qp, y_qp, converged, iters = self.qp_solver(
185
+ P,
186
+ q,
187
+ A,
188
+ b,
189
+ G,
190
+ h,
191
+ solver_tol=self.solver_tol,
192
+ )
193
+ if DEBUG_SAFETY_FILTER:
194
+ print(
195
+ f"{'Converged' if converged else 'Did not converge'}. Iterations: {iters}"
196
+ )
197
+ return x_qp[: self.m]
198
+
199
+ def h(self, z: ArrayLike, *h_args) -> Array:
200
+ """Barrier function(s)
201
+
202
+ Args:
203
+ z (ArrayLike): State, shape (n,)
204
+ *h_args: Optional additional arguments for the barrier function.
205
+
206
+ Returns:
207
+ Array: Barrier function evaluation, shape (num_barr,)
208
+ """
209
+
210
+ # Take any relative-degree-2 barrier functions and convert them to relative-degree-1
211
+ def _h_2(state):
212
+ return self.h_2(state, *h_args)
213
+
214
+ h_2, dh_2_dt = jax.jvp(_h_2, (z,), (self.f(z),))
215
+ h_2_as_rd1 = dh_2_dt + self.alpha_2(h_2)
216
+
217
+ # Merge the relative-degree-1 and relative-degree-2 barrier functions
218
+ return jnp.concatenate([self.h_1(z, *h_args), h_2_as_rd1])
219
+
220
+ def h_and_Lfh( # pylint: disable=invalid-name
221
+ self, z: ArrayLike, *h_args
222
+ ) -> Tuple[Array, Array]:
223
+ """Lie derivative of the barrier function(s) wrt the autonomous dynamics `f(z)`
224
+
225
+ The evaluation of the barrier function is also returned "for free", a byproduct of the jacobian-vector-product
226
+
227
+ Args:
228
+ z (ArrayLike): State, shape (n,)
229
+ *h_args: Optional additional arguments for the barrier function.
230
+
231
+ Returns:
232
+ h (Array): Barrier function evaluation, shape (num_barr,)
233
+ Lfh (Array): Lie derivative of `h` w.r.t. `f`, shape (num_barr,)
234
+ """
235
+ # Note: the below code is just a more efficient way of stating `Lfh = jax.jacobian(self.h)(z) @ self.f(z)`
236
+ # with the bonus benefit of also evaluating the barrier function
237
+
238
+ def _h(state):
239
+ return self.h(state, *h_args)
240
+
241
+ return jax.jvp(_h, (z,), (self.f(z),))
242
+
243
+ def Lgh(self, z: ArrayLike, *h_args) -> Array: # pylint: disable=invalid-name
244
+ """Lie derivative of the barrier function(s) wrt the control dynamics `g(z)u`
245
+
246
+ Args:
247
+ z (ArrayLike): State, shape (n,)
248
+ *h_args: Optional additional arguments for the barrier function.
249
+
250
+ Returns:
251
+ Array: Lgh, shape (num_barr, m)
252
+ """
253
+ # Note: the below code is just a more efficient way of stating `Lgh = jax.jacobian(self.h)(z) @ self.g(z)`
254
+
255
+ def _h(state):
256
+ return self.h(state, *h_args)
257
+
258
+ def _jvp(g_column):
259
+ return jax.jvp(_h, (z,), (g_column,))[1]
260
+
261
+ return jax.vmap(_jvp, in_axes=1, out_axes=1)(self.g(z))
262
+
263
+ ## QP Matrices ##
264
+
265
+ def P_qp( # pylint: disable=invalid-name
266
+ self, z: Array, u_des: Array, *h_args
267
+ ) -> Array:
268
+ """Quadratic term in the QP objective (`minimize 0.5 * x^T P x + q^T x`)
269
+
270
+ Args:
271
+ z (Array): State, shape (n,)
272
+ u_des (Array): Desired control input, shape (m,)
273
+ *h_args: Optional additional arguments for the barrier function.
274
+
275
+ Returns:
276
+ Array: P matrix, shape (m, m)
277
+ """
278
+ # This is user-modifiable in the config, but defaults to 2 * I for the standard min-norm CBF objective
279
+ return self.P_config(z, u_des, *h_args)
280
+
281
+ def q_qp(self, z: Array, u_des: Array, *h_args) -> Array:
282
+ """Linear term in the QP objective (`minimize 0.5 * x^T P x + q^T x`)
283
+
284
+ Args:
285
+ z (Array): State, shape (n,)
286
+ u_des (Array): Desired control input, shape (m,)
287
+ *h_args: Optional additional arguments for the barrier function.
288
+
289
+ Returns:
290
+ Array: q vector, shape (m,)
291
+ """
292
+ # This is user-modifiable in the config, but defaults to -2 * u_des for the standard min-norm CBF objective
293
+ return self.q_config(z, u_des, *h_args)
294
+
295
+ def G_qp( # pylint: disable=invalid-name
296
+ self, z: Array, u_des: Array, *h_args
297
+ ) -> Array:
298
+ """Inequality constraint matrix for the QP (`Gx <= h`)
299
+
300
+ Note:
301
+ The number of constraints depends on if we have control constraints or not.
302
+ Without control constraints, `num_constraints == num_barriers`.
303
+ With control constraints, `num_constraints == num_barriers + 2*m`
304
+
305
+ Args:
306
+ z (Array): State, shape (n,)
307
+ u_des (Array): Desired control input, shape (m,)
308
+ *h_args: Optional additional arguments for the barrier function.
309
+
310
+ Returns:
311
+ Array: G matrix, shape (num_constraints, m)
312
+ """
313
+ G = -self.Lgh(z, *h_args)
314
+ if self.control_constrained:
315
+ return jnp.block([[G], [jnp.eye(self.m)], [-jnp.eye(self.m)]])
316
+ else:
317
+ return G
318
+
319
+ def h_qp(self, z: Array, u_des: Array, *h_args) -> Array:
320
+ """Upper bound on constraints for the QP (`Gx <= h`)
321
+
322
+ Note:
323
+ The number of constraints depends on if we have control constraints or not.
324
+ Without control constraints, `num_constraints == num_barriers`.
325
+ With control constraints, `num_constraints == num_barriers + 2*m`
326
+
327
+ Args:
328
+ z (Array): State, shape (n,)
329
+ u_des (Array): Desired control input, shape (m,)
330
+ *h_args: Optional additional arguments for the barrier function.
331
+
332
+ Returns:
333
+ Array: h vector, shape (num_constraints,)
334
+ """
335
+ hz, lfh = self.h_and_Lfh(z, *h_args)
336
+ h = self.alpha(hz) + lfh
337
+ if self.control_constrained:
338
+ return jnp.concatenate(
339
+ [h, jnp.asarray(self.u_max), -jnp.asarray(self.u_min)]
340
+ )
341
+ else:
342
+ return h
343
+
344
+ @conditional_jit(not DEBUG_QP_DATA)
345
+ def qp_data(
346
+ self, z: Array, u_des: Array, *h_args
347
+ ) -> Tuple[Array, Array, Array, Array, Array, Array]:
348
+ """Constructs the QP matrices based on the current state and desired control
349
+
350
+ i.e. the matrices/vectors (P, q, A, b, G, h) for the optimization problem:
351
+
352
+ ```
353
+ minimize 0.5 * x^T P x + q^T x
354
+ subject to A x == b
355
+ G x <= h
356
+ ```
357
+
358
+ Note:
359
+ - CBFs do not rely on equality constraints, so `A` and `b` are empty.
360
+ - The number of constraints depends on if we have control constraints or not.
361
+ Without control constraints, `num_constraints == num_barriers`.
362
+ With control constraints, `num_constraints == num_barriers + 2*m`
363
+
364
+ Args:
365
+ z (Array): State, shape (n,)
366
+ u_des (Array): Desired control input, shape (m,)
367
+ *h_args: Optional additional arguments for the barrier function.
368
+
369
+ Returns:
370
+ P (Array): Quadratic term in the QP objective, shape (m, m)
371
+ q (Array): Linear term in the QP objective, shape (m,)
372
+ A (Array): Equality constraint matrix, shape (0, m)
373
+ b (Array): Equality constraint vector, shape (0,)
374
+ G (Array): Inequality constraint matrix, shape (num_constraints, m)
375
+ h (Array): Upper bound on constraints, shape (num_constraints,)
376
+ """
377
+ return (
378
+ self.P_qp(z, u_des, *h_args),
379
+ self.q_qp(z, u_des, *h_args),
380
+ jnp.zeros((0, self.m)), # Equality matrix (not used for CBF)
381
+ jnp.zeros(0), # Equality vector (not used for CBF)
382
+ self.G_qp(z, u_des, *h_args),
383
+ self.h_qp(z, u_des, *h_args),
384
+ )