derivkit 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- derivkit/__init__.py +22 -0
- derivkit/calculus/__init__.py +17 -0
- derivkit/calculus/calculus_core.py +152 -0
- derivkit/calculus/gradient.py +97 -0
- derivkit/calculus/hessian.py +528 -0
- derivkit/calculus/hyper_hessian.py +296 -0
- derivkit/calculus/jacobian.py +156 -0
- derivkit/calculus_kit.py +128 -0
- derivkit/derivative_kit.py +315 -0
- derivkit/derivatives/__init__.py +6 -0
- derivkit/derivatives/adaptive/__init__.py +5 -0
- derivkit/derivatives/adaptive/adaptive_fit.py +238 -0
- derivkit/derivatives/adaptive/batch_eval.py +179 -0
- derivkit/derivatives/adaptive/diagnostics.py +325 -0
- derivkit/derivatives/adaptive/grid.py +333 -0
- derivkit/derivatives/adaptive/polyfit_utils.py +513 -0
- derivkit/derivatives/adaptive/spacing.py +66 -0
- derivkit/derivatives/adaptive/transforms.py +245 -0
- derivkit/derivatives/autodiff/__init__.py +1 -0
- derivkit/derivatives/autodiff/jax_autodiff.py +95 -0
- derivkit/derivatives/autodiff/jax_core.py +217 -0
- derivkit/derivatives/autodiff/jax_utils.py +146 -0
- derivkit/derivatives/finite/__init__.py +5 -0
- derivkit/derivatives/finite/batch_eval.py +91 -0
- derivkit/derivatives/finite/core.py +84 -0
- derivkit/derivatives/finite/extrapolators.py +511 -0
- derivkit/derivatives/finite/finite_difference.py +247 -0
- derivkit/derivatives/finite/stencil.py +206 -0
- derivkit/derivatives/fornberg.py +245 -0
- derivkit/derivatives/local_polynomial_derivative/__init__.py +1 -0
- derivkit/derivatives/local_polynomial_derivative/diagnostics.py +90 -0
- derivkit/derivatives/local_polynomial_derivative/fit.py +199 -0
- derivkit/derivatives/local_polynomial_derivative/local_poly_config.py +95 -0
- derivkit/derivatives/local_polynomial_derivative/local_polynomial_derivative.py +205 -0
- derivkit/derivatives/local_polynomial_derivative/sampling.py +72 -0
- derivkit/derivatives/tabulated_model/__init__.py +1 -0
- derivkit/derivatives/tabulated_model/one_d.py +247 -0
- derivkit/forecast_kit.py +783 -0
- derivkit/forecasting/__init__.py +1 -0
- derivkit/forecasting/dali.py +78 -0
- derivkit/forecasting/expansions.py +486 -0
- derivkit/forecasting/fisher.py +298 -0
- derivkit/forecasting/fisher_gaussian.py +171 -0
- derivkit/forecasting/fisher_xy.py +357 -0
- derivkit/forecasting/forecast_core.py +313 -0
- derivkit/forecasting/getdist_dali_samples.py +429 -0
- derivkit/forecasting/getdist_fisher_samples.py +235 -0
- derivkit/forecasting/laplace.py +259 -0
- derivkit/forecasting/priors_core.py +860 -0
- derivkit/forecasting/sampling_utils.py +388 -0
- derivkit/likelihood_kit.py +114 -0
- derivkit/likelihoods/__init__.py +1 -0
- derivkit/likelihoods/gaussian.py +136 -0
- derivkit/likelihoods/poisson.py +176 -0
- derivkit/utils/__init__.py +13 -0
- derivkit/utils/concurrency.py +213 -0
- derivkit/utils/extrapolation.py +254 -0
- derivkit/utils/linalg.py +513 -0
- derivkit/utils/logger.py +26 -0
- derivkit/utils/numerics.py +262 -0
- derivkit/utils/sandbox.py +74 -0
- derivkit/utils/types.py +15 -0
- derivkit/utils/validate.py +811 -0
- derivkit-1.0.0.dist-info/METADATA +50 -0
- derivkit-1.0.0.dist-info/RECORD +68 -0
- derivkit-1.0.0.dist-info/WHEEL +5 -0
- derivkit-1.0.0.dist-info/licenses/LICENSE +21 -0
- derivkit-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
"""Helpers for parameter transformations and converting derivatives between coordinate systems.
|
|
2
|
+
|
|
3
|
+
This module provides small, self-contained transforms that make adaptive
|
|
4
|
+
polynomial fitting robust near parameter boundaries.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import Optional, Tuple
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"signed_log_forward",
|
|
15
|
+
"signed_log_to_physical",
|
|
16
|
+
"signed_log_derivatives_to_x",
|
|
17
|
+
"sqrt_domain_forward",
|
|
18
|
+
"sqrt_to_physical",
|
|
19
|
+
"sqrt_derivatives_to_x_at_zero",
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def signed_log_forward(x0: float) -> Tuple[float, float]:
|
|
24
|
+
"""Computes the signed-log coordinates for an expansion point.
|
|
25
|
+
|
|
26
|
+
The *signed-log* map represents a physical coordinate ``x`` as
|
|
27
|
+
``x = sgn * exp(q)``, where ``q = log(|x|)`` and ``sgn = sign(x)``.
|
|
28
|
+
Here, **physical** means the model’s native parameter (``x``), while
|
|
29
|
+
**internal** means the reparameterized coordinate used for numerics (``q``).
|
|
30
|
+
This reparameterization keeps multiplicative variation (orders of magnitude)
|
|
31
|
+
well-behaved and avoids crossing through zero during local polynomial fits.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
x0: Expansion point in physical coordinates. Must be finite and non-zero.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
Tuple[float, float]: ``(q0, sgn)``, where ``q0 = log(|x0|)`` and
|
|
38
|
+
``sgn = +1.0`` if ``x0 > 0`` else ``-1.0``.
|
|
39
|
+
|
|
40
|
+
Raises:
|
|
41
|
+
ValueError: If ``x0`` is not finite or equals zero.
|
|
42
|
+
"""
|
|
43
|
+
if not np.isfinite(x0):
|
|
44
|
+
raise ValueError("signed_log_forward requires a finite value of x0.")
|
|
45
|
+
if x0 == 0.0:
|
|
46
|
+
raise ValueError("signed_log_forward requires that x0 is non-zero.")
|
|
47
|
+
sgn = 1.0 if x0 > 0.0 else -1.0
|
|
48
|
+
q0 = np.log(abs(x0))
|
|
49
|
+
return q0, sgn
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def signed_log_to_physical(q: np.ndarray, sgn: float) -> np.ndarray:
|
|
53
|
+
"""Maps internal signed-log coordinate(s) to physical coordinate(s).
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
q: Internal coordinate(s) q = log(abs(x)).
|
|
57
|
+
sgn: Fixed sign (+1 or -1) taken from sign(x0).
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Physical coordinate(s) x = sgn * exp(q).
|
|
61
|
+
|
|
62
|
+
Raises:
|
|
63
|
+
ValueError: If `sgn` is not +1 or -1, or if `q` contains non-finite values.
|
|
64
|
+
"""
|
|
65
|
+
try:
|
|
66
|
+
sgn = _normalize_sign(sgn)
|
|
67
|
+
except ValueError as e:
|
|
68
|
+
raise ValueError(f"signed_log_to_physical: invalid `sgn`: {e}") from None
|
|
69
|
+
|
|
70
|
+
q = np.asarray(q, dtype=float)
|
|
71
|
+
try:
|
|
72
|
+
_require_finite("q", q)
|
|
73
|
+
except ValueError as e:
|
|
74
|
+
raise ValueError(f"signed_log_to_physical: {e}") from None
|
|
75
|
+
|
|
76
|
+
return sgn * np.exp(q)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def signed_log_derivatives_to_x(
|
|
80
|
+
order: int,
|
|
81
|
+
x0: float,
|
|
82
|
+
dfdq: np.ndarray,
|
|
83
|
+
d2fdq2: Optional[np.ndarray] = None,
|
|
84
|
+
) -> np.ndarray:
|
|
85
|
+
"""Converts derivatives from the signed-log coordinate ``q`` to the original parameter ``x`` at ``x0 ≠ 0``.
|
|
86
|
+
|
|
87
|
+
This method uses the chain rule to convert derivatives computed in the
|
|
88
|
+
internal signed-log coordinate q back to physical coordinates x at a
|
|
89
|
+
non-zero expansion point x0.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
order: Derivative order to return (1 or 2).
|
|
93
|
+
x0: Expansion point in the original parameter ``x`` (the model’s native coordinate);
|
|
94
|
+
must be finite and non-zero.
|
|
95
|
+
dfdq: First derivative in q (shape: (n_components,) or broadcastable).
|
|
96
|
+
d2fdq2: Second derivative with respect to ``q``; required when ``order == 2``.
|
|
97
|
+
A 1-D array with one value per component (shape ``(n_components,)``) or
|
|
98
|
+
broadcastable to that.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
The derivative(s) in physical coordinates at x0.
|
|
102
|
+
|
|
103
|
+
Raises:
|
|
104
|
+
ValueError: If `x0 == 0`, if required inputs (d2fdq2) are missing for order=2,
|
|
105
|
+
or if `x0` is not finite.
|
|
106
|
+
NotImplementedError: If `order` not in {1, 2}.
|
|
107
|
+
"""
|
|
108
|
+
if not np.isfinite(x0) or x0 == 0.0:
|
|
109
|
+
raise ValueError("signed_log_derivatives_to_x requires finite x0 != 0.")
|
|
110
|
+
dfdq = np.asarray(dfdq, dtype=float)
|
|
111
|
+
if order == 1:
|
|
112
|
+
return dfdq / x0
|
|
113
|
+
elif order == 2:
|
|
114
|
+
if d2fdq2 is None:
|
|
115
|
+
raise ValueError("order=2 conversion requires d2fdq2.")
|
|
116
|
+
d2fdq2 = np.asarray(d2fdq2, dtype=float)
|
|
117
|
+
return (d2fdq2 - dfdq) / (x0 ** 2)
|
|
118
|
+
raise NotImplementedError("signed_log_derivatives_to_x supports orders 1 and 2.")
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def sqrt_domain_forward(x0: float) -> tuple[float, float]:
|
|
122
|
+
"""Computes the internal domain coordinate u0 for the square-root domain transformation.
|
|
123
|
+
|
|
124
|
+
The *square-root domain* transform re-expresses a parameter ``x`` as
|
|
125
|
+
``x = s * u**2``, where ``s`` is the domain sign (+1 or –1). This mapping
|
|
126
|
+
flattens steep behavior near a boundary such as ``x = 0`` and allows
|
|
127
|
+
smooth polynomial fitting on either the positive or negative side.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
x0: Expansion point in physical coordinates (finite). May be ±0.0
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
Tuple[float, float]: ``(u0, s)``, with u0 >= 0 and sgn in {+1.0, -1.0}.
|
|
134
|
+
"""
|
|
135
|
+
if not np.isfinite(x0):
|
|
136
|
+
raise ValueError("sqrt_domain_forward requires finite x0.")
|
|
137
|
+
sgn = _sgn_from_x0(x0)
|
|
138
|
+
u0 = 0.0 if x0 == 0.0 else float(np.sqrt(abs(x0)))
|
|
139
|
+
return u0, sgn
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def sqrt_to_physical(u: np.ndarray, sign: float) -> np.ndarray:
|
|
143
|
+
"""Maps internal domain coordinate(s) to physical coordinate(s).
|
|
144
|
+
|
|
145
|
+
This method maps internal coordinate(s) u to physical coordinate(s) x
|
|
146
|
+
using the relation x = sign * u^2.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
u: Internal coordinate(s).
|
|
150
|
+
sign: Domain sign (+1 for x ≥ 0, -1 for x ≤ 0).
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Physical coordinate(s) x = sign * u^2.
|
|
154
|
+
|
|
155
|
+
Raises:
|
|
156
|
+
ValueError: If `sign` is not +1 or -1, or if `u` contains non-finite values.
|
|
157
|
+
"""
|
|
158
|
+
u = np.asarray(u, dtype=float)
|
|
159
|
+
_require_finite("u", u)
|
|
160
|
+
s = _normalize_sign(sign)
|
|
161
|
+
return s * (u ** 2)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def sqrt_derivatives_to_x_at_zero(
|
|
165
|
+
order: int,
|
|
166
|
+
x0: float,
|
|
167
|
+
g2: Optional[np.ndarray] = None,
|
|
168
|
+
g4: Optional[np.ndarray] = None,
|
|
169
|
+
) -> np.ndarray:
|
|
170
|
+
"""Pull back derivatives at value x0=0 from u-space (sqrt-domain) to physical x.
|
|
171
|
+
|
|
172
|
+
This method maps derivatives computed in the internal sqrt-domain coordinate u
|
|
173
|
+
back to physical coordinates x at the expansion point x0=0 using the chain rule.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
order: Derivative order to return (1 or 2).
|
|
177
|
+
x0: Expansion point in physical coordinates (finite). May be +0.0 or -0.0 to
|
|
178
|
+
select the domain side at the boundary. The domain sign is inferred
|
|
179
|
+
solely from x0 (including the sign of zero).
|
|
180
|
+
g2: Second derivative of g with respect to u at u=0; required for order=1.
|
|
181
|
+
g4: Fourth derivative of g with respect to u at u=0; required for order=2.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
The derivative(s) in physical coordinates at x0=0.
|
|
185
|
+
|
|
186
|
+
Raises:
|
|
187
|
+
ValueError: If required inputs (g2/g4) are missing for the requested order.
|
|
188
|
+
NotImplementedError: If `order` not in {1, 2}.
|
|
189
|
+
"""
|
|
190
|
+
s = _sgn_from_x0(x0)
|
|
191
|
+
if order == 1:
|
|
192
|
+
if g2 is None:
|
|
193
|
+
raise ValueError("order=1 conversion requires g2 (g'' at u=0).")
|
|
194
|
+
return np.asarray(g2, dtype=float) / (2.0 * s)
|
|
195
|
+
if order == 2:
|
|
196
|
+
if g4 is None:
|
|
197
|
+
raise ValueError("order=2 conversion requires g4 (g'''' at u=0).")
|
|
198
|
+
return np.asarray(g4, dtype=float) / (12.0 * s * s)
|
|
199
|
+
raise NotImplementedError("sqrt_derivatives_to_x_at_zero supports orders 1 and 2.")
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _normalize_sign(s: float) -> float:
|
|
203
|
+
"""Validate and normalize a sign value to exactly +1.0 or -1.0.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
s: Input sign value (must be approximately ±1).
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
+1.0 or -1.0.
|
|
210
|
+
|
|
211
|
+
Raises:
|
|
212
|
+
ValueError: If s is not finite or not approximately ±1.
|
|
213
|
+
"""
|
|
214
|
+
if not np.isfinite(s):
|
|
215
|
+
raise ValueError("sign must be finite.")
|
|
216
|
+
if np.isclose(abs(s), 1.0, rtol=0.0, atol=1e-12):
|
|
217
|
+
return 1.0 if s > 0.0 else -1.0
|
|
218
|
+
raise ValueError("sign must be +1 or -1.")
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def _require_finite(name: str, arr: np.ndarray) -> None:
|
|
222
|
+
"""Raises a ``ValueError`` if an array contains any non-finite values.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
name: Name of the array (for error message).
|
|
226
|
+
arr: Array to check.
|
|
227
|
+
|
|
228
|
+
Raises:
|
|
229
|
+
ValueError: If arr contains any non-finite values.
|
|
230
|
+
"""
|
|
231
|
+
if not np.all(np.isfinite(arr)):
|
|
232
|
+
raise ValueError(f"{name} must be finite.")
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def _sgn_from_x0(x0: float) -> float:
|
|
236
|
+
"""Returns the sign of x0, disambiguating zero using np.signbit."""
|
|
237
|
+
if not np.isfinite(x0):
|
|
238
|
+
raise ValueError("x0 must be finite.")
|
|
239
|
+
if x0 > 0.0:
|
|
240
|
+
return 1.0
|
|
241
|
+
if x0 < 0.0:
|
|
242
|
+
return -1.0
|
|
243
|
+
# x0 == 0.0: disambiguate +0.0 vs -0.0
|
|
244
|
+
return -1.0 if np.signbit(x0) else 1.0
|
|
245
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""JAX autodiff backend for DerivativeKit."""
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
r"""JAX-based autodiff backend for DerivativeKit.
|
|
2
|
+
|
|
3
|
+
This backend is intentionally minimal: it only supports scalar derivatives
|
|
4
|
+
$f: R \mapsto R$ via JAX autodiff, and must be registered explicitly as explained in the example.
|
|
5
|
+
|
|
6
|
+
Example:
|
|
7
|
+
--------
|
|
8
|
+
Basic usage (opt-in registration):
|
|
9
|
+
|
|
10
|
+
>>> from derivkit.derivative_kit import DerivativeKit # doctest: +SKIP
|
|
11
|
+
>>> from derivkit.derivatives.autodiff.jax_autodiff import register_jax_autodiff_backend # doctest: +SKIP
|
|
12
|
+
>>> register_jax_autodiff_backend() # doctest: +SKIP
|
|
13
|
+
>>>
|
|
14
|
+
>>> def func(x): # doctest: +SKIP
|
|
15
|
+
... import jax.numpy as jnp
|
|
16
|
+
... return jnp.sin(x) + 0.5 * x**2
|
|
17
|
+
...
|
|
18
|
+
>>> dk = DerivativeKit(func, 1.0) # doctest: +SKIP
|
|
19
|
+
>>> dk.differentiate(method="autodiff", order=1) # doctest: +SKIP
|
|
20
|
+
>>> dk.differentiate(method="autodiff", order=2) # doctest: +SKIP
|
|
21
|
+
|
|
22
|
+
Notes:
|
|
23
|
+
------
|
|
24
|
+
- This backend is scalar-only. For gradients/Jacobians/Hessians of functions
|
|
25
|
+
with vector inputs/outputs, use the standalone helpers in
|
|
26
|
+
``derivkit.autodiff.jax_core`` (e.g. ``autodiff_gradient``).
|
|
27
|
+
- To enable this backend, install the JAX extra: ``pip install "derivkit[jax]"``.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
from __future__ import annotations
|
|
32
|
+
|
|
33
|
+
from typing import Any, Callable
|
|
34
|
+
|
|
35
|
+
from derivkit.derivative_kit import register_method
|
|
36
|
+
from derivkit.derivatives.autodiff.jax_core import autodiff_derivative
|
|
37
|
+
from derivkit.derivatives.autodiff.jax_utils import require_jax
|
|
38
|
+
|
|
39
|
+
__all__ = [
|
|
40
|
+
"AutodiffDerivative",
|
|
41
|
+
"register_jax_autodiff_backend",
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class AutodiffDerivative:
|
|
46
|
+
"""DerivativeKit engine for JAX-based autodiff.
|
|
47
|
+
|
|
48
|
+
Supports scalar functions f: R -> R with JAX-differentiable bodies.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(self, function: Callable[[float], Any], x0: float):
|
|
52
|
+
"""Initializes the JAX autodiff derivative engine."""
|
|
53
|
+
self.function = function
|
|
54
|
+
self.x0 = float(x0)
|
|
55
|
+
|
|
56
|
+
def differentiate(self, *, order: int = 1, **_: Any) -> float:
|
|
57
|
+
"""Computes the k-th derivative via JAX autodiff.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
order: Derivative order (>=1).
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
Derivative value as a float.
|
|
64
|
+
"""
|
|
65
|
+
return autodiff_derivative(self.function, self.x0, order=order)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def register_jax_autodiff_backend(
|
|
69
|
+
*,
|
|
70
|
+
name: str = "autodiff",
|
|
71
|
+
aliases: tuple[str, ...] = ("jax", "jax-autodiff", "jax-diff", "jd"),
|
|
72
|
+
) -> None:
|
|
73
|
+
"""Registers the experimental JAX autodiff backend with DerivativeKit.
|
|
74
|
+
|
|
75
|
+
After calling this, you can use:
|
|
76
|
+
|
|
77
|
+
DerivativeKit(f, x0).differentiate(method=name, order=...)
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
name: Name of the method to register.
|
|
81
|
+
aliases: Alternative names for the method.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
None
|
|
85
|
+
|
|
86
|
+
Raises:
|
|
87
|
+
AutodiffUnavailable: If JAX is not available.
|
|
88
|
+
"""
|
|
89
|
+
require_jax()
|
|
90
|
+
|
|
91
|
+
register_method(
|
|
92
|
+
name=name,
|
|
93
|
+
cls=AutodiffDerivative,
|
|
94
|
+
aliases=aliases,
|
|
95
|
+
)
|
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
r"""JAX-based autodiff helpers for DerivKit.
|
|
2
|
+
|
|
3
|
+
This module does not register any DerivKit backend by default.
|
|
4
|
+
|
|
5
|
+
Use these functions directly, or see
|
|
6
|
+
:func:`derivkit.autodiff.jax_autodiff.register_jax_autodiff_backend`
|
|
7
|
+
for an opt-in integration.
|
|
8
|
+
|
|
9
|
+
Use only with JAX-differentiable functions. For arbitrary models, prefer
|
|
10
|
+
the "adaptive" or "finite" methods.
|
|
11
|
+
|
|
12
|
+
Shape conventions (aligned with :mod:`derivKit.calculus` builders):
|
|
13
|
+
|
|
14
|
+
- ``autodiff_derivative``:
|
|
15
|
+
:math:`f:\\mathbb{R}\\mapsto\\mathbb{R}` → returns ``float`` (scalar)
|
|
16
|
+
|
|
17
|
+
- ``autodiff_gradient``:
|
|
18
|
+
:math:`f:\\mathbb{R}^n\\mapsto\\mathbb{R}` → returns array of shape ``(n,)``
|
|
19
|
+
|
|
20
|
+
- ``autodiff_jacobian``:
|
|
21
|
+
:math:`f:\\mathbb{R}^n\\mapsto\\mathbb{R}^m` (or tensor output) → returns array of
|
|
22
|
+
shape ``(m, n)``, where ``m = \\prod(\text{out\\_shape})``
|
|
23
|
+
|
|
24
|
+
- ``autodiff_hessian``:
|
|
25
|
+
:math:`f:\\mathbb{R}^n\\mapsto\\mathbb{R}` → returns array of shape ``(n, n)``
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
from __future__ import annotations
|
|
29
|
+
|
|
30
|
+
from functools import partial
|
|
31
|
+
from typing import Callable
|
|
32
|
+
|
|
33
|
+
import numpy as np
|
|
34
|
+
|
|
35
|
+
from derivkit.derivatives.autodiff.jax_utils import (
|
|
36
|
+
AutodiffUnavailable,
|
|
37
|
+
apply_array_nd,
|
|
38
|
+
apply_scalar_1d,
|
|
39
|
+
apply_scalar_nd,
|
|
40
|
+
jax,
|
|
41
|
+
jnp,
|
|
42
|
+
require_jax,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
__all__ = [
|
|
46
|
+
"autodiff_derivative",
|
|
47
|
+
"autodiff_gradient",
|
|
48
|
+
"autodiff_jacobian",
|
|
49
|
+
"autodiff_hessian",
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def autodiff_derivative(func: Callable, x0: float, order: int = 1) -> float:
|
|
54
|
+
"""Calculates the k-th derivative of a function f: R -> R via JAX autodiff.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
func: Callable mapping float -> scalar.
|
|
58
|
+
x0: Point at which to evaluate the derivative.
|
|
59
|
+
order: Derivative order (>=1); uses repeated grad for higher orders.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
Derivative value as a float.
|
|
63
|
+
|
|
64
|
+
Raises:
|
|
65
|
+
AutodiffUnavailable: If JAX is not available or function is not differentiable.
|
|
66
|
+
ValueError: If order < 1.
|
|
67
|
+
TypeError: If func(x) is not scalar-valued.
|
|
68
|
+
"""
|
|
69
|
+
require_jax()
|
|
70
|
+
|
|
71
|
+
if order < 1:
|
|
72
|
+
raise ValueError("autodiff_derivative: order must be >= 1.")
|
|
73
|
+
|
|
74
|
+
f_jax = partial(apply_scalar_1d, func, "autodiff_derivative")
|
|
75
|
+
|
|
76
|
+
g = f_jax
|
|
77
|
+
for _ in range(order):
|
|
78
|
+
g = jax.grad(g)
|
|
79
|
+
|
|
80
|
+
try:
|
|
81
|
+
val = g(x0)
|
|
82
|
+
except (TypeError, ValueError) as exc:
|
|
83
|
+
raise AutodiffUnavailable(
|
|
84
|
+
"autodiff_derivative: function is not JAX-differentiable at x0. "
|
|
85
|
+
"Use JAX primitives / jax.numpy or fall back to 'adaptive'/'finite'."
|
|
86
|
+
) from exc
|
|
87
|
+
|
|
88
|
+
return float(val)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def autodiff_gradient(func: Callable, x0) -> np.ndarray:
|
|
92
|
+
"""Computes the gradient of a scalar-valued function f: R^n -> R via JAX autodiff.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
func: Function to be differentiated.
|
|
96
|
+
x0: Point at which to evaluate the gradient.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
A gradient vector as a 1D numpy.ndarray with shape (n,).
|
|
100
|
+
|
|
101
|
+
Raises:
|
|
102
|
+
AutodiffUnavailable: If JAX is not available or function is not differentiable.
|
|
103
|
+
TypeError: If func(theta) is not scalar-valued.
|
|
104
|
+
"""
|
|
105
|
+
require_jax()
|
|
106
|
+
|
|
107
|
+
x0_arr = np.asarray(x0, float).ravel()
|
|
108
|
+
|
|
109
|
+
f_jax = partial(apply_scalar_nd, func, "autodiff_gradient")
|
|
110
|
+
grad_f = jax.grad(f_jax)
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
g = grad_f(jnp.asarray(x0_arr))
|
|
114
|
+
except (TypeError, ValueError) as exc:
|
|
115
|
+
raise AutodiffUnavailable(
|
|
116
|
+
"autodiff_gradient: function is not JAX-differentiable."
|
|
117
|
+
) from exc
|
|
118
|
+
|
|
119
|
+
return np.asarray(g, dtype=float).reshape(-1)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def autodiff_jacobian(
|
|
123
|
+
func: Callable,
|
|
124
|
+
x0,
|
|
125
|
+
*,
|
|
126
|
+
mode: str | None = None,
|
|
127
|
+
) -> np.ndarray:
|
|
128
|
+
"""Calculates the Jacobian of a vector-valued function via JAX autodiff.
|
|
129
|
+
|
|
130
|
+
Output convention matches DerivKit Jacobian builders: we flatten the function
|
|
131
|
+
output to length m = prod(out_shape), and return a 2D Jacobian of shape (m, n),
|
|
132
|
+
with n = input dimension.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
func: Function to be differentiated.
|
|
136
|
+
x0: Point at which to evaluate the Jacobian; array-like, shape (n,).
|
|
137
|
+
mode: Differentiation mode; None (auto), 'fwd', or 'rev'.
|
|
138
|
+
If None, chooses 'rev' if m <= n, else 'fwd'. For more details about
|
|
139
|
+
modes, see JAX documentation for `jax.jacrev` and `jax.jacfwd`.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
A Jacobian matrix as a 2D numpy.ndarray with shape (m, n).
|
|
143
|
+
|
|
144
|
+
Raises:
|
|
145
|
+
AutodiffUnavailable: If JAX is not available or function is not differentiable.
|
|
146
|
+
ValueError: If mode is invalid.
|
|
147
|
+
TypeError: If func(theta) is scalar-valued.
|
|
148
|
+
"""
|
|
149
|
+
require_jax()
|
|
150
|
+
|
|
151
|
+
x0_arr = np.asarray(x0, float).ravel()
|
|
152
|
+
x0_jax = jnp.asarray(x0_arr)
|
|
153
|
+
|
|
154
|
+
f_jax = partial(apply_array_nd, func, "autodiff_jacobian")
|
|
155
|
+
|
|
156
|
+
try:
|
|
157
|
+
y0 = f_jax(x0_jax)
|
|
158
|
+
except (TypeError, ValueError) as exc:
|
|
159
|
+
raise AutodiffUnavailable(
|
|
160
|
+
"autodiff_jacobian: function is not JAX-differentiable at x0."
|
|
161
|
+
) from exc
|
|
162
|
+
|
|
163
|
+
in_dim = x0_arr.size
|
|
164
|
+
out_dim = int(np.prod(y0.shape))
|
|
165
|
+
|
|
166
|
+
if mode is None:
|
|
167
|
+
use_rev = out_dim <= in_dim
|
|
168
|
+
elif mode == "rev":
|
|
169
|
+
use_rev = True
|
|
170
|
+
elif mode == "fwd":
|
|
171
|
+
use_rev = False
|
|
172
|
+
else:
|
|
173
|
+
raise ValueError("autodiff_jacobian: mode must be None, 'fwd', or 'rev'.")
|
|
174
|
+
|
|
175
|
+
jac_fun = jax.jacrev if use_rev else jax.jacfwd
|
|
176
|
+
|
|
177
|
+
try:
|
|
178
|
+
jac = jac_fun(f_jax)(x0_jax)
|
|
179
|
+
except (TypeError, ValueError) as exc:
|
|
180
|
+
raise AutodiffUnavailable(
|
|
181
|
+
"autodiff_jacobian: failed to trace function with JAX."
|
|
182
|
+
) from exc
|
|
183
|
+
|
|
184
|
+
jac_np = np.asarray(jac, dtype=float)
|
|
185
|
+
return jac_np.reshape(out_dim, in_dim)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def autodiff_hessian(func: Callable, x0) -> np.ndarray:
|
|
189
|
+
"""Calculates the full Hessian of a scalar-valued function.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
func: A function to be differentiated.
|
|
193
|
+
x0: Point at which to evaluate the Hessian; array-like, shape (n,) with
|
|
194
|
+
n = input dimension.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
A Hessian matrix as a 2D numpy.ndarray with shape (n, n).
|
|
198
|
+
|
|
199
|
+
Raises:
|
|
200
|
+
AutodiffUnavailable: If JAX is not available or function is not differentiable.
|
|
201
|
+
TypeError: If func(theta) is not scalar-valued.
|
|
202
|
+
"""
|
|
203
|
+
require_jax()
|
|
204
|
+
|
|
205
|
+
x0_arr = np.asarray(x0, float).ravel()
|
|
206
|
+
x0_jax = jnp.asarray(x0_arr)
|
|
207
|
+
|
|
208
|
+
f_jax = partial(apply_scalar_nd, func, "autodiff_hessian")
|
|
209
|
+
|
|
210
|
+
try:
|
|
211
|
+
hess = jax.hessian(f_jax)(x0_jax)
|
|
212
|
+
except (TypeError, ValueError) as exc:
|
|
213
|
+
raise AutodiffUnavailable(
|
|
214
|
+
"autodiff_hessian: function is not JAX-differentiable."
|
|
215
|
+
) from exc
|
|
216
|
+
|
|
217
|
+
return np.asarray(hess, dtype=float)
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
"""Utilities for JAX-based autodiff in DerivKit."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Callable
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
import jax
|
|
9
|
+
import jax.numpy as jnp
|
|
10
|
+
except ImportError:
|
|
11
|
+
jax = None
|
|
12
|
+
jnp = None
|
|
13
|
+
_HAS_JAX = False
|
|
14
|
+
else:
|
|
15
|
+
_HAS_JAX = True
|
|
16
|
+
|
|
17
|
+
has_jax: bool = _HAS_JAX
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"AutodiffUnavailable",
|
|
21
|
+
"require_jax",
|
|
22
|
+
"to_jax_scalar",
|
|
23
|
+
"to_jax_array",
|
|
24
|
+
"apply_scalar_1d",
|
|
25
|
+
"apply_scalar_nd",
|
|
26
|
+
"apply_array_nd",
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class AutodiffUnavailable(RuntimeError):
|
|
31
|
+
"""Raises when JAX-based autodiff is unavailable."""
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def require_jax() -> None:
|
|
35
|
+
"""Raises if JAX is not available.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
None.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
None.
|
|
42
|
+
|
|
43
|
+
Raises:
|
|
44
|
+
AutodiffUnavailable: If JAX is not installed.
|
|
45
|
+
"""
|
|
46
|
+
if not _HAS_JAX:
|
|
47
|
+
raise AutodiffUnavailable(
|
|
48
|
+
"JAX autodiff requires `jax` + `jaxlib`.\n"
|
|
49
|
+
'Install with `pip install "derivkit[jax]"` '
|
|
50
|
+
"(or follow JAX's official install instructions for GPU)."
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def to_jax_scalar(y: Any, *, where: str) -> jnp.ndarray:
|
|
55
|
+
"""Ensures that output is scalar and returns as JAX array.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
y: Output to check.
|
|
59
|
+
where: Context string for error messages.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
Scalar (0-d) JAX array with shape ().
|
|
63
|
+
|
|
64
|
+
Raises:
|
|
65
|
+
TypeError: If output is not scalar.
|
|
66
|
+
"""
|
|
67
|
+
arr = jnp.asarray(y)
|
|
68
|
+
if arr.ndim != 0:
|
|
69
|
+
raise TypeError(f"{where}: expected scalar output; got shape {tuple(arr.shape)}.")
|
|
70
|
+
return arr
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def to_jax_array(y: Any, *, where: str) -> jnp.ndarray:
|
|
74
|
+
"""Ensures that output is array-like (not scalar) and returns as JAX array.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
y: Output to check.
|
|
78
|
+
where: Context string for error messages.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Non-scalar JAX array with shape (m,) or higher.
|
|
82
|
+
|
|
83
|
+
Raises:
|
|
84
|
+
TypeError: If output is scalar or cannot be converted to JAX array.
|
|
85
|
+
"""
|
|
86
|
+
try:
|
|
87
|
+
arr = jnp.asarray(y)
|
|
88
|
+
except TypeError as exc:
|
|
89
|
+
raise TypeError(f"{where}: output could not be converted to a JAX array.") from exc
|
|
90
|
+
if arr.ndim == 0:
|
|
91
|
+
raise TypeError(f"{where}: output is scalar; use autodiff_derivative/gradient instead.")
|
|
92
|
+
return arr
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def apply_scalar_1d(
|
|
96
|
+
func: Callable[[float], Any],
|
|
97
|
+
where: str,
|
|
98
|
+
x: jnp.ndarray,
|
|
99
|
+
) -> jnp.ndarray:
|
|
100
|
+
"""Takes an input function and maps it over a 1D array with scalar output enforcement.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
func: Function to apply.
|
|
104
|
+
where: Context string for error messages.
|
|
105
|
+
x: 1D JAX array of inputs.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
JAX array of scalar outputs.
|
|
109
|
+
"""
|
|
110
|
+
return to_jax_scalar(func(x), where=where)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def apply_scalar_nd(
|
|
114
|
+
func: Callable,
|
|
115
|
+
where: str,
|
|
116
|
+
theta: jnp.ndarray,
|
|
117
|
+
) -> jnp.ndarray:
|
|
118
|
+
"""Takes an input function and maps it over an ND array with scalar output enforcement.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
func: Function to apply.
|
|
122
|
+
where: Context string for error messages.
|
|
123
|
+
theta: ND JAX array of inputs.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
JAX array of scalar outputs.
|
|
127
|
+
"""
|
|
128
|
+
return to_jax_scalar(func(theta), where=where)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def apply_array_nd(
|
|
132
|
+
func: Callable,
|
|
133
|
+
where: str,
|
|
134
|
+
theta: jnp.ndarray,
|
|
135
|
+
) -> jnp.ndarray:
|
|
136
|
+
"""Takes an input function and maps it over an ND array with array output enforcement.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
func: Function to apply.
|
|
140
|
+
where: Context string for error messages.
|
|
141
|
+
theta: ND JAX array of inputs.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
JAX array of array outputs.
|
|
145
|
+
"""
|
|
146
|
+
return to_jax_array(func(theta), where=where)
|