eryn 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
eryn/ensemble.py ADDED
@@ -0,0 +1,1690 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import warnings
4
+
5
+ import numpy as np
6
+ from itertools import count
7
+ from copy import deepcopy
8
+
9
+ from .backends import Backend, HDFBackend
10
+ from .model import Model
11
+ from .moves import StretchMove, TemperatureControl, DistributionGenerateRJ, GaussianMove
12
+ from .pbar import get_progress_bar
13
+ from .state import State
14
+ from .prior import ProbDistContainer
15
+
16
+ # from .utils import PlotContainer
17
+ from .utils import PeriodicContainer
18
+ from .utils.utility import groups_from_inds
19
+
20
+
21
+ __all__ = ["EnsembleSampler", "walkers_independent"]
22
+
23
+
24
+ try:
25
+ from collections.abc import Iterable
26
+ except ImportError:
27
+ # for py2.7, will be an Exception in 3.8
28
+ from collections import Iterable
29
+
30
+
31
+ class EnsembleSampler(object):
32
+ """An ensemble MCMC sampler
33
+
34
+ The class controls the entire sampling run. It can handle
35
+ everything from a basic non-tempered MCMC to a parallel-tempered,
36
+ global fit containing multiple branches (models) and a variable
37
+ number of leaves (sources) per branch.
38
+ See `here <https://mikekatz04.github.io/Eryn/html/tutorial/Eryn_tutorial.html#The-Tree-Metaphor>`_
39
+ for a basic explainer.
40
+
41
+ Parameters related to parallelization can be controlled via the ``pool`` argument.
42
+
43
+ Args:
44
+ nwalkers (int): The number of walkers in the ensemble per temperature.
45
+ ndims (int, list of ints, or dict): The number of dimensions for each branch. If
46
+ ``dict``, keys should be the branch names and values the associated dimensionality.
47
+ log_like_fn (callable): A function that returns the natural logarithm of the
48
+ likelihood for that position. The inputs to ``log_like_fn`` depend on whether
49
+ the function is vectorized (kwarg ``vectorize`` below), if you are using reversible jump,
50
+ and how many branches you have.
51
+
52
+ In the simplest case where ``vectorize == False``, no reversible jump, and only one
53
+ type of model, the inputs are just the array of parameters for one walker, so shape is ``(ndim,)``.
54
+
55
+ If ``vectorize == True``, no reversible jimp, and only one type of model, the inputs will
56
+ be a 2D array of parameters of all the walkers going in. Shape: ``(num positions, ndim)``.
57
+
58
+ If using reversible jump, the leaves that go together in the same Likelihood will be grouped
59
+ together into a single function call. If ``vectorize == False``, then each group is sent as
60
+ an individual computation. With ``N`` different branches (``N > 1``), inputs would be a list
61
+ of 2D arrays of the coordinates for all leaves within each branch: ``([x0, x1,...,xN])``
62
+ where ``xi`` is 2D with shape ``(number of leaves in this branch, ndim)``. If ``N == 1``, then
63
+ a list is not provided, just x0, the 2D array of coordinates for the one branch considered.
64
+
65
+ If using reversible jump and ``vectorize == True``, then the arrays of parameters will be output
66
+ with information as regards the grouping of branch and leaf set. Inputs will be
67
+ ``([X0, X1,..XN], [group0, group1,...,groupN])`` where ``Xi`` is a 2D array of all
68
+ leaves in the sampler for branch ``i``. ``groupi`` is an index indicating which unique group
69
+ that sources belongs. For example, if we have 3 walkers with (1, 2, 1) leaves for model ``i``,
70
+ respectively, we wil have an ``Xi = array([params0, params1, params2, params3])`` and
71
+ ``groupsi = array([0, 1, 1, 2])``.
72
+ If ``N == 1``, then the lists are removed and the inputs become ``(X0, group0)``.
73
+
74
+ Extra ``args`` and ``kwargs`` for the Likelihood function can be added with the kwargs
75
+ ``args`` and ``kwargs`` below.
76
+
77
+ Please see the
78
+ `tutorial <https://mikekatz04.github.io/Eryn/html/tutorial/Eryn_tutorial.html#>`_
79
+ for more information.
80
+
81
+ priors (dict): The prior dictionary can take four forms.
82
+ 1) A dictionary with keys as int or tuple containing the int or tuple of int
83
+ that describe the parameter number over which to assess the prior, and values that
84
+ are prior probability distributions that must have a ``logpdf`` class method.
85
+ 2) A :class:`eryn.prior.ProbDistContainer` object.
86
+ 3) A dictionary with keys that are ``branch_names`` and values that are dictionaries for
87
+ each branch as described for (1).
88
+ 4) A dictionary with keys that are ``branch_names`` and values are
89
+ :class:`eryn.prior.ProbDistContainer` objects.
90
+ If the priors dictionary has the specific key ``"all_models_together"`` in it, then a special
91
+ prior must be input by the user that can produce the prior ``logpdf`` information that can depend on all of the models
92
+ together rather than handling them separately as is the default. In this case, the user must input
93
+ as the item attached to this special key a class object with a ``logpdf`` method that takes as input
94
+ two arguments: ``(coords, inds)``, which are the coordinate and index dictionaries across all models with
95
+ shapes of ``(ntemps, nwalkers, nleaves_max, ndim)`` and ``(ntemps, nwalkers, nleaves_max)`` for each
96
+ individual model, respectively. This function must then return the numpy array of logpdf values for the
97
+ prior value with shape ``(ntemps, nwalkers)``.
98
+ provide_groups (bool, optional): If ``True``, provide groups as described in ``log_like_fn`` above.
99
+ A group parameter is added for each branch. (default: ``False``)
100
+ provide_supplemental (bool, optional): If ``True``, it will provide keyword arguments to
101
+ the Likelihood function: ``supps`` and ``branch_supps``. Please see the `Tutorial <https://mikekatz04.github.io/Eryn/html/tutorial/Eryn_tutorial.html#>`_
102
+ and :class:`eryn.state.BranchSupplemental` for more information.
103
+ tempering_kwargs (dict, optional): Keyword arguments for initialization of the
104
+ tempering class: :class:`eryn.moves.tempering.TemperatureControl`. (default: ``{}``)
105
+ branch_names (list, optional): List of branch names. If ``None``, models will be assigned
106
+ names as ``f"model_{index}"`` based on ``nbranches``. (default: ``None``)
107
+ nbranches (int, optional): Number of branches (models) tested.
108
+ Only used if ``branch_names is None``.
109
+ (default: ``1``)
110
+ nleaves_max (int, list of ints, or dict, optional): Maximum allowable leaf count for each branch.
111
+ It should have the same length as the number of branches.
112
+ If ``dict``, keys should be the branch names and values the associated maximal leaf value.
113
+ (default: ``1``)
114
+ nleaves_min (int, list of ints, or dict, optional): Minimum allowable leaf count for each branch.
115
+ It should have the same length as the number of branches. Only used with Reversible Jump.
116
+ If ``dict``, keys should be the branch names and values the associated maximal leaf value.
117
+ If ``None`` and using Reversible Jump, will fill all branches with zero.
118
+ (default: ``0``)
119
+ pool (object, optional): An object with a ``map`` method that follows the same
120
+ calling sequence as the built-in ``map`` function. This is
121
+ generally used to compute the log-probabilities for the ensemble
122
+ in parallel.
123
+ moves (list or object, optional): This can be a single move object, a list of moves,
124
+ or a "weighted" list of the form ``[(eryn.moves.StretchMove(),
125
+ 0.1), ...]``. When running, the sampler will randomly select a
126
+ move from this list (optionally with weights) for each proposal.
127
+ If ``None``, the default will be :class:`StretchMove`.
128
+ (default: ``None``)
129
+ rj_moves (list or object or bool or str, optional): If ``None`` or ``False``, reversible jump will not be included in the run.
130
+ This can be a single move object, a list of moves,
131
+ or a "weighted" list of the form ``[(eryn.moves.DistributionGenerateRJ(),
132
+ 0.1), ...]``. When running, the sampler will randomly select a
133
+ move from this list (optionally with weights) for each proposal.
134
+ If ``True``, it defaults to :class:`DistributionGenerateRJ`. When running just the :class:`DistributionGenerateRJ`
135
+ with multiple branches, it will propose changes to all branhces simultaneously.
136
+ When running with more than one branch, useful options for ``rj_moves`` are ``"iterate_branches"``,
137
+ ``"separate_branches"``, or ``"together"``. If ``rj_moves == "iterate_branches"``, sample one branch by one branch in order of
138
+ the branch names. This occurs within one RJ proposal, for each RJ proposal. If ``rj_moves == "separate_branches"``,
139
+ there will be one RJ move per branch. During each individual RJ move, one of these proposals is chosen at random with
140
+ equal propability. This is generally recommended when using multiple branches.
141
+ If ``rj_moves == "together"``, this is equivalent to ``rj_moves == True``.
142
+ (default: ``None``)
143
+ dr_moves (bool, optional): If ``None`` ot ``False``, delayed rejection when proposing "birth"
144
+ of new components/models will be switched off for this run. Requires ``rj_moves`` set to ``True``.
145
+ Not implemented yet. Working on it.
146
+ (default: ``None``)
147
+ dr_max_iter (int, optional): Maximum number of iterations used with delayed rejection. (default: 5)
148
+ args (optional): A list of extra positional arguments for
149
+ ``log_like_fn``. ``log_like_fn`` will be called as
150
+ ``log_like_fn(sampler added args, *args, sampler added kwargs, **kwargs)``.
151
+ kwargs (optional): A dict of extra keyword arguments for
152
+ ``log_like_fn``. ``log_like_fn`` will be called as
153
+ ``log_like_fn(sampler added args, *args, sampler added kwargs, **kwargs)``.
154
+ backend (optional): Either a :class:`backends.Backend` or a subclass
155
+ (like :class:`backends.HDFBackend`) that is used to store and
156
+ serialize the state of the chain. By default, the chain is stored
157
+ as a set of numpy arrays in memory, but new backends can be
158
+ written to support other mediums.
159
+ vectorize (bool, optional): If ``True``, ``log_like_fn`` is expected
160
+ to accept an array of position vectors instead of just one. Note
161
+ that ``pool`` will be ignored if this is ``True``. See ``log_like_fn`` information
162
+ above to understand the arguments of ``log_like_fn`` based on whether
163
+ ``vectorize`` is ``True``.
164
+ (default: ``False``)
165
+ periodic (dict, optional): Keys are ``branch_names``. Values are dictionaries
166
+ that have (key: value) pairs as (index to parameter: period). Periodic
167
+ parameters are treated as having periodic boundary conditions in proposals.
168
+ update_fn (callable, optional): :class:`eryn.utils.updates.AdjustStretchProposalScale`
169
+ object that allows the user to update the sampler in any preferred way
170
+ every ``update_iterations`` sampler iterations. The callable must have signature:
171
+ ``(sampler iteration, last sample state object, EnsembleSampler object)``.
172
+ update_iterations (int, optional): Number of iterations between sampler
173
+ updates using ``update_fn``. Updates are only performed at the thinning rate.
174
+ If ``thin_by>1`` when :func:`EnsembleSampler.run_mcmc` is used, the sampler
175
+ is updated every ``thin_by * update_iterations`` iterations.
176
+ stopping_fn (callable, optional): :class:`eryn.utils.stopping.Stopping` object that
177
+ allows the user to end the sampler if specified criteria are met.
178
+ The callable must have signature:
179
+ ``(sampler iteration, last sample state object, EnsembleSampler object)``.
180
+ stopping_iterations (int, optional): Number of iterations between sampler
181
+ attempts to evaluate the ``stopping_fn``. Stopping checks are only performed at the thinning rate.
182
+ If ``thin_by>1`` when :func:`EnsembleSampler.run_mcmc` is used, the sampler
183
+ is checked for the stopping criterion every ``thin_by * stopping_iterations`` iterations.
184
+ fill_zero_leaves_val (double, optional): When there are zero leaves in a
185
+ given walker (across all branches), fill the likelihood value with
186
+ ``fill_zero_leaves_val``. If wanting to keep zero leaves as a possible
187
+ model, this should be set to the value of the contribution to the Likelihood
188
+ from the data. (Default: ``-1e300``).
189
+ num_repeats_in_model (int, optional): Number of times to repeat the in-model step
190
+ within in one sampler iteration. When analyzing the acceptance fraction, you must
191
+ include the value of ``num_repeats_in_model`` to get the proper denominator.
192
+ num_repeats_rj (int, optional): Number of time to repeat the reversible jump step
193
+ within in one sampler iteration. When analyzing the acceptance fraction, you must
194
+ include the value of ``num_repeats_rj`` to get the proper denominator.
195
+ track_moves (bool, optional): If ``True``, track acceptance fraction of each move
196
+ in the backend. If ``False``, no tracking is done. If ``True`` and run is interrupted, it will check
197
+ that the move configuration has not changed. It will not allow the run to go on
198
+ if it is changed. In this case, the user should declare a new backend and use the last
199
+ state from the previous backend. **Warning**: If the order of moves of the same move class
200
+ is changed, the check may not catch it, so the tracking may mix move acceptance fractions together.
201
+ info (dict, optional): Key and value pairs reprenting any information
202
+ the user wants to add to the backend if the user is not inputing
203
+ their own backend.
204
+
205
+ Raises:
206
+ ValueError: Any startup issues.
207
+
208
+ """
209
+
210
+ def __init__(
211
+ self,
212
+ nwalkers,
213
+ ndims, # assumes ndim_max
214
+ log_like_fn,
215
+ priors,
216
+ provide_groups=False,
217
+ provide_supplemental=False,
218
+ tempering_kwargs={},
219
+ branch_names=None,
220
+ nbranches=1,
221
+ nleaves_max=1,
222
+ nleaves_min=0,
223
+ pool=None,
224
+ moves=None,
225
+ rj_moves=None,
226
+ dr_moves=None,
227
+ dr_max_iter=5,
228
+ args=None,
229
+ kwargs=None,
230
+ backend=None,
231
+ vectorize=False,
232
+ blobs_dtype=None, # TODO check this
233
+ plot_iterations=-1, # TODO: do plot stuff?
234
+ plot_generator=None,
235
+ plot_name=None,
236
+ periodic=None,
237
+ update_fn=None,
238
+ update_iterations=-1,
239
+ stopping_fn=None,
240
+ stopping_iterations=-1,
241
+ fill_zero_leaves_val=-1e300,
242
+ num_repeats_in_model=1,
243
+ num_repeats_rj=1,
244
+ track_moves=True,
245
+ info={},
246
+ ):
247
+ # store priors
248
+ self.priors = priors
249
+
250
+ # store some kwargs
251
+ self.provide_groups = provide_groups
252
+ self.provide_supplemental = provide_supplemental
253
+ self.fill_zero_leaves_val = fill_zero_leaves_val
254
+ self.num_repeats_in_model = num_repeats_in_model
255
+ self.num_repeats_rj = num_repeats_rj
256
+ self.track_moves = track_moves
257
+
258
+ # setup emcee-like basics
259
+ self.pool = pool
260
+ self.vectorize = vectorize
261
+ self.blobs_dtype = blobs_dtype
262
+
263
+ # turn things into lists/dicts if needed
264
+ if branch_names is not None:
265
+ if isinstance(branch_names, str):
266
+ branch_names = [branch_names]
267
+
268
+ elif not isinstance(branch_names, list):
269
+ raise ValueError("branch_names must be string or list of strings.")
270
+
271
+ else:
272
+ branch_names = ["model_{}".format(i) for i in range(nbranches)]
273
+
274
+ nbranches = len(branch_names)
275
+
276
+ if isinstance(ndims, int):
277
+ assert len(branch_names) == 1
278
+ ndims = {branch_names[0]: ndims}
279
+
280
+ elif isinstance(ndims, list) or isinstance(ndims, np.ndarray):
281
+ assert len(branch_names) == len(ndims)
282
+ ndims = {bn: nd for bn, nd in zip(branch_names, ndims)}
283
+
284
+ elif isinstance(ndims, dict):
285
+ assert len(list(ndims.keys())) == len(branch_names)
286
+ for key in ndims:
287
+ if key not in branch_names:
288
+ raise ValueError(
289
+ f"{key} is in ndims but does not appear in branch_names: {branch_names}."
290
+ )
291
+ else:
292
+ raise ValueError("ndims is to be a scalar int, list or dict.")
293
+
294
+ if isinstance(nleaves_max, int):
295
+ assert len(branch_names) == 1
296
+ nleaves_max = {branch_names[0]: nleaves_max}
297
+
298
+ elif isinstance(nleaves_max, list) or isinstance(nleaves_max, np.ndarray):
299
+ assert len(branch_names) == len(nleaves_max)
300
+ nleaves_max = {bn: nl for bn, nl in zip(branch_names, nleaves_max)}
301
+
302
+ elif isinstance(nleaves_max, dict):
303
+ assert len(list(nleaves_max.keys())) == len(branch_names)
304
+ for key in nleaves_max:
305
+ if key not in branch_names:
306
+ raise ValueError(
307
+ f"{key} is in nleaves_max but does not appear in branch_names: {branch_names}."
308
+ )
309
+ else:
310
+ raise ValueError("nleaves_max is to be a scalar int, list, or dict.")
311
+
312
+ self.nbranches = len(branch_names)
313
+
314
+ self.branch_names = branch_names
315
+ self.ndims = ndims
316
+ self.nleaves_max = nleaves_max
317
+
318
+ # setup temperaing information
319
+ # default is no temperatures
320
+ if tempering_kwargs == {}:
321
+ self.ntemps = 1
322
+ self.temperature_control = None
323
+ else:
324
+ # get effective total dimension
325
+ total_ndim = 0
326
+ for key in self.branch_names:
327
+ total_ndim += self.nleaves_max[key] * self.ndims[key]
328
+ self.temperature_control = TemperatureControl(
329
+ total_ndim, nwalkers, **tempering_kwargs
330
+ )
331
+ self.ntemps = self.temperature_control.ntemps
332
+
333
+ # set basic variables for sampling settings
334
+ self.nwalkers = nwalkers
335
+ self.nbranches = nbranches
336
+
337
+ # eryn wraps periodic parameters
338
+ if periodic is not None:
339
+ if not isinstance(periodic, PeriodicContainer) and not isinstance(
340
+ periodic, dict
341
+ ):
342
+ raise ValueError(
343
+ "periodic must be PeriodicContainer or dict if not None."
344
+ )
345
+ elif isinstance(periodic, dict):
346
+ periodic = PeriodicContainer(periodic)
347
+
348
+ # Parse the move schedule
349
+ if moves is None:
350
+ if rj_moves is not None:
351
+ raise ValueError(
352
+ "If providing rj_moves, must provide moves kwarg as well."
353
+ )
354
+
355
+ # defaults to stretch move
356
+ self.moves = [
357
+ StretchMove(
358
+ temperature_control=self.temperature_control,
359
+ periodic=periodic,
360
+ a=2.0,
361
+ )
362
+ ]
363
+ self.weights = [1.0]
364
+
365
+ elif isinstance(moves, Iterable):
366
+ try:
367
+ self.moves, self.weights = [list(tmp) for tmp in zip(*moves)]
368
+
369
+ except TypeError:
370
+ self.moves = moves
371
+ self.weights = np.ones(len(moves))
372
+ else:
373
+ self.moves = [moves]
374
+ self.weights = [1.0]
375
+
376
+ self.weights = np.atleast_1d(self.weights).astype(float)
377
+ self.weights /= np.sum(self.weights)
378
+
379
+ # parse the reversible jump move schedule
380
+ if rj_moves is None:
381
+ self.has_reversible_jump = False
382
+ elif (isinstance(rj_moves, bool) and rj_moves) or isinstance(rj_moves, str):
383
+ self.has_reversible_jump = True
384
+
385
+ if self.has_reversible_jump:
386
+ if nleaves_min is None:
387
+ nleaves_min = {bn: 0 for bn in branch_names}
388
+ elif isinstance(nleaves_min, int):
389
+ assert len(branch_names) == 1
390
+ nleaves_min = {branch_names[0]: nleaves_min}
391
+
392
+ elif isinstance(nleaves_min, list) or isinstance(
393
+ nleaves_min, np.ndarray
394
+ ):
395
+ assert len(branch_names) == len(nleaves_min)
396
+ nleaves_min = {bn: nl for bn, nl in zip(branch_names, nleaves_min)}
397
+
398
+ elif isinstance(nleaves_min, dict):
399
+ assert len(list(nleaves_min.keys())) == len(branch_names)
400
+ for key in nleaves_min:
401
+ if key not in branch_names:
402
+ raise ValueError(
403
+ f"{key} is in nleaves_min but does not appear in branch_names: {branch_names}."
404
+ )
405
+ else:
406
+ raise ValueError(
407
+ "If providing nleaves_min, nleaves_min is to be a scalar int, list, or dict."
408
+ )
409
+
410
+ self.nleaves_min = nleaves_min
411
+
412
+ if (isinstance(rj_moves, bool) and rj_moves) or (
413
+ isinstance(rj_moves, str) and rj_moves == "together"
414
+ ):
415
+ # default to DistributionGenerateRJ
416
+
417
+ # gibbs sampling setup here means run all of them together
418
+ gibbs_sampling_setup = None
419
+
420
+ rj_move = DistributionGenerateRJ(
421
+ self.priors,
422
+ nleaves_max=self.nleaves_max,
423
+ nleaves_min=self.nleaves_min,
424
+ dr=dr_moves,
425
+ dr_max_iter=dr_max_iter,
426
+ tune=False,
427
+ temperature_control=self.temperature_control,
428
+ gibbs_sampling_setup=gibbs_sampling_setup,
429
+ )
430
+ self.rj_moves = [rj_move]
431
+ self.rj_weights = [1.0]
432
+
433
+ elif isinstance(rj_moves, str) and rj_moves == "iterate_branches":
434
+ # will iterate through all branches within one RJ proposal
435
+ gibbs_sampling_setup = deepcopy(branch_names)
436
+
437
+ # default to DistributionGenerateRJ
438
+ rj_move = DistributionGenerateRJ(
439
+ self.priors,
440
+ nleaves_max=self.nleaves_max,
441
+ nleaves_min=self.nleaves_min,
442
+ dr=dr_moves,
443
+ dr_max_iter=dr_max_iter,
444
+ tune=False,
445
+ temperature_control=self.temperature_control,
446
+ gibbs_sampling_setup=gibbs_sampling_setup,
447
+ )
448
+ self.rj_moves = [rj_move]
449
+ self.rj_weights = [1.0]
450
+
451
+ elif isinstance(rj_moves, str) and rj_moves == "separate_branches":
452
+ # will iterate through all branches within one RJ proposal
453
+ rj_moves = []
454
+ rj_weights = []
455
+ for branch_name in branch_names:
456
+
457
+ # only do one branch per move
458
+ gibbs_sampling_setup = [branch_name]
459
+
460
+ # default to DistributionGenerateRJ
461
+ rj_move_tmp = DistributionGenerateRJ(
462
+ self.priors,
463
+ nleaves_max=self.nleaves_max,
464
+ nleaves_min=self.nleaves_min,
465
+ dr=dr_moves,
466
+ dr_max_iter=dr_max_iter,
467
+ tune=False,
468
+ temperature_control=self.temperature_control,
469
+ gibbs_sampling_setup=gibbs_sampling_setup,
470
+ )
471
+ rj_moves.append(rj_move_tmp)
472
+ # will renormalize after
473
+ rj_weights.append(1.0)
474
+ self.rj_moves = rj_moves
475
+ self.rj_weights = rj_weights
476
+
477
+ elif isinstance(rj_moves, str):
478
+ raise ValueError(
479
+ f"When providing a str for rj_moves, must be 'together', 'iterate_branches', or 'separate_branches'. Input is {rj_moves}"
480
+ )
481
+
482
+ # same as above for moves
483
+ elif isinstance(rj_moves, Iterable):
484
+ self.has_reversible_jump = True
485
+
486
+ try:
487
+ self.rj_moves, self.rj_weights = zip(*rj_moves)
488
+ except TypeError:
489
+ self.rj_moves = rj_moves
490
+ self.rj_weights = np.ones(len(rj_moves))
491
+
492
+ elif isinstance(rj_moves, bool) and not rj_moves:
493
+ self.has_reversible_jump = False
494
+ self.rj_moves = None
495
+ self.rj_weights = None
496
+
497
+ elif not isinstance(rj_moves, bool):
498
+ self.has_reversible_jump = True
499
+
500
+ self.rj_moves = [rj_moves]
501
+ self.rj_weights = [1.0]
502
+
503
+ # adjust rj weights properly
504
+ if self.has_reversible_jump:
505
+ self.rj_weights = np.atleast_1d(self.rj_weights).astype(float)
506
+ self.rj_weights /= np.sum(self.rj_weights)
507
+
508
+ # warn if base stretch is used
509
+ for move in self.moves:
510
+ if type(move) == StretchMove:
511
+ warnings.warn(
512
+ "If using revisible jump, using the Stretch Move for in-model proposals is not advised. It will run and work, but it will not be using the correct complientary group of parameters meaning it will most likely be very inefficient."
513
+ )
514
+
515
+ # make sure moves have temperature module
516
+ if self.temperature_control is not None:
517
+ for move in self.moves:
518
+ if move.temperature_control is None:
519
+ move.temperature_control = self.temperature_control
520
+
521
+ if self.has_reversible_jump:
522
+ for move in self.rj_moves:
523
+ if move.temperature_control is None:
524
+ move.temperature_control = self.temperature_control
525
+
526
+ # make sure moves have temperature module
527
+ if periodic is not None:
528
+ for move in self.moves:
529
+ if move.periodic is None:
530
+ move.periodic = periodic
531
+
532
+ if self.has_reversible_jump:
533
+ for move in self.rj_moves:
534
+ if move.periodic is None:
535
+ move.periodic = periodic
536
+
537
+ # prepare the per proposal accepted values that are held as attributes in the specific classes
538
+ for move in self.moves:
539
+ move.accepted = np.zeros((self.ntemps, self.nwalkers))
540
+
541
+ if self.has_reversible_jump:
542
+ for move in self.rj_moves:
543
+ move.accepted = np.zeros((self.ntemps, self.nwalkers))
544
+
545
+ # setup backend if not provided or initialized
546
+ if backend is None:
547
+ self.backend = Backend()
548
+ elif isinstance(backend, str):
549
+ self.backend = HDFBackend(backend)
550
+ else:
551
+ self.backend = backend
552
+
553
+ self.info = info
554
+
555
+ all_moves_tmp = list(
556
+ tuple(self.moves)
557
+ if not self.has_reversible_jump
558
+ else tuple(self.moves) + tuple(self.rj_moves)
559
+ )
560
+
561
+ self.all_moves = {}
562
+ if self.track_moves:
563
+ current_indices_move_keys = {}
564
+ for move in all_moves_tmp:
565
+ # get out of tuple if weights are given
566
+ if isinstance(move, tuple):
567
+ move = move[0]
568
+
569
+ # get the name of the class instance as a string
570
+ move_name = move.__class__.__name__
571
+
572
+ # need to keep track how many times each type of move class has been used
573
+ if move_name not in current_indices_move_keys:
574
+ current_indices_move_keys[move_name] = 0
575
+
576
+ else:
577
+ current_indices_move_keys[move_name] += 1
578
+
579
+ # get the full name including the index
580
+ full_move_name = move_name + f"_{current_indices_move_keys[move_name]}"
581
+ self.all_moves[full_move_name] = move
582
+
583
+ # get move keys out
584
+ move_keys = list(self.all_moves.keys())
585
+
586
+ else:
587
+ move_keys = None
588
+
589
+ self.move_keys = move_keys
590
+
591
+ # Deal with re-used backends
592
+ if not self.backend.initialized:
593
+ self._previous_state = None
594
+ self.reset(
595
+ branch_names=branch_names,
596
+ ntemps=self.ntemps,
597
+ nleaves_max=nleaves_max,
598
+ rj=self.has_reversible_jump,
599
+ moves=move_keys,
600
+ **info,
601
+ )
602
+ state = np.random.get_state()
603
+ else:
604
+ if self.track_moves:
605
+ moves_okay = True
606
+ if len(self.move_keys) != len(self.backend.move_keys):
607
+ moves_okay = False
608
+
609
+ for key in self.move_keys:
610
+ if key not in self.backend.move_keys:
611
+ moves_okay = False
612
+
613
+ if not moves_okay:
614
+ raise ValueError(
615
+ "Configuration of moves has changed. Cannot use the same backend. Declare a new backend and start from the previous state. If you would prefer not to track move acceptance fraction, set track_moves to False in the EnsembleSampler."
616
+ )
617
+
618
+ # Check the backend shape
619
+ for i, (name, shape) in enumerate(self.backend.shape.items()):
620
+ test_shape = (
621
+ self.ntemps,
622
+ self.nwalkers,
623
+ self.nleaves_max[name],
624
+ self.ndims[name],
625
+ )
626
+ if shape != test_shape:
627
+ raise ValueError(
628
+ (
629
+ "the shape of the backend ({0}) is incompatible with the "
630
+ "shape of the sampler ({1} for model {2})"
631
+ ).format(shape, test_shape, name)
632
+ )
633
+
634
+ # Get the last random state
635
+ state = self.backend.random_state
636
+ if state is None:
637
+ state = np.random.get_state()
638
+
639
+ # Grab the last step so that we can restart
640
+ it = self.backend.iteration
641
+ if it > 0:
642
+ self._previous_state = self.get_last_sample()
643
+
644
+ # This is a random number generator that we can easily set the state
645
+ # of without affecting the numpy-wide generator
646
+ self._random = np.random.mtrand.RandomState()
647
+ self._random.set_state(state)
648
+
649
+ # Do a little bit of _magic_ to make the likelihood call with
650
+ # ``args`` and ``kwargs`` pickleable.
651
+ self.log_like_fn = _FunctionWrapper(log_like_fn, args, kwargs)
652
+
653
+ self.all_walkers = self.nwalkers * self.ntemps
654
+
655
+ # prepare plotting
656
+ # TODO: adjust plotting maybe?
657
+ self.plot_iterations = plot_iterations
658
+
659
+ if plot_generator is None and self.plot_iterations > 0:
660
+ raise NotImplementedError
661
+ # set to default if not provided
662
+ if plot_name is not None:
663
+ name = plot_name
664
+ else:
665
+ name = "output"
666
+ self.plot_generator = PlotContainer(
667
+ fp=name, backend=self.backend, thin_chain_by_ac=True
668
+ )
669
+ elif self.plot_iterations > 0:
670
+ raise NotImplementedError
671
+ self.plot_generator = plot_generator
672
+
673
+ # prepare stopping functions
674
+ self.stopping_fn = stopping_fn
675
+ self.stopping_iterations = stopping_iterations
676
+
677
+ # prepare update functions
678
+ self.update_fn = update_fn
679
+ self.update_iterations = update_iterations
680
+
681
+ @property
682
+ def random_state(self):
683
+ """
684
+ The state of the internal random number generator. In practice, it's
685
+ the result of calling ``get_state()`` on a
686
+ ``numpy.random.mtrand.RandomState`` object. You can try to set this
687
+ property but be warned that if you do this and it fails, it will do
688
+ so silently.
689
+
690
+ """
691
+ return self._random.get_state()
692
+
693
+ @random_state.setter # NOQA
694
+ def random_state(self, state):
695
+ """
696
+ Try to set the state of the random number generator but fail silently
697
+ if it doesn't work. Don't say I didn't warn you...
698
+
699
+ """
700
+ try:
701
+ self._random.set_state(state)
702
+ except:
703
+ pass
704
+
705
+ @property
706
+ def priors(self):
707
+ """
708
+ Return the priors in the sampler.
709
+
710
+ """
711
+ return self._priors
712
+
713
+ @priors.setter
714
+ def priors(self, priors):
715
+ """Set priors information.
716
+
717
+ This performs checks to make sure the inputs are okay.
718
+
719
+ """
720
+ if isinstance(priors, dict):
721
+ self._priors = {}
722
+
723
+ for key in priors.keys():
724
+ test = priors[key]
725
+ if isinstance(test, dict):
726
+ # check all dists
727
+ for ind, dist in test.items():
728
+ if not hasattr(dist, "logpdf"):
729
+ raise ValueError(
730
+ "Distribution for model {0} and index {1} does not have logpdf method.".format(
731
+ key, ind
732
+ )
733
+ )
734
+
735
+ self._priors[key] = ProbDistContainer(test)
736
+
737
+ elif isinstance(test, ProbDistContainer):
738
+ self._priors[key] = test
739
+
740
+ elif hasattr(test, "logpdf"):
741
+ self._priors[key] = test
742
+
743
+ else:
744
+ raise ValueError(
745
+ "priors dictionary items must be dictionaries with prior information or instances of the ProbDistContainer class."
746
+ )
747
+
748
+ elif isinstance(priors, ProbDistContainer):
749
+ self._priors = {"model_0": priors}
750
+
751
+ else:
752
+ raise ValueError("Priors must be a dictionary.")
753
+
754
+ return
755
+
756
+ @property
757
+ def iteration(self):
758
+ return self.backend.iteration
759
+
760
+ def reset(self, **info):
761
+ """
762
+ Reset the backend.
763
+
764
+ Args:
765
+ **info (dict, optional): information to pass to backend reset method.
766
+
767
+ """
768
+ self.backend.reset(self.nwalkers, self.ndims, **info)
769
+
770
+ def __getstate__(self):
771
+ # In order to be generally picklable, we need to discard the pool
772
+ # object before trying.
773
+ d = self.__dict__
774
+ d["pool"] = None
775
+ return d
776
+
777
+ def get_model(self):
778
+ """Get ``Model`` object from sampler
779
+
780
+ The model object is used to pass necessary information to the
781
+ proposals. This method can be used to retrieve the ``model`` used
782
+ in the sampler from outside the sampler.
783
+
784
+ Returns:
785
+ :class:`Model`: ``Model`` object used by sampler.
786
+
787
+ """
788
+ # Set up a wrapper around the relevant model functions
789
+ if self.pool is not None:
790
+ map_fn = self.pool.map
791
+ else:
792
+ map_fn = map
793
+
794
+ # setup model framework for passing necessary items
795
+ model = Model(
796
+ self.log_like_fn,
797
+ self.compute_log_like,
798
+ self.compute_log_prior,
799
+ self.temperature_control,
800
+ map_fn,
801
+ self._random,
802
+ )
803
+ return model
804
+
805
+ def sample(
806
+ self,
807
+ initial_state,
808
+ iterations=1,
809
+ tune=False,
810
+ skip_initial_state_check=True,
811
+ thin_by=1,
812
+ store=True,
813
+ progress=False,
814
+ ):
815
+ """Advance the chain as a generator
816
+
817
+ Args:
818
+ initial_state (State or ndarray[ntemps, nwalkers, nleaves_max, ndim] or dict): The initial
819
+ :class:`State` or positions of the walkers in the
820
+ parameter space. If multiple branches used, must be dict with keys
821
+ as the ``branch_names`` and values as the positions. If ``betas`` are
822
+ provided in the state object, they will be loaded into the
823
+ ``temperature_control``.
824
+ iterations (int or None, optional): The number of steps to generate.
825
+ ``None`` generates an infinite stream (requires ``store=False``).
826
+ (default: 1)
827
+ tune (bool, optional): If ``True``, the parameters of some moves
828
+ will be automatically tuned. (default: ``False``)
829
+ thin_by (int, optional): If you only want to store and yield every
830
+ ``thin_by`` samples in the chain, set ``thin_by`` to an
831
+ integer greater than 1. When this is set, ``iterations *
832
+ thin_by`` proposals will be made. (default: 1)
833
+ store (bool, optional): By default, the sampler stores in the backend
834
+ the positions (and other information) of the samples in the
835
+ chain. If you are using another method to store the samples to
836
+ a file or if you don't need to analyze the samples after the
837
+ fact (for burn-in for example) set ``store`` to ``False``. (default: ``True``)
838
+ progress (bool or str, optional): If ``True``, a progress bar will
839
+ be shown as the sampler progresses. If a string, will select a
840
+ specific ``tqdm`` progress bar - most notable is
841
+ ``'notebook'``, which shows a progress bar suitable for
842
+ Jupyter notebooks. If ``False``, no progress bar will be
843
+ shown. (default: ``False``)
844
+ skip_initial_state_check (bool, optional): If ``True``, a check
845
+ that the initial_state can fully explore the space will be
846
+ skipped. If using reversible jump, the user needs to ensure this on their own
847
+ (``skip_initial_state_check``is set to ``False`` in this case.
848
+ (default: ``True``)
849
+
850
+ Returns:
851
+ State: Every ``thin_by`` steps, this generator yields the :class:`State` of the ensemble.
852
+
853
+ Raises:
854
+ ValueError: Improper initialization.
855
+
856
+ """
857
+ if iterations is None and store:
858
+ raise ValueError("'store' must be False when 'iterations' is None")
859
+
860
+ # Interpret the input as a walker state and check the dimensions.
861
+
862
+ # initial_state.__class__ rather than State in case it is a subclass
863
+ # of State
864
+ if (
865
+ hasattr(initial_state, "__class__")
866
+ and issubclass(initial_state.__class__, State)
867
+ and not isinstance(initial_state.__class__, State)
868
+ ):
869
+ state = initial_state.__class__(initial_state, copy=True)
870
+ else:
871
+ state = State(initial_state, copy=True)
872
+
873
+ # Check the backend shape
874
+ for i, (name, branch) in enumerate(state.branches.items()):
875
+ ntemps_, nwalkers_, nleaves_, ndim_ = branch.shape
876
+ if (ntemps_, nwalkers_, nleaves_, ndim_) != (
877
+ self.ntemps,
878
+ self.nwalkers,
879
+ self.nleaves_max[name],
880
+ self.ndims[name],
881
+ ):
882
+ raise ValueError("incompatible input dimensions")
883
+
884
+ # do an initial state check if is requested and we are not using reversible jump
885
+ if (not skip_initial_state_check) and (
886
+ not walkers_independent(state.coords) and not self.has_reversible_jump
887
+ ):
888
+ raise ValueError(
889
+ "Initial state has a large condition number. "
890
+ "Make sure that your walkers are linearly independent for the "
891
+ "best performance"
892
+ )
893
+
894
+ # get log prior and likelihood if not provided in the initial state
895
+ if state.log_prior is None:
896
+ coords = state.branches_coords
897
+ inds = state.branches_inds
898
+ state.log_prior = self.compute_log_prior(coords, inds=inds)
899
+
900
+ if state.log_like is None:
901
+ coords = state.branches_coords
902
+ inds = state.branches_inds
903
+ state.log_like, state.blobs = self.compute_log_like(
904
+ coords,
905
+ inds=inds,
906
+ logp=state.log_prior,
907
+ supps=state.supplemental, # only used if self.provide_supplemental is True
908
+ branch_supps=state.branches_supplemental, # only used if self.provide_supplemental is True
909
+ )
910
+
911
+ # get betas out of state object if they are there
912
+ if state.betas is not None:
913
+ if state.betas.shape[0] != self.ntemps:
914
+ raise ValueError(
915
+ "Input state has inverse temperatures (betas), but not the correct number of temperatures according to sampler inputs."
916
+ )
917
+
918
+ self.temperature_control.betas = state.betas.copy()
919
+
920
+ else:
921
+ if hasattr(self, "temperature_control") and hasattr(
922
+ self.temperature_control, "betas"
923
+ ):
924
+ state.betas = self.temperature_control.betas.copy()
925
+
926
+ if np.shape(state.log_like) != (self.ntemps, self.nwalkers):
927
+ raise ValueError("incompatible input dimensions")
928
+ if np.shape(state.log_prior) != (self.ntemps, self.nwalkers):
929
+ raise ValueError("incompatible input dimensions")
930
+
931
+ # Check to make sure that the probability function didn't return
932
+ # ``np.nan``.
933
+ if np.any(np.isnan(state.log_like)):
934
+ raise ValueError("The initial log_like was NaN")
935
+
936
+ if np.any(np.isinf(state.log_like)):
937
+ raise ValueError("The initial log_like was +/- infinite")
938
+
939
+ if np.any(np.isnan(state.log_prior)):
940
+ raise ValueError("The initial log_prior was NaN")
941
+
942
+ if np.any(np.isinf(state.log_prior)):
943
+ raise ValueError("The initial log_prior was +/- infinite")
944
+
945
+ # Check that the thin keyword is reasonable.
946
+ thin_by = int(thin_by)
947
+ if thin_by <= 0:
948
+ raise ValueError("Invalid thinning argument")
949
+
950
+ yield_step = thin_by
951
+ checkpoint_step = thin_by
952
+ if store:
953
+ self.backend.grow(iterations, state.blobs)
954
+
955
+ # get the model object
956
+ model = self.get_model()
957
+
958
+ # Inject the progress bar
959
+ total = None if iterations is None else iterations * yield_step
960
+ with get_progress_bar(progress, total) as pbar:
961
+ i = 0
962
+ for _ in count() if iterations is None else range(iterations):
963
+ for _ in range(yield_step):
964
+ # in model moves
965
+ accepted = np.zeros((self.ntemps, self.nwalkers))
966
+ for repeat in range(self.num_repeats_in_model):
967
+ # Choose a random move
968
+ move = self._random.choice(self.moves, p=self.weights)
969
+
970
+ # Propose (in model)
971
+ state, accepted_out = move.propose(model, state)
972
+ accepted += accepted_out
973
+ if self.ntemps > 1:
974
+ in_model_swaps = move.temperature_control.swaps_accepted
975
+ else:
976
+ in_model_swaps = None
977
+
978
+ state.random_state = self.random_state
979
+
980
+ if tune:
981
+ move.tune(state, accepted_out)
982
+
983
+ if self.has_reversible_jump:
984
+ rj_accepted = np.zeros((self.ntemps, self.nwalkers))
985
+ for repeat in range(self.num_repeats_rj):
986
+ rj_move = self._random.choice(
987
+ self.rj_moves, p=self.rj_weights
988
+ )
989
+
990
+ # Propose (Between models)
991
+ state, rj_accepted_out = rj_move.propose(model, state)
992
+ rj_accepted += rj_accepted_out
993
+ # Again commenting out this section: We do not control temperature on RJ moves
994
+ # if self.ntemps > 1:
995
+ # rj_swaps = rj_move.temperature_control.swaps_accepted
996
+ # else:
997
+ # rj_swaps = None
998
+ rj_swaps = None
999
+
1000
+ state.random_state = self.random_state
1001
+
1002
+ if tune:
1003
+ rj_move.tune(state, rj_accepted_out)
1004
+
1005
+ else:
1006
+ rj_accepted = None
1007
+ rj_swaps = None
1008
+
1009
+ # Save the new step
1010
+ if store and (i + 1) % checkpoint_step == 0:
1011
+ if self.track_moves:
1012
+ moves_accepted_fraction = {
1013
+ key: move_tmp.acceptance_fraction
1014
+ for key, move_tmp in self.all_moves.items()
1015
+ }
1016
+ else:
1017
+ moves_accepted_fraction = None
1018
+
1019
+ self.backend.save_step(
1020
+ state,
1021
+ accepted,
1022
+ rj_accepted=rj_accepted,
1023
+ swaps_accepted=in_model_swaps,
1024
+ moves_accepted_fraction=moves_accepted_fraction,
1025
+ )
1026
+
1027
+ # update after diagnostic and stopping check
1028
+ # if updating and using burn_in, need to make sure it does not use
1029
+ # previous chain samples since they are not stored.
1030
+ if (
1031
+ self.update_iterations > 0
1032
+ and self.update_fn is not None
1033
+ and (i + 1) % (self.update_iterations) == 0
1034
+ ):
1035
+ self.update_fn(i, state, self)
1036
+
1037
+ pbar.update(1)
1038
+ i += 1
1039
+
1040
+ # Yield the result as an iterator so that the user can do all
1041
+ # sorts of fun stuff with the results so far.
1042
+ yield state
1043
+
1044
+ def run_mcmc(
1045
+ self, initial_state, nsteps, burn=None, post_burn_update=False, **kwargs
1046
+ ):
1047
+ """
1048
+ Iterate :func:`sample` for ``nsteps`` iterations and return the result.
1049
+
1050
+ Args:
1051
+ initial_state (State or ndarray[ntemps, nwalkers, nleaves_max, ndim] or dict): The initial
1052
+ :class:`State` or positions of the walkers in the
1053
+ parameter space. If multiple branches used, must be dict with keys
1054
+ as the ``branch_names`` and values as the positions. If ``betas`` are
1055
+ provided in the state object, they will be loaded into the
1056
+ ``temperature_control``.
1057
+ nsteps (int): The number of steps to generate. The total number of proposals is ``nsteps * thin_by``.
1058
+ burn (int, optional): Number of burn steps to run before storing information. The ``thin_by`` kwarg is ignored when counting burn steps since there is no storage (equivalent to ``thin_by=1``).
1059
+ post_burn_update (bool, optional): If ``True``, run ``update_fn`` after burn in.
1060
+
1061
+ Other parameters are directly passed to :func:`sample`.
1062
+
1063
+ Returns:
1064
+ State: This method returns the most recent result from :func:`sample`.
1065
+
1066
+ Raises:
1067
+ ValueError: ``If initial_state`` is None and ``run_mcmc`` has never been called.
1068
+
1069
+ """
1070
+ if initial_state is None:
1071
+ if self._previous_state is None:
1072
+ raise ValueError(
1073
+ "Cannot have `initial_state=None` if run_mcmc has never "
1074
+ "been called."
1075
+ )
1076
+ initial_state = self._previous_state
1077
+
1078
+ # run burn in
1079
+ if burn is not None and burn != 0:
1080
+ # prepare kwargs that relate to burn
1081
+ burn_kwargs = deepcopy(kwargs)
1082
+ burn_kwargs["store"] = False
1083
+ burn_kwargs["thin_by"] = 1
1084
+ i = 0
1085
+ for results in self.sample(initial_state, iterations=burn, **burn_kwargs):
1086
+ i += 1
1087
+
1088
+ # run post-burn update
1089
+ if post_burn_update and self.update_fn is not None:
1090
+ self.update_fn(i, results, self)
1091
+
1092
+ initial_state = results
1093
+
1094
+ if nsteps == 0:
1095
+ return initial_state
1096
+
1097
+ results = None
1098
+
1099
+ i = 0
1100
+ for results in self.sample(initial_state, iterations=nsteps, **kwargs):
1101
+ # diagnostic plots
1102
+ # TODO: adjust diagnostic plots
1103
+ if self.plot_iterations > 0 and (i + 1) % (self.plot_iterations) == 0:
1104
+ self.plot_generator.generate_plot_info() # TODO: remove defaults
1105
+
1106
+ # check for stopping before updating
1107
+ if (
1108
+ self.stopping_iterations > 0
1109
+ and self.stopping_fn is not None
1110
+ and (i + 1) % (self.stopping_iterations) == 0
1111
+ ):
1112
+ stop = self.stopping_fn(i, results, self)
1113
+
1114
+ if stop:
1115
+ break
1116
+
1117
+ i += 1
1118
+
1119
+ # Store so that the ``initial_state=None`` case will work
1120
+ self._previous_state = results
1121
+
1122
+ return results
1123
+
1124
+ def compute_log_prior(self, coords, inds=None, supps=None, branch_supps=None):
1125
+ """Calculate the vector of log-prior for the walkers
1126
+
1127
+ Args:
1128
+ coords (dict): Keys are ``branch_names`` and values are
1129
+ the position np.arrays[ntemps, nwalkers, nleaves_max, ndim].
1130
+ This dictionary is created with the ``branches_coords`` attribute
1131
+ from :class:`State`.
1132
+ inds (dict, optional): Keys are ``branch_names`` and values are
1133
+ the ``inds`` np.arrays[ntemps, nwalkers, nleaves_max] that indicates
1134
+ which leaves are being used. This dictionary is created with the
1135
+ ``branches_inds`` attribute from :class:`State`.
1136
+ (default: ``None``)
1137
+
1138
+ Returns:
1139
+ np.ndarray[ntemps, nwalkers]: Prior Values
1140
+
1141
+ """
1142
+
1143
+ # get number of temperature and walkers
1144
+ ntemps, nwalkers, _, _ = coords[list(coords.keys())[0]].shape
1145
+
1146
+ if inds is None:
1147
+ # default use all sources
1148
+ inds = {
1149
+ name: np.full(coords[name].shape[:-1], True, dtype=bool)
1150
+ for name in coords
1151
+ }
1152
+
1153
+ # take information out of dict and spread to x1..xn
1154
+ x_in = {}
1155
+
1156
+ # for completely customizable priors
1157
+ if "all_models_together" in self.priors:
1158
+ prior_out = self.priors["all_models_together"].logpdf(
1159
+ coords, inds, supps=supps, branch_supps=branch_supps
1160
+ )
1161
+ assert prior_out.shape == (ntemps, nwalkers)
1162
+
1163
+ elif self.provide_groups:
1164
+ # get group information from the inds dict
1165
+ groups = groups_from_inds(inds)
1166
+
1167
+ # get the coordinates that are used
1168
+ for i, (name, coords_i) in enumerate(coords.items()):
1169
+ x_in[name] = coords_i[inds[name]]
1170
+
1171
+ prior_out = np.zeros((ntemps * nwalkers))
1172
+ for name in x_in:
1173
+ # get prior for individual binaries
1174
+ prior_out_temp = self.priors[name].logpdf(x_in[name])
1175
+
1176
+ # arrange prior values by groups
1177
+ # TODO: vectorize this?
1178
+ for i in np.unique(groups[name]):
1179
+ # which members are in the group i
1180
+ inds_temp = np.where(groups[name] == i)[0]
1181
+ # num_in_group = len(inds_temp)
1182
+
1183
+ # add to the prior for this group
1184
+ prior_out[i] += prior_out_temp[inds_temp].sum()
1185
+
1186
+ # reshape
1187
+ prior_out = prior_out.reshape(ntemps, nwalkers)
1188
+
1189
+ else:
1190
+ # flatten coordinate arrays
1191
+ for i, (name, coords_i) in enumerate(coords.items()):
1192
+ ntemps, nwalkers, nleaves_max, ndim = coords_i.shape
1193
+
1194
+ x_in[name] = coords_i.reshape(-1, ndim)
1195
+
1196
+ prior_out = np.zeros((ntemps, nwalkers))
1197
+ for name in x_in:
1198
+ ntemps, nwalkers, nleaves_max, ndim = coords[name].shape
1199
+ prior_out_temp = (
1200
+ self.priors[name]
1201
+ .logpdf(x_in[name])
1202
+ .reshape(ntemps, nwalkers, nleaves_max)
1203
+ )
1204
+
1205
+ # fix any infs / nans from binaries that are not being used (inds == False)
1206
+ prior_out_temp[~inds[name]] = 0.0
1207
+
1208
+ # vectorized because everything is rectangular (no groups to indicate model difference)
1209
+ prior_out += prior_out_temp.sum(axis=-1)
1210
+
1211
+ return prior_out
1212
+
1213
+ def compute_log_like(
1214
+ self, coords, inds=None, logp=None, supps=None, branch_supps=None
1215
+ ):
1216
+ """Calculate the vector of log-likelihood for the walkers
1217
+
1218
+ Args:
1219
+ coords (dict): Keys are ``branch_names`` and values are
1220
+ the position np.arrays[ntemps, nwalkers, nleaves_max, ndim].
1221
+ This dictionary is created with the ``branches_coords`` attribute
1222
+ from :class:`State`.
1223
+ inds (dict, optional): Keys are ``branch_names`` and values are
1224
+ the inds np.arrays[ntemps, nwalkers, nleaves_max] that indicates
1225
+ which leaves are being used. This dictionary is created with the
1226
+ ``branches_inds`` attribute from :class:`State`.
1227
+ (default: ``None``)
1228
+ logp (np.ndarray[ntemps, nwalkers], optional): Log prior values associated
1229
+ with all walkers. If not provided, it will be calculated because
1230
+ if a walker has logp = -inf, its likelihood is not calculated.
1231
+ This prevents evaluting likelihood outside the prior.
1232
+ (default: ``None``)
1233
+
1234
+ Returns:
1235
+ tuple: Carries log-likelihood and blob information.
1236
+ First entry is np.ndarray[ntemps, nwalkers] with values corresponding
1237
+ to the log likelihood of each walker. Second entry is ``blobs``.
1238
+
1239
+ Raises:
1240
+ ValueError: Infinite or NaN values in parameters.
1241
+
1242
+ """
1243
+
1244
+ # if inds not provided, use all
1245
+ if inds is None:
1246
+ inds = {
1247
+ name: np.full(coords[name].shape[:-1], True, dtype=bool)
1248
+ for name in coords
1249
+ }
1250
+
1251
+ # Check that the parameters are in physical ranges.
1252
+ for name, ptemp in coords.items():
1253
+ if np.any(np.isinf(ptemp[inds[name]])):
1254
+ raise ValueError("At least one parameter value was infinite")
1255
+ if np.any(np.isnan(ptemp[inds[name]])):
1256
+ raise ValueError("At least one parameter value was NaN")
1257
+
1258
+ # if no prior values are added, compute_prior
1259
+ # this is necessary to ensure Likelihood is not evaluated outside of the prior
1260
+ if logp is None:
1261
+ logp = self.compute_log_prior(
1262
+ coords, inds=inds, supps=supps, branch_supps=branch_supps
1263
+ )
1264
+
1265
+ # if all points are outside the prior
1266
+ if np.all(np.isinf(logp)):
1267
+ warnings.warn(
1268
+ "All points input for the Likelihood have a log prior of -inf."
1269
+ )
1270
+ return np.full_like(logp, -1e300), None
1271
+
1272
+ # do not run log likelihood where logp = -inf
1273
+ inds_copy = deepcopy(inds)
1274
+ inds_bad = np.where(np.isinf(logp))
1275
+ for key in inds_copy:
1276
+ inds_copy[key][inds_bad] = False
1277
+
1278
+ # if inds_keep in branch supps, indicate which to not keep
1279
+ if (
1280
+ branch_supps is not None
1281
+ and key in branch_supps
1282
+ and branch_supps[key] is not None
1283
+ and "inds_keep" in branch_supps[key]
1284
+ ):
1285
+ # TODO: indicate specialty of inds_keep in branch_supp
1286
+ branch_supps[key][inds_bad] = {"inds_keep": False}
1287
+
1288
+ # take information out of dict and spread to x1..xn
1289
+ x_in = {}
1290
+ if self.provide_supplemental:
1291
+ if supps is None and branch_supps is None:
1292
+ raise ValueError(
1293
+ """supps and branch_supps are both None. If self.provide_supplemental
1294
+ is True, must provide some supplemental information."""
1295
+ )
1296
+ if branch_supps is not None:
1297
+ branch_supps_in = {}
1298
+
1299
+ # determine groupings from inds
1300
+ groups = groups_from_inds(inds_copy)
1301
+
1302
+ # need to map group inds properly
1303
+ # this is the unique group indexes
1304
+ unique_groups = np.unique(
1305
+ np.concatenate([groups_i for groups_i in groups.values()])
1306
+ )
1307
+
1308
+ # this is the map to those indexes that are used in the likelihood
1309
+ groups_map = np.arange(len(unique_groups))
1310
+
1311
+ # get the indices with groups_map for the Likelihood
1312
+ ll_groups = {}
1313
+ for key, group in groups.items():
1314
+ # get unique groups in this sub-group (or branch)
1315
+ temp_unique_groups, inverse = np.unique(group, return_inverse=True)
1316
+
1317
+ # use groups_map by finding where temp_unique_groups overlaps with unique_groups
1318
+ keep_groups = groups_map[np.in1d(unique_groups, temp_unique_groups)]
1319
+
1320
+ # fill group information for Likelihood
1321
+ ll_groups[key] = keep_groups[inverse]
1322
+
1323
+ for i, (name, coords_i) in enumerate(coords.items()):
1324
+ ntemps, nwalkers, nleaves_max, ndim = coords_i.shape
1325
+ nwalkers_all = ntemps * nwalkers
1326
+
1327
+ # fill x_values properly into dictionary
1328
+ x_in[name] = coords_i[inds_copy[name]]
1329
+
1330
+ # prepare branch supplementals for each branch
1331
+ if self.provide_supplemental:
1332
+ if branch_supps is not None: # and
1333
+ if branch_supps[name] is not None:
1334
+ # index the branch supps
1335
+ # it will carry in a dictionary of information
1336
+ branch_supps_in[name] = branch_supps[name][inds_copy[name]]
1337
+ else:
1338
+ # fill with None if this branch does not have a supplemental
1339
+ branch_supps_in[name] = None
1340
+
1341
+ # deal with overall supplemental not specific to the branches
1342
+ if self.provide_supplemental:
1343
+ if supps is not None:
1344
+ # get the flattened supplemental
1345
+ # this will produce the shape (ntemps * nwalkers,...)
1346
+ temp = supps.flat
1347
+
1348
+ # unique_groups will properly index the flattened array
1349
+ supps_in = {
1350
+ name: values[unique_groups] for name, values in temp.items()
1351
+ }
1352
+
1353
+ # prepare group information
1354
+ # this gets the group_map indexing into a list
1355
+ groups_in = list(ll_groups.values())
1356
+
1357
+ # if only one branch, take the group array out of the list
1358
+ if len(groups_in) == 1:
1359
+ groups_in = groups_in[0]
1360
+
1361
+ # list of paramter arrays
1362
+ params_in = list(x_in.values())
1363
+
1364
+ # Likelihoods are vectorized across groups
1365
+ if self.vectorize:
1366
+ # prepare args list
1367
+ args_in = []
1368
+
1369
+ # when vectorizing, if params_in has one entry, take out of list
1370
+ if len(params_in) == 1:
1371
+ params_in = params_in[0]
1372
+
1373
+ # add parameters to args
1374
+ args_in.append(params_in)
1375
+
1376
+ # if providing groups, add to args
1377
+ if self.provide_groups:
1378
+ args_in.append(groups_in)
1379
+
1380
+ # prepare supplementals as kwargs to the Likelihood
1381
+ kwargs_in = {}
1382
+ if self.provide_supplemental:
1383
+ if supps is not None:
1384
+ kwargs_in["supps"] = supps_in
1385
+ if branch_supps is not None:
1386
+ # get list of branch_supps values
1387
+ branch_supps_in_2 = list(branch_supps_in.values())
1388
+
1389
+ # if only one entry, take out of list
1390
+ if len(branch_supps_in_2) == 1:
1391
+ kwargs_in["branch_supps"] = branch_supps_in_2[0]
1392
+
1393
+ else:
1394
+ kwargs_in["branch_supps"] = branch_supps_in_2
1395
+
1396
+ # provide args, kwargs as a tuple
1397
+ args_and_kwargs = (args_in, kwargs_in)
1398
+
1399
+ # get vectorized results
1400
+ results = self.log_like_fn(args_and_kwargs)
1401
+
1402
+ # each Likelihood is computed individually
1403
+ else:
1404
+ # if groups in is an array, need to put it in a list.
1405
+ if isinstance(groups_in, np.ndarray):
1406
+ groups_in = [groups_in]
1407
+
1408
+ # prepare input args for all Likelihood calls
1409
+ # to be spread out with map functions below
1410
+ args_in = []
1411
+
1412
+ # each individual group in the groups_map
1413
+ for group_i in groups_map:
1414
+ # args and kwargs for the individual Likelihood
1415
+ arg_i = [None for _ in self.branch_names]
1416
+ kwarg_i = {}
1417
+
1418
+ # iterate over the group information from the branches
1419
+ for branch_i, groups_in_set in enumerate(groups_in):
1420
+ # which entries in this branch are in the overall group tested
1421
+ # this accounts for multiple leaves (or model counts)
1422
+ inds_keep = np.where(groups_in_set == group_i)[0]
1423
+
1424
+ branch_name_i = self.branch_names[branch_i]
1425
+
1426
+ if inds_keep.shape[0] > 0:
1427
+ # get parameters
1428
+
1429
+ params = params_in[branch_i][inds_keep]
1430
+
1431
+ # if leaf count is constant and leaf count is 1
1432
+ # just give 1D parameters
1433
+ if not self.has_reversible_jump and params.shape[0] == 1:
1434
+ params = params[0]
1435
+
1436
+ # add them to the specific args for this Likelihood
1437
+ arg_i[branch_i] = params
1438
+ if self.provide_supplemental:
1439
+ if supps is not None:
1440
+ # supps are specific to each group
1441
+ kwarg_i["supps"] = {
1442
+ key: supps_in[key][group_i] for key in supps_in
1443
+ }
1444
+ if branch_supps is not None:
1445
+ # make sure there is a dictionary ready in this kwarg dictionary
1446
+ if "branch_supps" not in kwarg_i:
1447
+ kwarg_i["branch_supps"] = {}
1448
+
1449
+ # fill these branch supplementals for the specific group
1450
+ if branch_supps_in[branch_name_i] is not None:
1451
+ # get list of branch_supps values
1452
+ kwarg_i["branch_supps"][branch_name_i] = (
1453
+ branch_supps_in[branch_name_i][inds_keep]
1454
+ )
1455
+ else:
1456
+ kwarg_i["branch_supps"][branch_name_i] = None
1457
+
1458
+ # if only one model type, will take out of groups
1459
+ add_term = arg_i[0] if len(groups_in) == 1 else arg_i
1460
+
1461
+ # based on how this is dealth with in the _FunctionWrapper
1462
+ # add_term is wrapped in a list
1463
+ args_in.append([[add_term], kwarg_i])
1464
+
1465
+ # If the `pool` property of the sampler has been set (i.e. we want
1466
+ # to use `multiprocessing`), use the `pool`'s map method.
1467
+ # Otherwise, just use the built-in `map` function.
1468
+ if self.pool is not None:
1469
+ map_func = self.pool.map
1470
+
1471
+ else:
1472
+ map_func = map
1473
+
1474
+ # get results and turn into an array
1475
+ results = np.asarray(list(map_func(self.log_like_fn, args_in)))
1476
+
1477
+ assert isinstance(results, np.ndarray)
1478
+
1479
+ # -1e300 because -np.inf screws up state acceptance transfer in proposals
1480
+ ll = np.full(nwalkers_all, -1e300)
1481
+ inds_fix_zeros = np.delete(np.arange(nwalkers_all), unique_groups)
1482
+
1483
+ # make sure second dimension is not 1
1484
+ if results.ndim == 2 and results.shape[1] == 1:
1485
+ results = np.squeeze(results)
1486
+
1487
+ # parse the results if it has blobs
1488
+ if results.ndim == 2:
1489
+ # get the results and put into groups that were analyzed
1490
+ ll[unique_groups] = results[:, 0]
1491
+
1492
+ # fix groups that were not analyzed
1493
+ ll[inds_fix_zeros] = self.fill_zero_leaves_val
1494
+
1495
+ # deal with blobs
1496
+ blobs_out = np.zeros((nwalkers_all, results.shape[1] - 1))
1497
+ blobs_out[unique_groups] = results[:, 1:]
1498
+
1499
+ elif results.dtype == "object":
1500
+ # TODO: check blobs and add this capability
1501
+ raise NotImplementedError
1502
+
1503
+ else:
1504
+ # no blobs
1505
+ ll[unique_groups] = results
1506
+ ll[inds_fix_zeros] = self.fill_zero_leaves_val
1507
+
1508
+ blobs_out = None
1509
+
1510
+ if False: # self.provide_supplemental:
1511
+ # TODO: need to think about how to return information, we may need to add a function to do that
1512
+ if branch_supps is not None:
1513
+ for name_i, name in enumerate(branch_supps):
1514
+ if branch_supps[name] is not None:
1515
+ # TODO: better way to do this? limit to
1516
+ if "inds_keep" in branch_supps[name]:
1517
+ inds_back = branch_supps[name][:]["inds_keep"]
1518
+ inds_back2 = branch_supps_in[name]["inds_keep"]
1519
+ else:
1520
+ inds_back = inds_copy[name]
1521
+ inds_back2 = slice(None)
1522
+ try:
1523
+ branch_supps[name][inds_back] = {
1524
+ key: branch_supps_in_2[name_i][key][inds_back2]
1525
+ for key in branch_supps_in_2[name_i]
1526
+ }
1527
+ except ValueError:
1528
+ breakpoint()
1529
+ branch_supps[name][inds_back] = {
1530
+ key: branch_supps_in_2[name_i][key][inds_back2]
1531
+ for key in branch_supps_in_2[name_i]
1532
+ }
1533
+
1534
+ # return Likelihood and blobs
1535
+ return ll.reshape(ntemps, nwalkers), blobs_out
1536
+
1537
+ @property
1538
+ def acceptance_fraction(self):
1539
+ """The fraction of proposed steps that were accepted"""
1540
+ return self.backend.accepted / float(self.backend.iteration)
1541
+
1542
+ @property
1543
+ def rj_acceptance_fraction(self):
1544
+ """The fraction of proposed reversible jump steps that were accepted"""
1545
+ if self.has_reversible_jump:
1546
+ return self.backend.rj_accepted / float(self.backend.iteration)
1547
+ else:
1548
+ return None
1549
+
1550
+ @property
1551
+ def swap_acceptance_fraction(self):
1552
+ """The fraction of proposed temperature swaps that were accepted"""
1553
+ return self.backend.swaps_accepted / float(
1554
+ self.backend.iteration * self.nwalkers
1555
+ )
1556
+
1557
+ def get_chain(self, **kwargs):
1558
+ return self.get_value("chain", **kwargs)
1559
+
1560
+ get_chain.__doc__ = Backend.get_chain.__doc__
1561
+
1562
+ def get_blobs(self, **kwargs):
1563
+ return self.get_value("blobs", **kwargs)
1564
+
1565
+ get_blobs.__doc__ = Backend.get_blobs.__doc__
1566
+
1567
+ def get_log_like(self, **kwargs):
1568
+ return self.backend.get_log_like(**kwargs)
1569
+
1570
+ get_log_like.__doc__ = Backend.get_log_prior.__doc__
1571
+
1572
+ def get_log_prior(self, **kwargs):
1573
+ return self.backend.get_log_prior(**kwargs)
1574
+
1575
+ get_log_prior.__doc__ = Backend.get_log_prior.__doc__
1576
+
1577
+ def get_log_posterior(self, **kwargs):
1578
+ return self.backend.get_log_posterior(**kwargs)
1579
+
1580
+ get_log_posterior.__doc__ = Backend.get_log_posterior.__doc__
1581
+
1582
+ def get_inds(self, **kwargs):
1583
+ return self.get_value("inds", **kwargs)
1584
+
1585
+ get_inds.__doc__ = Backend.get_inds.__doc__
1586
+
1587
+ def get_nleaves(self, **kwargs):
1588
+ return self.backend.get_nleaves(**kwargs)
1589
+
1590
+ get_nleaves.__doc__ = Backend.get_nleaves.__doc__
1591
+
1592
+ def get_last_sample(self, **kwargs):
1593
+ return self.backend.get_last_sample()
1594
+
1595
+ get_last_sample.__doc__ = Backend.get_last_sample.__doc__
1596
+
1597
+ def get_betas(self, **kwargs):
1598
+ return self.backend.get_betas(**kwargs)
1599
+
1600
+ get_betas.__doc__ = Backend.get_betas.__doc__
1601
+
1602
+ def get_value(self, name, **kwargs):
1603
+ """Get a specific value"""
1604
+ return self.backend.get_value(name, **kwargs)
1605
+
1606
+ def get_autocorr_time(self, **kwargs):
1607
+ """Compute autocorrelation time through backend."""
1608
+ return self.backend.get_autocorr_time(**kwargs)
1609
+
1610
+ get_autocorr_time.__doc__ = Backend.get_autocorr_time.__doc__
1611
+
1612
+
1613
+ class _FunctionWrapper(object):
1614
+ """
1615
+ This is a hack to make the likelihood function pickleable when ``args``
1616
+ or ``kwargs`` are also included.
1617
+
1618
+ """
1619
+
1620
+ def __init__(
1621
+ self,
1622
+ f,
1623
+ args,
1624
+ kwargs,
1625
+ ):
1626
+ self.f = f
1627
+ self.args = [] if args is None else args
1628
+ self.kwargs = {} if kwargs is None else kwargs
1629
+
1630
+ def __call__(self, args_and_kwargs):
1631
+ """
1632
+ Internal function that takes a tuple (args, kwargs) for entrance into the Likelihood.
1633
+
1634
+ ``self.args`` and ``self.kwargs`` are added to these inputs.
1635
+
1636
+ """
1637
+
1638
+ args_in_add, kwargs_in_add = args_and_kwargs
1639
+
1640
+ try:
1641
+ args_in = args_in_add + type(args_in_add)(self.args)
1642
+ kwargs_in = {**kwargs_in_add, **self.kwargs}
1643
+
1644
+ out = self.f(*args_in, **kwargs_in)
1645
+ return out
1646
+
1647
+ except: # pragma: no cover
1648
+ import traceback
1649
+
1650
+ print("eryn: Exception while calling your likelihood function:")
1651
+ print(" args added:", args_in_add)
1652
+ print(" args:", self.args)
1653
+ print(" kwargs added:", kwargs_in_add)
1654
+ print(" kwargs:", self.kwargs)
1655
+ print(" exception:")
1656
+ traceback.print_exc()
1657
+ raise
1658
+
1659
+
1660
+ def walkers_independent(coords_in):
1661
+ """Determine if walkers are independent
1662
+
1663
+ Orginall from ``emcee``.
1664
+
1665
+ Args:
1666
+ coords_in (np.ndarray[ntemps, nwalkers, nleaves_max, ndim]): Coordinates of the walkers.
1667
+
1668
+ Returns:
1669
+ bool: If walkers are independent.
1670
+
1671
+ """
1672
+ # make sure it is 4-dimensional and reshape
1673
+ # so it groups by temperature and walker
1674
+ assert coords_in.ndim == 4
1675
+ ntemps, nwalkers, nleaves_max, ndim = coords_in.shape
1676
+ coords = coords_in.reshape(ntemps * nwalkers, nleaves_max * ndim)
1677
+
1678
+ # make sure all coordinates are finite
1679
+ if not np.all(np.isfinite(coords)):
1680
+ return False
1681
+
1682
+ # roughly determine covariance information
1683
+ C = coords - np.mean(coords, axis=0)[None, :]
1684
+ C_colmax = np.amax(np.abs(C), axis=0)
1685
+ if np.any(C_colmax == 0):
1686
+ return False
1687
+ C /= C_colmax
1688
+ C_colsum = np.sqrt(np.sum(C**2, axis=0))
1689
+ C /= C_colsum
1690
+ return np.linalg.cond(C.astype(float)) <= 1e8