eryn 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eryn/CMakeLists.txt +51 -0
- eryn/__init__.py +35 -0
- eryn/backends/__init__.py +20 -0
- eryn/backends/backend.py +1150 -0
- eryn/backends/hdfbackend.py +819 -0
- eryn/ensemble.py +1690 -0
- eryn/git_version.py.in +7 -0
- eryn/model.py +18 -0
- eryn/moves/__init__.py +42 -0
- eryn/moves/combine.py +135 -0
- eryn/moves/delayedrejection.py +229 -0
- eryn/moves/distgen.py +104 -0
- eryn/moves/distgenrj.py +222 -0
- eryn/moves/gaussian.py +190 -0
- eryn/moves/group.py +281 -0
- eryn/moves/groupstretch.py +120 -0
- eryn/moves/mh.py +193 -0
- eryn/moves/move.py +703 -0
- eryn/moves/mtdistgen.py +137 -0
- eryn/moves/mtdistgenrj.py +190 -0
- eryn/moves/multipletry.py +776 -0
- eryn/moves/red_blue.py +333 -0
- eryn/moves/rj.py +388 -0
- eryn/moves/stretch.py +231 -0
- eryn/moves/tempering.py +649 -0
- eryn/pbar.py +56 -0
- eryn/prior.py +452 -0
- eryn/state.py +775 -0
- eryn/tests/__init__.py +0 -0
- eryn/tests/test_eryn.py +1246 -0
- eryn/utils/__init__.py +10 -0
- eryn/utils/periodic.py +134 -0
- eryn/utils/stopping.py +164 -0
- eryn/utils/transform.py +226 -0
- eryn/utils/updates.py +69 -0
- eryn/utils/utility.py +329 -0
- eryn-1.2.0.dist-info/METADATA +167 -0
- eryn-1.2.0.dist-info/RECORD +39 -0
- eryn-1.2.0.dist-info/WHEEL +4 -0
eryn/moves/distgenrj.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from .rj import ReversibleJumpMove
|
|
6
|
+
from ..prior import ProbDistContainer
|
|
7
|
+
|
|
8
|
+
__all__ = ["DistributionGenerateRJ"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DistributionGenerateRJ(ReversibleJumpMove):
|
|
12
|
+
"""Generate Revesible-Jump proposals from a distribution.
|
|
13
|
+
|
|
14
|
+
The prior can be entered as the ``generate_dist`` to generate proposals directly from the prior.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
generate_dist (dict): Keys are branch names and values are :class:`ProbDistContainer` objects
|
|
18
|
+
that have ``logpdf`` and ``rvs`` methods. If you
|
|
19
|
+
*args (tuple, optional): Additional arguments to pass to parent classes.
|
|
20
|
+
**kwargs (dict, optional): Keyword arguments passed to parent classes.
|
|
21
|
+
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, generate_dist, *args, **kwargs):
|
|
25
|
+
|
|
26
|
+
# make sure all inputs are distribution Containers
|
|
27
|
+
for key in generate_dist:
|
|
28
|
+
if not isinstance(generate_dist[key], ProbDistContainer):
|
|
29
|
+
raise ValueError(
|
|
30
|
+
"Distributions need to be eryn.prior.ProbDistContiner object."
|
|
31
|
+
)
|
|
32
|
+
self.generate_dist = generate_dist
|
|
33
|
+
super(DistributionGenerateRJ, self).__init__(*args, **kwargs)
|
|
34
|
+
|
|
35
|
+
def get_model_change_proposal(self, inds, random, nleaves_min, nleaves_max):
|
|
36
|
+
"""Helper function for changing the model count by 1.
|
|
37
|
+
|
|
38
|
+
This helper function works with nested models where you want to add or remove
|
|
39
|
+
one leaf at a time.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
inds (np.ndarray): ``inds`` values for this specific branch with shape
|
|
43
|
+
``(ntemps, nwalkers, nleaves_max)``.
|
|
44
|
+
random (object): Current random state of the sampler.
|
|
45
|
+
nleaves_min (int): Minimum allowable leaf count for this branch.
|
|
46
|
+
nleaves_max (int): Maximum allowable leaf count for this branch.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
dict: Keys are ``"+1"`` and ``"-1"``. Values are indexing information.
|
|
50
|
+
``"+1"`` and ``"-1"`` indicate if a source is being added or removed, respectively.
|
|
51
|
+
The indexing information is a 2D array with shape ``(number changing, 3)``.
|
|
52
|
+
The length 3 is the index into each of the ``(ntemps, nwalkers, nleaves_max)``.
|
|
53
|
+
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
ntemps, nwalkers, _ = inds.shape
|
|
57
|
+
|
|
58
|
+
nleaves = inds.sum(axis=-1)
|
|
59
|
+
|
|
60
|
+
# choose whether to add or remove
|
|
61
|
+
if self.fix_change is None:
|
|
62
|
+
change = random.choice([-1, +1], size=nleaves.shape)
|
|
63
|
+
else:
|
|
64
|
+
change = np.full(nleaves.shape, self.fix_change)
|
|
65
|
+
|
|
66
|
+
# fix edge cases
|
|
67
|
+
change = (
|
|
68
|
+
change * ((nleaves != nleaves_min) & (nleaves != nleaves_max))
|
|
69
|
+
+ (+1) * (nleaves == nleaves_min)
|
|
70
|
+
+ (-1) * (nleaves == nleaves_max)
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# setup storage for this information
|
|
74
|
+
inds_for_change = {}
|
|
75
|
+
num_increases = np.sum(change == +1)
|
|
76
|
+
inds_for_change["+1"] = np.zeros((num_increases, 3), dtype=int)
|
|
77
|
+
num_decreases = np.sum(change == -1)
|
|
78
|
+
inds_for_change["-1"] = np.zeros((num_decreases, 3), dtype=int)
|
|
79
|
+
|
|
80
|
+
# TODO: not loop ? Is it necessary?
|
|
81
|
+
# TODO: might be able to subtract new inds from old inds type of thing
|
|
82
|
+
# fill the inds_for_change
|
|
83
|
+
increase_i = 0
|
|
84
|
+
decrease_i = 0
|
|
85
|
+
for t in range(ntemps):
|
|
86
|
+
for w in range(nwalkers):
|
|
87
|
+
# check if add or remove
|
|
88
|
+
change_tw = change[t][w]
|
|
89
|
+
# inds array from specific walker
|
|
90
|
+
inds_tw = inds[t][w]
|
|
91
|
+
|
|
92
|
+
# adding
|
|
93
|
+
if change_tw == +1:
|
|
94
|
+
# find where leaves are not currently used
|
|
95
|
+
inds_false = np.where(inds_tw == False)[0]
|
|
96
|
+
# decide which spot to add
|
|
97
|
+
ind_change = random.choice(inds_false)
|
|
98
|
+
# put in the indexes into inds arrays
|
|
99
|
+
inds_for_change["+1"][increase_i] = np.array(
|
|
100
|
+
[t, w, ind_change], dtype=int
|
|
101
|
+
)
|
|
102
|
+
# count increases
|
|
103
|
+
increase_i += 1
|
|
104
|
+
|
|
105
|
+
# removing
|
|
106
|
+
elif change_tw == -1:
|
|
107
|
+
# change_tw == -1
|
|
108
|
+
# find which leavs are used
|
|
109
|
+
inds_true = np.where(inds_tw == True)[0]
|
|
110
|
+
# choose which to remove
|
|
111
|
+
ind_change = random.choice(inds_true)
|
|
112
|
+
# add indexes into inds
|
|
113
|
+
if inds_for_change["-1"].shape[0] > 0:
|
|
114
|
+
inds_for_change["-1"][decrease_i] = np.array(
|
|
115
|
+
[t, w, ind_change], dtype=int
|
|
116
|
+
)
|
|
117
|
+
decrease_i += 1
|
|
118
|
+
# do not care currently about what we do with discarded coords, they just sit in the state
|
|
119
|
+
# model component number not changing
|
|
120
|
+
else:
|
|
121
|
+
pass
|
|
122
|
+
return inds_for_change
|
|
123
|
+
|
|
124
|
+
def get_proposal(
|
|
125
|
+
self, all_coords, all_inds, nleaves_min_all, nleaves_max_all, random, **kwargs
|
|
126
|
+
):
|
|
127
|
+
"""Make a proposal
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
all_coords (dict): Keys are ``branch_names``. Values are
|
|
131
|
+
np.ndarray[ntemps, nwalkers, nleaves_max, ndim]. These are the curent
|
|
132
|
+
coordinates for all the walkers.
|
|
133
|
+
all_inds (dict): Keys are ``branch_names``. Values are
|
|
134
|
+
np.ndarray[ntemps, nwalkers, nleaves_max]. These are the boolean
|
|
135
|
+
arrays marking which leaves are currently used within each walker.
|
|
136
|
+
nleaves_min_all (dict): Minimum values of leaf ount for each model. Must have same order as ``all_cords``.
|
|
137
|
+
nleaves_max_all (dict): Maximum values of leaf ount for each model. Must have same order as ``all_cords``.
|
|
138
|
+
random (object): Current random state of the sampler.
|
|
139
|
+
**kwargs (ignored): For modularity.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
tuple: Tuple containing proposal information.
|
|
143
|
+
First entry is the new coordinates as a dictionary with keys
|
|
144
|
+
as ``branch_names`` and values as
|
|
145
|
+
``double `` np.ndarray[ntemps, nwalkers, nleaves_max, ndim] containing
|
|
146
|
+
proposed coordinates. Second entry is the new ``inds`` array with
|
|
147
|
+
boolean values flipped for added or removed sources. Third entry
|
|
148
|
+
is the factors associated with the
|
|
149
|
+
proposal necessary for detailed balance. This is effectively
|
|
150
|
+
any term in the detailed balance fraction. +log of factors if
|
|
151
|
+
in the numerator. -log of factors if in the denominator.
|
|
152
|
+
|
|
153
|
+
"""
|
|
154
|
+
# prepare the output dictionaries
|
|
155
|
+
q = {}
|
|
156
|
+
new_inds = {}
|
|
157
|
+
all_inds_for_change = {}
|
|
158
|
+
|
|
159
|
+
# loop over the models included here
|
|
160
|
+
assert len(nleaves_min_all)
|
|
161
|
+
assert len(all_coords.keys()) == len(nleaves_max_all.keys())
|
|
162
|
+
for (name, inds) in all_inds.items():
|
|
163
|
+
nleaves_max = nleaves_max_all[name]
|
|
164
|
+
nleaves_min = nleaves_min_all[name]
|
|
165
|
+
if nleaves_min == nleaves_max:
|
|
166
|
+
continue
|
|
167
|
+
elif nleaves_min > nleaves_max:
|
|
168
|
+
raise ValueError("nleaves_min is greater than nleaves_max. Not allowed.")
|
|
169
|
+
|
|
170
|
+
# get the inds adjustment information
|
|
171
|
+
all_inds_for_change[name] = self.get_model_change_proposal(
|
|
172
|
+
inds, random, nleaves_min, nleaves_max
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# loop through branches and propose new points from the prio
|
|
176
|
+
for i, (name, coords, inds) in enumerate(
|
|
177
|
+
zip(all_coords.keys(), all_coords.values(), all_inds.values(),)
|
|
178
|
+
):
|
|
179
|
+
# put in base information
|
|
180
|
+
ntemps, nwalkers, nleaves_max, ndim = coords.shape
|
|
181
|
+
new_inds[name] = inds.copy()
|
|
182
|
+
q[name] = coords.copy()
|
|
183
|
+
|
|
184
|
+
if i == 0:
|
|
185
|
+
factors = np.zeros((ntemps, nwalkers))
|
|
186
|
+
|
|
187
|
+
# if not included
|
|
188
|
+
if name not in all_inds_for_change:
|
|
189
|
+
continue
|
|
190
|
+
|
|
191
|
+
# inds changing for this branch
|
|
192
|
+
inds_for_change = all_inds_for_change[name]
|
|
193
|
+
|
|
194
|
+
# adjust inds
|
|
195
|
+
|
|
196
|
+
# adjust deaths from True -> False
|
|
197
|
+
inds_here = tuple(inds_for_change["-1"].T)
|
|
198
|
+
new_inds[name][inds_here] = False
|
|
199
|
+
|
|
200
|
+
# factor is +log q()
|
|
201
|
+
current_generate_dist = self.generate_dist[name]
|
|
202
|
+
factors[inds_here[:2]] += +1 * current_generate_dist.logpdf(
|
|
203
|
+
q[name][inds_here]
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# adjust births from False -> True
|
|
207
|
+
inds_here = tuple(inds_for_change["+1"].T)
|
|
208
|
+
new_inds[name][inds_here] = True
|
|
209
|
+
|
|
210
|
+
# add coordinates for new leaves
|
|
211
|
+
current_generate_dist = self.generate_dist[name]
|
|
212
|
+
inds_here = tuple(inds_for_change["+1"].T)
|
|
213
|
+
num_inds_change = len(inds_here[0])
|
|
214
|
+
|
|
215
|
+
q[name][inds_here] = current_generate_dist.rvs(size=num_inds_change)
|
|
216
|
+
|
|
217
|
+
# factor is -log q()
|
|
218
|
+
factors[inds_here[:2]] += -1 * current_generate_dist.logpdf(
|
|
219
|
+
q[name][inds_here]
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
return q, new_inds, factors
|
eryn/moves/gaussian.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from .mh import MHMove
|
|
6
|
+
|
|
7
|
+
__all__ = ["GaussianMove"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class GaussianMove(MHMove):
|
|
11
|
+
"""A Metropolis step with a Gaussian proposal function.
|
|
12
|
+
|
|
13
|
+
This class is heavily based on the same class in ``emcee``.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
cov (dict): The covariance of the proposal function. The keys are branch names and the
|
|
17
|
+
values are covariance information. This information can be provided as a scalar,
|
|
18
|
+
vector, or matrix and the proposal will be assumed isotropic,
|
|
19
|
+
axis-aligned, or general, respectively.
|
|
20
|
+
mode (str, optional): Select the method used for updating parameters. This
|
|
21
|
+
can be one of ``"vector"``, ``"random"``, or ``"sequential"``. The
|
|
22
|
+
``"vector"`` mode updates all dimensions simultaneously,
|
|
23
|
+
``"random"`` randomly selects a dimension and only updates that
|
|
24
|
+
one, and ``"sequential"`` loops over dimensions and updates each
|
|
25
|
+
one in turn. (default: ``"vector"``)
|
|
26
|
+
factor (float, optional): If provided the proposal will be made with a
|
|
27
|
+
standard deviation uniformly selected from the range
|
|
28
|
+
``exp(U(-log(factor), log(factor))) * cov``. This is invalid for
|
|
29
|
+
the ``"vector"`` mode. (default: ``None``)
|
|
30
|
+
**kwargs (dict, optional): Kwargs for parent classes. (default: ``{}``)
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
ValueError: If the proposal dimensions are invalid or if any of any of
|
|
34
|
+
the other arguments are inconsistent.
|
|
35
|
+
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, cov_all, mode="vector", factor=None, **kwargs):
|
|
39
|
+
self.all_proposal = {}
|
|
40
|
+
for name, cov in cov_all.items():
|
|
41
|
+
# Parse the proposal type.
|
|
42
|
+
try:
|
|
43
|
+
float(cov)
|
|
44
|
+
|
|
45
|
+
except TypeError:
|
|
46
|
+
cov = np.atleast_1d(cov)
|
|
47
|
+
if len(cov.shape) == 1:
|
|
48
|
+
# A diagonal proposal was given.
|
|
49
|
+
ndim = len(cov)
|
|
50
|
+
proposal = _diagonal_proposal(np.sqrt(cov), factor, mode)
|
|
51
|
+
|
|
52
|
+
elif len(cov.shape) == 2 and cov.shape[0] == cov.shape[1]:
|
|
53
|
+
# The full, square covariance matrix was given.
|
|
54
|
+
ndim = cov.shape[0]
|
|
55
|
+
proposal = _proposal(cov, factor, mode)
|
|
56
|
+
|
|
57
|
+
else:
|
|
58
|
+
raise ValueError("Invalid proposal scale dimensions")
|
|
59
|
+
|
|
60
|
+
else:
|
|
61
|
+
# This was a scalar proposal.
|
|
62
|
+
ndim = None
|
|
63
|
+
proposal = _isotropic_proposal(np.sqrt(cov), factor, mode)
|
|
64
|
+
self.all_proposal[name] = proposal
|
|
65
|
+
|
|
66
|
+
super(GaussianMove, self).__init__(**kwargs)
|
|
67
|
+
|
|
68
|
+
def get_proposal(self, branches_coords, random, branches_inds=None, **kwargs):
|
|
69
|
+
"""Get proposal from Gaussian distribution
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
branches_coords (dict): Keys are ``branch_names`` and values are
|
|
73
|
+
np.ndarray[ntemps, nwalkers, nleaves_max, ndim] representing
|
|
74
|
+
coordinates for walkers.
|
|
75
|
+
random (object): Current random state object.
|
|
76
|
+
branches_inds (dict, optional): Keys are ``branch_names`` and values are
|
|
77
|
+
np.ndarray[ntemps, nwalkers, nleaves_max] representing which
|
|
78
|
+
leaves are currently being used. (default: ``None``)
|
|
79
|
+
**kwargs (ignored): This is added for compatibility. It is ignored in this function.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
tuple: (Proposed coordinates, factors) -> (dict, np.ndarray)
|
|
83
|
+
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
# initialize ouput
|
|
87
|
+
q = {}
|
|
88
|
+
for name, coords in zip(branches_coords.keys(), branches_coords.values()):
|
|
89
|
+
ntemps, nwalkers, nleaves_max, ndim = coords.shape
|
|
90
|
+
|
|
91
|
+
# setup inds accordingly
|
|
92
|
+
if branches_inds is None:
|
|
93
|
+
inds = np.ones((ntemps, nwalkers, nleaves_max), dtype=bool)
|
|
94
|
+
else:
|
|
95
|
+
inds = branches_inds[name]
|
|
96
|
+
|
|
97
|
+
# get the proposal for this branch
|
|
98
|
+
proposal_fn = self.all_proposal[name]
|
|
99
|
+
inds_here = np.where(inds == True)
|
|
100
|
+
|
|
101
|
+
# copy coords
|
|
102
|
+
q[name] = coords.copy()
|
|
103
|
+
|
|
104
|
+
# get new points
|
|
105
|
+
new_coords, _ = proposal_fn(coords[inds_here], random)
|
|
106
|
+
|
|
107
|
+
# put into coords in proper location
|
|
108
|
+
q[name][inds_here] = new_coords.copy()
|
|
109
|
+
|
|
110
|
+
# handle periodic parameters
|
|
111
|
+
if self.periodic is not None:
|
|
112
|
+
q = self.periodic.wrap(
|
|
113
|
+
{
|
|
114
|
+
name: tmp.reshape((ntemps * nwalkers,) + tmp.shape[-2:])
|
|
115
|
+
for name, tmp in q.items()
|
|
116
|
+
},
|
|
117
|
+
xp=self.xp,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
q = {
|
|
121
|
+
name: tmp.reshape(
|
|
122
|
+
(
|
|
123
|
+
ntemps,
|
|
124
|
+
nwalkers,
|
|
125
|
+
)
|
|
126
|
+
+ tmp.shape[-2:]
|
|
127
|
+
)
|
|
128
|
+
for name, tmp in q.items()
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
return q, np.zeros((ntemps, nwalkers))
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class _isotropic_proposal(object):
|
|
135
|
+
allowed_modes = ["vector", "random", "sequential"]
|
|
136
|
+
|
|
137
|
+
def __init__(self, scale, factor, mode):
|
|
138
|
+
self.index = 0
|
|
139
|
+
self.scale = scale
|
|
140
|
+
self.invscale = np.linalg.inv(np.linalg.cholesky(scale))
|
|
141
|
+
if factor is None:
|
|
142
|
+
self._log_factor = None
|
|
143
|
+
else:
|
|
144
|
+
if factor < 1.0:
|
|
145
|
+
raise ValueError("'factor' must be >= 1.0")
|
|
146
|
+
self._log_factor = np.log(factor)
|
|
147
|
+
|
|
148
|
+
if mode not in self.allowed_modes:
|
|
149
|
+
raise ValueError(
|
|
150
|
+
("'{0}' is not a recognized mode. " "Please select from: {1}").format(
|
|
151
|
+
mode, self.allowed_modes
|
|
152
|
+
)
|
|
153
|
+
)
|
|
154
|
+
self.mode = mode
|
|
155
|
+
|
|
156
|
+
def get_factor(self, rng):
|
|
157
|
+
if self._log_factor is None:
|
|
158
|
+
return 1.0
|
|
159
|
+
return np.exp(rng.uniform(-self._log_factor, self._log_factor))
|
|
160
|
+
|
|
161
|
+
def get_updated_vector(self, rng, x0):
|
|
162
|
+
return x0 + self.get_factor(rng) * self.scale * rng.randn(*(x0.shape))
|
|
163
|
+
|
|
164
|
+
def __call__(self, x0, rng):
|
|
165
|
+
nw, nd = x0.shape
|
|
166
|
+
xnew = self.get_updated_vector(rng, x0)
|
|
167
|
+
if self.mode == "random":
|
|
168
|
+
m = (range(nw), rng.randint(x0.shape[-1], size=nw))
|
|
169
|
+
elif self.mode == "sequential":
|
|
170
|
+
m = (range(nw), self.index % nd + np.zeros(nw, dtype=int))
|
|
171
|
+
self.index = (self.index + 1) % nd
|
|
172
|
+
else:
|
|
173
|
+
return xnew, np.zeros(nw)
|
|
174
|
+
x = np.array(x0)
|
|
175
|
+
x[m] = xnew[m]
|
|
176
|
+
return x, np.zeros(nw)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class _diagonal_proposal(_isotropic_proposal):
|
|
180
|
+
def get_updated_vector(self, rng, x0):
|
|
181
|
+
return x0 + self.get_factor(rng) * self.scale * rng.randn(*(x0.shape))
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class _proposal(_isotropic_proposal):
|
|
185
|
+
allowed_modes = ["vector"]
|
|
186
|
+
|
|
187
|
+
def get_updated_vector(self, rng, x0):
|
|
188
|
+
return x0 + self.get_factor(rng) * rng.multivariate_normal(
|
|
189
|
+
np.zeros(len(self.scale)), self.scale, size=len(x0)
|
|
190
|
+
)
|
eryn/moves/group.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from abc import ABC
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
import numpy as np
|
|
5
|
+
import warnings
|
|
6
|
+
|
|
7
|
+
from ..state import BranchSupplemental, State
|
|
8
|
+
from .move import Move
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
__all__ = ["GroupMove"]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class GroupMove(Move, ABC):
|
|
15
|
+
"""
|
|
16
|
+
A "group" ensemble move based on the :class:`eryn.moves.RedBlueMove`.
|
|
17
|
+
|
|
18
|
+
In moves like the :class:`eryn.moves.StretchMove`, the complimentary
|
|
19
|
+
group for which the proposal is used is chosen from the current points in
|
|
20
|
+
the ensemble. In "group" moves the complimentary group is a stationary group
|
|
21
|
+
that is updated every `n_iter_update` iterations. This update is performed with the
|
|
22
|
+
last set of coordinates to maintain detailed balance.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
nfriends (int, optional): The number of friends to draw from as the complimentary
|
|
26
|
+
ensemble. This group is determined from the stationary group. If ``None``, it will
|
|
27
|
+
be set to the number of walkers. (default: ``None``)
|
|
28
|
+
n_iter_update (int, optional): Number of iterations to run before updating the
|
|
29
|
+
stationary distribution. (default: 100).
|
|
30
|
+
live_dangerously (bool, optional): If ``True``, allow for ``n_iter_update == 1``.
|
|
31
|
+
(deafault: ``False``)
|
|
32
|
+
|
|
33
|
+
``kwargs`` are passed to :class:`Move` class.
|
|
34
|
+
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self, nfriends=None, n_iter_update=100, live_dangerously=False, **kwargs
|
|
39
|
+
):
|
|
40
|
+
|
|
41
|
+
Move.__init__(self, **kwargs)
|
|
42
|
+
self.nfriends = int(nfriends)
|
|
43
|
+
self.n_iter_update = n_iter_update
|
|
44
|
+
|
|
45
|
+
if self.n_iter_update <= 1 and not live_dangerously:
|
|
46
|
+
raise ValueError("n_iter_update must be greather than or equal to 2.")
|
|
47
|
+
|
|
48
|
+
self.iter = 0
|
|
49
|
+
|
|
50
|
+
def find_friends(self, name, s, s_inds=None, branch_supps=None):
|
|
51
|
+
"""Function for finding friends.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
name (str): Branch name for proposal coordinates.
|
|
55
|
+
s (np.ndarray): Coordinates array for the points to be moved.
|
|
56
|
+
s_inds (np.ndarray, optional): ``inds`` arrays that represent which leaves are present.
|
|
57
|
+
(default: ``None``)
|
|
58
|
+
branch_supps (dict, optional): Keys are ``branch_names`` and values are
|
|
59
|
+
:class:`BranchSupplemental` objects. For group proposals,
|
|
60
|
+
``branch_supps`` are the best device for passing and tracking useful
|
|
61
|
+
information. (default: ``None``)
|
|
62
|
+
|
|
63
|
+
Return:
|
|
64
|
+
np.ndarray: Complimentary values.
|
|
65
|
+
|
|
66
|
+
"""
|
|
67
|
+
raise NotImplementedError
|
|
68
|
+
|
|
69
|
+
def choose_c_vals(self, name, s, s_inds=None, branch_supps=None):
|
|
70
|
+
"""Get the complimentary values."""
|
|
71
|
+
return self.find_friends(name, s, s_inds=s_inds, branch_supps=branch_supps)
|
|
72
|
+
|
|
73
|
+
def setup(self, branches):
|
|
74
|
+
"""Any setup necessary for the proposal"""
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
def setup_friends(self, branches):
|
|
78
|
+
"""Setup anything for finding friends.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
branches (dict): Dictionary with all the current branches in the sampler.
|
|
82
|
+
|
|
83
|
+
"""
|
|
84
|
+
raise NotImplementedError
|
|
85
|
+
|
|
86
|
+
def fix_friends(self, branches):
|
|
87
|
+
"""Fix any friends that were born through RJ.
|
|
88
|
+
|
|
89
|
+
This function is not required. If not implemented, it will just return immediately.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
branches (dict): Dictionary with all the current branches in the sampler.
|
|
93
|
+
|
|
94
|
+
"""
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def get_proposal(self, s_all, random, gibbs_ndim=None, s_inds_all=None, **kwargs):
|
|
99
|
+
"""Generate group stretch proposal coordinates
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
s_all (dict): Keys are ``branch_names`` and values are coordinates
|
|
103
|
+
for which a proposal is to be generated.
|
|
104
|
+
random (object): Random state object.
|
|
105
|
+
gibbs_ndim (int or np.ndarray, optional): If Gibbs sampling, this indicates
|
|
106
|
+
the true dimension. If given as an array, must have shape ``(ntemps, nwalkers)``.
|
|
107
|
+
See the tutorial for more information.
|
|
108
|
+
(default: ``None``)
|
|
109
|
+
s_inds_all (dict, optional): Keys are ``branch_names`` and values are
|
|
110
|
+
``inds`` arrays indicating which leaves are currently used. (default: ``None``)
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
tuple: First entry is new positions. Second entry is detailed balance factors.
|
|
114
|
+
|
|
115
|
+
Raises:
|
|
116
|
+
ValueError: Issues with dimensionality.
|
|
117
|
+
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
raise NotImplementedError("The proposal must be implemented by " "subclasses")
|
|
121
|
+
|
|
122
|
+
def propose(self, model, state):
|
|
123
|
+
"""Use the move to generate a proposal and compute the acceptance
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
model (:class:`eryn.model.Model`): Carrier of sampler information.
|
|
127
|
+
state (:class:`State`): Current state of the sampler.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
tuple: (state, accepted)
|
|
131
|
+
The first return is the state of the sampler after the move.
|
|
132
|
+
The second return value is the accepted count array.
|
|
133
|
+
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
# Check that the dimensions are compatible.
|
|
137
|
+
ndim_total = 0
|
|
138
|
+
for branch in state.branches.values():
|
|
139
|
+
ntemps, nwalkers, nleaves_, ndim_ = branch.shape
|
|
140
|
+
ndim_total += ndim_ * nleaves_
|
|
141
|
+
|
|
142
|
+
if self.nfriends is None:
|
|
143
|
+
self.nfriends = nwalkers
|
|
144
|
+
|
|
145
|
+
# Run any move-specific setup.
|
|
146
|
+
self.setup(state.branches)
|
|
147
|
+
|
|
148
|
+
if self.iter == 0 or self.iter % self.n_iter_update == 0:
|
|
149
|
+
self.setup_friends(state.branches)
|
|
150
|
+
|
|
151
|
+
if self.iter != 0 and self.iter % self.n_iter_update == 0:
|
|
152
|
+
# store old values to maintain detailed balance when updating
|
|
153
|
+
old_branches = deepcopy(state.branches)
|
|
154
|
+
|
|
155
|
+
# fix any friends that may have come through rj
|
|
156
|
+
if self.iter != 0 and self.iter % self.n_iter_update != 0:
|
|
157
|
+
self.fix_friends(state.branches)
|
|
158
|
+
|
|
159
|
+
# Split the ensemble in half and iterate over these two halves.
|
|
160
|
+
accepted = np.zeros((ntemps, nwalkers), dtype=bool)
|
|
161
|
+
|
|
162
|
+
all_branch_names = list(state.branches.keys())
|
|
163
|
+
|
|
164
|
+
# get gibbs sampling information
|
|
165
|
+
for branch_names_run, inds_run in self.gibbs_sampling_setup_iterator(
|
|
166
|
+
all_branch_names
|
|
167
|
+
):
|
|
168
|
+
|
|
169
|
+
if not np.all(
|
|
170
|
+
np.asarray(list(state.branches_supplemental.values())) == None
|
|
171
|
+
):
|
|
172
|
+
new_branch_supps = deepcopy(state.branches_supplemental)
|
|
173
|
+
else:
|
|
174
|
+
new_branch_supps = None
|
|
175
|
+
|
|
176
|
+
if state.supplemental is not None:
|
|
177
|
+
new_supps = deepcopy(state.supplemental)
|
|
178
|
+
else:
|
|
179
|
+
new_supps = None
|
|
180
|
+
|
|
181
|
+
# setup proposals based on Gibbs sampling
|
|
182
|
+
(
|
|
183
|
+
coords_going_for_proposal,
|
|
184
|
+
inds_going_for_proposal,
|
|
185
|
+
at_least_one_proposal,
|
|
186
|
+
) = self.setup_proposals(
|
|
187
|
+
branch_names_run, inds_run, state.branches_coords, state.branches_inds
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
if not at_least_one_proposal:
|
|
191
|
+
continue
|
|
192
|
+
|
|
193
|
+
# need to trick stretch proposal into using the dimenionality associated
|
|
194
|
+
# with Gibbs sampling if it is being used
|
|
195
|
+
gibbs_ndim = 0
|
|
196
|
+
for brn, ir in zip(branch_names_run, inds_run):
|
|
197
|
+
if ir is not None:
|
|
198
|
+
gibbs_ndim += ir.sum()
|
|
199
|
+
else:
|
|
200
|
+
# nleaves * ndim
|
|
201
|
+
gibbs_ndim += np.prod(state.branches[brn].shape[-2:])
|
|
202
|
+
|
|
203
|
+
self.current_model = model
|
|
204
|
+
self.current_state = state
|
|
205
|
+
# Get the move-specific proposal.
|
|
206
|
+
q, factors = self.get_proposal(
|
|
207
|
+
coords_going_for_proposal,
|
|
208
|
+
model.random,
|
|
209
|
+
gibbs_ndim=gibbs_ndim,
|
|
210
|
+
s_inds_all=inds_going_for_proposal,
|
|
211
|
+
branch_supps=new_branch_supps,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# account for gibbs sampling
|
|
215
|
+
self.cleanup_proposals_gibbs(
|
|
216
|
+
branch_names_run, inds_run, q, state.branches_coords
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# order everything properly
|
|
220
|
+
q, _, new_branch_supps = self.ensure_ordering(
|
|
221
|
+
list(state.branches.keys()), q, state.branches_inds, new_branch_supps
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# Compute prior of the proposed position
|
|
225
|
+
# new_inds_prior is adjusted if product-space is used
|
|
226
|
+
logp = model.compute_log_prior_fn(q, inds=state.branches_inds)
|
|
227
|
+
|
|
228
|
+
self.fix_logp_gibbs(branch_names_run, inds_run, logp, state.branches_inds)
|
|
229
|
+
|
|
230
|
+
# Can adjust supplementals in place
|
|
231
|
+
logl, new_blobs = model.compute_log_like_fn(
|
|
232
|
+
q,
|
|
233
|
+
inds=state.branches_inds,
|
|
234
|
+
logp=logp,
|
|
235
|
+
supps=new_supps,
|
|
236
|
+
branch_supps=new_branch_supps,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
# get log posterior
|
|
240
|
+
logP = self.compute_log_posterior(logl, logp)
|
|
241
|
+
|
|
242
|
+
# get previous information
|
|
243
|
+
prev_logl = state.log_like
|
|
244
|
+
|
|
245
|
+
prev_logp = state.log_prior
|
|
246
|
+
|
|
247
|
+
# takes care of tempering
|
|
248
|
+
prev_logP = self.compute_log_posterior(prev_logl, prev_logp)
|
|
249
|
+
|
|
250
|
+
# difference
|
|
251
|
+
lnpdiff = factors + logP - prev_logP
|
|
252
|
+
|
|
253
|
+
# draw against acceptance fraction
|
|
254
|
+
accepted = lnpdiff > np.log(model.random.rand(ntemps, nwalkers))
|
|
255
|
+
|
|
256
|
+
# Update the parameters
|
|
257
|
+
new_state = State(
|
|
258
|
+
q,
|
|
259
|
+
log_like=logl,
|
|
260
|
+
log_prior=logp,
|
|
261
|
+
blobs=new_blobs,
|
|
262
|
+
inds=state.branches_inds,
|
|
263
|
+
supplemental=new_supps,
|
|
264
|
+
branch_supplemental=new_branch_supps,
|
|
265
|
+
)
|
|
266
|
+
state = self.update(state, new_state, accepted)
|
|
267
|
+
|
|
268
|
+
# add to move-specific accepted information
|
|
269
|
+
self.accepted += accepted
|
|
270
|
+
self.num_proposals += 1
|
|
271
|
+
|
|
272
|
+
if self.temperature_control is not None:
|
|
273
|
+
state = self.temperature_control.temper_comps(state)
|
|
274
|
+
|
|
275
|
+
if self.iter != 0 and self.iter % self.n_iter_update == 0:
|
|
276
|
+
# use old values to maintain detailed balance when updating
|
|
277
|
+
# nfriends
|
|
278
|
+
self.setup_friends(old_branches)
|
|
279
|
+
|
|
280
|
+
self.iter += 1
|
|
281
|
+
return state, accepted
|