eryn 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eryn/CMakeLists.txt +51 -0
- eryn/__init__.py +35 -0
- eryn/backends/__init__.py +20 -0
- eryn/backends/backend.py +1150 -0
- eryn/backends/hdfbackend.py +819 -0
- eryn/ensemble.py +1690 -0
- eryn/git_version.py.in +7 -0
- eryn/model.py +18 -0
- eryn/moves/__init__.py +42 -0
- eryn/moves/combine.py +135 -0
- eryn/moves/delayedrejection.py +229 -0
- eryn/moves/distgen.py +104 -0
- eryn/moves/distgenrj.py +222 -0
- eryn/moves/gaussian.py +190 -0
- eryn/moves/group.py +281 -0
- eryn/moves/groupstretch.py +120 -0
- eryn/moves/mh.py +193 -0
- eryn/moves/move.py +703 -0
- eryn/moves/mtdistgen.py +137 -0
- eryn/moves/mtdistgenrj.py +190 -0
- eryn/moves/multipletry.py +776 -0
- eryn/moves/red_blue.py +333 -0
- eryn/moves/rj.py +388 -0
- eryn/moves/stretch.py +231 -0
- eryn/moves/tempering.py +649 -0
- eryn/pbar.py +56 -0
- eryn/prior.py +452 -0
- eryn/state.py +775 -0
- eryn/tests/__init__.py +0 -0
- eryn/tests/test_eryn.py +1246 -0
- eryn/utils/__init__.py +10 -0
- eryn/utils/periodic.py +134 -0
- eryn/utils/stopping.py +164 -0
- eryn/utils/transform.py +226 -0
- eryn/utils/updates.py +69 -0
- eryn/utils/utility.py +329 -0
- eryn-1.2.0.dist-info/METADATA +167 -0
- eryn-1.2.0.dist-info/RECORD +39 -0
- eryn-1.2.0.dist-info/WHEEL +4 -0
eryn/moves/move.py
ADDED
|
@@ -0,0 +1,703 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
from ..state import BranchSupplemental
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from copy import deepcopy
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
import cupy as cp
|
|
10
|
+
except (ModuleNotFoundError, ImportError):
|
|
11
|
+
import numpy as cp
|
|
12
|
+
|
|
13
|
+
__all__ = ["Move"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Move(object):
|
|
17
|
+
"""Parent class for proposals or "moves"
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
temperature_control (:class:`tempering.TemperatureControl`, optional):
|
|
21
|
+
This object controls the tempering. It is passed to the parent class
|
|
22
|
+
to moves so that all proposals can share and use temperature settings.
|
|
23
|
+
(default: ``None``)
|
|
24
|
+
periodic (:class:`eryn.utils.PeriodicContainer, optional):
|
|
25
|
+
This object holds periodic information and methods for periodic parameters. It is passed to the parent class
|
|
26
|
+
to moves so that all proposals can share and use periodic information.
|
|
27
|
+
(default: ``None``)
|
|
28
|
+
gibbs_sampling_setup (str, tuple, dict, or list, optional): This sets the Gibbs Sampling setup if
|
|
29
|
+
desired. The Gibbs sampling setup is completely customizable down to the leaf and parameters.
|
|
30
|
+
All of the separate Gibbs sampling splits will be run within 1 call to this proposal.
|
|
31
|
+
If ``None``, run all branches and all parameters. If ``str``, run all parameters within the
|
|
32
|
+
branch given as the string. To enter a branch with a specific set of parameters, you can
|
|
33
|
+
provide a 2-tuple with the first entry as the branch name and the second entry as a 2D
|
|
34
|
+
boolean array of shape ``(nleaves_max, ndim)`` that indicates which leaves and/or parameters
|
|
35
|
+
you want to run. ``None`` can also be entered in the second entry if all parameters are to be run.
|
|
36
|
+
A dictionary is also possible with keys as branch names and values as the same 2D boolean array
|
|
37
|
+
of shape ``(nleaves_max, ndim)`` that indicates which leaves and/or parameters
|
|
38
|
+
you want to run. ``None`` can also be entered in the value of the dictionary
|
|
39
|
+
if all parameters are to be run. If multiple keys are provided in the dictionary, those
|
|
40
|
+
branches will be run simultaneously in the proposal as one iteration of the proposing loop.
|
|
41
|
+
The final option is a list. This is how you make sure to run all the Gibbs splits. Each entry
|
|
42
|
+
of the list can be a string, 2-tuple, or dictionary as described above. The list controls
|
|
43
|
+
the order in which all of these splits are run. (default: ``None``)
|
|
44
|
+
prevent_swaps (bool, optional): If ``True``, do not perform temperature swaps in this move.
|
|
45
|
+
skip_supp_names_update (list, optional): List of names (`str`), that can be in any
|
|
46
|
+
:class:`eryn.state.BranchSupplemental`,
|
|
47
|
+
to skip when updating states (:func:`Move.update`). This is useful if a
|
|
48
|
+
large amount of memory is stored in the branch supplementals.
|
|
49
|
+
is_rj (bool, optional): If using RJ, this should be ``True``. (default: ``False``)
|
|
50
|
+
use_gpu (bool, optional): If ``True``, use ``CuPy`` for computations.
|
|
51
|
+
Use ``NumPy`` if ``use_gpu == False``. (default: ``False``)
|
|
52
|
+
random_seed (int, optional): Set the random seed in ``CuPy/NumPy`` if not ``None``.
|
|
53
|
+
(default: ``None``)
|
|
54
|
+
|
|
55
|
+
Raises:
|
|
56
|
+
ValueError: Incorrect inputs.
|
|
57
|
+
|
|
58
|
+
Attributes:
|
|
59
|
+
Note: All kwargs are stored as attributes.
|
|
60
|
+
num_proposals (int): the number of times this move has been run. This is needed to
|
|
61
|
+
compute the acceptance fraction.
|
|
62
|
+
gibbs_sampling_setup (list): All of the Gibbs sampling splits as described above.
|
|
63
|
+
xp (obj): ``NumPy`` or ``CuPy``.
|
|
64
|
+
use_gpu (bool): Whether ``Cupy`` (``True``) is used or not (``False``).
|
|
65
|
+
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
temperature_control=None,
|
|
71
|
+
periodic=None,
|
|
72
|
+
gibbs_sampling_setup=None,
|
|
73
|
+
prevent_swaps=False,
|
|
74
|
+
skip_supp_names_update=[],
|
|
75
|
+
is_rj=False,
|
|
76
|
+
use_gpu=False,
|
|
77
|
+
random_seed=None,
|
|
78
|
+
**kwargs
|
|
79
|
+
):
|
|
80
|
+
# store all information
|
|
81
|
+
self.temperature_control = temperature_control
|
|
82
|
+
self.periodic = periodic
|
|
83
|
+
self.skip_supp_names_update = skip_supp_names_update
|
|
84
|
+
self.prevent_swaps = prevent_swaps
|
|
85
|
+
|
|
86
|
+
self._initialize_branch_setup(gibbs_sampling_setup, is_rj=is_rj)
|
|
87
|
+
|
|
88
|
+
# keep track of the number of proposals
|
|
89
|
+
self.num_proposals = 0
|
|
90
|
+
self.time = 0
|
|
91
|
+
|
|
92
|
+
self.use_gpu = use_gpu
|
|
93
|
+
|
|
94
|
+
# set the random seet of the library if desired
|
|
95
|
+
if random_seed is not None:
|
|
96
|
+
self.xp.random.seed(random_seed)
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def use_gpu(self):
|
|
100
|
+
return self._use_gpu
|
|
101
|
+
|
|
102
|
+
@use_gpu.setter
|
|
103
|
+
def use_gpu(self, use_gpu):
|
|
104
|
+
self._use_gpu = use_gpu
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def xp(self):
|
|
108
|
+
if self._use_gpu is None:
|
|
109
|
+
raise ValueError("use_gpu has not been set.")
|
|
110
|
+
xp = cp if self.use_gpu else np
|
|
111
|
+
return xp
|
|
112
|
+
|
|
113
|
+
def _initialize_branch_setup(self, gibbs_sampling_setup, is_rj=False):
|
|
114
|
+
"""Initialize the gibbs setup properly."""
|
|
115
|
+
self.gibbs_sampling_setup = gibbs_sampling_setup
|
|
116
|
+
|
|
117
|
+
message_rj = """inputting gibbs indexing at the leaf/parameter level is not allowed
|
|
118
|
+
with an RJ proposal. Only branch names."""
|
|
119
|
+
|
|
120
|
+
message_non_rj = """When inputing gibbs indexing and using a 2-tuple, second item must be None or 2D np.ndarray of shape (nleaves_max, ndim)."""
|
|
121
|
+
|
|
122
|
+
# setup proposal branches properly
|
|
123
|
+
if self.gibbs_sampling_setup is not None:
|
|
124
|
+
# string indicates one branch (all of it)
|
|
125
|
+
if type(self.gibbs_sampling_setup) not in [str, tuple, list, dict]:
|
|
126
|
+
raise ValueError(
|
|
127
|
+
"gibbs_sampling_setup must be string, dict, tuple, or list."
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
if not isinstance(self.gibbs_sampling_setup, list):
|
|
131
|
+
self.gibbs_sampling_setup = [self.gibbs_sampling_setup]
|
|
132
|
+
|
|
133
|
+
gibbs_sampling_setup_tmp = []
|
|
134
|
+
for item in self.gibbs_sampling_setup:
|
|
135
|
+
# all the arguments are treated
|
|
136
|
+
|
|
137
|
+
# strings indicate single branch all parameters
|
|
138
|
+
if isinstance(item, str):
|
|
139
|
+
gibbs_sampling_setup_tmp.append(item)
|
|
140
|
+
|
|
141
|
+
# tuple is one branch with a split in the parameters
|
|
142
|
+
elif isinstance(item, tuple):
|
|
143
|
+
# check inputs
|
|
144
|
+
assert len(item) == 2
|
|
145
|
+
if item is not None and is_rj:
|
|
146
|
+
raise ValueError(message_rj)
|
|
147
|
+
|
|
148
|
+
elif (
|
|
149
|
+
not isinstance(item[1], np.ndarray) and item[1] is not None
|
|
150
|
+
) or (isinstance(item[1], np.ndarray) and item[1].ndim != 2):
|
|
151
|
+
breakpoint()
|
|
152
|
+
raise ValueError(message_non_rj)
|
|
153
|
+
|
|
154
|
+
gibbs_sampling_setup_tmp.append(item)
|
|
155
|
+
|
|
156
|
+
# dict can include multiple models and parameter splits
|
|
157
|
+
# these will all be in one iteration
|
|
158
|
+
elif isinstance(item, dict):
|
|
159
|
+
tmp = []
|
|
160
|
+
for key, value in item.items():
|
|
161
|
+
# check inputs
|
|
162
|
+
if value is not None and is_rj:
|
|
163
|
+
raise ValueError(message_rj)
|
|
164
|
+
|
|
165
|
+
elif (
|
|
166
|
+
not isinstance(value, np.ndarray) and value is not None
|
|
167
|
+
) or (isinstance(value, np.ndarray) and value.ndim != 2):
|
|
168
|
+
raise ValueError(message_non_rj)
|
|
169
|
+
|
|
170
|
+
tmp.append((key, value))
|
|
171
|
+
|
|
172
|
+
gibbs_sampling_setup_tmp.append(tmp)
|
|
173
|
+
|
|
174
|
+
else:
|
|
175
|
+
raise ValueError(
|
|
176
|
+
"If providing a list for gibbs_sampling_setup, each item needs to be a string, tuple, or dict."
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# copy the original for information if needed
|
|
180
|
+
self.gibbs_sampling_setup_input = deepcopy(self.gibbs_sampling_setup)
|
|
181
|
+
|
|
182
|
+
# store as the setup that all proposals will follow
|
|
183
|
+
self.gibbs_sampling_setup = gibbs_sampling_setup_tmp
|
|
184
|
+
|
|
185
|
+
# now that we have everything out of the input
|
|
186
|
+
# sort into branch names and indices to be run
|
|
187
|
+
branch_names_run_all = []
|
|
188
|
+
inds_run_all = []
|
|
189
|
+
|
|
190
|
+
# for each split in the gibbs splits
|
|
191
|
+
for prop_i, proposal_iteration in enumerate(self.gibbs_sampling_setup):
|
|
192
|
+
# break out
|
|
193
|
+
if isinstance(proposal_iteration, tuple):
|
|
194
|
+
# tuple is 1 entry loop
|
|
195
|
+
branch_names_run_all.append([proposal_iteration[0]])
|
|
196
|
+
inds_run_all.append([proposal_iteration[1]])
|
|
197
|
+
elif isinstance(proposal_iteration, str):
|
|
198
|
+
# string is 1 entry loop
|
|
199
|
+
branch_names_run_all.append([proposal_iteration])
|
|
200
|
+
inds_run_all.append([None])
|
|
201
|
+
|
|
202
|
+
elif isinstance(proposal_iteration, list):
|
|
203
|
+
# list allows more branches at the same time
|
|
204
|
+
branch_names_run_all.append([])
|
|
205
|
+
inds_run_all.append([])
|
|
206
|
+
for item in proposal_iteration:
|
|
207
|
+
if isinstance(item, str):
|
|
208
|
+
branch_names_run_all[prop_i].append(item)
|
|
209
|
+
inds_run_all[prop_i].append(None)
|
|
210
|
+
elif isinstance(item, tuple):
|
|
211
|
+
branch_names_run_all[prop_i].append(item[0])
|
|
212
|
+
inds_run_all[prop_i].append(item[1])
|
|
213
|
+
|
|
214
|
+
# store information
|
|
215
|
+
self.branch_names_run_all = branch_names_run_all
|
|
216
|
+
self.inds_run_all = inds_run_all
|
|
217
|
+
|
|
218
|
+
else:
|
|
219
|
+
# no Gibbs sampling
|
|
220
|
+
self.branch_names_run_all = [None]
|
|
221
|
+
self.inds_run_all = [None]
|
|
222
|
+
|
|
223
|
+
def gibbs_sampling_setup_iterator(self, all_branch_names):
|
|
224
|
+
"""Iterate through the gibbs splits as a generator
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
all_branch_names (list): List of all branch names.
|
|
228
|
+
|
|
229
|
+
Yields:
|
|
230
|
+
2-tuple: Gibbs sampling split.
|
|
231
|
+
First entry is the branch names to run and the second entry is the index
|
|
232
|
+
into the leaves/parameters for this Gibbs split.
|
|
233
|
+
|
|
234
|
+
Raises:
|
|
235
|
+
ValueError: Incorrect inputs.
|
|
236
|
+
|
|
237
|
+
"""
|
|
238
|
+
for branch_names_run, inds_run in zip(
|
|
239
|
+
self.branch_names_run_all, self.inds_run_all
|
|
240
|
+
):
|
|
241
|
+
# adjust if branch_names_run is None
|
|
242
|
+
if branch_names_run is None:
|
|
243
|
+
branch_names_run = all_branch_names
|
|
244
|
+
inds_run = [None for _ in branch_names_run]
|
|
245
|
+
# yield to the iterator
|
|
246
|
+
yield (branch_names_run, inds_run)
|
|
247
|
+
|
|
248
|
+
def setup_proposals(
|
|
249
|
+
self, branch_names_run, inds_run, branches_coords, branches_inds
|
|
250
|
+
):
|
|
251
|
+
"""Setup proposals when gibbs sampling.
|
|
252
|
+
|
|
253
|
+
Get inputs into the proposal including Gibbs split information.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
branch_names_run (list): List of branch names to run concurrently.
|
|
257
|
+
inds_run (list): List of ``inds`` arrays including Gibbs sampling information.
|
|
258
|
+
branches_coords (dict): Dictionary of coordinate arrays for all branches.
|
|
259
|
+
branches_inds (dict): Dictionary of ``inds`` arrays for all branches.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
tuple: (coords, inds, at_least_one_proposal)
|
|
263
|
+
* Coords including Gibbs sampling info.
|
|
264
|
+
* ``inds`` including Gibbs sampling info.
|
|
265
|
+
* ``at_least_one_proposal`` is boolean. It is passed out to
|
|
266
|
+
indicate there is at least one leaf available for the requested branch names.
|
|
267
|
+
|
|
268
|
+
"""
|
|
269
|
+
inds_going_for_proposal = {}
|
|
270
|
+
coords_going_for_proposal = {}
|
|
271
|
+
|
|
272
|
+
at_least_one_proposal = False
|
|
273
|
+
for bnr, ir in zip(branch_names_run, inds_run):
|
|
274
|
+
if ir is not None:
|
|
275
|
+
tmp = np.zeros_like(branches_inds[bnr], dtype=bool)
|
|
276
|
+
|
|
277
|
+
# flatten coordinates to the leaves dimension
|
|
278
|
+
ir_keep = ir.astype(int).sum(axis=-1).astype(bool)
|
|
279
|
+
tmp[:, :, ir_keep] = True
|
|
280
|
+
# make sure leavdes that are actually not there are not counted
|
|
281
|
+
tmp[~branches_inds[bnr]] = False
|
|
282
|
+
inds_going_for_proposal[bnr] = tmp
|
|
283
|
+
else:
|
|
284
|
+
inds_going_for_proposal[bnr] = branches_inds[bnr]
|
|
285
|
+
|
|
286
|
+
if np.any(inds_going_for_proposal[bnr]):
|
|
287
|
+
at_least_one_proposal = True
|
|
288
|
+
|
|
289
|
+
coords_going_for_proposal[bnr] = branches_coords[bnr]
|
|
290
|
+
|
|
291
|
+
return (
|
|
292
|
+
coords_going_for_proposal,
|
|
293
|
+
inds_going_for_proposal,
|
|
294
|
+
at_least_one_proposal,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
def cleanup_proposals_gibbs(
|
|
298
|
+
self,
|
|
299
|
+
branch_names_run,
|
|
300
|
+
inds_run,
|
|
301
|
+
q,
|
|
302
|
+
branches_coords,
|
|
303
|
+
new_inds=None,
|
|
304
|
+
branches_inds=None,
|
|
305
|
+
new_branch_supps=None,
|
|
306
|
+
branches_supplemental=None,
|
|
307
|
+
):
|
|
308
|
+
"""Set all not Gibbs-sampled parameters back
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
branch_names_run (list): List of branch names to run concurrently.
|
|
312
|
+
inds_run (list): List of ``inds`` arrays including Gibbs sampling information.
|
|
313
|
+
q (dict): Dictionary of new coordinate arrays for all proposal branches.
|
|
314
|
+
branches_coords (dict): Dictionary of old coordinate arrays for all branches.
|
|
315
|
+
new_inds (dict, optional): Dictionary of new inds arrays for all proposal branches.
|
|
316
|
+
branches_inds (dict, optional): Dictionary of old inds arrays for all branches.
|
|
317
|
+
new_branch_supps (dict, optional): Dictionary of new branches supplemental for all proposal branches.
|
|
318
|
+
branches_supplemental (dict, optional): Dictionary of old branches supplemental for all branches.
|
|
319
|
+
|
|
320
|
+
"""
|
|
321
|
+
# add back any parameters that are fixed for this round
|
|
322
|
+
for bnr, ir in zip(branch_names_run, inds_run):
|
|
323
|
+
if ir is not None:
|
|
324
|
+
q[bnr][:, :, ~ir] = branches_coords[bnr][:, :, ~ir]
|
|
325
|
+
|
|
326
|
+
# add other models that were not included
|
|
327
|
+
for key, value in branches_coords.items():
|
|
328
|
+
if key not in q:
|
|
329
|
+
q[key] = value.copy()
|
|
330
|
+
if new_inds is not None and key not in new_inds:
|
|
331
|
+
assert branches_inds is not None
|
|
332
|
+
new_inds[key] = branches_inds[key].copy()
|
|
333
|
+
|
|
334
|
+
if new_branch_supps is not None and key not in new_branch_supps:
|
|
335
|
+
assert branches_supplemental is not None
|
|
336
|
+
new_branch_supps[key] = branches_supplemental[key].copy()
|
|
337
|
+
|
|
338
|
+
def ensure_ordering(self, correct_key_order, q, new_inds, new_branch_supps):
|
|
339
|
+
"""Ensure proper order of key in dictionaries.
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
correct_key_order (list): Keys in correct order.
|
|
343
|
+
q (dict): Dictionary of new coordinate arrays for all branches.
|
|
344
|
+
new_inds (dict): Dictionary of new inds arrays for all branches.
|
|
345
|
+
new_branch_supps (dict or None): Dictionary of new branches supplemental for all proposal branches.
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
Tuple: (q, new_inds, new_branch_supps) in correct key order.
|
|
349
|
+
|
|
350
|
+
"""
|
|
351
|
+
if list(q.keys()) != correct_key_order:
|
|
352
|
+
q = {key: q[key] for key in correct_key_order}
|
|
353
|
+
|
|
354
|
+
if list(new_inds.keys()) != correct_key_order:
|
|
355
|
+
new_inds = {key: new_inds[key] for key in correct_key_order}
|
|
356
|
+
|
|
357
|
+
if (
|
|
358
|
+
new_branch_supps is not None
|
|
359
|
+
and list(new_branch_supps.keys()) != correct_key_order
|
|
360
|
+
):
|
|
361
|
+
temp = {key: None for key in correct_key_order}
|
|
362
|
+
for key in new_branch_supps:
|
|
363
|
+
temp[key] = new_branch_supps[key]
|
|
364
|
+
new_branch_supps = deepcopy(temp)
|
|
365
|
+
|
|
366
|
+
return q, new_inds, new_branch_supps
|
|
367
|
+
|
|
368
|
+
def fix_logp_gibbs(self, branch_names_run, inds_run, logp, inds):
|
|
369
|
+
"""Set any walker with no leaves to have logp = -np.inf
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
branch_names_run (list): List of branch names to run concurrently.
|
|
373
|
+
inds_run (list): List of ``inds`` arrays including Gibbs sampling information.
|
|
374
|
+
logp (np.ndarray): Log of the prior going into final posterior computation.
|
|
375
|
+
inds (dict): Dictionary of ``inds`` arrays for all branches.
|
|
376
|
+
|
|
377
|
+
"""
|
|
378
|
+
total_leaves = np.zeros_like(logp, dtype=int)
|
|
379
|
+
total_leaves_here = np.zeros_like(logp, dtype=int)
|
|
380
|
+
for bnr, ir in zip(branch_names_run, inds_run):
|
|
381
|
+
if ir is not None:
|
|
382
|
+
tmp = np.zeros_like(inds[bnr], dtype=bool)
|
|
383
|
+
|
|
384
|
+
# flatten coordinates to the leaves dimension
|
|
385
|
+
ir_keep = ir.astype(int).sum(axis=-1).astype(bool)
|
|
386
|
+
tmp[:, :, ir_keep] = True
|
|
387
|
+
# make sure leaves that are actually not there are not counted
|
|
388
|
+
tmp[~inds[bnr]] = False
|
|
389
|
+
|
|
390
|
+
else:
|
|
391
|
+
tmp = inds[bnr]
|
|
392
|
+
|
|
393
|
+
total_leaves += tmp.sum(axis=-1)
|
|
394
|
+
total_leaves_here += tmp.sum(axis=-1)
|
|
395
|
+
|
|
396
|
+
for name, inds_val in inds.items():
|
|
397
|
+
if name not in branch_names_run:
|
|
398
|
+
total_leaves += inds_val.sum(axis=-1)
|
|
399
|
+
|
|
400
|
+
# adjust
|
|
401
|
+
logp[(total_leaves != 0) & (total_leaves_here == 0)] = -np.inf # no use in running because no change
|
|
402
|
+
logp[(total_leaves == 0) & (total_leaves_here == 0)] = 0.0 # there is nothing in the model currently
|
|
403
|
+
|
|
404
|
+
@property
|
|
405
|
+
def accepted(self):
|
|
406
|
+
"""Accepted counts for this move."""
|
|
407
|
+
if self._accepted is None:
|
|
408
|
+
raise ValueError(
|
|
409
|
+
"accepted must be inititalized with the init_accepted function if you want to use it."
|
|
410
|
+
)
|
|
411
|
+
return self._accepted
|
|
412
|
+
|
|
413
|
+
@accepted.setter
|
|
414
|
+
def accepted(self, accepted):
|
|
415
|
+
assert isinstance(accepted, np.ndarray)
|
|
416
|
+
self._accepted = accepted
|
|
417
|
+
|
|
418
|
+
@property
|
|
419
|
+
def acceptance_fraction(self):
|
|
420
|
+
"""Acceptance fraction for this move."""
|
|
421
|
+
return self.accepted / self.num_proposals
|
|
422
|
+
|
|
423
|
+
@property
|
|
424
|
+
def temperature_control(self):
|
|
425
|
+
"""Temperature controller"""
|
|
426
|
+
return self._temperature_control
|
|
427
|
+
|
|
428
|
+
@temperature_control.setter
|
|
429
|
+
def temperature_control(self, temperature_control):
|
|
430
|
+
self._temperature_control = temperature_control
|
|
431
|
+
|
|
432
|
+
# use the setting of the temperature control to determine which log posterior function to use
|
|
433
|
+
# tempered or basic
|
|
434
|
+
if temperature_control is None:
|
|
435
|
+
self.compute_log_posterior = self.compute_log_posterior_basic
|
|
436
|
+
else:
|
|
437
|
+
self.compute_log_posterior = (
|
|
438
|
+
self.temperature_control.compute_log_posterior_tempered
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
self.ntemps = self.temperature_control.ntemps
|
|
442
|
+
|
|
443
|
+
def compute_log_posterior_basic(self, logl, logp):
|
|
444
|
+
"""Compute the log of posterior
|
|
445
|
+
|
|
446
|
+
:math:`\log{P} = \log{L} + \log{p}`
|
|
447
|
+
|
|
448
|
+
This method is to mesh with the tempered log posterior computation.
|
|
449
|
+
|
|
450
|
+
Args:
|
|
451
|
+
logl (np.ndarray[ntemps, nwalkers]): Log-likelihood values.
|
|
452
|
+
logp (np.ndarray[ntemps, nwalkers]): Log-prior values.
|
|
453
|
+
|
|
454
|
+
Returns:
|
|
455
|
+
np.ndarray[ntemps, nwalkers]: Log-Posterior values.
|
|
456
|
+
"""
|
|
457
|
+
return logl + logp
|
|
458
|
+
|
|
459
|
+
def tune(self, state, accepted):
|
|
460
|
+
"""Tune a proposal
|
|
461
|
+
|
|
462
|
+
This is a place holder for tuning.
|
|
463
|
+
|
|
464
|
+
Args:
|
|
465
|
+
state (:class:`eryn.state.State`): Current state of sampler.
|
|
466
|
+
accepted (np.ndarray[ntemps, nwalkers]): Accepted values for last pass
|
|
467
|
+
through proposal.
|
|
468
|
+
|
|
469
|
+
"""
|
|
470
|
+
pass
|
|
471
|
+
|
|
472
|
+
def update(self, old_state, new_state, accepted, subset=None):
|
|
473
|
+
"""Update a given subset of the ensemble with an accepted proposal
|
|
474
|
+
|
|
475
|
+
This class was updated from ``emcee`` to handle the added structure
|
|
476
|
+
of Eryn.
|
|
477
|
+
|
|
478
|
+
Args:
|
|
479
|
+
old_state (:class:`eryn.state.State`): State with current information.
|
|
480
|
+
New information is added to this state.
|
|
481
|
+
new_state (:class:`eryn.state.State`): State with information from proposed
|
|
482
|
+
points.
|
|
483
|
+
accepted (np.ndarray[ntemps, nwalkers]): A vector of booleans indicating
|
|
484
|
+
which walkers were accepted.
|
|
485
|
+
subset (np.ndarray[ntemps, nwalkers], optional): A boolean mask
|
|
486
|
+
indicating which walkers were included in the subset.
|
|
487
|
+
This can be used, for example, when updating only the
|
|
488
|
+
primary ensemble in a :class:`RedBlueMove`.
|
|
489
|
+
(default: ``None``)
|
|
490
|
+
|
|
491
|
+
Returns:
|
|
492
|
+
:class:`eryn.state.State`: ``old_state`` with accepted points added from ``new_state``.
|
|
493
|
+
|
|
494
|
+
"""
|
|
495
|
+
|
|
496
|
+
# TODO: update this to be use (tuples of inds) ??
|
|
497
|
+
if subset is None:
|
|
498
|
+
# subset of everything
|
|
499
|
+
subset = np.tile(
|
|
500
|
+
np.arange(old_state.log_like.shape[1]), (old_state.log_like.shape[0], 1)
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
# each computation is similar
|
|
504
|
+
# 1. Take subset of values from old information (take_along_axis)
|
|
505
|
+
# 2. Set new information
|
|
506
|
+
# 3. Combine into a new temporary quantity based on accepted or not
|
|
507
|
+
# 4. Put new combined subset back into full arrays (put_along_axis)
|
|
508
|
+
|
|
509
|
+
# take_along_axis is necessary to do this all in higher dimensions
|
|
510
|
+
accepted_temp = np.take_along_axis(accepted, subset, axis=1)
|
|
511
|
+
|
|
512
|
+
# new log likelihood
|
|
513
|
+
old_log_likes = np.take_along_axis(old_state.log_like, subset, axis=1)
|
|
514
|
+
new_log_likes = new_state.log_like
|
|
515
|
+
temp_change_log_like = new_log_likes * (accepted_temp) + old_log_likes * (
|
|
516
|
+
~accepted_temp
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
np.put_along_axis(old_state.log_like, subset, temp_change_log_like, axis=1)
|
|
520
|
+
|
|
521
|
+
# new log prior
|
|
522
|
+
old_log_priors = np.take_along_axis(old_state.log_prior, subset, axis=1)
|
|
523
|
+
new_log_priors = new_state.log_prior.copy()
|
|
524
|
+
|
|
525
|
+
# deal with -infs
|
|
526
|
+
new_log_priors[np.isinf(new_log_priors)] = 0.0
|
|
527
|
+
|
|
528
|
+
temp_change_log_prior = new_log_priors * (accepted_temp) + old_log_priors * (
|
|
529
|
+
~accepted_temp
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
np.put_along_axis(old_state.log_prior, subset, temp_change_log_prior, axis=1)
|
|
533
|
+
|
|
534
|
+
# inds
|
|
535
|
+
old_inds = {
|
|
536
|
+
name: np.take_along_axis(branch.inds, subset[:, :, None], axis=1)
|
|
537
|
+
for name, branch in old_state.branches.items()
|
|
538
|
+
}
|
|
539
|
+
|
|
540
|
+
new_inds = {name: branch.inds for name, branch in new_state.branches.items()}
|
|
541
|
+
|
|
542
|
+
temp_change_inds = {
|
|
543
|
+
name: new_inds[name] * (accepted_temp[:, :, None])
|
|
544
|
+
+ old_inds[name] * (~accepted_temp[:, :, None])
|
|
545
|
+
for name in old_inds
|
|
546
|
+
}
|
|
547
|
+
|
|
548
|
+
[
|
|
549
|
+
np.put_along_axis(
|
|
550
|
+
old_state.branches[name].inds,
|
|
551
|
+
subset[:, :, None],
|
|
552
|
+
temp_change_inds[name],
|
|
553
|
+
axis=1,
|
|
554
|
+
)
|
|
555
|
+
for name in new_inds
|
|
556
|
+
]
|
|
557
|
+
|
|
558
|
+
# check for branches_supplemental
|
|
559
|
+
run_branches_supplemental = False
|
|
560
|
+
for name, value in old_state.branches_supplemental.items():
|
|
561
|
+
if value is not None:
|
|
562
|
+
run_branches_supplemental = True
|
|
563
|
+
|
|
564
|
+
if run_branches_supplemental:
|
|
565
|
+
# branch_supplemental
|
|
566
|
+
temp_change_branch_supplemental = {}
|
|
567
|
+
for name in old_state.branches:
|
|
568
|
+
if old_state.branches[name].branch_supplemental is not None:
|
|
569
|
+
old_branch_supplemental = old_state.branches[
|
|
570
|
+
name
|
|
571
|
+
].branch_supplemental.take_along_axis(
|
|
572
|
+
subset[:, :, None],
|
|
573
|
+
axis=1,
|
|
574
|
+
skip_names=self.skip_supp_names_update,
|
|
575
|
+
)
|
|
576
|
+
new_branch_supplemental = new_state.branches[
|
|
577
|
+
name
|
|
578
|
+
].branch_supplemental[:]
|
|
579
|
+
|
|
580
|
+
tmp = {}
|
|
581
|
+
for key in old_branch_supplemental:
|
|
582
|
+
# need to check to see if we should skip anything
|
|
583
|
+
if key in self.skip_supp_names_update:
|
|
584
|
+
continue
|
|
585
|
+
accepted_temp_here = accepted_temp.copy()
|
|
586
|
+
|
|
587
|
+
# have adjust if it is an object array or a regular array
|
|
588
|
+
if new_branch_supplemental[key].dtype.name != "object":
|
|
589
|
+
for _ in range(
|
|
590
|
+
new_branch_supplemental[key].ndim
|
|
591
|
+
- accepted_temp_here.ndim
|
|
592
|
+
):
|
|
593
|
+
accepted_temp_here = np.expand_dims(
|
|
594
|
+
accepted_temp_here, (-1,)
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
# adjust for GPUs
|
|
598
|
+
try:
|
|
599
|
+
tmp[key] = new_branch_supplemental[key] * (
|
|
600
|
+
accepted_temp_here
|
|
601
|
+
) + old_branch_supplemental[key] * (~accepted_temp_here)
|
|
602
|
+
except TypeError:
|
|
603
|
+
# for gpus
|
|
604
|
+
tmp[key] = new_branch_supplemental[key] * (
|
|
605
|
+
xp.asarray(accepted_temp_here)
|
|
606
|
+
) + old_branch_supplemental[key] * (
|
|
607
|
+
xp.asarray(~accepted_temp_here)
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
temp_change_branch_supplemental[name] = BranchSupplemental(
|
|
611
|
+
tmp,
|
|
612
|
+
base_shape=new_state.branches_supplemental[name].base_shape,
|
|
613
|
+
copy=True,
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
else:
|
|
617
|
+
temp_change_branch_supplemental[name] = None
|
|
618
|
+
|
|
619
|
+
[
|
|
620
|
+
old_state.branches[name].branch_supplemental.put_along_axis(
|
|
621
|
+
subset[:, :, None],
|
|
622
|
+
temp_change_branch_supplemental[name][:],
|
|
623
|
+
axis=1,
|
|
624
|
+
)
|
|
625
|
+
for name in new_inds
|
|
626
|
+
if temp_change_branch_supplemental[name] is not None
|
|
627
|
+
]
|
|
628
|
+
|
|
629
|
+
# sampler level supplemental
|
|
630
|
+
if old_state.supplemental is not None:
|
|
631
|
+
old_suppliment = old_state.supplemental.take_along_axis(subset, axis=1)
|
|
632
|
+
new_suppliment = new_state.supplemental[:]
|
|
633
|
+
|
|
634
|
+
accepted_temp_here = accepted_temp.copy()
|
|
635
|
+
|
|
636
|
+
temp_change_suppliment = {}
|
|
637
|
+
for name in old_suppliment:
|
|
638
|
+
# make sure to get rid of specific supps if requested
|
|
639
|
+
if name in self.skip_supp_names_update:
|
|
640
|
+
continue
|
|
641
|
+
|
|
642
|
+
# adjust if it is not an object array
|
|
643
|
+
if old_suppliment[name].dtype.name != "object":
|
|
644
|
+
for _ in range(old_suppliment[name].ndim - accepted_temp_here.ndim):
|
|
645
|
+
accepted_temp_here = np.expand_dims(accepted_temp_here, (-1,))
|
|
646
|
+
try:
|
|
647
|
+
temp_change_suppliment[name] = new_suppliment[name] * (
|
|
648
|
+
accepted_temp_here
|
|
649
|
+
) + old_suppliment[name] * (~accepted_temp_here)
|
|
650
|
+
except TypeError:
|
|
651
|
+
temp_change_suppliment[name] = new_suppliment[name] * (
|
|
652
|
+
xp.asarray(accepted_temp_here)
|
|
653
|
+
) + old_suppliment[name] * (xp.asarray(~accepted_temp_here))
|
|
654
|
+
old_state.supplemental.put_along_axis(
|
|
655
|
+
subset, temp_change_suppliment, axis=1
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
# coords
|
|
659
|
+
old_coords = {
|
|
660
|
+
name: np.take_along_axis(branch.coords, subset[:, :, None, None], axis=1)
|
|
661
|
+
for name, branch in old_state.branches.items()
|
|
662
|
+
}
|
|
663
|
+
|
|
664
|
+
new_coords = {
|
|
665
|
+
name: branch.coords for name, branch in new_state.branches.items()
|
|
666
|
+
}
|
|
667
|
+
|
|
668
|
+
# change to copy then fill due to issue of adding Nans
|
|
669
|
+
temp_change_coords = {name: old_coords[name].copy() for name in old_coords}
|
|
670
|
+
|
|
671
|
+
for name in temp_change_coords:
|
|
672
|
+
temp_change_coords[name][accepted_temp] = new_coords[name][accepted_temp]
|
|
673
|
+
|
|
674
|
+
[
|
|
675
|
+
np.put_along_axis(
|
|
676
|
+
old_state.branches[name].coords,
|
|
677
|
+
subset[:, :, None, None],
|
|
678
|
+
temp_change_coords[name],
|
|
679
|
+
axis=1,
|
|
680
|
+
)
|
|
681
|
+
for name in new_coords
|
|
682
|
+
]
|
|
683
|
+
|
|
684
|
+
# take care of blobs
|
|
685
|
+
if new_state.blobs is not None:
|
|
686
|
+
if old_state.blobs is None:
|
|
687
|
+
raise ValueError(
|
|
688
|
+
"If you start sampling with a given log_like, "
|
|
689
|
+
"you also need to provide the current list of "
|
|
690
|
+
"blobs at that position."
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
old_blobs = np.take_along_axis(old_state.blobs, subset[:, :, None], axis=1)
|
|
694
|
+
new_blobs = new_state.blobs
|
|
695
|
+
temp_change_blobs = new_blobs * (accepted_temp[:, :, None]) + old_blobs * (
|
|
696
|
+
~accepted_temp[:, :, None]
|
|
697
|
+
)
|
|
698
|
+
|
|
699
|
+
np.put_along_axis(
|
|
700
|
+
old_state.blobs, subset[:, :, None], temp_change_blobs, axis=1
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
return old_state
|