eryn 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eryn/CMakeLists.txt +51 -0
- eryn/__init__.py +35 -0
- eryn/backends/__init__.py +20 -0
- eryn/backends/backend.py +1150 -0
- eryn/backends/hdfbackend.py +819 -0
- eryn/ensemble.py +1690 -0
- eryn/git_version.py.in +7 -0
- eryn/model.py +18 -0
- eryn/moves/__init__.py +42 -0
- eryn/moves/combine.py +135 -0
- eryn/moves/delayedrejection.py +229 -0
- eryn/moves/distgen.py +104 -0
- eryn/moves/distgenrj.py +222 -0
- eryn/moves/gaussian.py +190 -0
- eryn/moves/group.py +281 -0
- eryn/moves/groupstretch.py +120 -0
- eryn/moves/mh.py +193 -0
- eryn/moves/move.py +703 -0
- eryn/moves/mtdistgen.py +137 -0
- eryn/moves/mtdistgenrj.py +190 -0
- eryn/moves/multipletry.py +776 -0
- eryn/moves/red_blue.py +333 -0
- eryn/moves/rj.py +388 -0
- eryn/moves/stretch.py +231 -0
- eryn/moves/tempering.py +649 -0
- eryn/pbar.py +56 -0
- eryn/prior.py +452 -0
- eryn/state.py +775 -0
- eryn/tests/__init__.py +0 -0
- eryn/tests/test_eryn.py +1246 -0
- eryn/utils/__init__.py +10 -0
- eryn/utils/periodic.py +134 -0
- eryn/utils/stopping.py +164 -0
- eryn/utils/transform.py +226 -0
- eryn/utils/updates.py +69 -0
- eryn/utils/utility.py +329 -0
- eryn-1.2.0.dist-info/METADATA +167 -0
- eryn-1.2.0.dist-info/RECORD +39 -0
- eryn-1.2.0.dist-info/WHEEL +4 -0
eryn/state.py
ADDED
|
@@ -0,0 +1,775 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
import cupy as xp
|
|
7
|
+
|
|
8
|
+
except (ModuleNotFoundError, ImportError) as e:
|
|
9
|
+
import numpy as xp
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
__all__ = ["State"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class BranchSupplemental(object):
|
|
17
|
+
"""Special object to carry information through sampler.
|
|
18
|
+
|
|
19
|
+
The :class:`BranchSupplemental` object is a holder of information that is
|
|
20
|
+
passed through the sampler. It can also be indexed similar to other quantities
|
|
21
|
+
carried throughout the sampler.
|
|
22
|
+
|
|
23
|
+
This indexing is based on the ``base_shape``. You can store many objects that have the same base
|
|
24
|
+
shape and then index across all of them. For example, if you want to store individual leaf
|
|
25
|
+
information, the base shape will be ``(ntemps, nwalkers, nleaves_max)``.
|
|
26
|
+
If you want to store a 2D array per individual leaf, the overall shape will be
|
|
27
|
+
``(ntemps, nwalkers, nleaves_max, dim2_extra, dim1_extra)``. Another type of information is
|
|
28
|
+
stored in a class object (for example). Using ``numpy`` object arrays,
|
|
29
|
+
``ntemps * nwalkers * nleaves_max`` number of class objects can be stored in the array. Then,
|
|
30
|
+
using special indexing functions, information can be updated/accessed across all objects
|
|
31
|
+
stored simultaneously. If you index this class, it will give you back a dictionary with
|
|
32
|
+
all objects stored indexed for each leaf. So if you index (0, 0, 0) in our running example,
|
|
33
|
+
you will get back a dictionary with one 2D array and one class object from the ``numpy`` object
|
|
34
|
+
array.
|
|
35
|
+
|
|
36
|
+
All of these objects are stored in ``self.holder``.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
obj_info (dict): Initial information for storage. Keys are the names to be stored under
|
|
40
|
+
and values are arrays. These arrays should have a base shape that is equivalent to
|
|
41
|
+
``base_shape``, meaning ``array.shape[:len(base_shape)] == self.base_shape``.
|
|
42
|
+
The dimensions beyond the base shape can be anything.
|
|
43
|
+
base_shape (tuple): Base shape for indexing. Objects stored in the supplemental object
|
|
44
|
+
will have a shape that at minimum is equivalent to ``base_shape``.
|
|
45
|
+
copy (bool, optional): If ``True``, copy whatever information is given in before it is stored.
|
|
46
|
+
if ``False``, store directly the input information. (default: ``False``)
|
|
47
|
+
|
|
48
|
+
Attributes:
|
|
49
|
+
holder (dict): All of the objects stored for this supplemental object.
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(self, obj_info: dict, base_shape: tuple, copy: bool = False):
|
|
55
|
+
# store initial information
|
|
56
|
+
self.holder = {}
|
|
57
|
+
self.base_shape = base_shape
|
|
58
|
+
self.ndim = len(self.base_shape)
|
|
59
|
+
|
|
60
|
+
# add initial set of objects
|
|
61
|
+
self.add_objects(obj_info, copy=copy)
|
|
62
|
+
|
|
63
|
+
def add_objects(self, obj_info: dict, copy=False):
|
|
64
|
+
"""Add objects to the holder.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
obj_info (dict): Information for storage. Keys are the names to be stored under
|
|
68
|
+
and values are arrays. These arrays should have a base shape that is equivalent to
|
|
69
|
+
``base_shape``, meaning ``array.shape[:len(base_shape)] == self.base_shape``.
|
|
70
|
+
The dimensions beyond the base shape can be anything.
|
|
71
|
+
copy (bool, optional): If ``True``, copy whatever information is given in before it is stored.
|
|
72
|
+
if ``False``, store directly the input information. (default: ``False``)
|
|
73
|
+
|
|
74
|
+
Raises:
|
|
75
|
+
ValueError: Shape matching issues.
|
|
76
|
+
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
# whether a copy is requested
|
|
80
|
+
dc = deepcopy if copy else (lambda x: x)
|
|
81
|
+
|
|
82
|
+
# iterate through the dictionary of incoming objects to add
|
|
83
|
+
for name, obj_contained in obj_info.items():
|
|
84
|
+
if (
|
|
85
|
+
isinstance(obj_contained, np.ndarray)
|
|
86
|
+
and obj_contained.dtype.name == "object"
|
|
87
|
+
):
|
|
88
|
+
self.holder[name] = dc(obj_contained)
|
|
89
|
+
if self.base_shape is None:
|
|
90
|
+
self.base_shape = self.holder[name].shape
|
|
91
|
+
self.ndim = ndim = len(self.base_shape)
|
|
92
|
+
else:
|
|
93
|
+
if self.holder[name].shape != self.base_shape:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"Outer shapes of all input objects must be the same. {name} object array has shape {self.holder[name].shape}. The original shape found was {self.base_shape}."
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
else:
|
|
99
|
+
self.ndim = ndim = len(self.base_shape)
|
|
100
|
+
|
|
101
|
+
# xp for GPU
|
|
102
|
+
if isinstance(obj_contained, np.ndarray) or isinstance(
|
|
103
|
+
obj_contained, xp.ndarray
|
|
104
|
+
):
|
|
105
|
+
self.holder[name] = obj_contained.copy()
|
|
106
|
+
|
|
107
|
+
# fill object array from list
|
|
108
|
+
# adjust based on how many dimensions found
|
|
109
|
+
else:
|
|
110
|
+
# objects to be stored
|
|
111
|
+
self.holder[name] = np.empty(self.base_shape, dtype=object)
|
|
112
|
+
if len(obj_contained) != self.base_shape[0]:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
"Shapes of obj_contained does not match base_shape along axis 0."
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
if ndim > 1:
|
|
118
|
+
for i in range(self.base_shape[0]):
|
|
119
|
+
if len(obj_contained[i]) != self.base_shape[1]:
|
|
120
|
+
raise ValueError(
|
|
121
|
+
"Shapes of obj_contained does not match obj_contained_sha along axis 1."
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
if ndim > 2:
|
|
125
|
+
for j in range(self.base_shape[1]):
|
|
126
|
+
if len(obj_contained[i][j]) != self.base_shape[2]:
|
|
127
|
+
raise ValueError(
|
|
128
|
+
"Shapes of obj_contained does not match base_shape along axis 2."
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
for k in range(self.base_shape[2]):
|
|
132
|
+
self.holder[name][i, j, k] = obj_contained[i][
|
|
133
|
+
j
|
|
134
|
+
][k]
|
|
135
|
+
else:
|
|
136
|
+
for j in range(self.base_shape[1]):
|
|
137
|
+
self.holder[name][i, j] = obj_contained[i][j]
|
|
138
|
+
|
|
139
|
+
else:
|
|
140
|
+
for i in range(self.base_shape[0]):
|
|
141
|
+
self.holder[name][i] = obj_contained[i]
|
|
142
|
+
|
|
143
|
+
def remove_objects(self, names):
|
|
144
|
+
"""Remove objects from the holder.
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
names (str or list of str): Strings associated with information to delete.
|
|
149
|
+
Please note it does not return the information.
|
|
150
|
+
|
|
151
|
+
Raises:
|
|
152
|
+
ValueError: Input issues.
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
"""
|
|
156
|
+
# check inputs
|
|
157
|
+
if not isinstance(names, list):
|
|
158
|
+
if not isinstance(names, str):
|
|
159
|
+
raise ValueError("names must be a string or list of strings.")
|
|
160
|
+
|
|
161
|
+
names = [names]
|
|
162
|
+
|
|
163
|
+
# iterate and remove items from holder
|
|
164
|
+
for name in names:
|
|
165
|
+
self.holder.pop(name)
|
|
166
|
+
|
|
167
|
+
@property
|
|
168
|
+
def contained_objects(self):
|
|
169
|
+
"""The list of keys of contained objects."""
|
|
170
|
+
return list(self.holder.keys())
|
|
171
|
+
|
|
172
|
+
def __contains__(self, name: str):
|
|
173
|
+
"""Check if the holder holds a specific key."""
|
|
174
|
+
return name in self.holder
|
|
175
|
+
|
|
176
|
+
def __getitem__(self, tmp):
|
|
177
|
+
"""Special indexing for retrieval.
|
|
178
|
+
|
|
179
|
+
When indexing the overall class, this will return the slice of each object
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
tmp (int, np.ndarray, or slice): indexing slice of some form.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
dict: Keys are names of the objects contained. Values are the slices of those objects.
|
|
186
|
+
|
|
187
|
+
"""
|
|
188
|
+
# slice each object contained
|
|
189
|
+
return {name: values[tmp] for name, values in self.holder.items()}
|
|
190
|
+
|
|
191
|
+
def __setitem__(self, tmp, new_value):
|
|
192
|
+
"""Special indexing for setting elements.
|
|
193
|
+
|
|
194
|
+
When indexing the overall class, this will set object information.
|
|
195
|
+
|
|
196
|
+
**Please note**: If you try to input information that is not already stored,
|
|
197
|
+
it will ignore it.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
tmp (int, np.ndarray, or slice): indexing slice of some form.
|
|
201
|
+
|
|
202
|
+
"""
|
|
203
|
+
# loop through values already in holder
|
|
204
|
+
for name, values in self.holder.items():
|
|
205
|
+
if name not in new_value:
|
|
206
|
+
continue
|
|
207
|
+
# if the name is already contained, update with incoming value
|
|
208
|
+
self.holder[name][tmp] = new_value[name]
|
|
209
|
+
|
|
210
|
+
def take_along_axis(self, indices, axis: int, skip_names=[]):
|
|
211
|
+
"""Take information from contained arrays along an axis.
|
|
212
|
+
|
|
213
|
+
See ```numpy.take_along_axis`` <https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html>`_.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
indices (xp.ndarray): Indices to take along each 1d slice of arr. This must match the dimension
|
|
217
|
+
of ``self.base_shape``, but other dimensions only need to broadcast against
|
|
218
|
+
``self.base_shape``.
|
|
219
|
+
axis (int): The axis to take 1d slices along.
|
|
220
|
+
skip_names (list of str, optional): By default, this function returns the results for
|
|
221
|
+
all stored objects. This list gives the strings of objects to leave behind and
|
|
222
|
+
not return.
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
dict: Keys are names of stored objects and values are the proper array slices.
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
"""
|
|
229
|
+
# prepare output dictionary
|
|
230
|
+
out = {}
|
|
231
|
+
|
|
232
|
+
# iterate through holder
|
|
233
|
+
for name, values in self.holder.items():
|
|
234
|
+
# skip names if desired
|
|
235
|
+
if name in skip_names:
|
|
236
|
+
continue
|
|
237
|
+
|
|
238
|
+
indices_temp = indices.copy()
|
|
239
|
+
# adjust indices properly for specific object within the holder
|
|
240
|
+
if (
|
|
241
|
+
isinstance(values, np.ndarray) and values.dtype.name != "object"
|
|
242
|
+
) or isinstance(values, xp.ndarray):
|
|
243
|
+
# expand the dimensions of the indexing values for non-object arrays
|
|
244
|
+
for _ in range(values.ndim - indices_temp.ndim):
|
|
245
|
+
if isinstance(values, np.ndarray):
|
|
246
|
+
indices_temp = np.expand_dims(np.asarray(indices_temp), (-1,))
|
|
247
|
+
elif isinstance(values, xp.ndarray):
|
|
248
|
+
indices_temp = xp.expand_dims(xp.asarray(indices_temp), (-1,))
|
|
249
|
+
|
|
250
|
+
# store the output for either numpy or cupy
|
|
251
|
+
if isinstance(values, np.ndarray):
|
|
252
|
+
out[name] = np.take_along_axis(values, indices_temp, axis)
|
|
253
|
+
|
|
254
|
+
elif isinstance(values, xp.ndarray):
|
|
255
|
+
out[name] = xp.take_along_axis(values, indices_temp, axis)
|
|
256
|
+
|
|
257
|
+
return out
|
|
258
|
+
|
|
259
|
+
def put_along_axis(self, indices, values_in: dict, axis: int):
|
|
260
|
+
"""Put information information into contained arrays along an axis.
|
|
261
|
+
|
|
262
|
+
See ```numpy.put_along_axis`` <https://numpy.org/doc/stable/reference/generated/numpy.put_along_axis.html>`_.
|
|
263
|
+
|
|
264
|
+
**Please note** this function is not implemented in ``cupy``, so this is a custom implementation
|
|
265
|
+
for both ``cupy`` and ``numpy``.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
indices (xp.ndarray): Indices to put values along each 1d slice of arr. This must match
|
|
269
|
+
the dimension of ``self.base_shape``, but other dimensions only need to broadcast against
|
|
270
|
+
``self.base_shape``.
|
|
271
|
+
axis (int): The axis to put 1d slices along.
|
|
272
|
+
values_in (dict): Keys are the objects contained to update. Values are the arrays of these
|
|
273
|
+
objects with shape and dimension that can broadcast to match that of indices.
|
|
274
|
+
|
|
275
|
+
"""
|
|
276
|
+
# iterate through all objects in the holder
|
|
277
|
+
for name, values in self.holder.items():
|
|
278
|
+
# skip names that are not to be updated
|
|
279
|
+
if name not in values_in:
|
|
280
|
+
continue
|
|
281
|
+
|
|
282
|
+
# will need to have flexibility to broadcast
|
|
283
|
+
indices_temp = indices.copy()
|
|
284
|
+
|
|
285
|
+
if (
|
|
286
|
+
isinstance(values, np.ndarray) and values.dtype.name != "object"
|
|
287
|
+
) or isinstance(values, xp.ndarray):
|
|
288
|
+
# prepare indices for proper broadcasting
|
|
289
|
+
for _ in range(values.ndim - indices_temp.ndim):
|
|
290
|
+
if isinstance(values, np.ndarray):
|
|
291
|
+
indices_temp = np.expand_dims(np.asarray(indices_temp), (-1,))
|
|
292
|
+
elif isinstance(values, xp.ndarray):
|
|
293
|
+
indices_temp = xp.expand_dims(xp.asarray(indices_temp), (-1,))
|
|
294
|
+
|
|
295
|
+
# prepare slicing information for entry
|
|
296
|
+
if isinstance(values, np.ndarray):
|
|
297
|
+
inds0 = np.repeat(
|
|
298
|
+
np.arange(len(indices_temp))[:, None], indices_temp.shape[1], axis=1
|
|
299
|
+
)
|
|
300
|
+
elif isinstance(values, xp.ndarray):
|
|
301
|
+
inds0 = xp.repeat(
|
|
302
|
+
np.arange(len(indices_temp))[:, None], indices_temp.shape[1], axis=1
|
|
303
|
+
)
|
|
304
|
+
# self.xp.put_along_axis(self.holder[name], indices_temp, values_in[name], axis)
|
|
305
|
+
# because cupy does not have put_along_axis
|
|
306
|
+
self.holder[name][(inds0.flatten(), indices_temp.flatten())] = values_in[
|
|
307
|
+
name
|
|
308
|
+
].reshape((-1,) + values_in[name].shape[2:])
|
|
309
|
+
|
|
310
|
+
@property
|
|
311
|
+
def flat(self):
|
|
312
|
+
"""Get flattened arrays from the stored objects.
|
|
313
|
+
|
|
314
|
+
Here "flat" is in relation to ``self.base_shape``. Beyond ``self.base_shape``, the shape is mainted.
|
|
315
|
+
|
|
316
|
+
"""
|
|
317
|
+
out = {}
|
|
318
|
+
# loop through holder
|
|
319
|
+
for name, values in self.holder.items():
|
|
320
|
+
if (
|
|
321
|
+
isinstance(values, np.ndarray) and values.dtype.name != "object"
|
|
322
|
+
) or isinstance(values, xp.ndarray):
|
|
323
|
+
# need to account for higher dimensional arrays.
|
|
324
|
+
out[name] = values.reshape((-1,) + values.shape[2:])
|
|
325
|
+
else:
|
|
326
|
+
out[name] = values.flatten()
|
|
327
|
+
return out
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
class Branch(object):
|
|
331
|
+
"""Special container for one branch (model)
|
|
332
|
+
|
|
333
|
+
This class is a key component of Eryn. It this type of object
|
|
334
|
+
that allows for different models to be considered simultaneously
|
|
335
|
+
within an MCMC run.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
coords (4D double np.ndarray[ntemps, nwalkers, nleaves_max, ndim]): The coordinates
|
|
339
|
+
in parameter space of all walkers.
|
|
340
|
+
inds (3D bool np.ndarray[ntemps, nwalkers, nleaves_max], optional): The information
|
|
341
|
+
on which leaves are used and which are not used. A value of True means the specific leaf
|
|
342
|
+
was used in this step. Parameters from unused walkers are still kept. When they
|
|
343
|
+
are output to the backend, the backend saves a special number (default: ``np.nan``) for all coords
|
|
344
|
+
related to unused leaves at that step. If None, inds will fill with all True values.
|
|
345
|
+
(default: ``None``)
|
|
346
|
+
branch_supplemental (object): :class:`BranchSupplemental` object specific to this branch. (default: ``None``)
|
|
347
|
+
|
|
348
|
+
Raises:
|
|
349
|
+
ValueError: ``inds`` has wrong shape or number of leaves is less than zero.
|
|
350
|
+
|
|
351
|
+
"""
|
|
352
|
+
|
|
353
|
+
def __init__(self, coords, inds=None, branch_supplemental=None):
|
|
354
|
+
# store branch info
|
|
355
|
+
self.coords = coords
|
|
356
|
+
self.ntemps, self.ntrees, self.nleaves_max, self.ndim = coords.shape
|
|
357
|
+
self.shape = coords.shape
|
|
358
|
+
|
|
359
|
+
# make sure inds is correct
|
|
360
|
+
if inds is None:
|
|
361
|
+
self.inds = np.full((self.ntemps, self.ntrees, self.nleaves_max), True)
|
|
362
|
+
elif not isinstance(inds, np.ndarray):
|
|
363
|
+
raise ValueError("inds must be np.ndarray in Branch.")
|
|
364
|
+
elif inds.shape != (self.ntemps, self.ntrees, self.nleaves_max):
|
|
365
|
+
raise ValueError("inds has wrong shape.")
|
|
366
|
+
else:
|
|
367
|
+
self.inds = inds
|
|
368
|
+
|
|
369
|
+
if branch_supplemental is not None:
|
|
370
|
+
# make sure branch_supplemental shape matches
|
|
371
|
+
if branch_supplemental.base_shape != self.inds.shape:
|
|
372
|
+
raise ValueError(
|
|
373
|
+
f"branch_supplemental shape ( {branch_supplemental.base_shape} ) does not match inds shape ( {self.inds.shape} )."
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
# store
|
|
377
|
+
self.branch_supplemental = branch_supplemental
|
|
378
|
+
|
|
379
|
+
@property
|
|
380
|
+
def nleaves(self):
|
|
381
|
+
"""Number of leaves for each walker"""
|
|
382
|
+
# get number of leaves in each walker by summing inds along last axis
|
|
383
|
+
nleaves = np.sum(self.inds, axis=-1)
|
|
384
|
+
return nleaves
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
class State(object):
|
|
388
|
+
"""The state of the ensemble during an MCMC run
|
|
389
|
+
|
|
390
|
+
Args:
|
|
391
|
+
coords (double ndarray[ntemps, nwalkers, nleaves_max, ndim], dict, or :class:`.State`): The current positions of the walkers
|
|
392
|
+
in the parameter space. If dict, need to use ``branch_names`` for the keys.
|
|
393
|
+
inds (bool ndarray[ntemps, nwalkers, nleaves_max] or dict, optional): The information
|
|
394
|
+
on which leaves are used and which are not used. A value of True means the specific leaf
|
|
395
|
+
was used in this step. If dict, need to use ``branch_names`` for the keys.
|
|
396
|
+
Input should be ``None`` if a complete :class:`.State` object is input for ``coords``.
|
|
397
|
+
(default: ``None``)
|
|
398
|
+
branch_supplemental (object): :class:`BranchSupplemental` object specific to this branch.
|
|
399
|
+
(default: ``None``)
|
|
400
|
+
log_like (ndarray[ntemps, nwalkers], optional): Log likelihoods
|
|
401
|
+
for the walkers at positions given by ``coords``.
|
|
402
|
+
Input should be ``None`` if a complete :class:`.State` object is input for ``coords``.
|
|
403
|
+
(default: ``None``)
|
|
404
|
+
log_prior (ndarray[ntemps, nwalkers], optional): Log priors
|
|
405
|
+
for the walkers at positions given by ``coords``.
|
|
406
|
+
Input should be ``None`` if a complete :class:`.State` object is input for ``coords``.
|
|
407
|
+
(default: ``None``)
|
|
408
|
+
betas (ndarray[ntemps], optional): Temperatures in the sampler at the current step.
|
|
409
|
+
Input should be ``None`` if a complete :class:`.State` object is input for ``coords``.
|
|
410
|
+
(default: ``None``)
|
|
411
|
+
blobs (ndarray[ntemps, nwalkers, nblobs], Optional): The metadata “blobs”
|
|
412
|
+
associated with the current position. The value is only returned if
|
|
413
|
+
lnpostfn returns blobs too.
|
|
414
|
+
Input should be ``None`` if a complete :class:`.State` object is input for ``coords``.
|
|
415
|
+
(default: ``None``)
|
|
416
|
+
random_state (Optional): The current state of the random number
|
|
417
|
+
generator.
|
|
418
|
+
Input should be ``None`` if a complete :class:`.State` object is input for ``coords``.
|
|
419
|
+
(default: ``None``)
|
|
420
|
+
copy (bool, optional): If True, copy the the arrays in the former :class:`.State` obhect.
|
|
421
|
+
|
|
422
|
+
Raises:
|
|
423
|
+
ValueError: Dimensions of inputs or input types are incorrect.
|
|
424
|
+
|
|
425
|
+
"""
|
|
426
|
+
|
|
427
|
+
# __slots__ = (
|
|
428
|
+
# "branches",
|
|
429
|
+
# "log_like",
|
|
430
|
+
# "log_prior",
|
|
431
|
+
# "blobs",
|
|
432
|
+
# "betas",
|
|
433
|
+
# "supplemental",
|
|
434
|
+
# "random_state",
|
|
435
|
+
# )
|
|
436
|
+
|
|
437
|
+
def __init__(
|
|
438
|
+
self,
|
|
439
|
+
coords,
|
|
440
|
+
inds=None,
|
|
441
|
+
branch_supplemental=None,
|
|
442
|
+
supplemental=None,
|
|
443
|
+
log_like=None,
|
|
444
|
+
log_prior=None,
|
|
445
|
+
betas=None,
|
|
446
|
+
blobs=None,
|
|
447
|
+
random_state=None,
|
|
448
|
+
copy=False,
|
|
449
|
+
):
|
|
450
|
+
# decide if copying input info
|
|
451
|
+
dc = deepcopy if copy else lambda x: x
|
|
452
|
+
|
|
453
|
+
# check if coords is a State object
|
|
454
|
+
if hasattr(coords, "branches"):
|
|
455
|
+
self.branches = dc(coords.branches)
|
|
456
|
+
self.log_like = dc(coords.log_like)
|
|
457
|
+
self.log_prior = dc(coords.log_prior)
|
|
458
|
+
self.blobs = dc(coords.blobs)
|
|
459
|
+
self.betas = dc(coords.betas)
|
|
460
|
+
self.supplemental = dc(coords.supplemental)
|
|
461
|
+
self.random_state = dc(coords.random_state)
|
|
462
|
+
return
|
|
463
|
+
|
|
464
|
+
# protect against simplifying settings
|
|
465
|
+
if isinstance(coords, np.ndarray) or isinstance(coords, xp.ndarray):
|
|
466
|
+
coords = {"model_0": coords}
|
|
467
|
+
elif not isinstance(coords, dict):
|
|
468
|
+
raise ValueError(
|
|
469
|
+
"Input coords need to be np.ndarray, dict, or State object."
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
for name in coords:
|
|
473
|
+
if coords[name].ndim == 2:
|
|
474
|
+
coords[name] = coords[name][None, :, None, :]
|
|
475
|
+
|
|
476
|
+
# assume (ntemps, nwalkers) provided
|
|
477
|
+
if coords[name].ndim == 3:
|
|
478
|
+
coords[name] = coords[name][:, :, None, :]
|
|
479
|
+
|
|
480
|
+
elif coords[name].ndim < 2 or coords[name].ndim > 4:
|
|
481
|
+
raise ValueError(
|
|
482
|
+
"Dimension off coordinates must be between 2 and 4. coords dimension is {0}.".format(
|
|
483
|
+
coords.ndim
|
|
484
|
+
)
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
# if no inds given, make sure this is clear for all Branch objects
|
|
488
|
+
if inds is None:
|
|
489
|
+
inds = {key: None for key in coords}
|
|
490
|
+
elif not isinstance(inds, dict):
|
|
491
|
+
raise ValueError("inds must be None or dict.")
|
|
492
|
+
|
|
493
|
+
if branch_supplemental is None:
|
|
494
|
+
branch_supplemental = {key: None for key in coords}
|
|
495
|
+
elif isinstance(
|
|
496
|
+
branch_supplemental, dict
|
|
497
|
+
): # case where not all branches have supp
|
|
498
|
+
for key in coords.keys() - branch_supplemental.keys():
|
|
499
|
+
branch_supplemental[key] = None
|
|
500
|
+
elif not isinstance(branch_supplemental, dict):
|
|
501
|
+
raise ValueError("branch_supplemental must be None or dict.")
|
|
502
|
+
|
|
503
|
+
# setup all information for storage
|
|
504
|
+
self.branches = {
|
|
505
|
+
key: Branch(
|
|
506
|
+
dc(temp_coords),
|
|
507
|
+
inds=inds[key],
|
|
508
|
+
branch_supplemental=branch_supplemental[key],
|
|
509
|
+
)
|
|
510
|
+
for key, temp_coords in coords.items()
|
|
511
|
+
}
|
|
512
|
+
self.log_like = dc(np.atleast_2d(log_like)) if log_like is not None else None
|
|
513
|
+
self.log_prior = dc(np.atleast_2d(log_prior)) if log_prior is not None else None
|
|
514
|
+
self.blobs = dc(np.atleast_3d(blobs)) if blobs is not None else None
|
|
515
|
+
self.betas = dc(np.atleast_1d(betas)) if betas is not None else None
|
|
516
|
+
self.supplemental = dc(supplemental)
|
|
517
|
+
self.random_state = dc(random_state)
|
|
518
|
+
|
|
519
|
+
@property
|
|
520
|
+
def branches_inds(self):
|
|
521
|
+
"""Get the ``inds`` from all branch objects returned as a dictionary with ``branch_names`` as keys."""
|
|
522
|
+
return {name: branch.inds for name, branch in self.branches.items()}
|
|
523
|
+
|
|
524
|
+
@property
|
|
525
|
+
def branches_coords(self):
|
|
526
|
+
"""Get the ``coords`` from all branch objects returned as a dictionary with ``branch_names`` as keys."""
|
|
527
|
+
return {name: branch.coords for name, branch in self.branches.items()}
|
|
528
|
+
|
|
529
|
+
@property
|
|
530
|
+
def branches_supplemental(self):
|
|
531
|
+
"""Get the ``branch.supplemental`` from all branch objects returned as a dictionary with ``branch_names`` as keys."""
|
|
532
|
+
return {
|
|
533
|
+
name: branch.branch_supplemental for name, branch in self.branches.items()
|
|
534
|
+
}
|
|
535
|
+
|
|
536
|
+
@property
|
|
537
|
+
def branch_names(self):
|
|
538
|
+
"""Get the branch names in this state."""
|
|
539
|
+
return list(self.branches.keys())
|
|
540
|
+
|
|
541
|
+
def copy_into_self(self, state_to_copy):
|
|
542
|
+
for name in state_to_copy.__slots__:
|
|
543
|
+
setattr(self, name, getattr(state_to_copy, name))
|
|
544
|
+
|
|
545
|
+
def get_log_posterior(self, temper: bool = False):
|
|
546
|
+
"""Get the posterior probability
|
|
547
|
+
|
|
548
|
+
Args:
|
|
549
|
+
temper (bool, optional): If ``True``, apply tempering to the posterior computation.
|
|
550
|
+
|
|
551
|
+
Returns:
|
|
552
|
+
np.ndarray[ntemps, nwalkers]: Log of the posterior probability.
|
|
553
|
+
|
|
554
|
+
"""
|
|
555
|
+
|
|
556
|
+
if temper:
|
|
557
|
+
betas = self.betas
|
|
558
|
+
|
|
559
|
+
else:
|
|
560
|
+
betas = np.ones_like(self.betas)
|
|
561
|
+
|
|
562
|
+
return betas * self.log_like + self.log_prior
|
|
563
|
+
|
|
564
|
+
"""
|
|
565
|
+
# TODO
|
|
566
|
+
def __repr__(self):
|
|
567
|
+
return "State({0}, log_like={1}, blobs={2}, betas={3}, random_state={4})".format(
|
|
568
|
+
self.coords, self.log_like, self.blobs, self.betas, self.random_state
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
def __iter__(self):
|
|
572
|
+
temp = (self.coords,)
|
|
573
|
+
if self.log_like is not None:
|
|
574
|
+
temp += (self.log_like,)
|
|
575
|
+
|
|
576
|
+
if self.blobs is not None:
|
|
577
|
+
temp += (self.blobs,)
|
|
578
|
+
|
|
579
|
+
if self.betas is None:
|
|
580
|
+
temp += (self.betas,)
|
|
581
|
+
|
|
582
|
+
if self.random_state is not None:
|
|
583
|
+
temp += (self.random_state,)
|
|
584
|
+
return iter(temp)
|
|
585
|
+
"""
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
class ParaState(object):
|
|
589
|
+
"""The state of the ensemble during an MCMC run
|
|
590
|
+
|
|
591
|
+
Args:
|
|
592
|
+
coords (double ndarray[ntemps, nwalkers, nleaves_max, ndim], dict, or :class:`.State`): The current positions of the walkers
|
|
593
|
+
in the parameter space. If dict, need to use ``branch_names`` for the keys.
|
|
594
|
+
groups_running (bool ndarray[ntemps, nwalkers, nleaves_max] or dict, optional): The information
|
|
595
|
+
on which leaves are used and which are not used. A value of True means the specific leaf
|
|
596
|
+
was used in this step. If dict, need to use ``branch_names`` for the keys.
|
|
597
|
+
Input should be ``None`` if a complete :class:`.State` object is input for ``coords``.
|
|
598
|
+
(default: ``None``)
|
|
599
|
+
log_like (ndarray[ntemps, nwalkers], optional): Log likelihoods
|
|
600
|
+
for the walkers at positions given by ``coords``.
|
|
601
|
+
Input should be ``None`` if a complete :class:`.State` object is input for ``coords``.
|
|
602
|
+
(default: ``None``)
|
|
603
|
+
log_prior (ndarray[ntemps, nwalkers], optional): Log priors
|
|
604
|
+
for the walkers at positions given by ``coords``.
|
|
605
|
+
Input should be ``None`` if a complete :class:`.State` object is input for ``coords``.
|
|
606
|
+
(default: ``None``)
|
|
607
|
+
betas (ndarray[ntemps], optional): Temperatures in the sampler at the current step.
|
|
608
|
+
Input should be ``None`` if a complete :class:`.State` object is input for ``coords``.
|
|
609
|
+
(default: ``None``)
|
|
610
|
+
blobs (ndarray[ntemps, nwalkers, nblobs], Optional): The metadata “blobs”
|
|
611
|
+
associated with the current position. The value is only returned if
|
|
612
|
+
lnpostfn returns blobs too.
|
|
613
|
+
Input should be ``None`` if a complete :class:`.State` object is input for ``coords``.
|
|
614
|
+
(default: ``None``)
|
|
615
|
+
random_state (Optional): The current state of the random number
|
|
616
|
+
generator.
|
|
617
|
+
Input should be ``None`` if a complete :class:`.State` object is input for ``coords``.
|
|
618
|
+
(default: ``None``)
|
|
619
|
+
copy (bool, optional): If True, copy the the arrays in the former :class:`.State` obhect.
|
|
620
|
+
|
|
621
|
+
Raises:
|
|
622
|
+
ValueError: Dimensions of inputs or input types are incorrect.
|
|
623
|
+
|
|
624
|
+
"""
|
|
625
|
+
|
|
626
|
+
# __slots__ = (
|
|
627
|
+
# "branches",
|
|
628
|
+
# "log_like",
|
|
629
|
+
# "log_prior",
|
|
630
|
+
# "blobs",
|
|
631
|
+
# "betas",
|
|
632
|
+
# "supplemental",
|
|
633
|
+
# "random_state",
|
|
634
|
+
# )
|
|
635
|
+
|
|
636
|
+
def __init__(
|
|
637
|
+
self,
|
|
638
|
+
coords,
|
|
639
|
+
groups_running=None,
|
|
640
|
+
branch_supplemental=None,
|
|
641
|
+
supplemental=None,
|
|
642
|
+
log_like=None,
|
|
643
|
+
log_prior=None,
|
|
644
|
+
betas=None,
|
|
645
|
+
blobs=None,
|
|
646
|
+
random_state=None,
|
|
647
|
+
copy=False,
|
|
648
|
+
):
|
|
649
|
+
# decide if copying input info
|
|
650
|
+
dc = deepcopy if copy else lambda x: x
|
|
651
|
+
|
|
652
|
+
# check if coords is a State object
|
|
653
|
+
if hasattr(coords, "branches"):
|
|
654
|
+
self.branches = dc(coords.branches)
|
|
655
|
+
self.groups_running = dc(coords.groups_running)
|
|
656
|
+
self.log_like = dc(coords.log_like)
|
|
657
|
+
self.log_prior = dc(coords.log_prior)
|
|
658
|
+
self.blobs = dc(coords.blobs)
|
|
659
|
+
self.betas = dc(coords.betas)
|
|
660
|
+
self.supplemental = dc(coords.supplemental)
|
|
661
|
+
# self.random_state = dc(coords.random_state)
|
|
662
|
+
# TODO: check this
|
|
663
|
+
self.random_state = coords.random_state
|
|
664
|
+
return
|
|
665
|
+
|
|
666
|
+
# protect against simplifying settings
|
|
667
|
+
if isinstance(coords, np.ndarray) or isinstance(coords, xp.ndarray):
|
|
668
|
+
coords = {"model_0": coords}
|
|
669
|
+
elif not isinstance(coords, dict):
|
|
670
|
+
raise ValueError(
|
|
671
|
+
"Input coords need to be np.ndarray, dict, or State object."
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
for name in coords:
|
|
675
|
+
if coords[name].ndim == 2:
|
|
676
|
+
coords[name] = coords[name][None, :, None, :]
|
|
677
|
+
|
|
678
|
+
# assume (ntemps, nwalkers) provided
|
|
679
|
+
if coords[name].ndim == 3:
|
|
680
|
+
coords[name] = coords[name][:, :, None, :]
|
|
681
|
+
|
|
682
|
+
elif coords[name].ndim < 2 or coords[name].ndim > 4:
|
|
683
|
+
raise ValueError(
|
|
684
|
+
"Dimension off coordinates must be between 2 and 4. coords dimension is {0}.".format(
|
|
685
|
+
coords.ndim
|
|
686
|
+
)
|
|
687
|
+
)
|
|
688
|
+
|
|
689
|
+
if branch_supplemental is None:
|
|
690
|
+
branch_supplemental = {key: None for key in coords}
|
|
691
|
+
elif not isinstance(branch_supplemental, dict):
|
|
692
|
+
raise ValueError("branch_supplemental must be None or dict.")
|
|
693
|
+
|
|
694
|
+
# setup all information for storage
|
|
695
|
+
self.branches = {
|
|
696
|
+
key: Branch(
|
|
697
|
+
dc(temp_coords),
|
|
698
|
+
inds=None,
|
|
699
|
+
branch_supplemental=branch_supplemental[key],
|
|
700
|
+
)
|
|
701
|
+
for key, temp_coords in coords.items()
|
|
702
|
+
}
|
|
703
|
+
|
|
704
|
+
self.groups_running = (
|
|
705
|
+
dc(np.atleast_1d(groups_running)) if groups_running is not None else None
|
|
706
|
+
)
|
|
707
|
+
self.log_like = dc(np.atleast_2d(log_like)) if log_like is not None else None
|
|
708
|
+
self.log_prior = dc(np.atleast_2d(log_prior)) if log_prior is not None else None
|
|
709
|
+
self.blobs = dc(np.atleast_3d(blobs)) if blobs is not None else None
|
|
710
|
+
self.betas = dc(np.atleast_1d(betas)) if betas is not None else None
|
|
711
|
+
self.supplemental = dc(supplemental)
|
|
712
|
+
self.random_state = dc(random_state)
|
|
713
|
+
|
|
714
|
+
@property
|
|
715
|
+
def branches_coords(self):
|
|
716
|
+
"""Get the ``coords`` from all branch objects returned as a dictionary with ``branch_names`` as keys."""
|
|
717
|
+
return {name: branch.coords for name, branch in self.branches.items()}
|
|
718
|
+
|
|
719
|
+
@property
|
|
720
|
+
def branches_supplemental(self):
|
|
721
|
+
"""Get the ``branch.supplemental`` from all branch objects returned as a dictionary with ``branch_names`` as keys."""
|
|
722
|
+
return {
|
|
723
|
+
name: branch.branch_supplemental for name, branch in self.branches.items()
|
|
724
|
+
}
|
|
725
|
+
|
|
726
|
+
@property
|
|
727
|
+
def branch_names(self):
|
|
728
|
+
"""Get the branch names in this state."""
|
|
729
|
+
return list(self.branches.keys())
|
|
730
|
+
|
|
731
|
+
def copy_into_self(self, state_to_copy):
|
|
732
|
+
for name in state_to_copy.__slots__:
|
|
733
|
+
setattr(self, name, getattr(state_to_copy, name))
|
|
734
|
+
|
|
735
|
+
def get_log_posterior(self, temper: bool = False):
|
|
736
|
+
"""Get the posterior probability
|
|
737
|
+
|
|
738
|
+
Args:
|
|
739
|
+
temper (bool, optional): If ``True``, apply tempering to the posterior computation.
|
|
740
|
+
|
|
741
|
+
Returns:
|
|
742
|
+
np.ndarray[ntemps, nwalkers]: Log of the posterior probability.
|
|
743
|
+
|
|
744
|
+
"""
|
|
745
|
+
|
|
746
|
+
if temper:
|
|
747
|
+
betas = self.betas
|
|
748
|
+
|
|
749
|
+
else:
|
|
750
|
+
betas = np.ones_like(self.betas)
|
|
751
|
+
|
|
752
|
+
return betas * self.log_like + self.log_prior
|
|
753
|
+
|
|
754
|
+
"""
|
|
755
|
+
# TODO
|
|
756
|
+
def __repr__(self):
|
|
757
|
+
return "State({0}, log_like={1}, blobs={2}, betas={3}, random_state={4})".format(
|
|
758
|
+
self.coords, self.log_like, self.blobs, self.betas, self.random_state
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
def __iter__(self):
|
|
762
|
+
temp = (self.coords,)
|
|
763
|
+
if self.log_like is not None:
|
|
764
|
+
temp += (self.log_like,)
|
|
765
|
+
|
|
766
|
+
if self.blobs is not None:
|
|
767
|
+
temp += (self.blobs,)
|
|
768
|
+
|
|
769
|
+
if self.betas is None:
|
|
770
|
+
temp += (self.betas,)
|
|
771
|
+
|
|
772
|
+
if self.random_state is not None:
|
|
773
|
+
temp += (self.random_state,)
|
|
774
|
+
return iter(temp)
|
|
775
|
+
"""
|