eryn 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
eryn/state.py ADDED
@@ -0,0 +1,775 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from copy import deepcopy
4
+
5
+ try:
6
+ import cupy as xp
7
+
8
+ except (ModuleNotFoundError, ImportError) as e:
9
+ import numpy as xp
10
+
11
+ import numpy as np
12
+
13
+ __all__ = ["State"]
14
+
15
+
16
+ class BranchSupplemental(object):
17
+ """Special object to carry information through sampler.
18
+
19
+ The :class:`BranchSupplemental` object is a holder of information that is
20
+ passed through the sampler. It can also be indexed similar to other quantities
21
+ carried throughout the sampler.
22
+
23
+ This indexing is based on the ``base_shape``. You can store many objects that have the same base
24
+ shape and then index across all of them. For example, if you want to store individual leaf
25
+ information, the base shape will be ``(ntemps, nwalkers, nleaves_max)``.
26
+ If you want to store a 2D array per individual leaf, the overall shape will be
27
+ ``(ntemps, nwalkers, nleaves_max, dim2_extra, dim1_extra)``. Another type of information is
28
+ stored in a class object (for example). Using ``numpy`` object arrays,
29
+ ``ntemps * nwalkers * nleaves_max`` number of class objects can be stored in the array. Then,
30
+ using special indexing functions, information can be updated/accessed across all objects
31
+ stored simultaneously. If you index this class, it will give you back a dictionary with
32
+ all objects stored indexed for each leaf. So if you index (0, 0, 0) in our running example,
33
+ you will get back a dictionary with one 2D array and one class object from the ``numpy`` object
34
+ array.
35
+
36
+ All of these objects are stored in ``self.holder``.
37
+
38
+ Args:
39
+ obj_info (dict): Initial information for storage. Keys are the names to be stored under
40
+ and values are arrays. These arrays should have a base shape that is equivalent to
41
+ ``base_shape``, meaning ``array.shape[:len(base_shape)] == self.base_shape``.
42
+ The dimensions beyond the base shape can be anything.
43
+ base_shape (tuple): Base shape for indexing. Objects stored in the supplemental object
44
+ will have a shape that at minimum is equivalent to ``base_shape``.
45
+ copy (bool, optional): If ``True``, copy whatever information is given in before it is stored.
46
+ if ``False``, store directly the input information. (default: ``False``)
47
+
48
+ Attributes:
49
+ holder (dict): All of the objects stored for this supplemental object.
50
+
51
+
52
+ """
53
+
54
+ def __init__(self, obj_info: dict, base_shape: tuple, copy: bool = False):
55
+ # store initial information
56
+ self.holder = {}
57
+ self.base_shape = base_shape
58
+ self.ndim = len(self.base_shape)
59
+
60
+ # add initial set of objects
61
+ self.add_objects(obj_info, copy=copy)
62
+
63
+ def add_objects(self, obj_info: dict, copy=False):
64
+ """Add objects to the holder.
65
+
66
+ Args:
67
+ obj_info (dict): Information for storage. Keys are the names to be stored under
68
+ and values are arrays. These arrays should have a base shape that is equivalent to
69
+ ``base_shape``, meaning ``array.shape[:len(base_shape)] == self.base_shape``.
70
+ The dimensions beyond the base shape can be anything.
71
+ copy (bool, optional): If ``True``, copy whatever information is given in before it is stored.
72
+ if ``False``, store directly the input information. (default: ``False``)
73
+
74
+ Raises:
75
+ ValueError: Shape matching issues.
76
+
77
+ """
78
+
79
+ # whether a copy is requested
80
+ dc = deepcopy if copy else (lambda x: x)
81
+
82
+ # iterate through the dictionary of incoming objects to add
83
+ for name, obj_contained in obj_info.items():
84
+ if (
85
+ isinstance(obj_contained, np.ndarray)
86
+ and obj_contained.dtype.name == "object"
87
+ ):
88
+ self.holder[name] = dc(obj_contained)
89
+ if self.base_shape is None:
90
+ self.base_shape = self.holder[name].shape
91
+ self.ndim = ndim = len(self.base_shape)
92
+ else:
93
+ if self.holder[name].shape != self.base_shape:
94
+ raise ValueError(
95
+ f"Outer shapes of all input objects must be the same. {name} object array has shape {self.holder[name].shape}. The original shape found was {self.base_shape}."
96
+ )
97
+
98
+ else:
99
+ self.ndim = ndim = len(self.base_shape)
100
+
101
+ # xp for GPU
102
+ if isinstance(obj_contained, np.ndarray) or isinstance(
103
+ obj_contained, xp.ndarray
104
+ ):
105
+ self.holder[name] = obj_contained.copy()
106
+
107
+ # fill object array from list
108
+ # adjust based on how many dimensions found
109
+ else:
110
+ # objects to be stored
111
+ self.holder[name] = np.empty(self.base_shape, dtype=object)
112
+ if len(obj_contained) != self.base_shape[0]:
113
+ raise ValueError(
114
+ "Shapes of obj_contained does not match base_shape along axis 0."
115
+ )
116
+
117
+ if ndim > 1:
118
+ for i in range(self.base_shape[0]):
119
+ if len(obj_contained[i]) != self.base_shape[1]:
120
+ raise ValueError(
121
+ "Shapes of obj_contained does not match obj_contained_sha along axis 1."
122
+ )
123
+
124
+ if ndim > 2:
125
+ for j in range(self.base_shape[1]):
126
+ if len(obj_contained[i][j]) != self.base_shape[2]:
127
+ raise ValueError(
128
+ "Shapes of obj_contained does not match base_shape along axis 2."
129
+ )
130
+
131
+ for k in range(self.base_shape[2]):
132
+ self.holder[name][i, j, k] = obj_contained[i][
133
+ j
134
+ ][k]
135
+ else:
136
+ for j in range(self.base_shape[1]):
137
+ self.holder[name][i, j] = obj_contained[i][j]
138
+
139
+ else:
140
+ for i in range(self.base_shape[0]):
141
+ self.holder[name][i] = obj_contained[i]
142
+
143
+ def remove_objects(self, names):
144
+ """Remove objects from the holder.
145
+
146
+
147
+ Args:
148
+ names (str or list of str): Strings associated with information to delete.
149
+ Please note it does not return the information.
150
+
151
+ Raises:
152
+ ValueError: Input issues.
153
+
154
+
155
+ """
156
+ # check inputs
157
+ if not isinstance(names, list):
158
+ if not isinstance(names, str):
159
+ raise ValueError("names must be a string or list of strings.")
160
+
161
+ names = [names]
162
+
163
+ # iterate and remove items from holder
164
+ for name in names:
165
+ self.holder.pop(name)
166
+
167
+ @property
168
+ def contained_objects(self):
169
+ """The list of keys of contained objects."""
170
+ return list(self.holder.keys())
171
+
172
+ def __contains__(self, name: str):
173
+ """Check if the holder holds a specific key."""
174
+ return name in self.holder
175
+
176
+ def __getitem__(self, tmp):
177
+ """Special indexing for retrieval.
178
+
179
+ When indexing the overall class, this will return the slice of each object
180
+
181
+ Args:
182
+ tmp (int, np.ndarray, or slice): indexing slice of some form.
183
+
184
+ Returns:
185
+ dict: Keys are names of the objects contained. Values are the slices of those objects.
186
+
187
+ """
188
+ # slice each object contained
189
+ return {name: values[tmp] for name, values in self.holder.items()}
190
+
191
+ def __setitem__(self, tmp, new_value):
192
+ """Special indexing for setting elements.
193
+
194
+ When indexing the overall class, this will set object information.
195
+
196
+ **Please note**: If you try to input information that is not already stored,
197
+ it will ignore it.
198
+
199
+ Args:
200
+ tmp (int, np.ndarray, or slice): indexing slice of some form.
201
+
202
+ """
203
+ # loop through values already in holder
204
+ for name, values in self.holder.items():
205
+ if name not in new_value:
206
+ continue
207
+ # if the name is already contained, update with incoming value
208
+ self.holder[name][tmp] = new_value[name]
209
+
210
+ def take_along_axis(self, indices, axis: int, skip_names=[]):
211
+ """Take information from contained arrays along an axis.
212
+
213
+ See ```numpy.take_along_axis`` <https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html>`_.
214
+
215
+ Args:
216
+ indices (xp.ndarray): Indices to take along each 1d slice of arr. This must match the dimension
217
+ of ``self.base_shape``, but other dimensions only need to broadcast against
218
+ ``self.base_shape``.
219
+ axis (int): The axis to take 1d slices along.
220
+ skip_names (list of str, optional): By default, this function returns the results for
221
+ all stored objects. This list gives the strings of objects to leave behind and
222
+ not return.
223
+
224
+ Returns:
225
+ dict: Keys are names of stored objects and values are the proper array slices.
226
+
227
+
228
+ """
229
+ # prepare output dictionary
230
+ out = {}
231
+
232
+ # iterate through holder
233
+ for name, values in self.holder.items():
234
+ # skip names if desired
235
+ if name in skip_names:
236
+ continue
237
+
238
+ indices_temp = indices.copy()
239
+ # adjust indices properly for specific object within the holder
240
+ if (
241
+ isinstance(values, np.ndarray) and values.dtype.name != "object"
242
+ ) or isinstance(values, xp.ndarray):
243
+ # expand the dimensions of the indexing values for non-object arrays
244
+ for _ in range(values.ndim - indices_temp.ndim):
245
+ if isinstance(values, np.ndarray):
246
+ indices_temp = np.expand_dims(np.asarray(indices_temp), (-1,))
247
+ elif isinstance(values, xp.ndarray):
248
+ indices_temp = xp.expand_dims(xp.asarray(indices_temp), (-1,))
249
+
250
+ # store the output for either numpy or cupy
251
+ if isinstance(values, np.ndarray):
252
+ out[name] = np.take_along_axis(values, indices_temp, axis)
253
+
254
+ elif isinstance(values, xp.ndarray):
255
+ out[name] = xp.take_along_axis(values, indices_temp, axis)
256
+
257
+ return out
258
+
259
+ def put_along_axis(self, indices, values_in: dict, axis: int):
260
+ """Put information information into contained arrays along an axis.
261
+
262
+ See ```numpy.put_along_axis`` <https://numpy.org/doc/stable/reference/generated/numpy.put_along_axis.html>`_.
263
+
264
+ **Please note** this function is not implemented in ``cupy``, so this is a custom implementation
265
+ for both ``cupy`` and ``numpy``.
266
+
267
+ Args:
268
+ indices (xp.ndarray): Indices to put values along each 1d slice of arr. This must match
269
+ the dimension of ``self.base_shape``, but other dimensions only need to broadcast against
270
+ ``self.base_shape``.
271
+ axis (int): The axis to put 1d slices along.
272
+ values_in (dict): Keys are the objects contained to update. Values are the arrays of these
273
+ objects with shape and dimension that can broadcast to match that of indices.
274
+
275
+ """
276
+ # iterate through all objects in the holder
277
+ for name, values in self.holder.items():
278
+ # skip names that are not to be updated
279
+ if name not in values_in:
280
+ continue
281
+
282
+ # will need to have flexibility to broadcast
283
+ indices_temp = indices.copy()
284
+
285
+ if (
286
+ isinstance(values, np.ndarray) and values.dtype.name != "object"
287
+ ) or isinstance(values, xp.ndarray):
288
+ # prepare indices for proper broadcasting
289
+ for _ in range(values.ndim - indices_temp.ndim):
290
+ if isinstance(values, np.ndarray):
291
+ indices_temp = np.expand_dims(np.asarray(indices_temp), (-1,))
292
+ elif isinstance(values, xp.ndarray):
293
+ indices_temp = xp.expand_dims(xp.asarray(indices_temp), (-1,))
294
+
295
+ # prepare slicing information for entry
296
+ if isinstance(values, np.ndarray):
297
+ inds0 = np.repeat(
298
+ np.arange(len(indices_temp))[:, None], indices_temp.shape[1], axis=1
299
+ )
300
+ elif isinstance(values, xp.ndarray):
301
+ inds0 = xp.repeat(
302
+ np.arange(len(indices_temp))[:, None], indices_temp.shape[1], axis=1
303
+ )
304
+ # self.xp.put_along_axis(self.holder[name], indices_temp, values_in[name], axis)
305
+ # because cupy does not have put_along_axis
306
+ self.holder[name][(inds0.flatten(), indices_temp.flatten())] = values_in[
307
+ name
308
+ ].reshape((-1,) + values_in[name].shape[2:])
309
+
310
+ @property
311
+ def flat(self):
312
+ """Get flattened arrays from the stored objects.
313
+
314
+ Here "flat" is in relation to ``self.base_shape``. Beyond ``self.base_shape``, the shape is mainted.
315
+
316
+ """
317
+ out = {}
318
+ # loop through holder
319
+ for name, values in self.holder.items():
320
+ if (
321
+ isinstance(values, np.ndarray) and values.dtype.name != "object"
322
+ ) or isinstance(values, xp.ndarray):
323
+ # need to account for higher dimensional arrays.
324
+ out[name] = values.reshape((-1,) + values.shape[2:])
325
+ else:
326
+ out[name] = values.flatten()
327
+ return out
328
+
329
+
330
+ class Branch(object):
331
+ """Special container for one branch (model)
332
+
333
+ This class is a key component of Eryn. It this type of object
334
+ that allows for different models to be considered simultaneously
335
+ within an MCMC run.
336
+
337
+ Args:
338
+ coords (4D double np.ndarray[ntemps, nwalkers, nleaves_max, ndim]): The coordinates
339
+ in parameter space of all walkers.
340
+ inds (3D bool np.ndarray[ntemps, nwalkers, nleaves_max], optional): The information
341
+ on which leaves are used and which are not used. A value of True means the specific leaf
342
+ was used in this step. Parameters from unused walkers are still kept. When they
343
+ are output to the backend, the backend saves a special number (default: ``np.nan``) for all coords
344
+ related to unused leaves at that step. If None, inds will fill with all True values.
345
+ (default: ``None``)
346
+ branch_supplemental (object): :class:`BranchSupplemental` object specific to this branch. (default: ``None``)
347
+
348
+ Raises:
349
+ ValueError: ``inds`` has wrong shape or number of leaves is less than zero.
350
+
351
+ """
352
+
353
+ def __init__(self, coords, inds=None, branch_supplemental=None):
354
+ # store branch info
355
+ self.coords = coords
356
+ self.ntemps, self.ntrees, self.nleaves_max, self.ndim = coords.shape
357
+ self.shape = coords.shape
358
+
359
+ # make sure inds is correct
360
+ if inds is None:
361
+ self.inds = np.full((self.ntemps, self.ntrees, self.nleaves_max), True)
362
+ elif not isinstance(inds, np.ndarray):
363
+ raise ValueError("inds must be np.ndarray in Branch.")
364
+ elif inds.shape != (self.ntemps, self.ntrees, self.nleaves_max):
365
+ raise ValueError("inds has wrong shape.")
366
+ else:
367
+ self.inds = inds
368
+
369
+ if branch_supplemental is not None:
370
+ # make sure branch_supplemental shape matches
371
+ if branch_supplemental.base_shape != self.inds.shape:
372
+ raise ValueError(
373
+ f"branch_supplemental shape ( {branch_supplemental.base_shape} ) does not match inds shape ( {self.inds.shape} )."
374
+ )
375
+
376
+ # store
377
+ self.branch_supplemental = branch_supplemental
378
+
379
+ @property
380
+ def nleaves(self):
381
+ """Number of leaves for each walker"""
382
+ # get number of leaves in each walker by summing inds along last axis
383
+ nleaves = np.sum(self.inds, axis=-1)
384
+ return nleaves
385
+
386
+
387
+ class State(object):
388
+ """The state of the ensemble during an MCMC run
389
+
390
+ Args:
391
+ coords (double ndarray[ntemps, nwalkers, nleaves_max, ndim], dict, or :class:`.State`): The current positions of the walkers
392
+ in the parameter space. If dict, need to use ``branch_names`` for the keys.
393
+ inds (bool ndarray[ntemps, nwalkers, nleaves_max] or dict, optional): The information
394
+ on which leaves are used and which are not used. A value of True means the specific leaf
395
+ was used in this step. If dict, need to use ``branch_names`` for the keys.
396
+ Input should be ``None`` if a complete :class:`.State` object is input for ``coords``.
397
+ (default: ``None``)
398
+ branch_supplemental (object): :class:`BranchSupplemental` object specific to this branch.
399
+ (default: ``None``)
400
+ log_like (ndarray[ntemps, nwalkers], optional): Log likelihoods
401
+ for the walkers at positions given by ``coords``.
402
+ Input should be ``None`` if a complete :class:`.State` object is input for ``coords``.
403
+ (default: ``None``)
404
+ log_prior (ndarray[ntemps, nwalkers], optional): Log priors
405
+ for the walkers at positions given by ``coords``.
406
+ Input should be ``None`` if a complete :class:`.State` object is input for ``coords``.
407
+ (default: ``None``)
408
+ betas (ndarray[ntemps], optional): Temperatures in the sampler at the current step.
409
+ Input should be ``None`` if a complete :class:`.State` object is input for ``coords``.
410
+ (default: ``None``)
411
+ blobs (ndarray[ntemps, nwalkers, nblobs], Optional): The metadata “blobs”
412
+ associated with the current position. The value is only returned if
413
+ lnpostfn returns blobs too.
414
+ Input should be ``None`` if a complete :class:`.State` object is input for ``coords``.
415
+ (default: ``None``)
416
+ random_state (Optional): The current state of the random number
417
+ generator.
418
+ Input should be ``None`` if a complete :class:`.State` object is input for ``coords``.
419
+ (default: ``None``)
420
+ copy (bool, optional): If True, copy the the arrays in the former :class:`.State` obhect.
421
+
422
+ Raises:
423
+ ValueError: Dimensions of inputs or input types are incorrect.
424
+
425
+ """
426
+
427
+ # __slots__ = (
428
+ # "branches",
429
+ # "log_like",
430
+ # "log_prior",
431
+ # "blobs",
432
+ # "betas",
433
+ # "supplemental",
434
+ # "random_state",
435
+ # )
436
+
437
+ def __init__(
438
+ self,
439
+ coords,
440
+ inds=None,
441
+ branch_supplemental=None,
442
+ supplemental=None,
443
+ log_like=None,
444
+ log_prior=None,
445
+ betas=None,
446
+ blobs=None,
447
+ random_state=None,
448
+ copy=False,
449
+ ):
450
+ # decide if copying input info
451
+ dc = deepcopy if copy else lambda x: x
452
+
453
+ # check if coords is a State object
454
+ if hasattr(coords, "branches"):
455
+ self.branches = dc(coords.branches)
456
+ self.log_like = dc(coords.log_like)
457
+ self.log_prior = dc(coords.log_prior)
458
+ self.blobs = dc(coords.blobs)
459
+ self.betas = dc(coords.betas)
460
+ self.supplemental = dc(coords.supplemental)
461
+ self.random_state = dc(coords.random_state)
462
+ return
463
+
464
+ # protect against simplifying settings
465
+ if isinstance(coords, np.ndarray) or isinstance(coords, xp.ndarray):
466
+ coords = {"model_0": coords}
467
+ elif not isinstance(coords, dict):
468
+ raise ValueError(
469
+ "Input coords need to be np.ndarray, dict, or State object."
470
+ )
471
+
472
+ for name in coords:
473
+ if coords[name].ndim == 2:
474
+ coords[name] = coords[name][None, :, None, :]
475
+
476
+ # assume (ntemps, nwalkers) provided
477
+ if coords[name].ndim == 3:
478
+ coords[name] = coords[name][:, :, None, :]
479
+
480
+ elif coords[name].ndim < 2 or coords[name].ndim > 4:
481
+ raise ValueError(
482
+ "Dimension off coordinates must be between 2 and 4. coords dimension is {0}.".format(
483
+ coords.ndim
484
+ )
485
+ )
486
+
487
+ # if no inds given, make sure this is clear for all Branch objects
488
+ if inds is None:
489
+ inds = {key: None for key in coords}
490
+ elif not isinstance(inds, dict):
491
+ raise ValueError("inds must be None or dict.")
492
+
493
+ if branch_supplemental is None:
494
+ branch_supplemental = {key: None for key in coords}
495
+ elif isinstance(
496
+ branch_supplemental, dict
497
+ ): # case where not all branches have supp
498
+ for key in coords.keys() - branch_supplemental.keys():
499
+ branch_supplemental[key] = None
500
+ elif not isinstance(branch_supplemental, dict):
501
+ raise ValueError("branch_supplemental must be None or dict.")
502
+
503
+ # setup all information for storage
504
+ self.branches = {
505
+ key: Branch(
506
+ dc(temp_coords),
507
+ inds=inds[key],
508
+ branch_supplemental=branch_supplemental[key],
509
+ )
510
+ for key, temp_coords in coords.items()
511
+ }
512
+ self.log_like = dc(np.atleast_2d(log_like)) if log_like is not None else None
513
+ self.log_prior = dc(np.atleast_2d(log_prior)) if log_prior is not None else None
514
+ self.blobs = dc(np.atleast_3d(blobs)) if blobs is not None else None
515
+ self.betas = dc(np.atleast_1d(betas)) if betas is not None else None
516
+ self.supplemental = dc(supplemental)
517
+ self.random_state = dc(random_state)
518
+
519
+ @property
520
+ def branches_inds(self):
521
+ """Get the ``inds`` from all branch objects returned as a dictionary with ``branch_names`` as keys."""
522
+ return {name: branch.inds for name, branch in self.branches.items()}
523
+
524
+ @property
525
+ def branches_coords(self):
526
+ """Get the ``coords`` from all branch objects returned as a dictionary with ``branch_names`` as keys."""
527
+ return {name: branch.coords for name, branch in self.branches.items()}
528
+
529
+ @property
530
+ def branches_supplemental(self):
531
+ """Get the ``branch.supplemental`` from all branch objects returned as a dictionary with ``branch_names`` as keys."""
532
+ return {
533
+ name: branch.branch_supplemental for name, branch in self.branches.items()
534
+ }
535
+
536
+ @property
537
+ def branch_names(self):
538
+ """Get the branch names in this state."""
539
+ return list(self.branches.keys())
540
+
541
+ def copy_into_self(self, state_to_copy):
542
+ for name in state_to_copy.__slots__:
543
+ setattr(self, name, getattr(state_to_copy, name))
544
+
545
+ def get_log_posterior(self, temper: bool = False):
546
+ """Get the posterior probability
547
+
548
+ Args:
549
+ temper (bool, optional): If ``True``, apply tempering to the posterior computation.
550
+
551
+ Returns:
552
+ np.ndarray[ntemps, nwalkers]: Log of the posterior probability.
553
+
554
+ """
555
+
556
+ if temper:
557
+ betas = self.betas
558
+
559
+ else:
560
+ betas = np.ones_like(self.betas)
561
+
562
+ return betas * self.log_like + self.log_prior
563
+
564
+ """
565
+ # TODO
566
+ def __repr__(self):
567
+ return "State({0}, log_like={1}, blobs={2}, betas={3}, random_state={4})".format(
568
+ self.coords, self.log_like, self.blobs, self.betas, self.random_state
569
+ )
570
+
571
+ def __iter__(self):
572
+ temp = (self.coords,)
573
+ if self.log_like is not None:
574
+ temp += (self.log_like,)
575
+
576
+ if self.blobs is not None:
577
+ temp += (self.blobs,)
578
+
579
+ if self.betas is None:
580
+ temp += (self.betas,)
581
+
582
+ if self.random_state is not None:
583
+ temp += (self.random_state,)
584
+ return iter(temp)
585
+ """
586
+
587
+
588
+ class ParaState(object):
589
+ """The state of the ensemble during an MCMC run
590
+
591
+ Args:
592
+ coords (double ndarray[ntemps, nwalkers, nleaves_max, ndim], dict, or :class:`.State`): The current positions of the walkers
593
+ in the parameter space. If dict, need to use ``branch_names`` for the keys.
594
+ groups_running (bool ndarray[ntemps, nwalkers, nleaves_max] or dict, optional): The information
595
+ on which leaves are used and which are not used. A value of True means the specific leaf
596
+ was used in this step. If dict, need to use ``branch_names`` for the keys.
597
+ Input should be ``None`` if a complete :class:`.State` object is input for ``coords``.
598
+ (default: ``None``)
599
+ log_like (ndarray[ntemps, nwalkers], optional): Log likelihoods
600
+ for the walkers at positions given by ``coords``.
601
+ Input should be ``None`` if a complete :class:`.State` object is input for ``coords``.
602
+ (default: ``None``)
603
+ log_prior (ndarray[ntemps, nwalkers], optional): Log priors
604
+ for the walkers at positions given by ``coords``.
605
+ Input should be ``None`` if a complete :class:`.State` object is input for ``coords``.
606
+ (default: ``None``)
607
+ betas (ndarray[ntemps], optional): Temperatures in the sampler at the current step.
608
+ Input should be ``None`` if a complete :class:`.State` object is input for ``coords``.
609
+ (default: ``None``)
610
+ blobs (ndarray[ntemps, nwalkers, nblobs], Optional): The metadata “blobs”
611
+ associated with the current position. The value is only returned if
612
+ lnpostfn returns blobs too.
613
+ Input should be ``None`` if a complete :class:`.State` object is input for ``coords``.
614
+ (default: ``None``)
615
+ random_state (Optional): The current state of the random number
616
+ generator.
617
+ Input should be ``None`` if a complete :class:`.State` object is input for ``coords``.
618
+ (default: ``None``)
619
+ copy (bool, optional): If True, copy the the arrays in the former :class:`.State` obhect.
620
+
621
+ Raises:
622
+ ValueError: Dimensions of inputs or input types are incorrect.
623
+
624
+ """
625
+
626
+ # __slots__ = (
627
+ # "branches",
628
+ # "log_like",
629
+ # "log_prior",
630
+ # "blobs",
631
+ # "betas",
632
+ # "supplemental",
633
+ # "random_state",
634
+ # )
635
+
636
+ def __init__(
637
+ self,
638
+ coords,
639
+ groups_running=None,
640
+ branch_supplemental=None,
641
+ supplemental=None,
642
+ log_like=None,
643
+ log_prior=None,
644
+ betas=None,
645
+ blobs=None,
646
+ random_state=None,
647
+ copy=False,
648
+ ):
649
+ # decide if copying input info
650
+ dc = deepcopy if copy else lambda x: x
651
+
652
+ # check if coords is a State object
653
+ if hasattr(coords, "branches"):
654
+ self.branches = dc(coords.branches)
655
+ self.groups_running = dc(coords.groups_running)
656
+ self.log_like = dc(coords.log_like)
657
+ self.log_prior = dc(coords.log_prior)
658
+ self.blobs = dc(coords.blobs)
659
+ self.betas = dc(coords.betas)
660
+ self.supplemental = dc(coords.supplemental)
661
+ # self.random_state = dc(coords.random_state)
662
+ # TODO: check this
663
+ self.random_state = coords.random_state
664
+ return
665
+
666
+ # protect against simplifying settings
667
+ if isinstance(coords, np.ndarray) or isinstance(coords, xp.ndarray):
668
+ coords = {"model_0": coords}
669
+ elif not isinstance(coords, dict):
670
+ raise ValueError(
671
+ "Input coords need to be np.ndarray, dict, or State object."
672
+ )
673
+
674
+ for name in coords:
675
+ if coords[name].ndim == 2:
676
+ coords[name] = coords[name][None, :, None, :]
677
+
678
+ # assume (ntemps, nwalkers) provided
679
+ if coords[name].ndim == 3:
680
+ coords[name] = coords[name][:, :, None, :]
681
+
682
+ elif coords[name].ndim < 2 or coords[name].ndim > 4:
683
+ raise ValueError(
684
+ "Dimension off coordinates must be between 2 and 4. coords dimension is {0}.".format(
685
+ coords.ndim
686
+ )
687
+ )
688
+
689
+ if branch_supplemental is None:
690
+ branch_supplemental = {key: None for key in coords}
691
+ elif not isinstance(branch_supplemental, dict):
692
+ raise ValueError("branch_supplemental must be None or dict.")
693
+
694
+ # setup all information for storage
695
+ self.branches = {
696
+ key: Branch(
697
+ dc(temp_coords),
698
+ inds=None,
699
+ branch_supplemental=branch_supplemental[key],
700
+ )
701
+ for key, temp_coords in coords.items()
702
+ }
703
+
704
+ self.groups_running = (
705
+ dc(np.atleast_1d(groups_running)) if groups_running is not None else None
706
+ )
707
+ self.log_like = dc(np.atleast_2d(log_like)) if log_like is not None else None
708
+ self.log_prior = dc(np.atleast_2d(log_prior)) if log_prior is not None else None
709
+ self.blobs = dc(np.atleast_3d(blobs)) if blobs is not None else None
710
+ self.betas = dc(np.atleast_1d(betas)) if betas is not None else None
711
+ self.supplemental = dc(supplemental)
712
+ self.random_state = dc(random_state)
713
+
714
+ @property
715
+ def branches_coords(self):
716
+ """Get the ``coords`` from all branch objects returned as a dictionary with ``branch_names`` as keys."""
717
+ return {name: branch.coords for name, branch in self.branches.items()}
718
+
719
+ @property
720
+ def branches_supplemental(self):
721
+ """Get the ``branch.supplemental`` from all branch objects returned as a dictionary with ``branch_names`` as keys."""
722
+ return {
723
+ name: branch.branch_supplemental for name, branch in self.branches.items()
724
+ }
725
+
726
+ @property
727
+ def branch_names(self):
728
+ """Get the branch names in this state."""
729
+ return list(self.branches.keys())
730
+
731
+ def copy_into_self(self, state_to_copy):
732
+ for name in state_to_copy.__slots__:
733
+ setattr(self, name, getattr(state_to_copy, name))
734
+
735
+ def get_log_posterior(self, temper: bool = False):
736
+ """Get the posterior probability
737
+
738
+ Args:
739
+ temper (bool, optional): If ``True``, apply tempering to the posterior computation.
740
+
741
+ Returns:
742
+ np.ndarray[ntemps, nwalkers]: Log of the posterior probability.
743
+
744
+ """
745
+
746
+ if temper:
747
+ betas = self.betas
748
+
749
+ else:
750
+ betas = np.ones_like(self.betas)
751
+
752
+ return betas * self.log_like + self.log_prior
753
+
754
+ """
755
+ # TODO
756
+ def __repr__(self):
757
+ return "State({0}, log_like={1}, blobs={2}, betas={3}, random_state={4})".format(
758
+ self.coords, self.log_like, self.blobs, self.betas, self.random_state
759
+ )
760
+
761
+ def __iter__(self):
762
+ temp = (self.coords,)
763
+ if self.log_like is not None:
764
+ temp += (self.log_like,)
765
+
766
+ if self.blobs is not None:
767
+ temp += (self.blobs,)
768
+
769
+ if self.betas is None:
770
+ temp += (self.betas,)
771
+
772
+ if self.random_state is not None:
773
+ temp += (self.random_state,)
774
+ return iter(temp)
775
+ """