pymc-extras 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +6 -4
- pymc_extras/distributions/__init__.py +2 -0
- pymc_extras/distributions/continuous.py +3 -2
- pymc_extras/distributions/discrete.py +3 -1
- pymc_extras/distributions/transforms/__init__.py +3 -0
- pymc_extras/distributions/transforms/partial_order.py +227 -0
- pymc_extras/inference/__init__.py +4 -2
- pymc_extras/inference/find_map.py +62 -17
- pymc_extras/inference/fit.py +6 -4
- pymc_extras/inference/laplace.py +14 -8
- pymc_extras/inference/pathfinder/lbfgs.py +49 -13
- pymc_extras/inference/pathfinder/pathfinder.py +89 -103
- pymc_extras/statespace/core/statespace.py +191 -52
- pymc_extras/statespace/filters/distributions.py +15 -16
- pymc_extras/statespace/filters/kalman_filter.py +1 -18
- pymc_extras/statespace/filters/kalman_smoother.py +2 -6
- pymc_extras/statespace/models/ETS.py +10 -0
- pymc_extras/statespace/models/SARIMAX.py +26 -5
- pymc_extras/statespace/models/VARMAX.py +12 -2
- pymc_extras/statespace/models/structural.py +18 -5
- pymc_extras/statespace/utils/data_tools.py +24 -9
- pymc_extras-0.2.6.dist-info/METADATA +318 -0
- pymc_extras-0.2.6.dist-info/RECORD +65 -0
- {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info}/WHEEL +1 -2
- pymc_extras/version.py +0 -11
- pymc_extras/version.txt +0 -1
- pymc_extras-0.2.4.dist-info/METADATA +0 -110
- pymc_extras-0.2.4.dist-info/RECORD +0 -105
- pymc_extras-0.2.4.dist-info/top_level.txt +0 -2
- tests/__init__.py +0 -13
- tests/distributions/__init__.py +0 -19
- tests/distributions/test_continuous.py +0 -185
- tests/distributions/test_discrete.py +0 -210
- tests/distributions/test_discrete_markov_chain.py +0 -258
- tests/distributions/test_multivariate.py +0 -304
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +0 -132
- tests/model/marginal/test_graph_analysis.py +0 -182
- tests/model/marginal/test_marginal_model.py +0 -967
- tests/model/test_model_api.py +0 -38
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +0 -411
- tests/statespace/test_SARIMAX.py +0 -405
- tests/statespace/test_VARMAX.py +0 -184
- tests/statespace/test_coord_assignment.py +0 -116
- tests/statespace/test_distributions.py +0 -270
- tests/statespace/test_kalman_filter.py +0 -326
- tests/statespace/test_representation.py +0 -175
- tests/statespace/test_statespace.py +0 -872
- tests/statespace/test_statespace_JAX.py +0 -156
- tests/statespace/test_structural.py +0 -836
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +0 -9
- tests/statespace/utilities/statsmodel_local_level.py +0 -42
- tests/statespace/utilities/test_helpers.py +0 -310
- tests/test_blackjax_smc.py +0 -222
- tests/test_find_map.py +0 -103
- tests/test_histogram_approximation.py +0 -109
- tests/test_laplace.py +0 -265
- tests/test_linearmodel.py +0 -208
- tests/test_model_builder.py +0 -306
- tests/test_pathfinder.py +0 -203
- tests/test_pivoted_cholesky.py +0 -24
- tests/test_printing.py +0 -98
- tests/test_prior_from_trace.py +0 -172
- tests/test_splines.py +0 -77
- tests/utils.py +0 -0
- {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info/licenses}/LICENSE +0 -0
pymc_extras/__init__.py
CHANGED
|
@@ -13,18 +13,17 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
import logging
|
|
15
15
|
|
|
16
|
+
from importlib.metadata import version
|
|
17
|
+
|
|
16
18
|
from pymc_extras import gp, statespace, utils
|
|
17
19
|
from pymc_extras.distributions import *
|
|
18
|
-
from pymc_extras.inference
|
|
19
|
-
from pymc_extras.inference.fit import fit
|
|
20
|
-
from pymc_extras.inference.laplace import fit_laplace
|
|
20
|
+
from pymc_extras.inference import find_MAP, fit, fit_laplace, fit_pathfinder
|
|
21
21
|
from pymc_extras.model.marginal.marginal_model import (
|
|
22
22
|
MarginalModel,
|
|
23
23
|
marginalize,
|
|
24
24
|
recover_marginals,
|
|
25
25
|
)
|
|
26
26
|
from pymc_extras.model.model_api import as_model
|
|
27
|
-
from pymc_extras.version import __version__
|
|
28
27
|
|
|
29
28
|
_log = logging.getLogger("pmx")
|
|
30
29
|
|
|
@@ -33,3 +32,6 @@ if not logging.root.handlers:
|
|
|
33
32
|
if len(_log.handlers) == 0:
|
|
34
33
|
handler = logging.StreamHandler()
|
|
35
34
|
_log.addHandler(handler)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
__version__ = version("pymc-extras")
|
|
@@ -26,6 +26,7 @@ from pymc_extras.distributions.discrete import (
|
|
|
26
26
|
from pymc_extras.distributions.histogram_utils import histogram_approximation
|
|
27
27
|
from pymc_extras.distributions.multivariate import R2D2M2CP
|
|
28
28
|
from pymc_extras.distributions.timeseries import DiscreteMarkovChain
|
|
29
|
+
from pymc_extras.distributions.transforms import PartialOrder
|
|
29
30
|
|
|
30
31
|
__all__ = [
|
|
31
32
|
"Chi",
|
|
@@ -37,4 +38,5 @@ __all__ = [
|
|
|
37
38
|
"R2D2M2CP",
|
|
38
39
|
"Skellam",
|
|
39
40
|
"histogram_approximation",
|
|
41
|
+
"PartialOrder",
|
|
40
42
|
]
|
|
@@ -81,7 +81,7 @@ class GenExtreme(Continuous):
|
|
|
81
81
|
|
|
82
82
|
\left\{x: 1 + \xi\left(\frac{x-\mu}{\sigma}\right) > 0 \right\}.
|
|
83
83
|
|
|
84
|
-
Note that this parametrization is per Coles (2001), and differs from that of
|
|
84
|
+
Note that this parametrization is per Coles (2001) [1]_, and differs from that of
|
|
85
85
|
Scipy in the sign of the shape parameter, :math:`\xi`.
|
|
86
86
|
|
|
87
87
|
.. plot::
|
|
@@ -132,7 +132,7 @@ class GenExtreme(Continuous):
|
|
|
132
132
|
|
|
133
133
|
References
|
|
134
134
|
----------
|
|
135
|
-
.. [
|
|
135
|
+
.. [1] Coles, S.G. (2001).
|
|
136
136
|
An Introduction to the Statistical Modeling of Extreme Values
|
|
137
137
|
Springer-Verlag, London
|
|
138
138
|
|
|
@@ -260,6 +260,7 @@ class Chi:
|
|
|
260
260
|
Examples
|
|
261
261
|
--------
|
|
262
262
|
.. code-block:: python
|
|
263
|
+
|
|
263
264
|
import pymc as pm
|
|
264
265
|
from pymc_extras.distributions import Chi
|
|
265
266
|
|
|
@@ -116,6 +116,7 @@ class GeneralizedPoisson(pm.distributions.Discrete):
|
|
|
116
116
|
|
|
117
117
|
.. math:: f(x \mid \mu, \lambda) =
|
|
118
118
|
\frac{\mu (\mu + \lambda x)^{x-1} e^{-\mu - \lambda x}}{x!}
|
|
119
|
+
|
|
119
120
|
======== ======================================
|
|
120
121
|
Support :math:`x \in \mathbb{N}_0`
|
|
121
122
|
Mean :math:`\frac{\mu}{1 - \lambda}`
|
|
@@ -135,9 +136,10 @@ class GeneralizedPoisson(pm.distributions.Discrete):
|
|
|
135
136
|
When lam < 0, the mean is greater than the variance (underdispersion).
|
|
136
137
|
When lam > 0, the mean is less than the variance (overdispersion).
|
|
137
138
|
|
|
139
|
+
The PMF is taken from [1]_ and the random generator function is adapted from [2]_.
|
|
140
|
+
|
|
138
141
|
References
|
|
139
142
|
----------
|
|
140
|
-
The PMF is taken from [1] and the random generator function is adapted from [2].
|
|
141
143
|
.. [1] Consul, PoC, and Felix Famoye. "Generalized Poisson regression model."
|
|
142
144
|
Communications in Statistics-Theory and Methods 21.1 (1992): 89-109.
|
|
143
145
|
.. [2] Famoye, Felix. "Generalized Poisson random variate generation." American
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
# Copyright 2025 The PyMC Developers
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import numpy as np
|
|
15
|
+
import pytensor.tensor as pt
|
|
16
|
+
|
|
17
|
+
from pymc.logprob.transforms import Transform
|
|
18
|
+
|
|
19
|
+
__all__ = ["PartialOrder"]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def dtype_minval(dtype):
|
|
23
|
+
"""Find the minimum value for a given dtype"""
|
|
24
|
+
return np.iinfo(dtype).min if np.issubdtype(dtype, np.integer) else np.finfo(dtype).min
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def padded_where(x, to_len, padval=-1):
|
|
28
|
+
"""A padded version of np.where"""
|
|
29
|
+
w = np.where(x)
|
|
30
|
+
return np.concatenate([w[0], np.full(to_len - len(w[0]), padval)])
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class PartialOrder(Transform):
|
|
34
|
+
"""Create a PartialOrder transform
|
|
35
|
+
|
|
36
|
+
A more flexible version of the pymc ordered transform that
|
|
37
|
+
allows specifying a (strict) partial order on the elements.
|
|
38
|
+
|
|
39
|
+
Examples
|
|
40
|
+
--------
|
|
41
|
+
.. code:: python
|
|
42
|
+
|
|
43
|
+
import numpy as np
|
|
44
|
+
import pymc as pm
|
|
45
|
+
import pymc_extras as pmx
|
|
46
|
+
|
|
47
|
+
# Define two partial orders on 4 elements
|
|
48
|
+
# am[i,j] = 1 means i < j
|
|
49
|
+
adj_mats = np.array([
|
|
50
|
+
# 0 < {1, 2} < 3
|
|
51
|
+
[[0, 1, 1, 0],
|
|
52
|
+
[0, 0, 0, 1],
|
|
53
|
+
[0, 0, 0, 1],
|
|
54
|
+
[0, 0, 0, 0]],
|
|
55
|
+
|
|
56
|
+
# 1 < 0 < 3 < 2
|
|
57
|
+
[[0, 0, 0, 1],
|
|
58
|
+
[1, 0, 0, 0],
|
|
59
|
+
[0, 0, 0, 0],
|
|
60
|
+
[0, 0, 1, 0]],
|
|
61
|
+
])
|
|
62
|
+
|
|
63
|
+
# Create the partial order from the adjacency matrices
|
|
64
|
+
po = pmx.PartialOrder(adj_mats)
|
|
65
|
+
|
|
66
|
+
with pm.Model() as model:
|
|
67
|
+
# Generate 3 samples from both partial orders
|
|
68
|
+
pm.Normal("po_vals", shape=(3,2,4), transform=po,
|
|
69
|
+
initval=po.initvals((3,2,4)))
|
|
70
|
+
|
|
71
|
+
idata = pm.sample()
|
|
72
|
+
|
|
73
|
+
# Verify that for first po, the zeroth element is always the smallest
|
|
74
|
+
assert (idata.posterior['po_vals'][:,:,:,0,0] <
|
|
75
|
+
idata.posterior['po_vals'][:,:,:,0,1:]).all()
|
|
76
|
+
|
|
77
|
+
# Verify that for second po, the second element is always the largest
|
|
78
|
+
assert (idata.posterior['po_vals'][:,:,:,1,2] >=
|
|
79
|
+
idata.posterior['po_vals'][:,:,:,1,:]).all()
|
|
80
|
+
|
|
81
|
+
Technical notes
|
|
82
|
+
----------------
|
|
83
|
+
Partial order needs to be strict, i.e. without equalities.
|
|
84
|
+
A DAG defining the partial order is sufficient, as transitive closure is automatically computed.
|
|
85
|
+
Code works in O(N*D) in runtime, but takes O(N^3) in initialization,
|
|
86
|
+
where N is the number of nodes in the dag and D is the maximum
|
|
87
|
+
in-degree of a node in the transitive reduction.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
name = "partial_order"
|
|
91
|
+
|
|
92
|
+
def __init__(self, adj_mat):
|
|
93
|
+
"""
|
|
94
|
+
Initialize the PartialOrder transform
|
|
95
|
+
|
|
96
|
+
Parameters
|
|
97
|
+
----------
|
|
98
|
+
adj_mat: ndarray
|
|
99
|
+
adjacency matrix for the DAG that generates the partial order,
|
|
100
|
+
where ``adj_mat[i][j] = 1`` denotes ``i < j``.
|
|
101
|
+
Note this also accepts multiple DAGs if RV is multidimensional
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
# Basic input checks
|
|
105
|
+
if adj_mat.ndim < 2:
|
|
106
|
+
raise ValueError("Adjacency matrix must have at least 2 dimensions")
|
|
107
|
+
if adj_mat.shape[-2] != adj_mat.shape[-1]:
|
|
108
|
+
raise ValueError("Adjacency matrix is not square")
|
|
109
|
+
if adj_mat.min() != 0 or adj_mat.max() != 1:
|
|
110
|
+
raise ValueError("Adjacency matrix must contain only 0s and 1s")
|
|
111
|
+
|
|
112
|
+
# Create index over the first ellipsis dimensions
|
|
113
|
+
idx = np.ix_(*[np.arange(s) for s in adj_mat.shape[:-2]])
|
|
114
|
+
|
|
115
|
+
# Transitive closure using Floyd-Warshall
|
|
116
|
+
tc = adj_mat.astype(bool)
|
|
117
|
+
for k in range(tc.shape[-1]):
|
|
118
|
+
tc |= np.logical_and(tc[..., :, k, None], tc[..., None, k, :])
|
|
119
|
+
|
|
120
|
+
# Check if the dag is acyclic
|
|
121
|
+
if np.any(tc.diagonal(axis1=-2, axis2=-1)):
|
|
122
|
+
raise ValueError("Partial order contains equalities")
|
|
123
|
+
|
|
124
|
+
# Transitive reduction using the closure
|
|
125
|
+
# This gives the minimum description of the partial order
|
|
126
|
+
# This is to minmax the input degree
|
|
127
|
+
adj_mat = tc * (1 - np.matmul(tc, tc))
|
|
128
|
+
|
|
129
|
+
# Find the maximum in-degree of the reduced dag
|
|
130
|
+
dag_idim = adj_mat.sum(axis=-2).max()
|
|
131
|
+
|
|
132
|
+
# Topological sort
|
|
133
|
+
ts_inds = np.zeros(adj_mat.shape[:-1], dtype=int)
|
|
134
|
+
dm = adj_mat.copy()
|
|
135
|
+
for i in range(adj_mat.shape[1]):
|
|
136
|
+
assert dm.sum(axis=-2).min() == 0 # DAG is acyclic
|
|
137
|
+
nind = np.argmin(dm.sum(axis=-2), axis=-1)
|
|
138
|
+
dm[(*idx, slice(None), nind)] = 1 # Make nind not show up again
|
|
139
|
+
dm[(*idx, nind, slice(None))] = 0 # Allow it's children to show
|
|
140
|
+
ts_inds[(*idx, i)] = nind
|
|
141
|
+
self.ts_inds = ts_inds
|
|
142
|
+
|
|
143
|
+
# Change the dag to adjacency lists (with -1 for NA)
|
|
144
|
+
dag_T = np.apply_along_axis(padded_where, axis=-2, arr=adj_mat, padval=-1, to_len=dag_idim)
|
|
145
|
+
self.dag = np.swapaxes(dag_T, -2, -1)
|
|
146
|
+
self.is_start = np.all(self.dag[..., :, :] == -1, axis=-1)
|
|
147
|
+
|
|
148
|
+
def initvals(self, shape=None, lower=-1, upper=1):
|
|
149
|
+
"""
|
|
150
|
+
Create a set of appropriate initial values for the variable.
|
|
151
|
+
NB! It is important that proper initial values are used,
|
|
152
|
+
as only properly ordered values are in the range of the transform.
|
|
153
|
+
|
|
154
|
+
Parameters
|
|
155
|
+
----------
|
|
156
|
+
shape: tuple, default None
|
|
157
|
+
shape of the initial values. If None, adj_mat[:-1] is used
|
|
158
|
+
lower: float, default -1
|
|
159
|
+
lower bound for the initial values
|
|
160
|
+
upper: float, default 1
|
|
161
|
+
upper bound for the initial values
|
|
162
|
+
|
|
163
|
+
Returns
|
|
164
|
+
-------
|
|
165
|
+
vals: ndarray
|
|
166
|
+
initial values for the transformed variable
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
if shape is None:
|
|
170
|
+
shape = self.dag.shape[:-1]
|
|
171
|
+
|
|
172
|
+
if shape[-len(self.dag.shape[:-1]) :] != self.dag.shape[:-1]:
|
|
173
|
+
raise ValueError("Shape must match the shape of the adjacency matrix")
|
|
174
|
+
|
|
175
|
+
# Create the initial values
|
|
176
|
+
vals = np.linspace(lower, upper, self.dag.shape[-2])
|
|
177
|
+
inds = np.argsort(self.ts_inds, axis=-1)
|
|
178
|
+
ivals = vals[inds]
|
|
179
|
+
|
|
180
|
+
# Expand the initial values to the extra dimensions
|
|
181
|
+
extra_dims = shape[: -len(self.dag.shape[:-1])]
|
|
182
|
+
ivals = np.tile(ivals, extra_dims + tuple([1] * len(self.dag.shape[:-1])))
|
|
183
|
+
|
|
184
|
+
return ivals
|
|
185
|
+
|
|
186
|
+
def backward(self, value, *inputs):
|
|
187
|
+
minv = dtype_minval(value.dtype)
|
|
188
|
+
x = pt.concatenate(
|
|
189
|
+
[pt.zeros_like(value), pt.full(value.shape[:-1], minv)[..., None]], axis=-1
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
# Indices to allow broadcasting the max over the last dimension
|
|
193
|
+
idx = np.ix_(*[np.arange(s) for s in self.dag.shape[:-2]])
|
|
194
|
+
idx2 = tuple(np.tile(i[:, None], self.dag.shape[-1]) for i in idx)
|
|
195
|
+
|
|
196
|
+
# Has to be done stepwise as next steps depend on previous values
|
|
197
|
+
# Also has to be done in topological order, hence the ts_inds
|
|
198
|
+
for i in range(self.dag.shape[-2]):
|
|
199
|
+
tsi = self.ts_inds[..., i]
|
|
200
|
+
if len(tsi.shape) == 0:
|
|
201
|
+
tsi = int(tsi) # if shape 0, it's a scalar
|
|
202
|
+
ni = (*idx, tsi) # i-th node in topological order
|
|
203
|
+
eni = (Ellipsis, *ni)
|
|
204
|
+
ist = self.is_start[ni]
|
|
205
|
+
|
|
206
|
+
mval = pt.max(x[(Ellipsis, *idx2, self.dag[ni])], axis=-1)
|
|
207
|
+
x = pt.set_subtensor(x[eni], ist * value[eni] + (1 - ist) * (mval + pt.exp(value[eni])))
|
|
208
|
+
return x[..., :-1]
|
|
209
|
+
|
|
210
|
+
def forward(self, value, *inputs):
|
|
211
|
+
y = pt.zeros_like(value)
|
|
212
|
+
|
|
213
|
+
minv = dtype_minval(value.dtype)
|
|
214
|
+
vx = pt.concatenate([value, pt.full(value.shape[:-1], minv)[..., None]], axis=-1)
|
|
215
|
+
|
|
216
|
+
# Indices to allow broadcasting the max over the last dimension
|
|
217
|
+
idx = np.ix_(*[np.arange(s) for s in self.dag.shape[:-2]])
|
|
218
|
+
idx = tuple(np.tile(i[:, None, None], self.dag.shape[-2:]) for i in idx)
|
|
219
|
+
|
|
220
|
+
y = self.is_start * value + (1 - self.is_start) * (
|
|
221
|
+
pt.log(value - pt.max(vx[(Ellipsis, *idx, self.dag[..., :])], axis=-1))
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
return y
|
|
225
|
+
|
|
226
|
+
def log_jac_det(self, value, *inputs):
|
|
227
|
+
return pt.sum(value * (1 - self.is_start), axis=-1)
|
|
@@ -12,7 +12,9 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
|
|
15
|
+
from pymc_extras.inference.find_map import find_MAP
|
|
16
16
|
from pymc_extras.inference.fit import fit
|
|
17
|
+
from pymc_extras.inference.laplace import fit_laplace
|
|
18
|
+
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
|
|
17
19
|
|
|
18
|
-
__all__ = ["fit"]
|
|
20
|
+
__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"]
|
|
@@ -9,7 +9,7 @@ import pymc as pm
|
|
|
9
9
|
import pytensor
|
|
10
10
|
import pytensor.tensor as pt
|
|
11
11
|
|
|
12
|
-
from better_optimize import minimize
|
|
12
|
+
from better_optimize import basinhopping, minimize
|
|
13
13
|
from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method
|
|
14
14
|
from pymc.blocking import DictToArrayBijection, RaveledVars
|
|
15
15
|
from pymc.initial_point import make_initial_point_fn
|
|
@@ -146,7 +146,7 @@ def _compile_grad_and_hess_to_jax(
|
|
|
146
146
|
orig_loss_fn = f_loss.vm.jit_fn
|
|
147
147
|
|
|
148
148
|
@jax.jit
|
|
149
|
-
def loss_fn_jax_grad(x
|
|
149
|
+
def loss_fn_jax_grad(x):
|
|
150
150
|
return jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x)
|
|
151
151
|
|
|
152
152
|
f_loss_and_grad = loss_fn_jax_grad
|
|
@@ -301,6 +301,14 @@ def scipy_optimize_funcs_from_loss(
|
|
|
301
301
|
point=initial_point_dict, outputs=[loss], inputs=inputs
|
|
302
302
|
)
|
|
303
303
|
|
|
304
|
+
# If we use pytensor gradients, we will use the pytensor function wrapper that handles shared variables. When
|
|
305
|
+
# computing jax gradients, we discard the function wrapper, so we can't handle shared variables --> rewrite them
|
|
306
|
+
# away.
|
|
307
|
+
if use_jax_gradients:
|
|
308
|
+
from pymc.sampling.jax import _replace_shared_variables
|
|
309
|
+
|
|
310
|
+
[loss] = _replace_shared_variables([loss])
|
|
311
|
+
|
|
304
312
|
compute_grad = use_grad and not use_jax_gradients
|
|
305
313
|
compute_hess = use_hess and not use_jax_gradients
|
|
306
314
|
compute_hessp = use_hessp and not use_jax_gradients
|
|
@@ -327,7 +335,7 @@ def scipy_optimize_funcs_from_loss(
|
|
|
327
335
|
|
|
328
336
|
|
|
329
337
|
def find_MAP(
|
|
330
|
-
method: minimize_method,
|
|
338
|
+
method: minimize_method | Literal["basinhopping"],
|
|
331
339
|
*,
|
|
332
340
|
model: pm.Model | None = None,
|
|
333
341
|
use_grad: bool | None = None,
|
|
@@ -344,14 +352,17 @@ def find_MAP(
|
|
|
344
352
|
**optimizer_kwargs,
|
|
345
353
|
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], OptimizeResult]:
|
|
346
354
|
"""
|
|
347
|
-
Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.
|
|
355
|
+
Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.optimize.
|
|
348
356
|
|
|
349
357
|
Parameters
|
|
350
358
|
----------
|
|
351
359
|
model : pm.Model
|
|
352
360
|
The PyMC model to be fit. If None, the current model context is used.
|
|
353
361
|
method : str
|
|
354
|
-
The optimization method to use.
|
|
362
|
+
The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP,
|
|
363
|
+
trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.
|
|
364
|
+
|
|
365
|
+
See scipy.optimize.minimize documentation for details.
|
|
355
366
|
use_grad : bool | None, optional
|
|
356
367
|
Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on
|
|
357
368
|
the ``method``.
|
|
@@ -379,7 +390,9 @@ def find_MAP(
|
|
|
379
390
|
compile_kwargs: dict, optional
|
|
380
391
|
Additional options to pass to the ``pytensor.function`` function when compiling loss functions.
|
|
381
392
|
**optimizer_kwargs
|
|
382
|
-
Additional keyword arguments to pass to the ``scipy.optimize
|
|
393
|
+
Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
|
|
394
|
+
``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
|
|
395
|
+
``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
|
|
383
396
|
|
|
384
397
|
Returns
|
|
385
398
|
-------
|
|
@@ -405,6 +418,18 @@ def find_MAP(
|
|
|
405
418
|
initial_params = DictToArrayBijection.map(
|
|
406
419
|
{var_name: value for var_name, value in start_dict.items() if var_name in vars_dict}
|
|
407
420
|
)
|
|
421
|
+
|
|
422
|
+
do_basinhopping = method == "basinhopping"
|
|
423
|
+
minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
|
|
424
|
+
|
|
425
|
+
if do_basinhopping:
|
|
426
|
+
# For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
|
|
427
|
+
# another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
|
|
428
|
+
# if one isn't provided.
|
|
429
|
+
|
|
430
|
+
method = minimizer_kwargs.pop("method", "L-BFGS-B")
|
|
431
|
+
minimizer_kwargs["method"] = method
|
|
432
|
+
|
|
408
433
|
use_grad, use_hess, use_hessp = set_optimizer_function_defaults(
|
|
409
434
|
method, use_grad, use_hess, use_hessp
|
|
410
435
|
)
|
|
@@ -423,17 +448,37 @@ def find_MAP(
|
|
|
423
448
|
args = optimizer_kwargs.pop("args", None)
|
|
424
449
|
|
|
425
450
|
# better_optimize.minimize will check if f_logp is a fused loss+grad Op, and automatically assign the jac argument
|
|
426
|
-
# if so. That is why
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
hess
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
451
|
+
# if so. That is why the jac argument is not passed here in either branch.
|
|
452
|
+
|
|
453
|
+
if do_basinhopping:
|
|
454
|
+
if "args" not in minimizer_kwargs:
|
|
455
|
+
minimizer_kwargs["args"] = args
|
|
456
|
+
if "hess" not in minimizer_kwargs:
|
|
457
|
+
minimizer_kwargs["hess"] = f_hess
|
|
458
|
+
if "hessp" not in minimizer_kwargs:
|
|
459
|
+
minimizer_kwargs["hessp"] = f_hessp
|
|
460
|
+
if "method" not in minimizer_kwargs:
|
|
461
|
+
minimizer_kwargs["method"] = method
|
|
462
|
+
|
|
463
|
+
optimizer_result = basinhopping(
|
|
464
|
+
func=f_logp,
|
|
465
|
+
x0=cast(np.ndarray[float], initial_params.data),
|
|
466
|
+
progressbar=progressbar,
|
|
467
|
+
minimizer_kwargs=minimizer_kwargs,
|
|
468
|
+
**optimizer_kwargs,
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
else:
|
|
472
|
+
optimizer_result = minimize(
|
|
473
|
+
f=f_logp,
|
|
474
|
+
x0=cast(np.ndarray[float], initial_params.data),
|
|
475
|
+
args=args,
|
|
476
|
+
hess=f_hess,
|
|
477
|
+
hessp=f_hessp,
|
|
478
|
+
progressbar=progressbar,
|
|
479
|
+
method=method,
|
|
480
|
+
**optimizer_kwargs,
|
|
481
|
+
)
|
|
437
482
|
|
|
438
483
|
raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info)
|
|
439
484
|
unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed)
|
pymc_extras/inference/fit.py
CHANGED
|
@@ -11,11 +11,13 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
import arviz as az
|
|
14
15
|
|
|
15
16
|
|
|
16
|
-
def fit(method, **kwargs):
|
|
17
|
+
def fit(method: str, **kwargs) -> az.InferenceData:
|
|
17
18
|
"""
|
|
18
|
-
Fit a model with an inference algorithm
|
|
19
|
+
Fit a model with an inference algorithm.
|
|
20
|
+
See :func:`fit_pathfinder` and :func:`fit_laplace` for more details.
|
|
19
21
|
|
|
20
22
|
Parameters
|
|
21
23
|
----------
|
|
@@ -23,11 +25,11 @@ def fit(method, **kwargs):
|
|
|
23
25
|
Which inference method to run.
|
|
24
26
|
Supported: pathfinder or laplace
|
|
25
27
|
|
|
26
|
-
kwargs are passed on.
|
|
28
|
+
kwargs: keyword arguments are passed on to the inference method.
|
|
27
29
|
|
|
28
30
|
Returns
|
|
29
31
|
-------
|
|
30
|
-
arviz.InferenceData
|
|
32
|
+
:class:`~arviz.InferenceData`
|
|
31
33
|
"""
|
|
32
34
|
if method == "pathfinder":
|
|
33
35
|
from pymc_extras.inference.pathfinder import fit_pathfinder
|
pymc_extras/inference/laplace.py
CHANGED
|
@@ -377,7 +377,10 @@ def sample_laplace_posterior(
|
|
|
377
377
|
posterior_dist = stats.multivariate_normal(
|
|
378
378
|
mean=mu.data, cov=H_inv, allow_singular=True, seed=rng
|
|
379
379
|
)
|
|
380
|
+
|
|
380
381
|
posterior_draws = posterior_dist.rvs(size=(chains, draws))
|
|
382
|
+
if mu.data.shape == (1,):
|
|
383
|
+
posterior_draws = np.expand_dims(posterior_draws, -1)
|
|
381
384
|
|
|
382
385
|
if transform_samples:
|
|
383
386
|
constrained_rvs, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(model)
|
|
@@ -413,7 +416,7 @@ def sample_laplace_posterior(
|
|
|
413
416
|
|
|
414
417
|
|
|
415
418
|
def fit_laplace(
|
|
416
|
-
optimize_method: minimize_method = "BFGS",
|
|
419
|
+
optimize_method: minimize_method | Literal["basinhopping"] = "BFGS",
|
|
417
420
|
*,
|
|
418
421
|
model: pm.Model | None = None,
|
|
419
422
|
use_grad: bool | None = None,
|
|
@@ -446,8 +449,11 @@ def fit_laplace(
|
|
|
446
449
|
----------
|
|
447
450
|
model : pm.Model
|
|
448
451
|
The PyMC model to be fit. If None, the current model context is used.
|
|
449
|
-
|
|
450
|
-
The optimization method to use.
|
|
452
|
+
method : str
|
|
453
|
+
The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP,
|
|
454
|
+
trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.
|
|
455
|
+
|
|
456
|
+
See scipy.optimize.minimize documentation for details.
|
|
451
457
|
use_grad : bool | None, optional
|
|
452
458
|
Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on
|
|
453
459
|
the ``method``.
|
|
@@ -497,16 +503,16 @@ def fit_laplace(
|
|
|
497
503
|
diag_jitter: float | None
|
|
498
504
|
A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite.
|
|
499
505
|
If None, no jitter is added. Default is 1e-8.
|
|
500
|
-
optimizer_kwargs
|
|
501
|
-
Additional keyword arguments to pass to scipy.
|
|
502
|
-
|
|
503
|
-
|
|
506
|
+
optimizer_kwargs
|
|
507
|
+
Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
|
|
508
|
+
``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
|
|
509
|
+
``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
|
|
504
510
|
compile_kwargs: dict, optional
|
|
505
511
|
Additional keyword arguments to pass to pytensor.function.
|
|
506
512
|
|
|
507
513
|
Returns
|
|
508
514
|
-------
|
|
509
|
-
|
|
515
|
+
:class:`~arviz.InferenceData`
|
|
510
516
|
An InferenceData object containing the approximated posterior samples.
|
|
511
517
|
|
|
512
518
|
Examples
|