derivkit 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. derivkit/__init__.py +22 -0
  2. derivkit/calculus/__init__.py +17 -0
  3. derivkit/calculus/calculus_core.py +152 -0
  4. derivkit/calculus/gradient.py +97 -0
  5. derivkit/calculus/hessian.py +528 -0
  6. derivkit/calculus/hyper_hessian.py +296 -0
  7. derivkit/calculus/jacobian.py +156 -0
  8. derivkit/calculus_kit.py +128 -0
  9. derivkit/derivative_kit.py +315 -0
  10. derivkit/derivatives/__init__.py +6 -0
  11. derivkit/derivatives/adaptive/__init__.py +5 -0
  12. derivkit/derivatives/adaptive/adaptive_fit.py +238 -0
  13. derivkit/derivatives/adaptive/batch_eval.py +179 -0
  14. derivkit/derivatives/adaptive/diagnostics.py +325 -0
  15. derivkit/derivatives/adaptive/grid.py +333 -0
  16. derivkit/derivatives/adaptive/polyfit_utils.py +513 -0
  17. derivkit/derivatives/adaptive/spacing.py +66 -0
  18. derivkit/derivatives/adaptive/transforms.py +245 -0
  19. derivkit/derivatives/autodiff/__init__.py +1 -0
  20. derivkit/derivatives/autodiff/jax_autodiff.py +95 -0
  21. derivkit/derivatives/autodiff/jax_core.py +217 -0
  22. derivkit/derivatives/autodiff/jax_utils.py +146 -0
  23. derivkit/derivatives/finite/__init__.py +5 -0
  24. derivkit/derivatives/finite/batch_eval.py +91 -0
  25. derivkit/derivatives/finite/core.py +84 -0
  26. derivkit/derivatives/finite/extrapolators.py +511 -0
  27. derivkit/derivatives/finite/finite_difference.py +247 -0
  28. derivkit/derivatives/finite/stencil.py +206 -0
  29. derivkit/derivatives/fornberg.py +245 -0
  30. derivkit/derivatives/local_polynomial_derivative/__init__.py +1 -0
  31. derivkit/derivatives/local_polynomial_derivative/diagnostics.py +90 -0
  32. derivkit/derivatives/local_polynomial_derivative/fit.py +199 -0
  33. derivkit/derivatives/local_polynomial_derivative/local_poly_config.py +95 -0
  34. derivkit/derivatives/local_polynomial_derivative/local_polynomial_derivative.py +205 -0
  35. derivkit/derivatives/local_polynomial_derivative/sampling.py +72 -0
  36. derivkit/derivatives/tabulated_model/__init__.py +1 -0
  37. derivkit/derivatives/tabulated_model/one_d.py +247 -0
  38. derivkit/forecast_kit.py +783 -0
  39. derivkit/forecasting/__init__.py +1 -0
  40. derivkit/forecasting/dali.py +78 -0
  41. derivkit/forecasting/expansions.py +486 -0
  42. derivkit/forecasting/fisher.py +298 -0
  43. derivkit/forecasting/fisher_gaussian.py +171 -0
  44. derivkit/forecasting/fisher_xy.py +357 -0
  45. derivkit/forecasting/forecast_core.py +313 -0
  46. derivkit/forecasting/getdist_dali_samples.py +429 -0
  47. derivkit/forecasting/getdist_fisher_samples.py +235 -0
  48. derivkit/forecasting/laplace.py +259 -0
  49. derivkit/forecasting/priors_core.py +860 -0
  50. derivkit/forecasting/sampling_utils.py +388 -0
  51. derivkit/likelihood_kit.py +114 -0
  52. derivkit/likelihoods/__init__.py +1 -0
  53. derivkit/likelihoods/gaussian.py +136 -0
  54. derivkit/likelihoods/poisson.py +176 -0
  55. derivkit/utils/__init__.py +13 -0
  56. derivkit/utils/concurrency.py +213 -0
  57. derivkit/utils/extrapolation.py +254 -0
  58. derivkit/utils/linalg.py +513 -0
  59. derivkit/utils/logger.py +26 -0
  60. derivkit/utils/numerics.py +262 -0
  61. derivkit/utils/sandbox.py +74 -0
  62. derivkit/utils/types.py +15 -0
  63. derivkit/utils/validate.py +811 -0
  64. derivkit-1.0.0.dist-info/METADATA +50 -0
  65. derivkit-1.0.0.dist-info/RECORD +68 -0
  66. derivkit-1.0.0.dist-info/WHEEL +5 -0
  67. derivkit-1.0.0.dist-info/licenses/LICENSE +21 -0
  68. derivkit-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1 @@
1
+ """Forecasting utilities."""
@@ -0,0 +1,78 @@
1
+ """DALI forecasting utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Callable
6
+
7
+ import numpy as np
8
+
9
+ from derivkit.forecasting.forecast_core import get_forecast_tensors
10
+ from derivkit.utils.types import Array, ArrayLike1D, ArrayLike2D, FloatArray
11
+
12
+ __all__ = [
13
+ "build_dali",
14
+ ]
15
+
16
+
17
+ def build_dali(
18
+ function: Callable[[ArrayLike1D], np.floating | Array],
19
+ theta0: ArrayLike1D,
20
+ cov: ArrayLike2D,
21
+ *,
22
+ method: str | None = None,
23
+ forecast_order: int = 2,
24
+ n_workers: int = 1,
25
+ **dk_kwargs: Any,
26
+ ) -> dict[int, tuple[FloatArray, ...]]:
27
+ """Builds the DALI expansion for the given model of the supplied order.
28
+
29
+ Args:
30
+ function: The scalar or vector-valued function to
31
+ differentiate. It should accept a list or array of parameter
32
+ values as input and return either a scalar or a
33
+ :class:`np.ndarray` of observable values.
34
+ theta0: The expansion point (a 1D parameter vector) at which
35
+ derivatives are evaluated. Accepts a list/array of length ``p``,
36
+ with ``p`` the number of parameters.
37
+ cov: The covariance matrix of the observables. Should be a square
38
+ matrix with shape ``(n_observables, n_observables)``, where
39
+ ``n_observables`` is the number of observables returned by the
40
+ function.
41
+ method: Method name or alias (e.g., ``"adaptive"``, ``"finite"``).
42
+ If ``None``, the :class:`derivkit.derivative_kit.DerivativeKit`
43
+ default is used.
44
+ forecast_order: The requested order of the forecast.
45
+ Currently supported values and their meaning are given in
46
+ :data:`derivkit.forecasting.forecast_core.SUPPORTED_FORECAST_ORDERS`.
47
+ n_workers: Number of workers for per-parameter parallelization/threads.
48
+ Default ``1`` (serial). Inner batch evaluation is kept serial to
49
+ avoid oversubscription.
50
+ **dk_kwargs: Additional keyword arguments passed to
51
+ :class:`derivkit.calculus_kit.CalculusKit`.
52
+
53
+ Returns:
54
+ A dictionary with the keys equal to the order of the DALI expansion
55
+ and values equal to the DALI multiplet at that order.
56
+
57
+ For each forecast order k, the returned multiplet contains the tensors
58
+ introduced at that order. Concretely:
59
+
60
+ - order 1: ``(F_{(1,1)},)`` (DALI singlet: Fisher matrix)
61
+ - order 2: ``(D_{(2,1)}, D_{(2,2)})`` (DALI doublet tensors)
62
+ - order 3: ``(T_{(3,1)}, T_{(3,2)}, T_{(3,3)})`` (DALI triplet tensors)
63
+
64
+ Here ``D_{(k,l)}`` and ``T_{(k,l)}`` denote contractions of the
65
+ ``k``-th and ``l``-th order derivatives via the inverse covariance.
66
+
67
+ Each tensor axis has length ``p = len(theta0)``. The additional tensors at
68
+ order ``k`` have parameter-axis ranks from ``k+1`` through ``2*k``.
69
+ """
70
+ return get_forecast_tensors(
71
+ function,
72
+ theta0,
73
+ cov,
74
+ method=method,
75
+ forecast_order=forecast_order,
76
+ n_workers=n_workers,
77
+ **dk_kwargs,
78
+ )
@@ -0,0 +1,486 @@
1
+ """Utilities for evaluating Fisher and DALI likelihood expansions.
2
+
3
+ This module provides functional helpers to evaluate approximate likelihoods
4
+ (or posterior) surfaces from forecast tensors.
5
+
6
+ Conventions
7
+ -----------
8
+
9
+ This module uses a single convention throughout:
10
+
11
+ - ``delta_chi2`` is defined from the displacement ``d = theta - theta0``.
12
+ - The log posterior is returned (up to an additive constant) as::
13
+
14
+ log p(theta) = logprior(theta) - 0.5 * delta_chi2(theta)
15
+
16
+ With the forecast tensors returned by :func:`derivkit.forecasting.get_forecast_tensors`
17
+ (using the introduced-at-order convention):
18
+
19
+ - ``dali[1] == (F,)``
20
+ - ``dali[2] == (D1, D2)``
21
+ - ``dali[3] == (T1, T2, T3)``
22
+
23
+ the DALI ``delta_chi2`` is:
24
+
25
+ - order 1 (Fisher): ``d.T @ F @ d``
26
+ - order 2 (doublet): add ``(1/3) D1[d,d,d] + (1/12) D2[d,d,d,d]``
27
+ - order 3 (triplet): add ``(1/3) T1[d^4] + (1/6) T2[d^5] + (1/36) T3[d^6]``
28
+
29
+ GetDist convention
30
+ ------------------
31
+
32
+ GetDist expects ``loglikes`` to be the negative log posterior, up to a constant.
33
+ Since this module defines::
34
+
35
+ log p = logprior - 0.5 * delta_chi2 + const
36
+
37
+ a compatible choice for GetDist is::
38
+
39
+ loglikes = -logprior + 0.5 * delta_chi2
40
+
41
+ (optionally shifted by an additive constant for numerical stability).
42
+ """
43
+
44
+ from __future__ import annotations
45
+
46
+ from typing import Any, Callable, Sequence
47
+
48
+ import numpy as np
49
+ from numpy.typing import NDArray
50
+
51
+ from derivkit.forecasting.forecast_core import SUPPORTED_FORECAST_ORDERS
52
+ from derivkit.forecasting.priors_core import build_prior
53
+ from derivkit.utils.validate import (
54
+ validate_dali_shape,
55
+ validate_fisher_shape,
56
+ )
57
+
58
+ __all__ = [
59
+ "build_subspace",
60
+ "build_delta_chi2_fisher",
61
+ "build_delta_chi2_dali",
62
+ "build_logposterior_fisher",
63
+ "build_logposterior_dali",
64
+ ]
65
+
66
+
67
+ def _validate_and_normalize_idx(idx: Sequence[int], *, p: int) -> list[int]:
68
+ """Validates and normalizes a sequence of parameter indices.
69
+
70
+ Args:
71
+ idx: Sequence of parameter indices.
72
+ p: Total number of parameters.
73
+
74
+ Returns:
75
+ Indices as a list of Python ``int``.
76
+
77
+ Raises:
78
+ TypeError: If any entry of ``idx`` is not an integer.
79
+ IndexError: If any index is out of bounds for ``p``.
80
+ """
81
+ idx_list = list(idx)
82
+ if not all(isinstance(i, (int, np.integer)) for i in idx_list):
83
+ raise TypeError("idx must contain integer indices")
84
+ if any((i < 0) or (i >= p) for i in idx_list):
85
+ raise IndexError(f"idx contains out-of-bounds indices for p={p}: {idx_list}")
86
+ return idx_list
87
+
88
+
89
+ def _slice_param_tensor(t: NDArray[np.floating], idx: list[int]) -> NDArray[np.float64]:
90
+ """Slices a parameter-space tensor along all axes.
91
+
92
+ This helper assumes ``t`` is a tensor whose every axis indexes parameters,
93
+ e.g. Fisher ``(p, p)``, a cubic tensor ``(p, p, p)``, etc.
94
+
95
+ Args:
96
+ t: Tensor to slice.
97
+ idx: Parameter indices to keep.
98
+
99
+ Returns:
100
+ Sliced tensor as ``float64``.
101
+ """
102
+ t64 = np.asarray(t, np.float64)
103
+ sl = np.ix_(*([idx] * t64.ndim))
104
+ return t64[sl]
105
+
106
+
107
+ def build_subspace(
108
+ idx: Sequence[int],
109
+ *,
110
+ theta0: NDArray[np.floating],
111
+ fisher: NDArray[np.floating] | None = None,
112
+ dali: dict[int, tuple[NDArray[np.floating], ...]] | None = None,
113
+ ) -> dict[str, Any]:
114
+ """Extracts a parameter subspace for Fisher or DALI expansions.
115
+
116
+ This returns a *slice* through parameter space: parameters not in ``idx`` are
117
+ held fixed at their expansion values. This is not a marginalization.
118
+
119
+ Provide exactly one of ``fisher`` or ``dali``:
120
+
121
+ - Fisher: ``fisher`` has shape ``(p, p)`` and the return dict contains
122
+ ``{"theta0": theta0_sub, "fisher": fisher_sub}``.
123
+ - DALI: ``dali`` is the dict form returned by
124
+ :func:`derivkit.forecasting.get_forecast_tensors` using the introduced-at-order
125
+ convention, and the return dict contains ``{"theta0": theta0_sub, "dali": dali_sub}``.
126
+
127
+ Args:
128
+ idx: Parameter indices to extract.
129
+ theta0: Expansion point of shape ``(p,)``.
130
+ fisher: Fisher matrix of shape ``(p, p)``.
131
+ dali: Forecast tensors as a dict mapping ``order -> multiplet``.
132
+
133
+ Returns:
134
+ A dict containing the sliced objects. Always includes ``"theta0"``.
135
+ Includes ``"fisher"`` if ``fisher`` was provided, or ``"dali"`` if ``dali``
136
+ was provided.
137
+
138
+ Raises:
139
+ ValueError: If not exactly one of ``fisher`` or ``dali`` is provided.
140
+ TypeError: If ``idx`` contains non-integers, or if ``dali`` is not a dict.
141
+ IndexError: If any index in ``idx`` is out of bounds.
142
+ ValueError: If the provided arrays have incompatible shapes.
143
+ """
144
+ theta0_arr = np.asarray(theta0, np.float64).reshape(-1)
145
+ p = int(theta0_arr.shape[0])
146
+
147
+ if (fisher is None) == (dali is None):
148
+ raise ValueError("Provide exactly one of `fisher` or `dali`.")
149
+
150
+ if fisher is not None:
151
+ fisher_arr = np.asarray(fisher, np.float64)
152
+ validate_fisher_shape(theta0_arr, fisher_arr)
153
+ idx_list = _validate_and_normalize_idx(idx, p=p)
154
+ return {
155
+ "theta0": theta0_arr[idx_list],
156
+ "fisher": fisher_arr[np.ix_(idx_list, idx_list)],
157
+ }
158
+
159
+ # dali is not None
160
+ if not isinstance(dali, dict):
161
+ raise TypeError("dali must be the dict form returned by get_forecast_tensors.")
162
+
163
+ validate_dali_shape(theta0_arr, dali)
164
+ idx_list = _validate_and_normalize_idx(idx, p=p)
165
+
166
+ dali_sub: dict[int, tuple[NDArray[np.float64], ...]] = {}
167
+ for k, multiplet in dali.items():
168
+ dali_sub[int(k)] = tuple(_slice_param_tensor(t, idx_list) for t in multiplet)
169
+
170
+ return {
171
+ "theta0": theta0_arr[idx_list],
172
+ "dali": dali_sub,
173
+ }
174
+
175
+
176
+ def build_delta_chi2_fisher(
177
+ theta: NDArray[np.floating],
178
+ theta0: NDArray[np.floating],
179
+ fisher: NDArray[np.floating],
180
+ ) -> float:
181
+ """Computes a displacement chi-squared under the Fisher approximation.
182
+
183
+ Args:
184
+ theta: Evaluation point in parameter space. This is the trial parameter vector
185
+ at which the Fisher expansion is evaluated.
186
+ theta0: Expansion point (reference parameter vector). The Fisher matrix
187
+ is assumed to have been computed at this point, and the expansion is
188
+ taken in the displacement ``theta - theta0``.
189
+ fisher: Fisher matrix with shape ``(p, p)`` with ``p`` the number of parameters.
190
+
191
+ Returns:
192
+ The scalar delta chi-squared value between ``theta`` and ``theta_0``.
193
+ """
194
+ theta = np.asarray(theta, float)
195
+ theta0 = np.asarray(theta0, float)
196
+ fisher = np.asarray(fisher, float)
197
+ validate_fisher_shape(theta0, fisher)
198
+
199
+ displacement = theta - theta0
200
+ return float(displacement @ fisher @ displacement)
201
+
202
+
203
+ def _resolve_logprior(
204
+ *,
205
+ prior_terms: Sequence[tuple[str, dict[str, Any]] | dict[str, Any]] | None,
206
+ prior_bounds: Sequence[tuple[float | None, float | None]] | None,
207
+ logprior: Callable[[NDArray[np.floating]], float] | None,
208
+ ) -> Callable[[NDArray[np.floating]], float] | None:
209
+ """Determines which log-prior to use for likelihoods expansion evaluation.
210
+
211
+ This helper allows callers to specify a prior in one of two ways: either by passing
212
+ a pre-built ``logprior(theta)`` callable directly, or by providing a lightweight
213
+ prior specification (``prior_terms`` and/or ``prior_bounds``) that is compiled
214
+ internally using :func:`derivkit.forecasting.priors.core.build_prior`.
215
+
216
+ Only one of these input styles may be used at a time. Providing both results in a
217
+ ``ValueError``. If neither is provided, the function returns ``None``, indicating
218
+ that no prior is applied.
219
+
220
+ Args:
221
+ prior_terms: Prior term specification passed to
222
+ :func:`derivkit.forecasting.priors.core.build_prior`.
223
+ prior_bounds: Global hard bounds passed to
224
+ :func:`derivkit.forecasting.priors.core.build_prior`.
225
+ logprior: Optional custom log-prior callable. If it returns a non-finite value,
226
+ the posterior is treated as zero at that point and the function returns ``-np.inf``.
227
+
228
+ Returns:
229
+ A function that computes the log-prior contribution to the posterior, or
230
+ ``None`` if the likelihoods should be evaluated without a prior.
231
+ """
232
+ if logprior is not None and (prior_terms is not None or prior_bounds is not None):
233
+ raise ValueError("Use either `logprior` or (`prior_terms`/`prior_bounds`), not both.")
234
+
235
+ if logprior is None and (prior_terms is not None or prior_bounds is not None):
236
+ return build_prior(terms=prior_terms, bounds=prior_bounds)
237
+
238
+ return logprior
239
+
240
+
241
+ def build_logposterior_fisher(
242
+ theta: NDArray[np.floating],
243
+ theta0: NDArray[np.floating],
244
+ fisher: NDArray[np.floating],
245
+ *,
246
+ prior_terms: Sequence[tuple[str, dict[str, Any]] | dict[str, Any]] | None = None,
247
+ prior_bounds: Sequence[tuple[float | None, float | None]] | None = None,
248
+ logprior: Callable[[NDArray[np.floating]], float] | None = None,
249
+ ) -> float:
250
+ """Computes the log posterior under the Fisher approximation.
251
+
252
+ The returned value is defined up to an additive constant in log space.
253
+ This corresponds to an overall multiplicative normalization of the posterior
254
+ density in probability space.
255
+
256
+ If no prior is provided, this returns the Fisher log-likelihoods expansion
257
+ with a flat prior and no hard cutoffs.
258
+
259
+ The Fisher approximation corresponds to a purely quadratic ``delta_chi2`` surface::
260
+
261
+ delta_chi2 = d.T @ F @ d
262
+
263
+ so the log posterior is::
264
+
265
+ log p = -0.5 * delta_chi2
266
+
267
+ This normalization is equivalent to the ``convention="delta_chi2"`` used for DALI.
268
+ In this interpretation, fixed ``delta_chi2`` values correspond to fixed probability content
269
+ (e.g. 68%, 95%) in parameter space, as for a Gaussian likelihoods.
270
+ See :func:`derivkit.forecasting.expansions.delta_chi2_dali` for the corresponding
271
+ DALI definition of ``delta_chi2`` and its supported conventions.
272
+
273
+ Unlike the DALI case, there is no alternative normalization for the Fisher
274
+ approximation: the likelihoods is strictly Gaussian and fully described by the
275
+ quadratic form.
276
+
277
+ Args:
278
+ theta: Evaluation point in parameter space. This is the trial parameter vector
279
+ at which the Fisher/DALI expansion is evaluated.
280
+ theta0: Expansion point (reference parameter vector). The Fisher matrix and any
281
+ DALI tensors are assumed to have been computed at this point, and the
282
+ expansion is taken in the displacement ``theta - theta0``.
283
+ fisher: Fisher matrix with shape ``(p, p)`` with ``p`` the number of parameters.
284
+ prior_terms: prior term specification passed to
285
+ :func:`derivkit.forecasting.priors.core.build_prior`.
286
+ prior_bounds: Global hard bounds passed to
287
+ :func:`derivkit.forecasting.priors.core.build_prior`.
288
+ logprior: Optional custom log-prior callable. If it returns a non-finite value,
289
+ the posterior is treated as zero at that point and the function returns ``-np.inf``.
290
+
291
+ Returns:
292
+ Scalar log posterior value, defined up to an additive constant.
293
+ """
294
+ theta = np.asarray(theta, float)
295
+ theta0 = np.asarray(theta0, float)
296
+ fisher = np.asarray(fisher, float)
297
+ validate_fisher_shape(theta0, fisher)
298
+
299
+ logprior_fn = _resolve_logprior(prior_terms=prior_terms, prior_bounds=prior_bounds, logprior=logprior)
300
+
301
+ logprior_val = 0.0
302
+ if logprior_fn is not None:
303
+ logprior_val = float(logprior_fn(theta))
304
+ if not np.isfinite(logprior_val):
305
+ return -np.inf
306
+
307
+ displacement = theta - theta0
308
+ chi2 = float(displacement @ fisher @ displacement)
309
+ return logprior_val - 0.5 * chi2
310
+
311
+
312
+ def build_delta_chi2_dali(
313
+ theta: NDArray[np.floating],
314
+ theta0: NDArray[np.floating],
315
+ dali: Any,
316
+ *,
317
+ forecast_order: int | None = 2,
318
+ ) -> float:
319
+ """Compute ``delta_chi2`` under the DALI approximation.
320
+
321
+ This evaluates a scalar ``delta_chi2`` from the displacement ``d = theta - theta0``
322
+ using forecast tensors returned by :func:`derivkit.forecasting.get_forecast_tensors`.
323
+
324
+ The input must be the dict form using the introduced-at-order convention:
325
+
326
+ - ``dali[1] == (F,)`` with ``F`` of shape ``(p, p)``
327
+ - ``dali[2] == (D1, D2)`` with shapes ``(p, p, p)`` and ``(p, p, p, p)``
328
+ - ``dali[3] == (T1, T2, T3)`` with shapes ``(p,)*4``, ``(p,)*5``, ``(p,)*6``
329
+
330
+ The evaluated quantity is:
331
+
332
+ - order 2: ``d.T @ F @ d + (1/3) D1[d^3] + (1/12) D2[d^4]``
333
+ - order 3: order 2 plus ``(1/3) T1[d^4] + (1/6) T2[d^5] + (1/36) T3[d^6]``.
334
+
335
+ Args:
336
+ theta: Evaluation point in parameter space.
337
+ theta0: Expansion point (fiducial parameters).
338
+ dali: Forecast tensors as a dict.
339
+ forecast_order: Maximum order to include. If ``None``, uses the highest key in
340
+ ``dali`` and requires it to be at least 2.
341
+
342
+ Returns:
343
+ Scalar ``delta_chi2``.
344
+
345
+ Raises:
346
+ TypeError: If ``dali`` is not a dict.
347
+ ValueError: If required tensor orders are missing or have incompatible shapes.
348
+ """
349
+ theta = np.asarray(theta, float).reshape(-1)
350
+ theta0 = np.asarray(theta0, float).reshape(-1)
351
+
352
+ if theta.shape != theta0.shape:
353
+ raise ValueError(
354
+ f"theta and theta0 must have the same shape; got {theta.shape} and {theta0.shape}.")
355
+
356
+ # DALI evaluation requires the dict form (needs Fisher inside dali[1]).
357
+ if not isinstance(dali, dict):
358
+ raise TypeError(
359
+ "build_delta_chi2_dali expects the dict form from get_forecast_tensors "
360
+ "(needs dali[1]=(F,) plus higher-order tensors)."
361
+ )
362
+
363
+ validate_dali_shape(theta0, dali)
364
+
365
+ # Choose order
366
+ if forecast_order is None:
367
+ chosen = max(dali.keys())
368
+ else:
369
+ try:
370
+ chosen = int(forecast_order)
371
+ except Exception as e:
372
+ raise TypeError(
373
+ f"forecast_order must be an int or None;"
374
+ f" got {type(forecast_order)}.") from e
375
+
376
+ if chosen not in SUPPORTED_FORECAST_ORDERS:
377
+ raise ValueError(
378
+ f"forecast_order={chosen} is not supported."
379
+ f" Supported values: {SUPPORTED_FORECAST_ORDERS}."
380
+ )
381
+
382
+ if chosen < 2:
383
+ raise ValueError(
384
+ "build_delta_chi2_dali requires forecast_order >= 2. "
385
+ "Use your Fisher delta-chi2 function for forecast_order=1."
386
+ )
387
+
388
+ # Require the needed keys exist
389
+ if 1 not in dali or 2 not in dali:
390
+ raise ValueError(
391
+ "dali must contain keys 1 and 2 (Fisher + doublet tensors).")
392
+ if chosen >= 3 and 3 not in dali:
393
+ raise ValueError(
394
+ "forecast_order=3 requires dali to contain key 3 (triplet tensors).")
395
+
396
+ fisher = np.asarray(dali[1][0], dtype=np.float64)
397
+ d = theta - theta0
398
+ chi2 = float(d @ fisher @ d)
399
+
400
+ # doublet
401
+ d1 = np.asarray(dali[2][0], dtype=np.float64)
402
+ d2 = np.asarray(dali[2][1], dtype=np.float64)
403
+ chi2 += (1.0 / 3.0) * float(np.einsum("ijk,i,j,k->",
404
+ d1, d, d, d))
405
+ chi2 += (1.0 / 12.0) * float(np.einsum("ijkl,i,j,k,l->",
406
+ d2, d, d, d, d))
407
+
408
+ if chosen == 2:
409
+ return chi2
410
+
411
+ t1 = np.asarray(dali[3][0], dtype=np.float64)
412
+ t2 = np.asarray(dali[3][1], dtype=np.float64)
413
+ t3 = np.asarray(dali[3][2], dtype=np.float64)
414
+
415
+ t1_4 = float(np.einsum("ijkl,i,j,k,l->",
416
+ t1, d, d, d, d))
417
+ t2_5 = float(np.einsum("ijklm,i,j,k,l,m->",
418
+ t2, d, d, d, d, d))
419
+ t3_6 = float(np.einsum("ijklmn,i,j,k,l,m,n->",
420
+ t3, d, d, d, d, d, d))
421
+
422
+ chi2 = chi2 + (1.0 / 3.0) * t1_4 + (1.0 / 6.0) * t2_5 + (1.0 / 36.0) * t3_6
423
+ return chi2
424
+
425
+
426
+ def build_logposterior_dali(
427
+ theta: NDArray[np.floating],
428
+ theta0: NDArray[np.floating],
429
+ dali: Any,
430
+ *,
431
+ forecast_order: int | None = 2,
432
+ prior_terms: Sequence[tuple[str, dict[str, Any]] | dict[str, Any]] | None = None,
433
+ prior_bounds: Sequence[tuple[float | None, float | None]] | None = None,
434
+ logprior: Callable[[NDArray[np.floating]], float] | None = None,
435
+ ) -> float:
436
+ """Compute the log posterior under the DALI approximation.
437
+
438
+ The posterior is evaluated as::
439
+
440
+ log p(theta) = logprior(theta) - 0.5 * delta_chi2(theta)
441
+
442
+ where ``delta_chi2`` is computed from the dict-form forecast tensors ``dali``
443
+ using :func:`build_delta_chi2_dali`.
444
+
445
+ Args:
446
+ theta: Evaluation point in parameter space.
447
+ theta0: Expansion point (fiducial parameters).
448
+ dali: Forecast tensors as a dict in the introduced-at-order convention.
449
+ forecast_order: Maximum order to include in ``delta_chi2``. If ``None``, uses
450
+ the highest key in ``dali``.
451
+ prior_terms: Prior term specification passed to :func:`build_prior`.
452
+ prior_bounds: Global hard bounds passed to :func:`build_prior`.
453
+ logprior: Optional custom log-prior callable.
454
+
455
+ Returns:
456
+ Scalar log posterior value (up to an additive constant). If the prior evaluates
457
+ to a non-finite value, returns ``-np.inf``.
458
+ """
459
+ theta = np.asarray(theta, float).reshape(-1)
460
+ theta0 = np.asarray(theta0, float).reshape(-1)
461
+
462
+ if theta.shape != theta0.shape:
463
+ raise ValueError(
464
+ f"theta and theta0 must have the same shape; got {theta.shape} and {theta0.shape}."
465
+ )
466
+
467
+ if not isinstance(dali, dict):
468
+ raise TypeError(
469
+ "build_logposterior_dali expects the dict form from get_forecast_tensors."
470
+ )
471
+
472
+ validate_dali_shape(theta0, dali)
473
+
474
+ logprior_fn = _resolve_logprior(
475
+ prior_terms=prior_terms, prior_bounds=prior_bounds, logprior=logprior
476
+ )
477
+ logprior_val = 0.0
478
+ if logprior_fn is not None:
479
+ logprior_val = float(logprior_fn(theta))
480
+ if not np.isfinite(logprior_val):
481
+ return -np.inf
482
+
483
+ chi2 = build_delta_chi2_dali(theta, theta0, dali,
484
+ forecast_order=forecast_order)
485
+
486
+ return logprior_val - 0.5 * chi2