cbfpy 0.0.1__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cbfpy/cbfs/clf_cbf.py CHANGED
@@ -1,7 +1,7 @@
1
1
  """
2
2
  # Control Lyapunov Function / Control Barrier Functions (CLF-CBFs)
3
3
 
4
- Whereas a CBF acts as a safety filter on top of a nominal controller, a CLF-CBF acts as a safe controller itself,
4
+ Whereas a CBF acts as a safety filter on top of a nominal controller, a CLF-CBF acts as a safe controller itself,
5
5
  based on a control objective defined by the CLF and a safety constraint defined by the CBF. Note that the CLF
6
6
  objective should be quadratic and positive-definite to fit in this QP framework.
7
7
 
@@ -43,17 +43,12 @@ import jax
43
43
  import jax.numpy as jnp
44
44
  from jax import Array
45
45
  from jax.typing import ArrayLike
46
+ import numpy as np
46
47
  import qpax
47
48
 
48
49
  from cbfpy.config.clf_cbf_config import CLFCBFConfig
49
- from cbfpy.utils.jax_utils import conditional_jit
50
50
  from cbfpy.utils.general_utils import print_warning
51
51
 
52
- # Debugging flags to disable jit in specific sections of the code.
53
- # Note: If any higher-level jits exist, those must also be set to debug (disable jit)
54
- DEBUG_CONTROLLER = False
55
- DEBUG_QP_DATA = False
56
-
57
52
 
58
53
  @jax.tree_util.register_static
59
54
  class CLFCBF:
@@ -87,9 +82,9 @@ class CLFCBF:
87
82
  u_min: Optional[tuple],
88
83
  u_max: Optional[tuple],
89
84
  control_constrained: bool,
90
- relax_cbf: bool,
91
- cbf_relaxation_penalty: float,
85
+ relax_qp: bool,
92
86
  clf_relaxation_penalty: float,
87
+ constraint_relaxation_penalties: tuple,
93
88
  h_1: Callable[[ArrayLike], Array],
94
89
  h_2: Callable[[ArrayLike], Array],
95
90
  f: Callable[[ArrayLike], Array],
@@ -111,9 +106,9 @@ class CLFCBF:
111
106
  self.u_min = u_min
112
107
  self.u_max = u_max
113
108
  self.control_constrained = control_constrained
114
- self.relax_cbf = relax_cbf
115
- self.cbf_relaxation_penalty = cbf_relaxation_penalty
109
+ self.relax_qp = relax_qp
116
110
  self.clf_relaxation_penalty = clf_relaxation_penalty
111
+ self.constraint_relaxation_penalties = constraint_relaxation_penalties
117
112
  self.h_1 = h_1
118
113
  self.h_2 = h_2
119
114
  self.f = f
@@ -127,10 +122,6 @@ class CLFCBF:
127
122
  self.H = H
128
123
  self.F = F
129
124
  self.solver_tol = solver_tol
130
- if relax_cbf:
131
- self.qp_solver: Callable = jax.jit(qpax.solve_qp_elastic)
132
- else:
133
- self.qp_solver: Callable = jax.jit(qpax.solve_qp)
134
125
 
135
126
  @classmethod
136
127
  def from_config(cls, config: CLFCBFConfig) -> "CLFCBF":
@@ -151,9 +142,9 @@ class CLFCBF:
151
142
  config.u_min,
152
143
  config.u_max,
153
144
  config.control_constrained,
154
- config.relax_cbf,
155
- config.cbf_relaxation_penalty,
145
+ config.relax_qp,
156
146
  config.clf_relaxation_penalty,
147
+ config.constraint_relaxation_penalties,
157
148
  config.h_1,
158
149
  config.h_2,
159
150
  config.f,
@@ -168,19 +159,16 @@ class CLFCBF:
168
159
  config.F,
169
160
  config.solver_tol,
170
161
  )
171
- instance._validate_instance(*config.init_args)
162
+ instance._validate_instance(*config.init_args, **config.init_kwargs)
172
163
  return instance
173
164
 
174
- def _validate_instance(self, *h_args) -> None:
175
- """Checks that the CLF-CBF is valid; warns the user if not
165
+ def _validate_instance(self, *args, **kwargs) -> None:
166
+ """Checks that the CLF-CBF is valid; warns the user if not"""
176
167
 
177
- Args:
178
- *h_args: Optional additional arguments for the barrier function.
179
- """
180
- test_z = jnp.ones(self.n)
168
+ test_z = np.ones(self.n)
181
169
  try:
182
- test_lgh = self.Lgh(test_z, *h_args)
183
- if jnp.allclose(test_lgh, 0):
170
+ test_lgh = self.Lgh(test_z, *args, **kwargs)
171
+ if np.allclose(test_lgh, 0):
184
172
  print_warning(
185
173
  "Lgh is zero. Consider increasing the relative degree or modifying the barrier function."
186
174
  )
@@ -189,37 +177,36 @@ class CLFCBF:
189
177
  "Cannot test Lgh; missing additional arguments.\n"
190
178
  + "Please provide an initial seed for these args in the config's init_args input"
191
179
  )
192
- test_lgv = self.LgV(test_z)
193
- if jnp.allclose(test_lgv, 0):
180
+ test_lgv = self.LgV(test_z, test_z, *args, **kwargs)
181
+ if np.allclose(test_lgv, 0):
194
182
  print_warning(
195
183
  "LgV is zero. Consider increasing the relative degree or modifying the Lyapunov function."
196
184
  )
197
185
 
198
- @conditional_jit(not DEBUG_CONTROLLER)
199
- def controller(self, z: Array, z_des: Array, *h_args) -> Array:
186
+ @jax.jit
187
+ def controller(self, z: Array, z_des: Array, *args, **kwargs) -> Array:
200
188
  """Compute the CLF-CBF optimal control input, optimizing for the CLF objective while
201
189
  satisfying the CBF safety constraint.
202
190
 
203
191
  Args:
204
192
  z (Array): State, shape (n,)
205
193
  z_des (Array): Desired state, shape (n,)
206
- *h_args: Optional additional arguments for the barrier function.
207
194
 
208
195
  Returns:
209
196
  Array: Safe control input, shape (m,)
210
197
  """
211
- P, q, A, b, G, h = self.qp_data(z, z_des, *h_args)
212
- if self.relax_cbf:
213
- x_qp, t_qp, s1_qp, s2_qp, z1_qp, z2_qp, converged, iters = self.qp_solver(
198
+ P, q, A, b, G, h = self.qp_data(z, z_des, *args, **kwargs)
199
+ if self.relax_qp:
200
+ x_qp = qpax.solve_qp_elastic_primal(
214
201
  P,
215
202
  q,
216
203
  G,
217
204
  h,
218
- self.cbf_relaxation_penalty,
205
+ penalty=jnp.asarray(self.constraint_relaxation_penalties),
219
206
  solver_tol=self.solver_tol,
220
207
  )
221
208
  else:
222
- x_qp, s_qp, z_qp, y_qp, converged, iters = self.qp_solver(
209
+ x_qp, s_qp, z_qp, y_qp, converged, iters = qpax.solve_qp(
223
210
  P,
224
211
  q,
225
212
  A,
@@ -228,18 +215,13 @@ class CLFCBF:
228
215
  h,
229
216
  solver_tol=self.solver_tol,
230
217
  )
231
- if DEBUG_CONTROLLER:
232
- print(
233
- f"{'Converged' if converged else 'Did not converge'}. Iterations: {iters}"
234
- )
235
218
  return x_qp[: self.m]
236
219
 
237
- def h(self, z: ArrayLike, *h_args) -> Array:
220
+ def h(self, z: ArrayLike, *args, **kwargs) -> Array:
238
221
  """Barrier function(s)
239
222
 
240
223
  Args:
241
224
  z (ArrayLike): State, shape (n,)
242
- *h_args: Optional additional arguments for the barrier function.
243
225
 
244
226
  Returns:
245
227
  Array: Barrier function evaluation, shape (num_barr,)
@@ -247,16 +229,16 @@ class CLFCBF:
247
229
 
248
230
  # Take any relative-degree-2 barrier functions and convert them to relative-degree-1
249
231
  def _h_2(state):
250
- return self.h_2(state, *h_args)
232
+ return self.h_2(state, *args, **kwargs)
251
233
 
252
- h_2, dh_2_dt = jax.jvp(_h_2, (z,), (self.f(z),))
253
- h_2_as_rd1 = dh_2_dt + self.alpha_2(h_2)
234
+ h_2, dh_2_dt = jax.jvp(_h_2, (z,), (self.f(z, *args, **kwargs),))
235
+ h_2_as_rd1 = dh_2_dt + self.alpha_2(h_2, *args, **kwargs)
254
236
 
255
237
  # Merge the relative-degree-1 and relative-degree-2 barrier functions
256
- return jnp.concatenate([self.h_1(z, *h_args), h_2_as_rd1])
238
+ return jnp.concatenate([self.h_1(z, *args, **kwargs), h_2_as_rd1])
257
239
 
258
240
  def h_and_Lfh( # pylint: disable=invalid-name
259
- self, z: ArrayLike, *h_args
241
+ self, z: ArrayLike, *args, **kwargs
260
242
  ) -> Tuple[Array, Array]:
261
243
  """Lie derivative of the barrier function(s) wrt the autonomous dynamics `f(z)`
262
244
 
@@ -264,7 +246,6 @@ class CLFCBF:
264
246
 
265
247
  Args:
266
248
  z (ArrayLike): State, shape (n,)
267
- *h_args: Optional additional arguments for the barrier function.
268
249
 
269
250
  Returns:
270
251
  h (Array): Barrier function evaluation, shape (num_barr,)
@@ -274,16 +255,17 @@ class CLFCBF:
274
255
  # with the bonus benefit of also evaluating the barrier function
275
256
 
276
257
  def _h(state):
277
- return self.h(state, *h_args)
258
+ return self.h(state, *args, **kwargs)
278
259
 
279
- return jax.jvp(_h, (z,), (self.f(z),))
260
+ return jax.jvp(_h, (z,), (self.f(z, *args, **kwargs),))
280
261
 
281
- def Lgh(self, z: ArrayLike, *h_args) -> Array: # pylint: disable=invalid-name
262
+ def Lgh(
263
+ self, z: ArrayLike, *args, **kwargs
264
+ ) -> Array: # pylint: disable=invalid-name
282
265
  """Lie derivative of the barrier function(s) wrt the control dynamics `g(z)u`
283
266
 
284
267
  Args:
285
268
  z (ArrayLike): State, shape (n,)
286
- *h_args: Optional additional arguments for the barrier function.
287
269
 
288
270
  Returns:
289
271
  Array: Lgh, shape (num_barr, m)
@@ -291,98 +273,112 @@ class CLFCBF:
291
273
  # Note: the below code is just a more efficient way of stating `Lgh = jax.jacobian(self.h)(z) @ self.g(z)`
292
274
 
293
275
  def _h(state):
294
- return self.h(state, *h_args)
276
+ return self.h(state, *args, **kwargs)
295
277
 
296
278
  def _jvp(g_column):
297
279
  return jax.jvp(_h, (z,), (g_column,))[1]
298
280
 
299
- return jax.vmap(_jvp, in_axes=1, out_axes=1)(self.g(z))
281
+ return jax.vmap(_jvp, in_axes=1, out_axes=1)(self.g(z, *args, **kwargs))
300
282
 
301
283
  ## CLF functions ##
302
284
 
303
- def V(self, z: ArrayLike) -> Array:
285
+ def V(self, z: ArrayLike, z_des: ArrayLike, *args, **kwargs) -> Array:
304
286
  """Control Lyapunov Function(s)
305
287
 
306
288
  Args:
307
289
  z (ArrayLike): State, shape (n,)
290
+ z_des (ArrayLike): Desired state, shape (n,)
308
291
 
309
292
  Returns:
310
293
  Array: CLF evaluation, shape (num_clf,)
311
294
  """
295
+
296
+ def _V_2(state):
297
+ return self.V_2(state, z_des, *args, **kwargs)
298
+
312
299
  # Take any relative-degree-2 CLFs and convert them to relative-degree-1
313
300
  # NOTE: If adding args to the CLF, create a wrapper func like with the barrier function
314
- V_2, dV_2_dt = jax.jvp(self.V_2, (z,), (self.f(z),))
315
- V2_rd1 = dV_2_dt + self.gamma_2(V_2)
301
+ V_2, dV_2_dt = jax.jvp(_V_2, (z,), (self.f(z, *args, **kwargs),))
302
+ V2_rd1 = dV_2_dt + self.gamma_2(V_2, *args, **kwargs)
316
303
 
317
304
  # Merge the relative-degree-1 and relative-degree-2 CLFs
318
- return jnp.concatenate([self.V_1(z), V2_rd1])
305
+ return jnp.concatenate([self.V_1(z, z_des, *args, **kwargs), V2_rd1])
319
306
 
320
- def V_and_LfV(self, z: ArrayLike) -> Tuple[Array, Array]:
307
+ def V_and_LfV(
308
+ self, z: ArrayLike, z_des: ArrayLike, *args, **kwargs
309
+ ) -> Tuple[Array, Array]:
321
310
  """Lie derivative of the CLF wrt the autonomous dynamics `f(z)`
322
311
 
323
312
  The evaluation of the CLF is also returned "for free", a byproduct of the jacobian-vector-product
324
313
 
325
314
  Args:
326
315
  z (ArrayLike): State, shape (n,)
316
+ z_des (ArrayLike): Desired state, shape (n,)
327
317
 
328
318
  Returns:
329
319
  V (Array): CLF evaluation, shape (1,)
330
320
  LfV (Array): Lie derivative of `V` w.r.t. `f`, shape (1,)
331
321
  """
332
- return jax.jvp(self.V, (z,), (self.f(z),))
333
322
 
334
- def LgV(self, z: ArrayLike) -> Array:
323
+ def _V(state):
324
+ return self.V(state, z_des, *args, **kwargs)
325
+
326
+ return jax.jvp(_V, (z,), (self.f(z, *args, **kwargs),))
327
+
328
+ def LgV(self, z: ArrayLike, z_des: ArrayLike, *args, **kwargs) -> Array:
335
329
  """Lie derivative of the CLF wrt the control dynamics `g(z)u`
336
330
 
337
331
  Args:
338
332
  z (ArrayLike): State, shape (n,)
333
+ z_des (ArrayLike): Desired state, shape (n,)
339
334
 
340
335
  Returns:
341
336
  Array: LgV, shape (m,)
342
337
  """
343
338
 
339
+ def _V(state):
340
+ return self.V(state, z_des, *args, **kwargs)
341
+
344
342
  def _jvp(g_column):
345
- return jax.jvp(self.V, (z,), (g_column,))[1]
343
+ return jax.jvp(_V, (z,), (g_column,))[1]
346
344
 
347
- return jax.vmap(_jvp, in_axes=1, out_axes=1)(self.g(z))
345
+ return jax.vmap(_jvp, in_axes=1, out_axes=1)(self.g(z, *args, **kwargs))
348
346
 
349
347
  ## QP Matrices ##
350
348
 
351
349
  def P_qp( # pylint: disable=invalid-name
352
- self, z: Array, z_des: Array, *h_args
350
+ self, z: Array, z_des: Array, *args, **kwargs
353
351
  ) -> Array:
354
352
  """Quadratic term in the QP objective (`minimize 0.5 * x^T P x + q^T x`)
355
353
 
356
354
  Args:
357
355
  z (Array): State, shape (n,)
358
356
  z_des (Array): Desired state, shape (n,)
359
- *h_args: Optional additional arguments for the barrier function.
360
357
 
361
358
  Returns:
362
359
  Array: P matrix, shape (m, m)
363
360
  """
364
361
  return jnp.block(
365
362
  [
366
- [self.H(z), jnp.zeros((self.m, 1))],
363
+ [self.H(z, *args, **kwargs), jnp.zeros((self.m, 1))],
367
364
  [jnp.zeros((1, self.m)), jnp.atleast_1d(self.clf_relaxation_penalty)],
368
365
  ]
369
366
  )
370
367
 
371
- def q_qp(self, z: Array, z_des: Array, *h_args) -> Array:
368
+ def q_qp(self, z: Array, z_des: Array, *args, **kwargs) -> Array:
372
369
  """Linear term in the QP objective (`minimize 0.5 * x^T P x + q^T x`)
373
370
 
374
371
  Args:
375
372
  z (Array): State, shape (n,)
376
373
  z_des (Array): Desired state, shape (n,)
377
- *h_args: Optional additional arguments for the barrier function.
378
374
 
379
375
  Returns:
380
376
  Array: Q vector, shape (m,)
381
377
  """
382
- return jnp.concatenate([self.F(z), jnp.array([0.0])])
378
+ return jnp.concatenate([self.F(z, *args, **kwargs), jnp.array([0.0])])
383
379
 
384
380
  def G_qp( # pylint: disable=invalid-name
385
- self, z: Array, z_des: Array, *h_args
381
+ self, z: Array, z_des: Array, *args, **kwargs
386
382
  ) -> Array:
387
383
  """Inequality constraint matrix for the QP (`Gx <= h`)
388
384
 
@@ -394,15 +390,17 @@ class CLFCBF:
394
390
  Args:
395
391
  z (Array): State, shape (n,)
396
392
  z_des (Array): Desired state, shape (n,)
397
- *h_args: Optional additional arguments for the barrier function.
398
393
 
399
394
  Returns:
400
395
  Array: G matrix, shape (num_constraints, m)
401
396
  """
402
397
  G = jnp.block(
403
398
  [
404
- [self.LgV(z), -1.0 * jnp.ones((self.num_clf, 1))],
405
- [-self.Lgh(z, *h_args), jnp.zeros((self.num_cbf, 1))],
399
+ [
400
+ self.LgV(z, z_des, *args, **kwargs),
401
+ -1.0 * jnp.ones((self.num_clf, 1)),
402
+ ],
403
+ [-self.Lgh(z, *args, **kwargs), jnp.zeros((self.num_cbf, 1))],
406
404
  ]
407
405
  )
408
406
  if self.control_constrained:
@@ -416,7 +414,7 @@ class CLFCBF:
416
414
  else:
417
415
  return G
418
416
 
419
- def h_qp(self, z: Array, z_des: Array, *h_args) -> Array:
417
+ def h_qp(self, z: Array, z_des: Array, *args, **kwargs) -> Array:
420
418
  """Upper bound on constraints for the QP (`Gx <= h`)
421
419
 
422
420
  Note:
@@ -427,17 +425,16 @@ class CLFCBF:
427
425
  Args:
428
426
  z (Array): State, shape (n,)
429
427
  z_des (Array): Desired state, shape (n,)
430
- *h_args: Optional additional arguments for the barrier function.
431
428
 
432
429
  Returns:
433
430
  Array: h vector, shape (num_constraints,)
434
431
  """
435
- hz, lfh = self.h_and_Lfh(z, *h_args)
436
- vz, lfv = self.V_and_LfV(z)
432
+ hz, lfh = self.h_and_Lfh(z, *args, **kwargs)
433
+ vz, lfv = self.V_and_LfV(z, z_des, *args, **kwargs)
437
434
  h = jnp.concatenate(
438
435
  [
439
- -lfv - self.gamma(vz),
440
- self.alpha(hz) + lfh,
436
+ -lfv - self.gamma(vz, *args, **kwargs),
437
+ self.alpha(hz, *args, **kwargs) + lfh,
441
438
  ]
442
439
  )
443
440
  if self.control_constrained:
@@ -447,9 +444,8 @@ class CLFCBF:
447
444
  else:
448
445
  return h
449
446
 
450
- @conditional_jit(not DEBUG_QP_DATA)
451
447
  def qp_data(
452
- self, z: Array, z_des: Array, *h_args
448
+ self, z: Array, z_des: Array, *args, **kwargs
453
449
  ) -> Tuple[Array, Array, Array, Array, Array, Array]:
454
450
  """Constructs the QP matrices based on the current state and desired control
455
451
 
@@ -470,7 +466,6 @@ class CLFCBF:
470
466
  Args:
471
467
  z (Array): State, shape (n,)
472
468
  z_des (Array): Desired state, shape (n,)
473
- *h_args: Optional additional arguments for the barrier function.
474
469
 
475
470
  Returns:
476
471
  P (Array): Quadratic term in the QP objective, shape (m + 1, m + 1)
@@ -481,10 +476,10 @@ class CLFCBF:
481
476
  h (Array): Upper bound on constraints, shape (num_constraints,)
482
477
  """
483
478
  return (
484
- self.P_qp(z, z_des, *h_args),
485
- self.q_qp(z, z_des, *h_args),
479
+ self.P_qp(z, z_des, *args, **kwargs),
480
+ self.q_qp(z, z_des, *args, **kwargs),
486
481
  jnp.zeros((0, self.m + 1)), # Equality matrix (not used for CLF-CBF)
487
482
  jnp.zeros(0), # Equality vector (not used for CLF-CBF)
488
- self.G_qp(z, z_des, *h_args),
489
- self.h_qp(z, z_des, *h_args),
483
+ self.G_qp(z, z_des, *args, **kwargs),
484
+ self.h_qp(z, z_des, *args, **kwargs),
490
485
  )