derivkit 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. derivkit/__init__.py +22 -0
  2. derivkit/calculus/__init__.py +17 -0
  3. derivkit/calculus/calculus_core.py +152 -0
  4. derivkit/calculus/gradient.py +97 -0
  5. derivkit/calculus/hessian.py +528 -0
  6. derivkit/calculus/hyper_hessian.py +296 -0
  7. derivkit/calculus/jacobian.py +156 -0
  8. derivkit/calculus_kit.py +128 -0
  9. derivkit/derivative_kit.py +315 -0
  10. derivkit/derivatives/__init__.py +6 -0
  11. derivkit/derivatives/adaptive/__init__.py +5 -0
  12. derivkit/derivatives/adaptive/adaptive_fit.py +238 -0
  13. derivkit/derivatives/adaptive/batch_eval.py +179 -0
  14. derivkit/derivatives/adaptive/diagnostics.py +325 -0
  15. derivkit/derivatives/adaptive/grid.py +333 -0
  16. derivkit/derivatives/adaptive/polyfit_utils.py +513 -0
  17. derivkit/derivatives/adaptive/spacing.py +66 -0
  18. derivkit/derivatives/adaptive/transforms.py +245 -0
  19. derivkit/derivatives/autodiff/__init__.py +1 -0
  20. derivkit/derivatives/autodiff/jax_autodiff.py +95 -0
  21. derivkit/derivatives/autodiff/jax_core.py +217 -0
  22. derivkit/derivatives/autodiff/jax_utils.py +146 -0
  23. derivkit/derivatives/finite/__init__.py +5 -0
  24. derivkit/derivatives/finite/batch_eval.py +91 -0
  25. derivkit/derivatives/finite/core.py +84 -0
  26. derivkit/derivatives/finite/extrapolators.py +511 -0
  27. derivkit/derivatives/finite/finite_difference.py +247 -0
  28. derivkit/derivatives/finite/stencil.py +206 -0
  29. derivkit/derivatives/fornberg.py +245 -0
  30. derivkit/derivatives/local_polynomial_derivative/__init__.py +1 -0
  31. derivkit/derivatives/local_polynomial_derivative/diagnostics.py +90 -0
  32. derivkit/derivatives/local_polynomial_derivative/fit.py +199 -0
  33. derivkit/derivatives/local_polynomial_derivative/local_poly_config.py +95 -0
  34. derivkit/derivatives/local_polynomial_derivative/local_polynomial_derivative.py +205 -0
  35. derivkit/derivatives/local_polynomial_derivative/sampling.py +72 -0
  36. derivkit/derivatives/tabulated_model/__init__.py +1 -0
  37. derivkit/derivatives/tabulated_model/one_d.py +247 -0
  38. derivkit/forecast_kit.py +783 -0
  39. derivkit/forecasting/__init__.py +1 -0
  40. derivkit/forecasting/dali.py +78 -0
  41. derivkit/forecasting/expansions.py +486 -0
  42. derivkit/forecasting/fisher.py +298 -0
  43. derivkit/forecasting/fisher_gaussian.py +171 -0
  44. derivkit/forecasting/fisher_xy.py +357 -0
  45. derivkit/forecasting/forecast_core.py +313 -0
  46. derivkit/forecasting/getdist_dali_samples.py +429 -0
  47. derivkit/forecasting/getdist_fisher_samples.py +235 -0
  48. derivkit/forecasting/laplace.py +259 -0
  49. derivkit/forecasting/priors_core.py +860 -0
  50. derivkit/forecasting/sampling_utils.py +388 -0
  51. derivkit/likelihood_kit.py +114 -0
  52. derivkit/likelihoods/__init__.py +1 -0
  53. derivkit/likelihoods/gaussian.py +136 -0
  54. derivkit/likelihoods/poisson.py +176 -0
  55. derivkit/utils/__init__.py +13 -0
  56. derivkit/utils/concurrency.py +213 -0
  57. derivkit/utils/extrapolation.py +254 -0
  58. derivkit/utils/linalg.py +513 -0
  59. derivkit/utils/logger.py +26 -0
  60. derivkit/utils/numerics.py +262 -0
  61. derivkit/utils/sandbox.py +74 -0
  62. derivkit/utils/types.py +15 -0
  63. derivkit/utils/validate.py +811 -0
  64. derivkit-1.0.0.dist-info/METADATA +50 -0
  65. derivkit-1.0.0.dist-info/RECORD +68 -0
  66. derivkit-1.0.0.dist-info/WHEEL +5 -0
  67. derivkit-1.0.0.dist-info/licenses/LICENSE +21 -0
  68. derivkit-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,245 @@
1
+ """Helpers for parameter transformations and converting derivatives between coordinate systems.
2
+
3
+ This module provides small, self-contained transforms that make adaptive
4
+ polynomial fitting robust near parameter boundaries.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import Optional, Tuple
10
+
11
+ import numpy as np
12
+
13
+ __all__ = [
14
+ "signed_log_forward",
15
+ "signed_log_to_physical",
16
+ "signed_log_derivatives_to_x",
17
+ "sqrt_domain_forward",
18
+ "sqrt_to_physical",
19
+ "sqrt_derivatives_to_x_at_zero",
20
+ ]
21
+
22
+
23
+ def signed_log_forward(x0: float) -> Tuple[float, float]:
24
+ """Computes the signed-log coordinates for an expansion point.
25
+
26
+ The *signed-log* map represents a physical coordinate ``x`` as
27
+ ``x = sgn * exp(q)``, where ``q = log(|x|)`` and ``sgn = sign(x)``.
28
+ Here, **physical** means the model’s native parameter (``x``), while
29
+ **internal** means the reparameterized coordinate used for numerics (``q``).
30
+ This reparameterization keeps multiplicative variation (orders of magnitude)
31
+ well-behaved and avoids crossing through zero during local polynomial fits.
32
+
33
+ Args:
34
+ x0: Expansion point in physical coordinates. Must be finite and non-zero.
35
+
36
+ Returns:
37
+ Tuple[float, float]: ``(q0, sgn)``, where ``q0 = log(|x0|)`` and
38
+ ``sgn = +1.0`` if ``x0 > 0`` else ``-1.0``.
39
+
40
+ Raises:
41
+ ValueError: If ``x0`` is not finite or equals zero.
42
+ """
43
+ if not np.isfinite(x0):
44
+ raise ValueError("signed_log_forward requires a finite value of x0.")
45
+ if x0 == 0.0:
46
+ raise ValueError("signed_log_forward requires that x0 is non-zero.")
47
+ sgn = 1.0 if x0 > 0.0 else -1.0
48
+ q0 = np.log(abs(x0))
49
+ return q0, sgn
50
+
51
+
52
+ def signed_log_to_physical(q: np.ndarray, sgn: float) -> np.ndarray:
53
+ """Maps internal signed-log coordinate(s) to physical coordinate(s).
54
+
55
+ Args:
56
+ q: Internal coordinate(s) q = log(abs(x)).
57
+ sgn: Fixed sign (+1 or -1) taken from sign(x0).
58
+
59
+ Returns:
60
+ Physical coordinate(s) x = sgn * exp(q).
61
+
62
+ Raises:
63
+ ValueError: If `sgn` is not +1 or -1, or if `q` contains non-finite values.
64
+ """
65
+ try:
66
+ sgn = _normalize_sign(sgn)
67
+ except ValueError as e:
68
+ raise ValueError(f"signed_log_to_physical: invalid `sgn`: {e}") from None
69
+
70
+ q = np.asarray(q, dtype=float)
71
+ try:
72
+ _require_finite("q", q)
73
+ except ValueError as e:
74
+ raise ValueError(f"signed_log_to_physical: {e}") from None
75
+
76
+ return sgn * np.exp(q)
77
+
78
+
79
+ def signed_log_derivatives_to_x(
80
+ order: int,
81
+ x0: float,
82
+ dfdq: np.ndarray,
83
+ d2fdq2: Optional[np.ndarray] = None,
84
+ ) -> np.ndarray:
85
+ """Converts derivatives from the signed-log coordinate ``q`` to the original parameter ``x`` at ``x0 ≠ 0``.
86
+
87
+ This method uses the chain rule to convert derivatives computed in the
88
+ internal signed-log coordinate q back to physical coordinates x at a
89
+ non-zero expansion point x0.
90
+
91
+ Args:
92
+ order: Derivative order to return (1 or 2).
93
+ x0: Expansion point in the original parameter ``x`` (the model’s native coordinate);
94
+ must be finite and non-zero.
95
+ dfdq: First derivative in q (shape: (n_components,) or broadcastable).
96
+ d2fdq2: Second derivative with respect to ``q``; required when ``order == 2``.
97
+ A 1-D array with one value per component (shape ``(n_components,)``) or
98
+ broadcastable to that.
99
+
100
+ Returns:
101
+ The derivative(s) in physical coordinates at x0.
102
+
103
+ Raises:
104
+ ValueError: If `x0 == 0`, if required inputs (d2fdq2) are missing for order=2,
105
+ or if `x0` is not finite.
106
+ NotImplementedError: If `order` not in {1, 2}.
107
+ """
108
+ if not np.isfinite(x0) or x0 == 0.0:
109
+ raise ValueError("signed_log_derivatives_to_x requires finite x0 != 0.")
110
+ dfdq = np.asarray(dfdq, dtype=float)
111
+ if order == 1:
112
+ return dfdq / x0
113
+ elif order == 2:
114
+ if d2fdq2 is None:
115
+ raise ValueError("order=2 conversion requires d2fdq2.")
116
+ d2fdq2 = np.asarray(d2fdq2, dtype=float)
117
+ return (d2fdq2 - dfdq) / (x0 ** 2)
118
+ raise NotImplementedError("signed_log_derivatives_to_x supports orders 1 and 2.")
119
+
120
+
121
+ def sqrt_domain_forward(x0: float) -> tuple[float, float]:
122
+ """Computes the internal domain coordinate u0 for the square-root domain transformation.
123
+
124
+ The *square-root domain* transform re-expresses a parameter ``x`` as
125
+ ``x = s * u**2``, where ``s`` is the domain sign (+1 or –1). This mapping
126
+ flattens steep behavior near a boundary such as ``x = 0`` and allows
127
+ smooth polynomial fitting on either the positive or negative side.
128
+
129
+ Args:
130
+ x0: Expansion point in physical coordinates (finite). May be ±0.0
131
+
132
+ Returns:
133
+ Tuple[float, float]: ``(u0, s)``, with u0 >= 0 and sgn in {+1.0, -1.0}.
134
+ """
135
+ if not np.isfinite(x0):
136
+ raise ValueError("sqrt_domain_forward requires finite x0.")
137
+ sgn = _sgn_from_x0(x0)
138
+ u0 = 0.0 if x0 == 0.0 else float(np.sqrt(abs(x0)))
139
+ return u0, sgn
140
+
141
+
142
+ def sqrt_to_physical(u: np.ndarray, sign: float) -> np.ndarray:
143
+ """Maps internal domain coordinate(s) to physical coordinate(s).
144
+
145
+ This method maps internal coordinate(s) u to physical coordinate(s) x
146
+ using the relation x = sign * u^2.
147
+
148
+ Args:
149
+ u: Internal coordinate(s).
150
+ sign: Domain sign (+1 for x ≥ 0, -1 for x ≤ 0).
151
+
152
+ Returns:
153
+ Physical coordinate(s) x = sign * u^2.
154
+
155
+ Raises:
156
+ ValueError: If `sign` is not +1 or -1, or if `u` contains non-finite values.
157
+ """
158
+ u = np.asarray(u, dtype=float)
159
+ _require_finite("u", u)
160
+ s = _normalize_sign(sign)
161
+ return s * (u ** 2)
162
+
163
+
164
+ def sqrt_derivatives_to_x_at_zero(
165
+ order: int,
166
+ x0: float,
167
+ g2: Optional[np.ndarray] = None,
168
+ g4: Optional[np.ndarray] = None,
169
+ ) -> np.ndarray:
170
+ """Pull back derivatives at value x0=0 from u-space (sqrt-domain) to physical x.
171
+
172
+ This method maps derivatives computed in the internal sqrt-domain coordinate u
173
+ back to physical coordinates x at the expansion point x0=0 using the chain rule.
174
+
175
+ Args:
176
+ order: Derivative order to return (1 or 2).
177
+ x0: Expansion point in physical coordinates (finite). May be +0.0 or -0.0 to
178
+ select the domain side at the boundary. The domain sign is inferred
179
+ solely from x0 (including the sign of zero).
180
+ g2: Second derivative of g with respect to u at u=0; required for order=1.
181
+ g4: Fourth derivative of g with respect to u at u=0; required for order=2.
182
+
183
+ Returns:
184
+ The derivative(s) in physical coordinates at x0=0.
185
+
186
+ Raises:
187
+ ValueError: If required inputs (g2/g4) are missing for the requested order.
188
+ NotImplementedError: If `order` not in {1, 2}.
189
+ """
190
+ s = _sgn_from_x0(x0)
191
+ if order == 1:
192
+ if g2 is None:
193
+ raise ValueError("order=1 conversion requires g2 (g'' at u=0).")
194
+ return np.asarray(g2, dtype=float) / (2.0 * s)
195
+ if order == 2:
196
+ if g4 is None:
197
+ raise ValueError("order=2 conversion requires g4 (g'''' at u=0).")
198
+ return np.asarray(g4, dtype=float) / (12.0 * s * s)
199
+ raise NotImplementedError("sqrt_derivatives_to_x_at_zero supports orders 1 and 2.")
200
+
201
+
202
+ def _normalize_sign(s: float) -> float:
203
+ """Validate and normalize a sign value to exactly +1.0 or -1.0.
204
+
205
+ Args:
206
+ s: Input sign value (must be approximately ±1).
207
+
208
+ Returns:
209
+ +1.0 or -1.0.
210
+
211
+ Raises:
212
+ ValueError: If s is not finite or not approximately ±1.
213
+ """
214
+ if not np.isfinite(s):
215
+ raise ValueError("sign must be finite.")
216
+ if np.isclose(abs(s), 1.0, rtol=0.0, atol=1e-12):
217
+ return 1.0 if s > 0.0 else -1.0
218
+ raise ValueError("sign must be +1 or -1.")
219
+
220
+
221
+ def _require_finite(name: str, arr: np.ndarray) -> None:
222
+ """Raises a ``ValueError`` if an array contains any non-finite values.
223
+
224
+ Args:
225
+ name: Name of the array (for error message).
226
+ arr: Array to check.
227
+
228
+ Raises:
229
+ ValueError: If arr contains any non-finite values.
230
+ """
231
+ if not np.all(np.isfinite(arr)):
232
+ raise ValueError(f"{name} must be finite.")
233
+
234
+
235
+ def _sgn_from_x0(x0: float) -> float:
236
+ """Returns the sign of x0, disambiguating zero using np.signbit."""
237
+ if not np.isfinite(x0):
238
+ raise ValueError("x0 must be finite.")
239
+ if x0 > 0.0:
240
+ return 1.0
241
+ if x0 < 0.0:
242
+ return -1.0
243
+ # x0 == 0.0: disambiguate +0.0 vs -0.0
244
+ return -1.0 if np.signbit(x0) else 1.0
245
+
@@ -0,0 +1 @@
1
+ """JAX autodiff backend for DerivativeKit."""
@@ -0,0 +1,95 @@
1
+ r"""JAX-based autodiff backend for DerivativeKit.
2
+
3
+ This backend is intentionally minimal: it only supports scalar derivatives
4
+ $f: R \mapsto R$ via JAX autodiff, and must be registered explicitly as explained in the example.
5
+
6
+ Example:
7
+ --------
8
+ Basic usage (opt-in registration):
9
+
10
+ >>> from derivkit.derivative_kit import DerivativeKit # doctest: +SKIP
11
+ >>> from derivkit.derivatives.autodiff.jax_autodiff import register_jax_autodiff_backend # doctest: +SKIP
12
+ >>> register_jax_autodiff_backend() # doctest: +SKIP
13
+ >>>
14
+ >>> def func(x): # doctest: +SKIP
15
+ ... import jax.numpy as jnp
16
+ ... return jnp.sin(x) + 0.5 * x**2
17
+ ...
18
+ >>> dk = DerivativeKit(func, 1.0) # doctest: +SKIP
19
+ >>> dk.differentiate(method="autodiff", order=1) # doctest: +SKIP
20
+ >>> dk.differentiate(method="autodiff", order=2) # doctest: +SKIP
21
+
22
+ Notes:
23
+ ------
24
+ - This backend is scalar-only. For gradients/Jacobians/Hessians of functions
25
+ with vector inputs/outputs, use the standalone helpers in
26
+ ``derivkit.autodiff.jax_core`` (e.g. ``autodiff_gradient``).
27
+ - To enable this backend, install the JAX extra: ``pip install "derivkit[jax]"``.
28
+ """
29
+
30
+
31
+ from __future__ import annotations
32
+
33
+ from typing import Any, Callable
34
+
35
+ from derivkit.derivative_kit import register_method
36
+ from derivkit.derivatives.autodiff.jax_core import autodiff_derivative
37
+ from derivkit.derivatives.autodiff.jax_utils import require_jax
38
+
39
+ __all__ = [
40
+ "AutodiffDerivative",
41
+ "register_jax_autodiff_backend",
42
+ ]
43
+
44
+
45
+ class AutodiffDerivative:
46
+ """DerivativeKit engine for JAX-based autodiff.
47
+
48
+ Supports scalar functions f: R -> R with JAX-differentiable bodies.
49
+ """
50
+
51
+ def __init__(self, function: Callable[[float], Any], x0: float):
52
+ """Initializes the JAX autodiff derivative engine."""
53
+ self.function = function
54
+ self.x0 = float(x0)
55
+
56
+ def differentiate(self, *, order: int = 1, **_: Any) -> float:
57
+ """Computes the k-th derivative via JAX autodiff.
58
+
59
+ Args:
60
+ order: Derivative order (>=1).
61
+
62
+ Returns:
63
+ Derivative value as a float.
64
+ """
65
+ return autodiff_derivative(self.function, self.x0, order=order)
66
+
67
+
68
+ def register_jax_autodiff_backend(
69
+ *,
70
+ name: str = "autodiff",
71
+ aliases: tuple[str, ...] = ("jax", "jax-autodiff", "jax-diff", "jd"),
72
+ ) -> None:
73
+ """Registers the experimental JAX autodiff backend with DerivativeKit.
74
+
75
+ After calling this, you can use:
76
+
77
+ DerivativeKit(f, x0).differentiate(method=name, order=...)
78
+
79
+ Args:
80
+ name: Name of the method to register.
81
+ aliases: Alternative names for the method.
82
+
83
+ Returns:
84
+ None
85
+
86
+ Raises:
87
+ AutodiffUnavailable: If JAX is not available.
88
+ """
89
+ require_jax()
90
+
91
+ register_method(
92
+ name=name,
93
+ cls=AutodiffDerivative,
94
+ aliases=aliases,
95
+ )
@@ -0,0 +1,217 @@
1
+ r"""JAX-based autodiff helpers for DerivKit.
2
+
3
+ This module does not register any DerivKit backend by default.
4
+
5
+ Use these functions directly, or see
6
+ :func:`derivkit.autodiff.jax_autodiff.register_jax_autodiff_backend`
7
+ for an opt-in integration.
8
+
9
+ Use only with JAX-differentiable functions. For arbitrary models, prefer
10
+ the "adaptive" or "finite" methods.
11
+
12
+ Shape conventions (aligned with :mod:`derivKit.calculus` builders):
13
+
14
+ - ``autodiff_derivative``:
15
+ :math:`f:\\mathbb{R}\\mapsto\\mathbb{R}` → returns ``float`` (scalar)
16
+
17
+ - ``autodiff_gradient``:
18
+ :math:`f:\\mathbb{R}^n\\mapsto\\mathbb{R}` → returns array of shape ``(n,)``
19
+
20
+ - ``autodiff_jacobian``:
21
+ :math:`f:\\mathbb{R}^n\\mapsto\\mathbb{R}^m` (or tensor output) → returns array of
22
+ shape ``(m, n)``, where ``m = \\prod(\text{out\\_shape})``
23
+
24
+ - ``autodiff_hessian``:
25
+ :math:`f:\\mathbb{R}^n\\mapsto\\mathbb{R}` → returns array of shape ``(n, n)``
26
+ """
27
+
28
+ from __future__ import annotations
29
+
30
+ from functools import partial
31
+ from typing import Callable
32
+
33
+ import numpy as np
34
+
35
+ from derivkit.derivatives.autodiff.jax_utils import (
36
+ AutodiffUnavailable,
37
+ apply_array_nd,
38
+ apply_scalar_1d,
39
+ apply_scalar_nd,
40
+ jax,
41
+ jnp,
42
+ require_jax,
43
+ )
44
+
45
+ __all__ = [
46
+ "autodiff_derivative",
47
+ "autodiff_gradient",
48
+ "autodiff_jacobian",
49
+ "autodiff_hessian",
50
+ ]
51
+
52
+
53
+ def autodiff_derivative(func: Callable, x0: float, order: int = 1) -> float:
54
+ """Calculates the k-th derivative of a function f: R -> R via JAX autodiff.
55
+
56
+ Args:
57
+ func: Callable mapping float -> scalar.
58
+ x0: Point at which to evaluate the derivative.
59
+ order: Derivative order (>=1); uses repeated grad for higher orders.
60
+
61
+ Returns:
62
+ Derivative value as a float.
63
+
64
+ Raises:
65
+ AutodiffUnavailable: If JAX is not available or function is not differentiable.
66
+ ValueError: If order < 1.
67
+ TypeError: If func(x) is not scalar-valued.
68
+ """
69
+ require_jax()
70
+
71
+ if order < 1:
72
+ raise ValueError("autodiff_derivative: order must be >= 1.")
73
+
74
+ f_jax = partial(apply_scalar_1d, func, "autodiff_derivative")
75
+
76
+ g = f_jax
77
+ for _ in range(order):
78
+ g = jax.grad(g)
79
+
80
+ try:
81
+ val = g(x0)
82
+ except (TypeError, ValueError) as exc:
83
+ raise AutodiffUnavailable(
84
+ "autodiff_derivative: function is not JAX-differentiable at x0. "
85
+ "Use JAX primitives / jax.numpy or fall back to 'adaptive'/'finite'."
86
+ ) from exc
87
+
88
+ return float(val)
89
+
90
+
91
+ def autodiff_gradient(func: Callable, x0) -> np.ndarray:
92
+ """Computes the gradient of a scalar-valued function f: R^n -> R via JAX autodiff.
93
+
94
+ Args:
95
+ func: Function to be differentiated.
96
+ x0: Point at which to evaluate the gradient.
97
+
98
+ Returns:
99
+ A gradient vector as a 1D numpy.ndarray with shape (n,).
100
+
101
+ Raises:
102
+ AutodiffUnavailable: If JAX is not available or function is not differentiable.
103
+ TypeError: If func(theta) is not scalar-valued.
104
+ """
105
+ require_jax()
106
+
107
+ x0_arr = np.asarray(x0, float).ravel()
108
+
109
+ f_jax = partial(apply_scalar_nd, func, "autodiff_gradient")
110
+ grad_f = jax.grad(f_jax)
111
+
112
+ try:
113
+ g = grad_f(jnp.asarray(x0_arr))
114
+ except (TypeError, ValueError) as exc:
115
+ raise AutodiffUnavailable(
116
+ "autodiff_gradient: function is not JAX-differentiable."
117
+ ) from exc
118
+
119
+ return np.asarray(g, dtype=float).reshape(-1)
120
+
121
+
122
+ def autodiff_jacobian(
123
+ func: Callable,
124
+ x0,
125
+ *,
126
+ mode: str | None = None,
127
+ ) -> np.ndarray:
128
+ """Calculates the Jacobian of a vector-valued function via JAX autodiff.
129
+
130
+ Output convention matches DerivKit Jacobian builders: we flatten the function
131
+ output to length m = prod(out_shape), and return a 2D Jacobian of shape (m, n),
132
+ with n = input dimension.
133
+
134
+ Args:
135
+ func: Function to be differentiated.
136
+ x0: Point at which to evaluate the Jacobian; array-like, shape (n,).
137
+ mode: Differentiation mode; None (auto), 'fwd', or 'rev'.
138
+ If None, chooses 'rev' if m <= n, else 'fwd'. For more details about
139
+ modes, see JAX documentation for `jax.jacrev` and `jax.jacfwd`.
140
+
141
+ Returns:
142
+ A Jacobian matrix as a 2D numpy.ndarray with shape (m, n).
143
+
144
+ Raises:
145
+ AutodiffUnavailable: If JAX is not available or function is not differentiable.
146
+ ValueError: If mode is invalid.
147
+ TypeError: If func(theta) is scalar-valued.
148
+ """
149
+ require_jax()
150
+
151
+ x0_arr = np.asarray(x0, float).ravel()
152
+ x0_jax = jnp.asarray(x0_arr)
153
+
154
+ f_jax = partial(apply_array_nd, func, "autodiff_jacobian")
155
+
156
+ try:
157
+ y0 = f_jax(x0_jax)
158
+ except (TypeError, ValueError) as exc:
159
+ raise AutodiffUnavailable(
160
+ "autodiff_jacobian: function is not JAX-differentiable at x0."
161
+ ) from exc
162
+
163
+ in_dim = x0_arr.size
164
+ out_dim = int(np.prod(y0.shape))
165
+
166
+ if mode is None:
167
+ use_rev = out_dim <= in_dim
168
+ elif mode == "rev":
169
+ use_rev = True
170
+ elif mode == "fwd":
171
+ use_rev = False
172
+ else:
173
+ raise ValueError("autodiff_jacobian: mode must be None, 'fwd', or 'rev'.")
174
+
175
+ jac_fun = jax.jacrev if use_rev else jax.jacfwd
176
+
177
+ try:
178
+ jac = jac_fun(f_jax)(x0_jax)
179
+ except (TypeError, ValueError) as exc:
180
+ raise AutodiffUnavailable(
181
+ "autodiff_jacobian: failed to trace function with JAX."
182
+ ) from exc
183
+
184
+ jac_np = np.asarray(jac, dtype=float)
185
+ return jac_np.reshape(out_dim, in_dim)
186
+
187
+
188
+ def autodiff_hessian(func: Callable, x0) -> np.ndarray:
189
+ """Calculates the full Hessian of a scalar-valued function.
190
+
191
+ Args:
192
+ func: A function to be differentiated.
193
+ x0: Point at which to evaluate the Hessian; array-like, shape (n,) with
194
+ n = input dimension.
195
+
196
+ Returns:
197
+ A Hessian matrix as a 2D numpy.ndarray with shape (n, n).
198
+
199
+ Raises:
200
+ AutodiffUnavailable: If JAX is not available or function is not differentiable.
201
+ TypeError: If func(theta) is not scalar-valued.
202
+ """
203
+ require_jax()
204
+
205
+ x0_arr = np.asarray(x0, float).ravel()
206
+ x0_jax = jnp.asarray(x0_arr)
207
+
208
+ f_jax = partial(apply_scalar_nd, func, "autodiff_hessian")
209
+
210
+ try:
211
+ hess = jax.hessian(f_jax)(x0_jax)
212
+ except (TypeError, ValueError) as exc:
213
+ raise AutodiffUnavailable(
214
+ "autodiff_hessian: function is not JAX-differentiable."
215
+ ) from exc
216
+
217
+ return np.asarray(hess, dtype=float)
@@ -0,0 +1,146 @@
1
+ """Utilities for JAX-based autodiff in DerivKit."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Callable
6
+
7
+ try:
8
+ import jax
9
+ import jax.numpy as jnp
10
+ except ImportError:
11
+ jax = None
12
+ jnp = None
13
+ _HAS_JAX = False
14
+ else:
15
+ _HAS_JAX = True
16
+
17
+ has_jax: bool = _HAS_JAX
18
+
19
+ __all__ = [
20
+ "AutodiffUnavailable",
21
+ "require_jax",
22
+ "to_jax_scalar",
23
+ "to_jax_array",
24
+ "apply_scalar_1d",
25
+ "apply_scalar_nd",
26
+ "apply_array_nd",
27
+ ]
28
+
29
+
30
+ class AutodiffUnavailable(RuntimeError):
31
+ """Raises when JAX-based autodiff is unavailable."""
32
+
33
+
34
+ def require_jax() -> None:
35
+ """Raises if JAX is not available.
36
+
37
+ Args:
38
+ None.
39
+
40
+ Returns:
41
+ None.
42
+
43
+ Raises:
44
+ AutodiffUnavailable: If JAX is not installed.
45
+ """
46
+ if not _HAS_JAX:
47
+ raise AutodiffUnavailable(
48
+ "JAX autodiff requires `jax` + `jaxlib`.\n"
49
+ 'Install with `pip install "derivkit[jax]"` '
50
+ "(or follow JAX's official install instructions for GPU)."
51
+ )
52
+
53
+
54
+ def to_jax_scalar(y: Any, *, where: str) -> jnp.ndarray:
55
+ """Ensures that output is scalar and returns as JAX array.
56
+
57
+ Args:
58
+ y: Output to check.
59
+ where: Context string for error messages.
60
+
61
+ Returns:
62
+ Scalar (0-d) JAX array with shape ().
63
+
64
+ Raises:
65
+ TypeError: If output is not scalar.
66
+ """
67
+ arr = jnp.asarray(y)
68
+ if arr.ndim != 0:
69
+ raise TypeError(f"{where}: expected scalar output; got shape {tuple(arr.shape)}.")
70
+ return arr
71
+
72
+
73
+ def to_jax_array(y: Any, *, where: str) -> jnp.ndarray:
74
+ """Ensures that output is array-like (not scalar) and returns as JAX array.
75
+
76
+ Args:
77
+ y: Output to check.
78
+ where: Context string for error messages.
79
+
80
+ Returns:
81
+ Non-scalar JAX array with shape (m,) or higher.
82
+
83
+ Raises:
84
+ TypeError: If output is scalar or cannot be converted to JAX array.
85
+ """
86
+ try:
87
+ arr = jnp.asarray(y)
88
+ except TypeError as exc:
89
+ raise TypeError(f"{where}: output could not be converted to a JAX array.") from exc
90
+ if arr.ndim == 0:
91
+ raise TypeError(f"{where}: output is scalar; use autodiff_derivative/gradient instead.")
92
+ return arr
93
+
94
+
95
+ def apply_scalar_1d(
96
+ func: Callable[[float], Any],
97
+ where: str,
98
+ x: jnp.ndarray,
99
+ ) -> jnp.ndarray:
100
+ """Takes an input function and maps it over a 1D array with scalar output enforcement.
101
+
102
+ Args:
103
+ func: Function to apply.
104
+ where: Context string for error messages.
105
+ x: 1D JAX array of inputs.
106
+
107
+ Returns:
108
+ JAX array of scalar outputs.
109
+ """
110
+ return to_jax_scalar(func(x), where=where)
111
+
112
+
113
+ def apply_scalar_nd(
114
+ func: Callable,
115
+ where: str,
116
+ theta: jnp.ndarray,
117
+ ) -> jnp.ndarray:
118
+ """Takes an input function and maps it over an ND array with scalar output enforcement.
119
+
120
+ Args:
121
+ func: Function to apply.
122
+ where: Context string for error messages.
123
+ theta: ND JAX array of inputs.
124
+
125
+ Returns:
126
+ JAX array of scalar outputs.
127
+ """
128
+ return to_jax_scalar(func(theta), where=where)
129
+
130
+
131
+ def apply_array_nd(
132
+ func: Callable,
133
+ where: str,
134
+ theta: jnp.ndarray,
135
+ ) -> jnp.ndarray:
136
+ """Takes an input function and maps it over an ND array with array output enforcement.
137
+
138
+ Args:
139
+ func: Function to apply.
140
+ where: Context string for error messages.
141
+ theta: ND JAX array of inputs.
142
+
143
+ Returns:
144
+ JAX array of array outputs.
145
+ """
146
+ return to_jax_array(func(theta), where=where)