eryn 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eryn/CMakeLists.txt +51 -0
- eryn/__init__.py +35 -0
- eryn/backends/__init__.py +20 -0
- eryn/backends/backend.py +1150 -0
- eryn/backends/hdfbackend.py +819 -0
- eryn/ensemble.py +1690 -0
- eryn/git_version.py.in +7 -0
- eryn/model.py +18 -0
- eryn/moves/__init__.py +42 -0
- eryn/moves/combine.py +135 -0
- eryn/moves/delayedrejection.py +229 -0
- eryn/moves/distgen.py +104 -0
- eryn/moves/distgenrj.py +222 -0
- eryn/moves/gaussian.py +190 -0
- eryn/moves/group.py +281 -0
- eryn/moves/groupstretch.py +120 -0
- eryn/moves/mh.py +193 -0
- eryn/moves/move.py +703 -0
- eryn/moves/mtdistgen.py +137 -0
- eryn/moves/mtdistgenrj.py +190 -0
- eryn/moves/multipletry.py +776 -0
- eryn/moves/red_blue.py +333 -0
- eryn/moves/rj.py +388 -0
- eryn/moves/stretch.py +231 -0
- eryn/moves/tempering.py +649 -0
- eryn/pbar.py +56 -0
- eryn/prior.py +452 -0
- eryn/state.py +775 -0
- eryn/tests/__init__.py +0 -0
- eryn/tests/test_eryn.py +1246 -0
- eryn/utils/__init__.py +10 -0
- eryn/utils/periodic.py +134 -0
- eryn/utils/stopping.py +164 -0
- eryn/utils/transform.py +226 -0
- eryn/utils/updates.py +69 -0
- eryn/utils/utility.py +329 -0
- eryn-1.2.0.dist-info/METADATA +167 -0
- eryn-1.2.0.dist-info/RECORD +39 -0
- eryn-1.2.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
try:
|
|
3
|
+
import cupy as xp
|
|
4
|
+
except (ModuleNotFoundError, ImportError):
|
|
5
|
+
pass
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from .group import GroupMove
|
|
10
|
+
from .stretch import StretchMove
|
|
11
|
+
|
|
12
|
+
__all__ = ["GroupStretchMove"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class GroupStretchMove(GroupMove, StretchMove):
|
|
16
|
+
"""Proposal like stretch with stationary compliment.
|
|
17
|
+
|
|
18
|
+
This move uses the stretch proposal method and math, but the compliment
|
|
19
|
+
of walkers used to propose a new point is chosen from a stationary group
|
|
20
|
+
rather than the current walkers in the ensemble.
|
|
21
|
+
|
|
22
|
+
This move allows for "stretch"-like proposal to be used in Reversible Jump MCMC.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
**kwargs (dict, optional): Keyword arguments passed to :class:`GroupMove` and
|
|
26
|
+
:class:`StretchMove`.
|
|
27
|
+
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, **kwargs):
|
|
31
|
+
GroupMove.__init__(self, **kwargs)
|
|
32
|
+
StretchMove.__init__(self, **kwargs)
|
|
33
|
+
|
|
34
|
+
def get_proposal(
|
|
35
|
+
self,
|
|
36
|
+
s_all,
|
|
37
|
+
random,
|
|
38
|
+
gibbs_ndim=None,
|
|
39
|
+
s_inds_all=None,
|
|
40
|
+
branch_supps=None,
|
|
41
|
+
**kwargs
|
|
42
|
+
):
|
|
43
|
+
"""Generate group stretch proposal coordinates
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
s_all (dict): Keys are ``branch_names`` and values are coordinates
|
|
47
|
+
for which a proposal is to be generated.
|
|
48
|
+
random (object): Random state object.
|
|
49
|
+
gibbs_ndim (int or np.ndarray, optional): If Gibbs sampling, this indicates
|
|
50
|
+
the true dimension. If given as an array, must have shape ``(ntemps, nwalkers)``.
|
|
51
|
+
See the tutorial for more information.
|
|
52
|
+
(default: ``None``)
|
|
53
|
+
s_inds_all (dict, optional): Keys are ``branch_names`` and values are
|
|
54
|
+
``inds`` arrays indicating which leaves are currently used. (default: ``None``)
|
|
55
|
+
branch_supps (dict, optional): Keys are ``branch_names`` and values are
|
|
56
|
+
:class:`BranchSupplemental` objects. For the group stretch,
|
|
57
|
+
``branch_supps`` are the best device for passing and tracking useful
|
|
58
|
+
information. (default: ``None``)
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
tuple: First entry is new positions. Second entry is detailed balance factors.
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
ValueError: Issues with dimensionality.
|
|
65
|
+
|
|
66
|
+
"""
|
|
67
|
+
# needs to be set before we reach the end
|
|
68
|
+
self.zz = None
|
|
69
|
+
random_number_generator = random if not self.use_gpu else self.xp.random
|
|
70
|
+
newpos = {}
|
|
71
|
+
|
|
72
|
+
# iterate over branches
|
|
73
|
+
for i, name in enumerate(s_all):
|
|
74
|
+
# get points to move
|
|
75
|
+
s = self.xp.asarray(s_all[name])
|
|
76
|
+
|
|
77
|
+
Ns = s.shape[1]
|
|
78
|
+
|
|
79
|
+
if s_inds_all is not None:
|
|
80
|
+
s_inds = self.xp.asarray(s_inds_all[name])
|
|
81
|
+
else:
|
|
82
|
+
s_inds = None
|
|
83
|
+
|
|
84
|
+
ntemps, nwalkers, nleaves_max, ndim_here = s.shape
|
|
85
|
+
|
|
86
|
+
# gets rid of any values of exactly zero
|
|
87
|
+
ndim_temp = nleaves_max * ndim_here
|
|
88
|
+
|
|
89
|
+
# need to properly handle ndim
|
|
90
|
+
if i == 0:
|
|
91
|
+
ndim = ndim_temp
|
|
92
|
+
Ns_check = Ns
|
|
93
|
+
|
|
94
|
+
else:
|
|
95
|
+
ndim += ndim_temp
|
|
96
|
+
if Ns_check != Ns:
|
|
97
|
+
raise ValueError("Different number of walkers across models.")
|
|
98
|
+
|
|
99
|
+
Ns = nwalkers
|
|
100
|
+
|
|
101
|
+
# get actual compliment values
|
|
102
|
+
c_temp = self.choose_c_vals(
|
|
103
|
+
name, s, s_inds=s_inds, branch_supps=branch_supps
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# use stretch to get new proposals
|
|
107
|
+
newpos[name] = self.get_new_points(
|
|
108
|
+
name, s, c_temp, Ns, s.shape, i, random_number_generator
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# proper factors
|
|
112
|
+
factors = (ndim - 1.0) * self.xp.log(self.zz)
|
|
113
|
+
if self.use_gpu and not self.return_gpu:
|
|
114
|
+
factors = factors.get()
|
|
115
|
+
|
|
116
|
+
if gibbs_ndim is not None:
|
|
117
|
+
# adjust factors in place
|
|
118
|
+
self.adjust_factors(factors, ndim, gibbs_ndim)
|
|
119
|
+
|
|
120
|
+
return newpos, factors
|
eryn/moves/mh.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
from ..state import State
|
|
6
|
+
from .move import Move
|
|
7
|
+
|
|
8
|
+
__all__ = ["MHMove"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MHMove(Move):
|
|
12
|
+
r"""A general Metropolis-Hastings proposal
|
|
13
|
+
|
|
14
|
+
Concrete implementations can be made by providing a ``get_proposal`` method.
|
|
15
|
+
For standard Gaussian Metropolis moves, :class:`moves.GaussianMove` can be used.
|
|
16
|
+
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, **kwargs):
|
|
20
|
+
|
|
21
|
+
Move.__init__(self, **kwargs)
|
|
22
|
+
|
|
23
|
+
def get_proposal(self, branches_coords, random, branches_inds=None, **kwargs):
|
|
24
|
+
"""Get proposal
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
branches_coords (dict): Keys are ``branch_names`` and values are
|
|
28
|
+
np.ndarray[ntemps, nwalkers, nleaves_max, ndim] representing
|
|
29
|
+
coordinates for walkers.
|
|
30
|
+
random (object): Current random state object.
|
|
31
|
+
branches_inds (dict, optional): Keys are ``branch_names`` and values are
|
|
32
|
+
np.ndarray[ntemps, nwalkers, nleaves_max] representing which
|
|
33
|
+
leaves are currently being used. (default: ``None``)
|
|
34
|
+
**kwargs (ignored): This is added for compatibility. It is ignored in this function.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
tuple: (Proposed coordinates, factors) -> (dict, np.ndarray)
|
|
38
|
+
|
|
39
|
+
Raises:
|
|
40
|
+
NotImplementedError: If proposal is not implemented in a subclass.
|
|
41
|
+
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
raise NotImplementedError("The proposal must be implemented by " "subclasses")
|
|
45
|
+
|
|
46
|
+
def setup(self, branches_coords):
|
|
47
|
+
"""Any setup for the proposal.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
branches_coords (dict): Keys are ``branch_names``. Values are
|
|
51
|
+
np.ndarray[ntemps, nwalkers, nleaves_max, ndim]. These are the curent
|
|
52
|
+
coordinates for all the walkers.
|
|
53
|
+
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def propose(self, model, state):
|
|
57
|
+
"""Use the move to generate a proposal and compute the acceptance
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
model (:class:`eryn.model.Model`): Carrier of sampler information.
|
|
61
|
+
state (:class:`State`): Current state of the sampler.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
:class:`State`: State of sampler after proposal is complete.
|
|
65
|
+
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
self.setup(state.branches_coords)
|
|
69
|
+
|
|
70
|
+
# get all branch names for gibbs setup
|
|
71
|
+
all_branch_names = list(state.branches.keys())
|
|
72
|
+
|
|
73
|
+
# get initial shape information
|
|
74
|
+
ntemps, nwalkers, _, _ = state.branches[all_branch_names[0]].shape
|
|
75
|
+
|
|
76
|
+
# in case there are no leaves yet
|
|
77
|
+
accepted = np.zeros((ntemps, nwalkers), dtype=bool)
|
|
78
|
+
|
|
79
|
+
# iterate through gibbs setup
|
|
80
|
+
for branch_names_run, inds_run in self.gibbs_sampling_setup_iterator(
|
|
81
|
+
all_branch_names
|
|
82
|
+
):
|
|
83
|
+
# setup supplemental information
|
|
84
|
+
if not np.all(
|
|
85
|
+
np.asarray(list(state.branches_supplemental.values())) == None
|
|
86
|
+
):
|
|
87
|
+
new_branch_supps = deepcopy(state.branches_supplemental)
|
|
88
|
+
else:
|
|
89
|
+
new_branch_supps = None
|
|
90
|
+
|
|
91
|
+
if state.supplemental is not None:
|
|
92
|
+
new_supps = deepcopy(state.supplemental)
|
|
93
|
+
else:
|
|
94
|
+
new_supps = None
|
|
95
|
+
|
|
96
|
+
# setup information according to gibbs info
|
|
97
|
+
(
|
|
98
|
+
coords_going_for_proposal,
|
|
99
|
+
inds_going_for_proposal,
|
|
100
|
+
at_least_one_proposal,
|
|
101
|
+
) = self.setup_proposals(
|
|
102
|
+
branch_names_run, inds_run, state.branches_coords, state.branches_inds
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# if no walkers are actually being proposed
|
|
106
|
+
if not at_least_one_proposal:
|
|
107
|
+
continue
|
|
108
|
+
|
|
109
|
+
self.current_model = model
|
|
110
|
+
self.current_state = state
|
|
111
|
+
|
|
112
|
+
# Get the move-specific proposal.
|
|
113
|
+
q, factors = self.get_proposal(
|
|
114
|
+
coords_going_for_proposal,
|
|
115
|
+
model.random,
|
|
116
|
+
branches_inds=inds_going_for_proposal,
|
|
117
|
+
supps=new_supps,
|
|
118
|
+
branch_supps=new_branch_supps,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# account for gibbs sampling
|
|
122
|
+
self.cleanup_proposals_gibbs(
|
|
123
|
+
branch_names_run, inds_run, q, state.branches_coords
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# order everything properly
|
|
127
|
+
q, _, new_branch_supps = self.ensure_ordering(
|
|
128
|
+
list(state.branches.keys()), q, state.branches_inds, new_branch_supps
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# if not wrapping with mutliple try (normal route)
|
|
132
|
+
if not hasattr(self, "mt_ll") or not hasattr(self, "mt_lp"):
|
|
133
|
+
# Compute prior of the proposed position
|
|
134
|
+
logp = model.compute_log_prior_fn(q, inds=state.branches_inds)
|
|
135
|
+
|
|
136
|
+
self.fix_logp_gibbs(
|
|
137
|
+
branch_names_run, inds_run, logp, state.branches_inds
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Compute the lnprobs of the proposed position.
|
|
141
|
+
# Can adjust supplementals in place
|
|
142
|
+
logl, new_blobs = model.compute_log_like_fn(
|
|
143
|
+
q,
|
|
144
|
+
inds=state.branches_inds,
|
|
145
|
+
logp=logp,
|
|
146
|
+
supps=new_supps,
|
|
147
|
+
branch_supps=new_branch_supps,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
else:
|
|
151
|
+
# if already computed in multiple try
|
|
152
|
+
logl = self.mt_ll
|
|
153
|
+
logp = self.mt_lp
|
|
154
|
+
new_blobs = None
|
|
155
|
+
|
|
156
|
+
# get log posterior
|
|
157
|
+
logP = self.compute_log_posterior(logl, logp)
|
|
158
|
+
|
|
159
|
+
# get previous information
|
|
160
|
+
prev_logl = state.log_like
|
|
161
|
+
|
|
162
|
+
prev_logp = state.log_prior
|
|
163
|
+
|
|
164
|
+
# takes care of tempering
|
|
165
|
+
prev_logP = self.compute_log_posterior(prev_logl, prev_logp)
|
|
166
|
+
|
|
167
|
+
# difference
|
|
168
|
+
lnpdiff = factors + logP - prev_logP
|
|
169
|
+
|
|
170
|
+
# draw against acceptance fraction
|
|
171
|
+
accepted = lnpdiff > np.log(model.random.rand(ntemps, nwalkers))
|
|
172
|
+
|
|
173
|
+
# Update the parameters
|
|
174
|
+
new_state = State(
|
|
175
|
+
q,
|
|
176
|
+
log_like=logl,
|
|
177
|
+
log_prior=logp,
|
|
178
|
+
blobs=new_blobs,
|
|
179
|
+
inds=state.branches_inds,
|
|
180
|
+
supplemental=new_supps,
|
|
181
|
+
branch_supplemental=new_branch_supps,
|
|
182
|
+
)
|
|
183
|
+
state = self.update(state, new_state, accepted)
|
|
184
|
+
|
|
185
|
+
# add to move-specific accepted information
|
|
186
|
+
self.accepted += accepted
|
|
187
|
+
self.num_proposals += 1
|
|
188
|
+
|
|
189
|
+
# temperature swaps
|
|
190
|
+
if self.temperature_control is not None:
|
|
191
|
+
state = self.temperature_control.temper_comps(state)
|
|
192
|
+
|
|
193
|
+
return state, accepted
|