eryn 1.2.4__py3-none-any.whl → 1.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
eryn/backends/backend.py CHANGED
@@ -83,6 +83,7 @@ class Backend(object):
83
83
  nbranches=1,
84
84
  rj=False,
85
85
  moves=None,
86
+ key_order=None,
86
87
  **info,
87
88
  ):
88
89
  """Clear the state of the chain and empty the backend
@@ -106,6 +107,9 @@ class Backend(object):
106
107
  (default: ``False``)
107
108
  moves (list, optional): List of all of the move classes input into the sampler.
108
109
  (default: ``None``)
110
+ key_order (dict, optional): Keys are ``branch_names`` and values are lists of key ordering for each
111
+ branch. For example, ``{"model_0": ["x1", "x2", "x3"]}``.
112
+ (default: ``None``)
109
113
  **info (dict, optional): Any other key-value pairs to be added
110
114
  as attributes to the backend.
111
115
 
@@ -118,6 +122,7 @@ class Backend(object):
118
122
  branch_names=branch_names,
119
123
  rj=rj,
120
124
  moves=moves,
125
+ key_order=key_order,
121
126
  info=info,
122
127
  )
123
128
 
@@ -184,6 +189,7 @@ class Backend(object):
184
189
  self.branch_names = branch_names
185
190
  self.ndims = ndims
186
191
  self.nleaves_max = nleaves_max
192
+ self.key_order = key_order
187
193
 
188
194
  self.iteration = 0
189
195
 
@@ -176,6 +176,7 @@ class HDFBackend(Backend):
176
176
  nbranches=1,
177
177
  rj=False,
178
178
  moves=None,
179
+ key_order=None,
179
180
  **info,
180
181
  ):
181
182
  """Clear the state of the chain and empty the backend
@@ -199,6 +200,9 @@ class HDFBackend(Backend):
199
200
  (default: ``False``)
200
201
  moves (list, optional): List of all of the move classes input into the sampler.
201
202
  (default: ``None``)
203
+ key_order (dict, optional): Keys are ``branch_names`` and values are lists of key ordering for each
204
+ branch. For example, ``{"model_0": ["x1", "x2", "x3"]}``.
205
+ (default: ``None``)
202
206
  **info (dict, optional): Any other key-value pairs to be added
203
207
  as attributes to the backend. These are also added to the HDF5 file.
204
208
 
@@ -343,6 +347,7 @@ class HDFBackend(Backend):
343
347
 
344
348
  chain = g.create_group("chain")
345
349
  inds = g.create_group("inds")
350
+ k_o_g = g.create_group("key_order")
346
351
 
347
352
  for name in branch_names:
348
353
  nleaves = self.nleaves_max[name]
@@ -365,6 +370,9 @@ class HDFBackend(Backend):
365
370
  compression_opts=self.compression_opts,
366
371
  )
367
372
 
373
+ if key_order is not None:
374
+ k_o_g.attrs[name] = key_order[name]
375
+
368
376
  # store move specific information
369
377
  if moves is not None:
370
378
  move_group = g.create_group("moves")
@@ -388,6 +396,12 @@ class HDFBackend(Backend):
388
396
 
389
397
  self.blobs = None
390
398
 
399
+ @property
400
+ def key_order(self):
401
+ """Key order of parameters for each model."""
402
+ with self.open() as f:
403
+ return {key: value for key, value in f[self.name]["key_order"].attrs.items()}
404
+
391
405
  @property
392
406
  def nwalkers(self):
393
407
  """Get nwalkers from h5 file."""
@@ -456,6 +470,7 @@ class HDFBackend(Backend):
456
470
  branch_names=self.branch_names,
457
471
  rj=self.rj,
458
472
  moves=self.moves,
473
+ key_order=self.key_order,
459
474
  )
460
475
 
461
476
  @property
eryn/ensemble.py CHANGED
@@ -13,7 +13,7 @@ from .pbar import get_progress_bar
13
13
  from .state import State
14
14
  from .prior import ProbDistContainer
15
15
 
16
- # from .utils import PlotContainer
16
+ from .utils import PlotContainer
17
17
  from .utils import PeriodicContainer
18
18
  from .utils.utility import groups_from_inds
19
19
 
@@ -198,6 +198,7 @@ class EnsembleSampler(object):
198
198
  if it is changed. In this case, the user should declare a new backend and use the last
199
199
  state from the previous backend. **Warning**: If the order of moves of the same move class
200
200
  is changed, the check may not catch it, so the tracking may mix move acceptance fractions together.
201
+ (default: ``True'')
201
202
  info (dict, optional): Key and value pairs reprenting any information
202
203
  the user wants to add to the backend if the user is not inputing
203
204
  their own backend.
@@ -232,7 +233,7 @@ class EnsembleSampler(object):
232
233
  blobs_dtype=None, # TODO check this
233
234
  plot_iterations=-1, # TODO: do plot stuff?
234
235
  plot_generator=None,
235
- plot_name=None,
236
+ plot_folder=None,
236
237
  periodic=None,
237
238
  update_fn=None,
238
239
  update_iterations=-1,
@@ -597,6 +598,7 @@ class EnsembleSampler(object):
597
598
  nleaves_max=nleaves_max,
598
599
  rj=self.has_reversible_jump,
599
600
  moves=move_keys,
601
+ key_order=self.key_order,
600
602
  **info,
601
603
  )
602
604
  state = np.random.get_state()
@@ -615,6 +617,9 @@ class EnsembleSampler(object):
615
617
  "Configuration of moves has changed. Cannot use the same backend. Declare a new backend and start from the previous state. If you would prefer not to track move acceptance fraction, set track_moves to False in the EnsembleSampler."
616
618
  )
617
619
 
620
+ if self.key_order != self.backend.key_order:
621
+ raise ValueError("Input key order from priors does not match backend.")
622
+
618
623
  # Check the backend shape
619
624
  for i, (name, shape) in enumerate(self.backend.shape.items()):
620
625
  test_shape = (
@@ -657,18 +662,15 @@ class EnsembleSampler(object):
657
662
  self.plot_iterations = plot_iterations
658
663
 
659
664
  if plot_generator is None and self.plot_iterations > 0:
660
- raise NotImplementedError
661
665
  # set to default if not provided
662
- if plot_name is not None:
663
- name = plot_name
664
- else:
665
- name = "output"
666
+ if plot_folder is None:
667
+ plot_folder = "./runtime_plots"
666
668
  self.plot_generator = PlotContainer(
667
- fp=name, backend=self.backend, thin_chain_by_ac=True
669
+ backend=self.backend, plots=['base', 'rj'], parent_folder=plot_folder, discard=0.2
668
670
  )
669
671
  elif self.plot_iterations > 0:
670
- raise NotImplementedError
671
672
  self.plot_generator = plot_generator
673
+ self.plot_generator.backend = self.backend # make sure backend is correctly set
672
674
 
673
675
  # prepare stopping functions
674
676
  self.stopping_fn = stopping_fn
@@ -689,7 +691,7 @@ class EnsembleSampler(object):
689
691
 
690
692
  """
691
693
  return self._random.get_state()
692
-
694
+
693
695
  @random_state.setter # NOQA
694
696
  def random_state(self, state):
695
697
  """
@@ -751,13 +753,14 @@ class EnsembleSampler(object):
751
753
  else:
752
754
  raise ValueError("Priors must be a dictionary.")
753
755
 
756
+ self.key_order = {key: value.key_order for key, value in self._priors.items()}
754
757
  return
755
758
 
756
759
  @property
757
760
  def iteration(self):
758
761
  return self.backend.iteration
759
762
 
760
- def reset(self, **info):
763
+ def reset(self, **kwargs):
761
764
  """
762
765
  Reset the backend.
763
766
 
@@ -765,7 +768,7 @@ class EnsembleSampler(object):
765
768
  **info (dict, optional): information to pass to backend reset method.
766
769
 
767
770
  """
768
- self.backend.reset(self.nwalkers, self.ndims, **info)
771
+ self.backend.reset(self.nwalkers, self.ndims, **kwargs)
769
772
 
770
773
  def __getstate__(self):
771
774
  # In order to be generally picklable, we need to discard the pool
@@ -1101,7 +1104,7 @@ class EnsembleSampler(object):
1101
1104
  # diagnostic plots
1102
1105
  # TODO: adjust diagnostic plots
1103
1106
  if self.plot_iterations > 0 and (i + 1) % (self.plot_iterations) == 0:
1104
- self.plot_generator.generate_plot_info() # TODO: remove defaults
1107
+ self.plot_generator.produce_plots(sampler=self) # TODO: remove defaults
1105
1108
 
1106
1109
  # check for stopping before updating
1107
1110
  if (
eryn/prior.py CHANGED
@@ -248,24 +248,69 @@ class ProbDistContainer:
248
248
  # to separate out in list form
249
249
  self.priors = []
250
250
 
251
+ self.has_strings = False
252
+ self.has_ints = False
253
+
254
+ # this is for the strings (for the ints it just counts them)
255
+ current_ind = 0
256
+ key_order = []
257
+
251
258
  # setup lists
252
259
  temp_inds = []
253
260
  for inds, dist in priors_in.items():
254
261
  # multiple index
255
262
  if isinstance(inds, tuple):
256
- inds_in = np.asarray(inds)
263
+ inds_tmp = []
264
+ for i in range(len(inds)):
265
+ if isinstance(inds[i], str):
266
+ assert not self.has_ints
267
+ self.has_strings = True
268
+ inds_tmp.append(current_ind)
269
+ key_order.append(inds[i])
270
+
271
+ elif isinstance(inds[i], int):
272
+ assert not self.has_strings
273
+ self.has_ints = True
274
+ inds_tmp.append(i)
275
+
276
+ else:
277
+ raise ValueError("Index in tuple must be int or str and all be the same type.")
278
+
279
+ current_ind += 1
280
+
281
+ inds_in = np.asarray(inds_tmp)
257
282
  self.priors.append([inds_in, dist])
258
283
 
259
284
  # single index
260
285
  elif isinstance(inds, int):
286
+ self.has_ints = True
287
+ assert not self.has_strings
261
288
  inds_in = np.array([inds])
262
289
  self.priors.append([inds_in, dist])
290
+ current_ind += 1
291
+
292
+ elif isinstance(inds, str):
293
+ assert not self.has_ints
294
+ self.has_strings = True
295
+ key_order.append(inds)
296
+ inds_in = np.array([current_ind])
297
+ current_ind += 1
298
+ self.priors.append([inds_in, dist])
263
299
 
264
300
  else:
265
301
  raise ValueError(
266
- "Keys for prior dictionary must be an integer or tuple."
302
+ "Keys for prior dictionary must be an integer, string, or tuple."
267
303
  )
268
304
 
305
+ if self.has_strings:
306
+ assert not self.has_ints
307
+ # key order is already set
308
+ self.key_order = key_order
309
+
310
+ if self.has_ints:
311
+ self.key_order = [i for i in range(current_ind)] # here current_ind is the total count
312
+ assert not self.has_strings
313
+
269
314
  temp_inds.append(np.asarray([inds_in]))
270
315
 
271
316
  uni_inds = np.unique(np.concatenate(temp_inds, axis=1).flatten())
eryn/utils/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # -*- coding: utf-8 -*-
2
2
 
3
- # from .plot import PlotContainer
3
+ from .plot import PlotContainer
4
4
  from .utility import *
5
5
  from .periodic import *
6
6
  from .transform import *
eryn/utils/periodic.py CHANGED
@@ -18,15 +18,31 @@ class PeriodicContainer:
18
18
 
19
19
  """
20
20
 
21
- def __init__(self, periodic):
21
+ def __init__(self, periodic, key_order=None):
22
22
 
23
23
  # store all the information
24
24
  self.periodic = periodic
25
+ inds_periodic = {}
26
+ periods = {}
27
+ for key in periodic:
28
+ if periodic[key] is None:
29
+ continue
30
+ inds_periodic[key] = []
31
+ periods[key] = []
32
+ for var, period in periodic[key].items():
33
+ if isinstance(var, str):
34
+ if key_order is None:
35
+ raise ValueError(f"If providing str values for the variable names, must provide key_order argument.")
36
+
37
+ index = key_order[key].index(var)
38
+ inds_periodic[key].append(index)
39
+ periods[key].append(period)
40
+
25
41
  self.inds_periodic = {
26
- key: np.asarray([i for i in periodic[key].keys()]) for key in periodic
42
+ key: np.asarray(tmp) for key, tmp in inds_periodic.items()
27
43
  }
28
44
  self.periods = {
29
- key: np.asarray([i for i in periodic[key].values()]) for key in periodic
45
+ key: np.asarray(tmp) for key, tmp in periods.items()
30
46
  }
31
47
 
32
48
  def distance(self, p1, p2, xp=None):