eryn 1.2.5__py3-none-any.whl → 1.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eryn/ensemble.py +7 -10
- eryn/utils/__init__.py +1 -1
- eryn/utils/plot.py +1393 -0
- eryn/utils/updates.py +106 -0
- eryn/utils/utility.py +1 -1
- {eryn-1.2.5.dist-info → eryn-1.2.6.dist-info}/METADATA +7 -15
- {eryn-1.2.5.dist-info → eryn-1.2.6.dist-info}/RECORD +8 -7
- {eryn-1.2.5.dist-info → eryn-1.2.6.dist-info}/WHEEL +0 -0
eryn/ensemble.py
CHANGED
|
@@ -13,7 +13,7 @@ from .pbar import get_progress_bar
|
|
|
13
13
|
from .state import State
|
|
14
14
|
from .prior import ProbDistContainer
|
|
15
15
|
|
|
16
|
-
|
|
16
|
+
from .utils import PlotContainer
|
|
17
17
|
from .utils import PeriodicContainer
|
|
18
18
|
from .utils.utility import groups_from_inds
|
|
19
19
|
|
|
@@ -233,7 +233,7 @@ class EnsembleSampler(object):
|
|
|
233
233
|
blobs_dtype=None, # TODO check this
|
|
234
234
|
plot_iterations=-1, # TODO: do plot stuff?
|
|
235
235
|
plot_generator=None,
|
|
236
|
-
|
|
236
|
+
plot_folder=None,
|
|
237
237
|
periodic=None,
|
|
238
238
|
update_fn=None,
|
|
239
239
|
update_iterations=-1,
|
|
@@ -662,18 +662,15 @@ class EnsembleSampler(object):
|
|
|
662
662
|
self.plot_iterations = plot_iterations
|
|
663
663
|
|
|
664
664
|
if plot_generator is None and self.plot_iterations > 0:
|
|
665
|
-
raise NotImplementedError
|
|
666
665
|
# set to default if not provided
|
|
667
|
-
if
|
|
668
|
-
|
|
669
|
-
else:
|
|
670
|
-
name = "output"
|
|
666
|
+
if plot_folder is None:
|
|
667
|
+
plot_folder = "./runtime_plots"
|
|
671
668
|
self.plot_generator = PlotContainer(
|
|
672
|
-
|
|
669
|
+
backend=self.backend, plots=['base', 'rj'], parent_folder=plot_folder, discard=0.2
|
|
673
670
|
)
|
|
674
671
|
elif self.plot_iterations > 0:
|
|
675
|
-
raise NotImplementedError
|
|
676
672
|
self.plot_generator = plot_generator
|
|
673
|
+
self.plot_generator.backend = self.backend # make sure backend is correctly set
|
|
677
674
|
|
|
678
675
|
# prepare stopping functions
|
|
679
676
|
self.stopping_fn = stopping_fn
|
|
@@ -1107,7 +1104,7 @@ class EnsembleSampler(object):
|
|
|
1107
1104
|
# diagnostic plots
|
|
1108
1105
|
# TODO: adjust diagnostic plots
|
|
1109
1106
|
if self.plot_iterations > 0 and (i + 1) % (self.plot_iterations) == 0:
|
|
1110
|
-
self.plot_generator.
|
|
1107
|
+
self.plot_generator.produce_plots(sampler=self) # TODO: remove defaults
|
|
1111
1108
|
|
|
1112
1109
|
# check for stopping before updating
|
|
1113
1110
|
if (
|