eryn 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eryn/CMakeLists.txt +51 -0
- eryn/__init__.py +35 -0
- eryn/backends/__init__.py +20 -0
- eryn/backends/backend.py +1150 -0
- eryn/backends/hdfbackend.py +819 -0
- eryn/ensemble.py +1690 -0
- eryn/git_version.py.in +7 -0
- eryn/model.py +18 -0
- eryn/moves/__init__.py +42 -0
- eryn/moves/combine.py +135 -0
- eryn/moves/delayedrejection.py +229 -0
- eryn/moves/distgen.py +104 -0
- eryn/moves/distgenrj.py +222 -0
- eryn/moves/gaussian.py +190 -0
- eryn/moves/group.py +281 -0
- eryn/moves/groupstretch.py +120 -0
- eryn/moves/mh.py +193 -0
- eryn/moves/move.py +703 -0
- eryn/moves/mtdistgen.py +137 -0
- eryn/moves/mtdistgenrj.py +190 -0
- eryn/moves/multipletry.py +776 -0
- eryn/moves/red_blue.py +333 -0
- eryn/moves/rj.py +388 -0
- eryn/moves/stretch.py +231 -0
- eryn/moves/tempering.py +649 -0
- eryn/pbar.py +56 -0
- eryn/prior.py +452 -0
- eryn/state.py +775 -0
- eryn/tests/__init__.py +0 -0
- eryn/tests/test_eryn.py +1246 -0
- eryn/utils/__init__.py +10 -0
- eryn/utils/periodic.py +134 -0
- eryn/utils/stopping.py +164 -0
- eryn/utils/transform.py +226 -0
- eryn/utils/updates.py +69 -0
- eryn/utils/utility.py +329 -0
- eryn-1.2.0.dist-info/METADATA +167 -0
- eryn-1.2.0.dist-info/RECORD +39 -0
- eryn-1.2.0.dist-info/WHEEL +4 -0
eryn/ensemble.py
ADDED
|
@@ -0,0 +1,1690 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from itertools import count
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
|
|
9
|
+
from .backends import Backend, HDFBackend
|
|
10
|
+
from .model import Model
|
|
11
|
+
from .moves import StretchMove, TemperatureControl, DistributionGenerateRJ, GaussianMove
|
|
12
|
+
from .pbar import get_progress_bar
|
|
13
|
+
from .state import State
|
|
14
|
+
from .prior import ProbDistContainer
|
|
15
|
+
|
|
16
|
+
# from .utils import PlotContainer
|
|
17
|
+
from .utils import PeriodicContainer
|
|
18
|
+
from .utils.utility import groups_from_inds
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
__all__ = ["EnsembleSampler", "walkers_independent"]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
from collections.abc import Iterable
|
|
26
|
+
except ImportError:
|
|
27
|
+
# for py2.7, will be an Exception in 3.8
|
|
28
|
+
from collections import Iterable
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class EnsembleSampler(object):
|
|
32
|
+
"""An ensemble MCMC sampler
|
|
33
|
+
|
|
34
|
+
The class controls the entire sampling run. It can handle
|
|
35
|
+
everything from a basic non-tempered MCMC to a parallel-tempered,
|
|
36
|
+
global fit containing multiple branches (models) and a variable
|
|
37
|
+
number of leaves (sources) per branch.
|
|
38
|
+
See `here <https://mikekatz04.github.io/Eryn/html/tutorial/Eryn_tutorial.html#The-Tree-Metaphor>`_
|
|
39
|
+
for a basic explainer.
|
|
40
|
+
|
|
41
|
+
Parameters related to parallelization can be controlled via the ``pool`` argument.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
nwalkers (int): The number of walkers in the ensemble per temperature.
|
|
45
|
+
ndims (int, list of ints, or dict): The number of dimensions for each branch. If
|
|
46
|
+
``dict``, keys should be the branch names and values the associated dimensionality.
|
|
47
|
+
log_like_fn (callable): A function that returns the natural logarithm of the
|
|
48
|
+
likelihood for that position. The inputs to ``log_like_fn`` depend on whether
|
|
49
|
+
the function is vectorized (kwarg ``vectorize`` below), if you are using reversible jump,
|
|
50
|
+
and how many branches you have.
|
|
51
|
+
|
|
52
|
+
In the simplest case where ``vectorize == False``, no reversible jump, and only one
|
|
53
|
+
type of model, the inputs are just the array of parameters for one walker, so shape is ``(ndim,)``.
|
|
54
|
+
|
|
55
|
+
If ``vectorize == True``, no reversible jimp, and only one type of model, the inputs will
|
|
56
|
+
be a 2D array of parameters of all the walkers going in. Shape: ``(num positions, ndim)``.
|
|
57
|
+
|
|
58
|
+
If using reversible jump, the leaves that go together in the same Likelihood will be grouped
|
|
59
|
+
together into a single function call. If ``vectorize == False``, then each group is sent as
|
|
60
|
+
an individual computation. With ``N`` different branches (``N > 1``), inputs would be a list
|
|
61
|
+
of 2D arrays of the coordinates for all leaves within each branch: ``([x0, x1,...,xN])``
|
|
62
|
+
where ``xi`` is 2D with shape ``(number of leaves in this branch, ndim)``. If ``N == 1``, then
|
|
63
|
+
a list is not provided, just x0, the 2D array of coordinates for the one branch considered.
|
|
64
|
+
|
|
65
|
+
If using reversible jump and ``vectorize == True``, then the arrays of parameters will be output
|
|
66
|
+
with information as regards the grouping of branch and leaf set. Inputs will be
|
|
67
|
+
``([X0, X1,..XN], [group0, group1,...,groupN])`` where ``Xi`` is a 2D array of all
|
|
68
|
+
leaves in the sampler for branch ``i``. ``groupi`` is an index indicating which unique group
|
|
69
|
+
that sources belongs. For example, if we have 3 walkers with (1, 2, 1) leaves for model ``i``,
|
|
70
|
+
respectively, we wil have an ``Xi = array([params0, params1, params2, params3])`` and
|
|
71
|
+
``groupsi = array([0, 1, 1, 2])``.
|
|
72
|
+
If ``N == 1``, then the lists are removed and the inputs become ``(X0, group0)``.
|
|
73
|
+
|
|
74
|
+
Extra ``args`` and ``kwargs`` for the Likelihood function can be added with the kwargs
|
|
75
|
+
``args`` and ``kwargs`` below.
|
|
76
|
+
|
|
77
|
+
Please see the
|
|
78
|
+
`tutorial <https://mikekatz04.github.io/Eryn/html/tutorial/Eryn_tutorial.html#>`_
|
|
79
|
+
for more information.
|
|
80
|
+
|
|
81
|
+
priors (dict): The prior dictionary can take four forms.
|
|
82
|
+
1) A dictionary with keys as int or tuple containing the int or tuple of int
|
|
83
|
+
that describe the parameter number over which to assess the prior, and values that
|
|
84
|
+
are prior probability distributions that must have a ``logpdf`` class method.
|
|
85
|
+
2) A :class:`eryn.prior.ProbDistContainer` object.
|
|
86
|
+
3) A dictionary with keys that are ``branch_names`` and values that are dictionaries for
|
|
87
|
+
each branch as described for (1).
|
|
88
|
+
4) A dictionary with keys that are ``branch_names`` and values are
|
|
89
|
+
:class:`eryn.prior.ProbDistContainer` objects.
|
|
90
|
+
If the priors dictionary has the specific key ``"all_models_together"`` in it, then a special
|
|
91
|
+
prior must be input by the user that can produce the prior ``logpdf`` information that can depend on all of the models
|
|
92
|
+
together rather than handling them separately as is the default. In this case, the user must input
|
|
93
|
+
as the item attached to this special key a class object with a ``logpdf`` method that takes as input
|
|
94
|
+
two arguments: ``(coords, inds)``, which are the coordinate and index dictionaries across all models with
|
|
95
|
+
shapes of ``(ntemps, nwalkers, nleaves_max, ndim)`` and ``(ntemps, nwalkers, nleaves_max)`` for each
|
|
96
|
+
individual model, respectively. This function must then return the numpy array of logpdf values for the
|
|
97
|
+
prior value with shape ``(ntemps, nwalkers)``.
|
|
98
|
+
provide_groups (bool, optional): If ``True``, provide groups as described in ``log_like_fn`` above.
|
|
99
|
+
A group parameter is added for each branch. (default: ``False``)
|
|
100
|
+
provide_supplemental (bool, optional): If ``True``, it will provide keyword arguments to
|
|
101
|
+
the Likelihood function: ``supps`` and ``branch_supps``. Please see the `Tutorial <https://mikekatz04.github.io/Eryn/html/tutorial/Eryn_tutorial.html#>`_
|
|
102
|
+
and :class:`eryn.state.BranchSupplemental` for more information.
|
|
103
|
+
tempering_kwargs (dict, optional): Keyword arguments for initialization of the
|
|
104
|
+
tempering class: :class:`eryn.moves.tempering.TemperatureControl`. (default: ``{}``)
|
|
105
|
+
branch_names (list, optional): List of branch names. If ``None``, models will be assigned
|
|
106
|
+
names as ``f"model_{index}"`` based on ``nbranches``. (default: ``None``)
|
|
107
|
+
nbranches (int, optional): Number of branches (models) tested.
|
|
108
|
+
Only used if ``branch_names is None``.
|
|
109
|
+
(default: ``1``)
|
|
110
|
+
nleaves_max (int, list of ints, or dict, optional): Maximum allowable leaf count for each branch.
|
|
111
|
+
It should have the same length as the number of branches.
|
|
112
|
+
If ``dict``, keys should be the branch names and values the associated maximal leaf value.
|
|
113
|
+
(default: ``1``)
|
|
114
|
+
nleaves_min (int, list of ints, or dict, optional): Minimum allowable leaf count for each branch.
|
|
115
|
+
It should have the same length as the number of branches. Only used with Reversible Jump.
|
|
116
|
+
If ``dict``, keys should be the branch names and values the associated maximal leaf value.
|
|
117
|
+
If ``None`` and using Reversible Jump, will fill all branches with zero.
|
|
118
|
+
(default: ``0``)
|
|
119
|
+
pool (object, optional): An object with a ``map`` method that follows the same
|
|
120
|
+
calling sequence as the built-in ``map`` function. This is
|
|
121
|
+
generally used to compute the log-probabilities for the ensemble
|
|
122
|
+
in parallel.
|
|
123
|
+
moves (list or object, optional): This can be a single move object, a list of moves,
|
|
124
|
+
or a "weighted" list of the form ``[(eryn.moves.StretchMove(),
|
|
125
|
+
0.1), ...]``. When running, the sampler will randomly select a
|
|
126
|
+
move from this list (optionally with weights) for each proposal.
|
|
127
|
+
If ``None``, the default will be :class:`StretchMove`.
|
|
128
|
+
(default: ``None``)
|
|
129
|
+
rj_moves (list or object or bool or str, optional): If ``None`` or ``False``, reversible jump will not be included in the run.
|
|
130
|
+
This can be a single move object, a list of moves,
|
|
131
|
+
or a "weighted" list of the form ``[(eryn.moves.DistributionGenerateRJ(),
|
|
132
|
+
0.1), ...]``. When running, the sampler will randomly select a
|
|
133
|
+
move from this list (optionally with weights) for each proposal.
|
|
134
|
+
If ``True``, it defaults to :class:`DistributionGenerateRJ`. When running just the :class:`DistributionGenerateRJ`
|
|
135
|
+
with multiple branches, it will propose changes to all branhces simultaneously.
|
|
136
|
+
When running with more than one branch, useful options for ``rj_moves`` are ``"iterate_branches"``,
|
|
137
|
+
``"separate_branches"``, or ``"together"``. If ``rj_moves == "iterate_branches"``, sample one branch by one branch in order of
|
|
138
|
+
the branch names. This occurs within one RJ proposal, for each RJ proposal. If ``rj_moves == "separate_branches"``,
|
|
139
|
+
there will be one RJ move per branch. During each individual RJ move, one of these proposals is chosen at random with
|
|
140
|
+
equal propability. This is generally recommended when using multiple branches.
|
|
141
|
+
If ``rj_moves == "together"``, this is equivalent to ``rj_moves == True``.
|
|
142
|
+
(default: ``None``)
|
|
143
|
+
dr_moves (bool, optional): If ``None`` ot ``False``, delayed rejection when proposing "birth"
|
|
144
|
+
of new components/models will be switched off for this run. Requires ``rj_moves`` set to ``True``.
|
|
145
|
+
Not implemented yet. Working on it.
|
|
146
|
+
(default: ``None``)
|
|
147
|
+
dr_max_iter (int, optional): Maximum number of iterations used with delayed rejection. (default: 5)
|
|
148
|
+
args (optional): A list of extra positional arguments for
|
|
149
|
+
``log_like_fn``. ``log_like_fn`` will be called as
|
|
150
|
+
``log_like_fn(sampler added args, *args, sampler added kwargs, **kwargs)``.
|
|
151
|
+
kwargs (optional): A dict of extra keyword arguments for
|
|
152
|
+
``log_like_fn``. ``log_like_fn`` will be called as
|
|
153
|
+
``log_like_fn(sampler added args, *args, sampler added kwargs, **kwargs)``.
|
|
154
|
+
backend (optional): Either a :class:`backends.Backend` or a subclass
|
|
155
|
+
(like :class:`backends.HDFBackend`) that is used to store and
|
|
156
|
+
serialize the state of the chain. By default, the chain is stored
|
|
157
|
+
as a set of numpy arrays in memory, but new backends can be
|
|
158
|
+
written to support other mediums.
|
|
159
|
+
vectorize (bool, optional): If ``True``, ``log_like_fn`` is expected
|
|
160
|
+
to accept an array of position vectors instead of just one. Note
|
|
161
|
+
that ``pool`` will be ignored if this is ``True``. See ``log_like_fn`` information
|
|
162
|
+
above to understand the arguments of ``log_like_fn`` based on whether
|
|
163
|
+
``vectorize`` is ``True``.
|
|
164
|
+
(default: ``False``)
|
|
165
|
+
periodic (dict, optional): Keys are ``branch_names``. Values are dictionaries
|
|
166
|
+
that have (key: value) pairs as (index to parameter: period). Periodic
|
|
167
|
+
parameters are treated as having periodic boundary conditions in proposals.
|
|
168
|
+
update_fn (callable, optional): :class:`eryn.utils.updates.AdjustStretchProposalScale`
|
|
169
|
+
object that allows the user to update the sampler in any preferred way
|
|
170
|
+
every ``update_iterations`` sampler iterations. The callable must have signature:
|
|
171
|
+
``(sampler iteration, last sample state object, EnsembleSampler object)``.
|
|
172
|
+
update_iterations (int, optional): Number of iterations between sampler
|
|
173
|
+
updates using ``update_fn``. Updates are only performed at the thinning rate.
|
|
174
|
+
If ``thin_by>1`` when :func:`EnsembleSampler.run_mcmc` is used, the sampler
|
|
175
|
+
is updated every ``thin_by * update_iterations`` iterations.
|
|
176
|
+
stopping_fn (callable, optional): :class:`eryn.utils.stopping.Stopping` object that
|
|
177
|
+
allows the user to end the sampler if specified criteria are met.
|
|
178
|
+
The callable must have signature:
|
|
179
|
+
``(sampler iteration, last sample state object, EnsembleSampler object)``.
|
|
180
|
+
stopping_iterations (int, optional): Number of iterations between sampler
|
|
181
|
+
attempts to evaluate the ``stopping_fn``. Stopping checks are only performed at the thinning rate.
|
|
182
|
+
If ``thin_by>1`` when :func:`EnsembleSampler.run_mcmc` is used, the sampler
|
|
183
|
+
is checked for the stopping criterion every ``thin_by * stopping_iterations`` iterations.
|
|
184
|
+
fill_zero_leaves_val (double, optional): When there are zero leaves in a
|
|
185
|
+
given walker (across all branches), fill the likelihood value with
|
|
186
|
+
``fill_zero_leaves_val``. If wanting to keep zero leaves as a possible
|
|
187
|
+
model, this should be set to the value of the contribution to the Likelihood
|
|
188
|
+
from the data. (Default: ``-1e300``).
|
|
189
|
+
num_repeats_in_model (int, optional): Number of times to repeat the in-model step
|
|
190
|
+
within in one sampler iteration. When analyzing the acceptance fraction, you must
|
|
191
|
+
include the value of ``num_repeats_in_model`` to get the proper denominator.
|
|
192
|
+
num_repeats_rj (int, optional): Number of time to repeat the reversible jump step
|
|
193
|
+
within in one sampler iteration. When analyzing the acceptance fraction, you must
|
|
194
|
+
include the value of ``num_repeats_rj`` to get the proper denominator.
|
|
195
|
+
track_moves (bool, optional): If ``True``, track acceptance fraction of each move
|
|
196
|
+
in the backend. If ``False``, no tracking is done. If ``True`` and run is interrupted, it will check
|
|
197
|
+
that the move configuration has not changed. It will not allow the run to go on
|
|
198
|
+
if it is changed. In this case, the user should declare a new backend and use the last
|
|
199
|
+
state from the previous backend. **Warning**: If the order of moves of the same move class
|
|
200
|
+
is changed, the check may not catch it, so the tracking may mix move acceptance fractions together.
|
|
201
|
+
info (dict, optional): Key and value pairs reprenting any information
|
|
202
|
+
the user wants to add to the backend if the user is not inputing
|
|
203
|
+
their own backend.
|
|
204
|
+
|
|
205
|
+
Raises:
|
|
206
|
+
ValueError: Any startup issues.
|
|
207
|
+
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
def __init__(
|
|
211
|
+
self,
|
|
212
|
+
nwalkers,
|
|
213
|
+
ndims, # assumes ndim_max
|
|
214
|
+
log_like_fn,
|
|
215
|
+
priors,
|
|
216
|
+
provide_groups=False,
|
|
217
|
+
provide_supplemental=False,
|
|
218
|
+
tempering_kwargs={},
|
|
219
|
+
branch_names=None,
|
|
220
|
+
nbranches=1,
|
|
221
|
+
nleaves_max=1,
|
|
222
|
+
nleaves_min=0,
|
|
223
|
+
pool=None,
|
|
224
|
+
moves=None,
|
|
225
|
+
rj_moves=None,
|
|
226
|
+
dr_moves=None,
|
|
227
|
+
dr_max_iter=5,
|
|
228
|
+
args=None,
|
|
229
|
+
kwargs=None,
|
|
230
|
+
backend=None,
|
|
231
|
+
vectorize=False,
|
|
232
|
+
blobs_dtype=None, # TODO check this
|
|
233
|
+
plot_iterations=-1, # TODO: do plot stuff?
|
|
234
|
+
plot_generator=None,
|
|
235
|
+
plot_name=None,
|
|
236
|
+
periodic=None,
|
|
237
|
+
update_fn=None,
|
|
238
|
+
update_iterations=-1,
|
|
239
|
+
stopping_fn=None,
|
|
240
|
+
stopping_iterations=-1,
|
|
241
|
+
fill_zero_leaves_val=-1e300,
|
|
242
|
+
num_repeats_in_model=1,
|
|
243
|
+
num_repeats_rj=1,
|
|
244
|
+
track_moves=True,
|
|
245
|
+
info={},
|
|
246
|
+
):
|
|
247
|
+
# store priors
|
|
248
|
+
self.priors = priors
|
|
249
|
+
|
|
250
|
+
# store some kwargs
|
|
251
|
+
self.provide_groups = provide_groups
|
|
252
|
+
self.provide_supplemental = provide_supplemental
|
|
253
|
+
self.fill_zero_leaves_val = fill_zero_leaves_val
|
|
254
|
+
self.num_repeats_in_model = num_repeats_in_model
|
|
255
|
+
self.num_repeats_rj = num_repeats_rj
|
|
256
|
+
self.track_moves = track_moves
|
|
257
|
+
|
|
258
|
+
# setup emcee-like basics
|
|
259
|
+
self.pool = pool
|
|
260
|
+
self.vectorize = vectorize
|
|
261
|
+
self.blobs_dtype = blobs_dtype
|
|
262
|
+
|
|
263
|
+
# turn things into lists/dicts if needed
|
|
264
|
+
if branch_names is not None:
|
|
265
|
+
if isinstance(branch_names, str):
|
|
266
|
+
branch_names = [branch_names]
|
|
267
|
+
|
|
268
|
+
elif not isinstance(branch_names, list):
|
|
269
|
+
raise ValueError("branch_names must be string or list of strings.")
|
|
270
|
+
|
|
271
|
+
else:
|
|
272
|
+
branch_names = ["model_{}".format(i) for i in range(nbranches)]
|
|
273
|
+
|
|
274
|
+
nbranches = len(branch_names)
|
|
275
|
+
|
|
276
|
+
if isinstance(ndims, int):
|
|
277
|
+
assert len(branch_names) == 1
|
|
278
|
+
ndims = {branch_names[0]: ndims}
|
|
279
|
+
|
|
280
|
+
elif isinstance(ndims, list) or isinstance(ndims, np.ndarray):
|
|
281
|
+
assert len(branch_names) == len(ndims)
|
|
282
|
+
ndims = {bn: nd for bn, nd in zip(branch_names, ndims)}
|
|
283
|
+
|
|
284
|
+
elif isinstance(ndims, dict):
|
|
285
|
+
assert len(list(ndims.keys())) == len(branch_names)
|
|
286
|
+
for key in ndims:
|
|
287
|
+
if key not in branch_names:
|
|
288
|
+
raise ValueError(
|
|
289
|
+
f"{key} is in ndims but does not appear in branch_names: {branch_names}."
|
|
290
|
+
)
|
|
291
|
+
else:
|
|
292
|
+
raise ValueError("ndims is to be a scalar int, list or dict.")
|
|
293
|
+
|
|
294
|
+
if isinstance(nleaves_max, int):
|
|
295
|
+
assert len(branch_names) == 1
|
|
296
|
+
nleaves_max = {branch_names[0]: nleaves_max}
|
|
297
|
+
|
|
298
|
+
elif isinstance(nleaves_max, list) or isinstance(nleaves_max, np.ndarray):
|
|
299
|
+
assert len(branch_names) == len(nleaves_max)
|
|
300
|
+
nleaves_max = {bn: nl for bn, nl in zip(branch_names, nleaves_max)}
|
|
301
|
+
|
|
302
|
+
elif isinstance(nleaves_max, dict):
|
|
303
|
+
assert len(list(nleaves_max.keys())) == len(branch_names)
|
|
304
|
+
for key in nleaves_max:
|
|
305
|
+
if key not in branch_names:
|
|
306
|
+
raise ValueError(
|
|
307
|
+
f"{key} is in nleaves_max but does not appear in branch_names: {branch_names}."
|
|
308
|
+
)
|
|
309
|
+
else:
|
|
310
|
+
raise ValueError("nleaves_max is to be a scalar int, list, or dict.")
|
|
311
|
+
|
|
312
|
+
self.nbranches = len(branch_names)
|
|
313
|
+
|
|
314
|
+
self.branch_names = branch_names
|
|
315
|
+
self.ndims = ndims
|
|
316
|
+
self.nleaves_max = nleaves_max
|
|
317
|
+
|
|
318
|
+
# setup temperaing information
|
|
319
|
+
# default is no temperatures
|
|
320
|
+
if tempering_kwargs == {}:
|
|
321
|
+
self.ntemps = 1
|
|
322
|
+
self.temperature_control = None
|
|
323
|
+
else:
|
|
324
|
+
# get effective total dimension
|
|
325
|
+
total_ndim = 0
|
|
326
|
+
for key in self.branch_names:
|
|
327
|
+
total_ndim += self.nleaves_max[key] * self.ndims[key]
|
|
328
|
+
self.temperature_control = TemperatureControl(
|
|
329
|
+
total_ndim, nwalkers, **tempering_kwargs
|
|
330
|
+
)
|
|
331
|
+
self.ntemps = self.temperature_control.ntemps
|
|
332
|
+
|
|
333
|
+
# set basic variables for sampling settings
|
|
334
|
+
self.nwalkers = nwalkers
|
|
335
|
+
self.nbranches = nbranches
|
|
336
|
+
|
|
337
|
+
# eryn wraps periodic parameters
|
|
338
|
+
if periodic is not None:
|
|
339
|
+
if not isinstance(periodic, PeriodicContainer) and not isinstance(
|
|
340
|
+
periodic, dict
|
|
341
|
+
):
|
|
342
|
+
raise ValueError(
|
|
343
|
+
"periodic must be PeriodicContainer or dict if not None."
|
|
344
|
+
)
|
|
345
|
+
elif isinstance(periodic, dict):
|
|
346
|
+
periodic = PeriodicContainer(periodic)
|
|
347
|
+
|
|
348
|
+
# Parse the move schedule
|
|
349
|
+
if moves is None:
|
|
350
|
+
if rj_moves is not None:
|
|
351
|
+
raise ValueError(
|
|
352
|
+
"If providing rj_moves, must provide moves kwarg as well."
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
# defaults to stretch move
|
|
356
|
+
self.moves = [
|
|
357
|
+
StretchMove(
|
|
358
|
+
temperature_control=self.temperature_control,
|
|
359
|
+
periodic=periodic,
|
|
360
|
+
a=2.0,
|
|
361
|
+
)
|
|
362
|
+
]
|
|
363
|
+
self.weights = [1.0]
|
|
364
|
+
|
|
365
|
+
elif isinstance(moves, Iterable):
|
|
366
|
+
try:
|
|
367
|
+
self.moves, self.weights = [list(tmp) for tmp in zip(*moves)]
|
|
368
|
+
|
|
369
|
+
except TypeError:
|
|
370
|
+
self.moves = moves
|
|
371
|
+
self.weights = np.ones(len(moves))
|
|
372
|
+
else:
|
|
373
|
+
self.moves = [moves]
|
|
374
|
+
self.weights = [1.0]
|
|
375
|
+
|
|
376
|
+
self.weights = np.atleast_1d(self.weights).astype(float)
|
|
377
|
+
self.weights /= np.sum(self.weights)
|
|
378
|
+
|
|
379
|
+
# parse the reversible jump move schedule
|
|
380
|
+
if rj_moves is None:
|
|
381
|
+
self.has_reversible_jump = False
|
|
382
|
+
elif (isinstance(rj_moves, bool) and rj_moves) or isinstance(rj_moves, str):
|
|
383
|
+
self.has_reversible_jump = True
|
|
384
|
+
|
|
385
|
+
if self.has_reversible_jump:
|
|
386
|
+
if nleaves_min is None:
|
|
387
|
+
nleaves_min = {bn: 0 for bn in branch_names}
|
|
388
|
+
elif isinstance(nleaves_min, int):
|
|
389
|
+
assert len(branch_names) == 1
|
|
390
|
+
nleaves_min = {branch_names[0]: nleaves_min}
|
|
391
|
+
|
|
392
|
+
elif isinstance(nleaves_min, list) or isinstance(
|
|
393
|
+
nleaves_min, np.ndarray
|
|
394
|
+
):
|
|
395
|
+
assert len(branch_names) == len(nleaves_min)
|
|
396
|
+
nleaves_min = {bn: nl for bn, nl in zip(branch_names, nleaves_min)}
|
|
397
|
+
|
|
398
|
+
elif isinstance(nleaves_min, dict):
|
|
399
|
+
assert len(list(nleaves_min.keys())) == len(branch_names)
|
|
400
|
+
for key in nleaves_min:
|
|
401
|
+
if key not in branch_names:
|
|
402
|
+
raise ValueError(
|
|
403
|
+
f"{key} is in nleaves_min but does not appear in branch_names: {branch_names}."
|
|
404
|
+
)
|
|
405
|
+
else:
|
|
406
|
+
raise ValueError(
|
|
407
|
+
"If providing nleaves_min, nleaves_min is to be a scalar int, list, or dict."
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
self.nleaves_min = nleaves_min
|
|
411
|
+
|
|
412
|
+
if (isinstance(rj_moves, bool) and rj_moves) or (
|
|
413
|
+
isinstance(rj_moves, str) and rj_moves == "together"
|
|
414
|
+
):
|
|
415
|
+
# default to DistributionGenerateRJ
|
|
416
|
+
|
|
417
|
+
# gibbs sampling setup here means run all of them together
|
|
418
|
+
gibbs_sampling_setup = None
|
|
419
|
+
|
|
420
|
+
rj_move = DistributionGenerateRJ(
|
|
421
|
+
self.priors,
|
|
422
|
+
nleaves_max=self.nleaves_max,
|
|
423
|
+
nleaves_min=self.nleaves_min,
|
|
424
|
+
dr=dr_moves,
|
|
425
|
+
dr_max_iter=dr_max_iter,
|
|
426
|
+
tune=False,
|
|
427
|
+
temperature_control=self.temperature_control,
|
|
428
|
+
gibbs_sampling_setup=gibbs_sampling_setup,
|
|
429
|
+
)
|
|
430
|
+
self.rj_moves = [rj_move]
|
|
431
|
+
self.rj_weights = [1.0]
|
|
432
|
+
|
|
433
|
+
elif isinstance(rj_moves, str) and rj_moves == "iterate_branches":
|
|
434
|
+
# will iterate through all branches within one RJ proposal
|
|
435
|
+
gibbs_sampling_setup = deepcopy(branch_names)
|
|
436
|
+
|
|
437
|
+
# default to DistributionGenerateRJ
|
|
438
|
+
rj_move = DistributionGenerateRJ(
|
|
439
|
+
self.priors,
|
|
440
|
+
nleaves_max=self.nleaves_max,
|
|
441
|
+
nleaves_min=self.nleaves_min,
|
|
442
|
+
dr=dr_moves,
|
|
443
|
+
dr_max_iter=dr_max_iter,
|
|
444
|
+
tune=False,
|
|
445
|
+
temperature_control=self.temperature_control,
|
|
446
|
+
gibbs_sampling_setup=gibbs_sampling_setup,
|
|
447
|
+
)
|
|
448
|
+
self.rj_moves = [rj_move]
|
|
449
|
+
self.rj_weights = [1.0]
|
|
450
|
+
|
|
451
|
+
elif isinstance(rj_moves, str) and rj_moves == "separate_branches":
|
|
452
|
+
# will iterate through all branches within one RJ proposal
|
|
453
|
+
rj_moves = []
|
|
454
|
+
rj_weights = []
|
|
455
|
+
for branch_name in branch_names:
|
|
456
|
+
|
|
457
|
+
# only do one branch per move
|
|
458
|
+
gibbs_sampling_setup = [branch_name]
|
|
459
|
+
|
|
460
|
+
# default to DistributionGenerateRJ
|
|
461
|
+
rj_move_tmp = DistributionGenerateRJ(
|
|
462
|
+
self.priors,
|
|
463
|
+
nleaves_max=self.nleaves_max,
|
|
464
|
+
nleaves_min=self.nleaves_min,
|
|
465
|
+
dr=dr_moves,
|
|
466
|
+
dr_max_iter=dr_max_iter,
|
|
467
|
+
tune=False,
|
|
468
|
+
temperature_control=self.temperature_control,
|
|
469
|
+
gibbs_sampling_setup=gibbs_sampling_setup,
|
|
470
|
+
)
|
|
471
|
+
rj_moves.append(rj_move_tmp)
|
|
472
|
+
# will renormalize after
|
|
473
|
+
rj_weights.append(1.0)
|
|
474
|
+
self.rj_moves = rj_moves
|
|
475
|
+
self.rj_weights = rj_weights
|
|
476
|
+
|
|
477
|
+
elif isinstance(rj_moves, str):
|
|
478
|
+
raise ValueError(
|
|
479
|
+
f"When providing a str for rj_moves, must be 'together', 'iterate_branches', or 'separate_branches'. Input is {rj_moves}"
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
# same as above for moves
|
|
483
|
+
elif isinstance(rj_moves, Iterable):
|
|
484
|
+
self.has_reversible_jump = True
|
|
485
|
+
|
|
486
|
+
try:
|
|
487
|
+
self.rj_moves, self.rj_weights = zip(*rj_moves)
|
|
488
|
+
except TypeError:
|
|
489
|
+
self.rj_moves = rj_moves
|
|
490
|
+
self.rj_weights = np.ones(len(rj_moves))
|
|
491
|
+
|
|
492
|
+
elif isinstance(rj_moves, bool) and not rj_moves:
|
|
493
|
+
self.has_reversible_jump = False
|
|
494
|
+
self.rj_moves = None
|
|
495
|
+
self.rj_weights = None
|
|
496
|
+
|
|
497
|
+
elif not isinstance(rj_moves, bool):
|
|
498
|
+
self.has_reversible_jump = True
|
|
499
|
+
|
|
500
|
+
self.rj_moves = [rj_moves]
|
|
501
|
+
self.rj_weights = [1.0]
|
|
502
|
+
|
|
503
|
+
# adjust rj weights properly
|
|
504
|
+
if self.has_reversible_jump:
|
|
505
|
+
self.rj_weights = np.atleast_1d(self.rj_weights).astype(float)
|
|
506
|
+
self.rj_weights /= np.sum(self.rj_weights)
|
|
507
|
+
|
|
508
|
+
# warn if base stretch is used
|
|
509
|
+
for move in self.moves:
|
|
510
|
+
if type(move) == StretchMove:
|
|
511
|
+
warnings.warn(
|
|
512
|
+
"If using revisible jump, using the Stretch Move for in-model proposals is not advised. It will run and work, but it will not be using the correct complientary group of parameters meaning it will most likely be very inefficient."
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
# make sure moves have temperature module
|
|
516
|
+
if self.temperature_control is not None:
|
|
517
|
+
for move in self.moves:
|
|
518
|
+
if move.temperature_control is None:
|
|
519
|
+
move.temperature_control = self.temperature_control
|
|
520
|
+
|
|
521
|
+
if self.has_reversible_jump:
|
|
522
|
+
for move in self.rj_moves:
|
|
523
|
+
if move.temperature_control is None:
|
|
524
|
+
move.temperature_control = self.temperature_control
|
|
525
|
+
|
|
526
|
+
# make sure moves have temperature module
|
|
527
|
+
if periodic is not None:
|
|
528
|
+
for move in self.moves:
|
|
529
|
+
if move.periodic is None:
|
|
530
|
+
move.periodic = periodic
|
|
531
|
+
|
|
532
|
+
if self.has_reversible_jump:
|
|
533
|
+
for move in self.rj_moves:
|
|
534
|
+
if move.periodic is None:
|
|
535
|
+
move.periodic = periodic
|
|
536
|
+
|
|
537
|
+
# prepare the per proposal accepted values that are held as attributes in the specific classes
|
|
538
|
+
for move in self.moves:
|
|
539
|
+
move.accepted = np.zeros((self.ntemps, self.nwalkers))
|
|
540
|
+
|
|
541
|
+
if self.has_reversible_jump:
|
|
542
|
+
for move in self.rj_moves:
|
|
543
|
+
move.accepted = np.zeros((self.ntemps, self.nwalkers))
|
|
544
|
+
|
|
545
|
+
# setup backend if not provided or initialized
|
|
546
|
+
if backend is None:
|
|
547
|
+
self.backend = Backend()
|
|
548
|
+
elif isinstance(backend, str):
|
|
549
|
+
self.backend = HDFBackend(backend)
|
|
550
|
+
else:
|
|
551
|
+
self.backend = backend
|
|
552
|
+
|
|
553
|
+
self.info = info
|
|
554
|
+
|
|
555
|
+
all_moves_tmp = list(
|
|
556
|
+
tuple(self.moves)
|
|
557
|
+
if not self.has_reversible_jump
|
|
558
|
+
else tuple(self.moves) + tuple(self.rj_moves)
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
self.all_moves = {}
|
|
562
|
+
if self.track_moves:
|
|
563
|
+
current_indices_move_keys = {}
|
|
564
|
+
for move in all_moves_tmp:
|
|
565
|
+
# get out of tuple if weights are given
|
|
566
|
+
if isinstance(move, tuple):
|
|
567
|
+
move = move[0]
|
|
568
|
+
|
|
569
|
+
# get the name of the class instance as a string
|
|
570
|
+
move_name = move.__class__.__name__
|
|
571
|
+
|
|
572
|
+
# need to keep track how many times each type of move class has been used
|
|
573
|
+
if move_name not in current_indices_move_keys:
|
|
574
|
+
current_indices_move_keys[move_name] = 0
|
|
575
|
+
|
|
576
|
+
else:
|
|
577
|
+
current_indices_move_keys[move_name] += 1
|
|
578
|
+
|
|
579
|
+
# get the full name including the index
|
|
580
|
+
full_move_name = move_name + f"_{current_indices_move_keys[move_name]}"
|
|
581
|
+
self.all_moves[full_move_name] = move
|
|
582
|
+
|
|
583
|
+
# get move keys out
|
|
584
|
+
move_keys = list(self.all_moves.keys())
|
|
585
|
+
|
|
586
|
+
else:
|
|
587
|
+
move_keys = None
|
|
588
|
+
|
|
589
|
+
self.move_keys = move_keys
|
|
590
|
+
|
|
591
|
+
# Deal with re-used backends
|
|
592
|
+
if not self.backend.initialized:
|
|
593
|
+
self._previous_state = None
|
|
594
|
+
self.reset(
|
|
595
|
+
branch_names=branch_names,
|
|
596
|
+
ntemps=self.ntemps,
|
|
597
|
+
nleaves_max=nleaves_max,
|
|
598
|
+
rj=self.has_reversible_jump,
|
|
599
|
+
moves=move_keys,
|
|
600
|
+
**info,
|
|
601
|
+
)
|
|
602
|
+
state = np.random.get_state()
|
|
603
|
+
else:
|
|
604
|
+
if self.track_moves:
|
|
605
|
+
moves_okay = True
|
|
606
|
+
if len(self.move_keys) != len(self.backend.move_keys):
|
|
607
|
+
moves_okay = False
|
|
608
|
+
|
|
609
|
+
for key in self.move_keys:
|
|
610
|
+
if key not in self.backend.move_keys:
|
|
611
|
+
moves_okay = False
|
|
612
|
+
|
|
613
|
+
if not moves_okay:
|
|
614
|
+
raise ValueError(
|
|
615
|
+
"Configuration of moves has changed. Cannot use the same backend. Declare a new backend and start from the previous state. If you would prefer not to track move acceptance fraction, set track_moves to False in the EnsembleSampler."
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
# Check the backend shape
|
|
619
|
+
for i, (name, shape) in enumerate(self.backend.shape.items()):
|
|
620
|
+
test_shape = (
|
|
621
|
+
self.ntemps,
|
|
622
|
+
self.nwalkers,
|
|
623
|
+
self.nleaves_max[name],
|
|
624
|
+
self.ndims[name],
|
|
625
|
+
)
|
|
626
|
+
if shape != test_shape:
|
|
627
|
+
raise ValueError(
|
|
628
|
+
(
|
|
629
|
+
"the shape of the backend ({0}) is incompatible with the "
|
|
630
|
+
"shape of the sampler ({1} for model {2})"
|
|
631
|
+
).format(shape, test_shape, name)
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
# Get the last random state
|
|
635
|
+
state = self.backend.random_state
|
|
636
|
+
if state is None:
|
|
637
|
+
state = np.random.get_state()
|
|
638
|
+
|
|
639
|
+
# Grab the last step so that we can restart
|
|
640
|
+
it = self.backend.iteration
|
|
641
|
+
if it > 0:
|
|
642
|
+
self._previous_state = self.get_last_sample()
|
|
643
|
+
|
|
644
|
+
# This is a random number generator that we can easily set the state
|
|
645
|
+
# of without affecting the numpy-wide generator
|
|
646
|
+
self._random = np.random.mtrand.RandomState()
|
|
647
|
+
self._random.set_state(state)
|
|
648
|
+
|
|
649
|
+
# Do a little bit of _magic_ to make the likelihood call with
|
|
650
|
+
# ``args`` and ``kwargs`` pickleable.
|
|
651
|
+
self.log_like_fn = _FunctionWrapper(log_like_fn, args, kwargs)
|
|
652
|
+
|
|
653
|
+
self.all_walkers = self.nwalkers * self.ntemps
|
|
654
|
+
|
|
655
|
+
# prepare plotting
|
|
656
|
+
# TODO: adjust plotting maybe?
|
|
657
|
+
self.plot_iterations = plot_iterations
|
|
658
|
+
|
|
659
|
+
if plot_generator is None and self.plot_iterations > 0:
|
|
660
|
+
raise NotImplementedError
|
|
661
|
+
# set to default if not provided
|
|
662
|
+
if plot_name is not None:
|
|
663
|
+
name = plot_name
|
|
664
|
+
else:
|
|
665
|
+
name = "output"
|
|
666
|
+
self.plot_generator = PlotContainer(
|
|
667
|
+
fp=name, backend=self.backend, thin_chain_by_ac=True
|
|
668
|
+
)
|
|
669
|
+
elif self.plot_iterations > 0:
|
|
670
|
+
raise NotImplementedError
|
|
671
|
+
self.plot_generator = plot_generator
|
|
672
|
+
|
|
673
|
+
# prepare stopping functions
|
|
674
|
+
self.stopping_fn = stopping_fn
|
|
675
|
+
self.stopping_iterations = stopping_iterations
|
|
676
|
+
|
|
677
|
+
# prepare update functions
|
|
678
|
+
self.update_fn = update_fn
|
|
679
|
+
self.update_iterations = update_iterations
|
|
680
|
+
|
|
681
|
+
@property
|
|
682
|
+
def random_state(self):
|
|
683
|
+
"""
|
|
684
|
+
The state of the internal random number generator. In practice, it's
|
|
685
|
+
the result of calling ``get_state()`` on a
|
|
686
|
+
``numpy.random.mtrand.RandomState`` object. You can try to set this
|
|
687
|
+
property but be warned that if you do this and it fails, it will do
|
|
688
|
+
so silently.
|
|
689
|
+
|
|
690
|
+
"""
|
|
691
|
+
return self._random.get_state()
|
|
692
|
+
|
|
693
|
+
@random_state.setter # NOQA
|
|
694
|
+
def random_state(self, state):
|
|
695
|
+
"""
|
|
696
|
+
Try to set the state of the random number generator but fail silently
|
|
697
|
+
if it doesn't work. Don't say I didn't warn you...
|
|
698
|
+
|
|
699
|
+
"""
|
|
700
|
+
try:
|
|
701
|
+
self._random.set_state(state)
|
|
702
|
+
except:
|
|
703
|
+
pass
|
|
704
|
+
|
|
705
|
+
@property
|
|
706
|
+
def priors(self):
|
|
707
|
+
"""
|
|
708
|
+
Return the priors in the sampler.
|
|
709
|
+
|
|
710
|
+
"""
|
|
711
|
+
return self._priors
|
|
712
|
+
|
|
713
|
+
@priors.setter
|
|
714
|
+
def priors(self, priors):
|
|
715
|
+
"""Set priors information.
|
|
716
|
+
|
|
717
|
+
This performs checks to make sure the inputs are okay.
|
|
718
|
+
|
|
719
|
+
"""
|
|
720
|
+
if isinstance(priors, dict):
|
|
721
|
+
self._priors = {}
|
|
722
|
+
|
|
723
|
+
for key in priors.keys():
|
|
724
|
+
test = priors[key]
|
|
725
|
+
if isinstance(test, dict):
|
|
726
|
+
# check all dists
|
|
727
|
+
for ind, dist in test.items():
|
|
728
|
+
if not hasattr(dist, "logpdf"):
|
|
729
|
+
raise ValueError(
|
|
730
|
+
"Distribution for model {0} and index {1} does not have logpdf method.".format(
|
|
731
|
+
key, ind
|
|
732
|
+
)
|
|
733
|
+
)
|
|
734
|
+
|
|
735
|
+
self._priors[key] = ProbDistContainer(test)
|
|
736
|
+
|
|
737
|
+
elif isinstance(test, ProbDistContainer):
|
|
738
|
+
self._priors[key] = test
|
|
739
|
+
|
|
740
|
+
elif hasattr(test, "logpdf"):
|
|
741
|
+
self._priors[key] = test
|
|
742
|
+
|
|
743
|
+
else:
|
|
744
|
+
raise ValueError(
|
|
745
|
+
"priors dictionary items must be dictionaries with prior information or instances of the ProbDistContainer class."
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
elif isinstance(priors, ProbDistContainer):
|
|
749
|
+
self._priors = {"model_0": priors}
|
|
750
|
+
|
|
751
|
+
else:
|
|
752
|
+
raise ValueError("Priors must be a dictionary.")
|
|
753
|
+
|
|
754
|
+
return
|
|
755
|
+
|
|
756
|
+
@property
|
|
757
|
+
def iteration(self):
|
|
758
|
+
return self.backend.iteration
|
|
759
|
+
|
|
760
|
+
def reset(self, **info):
|
|
761
|
+
"""
|
|
762
|
+
Reset the backend.
|
|
763
|
+
|
|
764
|
+
Args:
|
|
765
|
+
**info (dict, optional): information to pass to backend reset method.
|
|
766
|
+
|
|
767
|
+
"""
|
|
768
|
+
self.backend.reset(self.nwalkers, self.ndims, **info)
|
|
769
|
+
|
|
770
|
+
def __getstate__(self):
|
|
771
|
+
# In order to be generally picklable, we need to discard the pool
|
|
772
|
+
# object before trying.
|
|
773
|
+
d = self.__dict__
|
|
774
|
+
d["pool"] = None
|
|
775
|
+
return d
|
|
776
|
+
|
|
777
|
+
def get_model(self):
|
|
778
|
+
"""Get ``Model`` object from sampler
|
|
779
|
+
|
|
780
|
+
The model object is used to pass necessary information to the
|
|
781
|
+
proposals. This method can be used to retrieve the ``model`` used
|
|
782
|
+
in the sampler from outside the sampler.
|
|
783
|
+
|
|
784
|
+
Returns:
|
|
785
|
+
:class:`Model`: ``Model`` object used by sampler.
|
|
786
|
+
|
|
787
|
+
"""
|
|
788
|
+
# Set up a wrapper around the relevant model functions
|
|
789
|
+
if self.pool is not None:
|
|
790
|
+
map_fn = self.pool.map
|
|
791
|
+
else:
|
|
792
|
+
map_fn = map
|
|
793
|
+
|
|
794
|
+
# setup model framework for passing necessary items
|
|
795
|
+
model = Model(
|
|
796
|
+
self.log_like_fn,
|
|
797
|
+
self.compute_log_like,
|
|
798
|
+
self.compute_log_prior,
|
|
799
|
+
self.temperature_control,
|
|
800
|
+
map_fn,
|
|
801
|
+
self._random,
|
|
802
|
+
)
|
|
803
|
+
return model
|
|
804
|
+
|
|
805
|
+
def sample(
|
|
806
|
+
self,
|
|
807
|
+
initial_state,
|
|
808
|
+
iterations=1,
|
|
809
|
+
tune=False,
|
|
810
|
+
skip_initial_state_check=True,
|
|
811
|
+
thin_by=1,
|
|
812
|
+
store=True,
|
|
813
|
+
progress=False,
|
|
814
|
+
):
|
|
815
|
+
"""Advance the chain as a generator
|
|
816
|
+
|
|
817
|
+
Args:
|
|
818
|
+
initial_state (State or ndarray[ntemps, nwalkers, nleaves_max, ndim] or dict): The initial
|
|
819
|
+
:class:`State` or positions of the walkers in the
|
|
820
|
+
parameter space. If multiple branches used, must be dict with keys
|
|
821
|
+
as the ``branch_names`` and values as the positions. If ``betas`` are
|
|
822
|
+
provided in the state object, they will be loaded into the
|
|
823
|
+
``temperature_control``.
|
|
824
|
+
iterations (int or None, optional): The number of steps to generate.
|
|
825
|
+
``None`` generates an infinite stream (requires ``store=False``).
|
|
826
|
+
(default: 1)
|
|
827
|
+
tune (bool, optional): If ``True``, the parameters of some moves
|
|
828
|
+
will be automatically tuned. (default: ``False``)
|
|
829
|
+
thin_by (int, optional): If you only want to store and yield every
|
|
830
|
+
``thin_by`` samples in the chain, set ``thin_by`` to an
|
|
831
|
+
integer greater than 1. When this is set, ``iterations *
|
|
832
|
+
thin_by`` proposals will be made. (default: 1)
|
|
833
|
+
store (bool, optional): By default, the sampler stores in the backend
|
|
834
|
+
the positions (and other information) of the samples in the
|
|
835
|
+
chain. If you are using another method to store the samples to
|
|
836
|
+
a file or if you don't need to analyze the samples after the
|
|
837
|
+
fact (for burn-in for example) set ``store`` to ``False``. (default: ``True``)
|
|
838
|
+
progress (bool or str, optional): If ``True``, a progress bar will
|
|
839
|
+
be shown as the sampler progresses. If a string, will select a
|
|
840
|
+
specific ``tqdm`` progress bar - most notable is
|
|
841
|
+
``'notebook'``, which shows a progress bar suitable for
|
|
842
|
+
Jupyter notebooks. If ``False``, no progress bar will be
|
|
843
|
+
shown. (default: ``False``)
|
|
844
|
+
skip_initial_state_check (bool, optional): If ``True``, a check
|
|
845
|
+
that the initial_state can fully explore the space will be
|
|
846
|
+
skipped. If using reversible jump, the user needs to ensure this on their own
|
|
847
|
+
(``skip_initial_state_check``is set to ``False`` in this case.
|
|
848
|
+
(default: ``True``)
|
|
849
|
+
|
|
850
|
+
Returns:
|
|
851
|
+
State: Every ``thin_by`` steps, this generator yields the :class:`State` of the ensemble.
|
|
852
|
+
|
|
853
|
+
Raises:
|
|
854
|
+
ValueError: Improper initialization.
|
|
855
|
+
|
|
856
|
+
"""
|
|
857
|
+
if iterations is None and store:
|
|
858
|
+
raise ValueError("'store' must be False when 'iterations' is None")
|
|
859
|
+
|
|
860
|
+
# Interpret the input as a walker state and check the dimensions.
|
|
861
|
+
|
|
862
|
+
# initial_state.__class__ rather than State in case it is a subclass
|
|
863
|
+
# of State
|
|
864
|
+
if (
|
|
865
|
+
hasattr(initial_state, "__class__")
|
|
866
|
+
and issubclass(initial_state.__class__, State)
|
|
867
|
+
and not isinstance(initial_state.__class__, State)
|
|
868
|
+
):
|
|
869
|
+
state = initial_state.__class__(initial_state, copy=True)
|
|
870
|
+
else:
|
|
871
|
+
state = State(initial_state, copy=True)
|
|
872
|
+
|
|
873
|
+
# Check the backend shape
|
|
874
|
+
for i, (name, branch) in enumerate(state.branches.items()):
|
|
875
|
+
ntemps_, nwalkers_, nleaves_, ndim_ = branch.shape
|
|
876
|
+
if (ntemps_, nwalkers_, nleaves_, ndim_) != (
|
|
877
|
+
self.ntemps,
|
|
878
|
+
self.nwalkers,
|
|
879
|
+
self.nleaves_max[name],
|
|
880
|
+
self.ndims[name],
|
|
881
|
+
):
|
|
882
|
+
raise ValueError("incompatible input dimensions")
|
|
883
|
+
|
|
884
|
+
# do an initial state check if is requested and we are not using reversible jump
|
|
885
|
+
if (not skip_initial_state_check) and (
|
|
886
|
+
not walkers_independent(state.coords) and not self.has_reversible_jump
|
|
887
|
+
):
|
|
888
|
+
raise ValueError(
|
|
889
|
+
"Initial state has a large condition number. "
|
|
890
|
+
"Make sure that your walkers are linearly independent for the "
|
|
891
|
+
"best performance"
|
|
892
|
+
)
|
|
893
|
+
|
|
894
|
+
# get log prior and likelihood if not provided in the initial state
|
|
895
|
+
if state.log_prior is None:
|
|
896
|
+
coords = state.branches_coords
|
|
897
|
+
inds = state.branches_inds
|
|
898
|
+
state.log_prior = self.compute_log_prior(coords, inds=inds)
|
|
899
|
+
|
|
900
|
+
if state.log_like is None:
|
|
901
|
+
coords = state.branches_coords
|
|
902
|
+
inds = state.branches_inds
|
|
903
|
+
state.log_like, state.blobs = self.compute_log_like(
|
|
904
|
+
coords,
|
|
905
|
+
inds=inds,
|
|
906
|
+
logp=state.log_prior,
|
|
907
|
+
supps=state.supplemental, # only used if self.provide_supplemental is True
|
|
908
|
+
branch_supps=state.branches_supplemental, # only used if self.provide_supplemental is True
|
|
909
|
+
)
|
|
910
|
+
|
|
911
|
+
# get betas out of state object if they are there
|
|
912
|
+
if state.betas is not None:
|
|
913
|
+
if state.betas.shape[0] != self.ntemps:
|
|
914
|
+
raise ValueError(
|
|
915
|
+
"Input state has inverse temperatures (betas), but not the correct number of temperatures according to sampler inputs."
|
|
916
|
+
)
|
|
917
|
+
|
|
918
|
+
self.temperature_control.betas = state.betas.copy()
|
|
919
|
+
|
|
920
|
+
else:
|
|
921
|
+
if hasattr(self, "temperature_control") and hasattr(
|
|
922
|
+
self.temperature_control, "betas"
|
|
923
|
+
):
|
|
924
|
+
state.betas = self.temperature_control.betas.copy()
|
|
925
|
+
|
|
926
|
+
if np.shape(state.log_like) != (self.ntemps, self.nwalkers):
|
|
927
|
+
raise ValueError("incompatible input dimensions")
|
|
928
|
+
if np.shape(state.log_prior) != (self.ntemps, self.nwalkers):
|
|
929
|
+
raise ValueError("incompatible input dimensions")
|
|
930
|
+
|
|
931
|
+
# Check to make sure that the probability function didn't return
|
|
932
|
+
# ``np.nan``.
|
|
933
|
+
if np.any(np.isnan(state.log_like)):
|
|
934
|
+
raise ValueError("The initial log_like was NaN")
|
|
935
|
+
|
|
936
|
+
if np.any(np.isinf(state.log_like)):
|
|
937
|
+
raise ValueError("The initial log_like was +/- infinite")
|
|
938
|
+
|
|
939
|
+
if np.any(np.isnan(state.log_prior)):
|
|
940
|
+
raise ValueError("The initial log_prior was NaN")
|
|
941
|
+
|
|
942
|
+
if np.any(np.isinf(state.log_prior)):
|
|
943
|
+
raise ValueError("The initial log_prior was +/- infinite")
|
|
944
|
+
|
|
945
|
+
# Check that the thin keyword is reasonable.
|
|
946
|
+
thin_by = int(thin_by)
|
|
947
|
+
if thin_by <= 0:
|
|
948
|
+
raise ValueError("Invalid thinning argument")
|
|
949
|
+
|
|
950
|
+
yield_step = thin_by
|
|
951
|
+
checkpoint_step = thin_by
|
|
952
|
+
if store:
|
|
953
|
+
self.backend.grow(iterations, state.blobs)
|
|
954
|
+
|
|
955
|
+
# get the model object
|
|
956
|
+
model = self.get_model()
|
|
957
|
+
|
|
958
|
+
# Inject the progress bar
|
|
959
|
+
total = None if iterations is None else iterations * yield_step
|
|
960
|
+
with get_progress_bar(progress, total) as pbar:
|
|
961
|
+
i = 0
|
|
962
|
+
for _ in count() if iterations is None else range(iterations):
|
|
963
|
+
for _ in range(yield_step):
|
|
964
|
+
# in model moves
|
|
965
|
+
accepted = np.zeros((self.ntemps, self.nwalkers))
|
|
966
|
+
for repeat in range(self.num_repeats_in_model):
|
|
967
|
+
# Choose a random move
|
|
968
|
+
move = self._random.choice(self.moves, p=self.weights)
|
|
969
|
+
|
|
970
|
+
# Propose (in model)
|
|
971
|
+
state, accepted_out = move.propose(model, state)
|
|
972
|
+
accepted += accepted_out
|
|
973
|
+
if self.ntemps > 1:
|
|
974
|
+
in_model_swaps = move.temperature_control.swaps_accepted
|
|
975
|
+
else:
|
|
976
|
+
in_model_swaps = None
|
|
977
|
+
|
|
978
|
+
state.random_state = self.random_state
|
|
979
|
+
|
|
980
|
+
if tune:
|
|
981
|
+
move.tune(state, accepted_out)
|
|
982
|
+
|
|
983
|
+
if self.has_reversible_jump:
|
|
984
|
+
rj_accepted = np.zeros((self.ntemps, self.nwalkers))
|
|
985
|
+
for repeat in range(self.num_repeats_rj):
|
|
986
|
+
rj_move = self._random.choice(
|
|
987
|
+
self.rj_moves, p=self.rj_weights
|
|
988
|
+
)
|
|
989
|
+
|
|
990
|
+
# Propose (Between models)
|
|
991
|
+
state, rj_accepted_out = rj_move.propose(model, state)
|
|
992
|
+
rj_accepted += rj_accepted_out
|
|
993
|
+
# Again commenting out this section: We do not control temperature on RJ moves
|
|
994
|
+
# if self.ntemps > 1:
|
|
995
|
+
# rj_swaps = rj_move.temperature_control.swaps_accepted
|
|
996
|
+
# else:
|
|
997
|
+
# rj_swaps = None
|
|
998
|
+
rj_swaps = None
|
|
999
|
+
|
|
1000
|
+
state.random_state = self.random_state
|
|
1001
|
+
|
|
1002
|
+
if tune:
|
|
1003
|
+
rj_move.tune(state, rj_accepted_out)
|
|
1004
|
+
|
|
1005
|
+
else:
|
|
1006
|
+
rj_accepted = None
|
|
1007
|
+
rj_swaps = None
|
|
1008
|
+
|
|
1009
|
+
# Save the new step
|
|
1010
|
+
if store and (i + 1) % checkpoint_step == 0:
|
|
1011
|
+
if self.track_moves:
|
|
1012
|
+
moves_accepted_fraction = {
|
|
1013
|
+
key: move_tmp.acceptance_fraction
|
|
1014
|
+
for key, move_tmp in self.all_moves.items()
|
|
1015
|
+
}
|
|
1016
|
+
else:
|
|
1017
|
+
moves_accepted_fraction = None
|
|
1018
|
+
|
|
1019
|
+
self.backend.save_step(
|
|
1020
|
+
state,
|
|
1021
|
+
accepted,
|
|
1022
|
+
rj_accepted=rj_accepted,
|
|
1023
|
+
swaps_accepted=in_model_swaps,
|
|
1024
|
+
moves_accepted_fraction=moves_accepted_fraction,
|
|
1025
|
+
)
|
|
1026
|
+
|
|
1027
|
+
# update after diagnostic and stopping check
|
|
1028
|
+
# if updating and using burn_in, need to make sure it does not use
|
|
1029
|
+
# previous chain samples since they are not stored.
|
|
1030
|
+
if (
|
|
1031
|
+
self.update_iterations > 0
|
|
1032
|
+
and self.update_fn is not None
|
|
1033
|
+
and (i + 1) % (self.update_iterations) == 0
|
|
1034
|
+
):
|
|
1035
|
+
self.update_fn(i, state, self)
|
|
1036
|
+
|
|
1037
|
+
pbar.update(1)
|
|
1038
|
+
i += 1
|
|
1039
|
+
|
|
1040
|
+
# Yield the result as an iterator so that the user can do all
|
|
1041
|
+
# sorts of fun stuff with the results so far.
|
|
1042
|
+
yield state
|
|
1043
|
+
|
|
1044
|
+
def run_mcmc(
|
|
1045
|
+
self, initial_state, nsteps, burn=None, post_burn_update=False, **kwargs
|
|
1046
|
+
):
|
|
1047
|
+
"""
|
|
1048
|
+
Iterate :func:`sample` for ``nsteps`` iterations and return the result.
|
|
1049
|
+
|
|
1050
|
+
Args:
|
|
1051
|
+
initial_state (State or ndarray[ntemps, nwalkers, nleaves_max, ndim] or dict): The initial
|
|
1052
|
+
:class:`State` or positions of the walkers in the
|
|
1053
|
+
parameter space. If multiple branches used, must be dict with keys
|
|
1054
|
+
as the ``branch_names`` and values as the positions. If ``betas`` are
|
|
1055
|
+
provided in the state object, they will be loaded into the
|
|
1056
|
+
``temperature_control``.
|
|
1057
|
+
nsteps (int): The number of steps to generate. The total number of proposals is ``nsteps * thin_by``.
|
|
1058
|
+
burn (int, optional): Number of burn steps to run before storing information. The ``thin_by`` kwarg is ignored when counting burn steps since there is no storage (equivalent to ``thin_by=1``).
|
|
1059
|
+
post_burn_update (bool, optional): If ``True``, run ``update_fn`` after burn in.
|
|
1060
|
+
|
|
1061
|
+
Other parameters are directly passed to :func:`sample`.
|
|
1062
|
+
|
|
1063
|
+
Returns:
|
|
1064
|
+
State: This method returns the most recent result from :func:`sample`.
|
|
1065
|
+
|
|
1066
|
+
Raises:
|
|
1067
|
+
ValueError: ``If initial_state`` is None and ``run_mcmc`` has never been called.
|
|
1068
|
+
|
|
1069
|
+
"""
|
|
1070
|
+
if initial_state is None:
|
|
1071
|
+
if self._previous_state is None:
|
|
1072
|
+
raise ValueError(
|
|
1073
|
+
"Cannot have `initial_state=None` if run_mcmc has never "
|
|
1074
|
+
"been called."
|
|
1075
|
+
)
|
|
1076
|
+
initial_state = self._previous_state
|
|
1077
|
+
|
|
1078
|
+
# run burn in
|
|
1079
|
+
if burn is not None and burn != 0:
|
|
1080
|
+
# prepare kwargs that relate to burn
|
|
1081
|
+
burn_kwargs = deepcopy(kwargs)
|
|
1082
|
+
burn_kwargs["store"] = False
|
|
1083
|
+
burn_kwargs["thin_by"] = 1
|
|
1084
|
+
i = 0
|
|
1085
|
+
for results in self.sample(initial_state, iterations=burn, **burn_kwargs):
|
|
1086
|
+
i += 1
|
|
1087
|
+
|
|
1088
|
+
# run post-burn update
|
|
1089
|
+
if post_burn_update and self.update_fn is not None:
|
|
1090
|
+
self.update_fn(i, results, self)
|
|
1091
|
+
|
|
1092
|
+
initial_state = results
|
|
1093
|
+
|
|
1094
|
+
if nsteps == 0:
|
|
1095
|
+
return initial_state
|
|
1096
|
+
|
|
1097
|
+
results = None
|
|
1098
|
+
|
|
1099
|
+
i = 0
|
|
1100
|
+
for results in self.sample(initial_state, iterations=nsteps, **kwargs):
|
|
1101
|
+
# diagnostic plots
|
|
1102
|
+
# TODO: adjust diagnostic plots
|
|
1103
|
+
if self.plot_iterations > 0 and (i + 1) % (self.plot_iterations) == 0:
|
|
1104
|
+
self.plot_generator.generate_plot_info() # TODO: remove defaults
|
|
1105
|
+
|
|
1106
|
+
# check for stopping before updating
|
|
1107
|
+
if (
|
|
1108
|
+
self.stopping_iterations > 0
|
|
1109
|
+
and self.stopping_fn is not None
|
|
1110
|
+
and (i + 1) % (self.stopping_iterations) == 0
|
|
1111
|
+
):
|
|
1112
|
+
stop = self.stopping_fn(i, results, self)
|
|
1113
|
+
|
|
1114
|
+
if stop:
|
|
1115
|
+
break
|
|
1116
|
+
|
|
1117
|
+
i += 1
|
|
1118
|
+
|
|
1119
|
+
# Store so that the ``initial_state=None`` case will work
|
|
1120
|
+
self._previous_state = results
|
|
1121
|
+
|
|
1122
|
+
return results
|
|
1123
|
+
|
|
1124
|
+
def compute_log_prior(self, coords, inds=None, supps=None, branch_supps=None):
|
|
1125
|
+
"""Calculate the vector of log-prior for the walkers
|
|
1126
|
+
|
|
1127
|
+
Args:
|
|
1128
|
+
coords (dict): Keys are ``branch_names`` and values are
|
|
1129
|
+
the position np.arrays[ntemps, nwalkers, nleaves_max, ndim].
|
|
1130
|
+
This dictionary is created with the ``branches_coords`` attribute
|
|
1131
|
+
from :class:`State`.
|
|
1132
|
+
inds (dict, optional): Keys are ``branch_names`` and values are
|
|
1133
|
+
the ``inds`` np.arrays[ntemps, nwalkers, nleaves_max] that indicates
|
|
1134
|
+
which leaves are being used. This dictionary is created with the
|
|
1135
|
+
``branches_inds`` attribute from :class:`State`.
|
|
1136
|
+
(default: ``None``)
|
|
1137
|
+
|
|
1138
|
+
Returns:
|
|
1139
|
+
np.ndarray[ntemps, nwalkers]: Prior Values
|
|
1140
|
+
|
|
1141
|
+
"""
|
|
1142
|
+
|
|
1143
|
+
# get number of temperature and walkers
|
|
1144
|
+
ntemps, nwalkers, _, _ = coords[list(coords.keys())[0]].shape
|
|
1145
|
+
|
|
1146
|
+
if inds is None:
|
|
1147
|
+
# default use all sources
|
|
1148
|
+
inds = {
|
|
1149
|
+
name: np.full(coords[name].shape[:-1], True, dtype=bool)
|
|
1150
|
+
for name in coords
|
|
1151
|
+
}
|
|
1152
|
+
|
|
1153
|
+
# take information out of dict and spread to x1..xn
|
|
1154
|
+
x_in = {}
|
|
1155
|
+
|
|
1156
|
+
# for completely customizable priors
|
|
1157
|
+
if "all_models_together" in self.priors:
|
|
1158
|
+
prior_out = self.priors["all_models_together"].logpdf(
|
|
1159
|
+
coords, inds, supps=supps, branch_supps=branch_supps
|
|
1160
|
+
)
|
|
1161
|
+
assert prior_out.shape == (ntemps, nwalkers)
|
|
1162
|
+
|
|
1163
|
+
elif self.provide_groups:
|
|
1164
|
+
# get group information from the inds dict
|
|
1165
|
+
groups = groups_from_inds(inds)
|
|
1166
|
+
|
|
1167
|
+
# get the coordinates that are used
|
|
1168
|
+
for i, (name, coords_i) in enumerate(coords.items()):
|
|
1169
|
+
x_in[name] = coords_i[inds[name]]
|
|
1170
|
+
|
|
1171
|
+
prior_out = np.zeros((ntemps * nwalkers))
|
|
1172
|
+
for name in x_in:
|
|
1173
|
+
# get prior for individual binaries
|
|
1174
|
+
prior_out_temp = self.priors[name].logpdf(x_in[name])
|
|
1175
|
+
|
|
1176
|
+
# arrange prior values by groups
|
|
1177
|
+
# TODO: vectorize this?
|
|
1178
|
+
for i in np.unique(groups[name]):
|
|
1179
|
+
# which members are in the group i
|
|
1180
|
+
inds_temp = np.where(groups[name] == i)[0]
|
|
1181
|
+
# num_in_group = len(inds_temp)
|
|
1182
|
+
|
|
1183
|
+
# add to the prior for this group
|
|
1184
|
+
prior_out[i] += prior_out_temp[inds_temp].sum()
|
|
1185
|
+
|
|
1186
|
+
# reshape
|
|
1187
|
+
prior_out = prior_out.reshape(ntemps, nwalkers)
|
|
1188
|
+
|
|
1189
|
+
else:
|
|
1190
|
+
# flatten coordinate arrays
|
|
1191
|
+
for i, (name, coords_i) in enumerate(coords.items()):
|
|
1192
|
+
ntemps, nwalkers, nleaves_max, ndim = coords_i.shape
|
|
1193
|
+
|
|
1194
|
+
x_in[name] = coords_i.reshape(-1, ndim)
|
|
1195
|
+
|
|
1196
|
+
prior_out = np.zeros((ntemps, nwalkers))
|
|
1197
|
+
for name in x_in:
|
|
1198
|
+
ntemps, nwalkers, nleaves_max, ndim = coords[name].shape
|
|
1199
|
+
prior_out_temp = (
|
|
1200
|
+
self.priors[name]
|
|
1201
|
+
.logpdf(x_in[name])
|
|
1202
|
+
.reshape(ntemps, nwalkers, nleaves_max)
|
|
1203
|
+
)
|
|
1204
|
+
|
|
1205
|
+
# fix any infs / nans from binaries that are not being used (inds == False)
|
|
1206
|
+
prior_out_temp[~inds[name]] = 0.0
|
|
1207
|
+
|
|
1208
|
+
# vectorized because everything is rectangular (no groups to indicate model difference)
|
|
1209
|
+
prior_out += prior_out_temp.sum(axis=-1)
|
|
1210
|
+
|
|
1211
|
+
return prior_out
|
|
1212
|
+
|
|
1213
|
+
def compute_log_like(
|
|
1214
|
+
self, coords, inds=None, logp=None, supps=None, branch_supps=None
|
|
1215
|
+
):
|
|
1216
|
+
"""Calculate the vector of log-likelihood for the walkers
|
|
1217
|
+
|
|
1218
|
+
Args:
|
|
1219
|
+
coords (dict): Keys are ``branch_names`` and values are
|
|
1220
|
+
the position np.arrays[ntemps, nwalkers, nleaves_max, ndim].
|
|
1221
|
+
This dictionary is created with the ``branches_coords`` attribute
|
|
1222
|
+
from :class:`State`.
|
|
1223
|
+
inds (dict, optional): Keys are ``branch_names`` and values are
|
|
1224
|
+
the inds np.arrays[ntemps, nwalkers, nleaves_max] that indicates
|
|
1225
|
+
which leaves are being used. This dictionary is created with the
|
|
1226
|
+
``branches_inds`` attribute from :class:`State`.
|
|
1227
|
+
(default: ``None``)
|
|
1228
|
+
logp (np.ndarray[ntemps, nwalkers], optional): Log prior values associated
|
|
1229
|
+
with all walkers. If not provided, it will be calculated because
|
|
1230
|
+
if a walker has logp = -inf, its likelihood is not calculated.
|
|
1231
|
+
This prevents evaluting likelihood outside the prior.
|
|
1232
|
+
(default: ``None``)
|
|
1233
|
+
|
|
1234
|
+
Returns:
|
|
1235
|
+
tuple: Carries log-likelihood and blob information.
|
|
1236
|
+
First entry is np.ndarray[ntemps, nwalkers] with values corresponding
|
|
1237
|
+
to the log likelihood of each walker. Second entry is ``blobs``.
|
|
1238
|
+
|
|
1239
|
+
Raises:
|
|
1240
|
+
ValueError: Infinite or NaN values in parameters.
|
|
1241
|
+
|
|
1242
|
+
"""
|
|
1243
|
+
|
|
1244
|
+
# if inds not provided, use all
|
|
1245
|
+
if inds is None:
|
|
1246
|
+
inds = {
|
|
1247
|
+
name: np.full(coords[name].shape[:-1], True, dtype=bool)
|
|
1248
|
+
for name in coords
|
|
1249
|
+
}
|
|
1250
|
+
|
|
1251
|
+
# Check that the parameters are in physical ranges.
|
|
1252
|
+
for name, ptemp in coords.items():
|
|
1253
|
+
if np.any(np.isinf(ptemp[inds[name]])):
|
|
1254
|
+
raise ValueError("At least one parameter value was infinite")
|
|
1255
|
+
if np.any(np.isnan(ptemp[inds[name]])):
|
|
1256
|
+
raise ValueError("At least one parameter value was NaN")
|
|
1257
|
+
|
|
1258
|
+
# if no prior values are added, compute_prior
|
|
1259
|
+
# this is necessary to ensure Likelihood is not evaluated outside of the prior
|
|
1260
|
+
if logp is None:
|
|
1261
|
+
logp = self.compute_log_prior(
|
|
1262
|
+
coords, inds=inds, supps=supps, branch_supps=branch_supps
|
|
1263
|
+
)
|
|
1264
|
+
|
|
1265
|
+
# if all points are outside the prior
|
|
1266
|
+
if np.all(np.isinf(logp)):
|
|
1267
|
+
warnings.warn(
|
|
1268
|
+
"All points input for the Likelihood have a log prior of -inf."
|
|
1269
|
+
)
|
|
1270
|
+
return np.full_like(logp, -1e300), None
|
|
1271
|
+
|
|
1272
|
+
# do not run log likelihood where logp = -inf
|
|
1273
|
+
inds_copy = deepcopy(inds)
|
|
1274
|
+
inds_bad = np.where(np.isinf(logp))
|
|
1275
|
+
for key in inds_copy:
|
|
1276
|
+
inds_copy[key][inds_bad] = False
|
|
1277
|
+
|
|
1278
|
+
# if inds_keep in branch supps, indicate which to not keep
|
|
1279
|
+
if (
|
|
1280
|
+
branch_supps is not None
|
|
1281
|
+
and key in branch_supps
|
|
1282
|
+
and branch_supps[key] is not None
|
|
1283
|
+
and "inds_keep" in branch_supps[key]
|
|
1284
|
+
):
|
|
1285
|
+
# TODO: indicate specialty of inds_keep in branch_supp
|
|
1286
|
+
branch_supps[key][inds_bad] = {"inds_keep": False}
|
|
1287
|
+
|
|
1288
|
+
# take information out of dict and spread to x1..xn
|
|
1289
|
+
x_in = {}
|
|
1290
|
+
if self.provide_supplemental:
|
|
1291
|
+
if supps is None and branch_supps is None:
|
|
1292
|
+
raise ValueError(
|
|
1293
|
+
"""supps and branch_supps are both None. If self.provide_supplemental
|
|
1294
|
+
is True, must provide some supplemental information."""
|
|
1295
|
+
)
|
|
1296
|
+
if branch_supps is not None:
|
|
1297
|
+
branch_supps_in = {}
|
|
1298
|
+
|
|
1299
|
+
# determine groupings from inds
|
|
1300
|
+
groups = groups_from_inds(inds_copy)
|
|
1301
|
+
|
|
1302
|
+
# need to map group inds properly
|
|
1303
|
+
# this is the unique group indexes
|
|
1304
|
+
unique_groups = np.unique(
|
|
1305
|
+
np.concatenate([groups_i for groups_i in groups.values()])
|
|
1306
|
+
)
|
|
1307
|
+
|
|
1308
|
+
# this is the map to those indexes that are used in the likelihood
|
|
1309
|
+
groups_map = np.arange(len(unique_groups))
|
|
1310
|
+
|
|
1311
|
+
# get the indices with groups_map for the Likelihood
|
|
1312
|
+
ll_groups = {}
|
|
1313
|
+
for key, group in groups.items():
|
|
1314
|
+
# get unique groups in this sub-group (or branch)
|
|
1315
|
+
temp_unique_groups, inverse = np.unique(group, return_inverse=True)
|
|
1316
|
+
|
|
1317
|
+
# use groups_map by finding where temp_unique_groups overlaps with unique_groups
|
|
1318
|
+
keep_groups = groups_map[np.in1d(unique_groups, temp_unique_groups)]
|
|
1319
|
+
|
|
1320
|
+
# fill group information for Likelihood
|
|
1321
|
+
ll_groups[key] = keep_groups[inverse]
|
|
1322
|
+
|
|
1323
|
+
for i, (name, coords_i) in enumerate(coords.items()):
|
|
1324
|
+
ntemps, nwalkers, nleaves_max, ndim = coords_i.shape
|
|
1325
|
+
nwalkers_all = ntemps * nwalkers
|
|
1326
|
+
|
|
1327
|
+
# fill x_values properly into dictionary
|
|
1328
|
+
x_in[name] = coords_i[inds_copy[name]]
|
|
1329
|
+
|
|
1330
|
+
# prepare branch supplementals for each branch
|
|
1331
|
+
if self.provide_supplemental:
|
|
1332
|
+
if branch_supps is not None: # and
|
|
1333
|
+
if branch_supps[name] is not None:
|
|
1334
|
+
# index the branch supps
|
|
1335
|
+
# it will carry in a dictionary of information
|
|
1336
|
+
branch_supps_in[name] = branch_supps[name][inds_copy[name]]
|
|
1337
|
+
else:
|
|
1338
|
+
# fill with None if this branch does not have a supplemental
|
|
1339
|
+
branch_supps_in[name] = None
|
|
1340
|
+
|
|
1341
|
+
# deal with overall supplemental not specific to the branches
|
|
1342
|
+
if self.provide_supplemental:
|
|
1343
|
+
if supps is not None:
|
|
1344
|
+
# get the flattened supplemental
|
|
1345
|
+
# this will produce the shape (ntemps * nwalkers,...)
|
|
1346
|
+
temp = supps.flat
|
|
1347
|
+
|
|
1348
|
+
# unique_groups will properly index the flattened array
|
|
1349
|
+
supps_in = {
|
|
1350
|
+
name: values[unique_groups] for name, values in temp.items()
|
|
1351
|
+
}
|
|
1352
|
+
|
|
1353
|
+
# prepare group information
|
|
1354
|
+
# this gets the group_map indexing into a list
|
|
1355
|
+
groups_in = list(ll_groups.values())
|
|
1356
|
+
|
|
1357
|
+
# if only one branch, take the group array out of the list
|
|
1358
|
+
if len(groups_in) == 1:
|
|
1359
|
+
groups_in = groups_in[0]
|
|
1360
|
+
|
|
1361
|
+
# list of paramter arrays
|
|
1362
|
+
params_in = list(x_in.values())
|
|
1363
|
+
|
|
1364
|
+
# Likelihoods are vectorized across groups
|
|
1365
|
+
if self.vectorize:
|
|
1366
|
+
# prepare args list
|
|
1367
|
+
args_in = []
|
|
1368
|
+
|
|
1369
|
+
# when vectorizing, if params_in has one entry, take out of list
|
|
1370
|
+
if len(params_in) == 1:
|
|
1371
|
+
params_in = params_in[0]
|
|
1372
|
+
|
|
1373
|
+
# add parameters to args
|
|
1374
|
+
args_in.append(params_in)
|
|
1375
|
+
|
|
1376
|
+
# if providing groups, add to args
|
|
1377
|
+
if self.provide_groups:
|
|
1378
|
+
args_in.append(groups_in)
|
|
1379
|
+
|
|
1380
|
+
# prepare supplementals as kwargs to the Likelihood
|
|
1381
|
+
kwargs_in = {}
|
|
1382
|
+
if self.provide_supplemental:
|
|
1383
|
+
if supps is not None:
|
|
1384
|
+
kwargs_in["supps"] = supps_in
|
|
1385
|
+
if branch_supps is not None:
|
|
1386
|
+
# get list of branch_supps values
|
|
1387
|
+
branch_supps_in_2 = list(branch_supps_in.values())
|
|
1388
|
+
|
|
1389
|
+
# if only one entry, take out of list
|
|
1390
|
+
if len(branch_supps_in_2) == 1:
|
|
1391
|
+
kwargs_in["branch_supps"] = branch_supps_in_2[0]
|
|
1392
|
+
|
|
1393
|
+
else:
|
|
1394
|
+
kwargs_in["branch_supps"] = branch_supps_in_2
|
|
1395
|
+
|
|
1396
|
+
# provide args, kwargs as a tuple
|
|
1397
|
+
args_and_kwargs = (args_in, kwargs_in)
|
|
1398
|
+
|
|
1399
|
+
# get vectorized results
|
|
1400
|
+
results = self.log_like_fn(args_and_kwargs)
|
|
1401
|
+
|
|
1402
|
+
# each Likelihood is computed individually
|
|
1403
|
+
else:
|
|
1404
|
+
# if groups in is an array, need to put it in a list.
|
|
1405
|
+
if isinstance(groups_in, np.ndarray):
|
|
1406
|
+
groups_in = [groups_in]
|
|
1407
|
+
|
|
1408
|
+
# prepare input args for all Likelihood calls
|
|
1409
|
+
# to be spread out with map functions below
|
|
1410
|
+
args_in = []
|
|
1411
|
+
|
|
1412
|
+
# each individual group in the groups_map
|
|
1413
|
+
for group_i in groups_map:
|
|
1414
|
+
# args and kwargs for the individual Likelihood
|
|
1415
|
+
arg_i = [None for _ in self.branch_names]
|
|
1416
|
+
kwarg_i = {}
|
|
1417
|
+
|
|
1418
|
+
# iterate over the group information from the branches
|
|
1419
|
+
for branch_i, groups_in_set in enumerate(groups_in):
|
|
1420
|
+
# which entries in this branch are in the overall group tested
|
|
1421
|
+
# this accounts for multiple leaves (or model counts)
|
|
1422
|
+
inds_keep = np.where(groups_in_set == group_i)[0]
|
|
1423
|
+
|
|
1424
|
+
branch_name_i = self.branch_names[branch_i]
|
|
1425
|
+
|
|
1426
|
+
if inds_keep.shape[0] > 0:
|
|
1427
|
+
# get parameters
|
|
1428
|
+
|
|
1429
|
+
params = params_in[branch_i][inds_keep]
|
|
1430
|
+
|
|
1431
|
+
# if leaf count is constant and leaf count is 1
|
|
1432
|
+
# just give 1D parameters
|
|
1433
|
+
if not self.has_reversible_jump and params.shape[0] == 1:
|
|
1434
|
+
params = params[0]
|
|
1435
|
+
|
|
1436
|
+
# add them to the specific args for this Likelihood
|
|
1437
|
+
arg_i[branch_i] = params
|
|
1438
|
+
if self.provide_supplemental:
|
|
1439
|
+
if supps is not None:
|
|
1440
|
+
# supps are specific to each group
|
|
1441
|
+
kwarg_i["supps"] = {
|
|
1442
|
+
key: supps_in[key][group_i] for key in supps_in
|
|
1443
|
+
}
|
|
1444
|
+
if branch_supps is not None:
|
|
1445
|
+
# make sure there is a dictionary ready in this kwarg dictionary
|
|
1446
|
+
if "branch_supps" not in kwarg_i:
|
|
1447
|
+
kwarg_i["branch_supps"] = {}
|
|
1448
|
+
|
|
1449
|
+
# fill these branch supplementals for the specific group
|
|
1450
|
+
if branch_supps_in[branch_name_i] is not None:
|
|
1451
|
+
# get list of branch_supps values
|
|
1452
|
+
kwarg_i["branch_supps"][branch_name_i] = (
|
|
1453
|
+
branch_supps_in[branch_name_i][inds_keep]
|
|
1454
|
+
)
|
|
1455
|
+
else:
|
|
1456
|
+
kwarg_i["branch_supps"][branch_name_i] = None
|
|
1457
|
+
|
|
1458
|
+
# if only one model type, will take out of groups
|
|
1459
|
+
add_term = arg_i[0] if len(groups_in) == 1 else arg_i
|
|
1460
|
+
|
|
1461
|
+
# based on how this is dealth with in the _FunctionWrapper
|
|
1462
|
+
# add_term is wrapped in a list
|
|
1463
|
+
args_in.append([[add_term], kwarg_i])
|
|
1464
|
+
|
|
1465
|
+
# If the `pool` property of the sampler has been set (i.e. we want
|
|
1466
|
+
# to use `multiprocessing`), use the `pool`'s map method.
|
|
1467
|
+
# Otherwise, just use the built-in `map` function.
|
|
1468
|
+
if self.pool is not None:
|
|
1469
|
+
map_func = self.pool.map
|
|
1470
|
+
|
|
1471
|
+
else:
|
|
1472
|
+
map_func = map
|
|
1473
|
+
|
|
1474
|
+
# get results and turn into an array
|
|
1475
|
+
results = np.asarray(list(map_func(self.log_like_fn, args_in)))
|
|
1476
|
+
|
|
1477
|
+
assert isinstance(results, np.ndarray)
|
|
1478
|
+
|
|
1479
|
+
# -1e300 because -np.inf screws up state acceptance transfer in proposals
|
|
1480
|
+
ll = np.full(nwalkers_all, -1e300)
|
|
1481
|
+
inds_fix_zeros = np.delete(np.arange(nwalkers_all), unique_groups)
|
|
1482
|
+
|
|
1483
|
+
# make sure second dimension is not 1
|
|
1484
|
+
if results.ndim == 2 and results.shape[1] == 1:
|
|
1485
|
+
results = np.squeeze(results)
|
|
1486
|
+
|
|
1487
|
+
# parse the results if it has blobs
|
|
1488
|
+
if results.ndim == 2:
|
|
1489
|
+
# get the results and put into groups that were analyzed
|
|
1490
|
+
ll[unique_groups] = results[:, 0]
|
|
1491
|
+
|
|
1492
|
+
# fix groups that were not analyzed
|
|
1493
|
+
ll[inds_fix_zeros] = self.fill_zero_leaves_val
|
|
1494
|
+
|
|
1495
|
+
# deal with blobs
|
|
1496
|
+
blobs_out = np.zeros((nwalkers_all, results.shape[1] - 1))
|
|
1497
|
+
blobs_out[unique_groups] = results[:, 1:]
|
|
1498
|
+
|
|
1499
|
+
elif results.dtype == "object":
|
|
1500
|
+
# TODO: check blobs and add this capability
|
|
1501
|
+
raise NotImplementedError
|
|
1502
|
+
|
|
1503
|
+
else:
|
|
1504
|
+
# no blobs
|
|
1505
|
+
ll[unique_groups] = results
|
|
1506
|
+
ll[inds_fix_zeros] = self.fill_zero_leaves_val
|
|
1507
|
+
|
|
1508
|
+
blobs_out = None
|
|
1509
|
+
|
|
1510
|
+
if False: # self.provide_supplemental:
|
|
1511
|
+
# TODO: need to think about how to return information, we may need to add a function to do that
|
|
1512
|
+
if branch_supps is not None:
|
|
1513
|
+
for name_i, name in enumerate(branch_supps):
|
|
1514
|
+
if branch_supps[name] is not None:
|
|
1515
|
+
# TODO: better way to do this? limit to
|
|
1516
|
+
if "inds_keep" in branch_supps[name]:
|
|
1517
|
+
inds_back = branch_supps[name][:]["inds_keep"]
|
|
1518
|
+
inds_back2 = branch_supps_in[name]["inds_keep"]
|
|
1519
|
+
else:
|
|
1520
|
+
inds_back = inds_copy[name]
|
|
1521
|
+
inds_back2 = slice(None)
|
|
1522
|
+
try:
|
|
1523
|
+
branch_supps[name][inds_back] = {
|
|
1524
|
+
key: branch_supps_in_2[name_i][key][inds_back2]
|
|
1525
|
+
for key in branch_supps_in_2[name_i]
|
|
1526
|
+
}
|
|
1527
|
+
except ValueError:
|
|
1528
|
+
breakpoint()
|
|
1529
|
+
branch_supps[name][inds_back] = {
|
|
1530
|
+
key: branch_supps_in_2[name_i][key][inds_back2]
|
|
1531
|
+
for key in branch_supps_in_2[name_i]
|
|
1532
|
+
}
|
|
1533
|
+
|
|
1534
|
+
# return Likelihood and blobs
|
|
1535
|
+
return ll.reshape(ntemps, nwalkers), blobs_out
|
|
1536
|
+
|
|
1537
|
+
@property
|
|
1538
|
+
def acceptance_fraction(self):
|
|
1539
|
+
"""The fraction of proposed steps that were accepted"""
|
|
1540
|
+
return self.backend.accepted / float(self.backend.iteration)
|
|
1541
|
+
|
|
1542
|
+
@property
|
|
1543
|
+
def rj_acceptance_fraction(self):
|
|
1544
|
+
"""The fraction of proposed reversible jump steps that were accepted"""
|
|
1545
|
+
if self.has_reversible_jump:
|
|
1546
|
+
return self.backend.rj_accepted / float(self.backend.iteration)
|
|
1547
|
+
else:
|
|
1548
|
+
return None
|
|
1549
|
+
|
|
1550
|
+
@property
|
|
1551
|
+
def swap_acceptance_fraction(self):
|
|
1552
|
+
"""The fraction of proposed temperature swaps that were accepted"""
|
|
1553
|
+
return self.backend.swaps_accepted / float(
|
|
1554
|
+
self.backend.iteration * self.nwalkers
|
|
1555
|
+
)
|
|
1556
|
+
|
|
1557
|
+
def get_chain(self, **kwargs):
|
|
1558
|
+
return self.get_value("chain", **kwargs)
|
|
1559
|
+
|
|
1560
|
+
get_chain.__doc__ = Backend.get_chain.__doc__
|
|
1561
|
+
|
|
1562
|
+
def get_blobs(self, **kwargs):
|
|
1563
|
+
return self.get_value("blobs", **kwargs)
|
|
1564
|
+
|
|
1565
|
+
get_blobs.__doc__ = Backend.get_blobs.__doc__
|
|
1566
|
+
|
|
1567
|
+
def get_log_like(self, **kwargs):
|
|
1568
|
+
return self.backend.get_log_like(**kwargs)
|
|
1569
|
+
|
|
1570
|
+
get_log_like.__doc__ = Backend.get_log_prior.__doc__
|
|
1571
|
+
|
|
1572
|
+
def get_log_prior(self, **kwargs):
|
|
1573
|
+
return self.backend.get_log_prior(**kwargs)
|
|
1574
|
+
|
|
1575
|
+
get_log_prior.__doc__ = Backend.get_log_prior.__doc__
|
|
1576
|
+
|
|
1577
|
+
def get_log_posterior(self, **kwargs):
|
|
1578
|
+
return self.backend.get_log_posterior(**kwargs)
|
|
1579
|
+
|
|
1580
|
+
get_log_posterior.__doc__ = Backend.get_log_posterior.__doc__
|
|
1581
|
+
|
|
1582
|
+
def get_inds(self, **kwargs):
|
|
1583
|
+
return self.get_value("inds", **kwargs)
|
|
1584
|
+
|
|
1585
|
+
get_inds.__doc__ = Backend.get_inds.__doc__
|
|
1586
|
+
|
|
1587
|
+
def get_nleaves(self, **kwargs):
|
|
1588
|
+
return self.backend.get_nleaves(**kwargs)
|
|
1589
|
+
|
|
1590
|
+
get_nleaves.__doc__ = Backend.get_nleaves.__doc__
|
|
1591
|
+
|
|
1592
|
+
def get_last_sample(self, **kwargs):
|
|
1593
|
+
return self.backend.get_last_sample()
|
|
1594
|
+
|
|
1595
|
+
get_last_sample.__doc__ = Backend.get_last_sample.__doc__
|
|
1596
|
+
|
|
1597
|
+
def get_betas(self, **kwargs):
|
|
1598
|
+
return self.backend.get_betas(**kwargs)
|
|
1599
|
+
|
|
1600
|
+
get_betas.__doc__ = Backend.get_betas.__doc__
|
|
1601
|
+
|
|
1602
|
+
def get_value(self, name, **kwargs):
|
|
1603
|
+
"""Get a specific value"""
|
|
1604
|
+
return self.backend.get_value(name, **kwargs)
|
|
1605
|
+
|
|
1606
|
+
def get_autocorr_time(self, **kwargs):
|
|
1607
|
+
"""Compute autocorrelation time through backend."""
|
|
1608
|
+
return self.backend.get_autocorr_time(**kwargs)
|
|
1609
|
+
|
|
1610
|
+
get_autocorr_time.__doc__ = Backend.get_autocorr_time.__doc__
|
|
1611
|
+
|
|
1612
|
+
|
|
1613
|
+
class _FunctionWrapper(object):
|
|
1614
|
+
"""
|
|
1615
|
+
This is a hack to make the likelihood function pickleable when ``args``
|
|
1616
|
+
or ``kwargs`` are also included.
|
|
1617
|
+
|
|
1618
|
+
"""
|
|
1619
|
+
|
|
1620
|
+
def __init__(
|
|
1621
|
+
self,
|
|
1622
|
+
f,
|
|
1623
|
+
args,
|
|
1624
|
+
kwargs,
|
|
1625
|
+
):
|
|
1626
|
+
self.f = f
|
|
1627
|
+
self.args = [] if args is None else args
|
|
1628
|
+
self.kwargs = {} if kwargs is None else kwargs
|
|
1629
|
+
|
|
1630
|
+
def __call__(self, args_and_kwargs):
|
|
1631
|
+
"""
|
|
1632
|
+
Internal function that takes a tuple (args, kwargs) for entrance into the Likelihood.
|
|
1633
|
+
|
|
1634
|
+
``self.args`` and ``self.kwargs`` are added to these inputs.
|
|
1635
|
+
|
|
1636
|
+
"""
|
|
1637
|
+
|
|
1638
|
+
args_in_add, kwargs_in_add = args_and_kwargs
|
|
1639
|
+
|
|
1640
|
+
try:
|
|
1641
|
+
args_in = args_in_add + type(args_in_add)(self.args)
|
|
1642
|
+
kwargs_in = {**kwargs_in_add, **self.kwargs}
|
|
1643
|
+
|
|
1644
|
+
out = self.f(*args_in, **kwargs_in)
|
|
1645
|
+
return out
|
|
1646
|
+
|
|
1647
|
+
except: # pragma: no cover
|
|
1648
|
+
import traceback
|
|
1649
|
+
|
|
1650
|
+
print("eryn: Exception while calling your likelihood function:")
|
|
1651
|
+
print(" args added:", args_in_add)
|
|
1652
|
+
print(" args:", self.args)
|
|
1653
|
+
print(" kwargs added:", kwargs_in_add)
|
|
1654
|
+
print(" kwargs:", self.kwargs)
|
|
1655
|
+
print(" exception:")
|
|
1656
|
+
traceback.print_exc()
|
|
1657
|
+
raise
|
|
1658
|
+
|
|
1659
|
+
|
|
1660
|
+
def walkers_independent(coords_in):
|
|
1661
|
+
"""Determine if walkers are independent
|
|
1662
|
+
|
|
1663
|
+
Orginall from ``emcee``.
|
|
1664
|
+
|
|
1665
|
+
Args:
|
|
1666
|
+
coords_in (np.ndarray[ntemps, nwalkers, nleaves_max, ndim]): Coordinates of the walkers.
|
|
1667
|
+
|
|
1668
|
+
Returns:
|
|
1669
|
+
bool: If walkers are independent.
|
|
1670
|
+
|
|
1671
|
+
"""
|
|
1672
|
+
# make sure it is 4-dimensional and reshape
|
|
1673
|
+
# so it groups by temperature and walker
|
|
1674
|
+
assert coords_in.ndim == 4
|
|
1675
|
+
ntemps, nwalkers, nleaves_max, ndim = coords_in.shape
|
|
1676
|
+
coords = coords_in.reshape(ntemps * nwalkers, nleaves_max * ndim)
|
|
1677
|
+
|
|
1678
|
+
# make sure all coordinates are finite
|
|
1679
|
+
if not np.all(np.isfinite(coords)):
|
|
1680
|
+
return False
|
|
1681
|
+
|
|
1682
|
+
# roughly determine covariance information
|
|
1683
|
+
C = coords - np.mean(coords, axis=0)[None, :]
|
|
1684
|
+
C_colmax = np.amax(np.abs(C), axis=0)
|
|
1685
|
+
if np.any(C_colmax == 0):
|
|
1686
|
+
return False
|
|
1687
|
+
C /= C_colmax
|
|
1688
|
+
C_colsum = np.sqrt(np.sum(C**2, axis=0))
|
|
1689
|
+
C /= C_colsum
|
|
1690
|
+
return np.linalg.cond(C.astype(float)) <= 1e8
|