derivkit 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. derivkit/__init__.py +22 -0
  2. derivkit/calculus/__init__.py +17 -0
  3. derivkit/calculus/calculus_core.py +152 -0
  4. derivkit/calculus/gradient.py +97 -0
  5. derivkit/calculus/hessian.py +528 -0
  6. derivkit/calculus/hyper_hessian.py +296 -0
  7. derivkit/calculus/jacobian.py +156 -0
  8. derivkit/calculus_kit.py +128 -0
  9. derivkit/derivative_kit.py +315 -0
  10. derivkit/derivatives/__init__.py +6 -0
  11. derivkit/derivatives/adaptive/__init__.py +5 -0
  12. derivkit/derivatives/adaptive/adaptive_fit.py +238 -0
  13. derivkit/derivatives/adaptive/batch_eval.py +179 -0
  14. derivkit/derivatives/adaptive/diagnostics.py +325 -0
  15. derivkit/derivatives/adaptive/grid.py +333 -0
  16. derivkit/derivatives/adaptive/polyfit_utils.py +513 -0
  17. derivkit/derivatives/adaptive/spacing.py +66 -0
  18. derivkit/derivatives/adaptive/transforms.py +245 -0
  19. derivkit/derivatives/autodiff/__init__.py +1 -0
  20. derivkit/derivatives/autodiff/jax_autodiff.py +95 -0
  21. derivkit/derivatives/autodiff/jax_core.py +217 -0
  22. derivkit/derivatives/autodiff/jax_utils.py +146 -0
  23. derivkit/derivatives/finite/__init__.py +5 -0
  24. derivkit/derivatives/finite/batch_eval.py +91 -0
  25. derivkit/derivatives/finite/core.py +84 -0
  26. derivkit/derivatives/finite/extrapolators.py +511 -0
  27. derivkit/derivatives/finite/finite_difference.py +247 -0
  28. derivkit/derivatives/finite/stencil.py +206 -0
  29. derivkit/derivatives/fornberg.py +245 -0
  30. derivkit/derivatives/local_polynomial_derivative/__init__.py +1 -0
  31. derivkit/derivatives/local_polynomial_derivative/diagnostics.py +90 -0
  32. derivkit/derivatives/local_polynomial_derivative/fit.py +199 -0
  33. derivkit/derivatives/local_polynomial_derivative/local_poly_config.py +95 -0
  34. derivkit/derivatives/local_polynomial_derivative/local_polynomial_derivative.py +205 -0
  35. derivkit/derivatives/local_polynomial_derivative/sampling.py +72 -0
  36. derivkit/derivatives/tabulated_model/__init__.py +1 -0
  37. derivkit/derivatives/tabulated_model/one_d.py +247 -0
  38. derivkit/forecast_kit.py +783 -0
  39. derivkit/forecasting/__init__.py +1 -0
  40. derivkit/forecasting/dali.py +78 -0
  41. derivkit/forecasting/expansions.py +486 -0
  42. derivkit/forecasting/fisher.py +298 -0
  43. derivkit/forecasting/fisher_gaussian.py +171 -0
  44. derivkit/forecasting/fisher_xy.py +357 -0
  45. derivkit/forecasting/forecast_core.py +313 -0
  46. derivkit/forecasting/getdist_dali_samples.py +429 -0
  47. derivkit/forecasting/getdist_fisher_samples.py +235 -0
  48. derivkit/forecasting/laplace.py +259 -0
  49. derivkit/forecasting/priors_core.py +860 -0
  50. derivkit/forecasting/sampling_utils.py +388 -0
  51. derivkit/likelihood_kit.py +114 -0
  52. derivkit/likelihoods/__init__.py +1 -0
  53. derivkit/likelihoods/gaussian.py +136 -0
  54. derivkit/likelihoods/poisson.py +176 -0
  55. derivkit/utils/__init__.py +13 -0
  56. derivkit/utils/concurrency.py +213 -0
  57. derivkit/utils/extrapolation.py +254 -0
  58. derivkit/utils/linalg.py +513 -0
  59. derivkit/utils/logger.py +26 -0
  60. derivkit/utils/numerics.py +262 -0
  61. derivkit/utils/sandbox.py +74 -0
  62. derivkit/utils/types.py +15 -0
  63. derivkit/utils/validate.py +811 -0
  64. derivkit-1.0.0.dist-info/METADATA +50 -0
  65. derivkit-1.0.0.dist-info/RECORD +68 -0
  66. derivkit-1.0.0.dist-info/WHEEL +5 -0
  67. derivkit-1.0.0.dist-info/licenses/LICENSE +21 -0
  68. derivkit-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,811 @@
1
+ """Validation utilities for DerivativeKit."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Callable
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ from numpy.typing import NDArray
10
+
11
+ from derivkit.utils.sandbox import get_partial_function
12
+
13
+ __all__ = [
14
+ "is_finite_and_differentiable",
15
+ "check_scalar_valued",
16
+ "validate_tabulated_xy",
17
+ "validate_covariance_matrix_shape",
18
+ "validate_symmetric_psd",
19
+ "validate_fisher_shape",
20
+ "validate_dali_shape",
21
+ "resolve_dali_introduced_multiplet",
22
+ "resolve_dali_assembled_multiplet",
23
+ "validate_square_matrix",
24
+ "ensure_finite",
25
+ "normalize_theta",
26
+ "validate_theta_1d_finite",
27
+ "validate_square_matrix_finite",
28
+ "resolve_covariance_input",
29
+ "flatten_matrix_c_order",
30
+ "require_callable",
31
+ ]
32
+
33
+ def is_finite_and_differentiable(
34
+ function: Callable[[float], Any],
35
+ x: float,
36
+ delta: float = 1e-5,
37
+ ) -> bool:
38
+ """Check that ``function`` is finite at ``x`` and ``x + delta``.
39
+
40
+ Evaluates without exceptions and returns finite values at both points.
41
+
42
+ Args:
43
+ function: Callable ``f(x)`` returning a scalar or array-like.
44
+ x: Probe point.
45
+ delta: Small forward step.
46
+
47
+ Returns:
48
+ A boolean which is ``True`` if the input is finite at both points
49
+ and ``False`` otherwise.
50
+ """
51
+ f0 = np.asarray(function(x))
52
+ f1 = np.asarray(function(x + delta))
53
+ return np.isfinite(f0).all() and np.isfinite(f1).all()
54
+
55
+
56
+ def check_scalar_valued(function, theta0: np.ndarray, i: int, n_workers: int):
57
+ """Helper used by ``build_gradient`` and ``build_hessian``.
58
+
59
+ Args:
60
+ function (callable): The scalar-valued function to
61
+ differentiate. It should accept a list or array of parameter
62
+ values as input and return a scalar observable value.
63
+ theta0: The points at which the derivative is evaluated.
64
+ A 1D array or list of parameter values matching the expected
65
+ input of the function.
66
+ i: Zero-based index of the parameter with respect to which to differentiate.
67
+ n_workers: Number of workers used inside
68
+ :meth:`derivkit.derivative_kit.DerivativeKit.differentiate`.
69
+ This does not parallelize across parameters.
70
+
71
+ Raises:
72
+ TypeError: If ``function`` does not return a scalar value.
73
+ """
74
+ partial_vec = get_partial_function(function, i, theta0)
75
+ _ = n_workers
76
+
77
+ probe = np.asarray(partial_vec(theta0[i]), dtype=float)
78
+ if probe.size != 1:
79
+ raise TypeError(
80
+ "build_gradient() expects a scalar-valued function; "
81
+ f"got shape {probe.shape} from full_function(params)."
82
+ )
83
+
84
+
85
+ def validate_tabulated_xy(
86
+ x: Any,
87
+ y: Any,
88
+ ) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
89
+ """Validates and converts tabulated ``x`` and ``y`` arrays into NumPy arrays.
90
+
91
+ Requirements:
92
+ - ``x`` is 1D and strictly increasing.
93
+ - ``y`` has at least 1 dimension.
94
+ - ``y.shape[0] == x.shape[0]``, but ``y`` may have arbitrary trailing
95
+ dimensions (scalar, vector, or ND output).
96
+
97
+ Args:
98
+ x: 1D array-like of x values (must be strictly increasing).
99
+ y: Array-like of y values with ``y.shape[0] == len(x)``.
100
+
101
+ Returns:
102
+ Tuple of (x_array, y_array) as NumPy arrays.
103
+
104
+ Raises:
105
+ ValueError: If input arrays do not meet the required conditions.
106
+ """
107
+ x_arr = np.asarray(x, dtype=float)
108
+ y_arr = np.asarray(y, dtype=float)
109
+
110
+ if x_arr.ndim != 1:
111
+ raise ValueError("x must be 1D.")
112
+ if x_arr.shape[0] != y_arr.shape[0]:
113
+ raise ValueError("x and y must have the same length along axis 0.")
114
+ if not np.all(np.diff(x_arr) > 0):
115
+ raise ValueError("x must be strictly increasing.")
116
+ if y_arr.ndim < 1:
117
+ raise ValueError("y must be at least 1D.")
118
+
119
+ return x_arr, y_arr
120
+
121
+
122
+ def validate_covariance_matrix_shape(cov: Any) -> NDArray[np.float64]:
123
+ """Validates covariance input shape: allows 0D/1D/2D; if 2D requires square."""
124
+ cov_arr = np.asarray(cov, dtype=float)
125
+ if cov_arr.ndim > 2:
126
+ raise ValueError(f"cov must be at most two-dimensional; got ndim={cov_arr.ndim}.")
127
+ if cov_arr.ndim == 2 and cov_arr.shape[0] != cov_arr.shape[1]:
128
+ raise ValueError(f"cov must be square; got shape={cov_arr.shape}.")
129
+ return cov_arr
130
+
131
+
132
+ def validate_symmetric_psd(
133
+ matrix: Any,
134
+ *,
135
+ sym_atol: float = 1e-12,
136
+ psd_atol: float = 1e-12,
137
+ ) -> NDArray[np.float64]:
138
+ """Validates that an input is a symmetric positive semidefinite (PSD) matrix.
139
+
140
+ This is intended for strict validation (e.g., inputs passed to GetDist, or any
141
+ code path where an indefinite covariance-like matrix should hard-fail). This
142
+ is an important validation because many algorithms assume PSD inputs, and
143
+ invalid inputs can lead to silent failures or nonsensical results.
144
+
145
+ Policy:
146
+ - Requires 2D square shape.
147
+ - Requires near-symmetry within ``sym_atol`` (raises if violated).
148
+ - After the symmetry check passes, checks PSD by computing eigenvalues of the
149
+ symmetrized matrix ``S = 0.5 * (A + A.T)`` for numerical robustness, and
150
+ requires ``min_eig(S) >= -psd_atol``.
151
+
152
+ Args:
153
+ matrix: Array-like input expected to be a covariance-like matrix.
154
+ sym_atol: Absolute tolerance for symmetry check.
155
+ psd_atol: Absolute tolerance for PSD check. Allows small negative eigenvalues
156
+ down to ``-psd_atol``.
157
+
158
+ Returns:
159
+ A NumPy array view/copy of the input, converted to ``float`` (same values as input).
160
+
161
+ Note:
162
+ The input must be symmetric within ``sym_atol``; this function does not
163
+ modify or symmetrize the returned matrix. The positive semi-definite check uses the
164
+ symmetrized form ``0.5*(A + A.T)`` only to reduce roundoff sensitivity
165
+ after the symmetry check passes.
166
+
167
+ Raises:
168
+ ValueError: If ``matrix`` is not 2D, square, is too asymmetric, contains non-finite
169
+ values, is not PSD within tolerance, if `max(|A - A.T|) > sym_atol``,
170
+ if ``min_eig(0.5*(A + A.T)) < -psd_atol``, or if eigenvalue computation fails.
171
+ """
172
+ a = np.asarray(matrix, dtype=np.float64)
173
+
174
+ if a.ndim != 2:
175
+ raise ValueError(f"matrix must be 2D; got ndim={a.ndim}.")
176
+ if a.shape[0] != a.shape[1]:
177
+ raise ValueError(f"matrix must be square; got shape={a.shape}.")
178
+ if not np.all(np.isfinite(a)):
179
+ raise ValueError("matrix contains non-finite values.")
180
+
181
+ skew = a - a.T
182
+ max_abs_skew = float(np.max(np.abs(skew))) if skew.size else 0.0
183
+ if max_abs_skew > sym_atol:
184
+ raise ValueError(
185
+ f"matrix must be symmetric within sym_atol={sym_atol:.2e}; "
186
+ f"max(|A-A^T|)={max_abs_skew:.2e}."
187
+ )
188
+
189
+ s = 0.5 * (a + a.T)
190
+ try:
191
+ evals = np.linalg.eigvalsh(s)
192
+ except np.linalg.LinAlgError as e:
193
+ raise ValueError("eigenvalue check failed for matrix (LinAlgError).") from e
194
+
195
+ min_eig = float(np.min(evals)) if evals.size else 0.0
196
+ if min_eig < -psd_atol:
197
+ raise ValueError(
198
+ f"matrix is not positive semi-definite within psd_atol={psd_atol:.2e}; min eigenvalue={min_eig:.2e}."
199
+ )
200
+
201
+ return a
202
+
203
+
204
+ def validate_fisher_shape(
205
+ theta0: NDArray[np.floating],
206
+ fisher: Any,
207
+ *,
208
+ check_finite: bool = False,
209
+ ) -> None:
210
+ """Validates Fisher matrix shape (and optionally finiteness).
211
+
212
+ Requirements:
213
+ - ``theta0`` is a non-empty 1D array of length ``p``.
214
+ - ``fisher`` is a 2D array with shape ``(p, p)``.
215
+
216
+ Args:
217
+ theta0: Fiducial parameter vector with shape ``(p,)``.
218
+ fisher: Fisher matrix with shape ``(p, p)``.
219
+ check_finite: If ``True``, require all entries of ``fisher`` to be finite.
220
+
221
+ Raises:
222
+ ValueError: If ``theta0`` is empty or if ``fisher`` does not have shape ``(p, p)``.
223
+ FloatingPointError: If ``check_finite=True`` and ``fisher`` contains non-finite values.
224
+ """
225
+ theta0_arr = np.asarray(theta0, dtype=np.float64).reshape(-1)
226
+ if theta0_arr.size == 0:
227
+ raise ValueError(
228
+ f"theta0 must be non-empty 1D; got shape {np.asarray(theta0).shape}."
229
+ )
230
+ p = int(theta0_arr.size)
231
+
232
+ f_arr = np.asarray(fisher, dtype=np.float64)
233
+ if f_arr.ndim != 2 or f_arr.shape != (p, p):
234
+ raise ValueError(
235
+ f"fisher must have shape {(p, p)}; got {f_arr.shape}.")
236
+
237
+ if check_finite and not np.isfinite(f_arr).all():
238
+ raise FloatingPointError("fisher contains non-finite values.")
239
+
240
+
241
+ def validate_dali_shape(
242
+ theta0: NDArray[np.floating],
243
+ dali: Any,
244
+ *,
245
+ check_finite: bool = False,
246
+ ) -> None:
247
+ """Validates forecast tensor shapes.
248
+
249
+ The accepted input forms match the conventions used by
250
+ :func:`derivkit.forecasting.get_forecast_tensors`:
251
+
252
+ - A dict mapping ``order -> multiplet`` for consecutive orders starting at 1.
253
+ - A single multiplet tuple.
254
+
255
+ With ``p = len(theta0)``, the required shapes are:
256
+
257
+ - order 1 multiplet: ``(F,)`` with ``F`` of shape ``(p, p)``.
258
+ - order 2 multiplet: ``(D_{(2,1)}, D_{(2,2)})`` with shapes
259
+ ``(p, p, p)`` and ``(p, p, p, p)``.
260
+ - order 3 multiplet: ``(T_{(3,1)}, T_{(3,2)}, T_{(3,3)})`` with shapes
261
+ ``(p, p, p, p)``, ``(p, p, p, p, p)``, and ``(p, p, p, p, p, p)``.
262
+
263
+ Args:
264
+ theta0: Fiducial parameter vector with shape ``(p,)``.
265
+ dali: Forecast tensors to validate. Must be either:
266
+
267
+ - ``dict[int, tuple[...]]`` where each value is a multiplet for that order, or
268
+ - ``tuple[...]`` which is a single multiplet.
269
+
270
+ check_finite: If ``True``, require all validated arrays to be finite.
271
+
272
+ Raises:
273
+ TypeError: If ``dali`` has an unsupported type, if dict keys are not ints,
274
+ or if any multiplet is not a tuple.
275
+ ValueError: If dict keys are not consecutive starting at 1, if a multiplet
276
+ has the wrong length for its order, or if any tensor has the wrong
277
+ dimension/shape.
278
+ FloatingPointError: If ``check_finite=True`` and any validated array contains
279
+ non-finite values.
280
+ """
281
+ theta0_arr = np.asarray(theta0, dtype=np.float64).reshape(-1)
282
+ if theta0_arr.size == 0:
283
+ raise ValueError(
284
+ f"theta0 must be non-empty 1D; got shape {np.asarray(theta0).shape}."
285
+ )
286
+ p = int(theta0_arr.size)
287
+
288
+ def _require_tensor(arr_like: Any, *, idx: int, expected_ndim: int) -> None:
289
+ """Validate a single forecast tensor against the expected ``(p,)*ndim`` shape.
290
+
291
+ This helper enforces the core tensor contract used by forecast tensors:
292
+ each axis has length ``p = len(theta0)`` and the tensor has a fixed rank
293
+ determined by its position in the multiplet.
294
+
295
+ Args:
296
+ arr_like: Array-like object to validate.
297
+ idx: Index of the tensor within its multiplet. Used only for error messages.
298
+ expected_ndim: Expected tensor rank.
299
+
300
+ Raises:
301
+ ValueError: If the array does not have ``expected_ndim`` dimensions or does
302
+ not have shape ``(p,) * expected_ndim``.
303
+ FloatingPointError: If ``check_finite=True`` in the enclosing scope and the
304
+ array contains non-finite values.
305
+ """
306
+ arr = np.asarray(arr_like, dtype=np.float64)
307
+ if arr.ndim != expected_ndim:
308
+ raise ValueError(
309
+ f"DALI tensor at position {idx} must have ndim"
310
+ f"={expected_ndim}; got ndim={arr.ndim}."
311
+ )
312
+ expected_shape = (p,) * expected_ndim
313
+ if arr.shape != expected_shape:
314
+ raise ValueError(
315
+ f"DALI tensor at position {idx} must have shape"
316
+ f" {expected_shape}; got {arr.shape}."
317
+ )
318
+ if check_finite and not np.isfinite(arr).all():
319
+ raise FloatingPointError(
320
+ f"DALI tensor at position {idx} contains non-finite values."
321
+ )
322
+
323
+ def _validate_order_multiplet(order: int, m: Any) -> None:
324
+ """Validate the multiplet structure for a specific forecast order.
325
+
326
+ This helper checks that ``m`` is a tuple of the correct length for the given
327
+ ``order`` and validates the shape of each tensor in the tuple.
328
+
329
+ Conventions with ``p = len(theta0)``:
330
+
331
+ - ``order == 1``: ``m == (F,)`` with ``F`` of shape ``(p, p)``.
332
+ - ``order == 2``: ``m == (D_{(2,1)}, D_{(2,2)})`` with shapes
333
+ ``(p, p, p)`` and ``(p, p, p, p)``.
334
+ - ``order == 3``: ``m == (T_{(3,1)}, T_{(3,2)}, T_{(3,3)})`` with shapes
335
+ ``(p, p, p, p)``, ``(p, p, p, p, p)``, and ``(p, p, p, p, p, p)``.
336
+
337
+ Args:
338
+ order: Forecast order associated with this multiplet.
339
+ m: Candidate multiplet tuple to validate.
340
+
341
+ Raises:
342
+ TypeError: If ``m`` is not a tuple.
343
+ ValueError: If ``order`` is unsupported or if the tuple length does not
344
+ match the expected structure for that order, or if any tensor shape
345
+ is invalid.
346
+ FloatingPointError: If ``check_finite=True`` in the enclosing scope and any
347
+ tensor contains non-finite values.
348
+ """
349
+ if not isinstance(m, tuple):
350
+ raise TypeError(f"dali[order={order}] must be a tuple; got {type(m)}.")
351
+
352
+ if order == 1:
353
+ if len(m) != 1:
354
+ raise ValueError(f"dali[1] must be a 1-tuple (F,); got length {len(m)}.")
355
+ validate_fisher_shape(theta0_arr, m[0], check_finite=check_finite)
356
+ return
357
+
358
+ if order == 2:
359
+ if len(m) != 2:
360
+ raise ValueError(f"dali[2] must be a 2-tuple (D21, D22); got length {len(m)}.")
361
+ _require_tensor(m[0], idx=0, expected_ndim=3)
362
+ _require_tensor(m[1], idx=1, expected_ndim=4)
363
+ return
364
+
365
+ if order == 3:
366
+ if len(m) != 3:
367
+ raise ValueError(
368
+ f"dali[3] must be a 3-tuple (T31, T32, T33); got length {len(m)}."
369
+ )
370
+ _require_tensor(m[0], idx=0, expected_ndim=4)
371
+ _require_tensor(m[1], idx=1, expected_ndim=5)
372
+ _require_tensor(m[2], idx=2, expected_ndim=6)
373
+ return
374
+
375
+ raise ValueError(f"Unsupported forecast order={order}. Expected 1, 2, or 3.")
376
+
377
+ def _validate_tuple_multiplet(m: tuple[Any, ...]) -> None:
378
+ """Validate a single multiplet tuple and infer its forecast order from structure.
379
+
380
+ The order is inferred using the tuple length together with the rank of the first
381
+ entry:
382
+
383
+ - ``(F,)``: length 1 and ``ndim(F) == 2``.
384
+ - ``(D_{(2,1)}, D_{(2,2)})``: length 2 and ``ndim(D_{(2,1)}) == 3``.
385
+ - ``(T_{(3,1)}, T_{(3,2)}, T_{(3,3)})``: length 3 and ``ndim(T_{(3,1)}) == 4``.
386
+
387
+ This helper exists to accept tuple inputs in a way that is consistent with the
388
+ per-order multiplet convention.
389
+
390
+ Args:
391
+ m: Candidate multiplet tuple.
392
+
393
+ Raises:
394
+ ValueError: If ``m`` is empty or does not match one of the supported tuple
395
+ structures.
396
+ TypeError/FloatingPointError: Propagated from the order-specific validation
397
+ helpers when shapes or finiteness checks fail.
398
+ """
399
+ if len(m) == 0:
400
+ raise ValueError("DALI tuple must be non-empty.")
401
+
402
+ first_ndim = np.asarray(m[0], dtype=np.float64).ndim
403
+
404
+ # Disambiguate strictly by (len, ndim of first tensor).
405
+ if len(m) == 1 and first_ndim == 2:
406
+ _validate_order_multiplet(1, m)
407
+ return
408
+ if len(m) == 2 and first_ndim == 3:
409
+ _validate_order_multiplet(2, m)
410
+ return
411
+ if len(m) == 3 and first_ndim == 4:
412
+ _validate_order_multiplet(3, m)
413
+ return
414
+
415
+ raise ValueError(
416
+ "Unrecognized DALI tuple form."
417
+ " Expected (F,)"
418
+ " or (D21,D22) or"
419
+ " (T31,T32,T33)."
420
+ )
421
+
422
+ # dict[int, tuple[...]]: get_forecast_tensors output
423
+ if isinstance(dali, dict):
424
+ if len(dali) == 0:
425
+ raise ValueError("DALI dict is empty.")
426
+
427
+ # keys must be int, consecutive, starting at 1
428
+ keys: list[int] = []
429
+ for k in dali.keys():
430
+ if not isinstance(k, int):
431
+ raise TypeError(f"DALI dict keys must be int;"
432
+ f" got {k!r} ({type(k)}).")
433
+ keys.append(k)
434
+
435
+ keys_sorted = sorted(keys)
436
+ if keys_sorted[0] != 1:
437
+ raise ValueError(f"DALI dict must start at key=1;"
438
+ f" got keys {keys_sorted}.")
439
+ if keys_sorted != list(range(1, keys_sorted[-1] + 1)):
440
+ raise ValueError(
441
+ f"DALI dict keys must be consecutive 1..K;"
442
+ f" got keys {keys_sorted}."
443
+ )
444
+
445
+ for order in keys_sorted:
446
+ _validate_order_multiplet(order, dali[order])
447
+ return
448
+
449
+ # tuple: single introduced-at-order multiplet
450
+ if isinstance(dali, tuple):
451
+ _validate_tuple_multiplet(dali)
452
+ return
453
+
454
+ raise TypeError(
455
+ "Invalid DALI type. Expected dict[int, tuple[...]] or tuple[...] "
456
+ "containing an introduced-at-order multiplet."
457
+ )
458
+
459
+
460
+ def resolve_dali_introduced_multiplet(
461
+ theta0: NDArray[np.floating],
462
+ dali: Any,
463
+ *,
464
+ forecast_order: int | None = None,
465
+ check_finite: bool = False,
466
+ ) -> tuple[int, tuple[NDArray[np.float64], ...]]:
467
+ """"Returns ``(order, multiplet)`` from any accepted forecast tensor output.
468
+
469
+ The accepted input forms match the conventions used by
470
+ :func:`derivkit.forecasting.get_forecast_tensors`:
471
+
472
+ - A dict mapping ``order -> multiplet`` for consecutive orders starting at 1.
473
+ - A single multiplet tuple.
474
+
475
+ If ``dali`` is a dict and ``forecast_order`` is not provided, the highest
476
+ available order is selected. If ``forecast_order`` is provided, it must be
477
+ present in the dict.
478
+
479
+ If ``dali`` is a tuple, the order is inferred from the tuple structure.
480
+
481
+ Args:
482
+ theta0: Fiducial parameter vector with shape ``(p,)``.
483
+ dali: Forecast tensors in one of the accepted forms.
484
+ forecast_order: Optional order selector when ``dali`` is a dict.
485
+ check_finite: If ``True``, require all selected arrays to be finite.
486
+
487
+ Returns:
488
+ Tuple ``(order, multiplet)`` where ``multiplet`` is a tuple of float64 arrays.
489
+
490
+ Raises:
491
+ TypeError/ValueError/FloatingPointError: If ``dali`` is invalid, if the
492
+ selected order does not exist, or if array shapes/values do not satisfy
493
+ the validation rules.
494
+ """
495
+ theta0_arr = np.asarray(theta0, dtype=np.float64).reshape(-1)
496
+ if theta0_arr.size == 0:
497
+ raise ValueError(
498
+ f"theta0 must be non-empty 1D;"
499
+ f" got shape {np.asarray(theta0).shape}."
500
+ )
501
+
502
+ validate_dali_shape(theta0_arr, dali, check_finite=check_finite)
503
+
504
+ if isinstance(dali, dict):
505
+ available = sorted(dali.keys())
506
+ chosen = available[-1] if forecast_order is None else int(forecast_order)
507
+ if chosen not in dali:
508
+ raise ValueError(f"forecast_order={chosen} "
509
+ f"not in DALI dict keys {available}.")
510
+ multiplet = tuple(np.asarray(x, dtype=np.float64) for x in dali[chosen])
511
+ return chosen, multiplet
512
+
513
+ # tuple: infer order from strict (len, first_ndim)
514
+ m = dali # validated as tuple by validate_dali_shape
515
+ first_ndim = np.asarray(m[0], dtype=np.float64).ndim
516
+
517
+ if len(m) == 1 and first_ndim == 2:
518
+ order = 1
519
+ elif len(m) == 2 and first_ndim == 3:
520
+ order = 2
521
+ elif len(m) == 3 and first_ndim == 4:
522
+ order = 3
523
+ else:
524
+ # Should be unreachable because validate_dali_shape already enforced.
525
+ raise RuntimeError("internal error: could not infer order from validated tuple.")
526
+
527
+ if forecast_order is not None and int(forecast_order) != order:
528
+ raise ValueError(
529
+ f"forecast_order={int(forecast_order)} does not match inferred order={order}."
530
+ )
531
+
532
+ multiplet = tuple(np.asarray(x, dtype=np.float64) for x in m)
533
+ return order, multiplet
534
+
535
+
536
+ def resolve_dali_assembled_multiplet(
537
+ theta0: NDArray[np.floating],
538
+ dali: Any,
539
+ *,
540
+ forecast_order: int | None = None,
541
+ check_finite: bool = False,
542
+ ) -> tuple[int, tuple[NDArray[np.float64], ...]]:
543
+ """Return ``(order, multiplet)`` where multiplet is assembled up to ``order``.
544
+
545
+ Accepted inputs (matching get_forecast_tensors):
546
+ - dict[int, tuple[...]]: per-order "introduced-at-order" multiplets
547
+ - tuple[...]: a single introduced-at-order multiplet
548
+
549
+ Returned multiplets are *assembled* as:
550
+ - order 1: (F,)
551
+ - order 2: (F, D1, D2)
552
+ - order 3: (F, D1, D2, T1, T2, T3)
553
+
554
+ Notes:
555
+ - Tuple inputs cannot be assembled for order>1 because they do not include F.
556
+ For order>1 evaluation, pass the dict form from get_forecast_tensors.
557
+ """
558
+ theta0_arr = np.asarray(theta0, dtype=np.float64).reshape(-1)
559
+ if theta0_arr.size == 0:
560
+ raise ValueError(
561
+ f"theta0 must be non-empty 1D; got shape {np.asarray(theta0).shape}."
562
+ )
563
+
564
+ validate_dali_shape(theta0_arr, dali, check_finite=check_finite)
565
+
566
+ if isinstance(dali, dict):
567
+ available = sorted(dali.keys())
568
+ chosen = available[-1] if forecast_order is None else int(forecast_order)
569
+ if chosen not in dali:
570
+ raise ValueError(f"forecast_order={chosen} not in DALI dict keys {available}.")
571
+ if chosen not in (1, 2, 3):
572
+ raise ValueError(f"forecast_order must be 1, 2, or 3; got {chosen}.")
573
+
574
+ # Always include Fisher
575
+ f = np.asarray(dali[1][0], dtype=np.float64)
576
+
577
+ if chosen == 1:
578
+ return 1, (f,)
579
+
580
+ d1 = np.asarray(dali[2][0], dtype=np.float64)
581
+ d2 = np.asarray(dali[2][1], dtype=np.float64)
582
+
583
+ if chosen == 2:
584
+ return 2, (f, d1, d2)
585
+
586
+ t1 = np.asarray(dali[3][0], dtype=np.float64)
587
+ t2 = np.asarray(dali[3][1], dtype=np.float64)
588
+ t3 = np.asarray(dali[3][2], dtype=np.float64)
589
+ return 3, (f, d1, d2, t1, t2, t3)
590
+
591
+ # tuple input: can only safely support Fisher-only (because order>1 tuples have no F)
592
+ m = dali # validated as tuple
593
+ first_ndim = np.asarray(m[0], dtype=np.float64).ndim
594
+
595
+ if len(m) == 1 and first_ndim == 2:
596
+ if forecast_order is not None and int(forecast_order) != 1:
597
+ raise ValueError(
598
+ "forecast_order>1 requires the dict form from get_forecast_tensors "
599
+ "(tuple multiplets do not include Fisher for order>1)."
600
+ )
601
+ f = np.asarray(m[0], dtype=np.float64)
602
+ return 1, (f,)
603
+
604
+ # If it's an introduced-at-order tuple of order 2 or 3, we refuse assembly.
605
+ if (len(m) == 2 and first_ndim == 3) or (len(m) == 3 and first_ndim == 4):
606
+ raise ValueError(
607
+ "Order>1 evaluation requires the dict form from get_forecast_tensors, "
608
+ "because introduced-at-order tuples do not include the Fisher matrix."
609
+ )
610
+
611
+ # Should be unreachable because validate_dali_shape already enforced allowed tuple forms.
612
+ raise RuntimeError("internal error: could not infer order from validated tuple.")
613
+
614
+
615
+ def validate_square_matrix(a: Any, *, name: str = "matrix") -> NDArray[np.float64]:
616
+ """Validates that the input is a 2D square matrix and return it as float array."""
617
+ arr = np.asarray(a, dtype=np.float64)
618
+ if arr.ndim != 2:
619
+ raise ValueError(f"{name} must be 2D; got ndim={arr.ndim}.")
620
+ if arr.shape[0] != arr.shape[1]:
621
+ raise ValueError(f"{name} must be square; got shape={arr.shape}.")
622
+ return arr
623
+
624
+
625
+ def ensure_finite(arr: Any, *, msg: str) -> None:
626
+ """Ensures that all values in an array are finite.
627
+
628
+ Args:
629
+ arr: Input array-like to check.
630
+ msg: Error message for the exception if non-finite values are found.
631
+
632
+ Raises:
633
+ FloatingPointError: If any value in ``arr`` is non-finite.
634
+ """
635
+ if not np.isfinite(np.asarray(arr)).all():
636
+ raise FloatingPointError(msg)
637
+
638
+
639
+ def normalize_theta(theta0: Any) -> NDArray[np.float64]:
640
+ """Ensures that data vector is a non-empty 1D float array.
641
+
642
+ Args:
643
+ theta0: Input array-like to validate and convert.
644
+
645
+ Returns:
646
+ 1D float array.
647
+
648
+ Raises:
649
+ ValueError: if ``theta0`` is empty.
650
+ """
651
+ theta = np.asarray(theta0, dtype=np.float64).reshape(-1)
652
+ if theta.size == 0:
653
+ raise ValueError("theta0 must be a non-empty 1D array.")
654
+ return theta
655
+
656
+
657
+ def validate_theta_1d_finite(theta: Any, *, name: str = "theta") -> NDArray[np.float64]:
658
+ """Validates that ``theta`` is a finite, non-empty 1D parameter vector and returns it as a float64 NumPy array.
659
+
660
+ Args:
661
+ theta: Array-like parameter vector.
662
+ name: Name used in error messages.
663
+
664
+ Returns:
665
+ A 1D float64 NumPy array containing the validated parameter vector.
666
+
667
+ Raises:
668
+ ValueError: If ``theta`` is not 1D, is empty, or contains non-finite values.
669
+ """
670
+ t = np.asarray(theta, dtype=float)
671
+ if t.ndim != 1:
672
+ raise ValueError(f"{name} must be 1D; got shape {t.shape}.")
673
+ if t.size == 0:
674
+ raise ValueError(f"{name} must be non-empty.")
675
+ if not np.all(np.isfinite(t)):
676
+ raise ValueError(f"{name} contains non-finite values.")
677
+ return t.astype(np.float64, copy=False)
678
+
679
+
680
+ def validate_square_matrix_finite(
681
+ a: Any, *, name: str = "matrix"
682
+ ) -> NDArray[np.float64]:
683
+ """Validates that ``a`` is a finite 2D square matrix and returns it as a float64 NumPy array.
684
+
685
+ Args:
686
+ a: Array-like matrix.
687
+ name: Name used in error messages.
688
+
689
+ Returns:
690
+ A 2D float64 NumPy array containing the validated square matrix.
691
+
692
+ Raises:
693
+ ValueError: If ``a`` is not 2D, is not square, or contains non-finite values.
694
+ """
695
+ m = np.asarray(a, dtype=float)
696
+ if m.ndim != 2:
697
+ raise ValueError(f"{name} must be 2D; got ndim={m.ndim}.")
698
+ if m.shape[0] != m.shape[1]:
699
+ raise ValueError(f"{name} must be square; got shape {m.shape}.")
700
+ if not np.all(np.isfinite(m)):
701
+ raise ValueError(f"{name} contains non-finite values.")
702
+ return m.astype(np.float64, copy=False)
703
+
704
+
705
+ def resolve_covariance_input(
706
+ cov: NDArray[np.float64]
707
+ | Callable[[NDArray[np.float64]], NDArray[np.float64]],
708
+ *,
709
+ theta0: NDArray[np.float64],
710
+ validate: Callable[[Any], NDArray[np.float64]],
711
+ ) -> tuple[NDArray[np.float64], Callable[[NDArray[np.float64]], NDArray[np.float64]] | None]:
712
+ """Returns the covariance-like input after validation.
713
+
714
+ Args:
715
+ cov: Covariance input. You can pass:
716
+
717
+ - A fixed square covariance array (constant covariance).
718
+ In this case the returned callable is ``None``.
719
+ - A callable that takes ``theta`` and returns a square
720
+ covariance array. In this case the function evaluates
721
+ it at ``theta0`` to get ``cov0`` and returns the callable
722
+ as ``cov_fn``.
723
+
724
+ theta0: Fiducial parameter vector. Only used when ``cov`` is a callable
725
+ covariance function (or when a callable is provided in the tuple
726
+ form). Ignored for fixed covariance arrays.
727
+ validate: A function that converts a covariance-like input into a NumPy
728
+ array and checks its basic shape (and any other rules the caller
729
+ wants). ``resolve_covariance_input`` exists to handle the different
730
+ input types for ``cov`` (array vs callable) and to consistently
731
+ produce ``(cov0, cov_fn)``; ``validate`` is only used to check or
732
+ coerce the arrays that come out of that process.
733
+
734
+ Returns:
735
+ A tuple with two items:
736
+
737
+ - ``cov0``: The validated covariance at ``theta0`` (or the provided
738
+ fixed covariance).
739
+ - ``cov_fn``: The callable covariance function if provided,
740
+ otherwise ``None``.
741
+ """
742
+ if callable(cov):
743
+ return validate(cov(theta0)), cov
744
+
745
+ return validate(cov), None
746
+
747
+
748
+ def flatten_matrix_c_order(
749
+ cov_function: Callable[[NDArray[np.float64]], NDArray[np.float64]],
750
+ theta: NDArray[np.float64],
751
+ *,
752
+ n_observables: int,
753
+ ) -> NDArray[np.float64]:
754
+ """Validates the output of a covariance function and flattens it to 1D.
755
+
756
+ This function uses the convention of flattening 2D arrays in row-major ("C") order.
757
+ The flattening is necessary when computing derivatives of covariance matrices with respect to
758
+ parameters, as the derivative routines typically operate on 1D arrays.
759
+
760
+ Args:
761
+ cov_function: Callable that takes a parameter vector and returns a covariance matrix.
762
+ theta: Parameter vector at which to evaluate the covariance function.
763
+ n_observables: Number of observables, used to validate the shape of the covariance matrix.
764
+
765
+ Returns:
766
+ A 1D NumPy array representing the flattened covariance matrix.
767
+
768
+ Raises:
769
+ ValueError: If the output of ``cov_function`` does not have the expected shape.
770
+ """
771
+ cov = validate_covariance_matrix_shape(cov_function(theta))
772
+ if cov.shape != (n_observables, n_observables):
773
+ raise ValueError(
774
+ f"cov_function(theta) must return shape {(n_observables, n_observables)}; got {cov.shape}."
775
+ )
776
+ return np.asarray(cov, dtype=np.float64).ravel(order="C")
777
+
778
+
779
+ def require_callable(
780
+ func: Callable[..., Any] | None,
781
+ *,
782
+ name: str = "function",
783
+ context: str | None = None,
784
+ hint: str | None = None,
785
+ ) -> Callable[..., Any]:
786
+ """Ensures a required callable is provided.
787
+
788
+ This is a small helper to validate inputs.
789
+ If ``func`` is ``None``, it raises a ``ValueError`` with a clear message (and an
790
+ optional context/hint to make debugging easier). If ``func`` is provided, it is
791
+ returned unchanged so the caller can use it directly.
792
+
793
+ Args:
794
+ func: Callable to validate.
795
+ name: Name shown in the error message.
796
+ context: Optional context prefix (e.g. "ForecastKit.fisher").
797
+ hint: Optional hint appended to the error message.
798
+
799
+ Returns:
800
+ The input callable.
801
+
802
+ Raises:
803
+ ValueError: If ``func`` is None.
804
+ """
805
+ if func is None:
806
+ prefix = f"{context}: " if context else ""
807
+ msg = f"{prefix}{name} must be provided."
808
+ if hint:
809
+ msg += f" {hint}"
810
+ raise ValueError(msg)
811
+ return func