eryn 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eryn/CMakeLists.txt +51 -0
- eryn/__init__.py +35 -0
- eryn/backends/__init__.py +20 -0
- eryn/backends/backend.py +1150 -0
- eryn/backends/hdfbackend.py +819 -0
- eryn/ensemble.py +1690 -0
- eryn/git_version.py.in +7 -0
- eryn/model.py +18 -0
- eryn/moves/__init__.py +42 -0
- eryn/moves/combine.py +135 -0
- eryn/moves/delayedrejection.py +229 -0
- eryn/moves/distgen.py +104 -0
- eryn/moves/distgenrj.py +222 -0
- eryn/moves/gaussian.py +190 -0
- eryn/moves/group.py +281 -0
- eryn/moves/groupstretch.py +120 -0
- eryn/moves/mh.py +193 -0
- eryn/moves/move.py +703 -0
- eryn/moves/mtdistgen.py +137 -0
- eryn/moves/mtdistgenrj.py +190 -0
- eryn/moves/multipletry.py +776 -0
- eryn/moves/red_blue.py +333 -0
- eryn/moves/rj.py +388 -0
- eryn/moves/stretch.py +231 -0
- eryn/moves/tempering.py +649 -0
- eryn/pbar.py +56 -0
- eryn/prior.py +452 -0
- eryn/state.py +775 -0
- eryn/tests/__init__.py +0 -0
- eryn/tests/test_eryn.py +1246 -0
- eryn/utils/__init__.py +10 -0
- eryn/utils/periodic.py +134 -0
- eryn/utils/stopping.py +164 -0
- eryn/utils/transform.py +226 -0
- eryn/utils/updates.py +69 -0
- eryn/utils/utility.py +329 -0
- eryn-1.2.0.dist-info/METADATA +167 -0
- eryn-1.2.0.dist-info/RECORD +39 -0
- eryn-1.2.0.dist-info/WHEEL +4 -0
eryn/git_version.py.in
ADDED
eryn/model.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
from collections import namedtuple
|
|
4
|
+
|
|
5
|
+
__all__ = ["Model"]
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
Model = namedtuple(
|
|
9
|
+
"Model",
|
|
10
|
+
(
|
|
11
|
+
"log_like_fn",
|
|
12
|
+
"compute_log_like_fn",
|
|
13
|
+
"compute_log_prior_fn",
|
|
14
|
+
"temperature_control",
|
|
15
|
+
"map_fn",
|
|
16
|
+
"random",
|
|
17
|
+
),
|
|
18
|
+
)
|
eryn/moves/__init__.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
# from .de import DEMove
|
|
4
|
+
# from .de_snooker import DESnookerMove
|
|
5
|
+
from .gaussian import GaussianMove
|
|
6
|
+
|
|
7
|
+
# from .kde import KDEMove
|
|
8
|
+
from .mh import MHMove
|
|
9
|
+
from .move import Move
|
|
10
|
+
from .red_blue import RedBlueMove
|
|
11
|
+
from .stretch import StretchMove
|
|
12
|
+
|
|
13
|
+
# from .walk import WalkMove
|
|
14
|
+
from .tempering import TemperatureControl
|
|
15
|
+
from .rj import ReversibleJumpMove
|
|
16
|
+
from .distgenrj import DistributionGenerateRJ
|
|
17
|
+
from .distgen import DistributionGenerate
|
|
18
|
+
from .multipletry import MultipleTryMove
|
|
19
|
+
from .group import GroupMove
|
|
20
|
+
from .groupstretch import GroupStretchMove
|
|
21
|
+
from .combine import CombineMove
|
|
22
|
+
|
|
23
|
+
# from .basicmodelswaprj import BasicSymmetricModelSwapRJMove
|
|
24
|
+
from .mtdistgen import MTDistGenMove
|
|
25
|
+
from .mtdistgenrj import MTDistGenMoveRJ
|
|
26
|
+
from .multipletry import MultipleTryMove
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
"Move",
|
|
30
|
+
"MHMove",
|
|
31
|
+
"GaussianMove",
|
|
32
|
+
"RedBlueMove",
|
|
33
|
+
"StretchMove",
|
|
34
|
+
"DistributionGenerateRJ",
|
|
35
|
+
"DistributionGenerate",
|
|
36
|
+
"TemperatureControl",
|
|
37
|
+
"ReversibleJumpMove",
|
|
38
|
+
"MultipleTryMove",
|
|
39
|
+
"GroupMove",
|
|
40
|
+
"GroupStretchMove",
|
|
41
|
+
"CombineMove",
|
|
42
|
+
]
|
eryn/moves/combine.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
from ..state import BranchSupplemental
|
|
4
|
+
from . import Move
|
|
5
|
+
import numpy as np
|
|
6
|
+
import tqdm
|
|
7
|
+
|
|
8
|
+
__all__ = ["CombineMove"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class CombineMove(Move):
|
|
12
|
+
"""Move that combines specific moves in order.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
moves (list): List of moves, similar to how ``moves`` is submitted
|
|
16
|
+
to :class:`eryn.ensemble.EnsembleSampler`. If weights are provided,
|
|
17
|
+
they will be ignored.
|
|
18
|
+
*args (tuple, optional): args to be passed to :class:`Move`.
|
|
19
|
+
verbose (bool, optional): If ``True``, use ``tqdm`` to show progress throught steps.
|
|
20
|
+
This can be very helpful when debugging.
|
|
21
|
+
**kwargs (dict, optional): kwargs to be passed to :class:`Move`.
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, moves, *args, verbose=False, **kwargs):
|
|
27
|
+
# store moves
|
|
28
|
+
self.moves = moves
|
|
29
|
+
self.verbose = verbose
|
|
30
|
+
Move.__init__(self, *args, **kwargs)
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def accepted(self):
|
|
34
|
+
"""Accepted counts for each move."""
|
|
35
|
+
if self._accepted is None:
|
|
36
|
+
raise ValueError(
|
|
37
|
+
"accepted must be inititalized with the init_accepted function if you want to use it."
|
|
38
|
+
)
|
|
39
|
+
# this retrieves the accepted arrays from the individual moves
|
|
40
|
+
accepted_out = [move.accepted for move in self.moves]
|
|
41
|
+
return accepted_out
|
|
42
|
+
|
|
43
|
+
@accepted.setter
|
|
44
|
+
def accepted(self, accepted):
|
|
45
|
+
# set the accepted arrays for all moves
|
|
46
|
+
assert isinstance(accepted, np.ndarray)
|
|
47
|
+
for move in self.moves:
|
|
48
|
+
move.accepted = accepted.copy()
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def acceptance_fraction(self):
|
|
52
|
+
"""get acceptance fraction averaged over all moves"""
|
|
53
|
+
acceptance_fraction_out = np.mean(
|
|
54
|
+
[move.acceptance_fraction for move in self.moves], axis=0
|
|
55
|
+
)
|
|
56
|
+
return acceptance_fraction_out
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def acceptance_fraction_separate(self):
|
|
60
|
+
"""get acceptance fraction from each move"""
|
|
61
|
+
acceptance_fraction_out = [move.acceptance_fraction for move in self.moves]
|
|
62
|
+
return acceptance_fraction_out
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def temperature_control(self):
|
|
66
|
+
"""temperature controller"""
|
|
67
|
+
return self._temperature_control
|
|
68
|
+
|
|
69
|
+
@temperature_control.setter
|
|
70
|
+
def temperature_control(self, temperature_control):
|
|
71
|
+
# when setting the temperature control object
|
|
72
|
+
# need to apply it to each move
|
|
73
|
+
for i, move in enumerate(self.moves):
|
|
74
|
+
# if weights were provided with moves, remove move class
|
|
75
|
+
if isinstance(move, tuple):
|
|
76
|
+
move = move[0]
|
|
77
|
+
# set temperature control for each move
|
|
78
|
+
move.temperature_control = temperature_control
|
|
79
|
+
|
|
80
|
+
# main temperature control here for reference
|
|
81
|
+
self._temperature_control = temperature_control
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def periodic(self):
|
|
85
|
+
"""periodic parameter information"""
|
|
86
|
+
return self._periodic
|
|
87
|
+
|
|
88
|
+
@periodic.setter
|
|
89
|
+
def periodic(self, periodic):
|
|
90
|
+
# when setting the periodic parameters
|
|
91
|
+
# need to apply it to each move
|
|
92
|
+
for i, move in enumerate(self.moves):
|
|
93
|
+
# if weights were provided with moves, remove move class
|
|
94
|
+
if isinstance(move, tuple):
|
|
95
|
+
move = move[0]
|
|
96
|
+
move.periodic = periodic
|
|
97
|
+
self._periodic = periodic
|
|
98
|
+
|
|
99
|
+
def propose(self, model, state):
|
|
100
|
+
"""Propose a combined move.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
model (:class:`eryn.model.Model`): Carrier of sampler information.
|
|
104
|
+
state (:class:`State`): Current state of the sampler.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
tuple: (state, accepted)
|
|
108
|
+
The first return is the state of the sampler after the move.
|
|
109
|
+
The second return value is the accepted count array for each walker
|
|
110
|
+
counting for all proposals.
|
|
111
|
+
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
# prepare for verbosity if needed
|
|
115
|
+
iterator = enumerate(self.moves)
|
|
116
|
+
if self.verbose:
|
|
117
|
+
iterator = tqdm.tqdm(iterator)
|
|
118
|
+
|
|
119
|
+
# we will set this inside the loop during the first iteration
|
|
120
|
+
accepted_out = None
|
|
121
|
+
for i, move in iterator:
|
|
122
|
+
# get move out of tuple
|
|
123
|
+
if isinstance(move, tuple):
|
|
124
|
+
move = move[0]
|
|
125
|
+
|
|
126
|
+
# run move
|
|
127
|
+
state, accepted = move.propose(model, state)
|
|
128
|
+
|
|
129
|
+
# set (first iteration) or add (after first iteration) to accepted_out
|
|
130
|
+
if accepted_out is None:
|
|
131
|
+
accepted_out = accepted.copy()
|
|
132
|
+
else:
|
|
133
|
+
accepted_out += accepted
|
|
134
|
+
|
|
135
|
+
return state, accepted_out
|
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
|
|
6
|
+
from ..state import State
|
|
7
|
+
from .move import Move
|
|
8
|
+
from ..state import BranchSupplemental
|
|
9
|
+
|
|
10
|
+
__all__ = ["MHMove"]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DelayedRejectionContainer:
|
|
14
|
+
def __init__(self, **kwargs):
|
|
15
|
+
for key, item in kwargs.items():
|
|
16
|
+
setattr(self, key, item)
|
|
17
|
+
|
|
18
|
+
# Initialize
|
|
19
|
+
self.coords = []
|
|
20
|
+
self.log_prob = []
|
|
21
|
+
self.log_prior = []
|
|
22
|
+
self.alpha = []
|
|
23
|
+
|
|
24
|
+
def append(self, new_coords, new_log_prob, new_log_prior, new_alpha):
|
|
25
|
+
self.coords.append(new_coords)
|
|
26
|
+
self.log_prob.append(new_log_prob)
|
|
27
|
+
self.log_prior.append(new_log_prior)
|
|
28
|
+
self.alpha.append(new_alpha)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class DelayedRejection(Move):
|
|
32
|
+
r"""
|
|
33
|
+
Delayed Rejection scheme assuming symmetric and non-adaptive proposal distribution.
|
|
34
|
+
We apply the DR algorithm only on the cases where we have rejected a +1 proposal for
|
|
35
|
+
a given Reversible Jump proposal and branch.
|
|
36
|
+
|
|
37
|
+
Refernces:
|
|
38
|
+
|
|
39
|
+
Tierney L and Mira A, Stat. Med. 18 2507 (1999)
|
|
40
|
+
Haario et al, Stat. Comput. 16:339-354 (2006)
|
|
41
|
+
Mira A, Metron - International Journal of Statistics, vol. LIX, issue 3-4, 231-241 (2001)
|
|
42
|
+
M. Trias, et al, https://arxiv.org/abs/0904.2207
|
|
43
|
+
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, proposal, max_iter=10, **kwargs):
|
|
47
|
+
self.proposal = proposal
|
|
48
|
+
self.max_iter = max_iter
|
|
49
|
+
self.dr_container = None
|
|
50
|
+
super(DelayedRejection, self).__init__(**kwargs)
|
|
51
|
+
|
|
52
|
+
def dr_scheme(
|
|
53
|
+
self,
|
|
54
|
+
state,
|
|
55
|
+
new_state,
|
|
56
|
+
keep_rejected,
|
|
57
|
+
model,
|
|
58
|
+
ntemps,
|
|
59
|
+
nwalkers,
|
|
60
|
+
inds_for_change,
|
|
61
|
+
inds=None,
|
|
62
|
+
dr_iter=0,
|
|
63
|
+
):
|
|
64
|
+
"""Calcuate the delayed rejection acceptace ratio.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
stateslist (:class:`State`): a python list containing the proposed states
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
logalpha: a numpy array containing the acceptance ratios per temperature and walker.
|
|
71
|
+
"""
|
|
72
|
+
# Draw a uniform random for the previously rejected points
|
|
73
|
+
randU = model.random.rand(
|
|
74
|
+
ntemps, nwalkers
|
|
75
|
+
) # We draw for all temps x walkers but we ignore
|
|
76
|
+
# previously accepted points by setting prior[rej] = - inf
|
|
77
|
+
|
|
78
|
+
old_new_state = State(new_state, copy=True)
|
|
79
|
+
|
|
80
|
+
# Propose a new point
|
|
81
|
+
new_state, log_proposal_ratio = self.get_new_state(
|
|
82
|
+
model, new_state, keep_rejected
|
|
83
|
+
) # Get a new state
|
|
84
|
+
|
|
85
|
+
# Compute log-likelihood and log-prior
|
|
86
|
+
logp = new_state.log_prior
|
|
87
|
+
logl = new_state.log_like
|
|
88
|
+
|
|
89
|
+
# Compute the logposterior for all
|
|
90
|
+
logP = self.compute_log_posterior(logl, logp)
|
|
91
|
+
|
|
92
|
+
# Compute log-likelihood and log-prior
|
|
93
|
+
prev_logp = old_new_state.log_prior
|
|
94
|
+
prev_logl = old_new_state.log_like
|
|
95
|
+
|
|
96
|
+
# Compute the logposterior for all
|
|
97
|
+
prev_logP = self.compute_log_posterior(prev_logl, prev_logp)
|
|
98
|
+
|
|
99
|
+
# Compute the acceptance ratio
|
|
100
|
+
lndiff = logP - prev_logP + log_proposal_ratio
|
|
101
|
+
alpha_1 = np.exp(lndiff)
|
|
102
|
+
alpha_1[alpha_1 > 1.0] = 1.0 # np.min((1, alpha))
|
|
103
|
+
|
|
104
|
+
# update delayed rejection alpha
|
|
105
|
+
dr_alpha = np.exp(
|
|
106
|
+
lndiff
|
|
107
|
+
+ np.log(1.0 - alpha_1)
|
|
108
|
+
- np.log(1.0 - old_new_state.supplemental[:]["past_alpha"])
|
|
109
|
+
)
|
|
110
|
+
dr_alpha[dr_alpha > 1.0] = 1.0 # np.min((1., dr_alpha ))
|
|
111
|
+
dr_alpha = np.nan_to_num(dr_alpha) # Automatically reject NaNs
|
|
112
|
+
|
|
113
|
+
new_state.supplemental[:] = {"alpha": dr_alpha} # Replace current dr alpha
|
|
114
|
+
|
|
115
|
+
new_accepted = np.logical_or(
|
|
116
|
+
dr_alpha >= 1.0, randU < dr_alpha
|
|
117
|
+
) # Decide on accepted points
|
|
118
|
+
|
|
119
|
+
# Update state with the new accepted points
|
|
120
|
+
state = self.update(state, new_state, new_accepted)
|
|
121
|
+
|
|
122
|
+
return state, new_accepted, new_state
|
|
123
|
+
|
|
124
|
+
def get_new_state(self, model, state, keep):
|
|
125
|
+
"""A utility function to propose new points"""
|
|
126
|
+
qn, factors = self.proposal.get_proposal(
|
|
127
|
+
state.branches_coords, state.branches_inds, model.random
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Compute prior of the proposed position
|
|
131
|
+
logp = model.compute_log_prior_fn(qn, inds=state.branches_inds)
|
|
132
|
+
logp[~keep] = -np.inf # This trick help us compute only the indeces of interest
|
|
133
|
+
|
|
134
|
+
# Compute the lnprobs of the proposed position.
|
|
135
|
+
logl, new_blobs = model.compute_log_like_fn(
|
|
136
|
+
qn, inds=state.branches_inds, logp=logp
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Update the parameters, update the state. TODO: Fix blobs?
|
|
140
|
+
new_state = State(
|
|
141
|
+
qn,
|
|
142
|
+
log_like=logl,
|
|
143
|
+
log_prior=logp,
|
|
144
|
+
blobs=new_blobs,
|
|
145
|
+
inds=state.branches_inds,
|
|
146
|
+
supplemental=state.supplemental,
|
|
147
|
+
) # I create a new initial state that all are accepted
|
|
148
|
+
return new_state, factors
|
|
149
|
+
|
|
150
|
+
def propose(
|
|
151
|
+
self,
|
|
152
|
+
log_diff_0,
|
|
153
|
+
accepted,
|
|
154
|
+
model,
|
|
155
|
+
state,
|
|
156
|
+
new_state,
|
|
157
|
+
inds,
|
|
158
|
+
inds_for_change,
|
|
159
|
+
factors,
|
|
160
|
+
):
|
|
161
|
+
"""Use the move to generate a proposal and compute the acceptance
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
accepted ():
|
|
165
|
+
model (:class:`eryn.model.Model`): Carrier of sampler information.
|
|
166
|
+
state (:class:`State`): Current state of the sampler.
|
|
167
|
+
rj_inds (): Dictionary containing the indices where the Reversible Jump
|
|
168
|
+
move proposed "birth" of a model. Will we operate a Delayed Rejection type
|
|
169
|
+
of move only on those cases. The keys of the dictionary are the names of the
|
|
170
|
+
models.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
:class:`State`: State of sampler after proposal is complete.
|
|
174
|
+
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
# Check to make sure that the dimensions match.
|
|
178
|
+
ntemps, nwalkers, _, _ = state.branches[list(state.branches.keys())[0]].shape
|
|
179
|
+
|
|
180
|
+
alpha_0 = np.exp(log_diff_0)
|
|
181
|
+
alpha_0[alpha_0 > 1.0] = 1.0 # np.min((1.0, alpha_0))
|
|
182
|
+
new_state.supplemental = BranchSupplemental(
|
|
183
|
+
{"past_alpha": alpha_0}, base_shape=(ntemps, nwalkers)
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# Check to make sure that the dimensions match.
|
|
187
|
+
ntemps, nwalkers, _, _ = state.branches[list(state.branches.keys())[0]].shape
|
|
188
|
+
|
|
189
|
+
dr_iter = 0 # Initialize
|
|
190
|
+
|
|
191
|
+
# Begin main DR loop. Stop when we exceed the maximum iterations, or (extreme case) all proposals are accepted
|
|
192
|
+
while dr_iter <= self.max_iter and not np.all(accepted):
|
|
193
|
+
rejected = ~accepted # Get rejected points
|
|
194
|
+
|
|
195
|
+
# Get the +1 proposals that got previously rejected
|
|
196
|
+
plus_one_rej_inds = {}
|
|
197
|
+
for name in inds_for_change:
|
|
198
|
+
plus_one_inds = inds_for_change[name]["+1"][:, :2]
|
|
199
|
+
plus_one_rej_inds[name] = plus_one_inds[
|
|
200
|
+
rejected[(plus_one_inds[:, 0], plus_one_inds[:, 1])]
|
|
201
|
+
]
|
|
202
|
+
|
|
203
|
+
# Generate the indeces of the proposals that got rejected
|
|
204
|
+
keep_rejected = np.unique(
|
|
205
|
+
np.concatenate(list(plus_one_rej_inds.values())), axis=0
|
|
206
|
+
)
|
|
207
|
+
run_dr = np.zeros_like(rejected)
|
|
208
|
+
run_dr[tuple(keep_rejected.T)] = True
|
|
209
|
+
|
|
210
|
+
# Pass into the DR scheme
|
|
211
|
+
state, new_accepted, new_state = self.dr_scheme(
|
|
212
|
+
state,
|
|
213
|
+
new_state,
|
|
214
|
+
run_dr,
|
|
215
|
+
model,
|
|
216
|
+
ntemps,
|
|
217
|
+
nwalkers,
|
|
218
|
+
inds_for_change,
|
|
219
|
+
inds=inds,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# Update the accepted, increment current iteration
|
|
223
|
+
accepted += new_accepted
|
|
224
|
+
dr_iter += 1
|
|
225
|
+
|
|
226
|
+
if self.temperature_control is not None:
|
|
227
|
+
state, accepted = self.temperature_control.temper_comps(state, accepted)
|
|
228
|
+
|
|
229
|
+
return state, accepted
|
eryn/moves/distgen.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from .mh import MHMove
|
|
5
|
+
from ..prior import ProbDistContainer
|
|
6
|
+
|
|
7
|
+
__all__ = ["DistributionGenerate"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DistributionGenerate(MHMove):
|
|
11
|
+
"""Generate proposals from a distribution
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
generate_dist (dict): Dictionary with keys as branch names and items as
|
|
15
|
+
:class:`ProbDistContainer` objects that have ``logpdf``
|
|
16
|
+
and ``rvs`` methods.
|
|
17
|
+
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, generate_dist, *args, **kwargs):
|
|
21
|
+
if not isinstance(generate_dist, dict):
|
|
22
|
+
raise ValueError(
|
|
23
|
+
"When entering directly into the DistributionGenerate class, generate_dist must be a dictionary. The keys are branch names and the items are ProbDistContainer objects."
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
for key in generate_dist:
|
|
27
|
+
if not isinstance(generate_dist[key], ProbDistContainer):
|
|
28
|
+
raise ValueError(
|
|
29
|
+
"Distributions need to be eryn.prior.ProbDistContainer object."
|
|
30
|
+
)
|
|
31
|
+
self.generate_dist = generate_dist
|
|
32
|
+
super(DistributionGenerate, self).__init__(*args, **kwargs)
|
|
33
|
+
|
|
34
|
+
def get_proposal(self, branches_coords, random, branches_inds=None, **kwargs):
|
|
35
|
+
"""Make a proposal
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
branches_coords (dict): Keys are ``branch_names`` and values are
|
|
39
|
+
np.ndarray[ntemps, nwalkers, nleaves_max, ndim] representing
|
|
40
|
+
coordinates for walkers.
|
|
41
|
+
random (object): Current random state object.
|
|
42
|
+
branches_inds (dict, optional): Keys are ``branch_names`` and values are
|
|
43
|
+
np.ndarray[ntemps, nwalkers, nleaves_max] representing which
|
|
44
|
+
leaves are currently being used. (default: ``None``)
|
|
45
|
+
**kwargs (ignored): This is added for compatibility. It is ignored in this function.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
tuple: Tuple containing proposal information.
|
|
49
|
+
First entry is the new coordinates as a dictionary with keys
|
|
50
|
+
as ``branch_names`` and values as
|
|
51
|
+
``double `` np.ndarray[ntemps, nwalkers, nleaves_max, ndim] containing
|
|
52
|
+
proposed coordinates. Second entry
|
|
53
|
+
is the factors associated with the
|
|
54
|
+
proposal necessary for detailed balance. This is effectively
|
|
55
|
+
any term in the detailed balance fraction. +log of factors if
|
|
56
|
+
in the numerator. -log of factors if in the denominator.
|
|
57
|
+
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
# set up all initial holders
|
|
61
|
+
q = {}
|
|
62
|
+
factors = {}
|
|
63
|
+
new_inds = {}
|
|
64
|
+
if branches_inds is None:
|
|
65
|
+
branches_inds = {
|
|
66
|
+
name: np.ones(coords.shape[:-1], dtype=bool)
|
|
67
|
+
for name, coords in branches_coords
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
# iterate through branches and propose new points where inds == True
|
|
71
|
+
for i, (name, coords, inds) in enumerate(
|
|
72
|
+
zip(
|
|
73
|
+
branches_coords.keys(),
|
|
74
|
+
branches_coords.values(),
|
|
75
|
+
branches_inds.values(),
|
|
76
|
+
)
|
|
77
|
+
):
|
|
78
|
+
# copy over previous info
|
|
79
|
+
ntemps, nwalkers, _, _ = coords.shape
|
|
80
|
+
q[name] = coords.copy()
|
|
81
|
+
new_inds[name] = inds.copy()
|
|
82
|
+
|
|
83
|
+
if i == 0:
|
|
84
|
+
factors = np.zeros((ntemps, nwalkers))
|
|
85
|
+
|
|
86
|
+
# add coordinates for new leaves
|
|
87
|
+
current_generate_dist = self.generate_dist[name]
|
|
88
|
+
inds_here = np.where(inds == True)
|
|
89
|
+
num_inds_change = len(inds_here[0])
|
|
90
|
+
|
|
91
|
+
old_points = coords[inds_here]
|
|
92
|
+
|
|
93
|
+
# old points so + log(qold)
|
|
94
|
+
factors[inds_here[:2]] += +1 * current_generate_dist.logpdf(old_points)
|
|
95
|
+
|
|
96
|
+
# Draw
|
|
97
|
+
new_points = current_generate_dist.rvs(size=num_inds_change)
|
|
98
|
+
|
|
99
|
+
# new point, so -log(qnew)
|
|
100
|
+
factors[inds_here[:2]] += -1 * current_generate_dist.logpdf(new_points)
|
|
101
|
+
|
|
102
|
+
q[name][inds_here] = new_points
|
|
103
|
+
|
|
104
|
+
return q, factors
|