eryn 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,120 @@
1
+ # -*- coding: utf-8 -*-
2
+ try:
3
+ import cupy as xp
4
+ except (ModuleNotFoundError, ImportError):
5
+ pass
6
+
7
+ import numpy as np
8
+
9
+ from .group import GroupMove
10
+ from .stretch import StretchMove
11
+
12
+ __all__ = ["GroupStretchMove"]
13
+
14
+
15
+ class GroupStretchMove(GroupMove, StretchMove):
16
+ """Proposal like stretch with stationary compliment.
17
+
18
+ This move uses the stretch proposal method and math, but the compliment
19
+ of walkers used to propose a new point is chosen from a stationary group
20
+ rather than the current walkers in the ensemble.
21
+
22
+ This move allows for "stretch"-like proposal to be used in Reversible Jump MCMC.
23
+
24
+ Args:
25
+ **kwargs (dict, optional): Keyword arguments passed to :class:`GroupMove` and
26
+ :class:`StretchMove`.
27
+
28
+ """
29
+
30
+ def __init__(self, **kwargs):
31
+ GroupMove.__init__(self, **kwargs)
32
+ StretchMove.__init__(self, **kwargs)
33
+
34
+ def get_proposal(
35
+ self,
36
+ s_all,
37
+ random,
38
+ gibbs_ndim=None,
39
+ s_inds_all=None,
40
+ branch_supps=None,
41
+ **kwargs
42
+ ):
43
+ """Generate group stretch proposal coordinates
44
+
45
+ Args:
46
+ s_all (dict): Keys are ``branch_names`` and values are coordinates
47
+ for which a proposal is to be generated.
48
+ random (object): Random state object.
49
+ gibbs_ndim (int or np.ndarray, optional): If Gibbs sampling, this indicates
50
+ the true dimension. If given as an array, must have shape ``(ntemps, nwalkers)``.
51
+ See the tutorial for more information.
52
+ (default: ``None``)
53
+ s_inds_all (dict, optional): Keys are ``branch_names`` and values are
54
+ ``inds`` arrays indicating which leaves are currently used. (default: ``None``)
55
+ branch_supps (dict, optional): Keys are ``branch_names`` and values are
56
+ :class:`BranchSupplemental` objects. For the group stretch,
57
+ ``branch_supps`` are the best device for passing and tracking useful
58
+ information. (default: ``None``)
59
+
60
+ Returns:
61
+ tuple: First entry is new positions. Second entry is detailed balance factors.
62
+
63
+ Raises:
64
+ ValueError: Issues with dimensionality.
65
+
66
+ """
67
+ # needs to be set before we reach the end
68
+ self.zz = None
69
+ random_number_generator = random if not self.use_gpu else self.xp.random
70
+ newpos = {}
71
+
72
+ # iterate over branches
73
+ for i, name in enumerate(s_all):
74
+ # get points to move
75
+ s = self.xp.asarray(s_all[name])
76
+
77
+ Ns = s.shape[1]
78
+
79
+ if s_inds_all is not None:
80
+ s_inds = self.xp.asarray(s_inds_all[name])
81
+ else:
82
+ s_inds = None
83
+
84
+ ntemps, nwalkers, nleaves_max, ndim_here = s.shape
85
+
86
+ # gets rid of any values of exactly zero
87
+ ndim_temp = nleaves_max * ndim_here
88
+
89
+ # need to properly handle ndim
90
+ if i == 0:
91
+ ndim = ndim_temp
92
+ Ns_check = Ns
93
+
94
+ else:
95
+ ndim += ndim_temp
96
+ if Ns_check != Ns:
97
+ raise ValueError("Different number of walkers across models.")
98
+
99
+ Ns = nwalkers
100
+
101
+ # get actual compliment values
102
+ c_temp = self.choose_c_vals(
103
+ name, s, s_inds=s_inds, branch_supps=branch_supps
104
+ )
105
+
106
+ # use stretch to get new proposals
107
+ newpos[name] = self.get_new_points(
108
+ name, s, c_temp, Ns, s.shape, i, random_number_generator
109
+ )
110
+
111
+ # proper factors
112
+ factors = (ndim - 1.0) * self.xp.log(self.zz)
113
+ if self.use_gpu and not self.return_gpu:
114
+ factors = factors.get()
115
+
116
+ if gibbs_ndim is not None:
117
+ # adjust factors in place
118
+ self.adjust_factors(factors, ndim, gibbs_ndim)
119
+
120
+ return newpos, factors
eryn/moves/mh.py ADDED
@@ -0,0 +1,193 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import numpy as np
4
+ from copy import deepcopy
5
+ from ..state import State
6
+ from .move import Move
7
+
8
+ __all__ = ["MHMove"]
9
+
10
+
11
+ class MHMove(Move):
12
+ r"""A general Metropolis-Hastings proposal
13
+
14
+ Concrete implementations can be made by providing a ``get_proposal`` method.
15
+ For standard Gaussian Metropolis moves, :class:`moves.GaussianMove` can be used.
16
+
17
+ """
18
+
19
+ def __init__(self, **kwargs):
20
+
21
+ Move.__init__(self, **kwargs)
22
+
23
+ def get_proposal(self, branches_coords, random, branches_inds=None, **kwargs):
24
+ """Get proposal
25
+
26
+ Args:
27
+ branches_coords (dict): Keys are ``branch_names`` and values are
28
+ np.ndarray[ntemps, nwalkers, nleaves_max, ndim] representing
29
+ coordinates for walkers.
30
+ random (object): Current random state object.
31
+ branches_inds (dict, optional): Keys are ``branch_names`` and values are
32
+ np.ndarray[ntemps, nwalkers, nleaves_max] representing which
33
+ leaves are currently being used. (default: ``None``)
34
+ **kwargs (ignored): This is added for compatibility. It is ignored in this function.
35
+
36
+ Returns:
37
+ tuple: (Proposed coordinates, factors) -> (dict, np.ndarray)
38
+
39
+ Raises:
40
+ NotImplementedError: If proposal is not implemented in a subclass.
41
+
42
+ """
43
+
44
+ raise NotImplementedError("The proposal must be implemented by " "subclasses")
45
+
46
+ def setup(self, branches_coords):
47
+ """Any setup for the proposal.
48
+
49
+ Args:
50
+ branches_coords (dict): Keys are ``branch_names``. Values are
51
+ np.ndarray[ntemps, nwalkers, nleaves_max, ndim]. These are the curent
52
+ coordinates for all the walkers.
53
+
54
+ """
55
+
56
+ def propose(self, model, state):
57
+ """Use the move to generate a proposal and compute the acceptance
58
+
59
+ Args:
60
+ model (:class:`eryn.model.Model`): Carrier of sampler information.
61
+ state (:class:`State`): Current state of the sampler.
62
+
63
+ Returns:
64
+ :class:`State`: State of sampler after proposal is complete.
65
+
66
+ """
67
+
68
+ self.setup(state.branches_coords)
69
+
70
+ # get all branch names for gibbs setup
71
+ all_branch_names = list(state.branches.keys())
72
+
73
+ # get initial shape information
74
+ ntemps, nwalkers, _, _ = state.branches[all_branch_names[0]].shape
75
+
76
+ # in case there are no leaves yet
77
+ accepted = np.zeros((ntemps, nwalkers), dtype=bool)
78
+
79
+ # iterate through gibbs setup
80
+ for branch_names_run, inds_run in self.gibbs_sampling_setup_iterator(
81
+ all_branch_names
82
+ ):
83
+ # setup supplemental information
84
+ if not np.all(
85
+ np.asarray(list(state.branches_supplemental.values())) == None
86
+ ):
87
+ new_branch_supps = deepcopy(state.branches_supplemental)
88
+ else:
89
+ new_branch_supps = None
90
+
91
+ if state.supplemental is not None:
92
+ new_supps = deepcopy(state.supplemental)
93
+ else:
94
+ new_supps = None
95
+
96
+ # setup information according to gibbs info
97
+ (
98
+ coords_going_for_proposal,
99
+ inds_going_for_proposal,
100
+ at_least_one_proposal,
101
+ ) = self.setup_proposals(
102
+ branch_names_run, inds_run, state.branches_coords, state.branches_inds
103
+ )
104
+
105
+ # if no walkers are actually being proposed
106
+ if not at_least_one_proposal:
107
+ continue
108
+
109
+ self.current_model = model
110
+ self.current_state = state
111
+
112
+ # Get the move-specific proposal.
113
+ q, factors = self.get_proposal(
114
+ coords_going_for_proposal,
115
+ model.random,
116
+ branches_inds=inds_going_for_proposal,
117
+ supps=new_supps,
118
+ branch_supps=new_branch_supps,
119
+ )
120
+
121
+ # account for gibbs sampling
122
+ self.cleanup_proposals_gibbs(
123
+ branch_names_run, inds_run, q, state.branches_coords
124
+ )
125
+
126
+ # order everything properly
127
+ q, _, new_branch_supps = self.ensure_ordering(
128
+ list(state.branches.keys()), q, state.branches_inds, new_branch_supps
129
+ )
130
+
131
+ # if not wrapping with mutliple try (normal route)
132
+ if not hasattr(self, "mt_ll") or not hasattr(self, "mt_lp"):
133
+ # Compute prior of the proposed position
134
+ logp = model.compute_log_prior_fn(q, inds=state.branches_inds)
135
+
136
+ self.fix_logp_gibbs(
137
+ branch_names_run, inds_run, logp, state.branches_inds
138
+ )
139
+
140
+ # Compute the lnprobs of the proposed position.
141
+ # Can adjust supplementals in place
142
+ logl, new_blobs = model.compute_log_like_fn(
143
+ q,
144
+ inds=state.branches_inds,
145
+ logp=logp,
146
+ supps=new_supps,
147
+ branch_supps=new_branch_supps,
148
+ )
149
+
150
+ else:
151
+ # if already computed in multiple try
152
+ logl = self.mt_ll
153
+ logp = self.mt_lp
154
+ new_blobs = None
155
+
156
+ # get log posterior
157
+ logP = self.compute_log_posterior(logl, logp)
158
+
159
+ # get previous information
160
+ prev_logl = state.log_like
161
+
162
+ prev_logp = state.log_prior
163
+
164
+ # takes care of tempering
165
+ prev_logP = self.compute_log_posterior(prev_logl, prev_logp)
166
+
167
+ # difference
168
+ lnpdiff = factors + logP - prev_logP
169
+
170
+ # draw against acceptance fraction
171
+ accepted = lnpdiff > np.log(model.random.rand(ntemps, nwalkers))
172
+
173
+ # Update the parameters
174
+ new_state = State(
175
+ q,
176
+ log_like=logl,
177
+ log_prior=logp,
178
+ blobs=new_blobs,
179
+ inds=state.branches_inds,
180
+ supplemental=new_supps,
181
+ branch_supplemental=new_branch_supps,
182
+ )
183
+ state = self.update(state, new_state, accepted)
184
+
185
+ # add to move-specific accepted information
186
+ self.accepted += accepted
187
+ self.num_proposals += 1
188
+
189
+ # temperature swaps
190
+ if self.temperature_control is not None:
191
+ state = self.temperature_control.temper_comps(state)
192
+
193
+ return state, accepted