derivkit 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- derivkit/__init__.py +22 -0
- derivkit/calculus/__init__.py +17 -0
- derivkit/calculus/calculus_core.py +152 -0
- derivkit/calculus/gradient.py +97 -0
- derivkit/calculus/hessian.py +528 -0
- derivkit/calculus/hyper_hessian.py +296 -0
- derivkit/calculus/jacobian.py +156 -0
- derivkit/calculus_kit.py +128 -0
- derivkit/derivative_kit.py +315 -0
- derivkit/derivatives/__init__.py +6 -0
- derivkit/derivatives/adaptive/__init__.py +5 -0
- derivkit/derivatives/adaptive/adaptive_fit.py +238 -0
- derivkit/derivatives/adaptive/batch_eval.py +179 -0
- derivkit/derivatives/adaptive/diagnostics.py +325 -0
- derivkit/derivatives/adaptive/grid.py +333 -0
- derivkit/derivatives/adaptive/polyfit_utils.py +513 -0
- derivkit/derivatives/adaptive/spacing.py +66 -0
- derivkit/derivatives/adaptive/transforms.py +245 -0
- derivkit/derivatives/autodiff/__init__.py +1 -0
- derivkit/derivatives/autodiff/jax_autodiff.py +95 -0
- derivkit/derivatives/autodiff/jax_core.py +217 -0
- derivkit/derivatives/autodiff/jax_utils.py +146 -0
- derivkit/derivatives/finite/__init__.py +5 -0
- derivkit/derivatives/finite/batch_eval.py +91 -0
- derivkit/derivatives/finite/core.py +84 -0
- derivkit/derivatives/finite/extrapolators.py +511 -0
- derivkit/derivatives/finite/finite_difference.py +247 -0
- derivkit/derivatives/finite/stencil.py +206 -0
- derivkit/derivatives/fornberg.py +245 -0
- derivkit/derivatives/local_polynomial_derivative/__init__.py +1 -0
- derivkit/derivatives/local_polynomial_derivative/diagnostics.py +90 -0
- derivkit/derivatives/local_polynomial_derivative/fit.py +199 -0
- derivkit/derivatives/local_polynomial_derivative/local_poly_config.py +95 -0
- derivkit/derivatives/local_polynomial_derivative/local_polynomial_derivative.py +205 -0
- derivkit/derivatives/local_polynomial_derivative/sampling.py +72 -0
- derivkit/derivatives/tabulated_model/__init__.py +1 -0
- derivkit/derivatives/tabulated_model/one_d.py +247 -0
- derivkit/forecast_kit.py +783 -0
- derivkit/forecasting/__init__.py +1 -0
- derivkit/forecasting/dali.py +78 -0
- derivkit/forecasting/expansions.py +486 -0
- derivkit/forecasting/fisher.py +298 -0
- derivkit/forecasting/fisher_gaussian.py +171 -0
- derivkit/forecasting/fisher_xy.py +357 -0
- derivkit/forecasting/forecast_core.py +313 -0
- derivkit/forecasting/getdist_dali_samples.py +429 -0
- derivkit/forecasting/getdist_fisher_samples.py +235 -0
- derivkit/forecasting/laplace.py +259 -0
- derivkit/forecasting/priors_core.py +860 -0
- derivkit/forecasting/sampling_utils.py +388 -0
- derivkit/likelihood_kit.py +114 -0
- derivkit/likelihoods/__init__.py +1 -0
- derivkit/likelihoods/gaussian.py +136 -0
- derivkit/likelihoods/poisson.py +176 -0
- derivkit/utils/__init__.py +13 -0
- derivkit/utils/concurrency.py +213 -0
- derivkit/utils/extrapolation.py +254 -0
- derivkit/utils/linalg.py +513 -0
- derivkit/utils/logger.py +26 -0
- derivkit/utils/numerics.py +262 -0
- derivkit/utils/sandbox.py +74 -0
- derivkit/utils/types.py +15 -0
- derivkit/utils/validate.py +811 -0
- derivkit-1.0.0.dist-info/METADATA +50 -0
- derivkit-1.0.0.dist-info/RECORD +68 -0
- derivkit-1.0.0.dist-info/WHEEL +5 -0
- derivkit-1.0.0.dist-info/licenses/LICENSE +21 -0
- derivkit-1.0.0.dist-info/top_level.txt +1 -0
derivkit/forecast_kit.py
ADDED
|
@@ -0,0 +1,783 @@
|
|
|
1
|
+
r"""Provides the ForecastKit class.
|
|
2
|
+
|
|
3
|
+
A light wrapper around the core forecasting utilities
|
|
4
|
+
(:func:`derivkit.forecasting.fisher.build_fisher_matrix`,
|
|
5
|
+
:func:`derivkit.forecasting.dali.build_dali`,
|
|
6
|
+
:func:`derivkit.forecasting.fisher.build_delta_nu`,
|
|
7
|
+
and :func:`derivkit.forecasting.fisher.build_fisher_bias`) that exposes a simple
|
|
8
|
+
API for Fisher and DALI tensors.
|
|
9
|
+
|
|
10
|
+
Typical usage example:
|
|
11
|
+
|
|
12
|
+
>>> import numpy as np
|
|
13
|
+
>>> from derivkit.forecast_kit import ForecastKit
|
|
14
|
+
>>>
|
|
15
|
+
>>> # Toy linear model: 2 params -> 2 observables
|
|
16
|
+
>>> def model(theta: np.ndarray) -> np.ndarray:
|
|
17
|
+
... theta = np.asarray(theta, dtype=float)
|
|
18
|
+
... return np.array([theta[0] + 2.0 * theta[1], 3.0 * theta[0] - theta[1]], dtype=float)
|
|
19
|
+
>>>
|
|
20
|
+
>>> theta0 = np.array([0.1, -0.2])
|
|
21
|
+
>>> cov = np.eye(2)
|
|
22
|
+
>>>
|
|
23
|
+
>>> fk = ForecastKit(function=model, theta0=theta0, cov=cov)
|
|
24
|
+
>>> fisher_matrix = fk.fisher(method="finite", n_workers=1)
|
|
25
|
+
>>> fisher_matrix.shape
|
|
26
|
+
(2, 2)
|
|
27
|
+
>>> dali = fk.dali(forecast_order=2, method="finite", n_workers=1)
|
|
28
|
+
>>> F = dali[1][0]
|
|
29
|
+
>>> D1, D2 = dali[2]
|
|
30
|
+
>>>
|
|
31
|
+
>>> data_unbiased = model(theta0)
|
|
32
|
+
>>> data_biased = data_unbiased + np.array([1e-3, -2e-3])
|
|
33
|
+
>>> dn = fk.delta_nu(data_unbiased=data_unbiased, data_biased=data_biased)
|
|
34
|
+
>>> dn.shape
|
|
35
|
+
(2,)
|
|
36
|
+
>>>
|
|
37
|
+
>>> bias_vec, delta_theta = fk.fisher_bias(
|
|
38
|
+
... fisher_matrix=fisher_matrix,
|
|
39
|
+
... delta_nu=dn,
|
|
40
|
+
... method="finite",
|
|
41
|
+
... n_workers=1,
|
|
42
|
+
... )
|
|
43
|
+
>>> bias_vec.shape, delta_theta.shape
|
|
44
|
+
((2,), (2,))
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
from __future__ import annotations
|
|
48
|
+
|
|
49
|
+
from collections.abc import Callable
|
|
50
|
+
from typing import Any, Mapping, Sequence
|
|
51
|
+
|
|
52
|
+
import numpy as np
|
|
53
|
+
|
|
54
|
+
from derivkit.forecasting.dali import build_dali
|
|
55
|
+
from derivkit.forecasting.expansions import (
|
|
56
|
+
build_delta_chi2_dali,
|
|
57
|
+
build_delta_chi2_fisher,
|
|
58
|
+
build_logposterior_dali,
|
|
59
|
+
build_logposterior_fisher,
|
|
60
|
+
)
|
|
61
|
+
from derivkit.forecasting.fisher import (
|
|
62
|
+
build_delta_nu,
|
|
63
|
+
build_fisher_bias,
|
|
64
|
+
build_fisher_matrix,
|
|
65
|
+
)
|
|
66
|
+
from derivkit.forecasting.fisher_gaussian import (
|
|
67
|
+
build_gaussian_fisher_matrix,
|
|
68
|
+
)
|
|
69
|
+
from derivkit.forecasting.getdist_dali_samples import (
|
|
70
|
+
dali_to_getdist_emcee,
|
|
71
|
+
dali_to_getdist_importance,
|
|
72
|
+
)
|
|
73
|
+
from derivkit.forecasting.getdist_fisher_samples import (
|
|
74
|
+
fisher_to_getdist_gaussiannd,
|
|
75
|
+
fisher_to_getdist_samples,
|
|
76
|
+
)
|
|
77
|
+
from derivkit.forecasting.laplace import (
|
|
78
|
+
build_laplace_approximation,
|
|
79
|
+
build_laplace_covariance,
|
|
80
|
+
build_laplace_hessian,
|
|
81
|
+
build_negative_logposterior,
|
|
82
|
+
)
|
|
83
|
+
from derivkit.utils.types import FloatArray
|
|
84
|
+
from derivkit.utils.validate import (
|
|
85
|
+
require_callable,
|
|
86
|
+
resolve_covariance_input,
|
|
87
|
+
validate_covariance_matrix_shape,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
_RESERVED_KWARGS = {"theta0"}
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class ForecastKit:
|
|
94
|
+
"""Provides access to forecasting workflows."""
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
function: Callable[[Sequence[float] | np.ndarray], np.ndarray] | None,
|
|
98
|
+
theta0: Sequence[float] | np.ndarray,
|
|
99
|
+
cov: np.ndarray
|
|
100
|
+
| Callable[[np.ndarray], np.ndarray],
|
|
101
|
+
):
|
|
102
|
+
r"""Initialises the ForecastKit with model, fiducials, and covariance.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
function: Callable returning the model mean vector :math:`\mu(\theta)`.
|
|
106
|
+
May be ``None`` if you only plan to use covariance-only workflows
|
|
107
|
+
(e.g. generalized Fisher with ``term="cov"``). Required for
|
|
108
|
+
:meth:`fisher`, :meth:`dali`, and :meth:`fisher_bias`.
|
|
109
|
+
theta0: Fiducial parameter values of shape ``(p,)`` where ``p`` is the
|
|
110
|
+
number of parameters.
|
|
111
|
+
cov: Covariance specification. Supported forms are:
|
|
112
|
+
|
|
113
|
+
- ``cov=C0``: fixed covariance matrix :math:`C(\theta_0)` with shape
|
|
114
|
+
``(n_obs, n_obs)``, where ``n_obs`` is the number of observables.
|
|
115
|
+
- ``cov=cov_fn``: callable with ``cov_fn(theta)`` returning the covariance
|
|
116
|
+
matrix :math:`C(\theta)` evaluated at the parameter vector ``theta``,
|
|
117
|
+
with shape ``(n_obs, n_obs)``. The covariance at ``theta0`` is evaluated
|
|
118
|
+
once and cached.
|
|
119
|
+
"""
|
|
120
|
+
self.function = function
|
|
121
|
+
self.theta0 = np.atleast_1d(np.asarray(theta0, dtype=np.float64))
|
|
122
|
+
|
|
123
|
+
cov0, cov_fn = resolve_covariance_input(
|
|
124
|
+
cov,
|
|
125
|
+
theta0=self.theta0,
|
|
126
|
+
validate=validate_covariance_matrix_shape,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
self.cov0 = cov0
|
|
130
|
+
self.cov_fn = cov_fn
|
|
131
|
+
self.n_observables = int(self.cov0.shape[0])
|
|
132
|
+
|
|
133
|
+
def fisher(
|
|
134
|
+
self,
|
|
135
|
+
*,
|
|
136
|
+
method: str | None = None,
|
|
137
|
+
n_workers: int = 1,
|
|
138
|
+
**dk_kwargs: Any,
|
|
139
|
+
) -> np.ndarray:
|
|
140
|
+
"""Computes the Fisher information matrix for a given model and covariance.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
method: Derivative method name or alias (e.g., ``"adaptive"``,
|
|
144
|
+
``"finite"``). If ``None``, the
|
|
145
|
+
:class:`derivkit.derivative_kit.DerivativeKit` default is used.
|
|
146
|
+
n_workers: Number of workers for per-parameter parallelisation.
|
|
147
|
+
Default is ``1`` (serial).
|
|
148
|
+
**dk_kwargs: Additional keyword arguments forwarded to
|
|
149
|
+
:meth:`derivkit.derivative_kit.DerivativeKit.differentiate`.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Fisher matrix with shape ``(n_parameters, n_parameters)``.
|
|
153
|
+
"""
|
|
154
|
+
function = require_callable(self.function, context="ForecastKit.fisher")
|
|
155
|
+
|
|
156
|
+
fisher_matrix = build_fisher_matrix(
|
|
157
|
+
function=function,
|
|
158
|
+
theta0=self.theta0,
|
|
159
|
+
cov=self.cov0,
|
|
160
|
+
method=method,
|
|
161
|
+
n_workers=n_workers,
|
|
162
|
+
**dk_kwargs,
|
|
163
|
+
)
|
|
164
|
+
return fisher_matrix
|
|
165
|
+
|
|
166
|
+
def fisher_bias(
|
|
167
|
+
self,
|
|
168
|
+
*,
|
|
169
|
+
fisher_matrix: np.ndarray,
|
|
170
|
+
delta_nu: np.ndarray,
|
|
171
|
+
method: str | None = None,
|
|
172
|
+
n_workers: int = 1,
|
|
173
|
+
rcond: float = 1e-12,
|
|
174
|
+
**dk_kwargs: Any,
|
|
175
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
176
|
+
r"""Estimates parameter bias using the stored model, expansion point, and covariance.
|
|
177
|
+
|
|
178
|
+
This function takes a model, an expansion point, a covariance matrix,
|
|
179
|
+
a Fisher matrix, and a data-vector difference ``delta_nu`` and maps that
|
|
180
|
+
difference into parameter space. A common use case is the classic
|
|
181
|
+
"Fisher bias" setup, where one asks how a systematic-induced change in
|
|
182
|
+
the data would shift inferred parameters.
|
|
183
|
+
|
|
184
|
+
Internally, the function evaluates the model response at the expansion
|
|
185
|
+
point and uses the covariance and Fisher matrix to compute both the
|
|
186
|
+
parameter-space bias vector and the corresponding shifts. See
|
|
187
|
+
https://arxiv.org/abs/0710.5171 for details.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
fisher_matrix: Square matrix describing information about
|
|
191
|
+
the parameters. Its shape must be ``(p, p)``, where ``p``
|
|
192
|
+
is the number of parameters.
|
|
193
|
+
delta_nu: Difference between a biased and an unbiased data vector,
|
|
194
|
+
for example :math:`\Delta\nu = \nu_{\mathrm{biased}} - \nu_{\mathrm{unbiased}}`.
|
|
195
|
+
Accepts a 1D array of length n or a 2D array that will be
|
|
196
|
+
flattened in row-major order ("C") to length n, where n is
|
|
197
|
+
the number of observables. If supplied as a 1D array, it must
|
|
198
|
+
already follow the same row-major ("C") flattening convention
|
|
199
|
+
used throughout the package.
|
|
200
|
+
n_workers: Number of workers used by the internal derivative routine
|
|
201
|
+
when forming the Jacobian.
|
|
202
|
+
method: Method name or alias (e.g., ``"adaptive"``, ``"finite"``).
|
|
203
|
+
If ``None``, the :class:`derivkit.derivative_kit.DerivativeKit` default is used.
|
|
204
|
+
rcond: Regularization cutoff for pseudoinverse.
|
|
205
|
+
Default is ``1e-12``.
|
|
206
|
+
**dk_kwargs: Additional keyword arguments passed to
|
|
207
|
+
:meth:`derivkit.derivative_kit.DerivativeKit.differentiate`.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
A tuple ``(bias_vec, delta_theta)`` of 1D arrays with length ``p``,
|
|
211
|
+
where ``bias_vec`` is the parameter-space bias vector
|
|
212
|
+
and ``delta_theta`` are the corresponding parameter shifts.
|
|
213
|
+
"""
|
|
214
|
+
function = require_callable(self.function, context="ForecastKit.fisher_bias")
|
|
215
|
+
|
|
216
|
+
bias = build_fisher_bias(
|
|
217
|
+
function=function,
|
|
218
|
+
theta0=self.theta0,
|
|
219
|
+
cov=self.cov0,
|
|
220
|
+
fisher_matrix=fisher_matrix,
|
|
221
|
+
delta_nu=delta_nu,
|
|
222
|
+
method=method,
|
|
223
|
+
n_workers=n_workers,
|
|
224
|
+
rcond=rcond,
|
|
225
|
+
**dk_kwargs,
|
|
226
|
+
)
|
|
227
|
+
return bias
|
|
228
|
+
|
|
229
|
+
def delta_nu(
|
|
230
|
+
self,
|
|
231
|
+
data_unbiased: np.ndarray,
|
|
232
|
+
data_biased: np.ndarray,
|
|
233
|
+
) -> np.ndarray:
|
|
234
|
+
"""Computes the difference between two data vectors.
|
|
235
|
+
|
|
236
|
+
This helper is used in Fisher-bias calculations and any other workflow
|
|
237
|
+
where two data vectors are compared: it takes a pair of vectors (for
|
|
238
|
+
example, a version with a systematic and one without) and returns their
|
|
239
|
+
difference as a 1D array whose length matches the number of observables
|
|
240
|
+
implied by ``cov``. It works with both 1D inputs and 2D arrays (for
|
|
241
|
+
example, correlation-by-ell) and flattens 2D inputs using NumPy's
|
|
242
|
+
row-major ("C") order, which is the standard convention throughout the
|
|
243
|
+
DerivKit package.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
data_unbiased: Reference data vector without the systematic.
|
|
247
|
+
Can be 1D or 2D. If 1D, it must follow the NumPy's row-major
|
|
248
|
+
("C") flattening convention used throughout the package.
|
|
249
|
+
data_biased: Data vector that includes the systematic effect.
|
|
250
|
+
Can be 1D or 2D. If 1D, it must follow the NumPy's row-major
|
|
251
|
+
("C") flattening convention used throughout the package.
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
A 1D NumPy array of length ``n_observables`` representing the
|
|
255
|
+
mismatch between the two input data vectors. This is simply the
|
|
256
|
+
element-wise difference between the input with systematic and the
|
|
257
|
+
input without systematic, flattened if necessary to match the
|
|
258
|
+
expected observable ordering.
|
|
259
|
+
"""
|
|
260
|
+
nu = build_delta_nu(
|
|
261
|
+
cov=self.cov0,
|
|
262
|
+
data_biased=data_biased,
|
|
263
|
+
data_unbiased=data_unbiased,
|
|
264
|
+
)
|
|
265
|
+
return nu
|
|
266
|
+
|
|
267
|
+
def dali(
|
|
268
|
+
self,
|
|
269
|
+
*,
|
|
270
|
+
method: str | None = None,
|
|
271
|
+
forecast_order: int = 2,
|
|
272
|
+
n_workers: int = 1,
|
|
273
|
+
**dk_kwargs: Any,
|
|
274
|
+
) -> dict[int, tuple[FloatArray, ...]]:
|
|
275
|
+
"""Builds the DALI expansion for the given model up to the given order.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
method: Method name or alias (e.g., ``"adaptive"``, ``"finite"``).
|
|
279
|
+
If ``None``, the :class:`derivkit.derivative_kit.DerivativeKit`
|
|
280
|
+
default is used.
|
|
281
|
+
forecast_order: The requested order of the forecast.
|
|
282
|
+
Currently supported values and their meaning are given in
|
|
283
|
+
:data:`derivkit.forecasting.forecast_core.SUPPORTED_FORECAST_ORDERS`.
|
|
284
|
+
n_workers: Number of workers for per-parameter
|
|
285
|
+
parallelization/threads. Default ``1`` (serial). Inner batch
|
|
286
|
+
evaluation is kept serial to avoid oversubscription.
|
|
287
|
+
**dk_kwargs: Additional keyword arguments passed to
|
|
288
|
+
:class:`derivkit.calculus_kit.CalculusKit`.
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
A dict mapping ``order -> multiplet`` for all ``order = 1..forecast_order``.
|
|
292
|
+
|
|
293
|
+
For each forecast order k, the returned multiplet contains the tensors
|
|
294
|
+
introduced at that order. Concretely:
|
|
295
|
+
|
|
296
|
+
- order 1: ``(F_{(1,1)},)`` (Fisher matrix)
|
|
297
|
+
- order 2: ``(D_{(2,1)}, D_{(2,2)})``
|
|
298
|
+
- order 3: ``(T_{(3,1)}, T_{(3,2)}, T_{(3,3)})``
|
|
299
|
+
|
|
300
|
+
Here ``D_{(k,l)}`` and ``T_{(k,l)}`` denote contractions of the
|
|
301
|
+
``k``-th and ``l``-th order derivatives via the inverse covariance.
|
|
302
|
+
|
|
303
|
+
Each tensor axis has length ``p = len(self.theta0)``. The
|
|
304
|
+
additional tensors at
|
|
305
|
+
order ``k`` have parameter-axis ranks from ``k+1`` through ``2*k``.
|
|
306
|
+
"""
|
|
307
|
+
function = require_callable(self.function, context="ForecastKit.dali")
|
|
308
|
+
|
|
309
|
+
dali_tensors = build_dali(
|
|
310
|
+
function=function,
|
|
311
|
+
theta0=self.theta0,
|
|
312
|
+
cov=self.cov0,
|
|
313
|
+
method=method,
|
|
314
|
+
forecast_order=forecast_order,
|
|
315
|
+
n_workers=n_workers,
|
|
316
|
+
**dk_kwargs,
|
|
317
|
+
)
|
|
318
|
+
return dali_tensors
|
|
319
|
+
|
|
320
|
+
def gaussian_fisher(
|
|
321
|
+
self,
|
|
322
|
+
*,
|
|
323
|
+
method: str | None = None,
|
|
324
|
+
n_workers: int = 1,
|
|
325
|
+
rcond: float = 1e-12,
|
|
326
|
+
symmetrize_dcov: bool = True,
|
|
327
|
+
**dk_kwargs: Any,
|
|
328
|
+
) -> np.ndarray:
|
|
329
|
+
r"""Computes the generalized Fisher matrix for parameter-dependent mean and/or covariance.
|
|
330
|
+
|
|
331
|
+
This function computes the generalized Fisher matrix for a Gaussian
|
|
332
|
+
likelihood with parameter-dependent mean and/or covariance.
|
|
333
|
+
Uses :func:`derivkit.forecasting.fisher_gaussian.build_gaussian_fisher_matrix`.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
method: Derivative method name or alias (e.g., ``"adaptive"``, ``"finite"``).
|
|
337
|
+
n_workers: Number of workers for per-parameter parallelisation.
|
|
338
|
+
rcond: Regularization cutoff for pseudoinverse fallback in linear solves.
|
|
339
|
+
symmetrize_dcov: If ``True``, symmetrize each covariance derivative via
|
|
340
|
+
:math:`\tfrac{1}{2}(C_{,i} + C_{,i}^{\mathsf{T}})`.
|
|
341
|
+
**dk_kwargs: Forwarded to the internal derivative calls.
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
Fisher matrix with shape ``(p, p)``.
|
|
345
|
+
"""
|
|
346
|
+
cov_spec = self.cov_fn if self.cov_fn is not None else self.cov0
|
|
347
|
+
|
|
348
|
+
return build_gaussian_fisher_matrix(
|
|
349
|
+
theta0=self.theta0,
|
|
350
|
+
cov=cov_spec,
|
|
351
|
+
function=self.function,
|
|
352
|
+
method=method,
|
|
353
|
+
n_workers=n_workers,
|
|
354
|
+
rcond=rcond,
|
|
355
|
+
symmetrize_dcov=symmetrize_dcov,
|
|
356
|
+
**dk_kwargs,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
def delta_chi2_fisher(
|
|
360
|
+
self,
|
|
361
|
+
*,
|
|
362
|
+
theta: np.ndarray,
|
|
363
|
+
fisher: np.ndarray,
|
|
364
|
+
) -> float:
|
|
365
|
+
"""Computes a displacement chi-squared under the Fisher approximation.
|
|
366
|
+
|
|
367
|
+
This evaluates the standard quadratic form
|
|
368
|
+
|
|
369
|
+
``delta_chi2 = (theta - theta0)^T @ F @ (theta - theta0)``
|
|
370
|
+
|
|
371
|
+
using the provided Fisher matrix and the stored expansion point :attr:`ForecastKit.theta0`.
|
|
372
|
+
|
|
373
|
+
Args:
|
|
374
|
+
theta: Evaluation point in parameter space with shape ``(p,)``.
|
|
375
|
+
fisher: Fisher matrix with shape ``(p, p)``.
|
|
376
|
+
|
|
377
|
+
Returns:
|
|
378
|
+
Scalar delta chi-squared value.
|
|
379
|
+
"""
|
|
380
|
+
return build_delta_chi2_fisher(theta=theta, theta0=self.theta0, fisher=fisher)
|
|
381
|
+
|
|
382
|
+
def delta_chi2_dali(
|
|
383
|
+
self,
|
|
384
|
+
*,
|
|
385
|
+
theta: np.ndarray,
|
|
386
|
+
dali: dict[int, tuple[np.ndarray, ...]],
|
|
387
|
+
forecast_order: int | None = 2,
|
|
388
|
+
) -> float:
|
|
389
|
+
"""Computes a displacement chi-squared under the DALI approximation.
|
|
390
|
+
|
|
391
|
+
This evaluates a scalar ``delta_chi2`` from the displacement
|
|
392
|
+
``d = theta - theta0`` using the provided
|
|
393
|
+
DALI tensors.
|
|
394
|
+
|
|
395
|
+
The expansion point is taken from :attr:`ForecastKit.theta0`.
|
|
396
|
+
|
|
397
|
+
Args:
|
|
398
|
+
theta: Evaluation point in parameter space with shape ``(p,)``.
|
|
399
|
+
dali: DALI tensors as returned by :meth:`ForecastKit.dali`.
|
|
400
|
+
forecast_order: Order of the forecast to use for the DALI contractions.
|
|
401
|
+
|
|
402
|
+
Returns:
|
|
403
|
+
Scalar delta chi-squared value.
|
|
404
|
+
"""
|
|
405
|
+
return build_delta_chi2_dali(
|
|
406
|
+
theta=theta,
|
|
407
|
+
theta0=self.theta0,
|
|
408
|
+
dali=dali,
|
|
409
|
+
forecast_order=forecast_order,
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
def logposterior_fisher(
|
|
413
|
+
self,
|
|
414
|
+
*,
|
|
415
|
+
theta: np.ndarray,
|
|
416
|
+
fisher: np.ndarray,
|
|
417
|
+
prior_terms: Sequence[tuple[str, dict[str, Any]] | dict[str, Any]] | None = None,
|
|
418
|
+
prior_bounds: Sequence[tuple[float | None, float | None]] | None = None,
|
|
419
|
+
logprior: Callable[[np.ndarray], float] | None = None,
|
|
420
|
+
) -> float:
|
|
421
|
+
"""Computes the log posterior under the Fisher approximation.
|
|
422
|
+
|
|
423
|
+
The returned value is defined up to an additive constant in log space.
|
|
424
|
+
This corresponds to an overall multiplicative normalization of the posterior
|
|
425
|
+
density in probability space.
|
|
426
|
+
|
|
427
|
+
If no prior is provided, this returns the Fisher log-likelihood expansion
|
|
428
|
+
with a flat prior and no hard cutoffs. Priors may be provided either as a
|
|
429
|
+
pre-built ``logprior(theta)`` callable or as a lightweight prior specification
|
|
430
|
+
via ``prior_terms`` and/or ``prior_bounds``.
|
|
431
|
+
|
|
432
|
+
The expansion point is taken from the stored ``self.theta0``.
|
|
433
|
+
|
|
434
|
+
Args:
|
|
435
|
+
theta: Evaluation point in parameter space with shape ``(p,)``.
|
|
436
|
+
fisher: Fisher matrix with shape ``(p, p)``.
|
|
437
|
+
prior_terms: Prior term specification passed to the underlying prior
|
|
438
|
+
builder. Use this only if ``logprior`` is not provided.
|
|
439
|
+
prior_bounds: Global hard bounds passed to the underlying prior builder.
|
|
440
|
+
Use this only if ``logprior`` is not provided.
|
|
441
|
+
logprior: Optional custom log-prior callable. If it returns a non-finite
|
|
442
|
+
value, the posterior is treated as zero at that point and the function
|
|
443
|
+
returns ``-np.inf``. Cannot be used together with ``prior_terms`` or
|
|
444
|
+
``prior_bounds``.
|
|
445
|
+
|
|
446
|
+
Returns:
|
|
447
|
+
Scalar log posterior value, defined up to an additive constant.
|
|
448
|
+
"""
|
|
449
|
+
return build_logposterior_fisher(
|
|
450
|
+
theta=theta,
|
|
451
|
+
theta0=self.theta0,
|
|
452
|
+
fisher=fisher,
|
|
453
|
+
prior_terms=prior_terms,
|
|
454
|
+
prior_bounds=prior_bounds,
|
|
455
|
+
logprior=logprior,
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
def logposterior_dali(
|
|
459
|
+
self,
|
|
460
|
+
*,
|
|
461
|
+
theta: np.ndarray,
|
|
462
|
+
dali: dict[int, tuple[np.ndarray, ...]],
|
|
463
|
+
forecast_order: int | None = 2,
|
|
464
|
+
prior_terms: Sequence[tuple[str, dict[str, Any]] | dict[str, Any]] | None = None,
|
|
465
|
+
prior_bounds: Sequence[tuple[float | None, float | None]] | None = None,
|
|
466
|
+
logprior: Callable[[np.ndarray], float] | None = None,
|
|
467
|
+
) -> float:
|
|
468
|
+
"""Computes the log posterior (up to a constant) under the DALI approximation.
|
|
469
|
+
|
|
470
|
+
If no prior is provided, this returns the DALI log-likelihood expansion with
|
|
471
|
+
a flat prior and no hard cutoffs. Priors may be provided either as a pre-built
|
|
472
|
+
``logprior(theta)`` callable or as a lightweight prior specification via
|
|
473
|
+
``prior_terms`` and/or ``prior_bounds``.
|
|
474
|
+
|
|
475
|
+
The expansion point is taken from the stored ``self.theta0``.
|
|
476
|
+
|
|
477
|
+
Args:
|
|
478
|
+
theta: Evaluation point in parameter space with shape ``(p,)``.
|
|
479
|
+
dali: DALI tensors as returned by :meth:`ForecastKit.dali`.
|
|
480
|
+
forecast_order: Order of the forecast to use for the DALI contractions.
|
|
481
|
+
prior_terms: Prior term specification passed to the underlying prior
|
|
482
|
+
builder. Use this only if ``logprior`` is not provided.
|
|
483
|
+
prior_bounds: Global hard bounds passed to the underlying prior builder.
|
|
484
|
+
Use this only if ``logprior`` is not provided.
|
|
485
|
+
logprior: Optional custom log-prior callable. If it returns a non-finite
|
|
486
|
+
value, the posterior is treated as zero at that point and the function
|
|
487
|
+
returns ``-np.inf``. Cannot be used together with ``prior_terms`` or
|
|
488
|
+
``prior_bounds``.
|
|
489
|
+
|
|
490
|
+
Returns:
|
|
491
|
+
Scalar log posterior value, defined up to an additive constant.
|
|
492
|
+
"""
|
|
493
|
+
return build_logposterior_dali(
|
|
494
|
+
theta=theta,
|
|
495
|
+
theta0=self.theta0,
|
|
496
|
+
dali=dali,
|
|
497
|
+
forecast_order=forecast_order,
|
|
498
|
+
prior_terms=prior_terms,
|
|
499
|
+
prior_bounds=prior_bounds,
|
|
500
|
+
logprior=logprior,
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
def negative_logposterior(
|
|
504
|
+
self,
|
|
505
|
+
theta: Sequence[float] | np.ndarray,
|
|
506
|
+
*,
|
|
507
|
+
logposterior: Callable[[np.ndarray], float],
|
|
508
|
+
) -> float:
|
|
509
|
+
"""Computes the negative log-posterior at ``theta``.
|
|
510
|
+
|
|
511
|
+
This converts a log-posterior callable into the objective used by MAP
|
|
512
|
+
estimation and curvature-based methods. It simply returns
|
|
513
|
+
``-logposterior(theta)`` and validates that the result is finite.
|
|
514
|
+
|
|
515
|
+
Args:
|
|
516
|
+
theta: 1D array-like parameter vector.
|
|
517
|
+
logposterior: Callable that accepts a 1D float64 array and returns a scalar float.
|
|
518
|
+
|
|
519
|
+
Returns:
|
|
520
|
+
Negative log-posterior value as a float.
|
|
521
|
+
"""
|
|
522
|
+
return build_negative_logposterior(theta, logposterior=logposterior)
|
|
523
|
+
|
|
524
|
+
def laplace_hessian(
|
|
525
|
+
self,
|
|
526
|
+
*,
|
|
527
|
+
neg_logposterior: Callable[[np.ndarray], float],
|
|
528
|
+
theta_map: Sequence[float] | np.ndarray | None = None,
|
|
529
|
+
method: str | None = None,
|
|
530
|
+
n_workers: int = 1,
|
|
531
|
+
**dk_kwargs: Any,
|
|
532
|
+
) -> np.ndarray:
|
|
533
|
+
"""Computes the Hessian of the negative log-posterior at ``theta_map``.
|
|
534
|
+
|
|
535
|
+
The Hessian at ``theta_map`` measures the local curvature of the posterior peak.
|
|
536
|
+
In the Laplace approximation, this Hessian plays the role of a local precision
|
|
537
|
+
matrix, and its inverse provides a fast Gaussian estimate of parameter
|
|
538
|
+
uncertainties and correlations.
|
|
539
|
+
|
|
540
|
+
If ``theta_map`` is not provided, this uses the stored expansion point ``self.theta0``.
|
|
541
|
+
|
|
542
|
+
Args:
|
|
543
|
+
neg_logposterior: Callable returning the scalar negative log-posterior.
|
|
544
|
+
theta_map: Point where the curvature is evaluated (typically the MAP).
|
|
545
|
+
If ``None``, uses ``self.theta0``.
|
|
546
|
+
method: Derivative method name/alias forwarded to the calculus machinery.
|
|
547
|
+
n_workers: Outer parallelism forwarded to Hessian construction.
|
|
548
|
+
**dk_kwargs: Additional keyword arguments forwarded to
|
|
549
|
+
:meth:`derivkit.derivative_kit.DerivativeKit.differentiate`.
|
|
550
|
+
|
|
551
|
+
Returns:
|
|
552
|
+
A symmetric 2D array with shape ``(p, p)`` giving the Hessian of
|
|
553
|
+
``neg_logposterior`` evaluated at ``theta_map``.
|
|
554
|
+
"""
|
|
555
|
+
theta = self.theta0 if theta_map is None else theta_map
|
|
556
|
+
return build_laplace_hessian(
|
|
557
|
+
neg_logposterior=neg_logposterior,
|
|
558
|
+
theta_map=theta,
|
|
559
|
+
method=method,
|
|
560
|
+
n_workers=n_workers,
|
|
561
|
+
**dk_kwargs,
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
def laplace_covariance(
|
|
565
|
+
self,
|
|
566
|
+
hessian: np.ndarray,
|
|
567
|
+
*,
|
|
568
|
+
rcond: float = 1e-12,
|
|
569
|
+
) -> np.ndarray:
|
|
570
|
+
"""Computes the Laplace covariance matrix from a Hessian.
|
|
571
|
+
|
|
572
|
+
In the Laplace (Gaussian) approximation, the Hessian of the negative
|
|
573
|
+
log-posterior at the expansion point acts like a local precision matrix.
|
|
574
|
+
The approximate posterior covariance is the matrix inverse of that Hessian.
|
|
575
|
+
|
|
576
|
+
Args:
|
|
577
|
+
hessian: 2D square Hessian matrix.
|
|
578
|
+
rcond: Cutoff for small singular values used by the pseudoinverse fallback.
|
|
579
|
+
|
|
580
|
+
Returns:
|
|
581
|
+
A 2D symmetric covariance matrix with the same shape as ``hessian``.
|
|
582
|
+
"""
|
|
583
|
+
return build_laplace_covariance(hessian, rcond=rcond)
|
|
584
|
+
|
|
585
|
+
def laplace_approximation(
|
|
586
|
+
self,
|
|
587
|
+
*,
|
|
588
|
+
neg_logposterior: Callable[[np.ndarray], float],
|
|
589
|
+
theta_map: Sequence[float] | np.ndarray | None = None,
|
|
590
|
+
method: str | None = None,
|
|
591
|
+
n_workers: int = 1,
|
|
592
|
+
ensure_spd: bool = True,
|
|
593
|
+
rcond: float = 1e-12,
|
|
594
|
+
**dk_kwargs: Any,
|
|
595
|
+
) -> dict[str, Any]:
|
|
596
|
+
"""Computes a Laplace (Gaussian) approximation around ``theta_map``.
|
|
597
|
+
|
|
598
|
+
The Laplace approximation replaces the posterior near its peak with a Gaussian.
|
|
599
|
+
It does this by measuring the local curvature of the negative log-posterior
|
|
600
|
+
using its Hessian at ``theta_map``. The Hessian acts like a local precision
|
|
601
|
+
matrix, and its inverse is the approximate covariance.
|
|
602
|
+
|
|
603
|
+
If ``theta_map`` is not provided, this uses the stored expansion point ``self.theta0``.
|
|
604
|
+
|
|
605
|
+
Args:
|
|
606
|
+
neg_logposterior: Callable that accepts a 1D float64 parameter vector and
|
|
607
|
+
returns a scalar negative log-posterior value.
|
|
608
|
+
theta_map: Expansion point for the approximation. This is often the maximum a
|
|
609
|
+
posteriori estimate (MAP). If ``None``, uses ``self.theta0``.
|
|
610
|
+
method: Derivative method name/alias forwarded to the Hessian builder.
|
|
611
|
+
n_workers: Outer parallelism forwarded to Hessian construction.
|
|
612
|
+
ensure_spd: If ``True``, attempt to regularize the Hessian to be symmetric positive definite
|
|
613
|
+
(SPD) by adding diagonal jitter.
|
|
614
|
+
rcond: Cutoff for small singular values used by the pseudoinverse fallback
|
|
615
|
+
when computing the covariance.
|
|
616
|
+
**dk_kwargs: Additional keyword arguments forwarded to
|
|
617
|
+
:meth:`derivkit.derivative_kit.DerivativeKit.differentiate`.
|
|
618
|
+
|
|
619
|
+
Returns:
|
|
620
|
+
Dictionary with the Laplace approximation outputs (theta_map, neg_logposterior_at_map,
|
|
621
|
+
hessian, cov, and jitter).
|
|
622
|
+
"""
|
|
623
|
+
theta = self.theta0 if theta_map is None else theta_map
|
|
624
|
+
return build_laplace_approximation(
|
|
625
|
+
neg_logposterior=neg_logposterior,
|
|
626
|
+
theta_map=theta,
|
|
627
|
+
method=method,
|
|
628
|
+
n_workers=n_workers,
|
|
629
|
+
ensure_spd=ensure_spd,
|
|
630
|
+
rcond=rcond,
|
|
631
|
+
**dk_kwargs,
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
def getdist_fisher_gaussian(
|
|
635
|
+
self,
|
|
636
|
+
*,
|
|
637
|
+
fisher: np.ndarray,
|
|
638
|
+
names: Sequence[str] | None = None,
|
|
639
|
+
labels: Sequence[str] | None = None,
|
|
640
|
+
**kwargs: Any,
|
|
641
|
+
):
|
|
642
|
+
"""Converts a Fisher Gaussian into a GetDist :class:`getdist.gaussian_mixtures.GaussianND`.
|
|
643
|
+
|
|
644
|
+
This is a thin wrapper around
|
|
645
|
+
:func:`derivkit.forecasting.getdist_fisher_samples.fisher_to_getdist_gaussiannd`
|
|
646
|
+
that fixes the mean to the stored expansion point ``self.theta0``.
|
|
647
|
+
|
|
648
|
+
Args:
|
|
649
|
+
fisher: Fisher matrix with shape ``(p, p)`` evaluated at ``self.theta0``.
|
|
650
|
+
names: Optional parameter names (length ``p``).
|
|
651
|
+
labels: Optional parameter labels (length ``p``).
|
|
652
|
+
**kwargs: Forwarded to
|
|
653
|
+
:func:`derivkit.forecasting.getdist_fisher_samples.fisher_to_getdist_gaussiannd`
|
|
654
|
+
(e.g. ``label``, ``rcond``).
|
|
655
|
+
|
|
656
|
+
Returns:
|
|
657
|
+
A :class:`getdist.gaussian_mixtures.GaussianND` with mean ``self.theta0`` and
|
|
658
|
+
covariance given by the (pseudo-)inverse Fisher matrix.
|
|
659
|
+
"""
|
|
660
|
+
return fisher_to_getdist_gaussiannd(
|
|
661
|
+
self.theta0,
|
|
662
|
+
fisher,
|
|
663
|
+
names=names,
|
|
664
|
+
labels=labels,
|
|
665
|
+
**kwargs,
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
def getdist_fisher_samples(
|
|
669
|
+
self,
|
|
670
|
+
*,
|
|
671
|
+
fisher: np.ndarray,
|
|
672
|
+
names: Sequence[str],
|
|
673
|
+
labels: Sequence[str],
|
|
674
|
+
**kwargs: Any,
|
|
675
|
+
):
|
|
676
|
+
"""Draws GetDist :class:`getdist.MCSamples` from the Fisher Gaussian at ``self.theta0``.
|
|
677
|
+
|
|
678
|
+
This is a thin wrapper around
|
|
679
|
+
:func:`derivkit.forecasting.getdist_fisher_samples.fisher_to_getdist_samples`
|
|
680
|
+
that fixes the sampling center to the stored expansion point ``self.theta0``.
|
|
681
|
+
|
|
682
|
+
Args:
|
|
683
|
+
fisher: Fisher matrix with shape ``(p, p)`` evaluated at ``self.theta0``.
|
|
684
|
+
names: Parameter names for GetDist (length ``p``).
|
|
685
|
+
labels: Parameter labels for GetDist (length ``p``).
|
|
686
|
+
**kwargs: Forwarded to
|
|
687
|
+
:func:`derivkit.forecasting.getdist_fisher_samples.fisher_to_getdist_samples`
|
|
688
|
+
(e.g. ``n_samples``, ``seed``, ``kernel_scale``, ``prior_terms``,
|
|
689
|
+
``prior_bounds``, ``logprior``, ``hard_bounds``, ``store_loglikes``, ``label``).
|
|
690
|
+
|
|
691
|
+
Returns:
|
|
692
|
+
A :class:`getdist.MCSamples` object containing samples drawn from the Fisher Gaussian.
|
|
693
|
+
"""
|
|
694
|
+
return fisher_to_getdist_samples(
|
|
695
|
+
self.theta0,
|
|
696
|
+
fisher,
|
|
697
|
+
names=names,
|
|
698
|
+
labels=labels,
|
|
699
|
+
**kwargs,
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
def getdist_dali_importance(
|
|
703
|
+
self,
|
|
704
|
+
*,
|
|
705
|
+
dali: dict[int, tuple[np.ndarray, ...]],
|
|
706
|
+
names: Sequence[str],
|
|
707
|
+
labels: Sequence[str],
|
|
708
|
+
**kwargs: Any,
|
|
709
|
+
):
|
|
710
|
+
"""Returns GetDist :class:`getdist.MCSamples` for a DALI posterior via importance sampling.
|
|
711
|
+
|
|
712
|
+
This is a thin wrapper around
|
|
713
|
+
:func:`derivkit.forecasting.getdist_dali_samples.dali_to_getdist_importance`
|
|
714
|
+
that fixes the expansion point to ``self.theta0``.
|
|
715
|
+
|
|
716
|
+
Args:
|
|
717
|
+
dali: DALI tensors as returned by :meth:`ForecastKit.dali`.
|
|
718
|
+
names: Parameter names for GetDist (length ``p``).
|
|
719
|
+
labels: Parameter labels for GetDist (length ``p``).
|
|
720
|
+
**kwargs: Forwarded to
|
|
721
|
+
:func:`derivkit.forecasting.getdist_dali_samples.dali_to_getdist_importance`
|
|
722
|
+
(e.g. ``n_samples``, ``kernel_scale``, ``seed``,
|
|
723
|
+
``prior_terms``, ``prior_bounds``, ``logprior``,
|
|
724
|
+
``sampler_bounds``, ``label``).
|
|
725
|
+
|
|
726
|
+
Returns:
|
|
727
|
+
A :class:`getdist.MCSamples` with importance weights.
|
|
728
|
+
"""
|
|
729
|
+
kwargs = _drop_reserved_kwargs(kwargs, reserved=_RESERVED_KWARGS)
|
|
730
|
+
|
|
731
|
+
return dali_to_getdist_importance(
|
|
732
|
+
theta0=self.theta0,
|
|
733
|
+
dali=dali,
|
|
734
|
+
names=names,
|
|
735
|
+
labels=labels,
|
|
736
|
+
**kwargs,
|
|
737
|
+
)
|
|
738
|
+
|
|
739
|
+
def getdist_dali_emcee(
|
|
740
|
+
self,
|
|
741
|
+
*,
|
|
742
|
+
dali: dict[int, tuple[np.ndarray, ...]],
|
|
743
|
+
names: Sequence[str],
|
|
744
|
+
labels: Sequence[str],
|
|
745
|
+
**kwargs: Any,
|
|
746
|
+
):
|
|
747
|
+
"""Returns GetDist :class:`getdist.MCSamples` from ``emcee`` sampling of a DALI posterior.
|
|
748
|
+
|
|
749
|
+
This is a thin wrapper around
|
|
750
|
+
:func:`derivkit.forecasting.getdist_dali_samples.dali_to_getdist_emcee`
|
|
751
|
+
that fixes the expansion point to ``self.theta0``.
|
|
752
|
+
|
|
753
|
+
Args:
|
|
754
|
+
dali: DALI tensors as returned by :meth:`ForecastKit.dali`.
|
|
755
|
+
names: Parameter names for GetDist (length ``p``).
|
|
756
|
+
labels: Parameter labels for GetDist (length ``p``).
|
|
757
|
+
**kwargs: Forwarded to
|
|
758
|
+
:func:`derivkit.forecasting.getdist_dali_samples.dali_to_getdist_emcee`
|
|
759
|
+
(e.g. ``n_steps``, ``burn``, ``thin``, ``n_walkers``, ``init_scale``, ``seed``,
|
|
760
|
+
``prior_terms``, ``prior_bounds``, ``logprior``,
|
|
761
|
+
``sampler_bounds``, ``label``).
|
|
762
|
+
|
|
763
|
+
Returns:
|
|
764
|
+
A :class:`getdist.MCSamples` containing MCMC chains.
|
|
765
|
+
"""
|
|
766
|
+
kwargs = _drop_reserved_kwargs(kwargs, reserved=_RESERVED_KWARGS)
|
|
767
|
+
|
|
768
|
+
return dali_to_getdist_emcee(
|
|
769
|
+
theta0=self.theta0,
|
|
770
|
+
dali=dali,
|
|
771
|
+
names=names,
|
|
772
|
+
labels=labels,
|
|
773
|
+
**kwargs,
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
|
|
777
|
+
def _drop_reserved_kwargs(
|
|
778
|
+
kwargs: Mapping[str, Any],
|
|
779
|
+
*,
|
|
780
|
+
reserved: set[str]
|
|
781
|
+
) -> dict[str, Any]:
|
|
782
|
+
"""Removes reserved keyword arguments from a dictionary."""
|
|
783
|
+
return {k: v for k, v in kwargs.items() if k not in reserved}
|