eryn 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,222 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import numpy as np
4
+
5
+ from .rj import ReversibleJumpMove
6
+ from ..prior import ProbDistContainer
7
+
8
+ __all__ = ["DistributionGenerateRJ"]
9
+
10
+
11
+ class DistributionGenerateRJ(ReversibleJumpMove):
12
+ """Generate Revesible-Jump proposals from a distribution.
13
+
14
+ The prior can be entered as the ``generate_dist`` to generate proposals directly from the prior.
15
+
16
+ Args:
17
+ generate_dist (dict): Keys are branch names and values are :class:`ProbDistContainer` objects
18
+ that have ``logpdf`` and ``rvs`` methods. If you
19
+ *args (tuple, optional): Additional arguments to pass to parent classes.
20
+ **kwargs (dict, optional): Keyword arguments passed to parent classes.
21
+
22
+ """
23
+
24
+ def __init__(self, generate_dist, *args, **kwargs):
25
+
26
+ # make sure all inputs are distribution Containers
27
+ for key in generate_dist:
28
+ if not isinstance(generate_dist[key], ProbDistContainer):
29
+ raise ValueError(
30
+ "Distributions need to be eryn.prior.ProbDistContiner object."
31
+ )
32
+ self.generate_dist = generate_dist
33
+ super(DistributionGenerateRJ, self).__init__(*args, **kwargs)
34
+
35
+ def get_model_change_proposal(self, inds, random, nleaves_min, nleaves_max):
36
+ """Helper function for changing the model count by 1.
37
+
38
+ This helper function works with nested models where you want to add or remove
39
+ one leaf at a time.
40
+
41
+ Args:
42
+ inds (np.ndarray): ``inds`` values for this specific branch with shape
43
+ ``(ntemps, nwalkers, nleaves_max)``.
44
+ random (object): Current random state of the sampler.
45
+ nleaves_min (int): Minimum allowable leaf count for this branch.
46
+ nleaves_max (int): Maximum allowable leaf count for this branch.
47
+
48
+ Returns:
49
+ dict: Keys are ``"+1"`` and ``"-1"``. Values are indexing information.
50
+ ``"+1"`` and ``"-1"`` indicate if a source is being added or removed, respectively.
51
+ The indexing information is a 2D array with shape ``(number changing, 3)``.
52
+ The length 3 is the index into each of the ``(ntemps, nwalkers, nleaves_max)``.
53
+
54
+ """
55
+
56
+ ntemps, nwalkers, _ = inds.shape
57
+
58
+ nleaves = inds.sum(axis=-1)
59
+
60
+ # choose whether to add or remove
61
+ if self.fix_change is None:
62
+ change = random.choice([-1, +1], size=nleaves.shape)
63
+ else:
64
+ change = np.full(nleaves.shape, self.fix_change)
65
+
66
+ # fix edge cases
67
+ change = (
68
+ change * ((nleaves != nleaves_min) & (nleaves != nleaves_max))
69
+ + (+1) * (nleaves == nleaves_min)
70
+ + (-1) * (nleaves == nleaves_max)
71
+ )
72
+
73
+ # setup storage for this information
74
+ inds_for_change = {}
75
+ num_increases = np.sum(change == +1)
76
+ inds_for_change["+1"] = np.zeros((num_increases, 3), dtype=int)
77
+ num_decreases = np.sum(change == -1)
78
+ inds_for_change["-1"] = np.zeros((num_decreases, 3), dtype=int)
79
+
80
+ # TODO: not loop ? Is it necessary?
81
+ # TODO: might be able to subtract new inds from old inds type of thing
82
+ # fill the inds_for_change
83
+ increase_i = 0
84
+ decrease_i = 0
85
+ for t in range(ntemps):
86
+ for w in range(nwalkers):
87
+ # check if add or remove
88
+ change_tw = change[t][w]
89
+ # inds array from specific walker
90
+ inds_tw = inds[t][w]
91
+
92
+ # adding
93
+ if change_tw == +1:
94
+ # find where leaves are not currently used
95
+ inds_false = np.where(inds_tw == False)[0]
96
+ # decide which spot to add
97
+ ind_change = random.choice(inds_false)
98
+ # put in the indexes into inds arrays
99
+ inds_for_change["+1"][increase_i] = np.array(
100
+ [t, w, ind_change], dtype=int
101
+ )
102
+ # count increases
103
+ increase_i += 1
104
+
105
+ # removing
106
+ elif change_tw == -1:
107
+ # change_tw == -1
108
+ # find which leavs are used
109
+ inds_true = np.where(inds_tw == True)[0]
110
+ # choose which to remove
111
+ ind_change = random.choice(inds_true)
112
+ # add indexes into inds
113
+ if inds_for_change["-1"].shape[0] > 0:
114
+ inds_for_change["-1"][decrease_i] = np.array(
115
+ [t, w, ind_change], dtype=int
116
+ )
117
+ decrease_i += 1
118
+ # do not care currently about what we do with discarded coords, they just sit in the state
119
+ # model component number not changing
120
+ else:
121
+ pass
122
+ return inds_for_change
123
+
124
+ def get_proposal(
125
+ self, all_coords, all_inds, nleaves_min_all, nleaves_max_all, random, **kwargs
126
+ ):
127
+ """Make a proposal
128
+
129
+ Args:
130
+ all_coords (dict): Keys are ``branch_names``. Values are
131
+ np.ndarray[ntemps, nwalkers, nleaves_max, ndim]. These are the curent
132
+ coordinates for all the walkers.
133
+ all_inds (dict): Keys are ``branch_names``. Values are
134
+ np.ndarray[ntemps, nwalkers, nleaves_max]. These are the boolean
135
+ arrays marking which leaves are currently used within each walker.
136
+ nleaves_min_all (dict): Minimum values of leaf ount for each model. Must have same order as ``all_cords``.
137
+ nleaves_max_all (dict): Maximum values of leaf ount for each model. Must have same order as ``all_cords``.
138
+ random (object): Current random state of the sampler.
139
+ **kwargs (ignored): For modularity.
140
+
141
+ Returns:
142
+ tuple: Tuple containing proposal information.
143
+ First entry is the new coordinates as a dictionary with keys
144
+ as ``branch_names`` and values as
145
+ ``double `` np.ndarray[ntemps, nwalkers, nleaves_max, ndim] containing
146
+ proposed coordinates. Second entry is the new ``inds`` array with
147
+ boolean values flipped for added or removed sources. Third entry
148
+ is the factors associated with the
149
+ proposal necessary for detailed balance. This is effectively
150
+ any term in the detailed balance fraction. +log of factors if
151
+ in the numerator. -log of factors if in the denominator.
152
+
153
+ """
154
+ # prepare the output dictionaries
155
+ q = {}
156
+ new_inds = {}
157
+ all_inds_for_change = {}
158
+
159
+ # loop over the models included here
160
+ assert len(nleaves_min_all)
161
+ assert len(all_coords.keys()) == len(nleaves_max_all.keys())
162
+ for (name, inds) in all_inds.items():
163
+ nleaves_max = nleaves_max_all[name]
164
+ nleaves_min = nleaves_min_all[name]
165
+ if nleaves_min == nleaves_max:
166
+ continue
167
+ elif nleaves_min > nleaves_max:
168
+ raise ValueError("nleaves_min is greater than nleaves_max. Not allowed.")
169
+
170
+ # get the inds adjustment information
171
+ all_inds_for_change[name] = self.get_model_change_proposal(
172
+ inds, random, nleaves_min, nleaves_max
173
+ )
174
+
175
+ # loop through branches and propose new points from the prio
176
+ for i, (name, coords, inds) in enumerate(
177
+ zip(all_coords.keys(), all_coords.values(), all_inds.values(),)
178
+ ):
179
+ # put in base information
180
+ ntemps, nwalkers, nleaves_max, ndim = coords.shape
181
+ new_inds[name] = inds.copy()
182
+ q[name] = coords.copy()
183
+
184
+ if i == 0:
185
+ factors = np.zeros((ntemps, nwalkers))
186
+
187
+ # if not included
188
+ if name not in all_inds_for_change:
189
+ continue
190
+
191
+ # inds changing for this branch
192
+ inds_for_change = all_inds_for_change[name]
193
+
194
+ # adjust inds
195
+
196
+ # adjust deaths from True -> False
197
+ inds_here = tuple(inds_for_change["-1"].T)
198
+ new_inds[name][inds_here] = False
199
+
200
+ # factor is +log q()
201
+ current_generate_dist = self.generate_dist[name]
202
+ factors[inds_here[:2]] += +1 * current_generate_dist.logpdf(
203
+ q[name][inds_here]
204
+ )
205
+
206
+ # adjust births from False -> True
207
+ inds_here = tuple(inds_for_change["+1"].T)
208
+ new_inds[name][inds_here] = True
209
+
210
+ # add coordinates for new leaves
211
+ current_generate_dist = self.generate_dist[name]
212
+ inds_here = tuple(inds_for_change["+1"].T)
213
+ num_inds_change = len(inds_here[0])
214
+
215
+ q[name][inds_here] = current_generate_dist.rvs(size=num_inds_change)
216
+
217
+ # factor is -log q()
218
+ factors[inds_here[:2]] += -1 * current_generate_dist.logpdf(
219
+ q[name][inds_here]
220
+ )
221
+
222
+ return q, new_inds, factors
eryn/moves/gaussian.py ADDED
@@ -0,0 +1,190 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import numpy as np
4
+
5
+ from .mh import MHMove
6
+
7
+ __all__ = ["GaussianMove"]
8
+
9
+
10
+ class GaussianMove(MHMove):
11
+ """A Metropolis step with a Gaussian proposal function.
12
+
13
+ This class is heavily based on the same class in ``emcee``.
14
+
15
+ Args:
16
+ cov (dict): The covariance of the proposal function. The keys are branch names and the
17
+ values are covariance information. This information can be provided as a scalar,
18
+ vector, or matrix and the proposal will be assumed isotropic,
19
+ axis-aligned, or general, respectively.
20
+ mode (str, optional): Select the method used for updating parameters. This
21
+ can be one of ``"vector"``, ``"random"``, or ``"sequential"``. The
22
+ ``"vector"`` mode updates all dimensions simultaneously,
23
+ ``"random"`` randomly selects a dimension and only updates that
24
+ one, and ``"sequential"`` loops over dimensions and updates each
25
+ one in turn. (default: ``"vector"``)
26
+ factor (float, optional): If provided the proposal will be made with a
27
+ standard deviation uniformly selected from the range
28
+ ``exp(U(-log(factor), log(factor))) * cov``. This is invalid for
29
+ the ``"vector"`` mode. (default: ``None``)
30
+ **kwargs (dict, optional): Kwargs for parent classes. (default: ``{}``)
31
+
32
+ Raises:
33
+ ValueError: If the proposal dimensions are invalid or if any of any of
34
+ the other arguments are inconsistent.
35
+
36
+ """
37
+
38
+ def __init__(self, cov_all, mode="vector", factor=None, **kwargs):
39
+ self.all_proposal = {}
40
+ for name, cov in cov_all.items():
41
+ # Parse the proposal type.
42
+ try:
43
+ float(cov)
44
+
45
+ except TypeError:
46
+ cov = np.atleast_1d(cov)
47
+ if len(cov.shape) == 1:
48
+ # A diagonal proposal was given.
49
+ ndim = len(cov)
50
+ proposal = _diagonal_proposal(np.sqrt(cov), factor, mode)
51
+
52
+ elif len(cov.shape) == 2 and cov.shape[0] == cov.shape[1]:
53
+ # The full, square covariance matrix was given.
54
+ ndim = cov.shape[0]
55
+ proposal = _proposal(cov, factor, mode)
56
+
57
+ else:
58
+ raise ValueError("Invalid proposal scale dimensions")
59
+
60
+ else:
61
+ # This was a scalar proposal.
62
+ ndim = None
63
+ proposal = _isotropic_proposal(np.sqrt(cov), factor, mode)
64
+ self.all_proposal[name] = proposal
65
+
66
+ super(GaussianMove, self).__init__(**kwargs)
67
+
68
+ def get_proposal(self, branches_coords, random, branches_inds=None, **kwargs):
69
+ """Get proposal from Gaussian distribution
70
+
71
+ Args:
72
+ branches_coords (dict): Keys are ``branch_names`` and values are
73
+ np.ndarray[ntemps, nwalkers, nleaves_max, ndim] representing
74
+ coordinates for walkers.
75
+ random (object): Current random state object.
76
+ branches_inds (dict, optional): Keys are ``branch_names`` and values are
77
+ np.ndarray[ntemps, nwalkers, nleaves_max] representing which
78
+ leaves are currently being used. (default: ``None``)
79
+ **kwargs (ignored): This is added for compatibility. It is ignored in this function.
80
+
81
+ Returns:
82
+ tuple: (Proposed coordinates, factors) -> (dict, np.ndarray)
83
+
84
+ """
85
+
86
+ # initialize ouput
87
+ q = {}
88
+ for name, coords in zip(branches_coords.keys(), branches_coords.values()):
89
+ ntemps, nwalkers, nleaves_max, ndim = coords.shape
90
+
91
+ # setup inds accordingly
92
+ if branches_inds is None:
93
+ inds = np.ones((ntemps, nwalkers, nleaves_max), dtype=bool)
94
+ else:
95
+ inds = branches_inds[name]
96
+
97
+ # get the proposal for this branch
98
+ proposal_fn = self.all_proposal[name]
99
+ inds_here = np.where(inds == True)
100
+
101
+ # copy coords
102
+ q[name] = coords.copy()
103
+
104
+ # get new points
105
+ new_coords, _ = proposal_fn(coords[inds_here], random)
106
+
107
+ # put into coords in proper location
108
+ q[name][inds_here] = new_coords.copy()
109
+
110
+ # handle periodic parameters
111
+ if self.periodic is not None:
112
+ q = self.periodic.wrap(
113
+ {
114
+ name: tmp.reshape((ntemps * nwalkers,) + tmp.shape[-2:])
115
+ for name, tmp in q.items()
116
+ },
117
+ xp=self.xp,
118
+ )
119
+
120
+ q = {
121
+ name: tmp.reshape(
122
+ (
123
+ ntemps,
124
+ nwalkers,
125
+ )
126
+ + tmp.shape[-2:]
127
+ )
128
+ for name, tmp in q.items()
129
+ }
130
+
131
+ return q, np.zeros((ntemps, nwalkers))
132
+
133
+
134
+ class _isotropic_proposal(object):
135
+ allowed_modes = ["vector", "random", "sequential"]
136
+
137
+ def __init__(self, scale, factor, mode):
138
+ self.index = 0
139
+ self.scale = scale
140
+ self.invscale = np.linalg.inv(np.linalg.cholesky(scale))
141
+ if factor is None:
142
+ self._log_factor = None
143
+ else:
144
+ if factor < 1.0:
145
+ raise ValueError("'factor' must be >= 1.0")
146
+ self._log_factor = np.log(factor)
147
+
148
+ if mode not in self.allowed_modes:
149
+ raise ValueError(
150
+ ("'{0}' is not a recognized mode. " "Please select from: {1}").format(
151
+ mode, self.allowed_modes
152
+ )
153
+ )
154
+ self.mode = mode
155
+
156
+ def get_factor(self, rng):
157
+ if self._log_factor is None:
158
+ return 1.0
159
+ return np.exp(rng.uniform(-self._log_factor, self._log_factor))
160
+
161
+ def get_updated_vector(self, rng, x0):
162
+ return x0 + self.get_factor(rng) * self.scale * rng.randn(*(x0.shape))
163
+
164
+ def __call__(self, x0, rng):
165
+ nw, nd = x0.shape
166
+ xnew = self.get_updated_vector(rng, x0)
167
+ if self.mode == "random":
168
+ m = (range(nw), rng.randint(x0.shape[-1], size=nw))
169
+ elif self.mode == "sequential":
170
+ m = (range(nw), self.index % nd + np.zeros(nw, dtype=int))
171
+ self.index = (self.index + 1) % nd
172
+ else:
173
+ return xnew, np.zeros(nw)
174
+ x = np.array(x0)
175
+ x[m] = xnew[m]
176
+ return x, np.zeros(nw)
177
+
178
+
179
+ class _diagonal_proposal(_isotropic_proposal):
180
+ def get_updated_vector(self, rng, x0):
181
+ return x0 + self.get_factor(rng) * self.scale * rng.randn(*(x0.shape))
182
+
183
+
184
+ class _proposal(_isotropic_proposal):
185
+ allowed_modes = ["vector"]
186
+
187
+ def get_updated_vector(self, rng, x0):
188
+ return x0 + self.get_factor(rng) * rng.multivariate_normal(
189
+ np.zeros(len(self.scale)), self.scale, size=len(x0)
190
+ )
eryn/moves/group.py ADDED
@@ -0,0 +1,281 @@
1
+ # -*- coding: utf-8 -*-
2
+ from abc import ABC
3
+ from copy import deepcopy
4
+ import numpy as np
5
+ import warnings
6
+
7
+ from ..state import BranchSupplemental, State
8
+ from .move import Move
9
+
10
+
11
+ __all__ = ["GroupMove"]
12
+
13
+
14
+ class GroupMove(Move, ABC):
15
+ """
16
+ A "group" ensemble move based on the :class:`eryn.moves.RedBlueMove`.
17
+
18
+ In moves like the :class:`eryn.moves.StretchMove`, the complimentary
19
+ group for which the proposal is used is chosen from the current points in
20
+ the ensemble. In "group" moves the complimentary group is a stationary group
21
+ that is updated every `n_iter_update` iterations. This update is performed with the
22
+ last set of coordinates to maintain detailed balance.
23
+
24
+ Args:
25
+ nfriends (int, optional): The number of friends to draw from as the complimentary
26
+ ensemble. This group is determined from the stationary group. If ``None``, it will
27
+ be set to the number of walkers. (default: ``None``)
28
+ n_iter_update (int, optional): Number of iterations to run before updating the
29
+ stationary distribution. (default: 100).
30
+ live_dangerously (bool, optional): If ``True``, allow for ``n_iter_update == 1``.
31
+ (deafault: ``False``)
32
+
33
+ ``kwargs`` are passed to :class:`Move` class.
34
+
35
+ """
36
+
37
+ def __init__(
38
+ self, nfriends=None, n_iter_update=100, live_dangerously=False, **kwargs
39
+ ):
40
+
41
+ Move.__init__(self, **kwargs)
42
+ self.nfriends = int(nfriends)
43
+ self.n_iter_update = n_iter_update
44
+
45
+ if self.n_iter_update <= 1 and not live_dangerously:
46
+ raise ValueError("n_iter_update must be greather than or equal to 2.")
47
+
48
+ self.iter = 0
49
+
50
+ def find_friends(self, name, s, s_inds=None, branch_supps=None):
51
+ """Function for finding friends.
52
+
53
+ Args:
54
+ name (str): Branch name for proposal coordinates.
55
+ s (np.ndarray): Coordinates array for the points to be moved.
56
+ s_inds (np.ndarray, optional): ``inds`` arrays that represent which leaves are present.
57
+ (default: ``None``)
58
+ branch_supps (dict, optional): Keys are ``branch_names`` and values are
59
+ :class:`BranchSupplemental` objects. For group proposals,
60
+ ``branch_supps`` are the best device for passing and tracking useful
61
+ information. (default: ``None``)
62
+
63
+ Return:
64
+ np.ndarray: Complimentary values.
65
+
66
+ """
67
+ raise NotImplementedError
68
+
69
+ def choose_c_vals(self, name, s, s_inds=None, branch_supps=None):
70
+ """Get the complimentary values."""
71
+ return self.find_friends(name, s, s_inds=s_inds, branch_supps=branch_supps)
72
+
73
+ def setup(self, branches):
74
+ """Any setup necessary for the proposal"""
75
+ pass
76
+
77
+ def setup_friends(self, branches):
78
+ """Setup anything for finding friends.
79
+
80
+ Args:
81
+ branches (dict): Dictionary with all the current branches in the sampler.
82
+
83
+ """
84
+ raise NotImplementedError
85
+
86
+ def fix_friends(self, branches):
87
+ """Fix any friends that were born through RJ.
88
+
89
+ This function is not required. If not implemented, it will just return immediately.
90
+
91
+ Args:
92
+ branches (dict): Dictionary with all the current branches in the sampler.
93
+
94
+ """
95
+ return
96
+
97
+ @classmethod
98
+ def get_proposal(self, s_all, random, gibbs_ndim=None, s_inds_all=None, **kwargs):
99
+ """Generate group stretch proposal coordinates
100
+
101
+ Args:
102
+ s_all (dict): Keys are ``branch_names`` and values are coordinates
103
+ for which a proposal is to be generated.
104
+ random (object): Random state object.
105
+ gibbs_ndim (int or np.ndarray, optional): If Gibbs sampling, this indicates
106
+ the true dimension. If given as an array, must have shape ``(ntemps, nwalkers)``.
107
+ See the tutorial for more information.
108
+ (default: ``None``)
109
+ s_inds_all (dict, optional): Keys are ``branch_names`` and values are
110
+ ``inds`` arrays indicating which leaves are currently used. (default: ``None``)
111
+
112
+ Returns:
113
+ tuple: First entry is new positions. Second entry is detailed balance factors.
114
+
115
+ Raises:
116
+ ValueError: Issues with dimensionality.
117
+
118
+ """
119
+
120
+ raise NotImplementedError("The proposal must be implemented by " "subclasses")
121
+
122
+ def propose(self, model, state):
123
+ """Use the move to generate a proposal and compute the acceptance
124
+
125
+ Args:
126
+ model (:class:`eryn.model.Model`): Carrier of sampler information.
127
+ state (:class:`State`): Current state of the sampler.
128
+
129
+ Returns:
130
+ tuple: (state, accepted)
131
+ The first return is the state of the sampler after the move.
132
+ The second return value is the accepted count array.
133
+
134
+ """
135
+
136
+ # Check that the dimensions are compatible.
137
+ ndim_total = 0
138
+ for branch in state.branches.values():
139
+ ntemps, nwalkers, nleaves_, ndim_ = branch.shape
140
+ ndim_total += ndim_ * nleaves_
141
+
142
+ if self.nfriends is None:
143
+ self.nfriends = nwalkers
144
+
145
+ # Run any move-specific setup.
146
+ self.setup(state.branches)
147
+
148
+ if self.iter == 0 or self.iter % self.n_iter_update == 0:
149
+ self.setup_friends(state.branches)
150
+
151
+ if self.iter != 0 and self.iter % self.n_iter_update == 0:
152
+ # store old values to maintain detailed balance when updating
153
+ old_branches = deepcopy(state.branches)
154
+
155
+ # fix any friends that may have come through rj
156
+ if self.iter != 0 and self.iter % self.n_iter_update != 0:
157
+ self.fix_friends(state.branches)
158
+
159
+ # Split the ensemble in half and iterate over these two halves.
160
+ accepted = np.zeros((ntemps, nwalkers), dtype=bool)
161
+
162
+ all_branch_names = list(state.branches.keys())
163
+
164
+ # get gibbs sampling information
165
+ for branch_names_run, inds_run in self.gibbs_sampling_setup_iterator(
166
+ all_branch_names
167
+ ):
168
+
169
+ if not np.all(
170
+ np.asarray(list(state.branches_supplemental.values())) == None
171
+ ):
172
+ new_branch_supps = deepcopy(state.branches_supplemental)
173
+ else:
174
+ new_branch_supps = None
175
+
176
+ if state.supplemental is not None:
177
+ new_supps = deepcopy(state.supplemental)
178
+ else:
179
+ new_supps = None
180
+
181
+ # setup proposals based on Gibbs sampling
182
+ (
183
+ coords_going_for_proposal,
184
+ inds_going_for_proposal,
185
+ at_least_one_proposal,
186
+ ) = self.setup_proposals(
187
+ branch_names_run, inds_run, state.branches_coords, state.branches_inds
188
+ )
189
+
190
+ if not at_least_one_proposal:
191
+ continue
192
+
193
+ # need to trick stretch proposal into using the dimenionality associated
194
+ # with Gibbs sampling if it is being used
195
+ gibbs_ndim = 0
196
+ for brn, ir in zip(branch_names_run, inds_run):
197
+ if ir is not None:
198
+ gibbs_ndim += ir.sum()
199
+ else:
200
+ # nleaves * ndim
201
+ gibbs_ndim += np.prod(state.branches[brn].shape[-2:])
202
+
203
+ self.current_model = model
204
+ self.current_state = state
205
+ # Get the move-specific proposal.
206
+ q, factors = self.get_proposal(
207
+ coords_going_for_proposal,
208
+ model.random,
209
+ gibbs_ndim=gibbs_ndim,
210
+ s_inds_all=inds_going_for_proposal,
211
+ branch_supps=new_branch_supps,
212
+ )
213
+
214
+ # account for gibbs sampling
215
+ self.cleanup_proposals_gibbs(
216
+ branch_names_run, inds_run, q, state.branches_coords
217
+ )
218
+
219
+ # order everything properly
220
+ q, _, new_branch_supps = self.ensure_ordering(
221
+ list(state.branches.keys()), q, state.branches_inds, new_branch_supps
222
+ )
223
+
224
+ # Compute prior of the proposed position
225
+ # new_inds_prior is adjusted if product-space is used
226
+ logp = model.compute_log_prior_fn(q, inds=state.branches_inds)
227
+
228
+ self.fix_logp_gibbs(branch_names_run, inds_run, logp, state.branches_inds)
229
+
230
+ # Can adjust supplementals in place
231
+ logl, new_blobs = model.compute_log_like_fn(
232
+ q,
233
+ inds=state.branches_inds,
234
+ logp=logp,
235
+ supps=new_supps,
236
+ branch_supps=new_branch_supps,
237
+ )
238
+
239
+ # get log posterior
240
+ logP = self.compute_log_posterior(logl, logp)
241
+
242
+ # get previous information
243
+ prev_logl = state.log_like
244
+
245
+ prev_logp = state.log_prior
246
+
247
+ # takes care of tempering
248
+ prev_logP = self.compute_log_posterior(prev_logl, prev_logp)
249
+
250
+ # difference
251
+ lnpdiff = factors + logP - prev_logP
252
+
253
+ # draw against acceptance fraction
254
+ accepted = lnpdiff > np.log(model.random.rand(ntemps, nwalkers))
255
+
256
+ # Update the parameters
257
+ new_state = State(
258
+ q,
259
+ log_like=logl,
260
+ log_prior=logp,
261
+ blobs=new_blobs,
262
+ inds=state.branches_inds,
263
+ supplemental=new_supps,
264
+ branch_supplemental=new_branch_supps,
265
+ )
266
+ state = self.update(state, new_state, accepted)
267
+
268
+ # add to move-specific accepted information
269
+ self.accepted += accepted
270
+ self.num_proposals += 1
271
+
272
+ if self.temperature_control is not None:
273
+ state = self.temperature_control.temper_comps(state)
274
+
275
+ if self.iter != 0 and self.iter % self.n_iter_update == 0:
276
+ # use old values to maintain detailed balance when updating
277
+ # nfriends
278
+ self.setup_friends(old_branches)
279
+
280
+ self.iter += 1
281
+ return state, accepted