derivkit 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- derivkit/__init__.py +22 -0
- derivkit/calculus/__init__.py +17 -0
- derivkit/calculus/calculus_core.py +152 -0
- derivkit/calculus/gradient.py +97 -0
- derivkit/calculus/hessian.py +528 -0
- derivkit/calculus/hyper_hessian.py +296 -0
- derivkit/calculus/jacobian.py +156 -0
- derivkit/calculus_kit.py +128 -0
- derivkit/derivative_kit.py +315 -0
- derivkit/derivatives/__init__.py +6 -0
- derivkit/derivatives/adaptive/__init__.py +5 -0
- derivkit/derivatives/adaptive/adaptive_fit.py +238 -0
- derivkit/derivatives/adaptive/batch_eval.py +179 -0
- derivkit/derivatives/adaptive/diagnostics.py +325 -0
- derivkit/derivatives/adaptive/grid.py +333 -0
- derivkit/derivatives/adaptive/polyfit_utils.py +513 -0
- derivkit/derivatives/adaptive/spacing.py +66 -0
- derivkit/derivatives/adaptive/transforms.py +245 -0
- derivkit/derivatives/autodiff/__init__.py +1 -0
- derivkit/derivatives/autodiff/jax_autodiff.py +95 -0
- derivkit/derivatives/autodiff/jax_core.py +217 -0
- derivkit/derivatives/autodiff/jax_utils.py +146 -0
- derivkit/derivatives/finite/__init__.py +5 -0
- derivkit/derivatives/finite/batch_eval.py +91 -0
- derivkit/derivatives/finite/core.py +84 -0
- derivkit/derivatives/finite/extrapolators.py +511 -0
- derivkit/derivatives/finite/finite_difference.py +247 -0
- derivkit/derivatives/finite/stencil.py +206 -0
- derivkit/derivatives/fornberg.py +245 -0
- derivkit/derivatives/local_polynomial_derivative/__init__.py +1 -0
- derivkit/derivatives/local_polynomial_derivative/diagnostics.py +90 -0
- derivkit/derivatives/local_polynomial_derivative/fit.py +199 -0
- derivkit/derivatives/local_polynomial_derivative/local_poly_config.py +95 -0
- derivkit/derivatives/local_polynomial_derivative/local_polynomial_derivative.py +205 -0
- derivkit/derivatives/local_polynomial_derivative/sampling.py +72 -0
- derivkit/derivatives/tabulated_model/__init__.py +1 -0
- derivkit/derivatives/tabulated_model/one_d.py +247 -0
- derivkit/forecast_kit.py +783 -0
- derivkit/forecasting/__init__.py +1 -0
- derivkit/forecasting/dali.py +78 -0
- derivkit/forecasting/expansions.py +486 -0
- derivkit/forecasting/fisher.py +298 -0
- derivkit/forecasting/fisher_gaussian.py +171 -0
- derivkit/forecasting/fisher_xy.py +357 -0
- derivkit/forecasting/forecast_core.py +313 -0
- derivkit/forecasting/getdist_dali_samples.py +429 -0
- derivkit/forecasting/getdist_fisher_samples.py +235 -0
- derivkit/forecasting/laplace.py +259 -0
- derivkit/forecasting/priors_core.py +860 -0
- derivkit/forecasting/sampling_utils.py +388 -0
- derivkit/likelihood_kit.py +114 -0
- derivkit/likelihoods/__init__.py +1 -0
- derivkit/likelihoods/gaussian.py +136 -0
- derivkit/likelihoods/poisson.py +176 -0
- derivkit/utils/__init__.py +13 -0
- derivkit/utils/concurrency.py +213 -0
- derivkit/utils/extrapolation.py +254 -0
- derivkit/utils/linalg.py +513 -0
- derivkit/utils/logger.py +26 -0
- derivkit/utils/numerics.py +262 -0
- derivkit/utils/sandbox.py +74 -0
- derivkit/utils/types.py +15 -0
- derivkit/utils/validate.py +811 -0
- derivkit-1.0.0.dist-info/METADATA +50 -0
- derivkit-1.0.0.dist-info/RECORD +68 -0
- derivkit-1.0.0.dist-info/WHEEL +5 -0
- derivkit-1.0.0.dist-info/licenses/LICENSE +21 -0
- derivkit-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,811 @@
|
|
|
1
|
+
"""Validation utilities for DerivativeKit."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from numpy.typing import NDArray
|
|
10
|
+
|
|
11
|
+
from derivkit.utils.sandbox import get_partial_function
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"is_finite_and_differentiable",
|
|
15
|
+
"check_scalar_valued",
|
|
16
|
+
"validate_tabulated_xy",
|
|
17
|
+
"validate_covariance_matrix_shape",
|
|
18
|
+
"validate_symmetric_psd",
|
|
19
|
+
"validate_fisher_shape",
|
|
20
|
+
"validate_dali_shape",
|
|
21
|
+
"resolve_dali_introduced_multiplet",
|
|
22
|
+
"resolve_dali_assembled_multiplet",
|
|
23
|
+
"validate_square_matrix",
|
|
24
|
+
"ensure_finite",
|
|
25
|
+
"normalize_theta",
|
|
26
|
+
"validate_theta_1d_finite",
|
|
27
|
+
"validate_square_matrix_finite",
|
|
28
|
+
"resolve_covariance_input",
|
|
29
|
+
"flatten_matrix_c_order",
|
|
30
|
+
"require_callable",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
def is_finite_and_differentiable(
|
|
34
|
+
function: Callable[[float], Any],
|
|
35
|
+
x: float,
|
|
36
|
+
delta: float = 1e-5,
|
|
37
|
+
) -> bool:
|
|
38
|
+
"""Check that ``function`` is finite at ``x`` and ``x + delta``.
|
|
39
|
+
|
|
40
|
+
Evaluates without exceptions and returns finite values at both points.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
function: Callable ``f(x)`` returning a scalar or array-like.
|
|
44
|
+
x: Probe point.
|
|
45
|
+
delta: Small forward step.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
A boolean which is ``True`` if the input is finite at both points
|
|
49
|
+
and ``False`` otherwise.
|
|
50
|
+
"""
|
|
51
|
+
f0 = np.asarray(function(x))
|
|
52
|
+
f1 = np.asarray(function(x + delta))
|
|
53
|
+
return np.isfinite(f0).all() and np.isfinite(f1).all()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def check_scalar_valued(function, theta0: np.ndarray, i: int, n_workers: int):
|
|
57
|
+
"""Helper used by ``build_gradient`` and ``build_hessian``.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
function (callable): The scalar-valued function to
|
|
61
|
+
differentiate. It should accept a list or array of parameter
|
|
62
|
+
values as input and return a scalar observable value.
|
|
63
|
+
theta0: The points at which the derivative is evaluated.
|
|
64
|
+
A 1D array or list of parameter values matching the expected
|
|
65
|
+
input of the function.
|
|
66
|
+
i: Zero-based index of the parameter with respect to which to differentiate.
|
|
67
|
+
n_workers: Number of workers used inside
|
|
68
|
+
:meth:`derivkit.derivative_kit.DerivativeKit.differentiate`.
|
|
69
|
+
This does not parallelize across parameters.
|
|
70
|
+
|
|
71
|
+
Raises:
|
|
72
|
+
TypeError: If ``function`` does not return a scalar value.
|
|
73
|
+
"""
|
|
74
|
+
partial_vec = get_partial_function(function, i, theta0)
|
|
75
|
+
_ = n_workers
|
|
76
|
+
|
|
77
|
+
probe = np.asarray(partial_vec(theta0[i]), dtype=float)
|
|
78
|
+
if probe.size != 1:
|
|
79
|
+
raise TypeError(
|
|
80
|
+
"build_gradient() expects a scalar-valued function; "
|
|
81
|
+
f"got shape {probe.shape} from full_function(params)."
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def validate_tabulated_xy(
|
|
86
|
+
x: Any,
|
|
87
|
+
y: Any,
|
|
88
|
+
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
|
|
89
|
+
"""Validates and converts tabulated ``x`` and ``y`` arrays into NumPy arrays.
|
|
90
|
+
|
|
91
|
+
Requirements:
|
|
92
|
+
- ``x`` is 1D and strictly increasing.
|
|
93
|
+
- ``y`` has at least 1 dimension.
|
|
94
|
+
- ``y.shape[0] == x.shape[0]``, but ``y`` may have arbitrary trailing
|
|
95
|
+
dimensions (scalar, vector, or ND output).
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
x: 1D array-like of x values (must be strictly increasing).
|
|
99
|
+
y: Array-like of y values with ``y.shape[0] == len(x)``.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Tuple of (x_array, y_array) as NumPy arrays.
|
|
103
|
+
|
|
104
|
+
Raises:
|
|
105
|
+
ValueError: If input arrays do not meet the required conditions.
|
|
106
|
+
"""
|
|
107
|
+
x_arr = np.asarray(x, dtype=float)
|
|
108
|
+
y_arr = np.asarray(y, dtype=float)
|
|
109
|
+
|
|
110
|
+
if x_arr.ndim != 1:
|
|
111
|
+
raise ValueError("x must be 1D.")
|
|
112
|
+
if x_arr.shape[0] != y_arr.shape[0]:
|
|
113
|
+
raise ValueError("x and y must have the same length along axis 0.")
|
|
114
|
+
if not np.all(np.diff(x_arr) > 0):
|
|
115
|
+
raise ValueError("x must be strictly increasing.")
|
|
116
|
+
if y_arr.ndim < 1:
|
|
117
|
+
raise ValueError("y must be at least 1D.")
|
|
118
|
+
|
|
119
|
+
return x_arr, y_arr
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def validate_covariance_matrix_shape(cov: Any) -> NDArray[np.float64]:
|
|
123
|
+
"""Validates covariance input shape: allows 0D/1D/2D; if 2D requires square."""
|
|
124
|
+
cov_arr = np.asarray(cov, dtype=float)
|
|
125
|
+
if cov_arr.ndim > 2:
|
|
126
|
+
raise ValueError(f"cov must be at most two-dimensional; got ndim={cov_arr.ndim}.")
|
|
127
|
+
if cov_arr.ndim == 2 and cov_arr.shape[0] != cov_arr.shape[1]:
|
|
128
|
+
raise ValueError(f"cov must be square; got shape={cov_arr.shape}.")
|
|
129
|
+
return cov_arr
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def validate_symmetric_psd(
|
|
133
|
+
matrix: Any,
|
|
134
|
+
*,
|
|
135
|
+
sym_atol: float = 1e-12,
|
|
136
|
+
psd_atol: float = 1e-12,
|
|
137
|
+
) -> NDArray[np.float64]:
|
|
138
|
+
"""Validates that an input is a symmetric positive semidefinite (PSD) matrix.
|
|
139
|
+
|
|
140
|
+
This is intended for strict validation (e.g., inputs passed to GetDist, or any
|
|
141
|
+
code path where an indefinite covariance-like matrix should hard-fail). This
|
|
142
|
+
is an important validation because many algorithms assume PSD inputs, and
|
|
143
|
+
invalid inputs can lead to silent failures or nonsensical results.
|
|
144
|
+
|
|
145
|
+
Policy:
|
|
146
|
+
- Requires 2D square shape.
|
|
147
|
+
- Requires near-symmetry within ``sym_atol`` (raises if violated).
|
|
148
|
+
- After the symmetry check passes, checks PSD by computing eigenvalues of the
|
|
149
|
+
symmetrized matrix ``S = 0.5 * (A + A.T)`` for numerical robustness, and
|
|
150
|
+
requires ``min_eig(S) >= -psd_atol``.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
matrix: Array-like input expected to be a covariance-like matrix.
|
|
154
|
+
sym_atol: Absolute tolerance for symmetry check.
|
|
155
|
+
psd_atol: Absolute tolerance for PSD check. Allows small negative eigenvalues
|
|
156
|
+
down to ``-psd_atol``.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
A NumPy array view/copy of the input, converted to ``float`` (same values as input).
|
|
160
|
+
|
|
161
|
+
Note:
|
|
162
|
+
The input must be symmetric within ``sym_atol``; this function does not
|
|
163
|
+
modify or symmetrize the returned matrix. The positive semi-definite check uses the
|
|
164
|
+
symmetrized form ``0.5*(A + A.T)`` only to reduce roundoff sensitivity
|
|
165
|
+
after the symmetry check passes.
|
|
166
|
+
|
|
167
|
+
Raises:
|
|
168
|
+
ValueError: If ``matrix`` is not 2D, square, is too asymmetric, contains non-finite
|
|
169
|
+
values, is not PSD within tolerance, if `max(|A - A.T|) > sym_atol``,
|
|
170
|
+
if ``min_eig(0.5*(A + A.T)) < -psd_atol``, or if eigenvalue computation fails.
|
|
171
|
+
"""
|
|
172
|
+
a = np.asarray(matrix, dtype=np.float64)
|
|
173
|
+
|
|
174
|
+
if a.ndim != 2:
|
|
175
|
+
raise ValueError(f"matrix must be 2D; got ndim={a.ndim}.")
|
|
176
|
+
if a.shape[0] != a.shape[1]:
|
|
177
|
+
raise ValueError(f"matrix must be square; got shape={a.shape}.")
|
|
178
|
+
if not np.all(np.isfinite(a)):
|
|
179
|
+
raise ValueError("matrix contains non-finite values.")
|
|
180
|
+
|
|
181
|
+
skew = a - a.T
|
|
182
|
+
max_abs_skew = float(np.max(np.abs(skew))) if skew.size else 0.0
|
|
183
|
+
if max_abs_skew > sym_atol:
|
|
184
|
+
raise ValueError(
|
|
185
|
+
f"matrix must be symmetric within sym_atol={sym_atol:.2e}; "
|
|
186
|
+
f"max(|A-A^T|)={max_abs_skew:.2e}."
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
s = 0.5 * (a + a.T)
|
|
190
|
+
try:
|
|
191
|
+
evals = np.linalg.eigvalsh(s)
|
|
192
|
+
except np.linalg.LinAlgError as e:
|
|
193
|
+
raise ValueError("eigenvalue check failed for matrix (LinAlgError).") from e
|
|
194
|
+
|
|
195
|
+
min_eig = float(np.min(evals)) if evals.size else 0.0
|
|
196
|
+
if min_eig < -psd_atol:
|
|
197
|
+
raise ValueError(
|
|
198
|
+
f"matrix is not positive semi-definite within psd_atol={psd_atol:.2e}; min eigenvalue={min_eig:.2e}."
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
return a
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def validate_fisher_shape(
|
|
205
|
+
theta0: NDArray[np.floating],
|
|
206
|
+
fisher: Any,
|
|
207
|
+
*,
|
|
208
|
+
check_finite: bool = False,
|
|
209
|
+
) -> None:
|
|
210
|
+
"""Validates Fisher matrix shape (and optionally finiteness).
|
|
211
|
+
|
|
212
|
+
Requirements:
|
|
213
|
+
- ``theta0`` is a non-empty 1D array of length ``p``.
|
|
214
|
+
- ``fisher`` is a 2D array with shape ``(p, p)``.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
theta0: Fiducial parameter vector with shape ``(p,)``.
|
|
218
|
+
fisher: Fisher matrix with shape ``(p, p)``.
|
|
219
|
+
check_finite: If ``True``, require all entries of ``fisher`` to be finite.
|
|
220
|
+
|
|
221
|
+
Raises:
|
|
222
|
+
ValueError: If ``theta0`` is empty or if ``fisher`` does not have shape ``(p, p)``.
|
|
223
|
+
FloatingPointError: If ``check_finite=True`` and ``fisher`` contains non-finite values.
|
|
224
|
+
"""
|
|
225
|
+
theta0_arr = np.asarray(theta0, dtype=np.float64).reshape(-1)
|
|
226
|
+
if theta0_arr.size == 0:
|
|
227
|
+
raise ValueError(
|
|
228
|
+
f"theta0 must be non-empty 1D; got shape {np.asarray(theta0).shape}."
|
|
229
|
+
)
|
|
230
|
+
p = int(theta0_arr.size)
|
|
231
|
+
|
|
232
|
+
f_arr = np.asarray(fisher, dtype=np.float64)
|
|
233
|
+
if f_arr.ndim != 2 or f_arr.shape != (p, p):
|
|
234
|
+
raise ValueError(
|
|
235
|
+
f"fisher must have shape {(p, p)}; got {f_arr.shape}.")
|
|
236
|
+
|
|
237
|
+
if check_finite and not np.isfinite(f_arr).all():
|
|
238
|
+
raise FloatingPointError("fisher contains non-finite values.")
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def validate_dali_shape(
|
|
242
|
+
theta0: NDArray[np.floating],
|
|
243
|
+
dali: Any,
|
|
244
|
+
*,
|
|
245
|
+
check_finite: bool = False,
|
|
246
|
+
) -> None:
|
|
247
|
+
"""Validates forecast tensor shapes.
|
|
248
|
+
|
|
249
|
+
The accepted input forms match the conventions used by
|
|
250
|
+
:func:`derivkit.forecasting.get_forecast_tensors`:
|
|
251
|
+
|
|
252
|
+
- A dict mapping ``order -> multiplet`` for consecutive orders starting at 1.
|
|
253
|
+
- A single multiplet tuple.
|
|
254
|
+
|
|
255
|
+
With ``p = len(theta0)``, the required shapes are:
|
|
256
|
+
|
|
257
|
+
- order 1 multiplet: ``(F,)`` with ``F`` of shape ``(p, p)``.
|
|
258
|
+
- order 2 multiplet: ``(D_{(2,1)}, D_{(2,2)})`` with shapes
|
|
259
|
+
``(p, p, p)`` and ``(p, p, p, p)``.
|
|
260
|
+
- order 3 multiplet: ``(T_{(3,1)}, T_{(3,2)}, T_{(3,3)})`` with shapes
|
|
261
|
+
``(p, p, p, p)``, ``(p, p, p, p, p)``, and ``(p, p, p, p, p, p)``.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
theta0: Fiducial parameter vector with shape ``(p,)``.
|
|
265
|
+
dali: Forecast tensors to validate. Must be either:
|
|
266
|
+
|
|
267
|
+
- ``dict[int, tuple[...]]`` where each value is a multiplet for that order, or
|
|
268
|
+
- ``tuple[...]`` which is a single multiplet.
|
|
269
|
+
|
|
270
|
+
check_finite: If ``True``, require all validated arrays to be finite.
|
|
271
|
+
|
|
272
|
+
Raises:
|
|
273
|
+
TypeError: If ``dali`` has an unsupported type, if dict keys are not ints,
|
|
274
|
+
or if any multiplet is not a tuple.
|
|
275
|
+
ValueError: If dict keys are not consecutive starting at 1, if a multiplet
|
|
276
|
+
has the wrong length for its order, or if any tensor has the wrong
|
|
277
|
+
dimension/shape.
|
|
278
|
+
FloatingPointError: If ``check_finite=True`` and any validated array contains
|
|
279
|
+
non-finite values.
|
|
280
|
+
"""
|
|
281
|
+
theta0_arr = np.asarray(theta0, dtype=np.float64).reshape(-1)
|
|
282
|
+
if theta0_arr.size == 0:
|
|
283
|
+
raise ValueError(
|
|
284
|
+
f"theta0 must be non-empty 1D; got shape {np.asarray(theta0).shape}."
|
|
285
|
+
)
|
|
286
|
+
p = int(theta0_arr.size)
|
|
287
|
+
|
|
288
|
+
def _require_tensor(arr_like: Any, *, idx: int, expected_ndim: int) -> None:
|
|
289
|
+
"""Validate a single forecast tensor against the expected ``(p,)*ndim`` shape.
|
|
290
|
+
|
|
291
|
+
This helper enforces the core tensor contract used by forecast tensors:
|
|
292
|
+
each axis has length ``p = len(theta0)`` and the tensor has a fixed rank
|
|
293
|
+
determined by its position in the multiplet.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
arr_like: Array-like object to validate.
|
|
297
|
+
idx: Index of the tensor within its multiplet. Used only for error messages.
|
|
298
|
+
expected_ndim: Expected tensor rank.
|
|
299
|
+
|
|
300
|
+
Raises:
|
|
301
|
+
ValueError: If the array does not have ``expected_ndim`` dimensions or does
|
|
302
|
+
not have shape ``(p,) * expected_ndim``.
|
|
303
|
+
FloatingPointError: If ``check_finite=True`` in the enclosing scope and the
|
|
304
|
+
array contains non-finite values.
|
|
305
|
+
"""
|
|
306
|
+
arr = np.asarray(arr_like, dtype=np.float64)
|
|
307
|
+
if arr.ndim != expected_ndim:
|
|
308
|
+
raise ValueError(
|
|
309
|
+
f"DALI tensor at position {idx} must have ndim"
|
|
310
|
+
f"={expected_ndim}; got ndim={arr.ndim}."
|
|
311
|
+
)
|
|
312
|
+
expected_shape = (p,) * expected_ndim
|
|
313
|
+
if arr.shape != expected_shape:
|
|
314
|
+
raise ValueError(
|
|
315
|
+
f"DALI tensor at position {idx} must have shape"
|
|
316
|
+
f" {expected_shape}; got {arr.shape}."
|
|
317
|
+
)
|
|
318
|
+
if check_finite and not np.isfinite(arr).all():
|
|
319
|
+
raise FloatingPointError(
|
|
320
|
+
f"DALI tensor at position {idx} contains non-finite values."
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
def _validate_order_multiplet(order: int, m: Any) -> None:
|
|
324
|
+
"""Validate the multiplet structure for a specific forecast order.
|
|
325
|
+
|
|
326
|
+
This helper checks that ``m`` is a tuple of the correct length for the given
|
|
327
|
+
``order`` and validates the shape of each tensor in the tuple.
|
|
328
|
+
|
|
329
|
+
Conventions with ``p = len(theta0)``:
|
|
330
|
+
|
|
331
|
+
- ``order == 1``: ``m == (F,)`` with ``F`` of shape ``(p, p)``.
|
|
332
|
+
- ``order == 2``: ``m == (D_{(2,1)}, D_{(2,2)})`` with shapes
|
|
333
|
+
``(p, p, p)`` and ``(p, p, p, p)``.
|
|
334
|
+
- ``order == 3``: ``m == (T_{(3,1)}, T_{(3,2)}, T_{(3,3)})`` with shapes
|
|
335
|
+
``(p, p, p, p)``, ``(p, p, p, p, p)``, and ``(p, p, p, p, p, p)``.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
order: Forecast order associated with this multiplet.
|
|
339
|
+
m: Candidate multiplet tuple to validate.
|
|
340
|
+
|
|
341
|
+
Raises:
|
|
342
|
+
TypeError: If ``m`` is not a tuple.
|
|
343
|
+
ValueError: If ``order`` is unsupported or if the tuple length does not
|
|
344
|
+
match the expected structure for that order, or if any tensor shape
|
|
345
|
+
is invalid.
|
|
346
|
+
FloatingPointError: If ``check_finite=True`` in the enclosing scope and any
|
|
347
|
+
tensor contains non-finite values.
|
|
348
|
+
"""
|
|
349
|
+
if not isinstance(m, tuple):
|
|
350
|
+
raise TypeError(f"dali[order={order}] must be a tuple; got {type(m)}.")
|
|
351
|
+
|
|
352
|
+
if order == 1:
|
|
353
|
+
if len(m) != 1:
|
|
354
|
+
raise ValueError(f"dali[1] must be a 1-tuple (F,); got length {len(m)}.")
|
|
355
|
+
validate_fisher_shape(theta0_arr, m[0], check_finite=check_finite)
|
|
356
|
+
return
|
|
357
|
+
|
|
358
|
+
if order == 2:
|
|
359
|
+
if len(m) != 2:
|
|
360
|
+
raise ValueError(f"dali[2] must be a 2-tuple (D21, D22); got length {len(m)}.")
|
|
361
|
+
_require_tensor(m[0], idx=0, expected_ndim=3)
|
|
362
|
+
_require_tensor(m[1], idx=1, expected_ndim=4)
|
|
363
|
+
return
|
|
364
|
+
|
|
365
|
+
if order == 3:
|
|
366
|
+
if len(m) != 3:
|
|
367
|
+
raise ValueError(
|
|
368
|
+
f"dali[3] must be a 3-tuple (T31, T32, T33); got length {len(m)}."
|
|
369
|
+
)
|
|
370
|
+
_require_tensor(m[0], idx=0, expected_ndim=4)
|
|
371
|
+
_require_tensor(m[1], idx=1, expected_ndim=5)
|
|
372
|
+
_require_tensor(m[2], idx=2, expected_ndim=6)
|
|
373
|
+
return
|
|
374
|
+
|
|
375
|
+
raise ValueError(f"Unsupported forecast order={order}. Expected 1, 2, or 3.")
|
|
376
|
+
|
|
377
|
+
def _validate_tuple_multiplet(m: tuple[Any, ...]) -> None:
|
|
378
|
+
"""Validate a single multiplet tuple and infer its forecast order from structure.
|
|
379
|
+
|
|
380
|
+
The order is inferred using the tuple length together with the rank of the first
|
|
381
|
+
entry:
|
|
382
|
+
|
|
383
|
+
- ``(F,)``: length 1 and ``ndim(F) == 2``.
|
|
384
|
+
- ``(D_{(2,1)}, D_{(2,2)})``: length 2 and ``ndim(D_{(2,1)}) == 3``.
|
|
385
|
+
- ``(T_{(3,1)}, T_{(3,2)}, T_{(3,3)})``: length 3 and ``ndim(T_{(3,1)}) == 4``.
|
|
386
|
+
|
|
387
|
+
This helper exists to accept tuple inputs in a way that is consistent with the
|
|
388
|
+
per-order multiplet convention.
|
|
389
|
+
|
|
390
|
+
Args:
|
|
391
|
+
m: Candidate multiplet tuple.
|
|
392
|
+
|
|
393
|
+
Raises:
|
|
394
|
+
ValueError: If ``m`` is empty or does not match one of the supported tuple
|
|
395
|
+
structures.
|
|
396
|
+
TypeError/FloatingPointError: Propagated from the order-specific validation
|
|
397
|
+
helpers when shapes or finiteness checks fail.
|
|
398
|
+
"""
|
|
399
|
+
if len(m) == 0:
|
|
400
|
+
raise ValueError("DALI tuple must be non-empty.")
|
|
401
|
+
|
|
402
|
+
first_ndim = np.asarray(m[0], dtype=np.float64).ndim
|
|
403
|
+
|
|
404
|
+
# Disambiguate strictly by (len, ndim of first tensor).
|
|
405
|
+
if len(m) == 1 and first_ndim == 2:
|
|
406
|
+
_validate_order_multiplet(1, m)
|
|
407
|
+
return
|
|
408
|
+
if len(m) == 2 and first_ndim == 3:
|
|
409
|
+
_validate_order_multiplet(2, m)
|
|
410
|
+
return
|
|
411
|
+
if len(m) == 3 and first_ndim == 4:
|
|
412
|
+
_validate_order_multiplet(3, m)
|
|
413
|
+
return
|
|
414
|
+
|
|
415
|
+
raise ValueError(
|
|
416
|
+
"Unrecognized DALI tuple form."
|
|
417
|
+
" Expected (F,)"
|
|
418
|
+
" or (D21,D22) or"
|
|
419
|
+
" (T31,T32,T33)."
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
# dict[int, tuple[...]]: get_forecast_tensors output
|
|
423
|
+
if isinstance(dali, dict):
|
|
424
|
+
if len(dali) == 0:
|
|
425
|
+
raise ValueError("DALI dict is empty.")
|
|
426
|
+
|
|
427
|
+
# keys must be int, consecutive, starting at 1
|
|
428
|
+
keys: list[int] = []
|
|
429
|
+
for k in dali.keys():
|
|
430
|
+
if not isinstance(k, int):
|
|
431
|
+
raise TypeError(f"DALI dict keys must be int;"
|
|
432
|
+
f" got {k!r} ({type(k)}).")
|
|
433
|
+
keys.append(k)
|
|
434
|
+
|
|
435
|
+
keys_sorted = sorted(keys)
|
|
436
|
+
if keys_sorted[0] != 1:
|
|
437
|
+
raise ValueError(f"DALI dict must start at key=1;"
|
|
438
|
+
f" got keys {keys_sorted}.")
|
|
439
|
+
if keys_sorted != list(range(1, keys_sorted[-1] + 1)):
|
|
440
|
+
raise ValueError(
|
|
441
|
+
f"DALI dict keys must be consecutive 1..K;"
|
|
442
|
+
f" got keys {keys_sorted}."
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
for order in keys_sorted:
|
|
446
|
+
_validate_order_multiplet(order, dali[order])
|
|
447
|
+
return
|
|
448
|
+
|
|
449
|
+
# tuple: single introduced-at-order multiplet
|
|
450
|
+
if isinstance(dali, tuple):
|
|
451
|
+
_validate_tuple_multiplet(dali)
|
|
452
|
+
return
|
|
453
|
+
|
|
454
|
+
raise TypeError(
|
|
455
|
+
"Invalid DALI type. Expected dict[int, tuple[...]] or tuple[...] "
|
|
456
|
+
"containing an introduced-at-order multiplet."
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
def resolve_dali_introduced_multiplet(
|
|
461
|
+
theta0: NDArray[np.floating],
|
|
462
|
+
dali: Any,
|
|
463
|
+
*,
|
|
464
|
+
forecast_order: int | None = None,
|
|
465
|
+
check_finite: bool = False,
|
|
466
|
+
) -> tuple[int, tuple[NDArray[np.float64], ...]]:
|
|
467
|
+
""""Returns ``(order, multiplet)`` from any accepted forecast tensor output.
|
|
468
|
+
|
|
469
|
+
The accepted input forms match the conventions used by
|
|
470
|
+
:func:`derivkit.forecasting.get_forecast_tensors`:
|
|
471
|
+
|
|
472
|
+
- A dict mapping ``order -> multiplet`` for consecutive orders starting at 1.
|
|
473
|
+
- A single multiplet tuple.
|
|
474
|
+
|
|
475
|
+
If ``dali`` is a dict and ``forecast_order`` is not provided, the highest
|
|
476
|
+
available order is selected. If ``forecast_order`` is provided, it must be
|
|
477
|
+
present in the dict.
|
|
478
|
+
|
|
479
|
+
If ``dali`` is a tuple, the order is inferred from the tuple structure.
|
|
480
|
+
|
|
481
|
+
Args:
|
|
482
|
+
theta0: Fiducial parameter vector with shape ``(p,)``.
|
|
483
|
+
dali: Forecast tensors in one of the accepted forms.
|
|
484
|
+
forecast_order: Optional order selector when ``dali`` is a dict.
|
|
485
|
+
check_finite: If ``True``, require all selected arrays to be finite.
|
|
486
|
+
|
|
487
|
+
Returns:
|
|
488
|
+
Tuple ``(order, multiplet)`` where ``multiplet`` is a tuple of float64 arrays.
|
|
489
|
+
|
|
490
|
+
Raises:
|
|
491
|
+
TypeError/ValueError/FloatingPointError: If ``dali`` is invalid, if the
|
|
492
|
+
selected order does not exist, or if array shapes/values do not satisfy
|
|
493
|
+
the validation rules.
|
|
494
|
+
"""
|
|
495
|
+
theta0_arr = np.asarray(theta0, dtype=np.float64).reshape(-1)
|
|
496
|
+
if theta0_arr.size == 0:
|
|
497
|
+
raise ValueError(
|
|
498
|
+
f"theta0 must be non-empty 1D;"
|
|
499
|
+
f" got shape {np.asarray(theta0).shape}."
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
validate_dali_shape(theta0_arr, dali, check_finite=check_finite)
|
|
503
|
+
|
|
504
|
+
if isinstance(dali, dict):
|
|
505
|
+
available = sorted(dali.keys())
|
|
506
|
+
chosen = available[-1] if forecast_order is None else int(forecast_order)
|
|
507
|
+
if chosen not in dali:
|
|
508
|
+
raise ValueError(f"forecast_order={chosen} "
|
|
509
|
+
f"not in DALI dict keys {available}.")
|
|
510
|
+
multiplet = tuple(np.asarray(x, dtype=np.float64) for x in dali[chosen])
|
|
511
|
+
return chosen, multiplet
|
|
512
|
+
|
|
513
|
+
# tuple: infer order from strict (len, first_ndim)
|
|
514
|
+
m = dali # validated as tuple by validate_dali_shape
|
|
515
|
+
first_ndim = np.asarray(m[0], dtype=np.float64).ndim
|
|
516
|
+
|
|
517
|
+
if len(m) == 1 and first_ndim == 2:
|
|
518
|
+
order = 1
|
|
519
|
+
elif len(m) == 2 and first_ndim == 3:
|
|
520
|
+
order = 2
|
|
521
|
+
elif len(m) == 3 and first_ndim == 4:
|
|
522
|
+
order = 3
|
|
523
|
+
else:
|
|
524
|
+
# Should be unreachable because validate_dali_shape already enforced.
|
|
525
|
+
raise RuntimeError("internal error: could not infer order from validated tuple.")
|
|
526
|
+
|
|
527
|
+
if forecast_order is not None and int(forecast_order) != order:
|
|
528
|
+
raise ValueError(
|
|
529
|
+
f"forecast_order={int(forecast_order)} does not match inferred order={order}."
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
multiplet = tuple(np.asarray(x, dtype=np.float64) for x in m)
|
|
533
|
+
return order, multiplet
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
def resolve_dali_assembled_multiplet(
|
|
537
|
+
theta0: NDArray[np.floating],
|
|
538
|
+
dali: Any,
|
|
539
|
+
*,
|
|
540
|
+
forecast_order: int | None = None,
|
|
541
|
+
check_finite: bool = False,
|
|
542
|
+
) -> tuple[int, tuple[NDArray[np.float64], ...]]:
|
|
543
|
+
"""Return ``(order, multiplet)`` where multiplet is assembled up to ``order``.
|
|
544
|
+
|
|
545
|
+
Accepted inputs (matching get_forecast_tensors):
|
|
546
|
+
- dict[int, tuple[...]]: per-order "introduced-at-order" multiplets
|
|
547
|
+
- tuple[...]: a single introduced-at-order multiplet
|
|
548
|
+
|
|
549
|
+
Returned multiplets are *assembled* as:
|
|
550
|
+
- order 1: (F,)
|
|
551
|
+
- order 2: (F, D1, D2)
|
|
552
|
+
- order 3: (F, D1, D2, T1, T2, T3)
|
|
553
|
+
|
|
554
|
+
Notes:
|
|
555
|
+
- Tuple inputs cannot be assembled for order>1 because they do not include F.
|
|
556
|
+
For order>1 evaluation, pass the dict form from get_forecast_tensors.
|
|
557
|
+
"""
|
|
558
|
+
theta0_arr = np.asarray(theta0, dtype=np.float64).reshape(-1)
|
|
559
|
+
if theta0_arr.size == 0:
|
|
560
|
+
raise ValueError(
|
|
561
|
+
f"theta0 must be non-empty 1D; got shape {np.asarray(theta0).shape}."
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
validate_dali_shape(theta0_arr, dali, check_finite=check_finite)
|
|
565
|
+
|
|
566
|
+
if isinstance(dali, dict):
|
|
567
|
+
available = sorted(dali.keys())
|
|
568
|
+
chosen = available[-1] if forecast_order is None else int(forecast_order)
|
|
569
|
+
if chosen not in dali:
|
|
570
|
+
raise ValueError(f"forecast_order={chosen} not in DALI dict keys {available}.")
|
|
571
|
+
if chosen not in (1, 2, 3):
|
|
572
|
+
raise ValueError(f"forecast_order must be 1, 2, or 3; got {chosen}.")
|
|
573
|
+
|
|
574
|
+
# Always include Fisher
|
|
575
|
+
f = np.asarray(dali[1][0], dtype=np.float64)
|
|
576
|
+
|
|
577
|
+
if chosen == 1:
|
|
578
|
+
return 1, (f,)
|
|
579
|
+
|
|
580
|
+
d1 = np.asarray(dali[2][0], dtype=np.float64)
|
|
581
|
+
d2 = np.asarray(dali[2][1], dtype=np.float64)
|
|
582
|
+
|
|
583
|
+
if chosen == 2:
|
|
584
|
+
return 2, (f, d1, d2)
|
|
585
|
+
|
|
586
|
+
t1 = np.asarray(dali[3][0], dtype=np.float64)
|
|
587
|
+
t2 = np.asarray(dali[3][1], dtype=np.float64)
|
|
588
|
+
t3 = np.asarray(dali[3][2], dtype=np.float64)
|
|
589
|
+
return 3, (f, d1, d2, t1, t2, t3)
|
|
590
|
+
|
|
591
|
+
# tuple input: can only safely support Fisher-only (because order>1 tuples have no F)
|
|
592
|
+
m = dali # validated as tuple
|
|
593
|
+
first_ndim = np.asarray(m[0], dtype=np.float64).ndim
|
|
594
|
+
|
|
595
|
+
if len(m) == 1 and first_ndim == 2:
|
|
596
|
+
if forecast_order is not None and int(forecast_order) != 1:
|
|
597
|
+
raise ValueError(
|
|
598
|
+
"forecast_order>1 requires the dict form from get_forecast_tensors "
|
|
599
|
+
"(tuple multiplets do not include Fisher for order>1)."
|
|
600
|
+
)
|
|
601
|
+
f = np.asarray(m[0], dtype=np.float64)
|
|
602
|
+
return 1, (f,)
|
|
603
|
+
|
|
604
|
+
# If it's an introduced-at-order tuple of order 2 or 3, we refuse assembly.
|
|
605
|
+
if (len(m) == 2 and first_ndim == 3) or (len(m) == 3 and first_ndim == 4):
|
|
606
|
+
raise ValueError(
|
|
607
|
+
"Order>1 evaluation requires the dict form from get_forecast_tensors, "
|
|
608
|
+
"because introduced-at-order tuples do not include the Fisher matrix."
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
# Should be unreachable because validate_dali_shape already enforced allowed tuple forms.
|
|
612
|
+
raise RuntimeError("internal error: could not infer order from validated tuple.")
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
def validate_square_matrix(a: Any, *, name: str = "matrix") -> NDArray[np.float64]:
|
|
616
|
+
"""Validates that the input is a 2D square matrix and return it as float array."""
|
|
617
|
+
arr = np.asarray(a, dtype=np.float64)
|
|
618
|
+
if arr.ndim != 2:
|
|
619
|
+
raise ValueError(f"{name} must be 2D; got ndim={arr.ndim}.")
|
|
620
|
+
if arr.shape[0] != arr.shape[1]:
|
|
621
|
+
raise ValueError(f"{name} must be square; got shape={arr.shape}.")
|
|
622
|
+
return arr
|
|
623
|
+
|
|
624
|
+
|
|
625
|
+
def ensure_finite(arr: Any, *, msg: str) -> None:
|
|
626
|
+
"""Ensures that all values in an array are finite.
|
|
627
|
+
|
|
628
|
+
Args:
|
|
629
|
+
arr: Input array-like to check.
|
|
630
|
+
msg: Error message for the exception if non-finite values are found.
|
|
631
|
+
|
|
632
|
+
Raises:
|
|
633
|
+
FloatingPointError: If any value in ``arr`` is non-finite.
|
|
634
|
+
"""
|
|
635
|
+
if not np.isfinite(np.asarray(arr)).all():
|
|
636
|
+
raise FloatingPointError(msg)
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
def normalize_theta(theta0: Any) -> NDArray[np.float64]:
|
|
640
|
+
"""Ensures that data vector is a non-empty 1D float array.
|
|
641
|
+
|
|
642
|
+
Args:
|
|
643
|
+
theta0: Input array-like to validate and convert.
|
|
644
|
+
|
|
645
|
+
Returns:
|
|
646
|
+
1D float array.
|
|
647
|
+
|
|
648
|
+
Raises:
|
|
649
|
+
ValueError: if ``theta0`` is empty.
|
|
650
|
+
"""
|
|
651
|
+
theta = np.asarray(theta0, dtype=np.float64).reshape(-1)
|
|
652
|
+
if theta.size == 0:
|
|
653
|
+
raise ValueError("theta0 must be a non-empty 1D array.")
|
|
654
|
+
return theta
|
|
655
|
+
|
|
656
|
+
|
|
657
|
+
def validate_theta_1d_finite(theta: Any, *, name: str = "theta") -> NDArray[np.float64]:
|
|
658
|
+
"""Validates that ``theta`` is a finite, non-empty 1D parameter vector and returns it as a float64 NumPy array.
|
|
659
|
+
|
|
660
|
+
Args:
|
|
661
|
+
theta: Array-like parameter vector.
|
|
662
|
+
name: Name used in error messages.
|
|
663
|
+
|
|
664
|
+
Returns:
|
|
665
|
+
A 1D float64 NumPy array containing the validated parameter vector.
|
|
666
|
+
|
|
667
|
+
Raises:
|
|
668
|
+
ValueError: If ``theta`` is not 1D, is empty, or contains non-finite values.
|
|
669
|
+
"""
|
|
670
|
+
t = np.asarray(theta, dtype=float)
|
|
671
|
+
if t.ndim != 1:
|
|
672
|
+
raise ValueError(f"{name} must be 1D; got shape {t.shape}.")
|
|
673
|
+
if t.size == 0:
|
|
674
|
+
raise ValueError(f"{name} must be non-empty.")
|
|
675
|
+
if not np.all(np.isfinite(t)):
|
|
676
|
+
raise ValueError(f"{name} contains non-finite values.")
|
|
677
|
+
return t.astype(np.float64, copy=False)
|
|
678
|
+
|
|
679
|
+
|
|
680
|
+
def validate_square_matrix_finite(
|
|
681
|
+
a: Any, *, name: str = "matrix"
|
|
682
|
+
) -> NDArray[np.float64]:
|
|
683
|
+
"""Validates that ``a`` is a finite 2D square matrix and returns it as a float64 NumPy array.
|
|
684
|
+
|
|
685
|
+
Args:
|
|
686
|
+
a: Array-like matrix.
|
|
687
|
+
name: Name used in error messages.
|
|
688
|
+
|
|
689
|
+
Returns:
|
|
690
|
+
A 2D float64 NumPy array containing the validated square matrix.
|
|
691
|
+
|
|
692
|
+
Raises:
|
|
693
|
+
ValueError: If ``a`` is not 2D, is not square, or contains non-finite values.
|
|
694
|
+
"""
|
|
695
|
+
m = np.asarray(a, dtype=float)
|
|
696
|
+
if m.ndim != 2:
|
|
697
|
+
raise ValueError(f"{name} must be 2D; got ndim={m.ndim}.")
|
|
698
|
+
if m.shape[0] != m.shape[1]:
|
|
699
|
+
raise ValueError(f"{name} must be square; got shape {m.shape}.")
|
|
700
|
+
if not np.all(np.isfinite(m)):
|
|
701
|
+
raise ValueError(f"{name} contains non-finite values.")
|
|
702
|
+
return m.astype(np.float64, copy=False)
|
|
703
|
+
|
|
704
|
+
|
|
705
|
+
def resolve_covariance_input(
|
|
706
|
+
cov: NDArray[np.float64]
|
|
707
|
+
| Callable[[NDArray[np.float64]], NDArray[np.float64]],
|
|
708
|
+
*,
|
|
709
|
+
theta0: NDArray[np.float64],
|
|
710
|
+
validate: Callable[[Any], NDArray[np.float64]],
|
|
711
|
+
) -> tuple[NDArray[np.float64], Callable[[NDArray[np.float64]], NDArray[np.float64]] | None]:
|
|
712
|
+
"""Returns the covariance-like input after validation.
|
|
713
|
+
|
|
714
|
+
Args:
|
|
715
|
+
cov: Covariance input. You can pass:
|
|
716
|
+
|
|
717
|
+
- A fixed square covariance array (constant covariance).
|
|
718
|
+
In this case the returned callable is ``None``.
|
|
719
|
+
- A callable that takes ``theta`` and returns a square
|
|
720
|
+
covariance array. In this case the function evaluates
|
|
721
|
+
it at ``theta0`` to get ``cov0`` and returns the callable
|
|
722
|
+
as ``cov_fn``.
|
|
723
|
+
|
|
724
|
+
theta0: Fiducial parameter vector. Only used when ``cov`` is a callable
|
|
725
|
+
covariance function (or when a callable is provided in the tuple
|
|
726
|
+
form). Ignored for fixed covariance arrays.
|
|
727
|
+
validate: A function that converts a covariance-like input into a NumPy
|
|
728
|
+
array and checks its basic shape (and any other rules the caller
|
|
729
|
+
wants). ``resolve_covariance_input`` exists to handle the different
|
|
730
|
+
input types for ``cov`` (array vs callable) and to consistently
|
|
731
|
+
produce ``(cov0, cov_fn)``; ``validate`` is only used to check or
|
|
732
|
+
coerce the arrays that come out of that process.
|
|
733
|
+
|
|
734
|
+
Returns:
|
|
735
|
+
A tuple with two items:
|
|
736
|
+
|
|
737
|
+
- ``cov0``: The validated covariance at ``theta0`` (or the provided
|
|
738
|
+
fixed covariance).
|
|
739
|
+
- ``cov_fn``: The callable covariance function if provided,
|
|
740
|
+
otherwise ``None``.
|
|
741
|
+
"""
|
|
742
|
+
if callable(cov):
|
|
743
|
+
return validate(cov(theta0)), cov
|
|
744
|
+
|
|
745
|
+
return validate(cov), None
|
|
746
|
+
|
|
747
|
+
|
|
748
|
+
def flatten_matrix_c_order(
|
|
749
|
+
cov_function: Callable[[NDArray[np.float64]], NDArray[np.float64]],
|
|
750
|
+
theta: NDArray[np.float64],
|
|
751
|
+
*,
|
|
752
|
+
n_observables: int,
|
|
753
|
+
) -> NDArray[np.float64]:
|
|
754
|
+
"""Validates the output of a covariance function and flattens it to 1D.
|
|
755
|
+
|
|
756
|
+
This function uses the convention of flattening 2D arrays in row-major ("C") order.
|
|
757
|
+
The flattening is necessary when computing derivatives of covariance matrices with respect to
|
|
758
|
+
parameters, as the derivative routines typically operate on 1D arrays.
|
|
759
|
+
|
|
760
|
+
Args:
|
|
761
|
+
cov_function: Callable that takes a parameter vector and returns a covariance matrix.
|
|
762
|
+
theta: Parameter vector at which to evaluate the covariance function.
|
|
763
|
+
n_observables: Number of observables, used to validate the shape of the covariance matrix.
|
|
764
|
+
|
|
765
|
+
Returns:
|
|
766
|
+
A 1D NumPy array representing the flattened covariance matrix.
|
|
767
|
+
|
|
768
|
+
Raises:
|
|
769
|
+
ValueError: If the output of ``cov_function`` does not have the expected shape.
|
|
770
|
+
"""
|
|
771
|
+
cov = validate_covariance_matrix_shape(cov_function(theta))
|
|
772
|
+
if cov.shape != (n_observables, n_observables):
|
|
773
|
+
raise ValueError(
|
|
774
|
+
f"cov_function(theta) must return shape {(n_observables, n_observables)}; got {cov.shape}."
|
|
775
|
+
)
|
|
776
|
+
return np.asarray(cov, dtype=np.float64).ravel(order="C")
|
|
777
|
+
|
|
778
|
+
|
|
779
|
+
def require_callable(
|
|
780
|
+
func: Callable[..., Any] | None,
|
|
781
|
+
*,
|
|
782
|
+
name: str = "function",
|
|
783
|
+
context: str | None = None,
|
|
784
|
+
hint: str | None = None,
|
|
785
|
+
) -> Callable[..., Any]:
|
|
786
|
+
"""Ensures a required callable is provided.
|
|
787
|
+
|
|
788
|
+
This is a small helper to validate inputs.
|
|
789
|
+
If ``func`` is ``None``, it raises a ``ValueError`` with a clear message (and an
|
|
790
|
+
optional context/hint to make debugging easier). If ``func`` is provided, it is
|
|
791
|
+
returned unchanged so the caller can use it directly.
|
|
792
|
+
|
|
793
|
+
Args:
|
|
794
|
+
func: Callable to validate.
|
|
795
|
+
name: Name shown in the error message.
|
|
796
|
+
context: Optional context prefix (e.g. "ForecastKit.fisher").
|
|
797
|
+
hint: Optional hint appended to the error message.
|
|
798
|
+
|
|
799
|
+
Returns:
|
|
800
|
+
The input callable.
|
|
801
|
+
|
|
802
|
+
Raises:
|
|
803
|
+
ValueError: If ``func`` is None.
|
|
804
|
+
"""
|
|
805
|
+
if func is None:
|
|
806
|
+
prefix = f"{context}: " if context else ""
|
|
807
|
+
msg = f"{prefix}{name} must be provided."
|
|
808
|
+
if hint:
|
|
809
|
+
msg += f" {hint}"
|
|
810
|
+
raise ValueError(msg)
|
|
811
|
+
return func
|