eryn 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eryn/CMakeLists.txt +51 -0
- eryn/__init__.py +35 -0
- eryn/backends/__init__.py +20 -0
- eryn/backends/backend.py +1150 -0
- eryn/backends/hdfbackend.py +819 -0
- eryn/ensemble.py +1690 -0
- eryn/git_version.py.in +7 -0
- eryn/model.py +18 -0
- eryn/moves/__init__.py +42 -0
- eryn/moves/combine.py +135 -0
- eryn/moves/delayedrejection.py +229 -0
- eryn/moves/distgen.py +104 -0
- eryn/moves/distgenrj.py +222 -0
- eryn/moves/gaussian.py +190 -0
- eryn/moves/group.py +281 -0
- eryn/moves/groupstretch.py +120 -0
- eryn/moves/mh.py +193 -0
- eryn/moves/move.py +703 -0
- eryn/moves/mtdistgen.py +137 -0
- eryn/moves/mtdistgenrj.py +190 -0
- eryn/moves/multipletry.py +776 -0
- eryn/moves/red_blue.py +333 -0
- eryn/moves/rj.py +388 -0
- eryn/moves/stretch.py +231 -0
- eryn/moves/tempering.py +649 -0
- eryn/pbar.py +56 -0
- eryn/prior.py +452 -0
- eryn/state.py +775 -0
- eryn/tests/__init__.py +0 -0
- eryn/tests/test_eryn.py +1246 -0
- eryn/utils/__init__.py +10 -0
- eryn/utils/periodic.py +134 -0
- eryn/utils/stopping.py +164 -0
- eryn/utils/transform.py +226 -0
- eryn/utils/updates.py +69 -0
- eryn/utils/utility.py +329 -0
- eryn-1.2.0.dist-info/METADATA +167 -0
- eryn-1.2.0.dist-info/RECORD +39 -0
- eryn-1.2.0.dist-info/WHEEL +4 -0
eryn/utils/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
# from .plot import PlotContainer
|
|
4
|
+
from .utility import *
|
|
5
|
+
from .periodic import *
|
|
6
|
+
from .transform import *
|
|
7
|
+
from .updates import *
|
|
8
|
+
from .stopping import *
|
|
9
|
+
|
|
10
|
+
__all__ = ["PlotContainer", "PeriodicContainer", "TransformContainer"]
|
eryn/utils/periodic.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
from ast import Import
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
import cupy as xp
|
|
6
|
+
|
|
7
|
+
except (ModuleNotFoundError, ImportError) as e:
|
|
8
|
+
pass
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class PeriodicContainer:
|
|
12
|
+
"""Perform operations for periodic parameters
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
periodic_in (dict): Keys are ``branch_names``. Values are
|
|
16
|
+
dictionaries. These dictionaries have keys as the parameter
|
|
17
|
+
indexes and values their associated period.
|
|
18
|
+
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, periodic):
|
|
22
|
+
|
|
23
|
+
# store all the information
|
|
24
|
+
self.periodic = periodic
|
|
25
|
+
self.inds_periodic = {
|
|
26
|
+
key: np.asarray([i for i in periodic[key].keys()]) for key in periodic
|
|
27
|
+
}
|
|
28
|
+
self.periods = {
|
|
29
|
+
key: np.asarray([i for i in periodic[key].values()]) for key in periodic
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
def distance(self, p1, p2, xp=None):
|
|
33
|
+
"""Move from p1 to p2 with periodic distance control
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
p1 (dict): If dict, keys are ``branch_names``
|
|
37
|
+
and values are positions with parameters along the final dimension.
|
|
38
|
+
p2 (dict): If dict, keys are ``branch_names``
|
|
39
|
+
and values are positions with parameters along the final dimension.
|
|
40
|
+
xp (object, optional): ``numpy`` or ``cupy``. If ``None``, use ``numpy``.
|
|
41
|
+
(default: ``None``)
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
dict: Distances accounting for periodicity.
|
|
45
|
+
Keys are branch names and values are distance arrays.
|
|
46
|
+
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
# cupy or numpy
|
|
50
|
+
if xp is None:
|
|
51
|
+
xp = np
|
|
52
|
+
|
|
53
|
+
# make sure both have same branches
|
|
54
|
+
assert list(p1.keys()) == list(p2.keys())
|
|
55
|
+
|
|
56
|
+
names = list(p1.keys())
|
|
57
|
+
|
|
58
|
+
# prepare output
|
|
59
|
+
out_diff = {}
|
|
60
|
+
for key in names:
|
|
61
|
+
|
|
62
|
+
# get basic distance
|
|
63
|
+
diff = p2[key] - p1[key]
|
|
64
|
+
|
|
65
|
+
# no periodic parameters for this key
|
|
66
|
+
if key not in self.periods:
|
|
67
|
+
out_diff[key] = diff
|
|
68
|
+
continue
|
|
69
|
+
|
|
70
|
+
# get period info
|
|
71
|
+
periods = xp.asarray(self.periods[key])
|
|
72
|
+
inds_periodic = xp.asarray(self.inds_periodic[key])
|
|
73
|
+
|
|
74
|
+
if len(self.periods[key]) > 0:
|
|
75
|
+
# get specific periodic parameterss
|
|
76
|
+
diff_periodic = diff[:, :, inds_periodic]
|
|
77
|
+
|
|
78
|
+
# fix when the distance is over 1/2 period away
|
|
79
|
+
inds_fix = (
|
|
80
|
+
xp.abs(diff_periodic) > periods[xp.newaxis, xp.newaxis, :] / 2.0
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# wrap back to make proper periodic distance
|
|
84
|
+
new_s = -(
|
|
85
|
+
periods[xp.newaxis, xp.newaxis, :] - p1[key][:, :, inds_periodic]
|
|
86
|
+
) * (diff_periodic < 0.0) + (
|
|
87
|
+
periods[xp.newaxis, xp.newaxis, :] + p1[key][:, :, inds_periodic]
|
|
88
|
+
) * (
|
|
89
|
+
diff_periodic >= 0.0
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# fill new information
|
|
93
|
+
diff_periodic[inds_fix] = (
|
|
94
|
+
p2[key][:, :, inds_periodic][inds_fix] - new_s[inds_fix]
|
|
95
|
+
)
|
|
96
|
+
diff[:, :, inds_periodic] = diff_periodic
|
|
97
|
+
|
|
98
|
+
out_diff[key] = diff
|
|
99
|
+
|
|
100
|
+
return out_diff
|
|
101
|
+
|
|
102
|
+
def wrap(self, p, xp=None):
|
|
103
|
+
"""Wrap p with periodic distance control
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
p (dict): If dict, keys are ``branch_names``
|
|
107
|
+
and values are positions with parameters along the final dimension.
|
|
108
|
+
xp (object, optional): ``numpy`` or ``cupy``. If ``None``, use ``numpy``.
|
|
109
|
+
(default: ``None``)
|
|
110
|
+
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
# cupy or numpy
|
|
114
|
+
if xp is None:
|
|
115
|
+
xp = np
|
|
116
|
+
|
|
117
|
+
names = list(p.keys())
|
|
118
|
+
# wrap for each branch
|
|
119
|
+
for key in names:
|
|
120
|
+
pos = p[key]
|
|
121
|
+
|
|
122
|
+
if len(self.periods[key]) > 0:
|
|
123
|
+
# get periodic information
|
|
124
|
+
periods = xp.asarray(self.periods[key])
|
|
125
|
+
inds_periodic = xp.asarray(self.inds_periodic[key])
|
|
126
|
+
# wrap
|
|
127
|
+
pos[:, :, inds_periodic] = (
|
|
128
|
+
pos[:, :, inds_periodic] % periods[xp.newaxis, xp.newaxis, :]
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# fill new info
|
|
132
|
+
p[key] = pos
|
|
133
|
+
|
|
134
|
+
return p
|
eryn/utils/stopping.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
from abc import ABC
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Stopping(ABC, object):
|
|
9
|
+
"""Base class for stopping.
|
|
10
|
+
|
|
11
|
+
Stopping checks are only performed every ``thin_by`` iterations.
|
|
12
|
+
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
@classmethod
|
|
16
|
+
def __call__(self, iter, last_sample, sampler):
|
|
17
|
+
"""Call update function.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
iter (int): Iteration of the sampler.
|
|
21
|
+
last_sample (obj): Last state of sampler (:class:`eryn.state.State`).
|
|
22
|
+
sampler (obj): Full sampler oject (:class:`eryn.ensemble.EnsembleSampler`).
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
bool: Value of ``stop``. If ``True``, stop sampling.
|
|
26
|
+
|
|
27
|
+
"""
|
|
28
|
+
raise NotImplementedError
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class SearchConvergeStopping(Stopping):
|
|
32
|
+
"""Stopping function based on a convergence to a maximunm Likelihood.
|
|
33
|
+
|
|
34
|
+
Stopping checks are only performed every ``thin_by`` iterations.
|
|
35
|
+
Therefore, the iterations of stopping checks are really every
|
|
36
|
+
``sampler iterations * thin_by``.
|
|
37
|
+
|
|
38
|
+
All arguments are stored as attributes.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
n_iters (int, optional): Number of iterative stopping checks that need to pass
|
|
42
|
+
in order to stop the sampler. (default: ``30``)
|
|
43
|
+
diff (float, optional): Change in the Likelihood needed to fail the stopping check. In other words,
|
|
44
|
+
if the new maximum Likelihood is more than ``diff`` greater than the old, all iterative checks
|
|
45
|
+
reset. (default: 0.1).
|
|
46
|
+
start_iteration (int, optional): Iteration of sampler to start checking to stop. (default: 0)
|
|
47
|
+
verbose (bool, optional): If ``True``, print information. (default: ``False``)
|
|
48
|
+
|
|
49
|
+
Attributes:
|
|
50
|
+
iters_consecutive (int): Number of consecutive passes of the stopping check.
|
|
51
|
+
past_like_best (float): Previous best Likelihood. The initial value is ``-np.inf``.
|
|
52
|
+
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(self, n_iters=30, diff=0.1, start_iteration=0, verbose=False):
|
|
56
|
+
|
|
57
|
+
# store all the relevant information
|
|
58
|
+
self.n_iters = n_iters
|
|
59
|
+
|
|
60
|
+
self.diff = diff
|
|
61
|
+
self.verbose = verbose
|
|
62
|
+
self.start_iteration = start_iteration
|
|
63
|
+
|
|
64
|
+
# initialize important info
|
|
65
|
+
self.iters_consecutive = 0
|
|
66
|
+
self.past_like_best = -np.inf
|
|
67
|
+
|
|
68
|
+
def __call__(self, iter, sample, sampler):
|
|
69
|
+
"""Call update function.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
iter (int): Iteration of the sampler.
|
|
73
|
+
last_sample (obj): Last state of sampler (:class:`eryn.state.State`).
|
|
74
|
+
sampler (obj): Full sampler oject (:class:`eryn.ensemble.EnsembleSampler`).
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
bool: Value of ``stop``. If ``True``, stop sampling.
|
|
78
|
+
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
# if we have not reached the start iteration return
|
|
82
|
+
if iter < self.start_iteration:
|
|
83
|
+
return False
|
|
84
|
+
|
|
85
|
+
# get best Likelihood so far
|
|
86
|
+
like_best = sampler.get_log_like(discard=self.start_iteration).max()
|
|
87
|
+
|
|
88
|
+
# compare to last
|
|
89
|
+
# if it is less than diff change it passes
|
|
90
|
+
if np.abs(like_best - self.past_like_best) < self.diff:
|
|
91
|
+
self.iters_consecutive += 1
|
|
92
|
+
|
|
93
|
+
else:
|
|
94
|
+
# if it fails reset iters consecutive
|
|
95
|
+
self.iters_consecutive = 0
|
|
96
|
+
|
|
97
|
+
# store new best
|
|
98
|
+
self.past_like_best = like_best
|
|
99
|
+
|
|
100
|
+
# print information
|
|
101
|
+
if self.verbose:
|
|
102
|
+
print(
|
|
103
|
+
f"\nITERS CONSECUTIVE: {self.iters_consecutive}",
|
|
104
|
+
f"Previous best LL: {self.past_like_best}",
|
|
105
|
+
f"Current best LL: {like_best}\n",
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
if self.iters_consecutive >= self.n_iters:
|
|
109
|
+
# if we have passes the number of iters necessary, return True and reset
|
|
110
|
+
self.iters_consecutive = 0
|
|
111
|
+
return True
|
|
112
|
+
|
|
113
|
+
else:
|
|
114
|
+
return False
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
"""
|
|
118
|
+
class AutoCorrelationStop(Stopping):
|
|
119
|
+
# TODO: check and doc this
|
|
120
|
+
def __init__(self, autocorr_multiplier=50, verbose=False):
|
|
121
|
+
self.autocorr_multiplier = autocorr_multiplier
|
|
122
|
+
self.verbose = verbose
|
|
123
|
+
|
|
124
|
+
self.time = 0
|
|
125
|
+
|
|
126
|
+
def __call__(self, iter, last_sample, sampler):
|
|
127
|
+
|
|
128
|
+
tau = sampler.backend.get_autocorr_time(multiply_thin=False)
|
|
129
|
+
|
|
130
|
+
if self.time > 0:
|
|
131
|
+
# backend iteration
|
|
132
|
+
iteration = sampler.backend.iteration
|
|
133
|
+
|
|
134
|
+
finish = []
|
|
135
|
+
|
|
136
|
+
for name, values in tau.items():
|
|
137
|
+
converged = np.all(tau[name] * self.autocorr_multiplier < iteration)
|
|
138
|
+
converged &= np.all(
|
|
139
|
+
np.abs(self.old_tau[name] - tau[name]) / tau[name] < 0.01
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
finish.append(converged)
|
|
143
|
+
|
|
144
|
+
stop = True if np.all(finish) else False
|
|
145
|
+
if self.verbose:
|
|
146
|
+
print(
|
|
147
|
+
"\ntau:",
|
|
148
|
+
tau,
|
|
149
|
+
"\nIteration:",
|
|
150
|
+
iteration,
|
|
151
|
+
"\nAutocorrelation multiplier:",
|
|
152
|
+
self.autocorr_multiplier,
|
|
153
|
+
"\nStopping:",
|
|
154
|
+
stop,
|
|
155
|
+
"\n",
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
else:
|
|
159
|
+
stop = False
|
|
160
|
+
|
|
161
|
+
self.old_tau = tau
|
|
162
|
+
self.time += 1
|
|
163
|
+
return stop
|
|
164
|
+
"""
|
eryn/utils/transform.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
try:
|
|
2
|
+
import cupy as xp
|
|
3
|
+
|
|
4
|
+
except (ModuleNotFoundError, ImportError) as e:
|
|
5
|
+
pass
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TransformContainer:
|
|
11
|
+
"""Container for helpful transformations
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
parameter_transforms (dict, optional): Keys are ``int`` or ``tuple``
|
|
15
|
+
of ``int`` that contain the indexes into the parameters
|
|
16
|
+
that correspond to the transformation added as the Values to the
|
|
17
|
+
dict. If using ``fill_values``, you must be careful with
|
|
18
|
+
making sure parameter transforms properly comes before or after
|
|
19
|
+
filling values. ``int`` indicate single parameter transforms. These
|
|
20
|
+
are performed first. ``tuple`` of ``int`` indicates multiple
|
|
21
|
+
parameter transforms. These are performed after single-parameter transforms.
|
|
22
|
+
(default: ``None``)
|
|
23
|
+
fill_dict (dict, optional): Keys must contain ``'ndim_full'``, ``'fill_inds'``,
|
|
24
|
+
and ``'fill_values'``. ``'ndim_full'`` is the full last dimension of the final
|
|
25
|
+
array after fill_values are added. 'fill_inds' and 'fill_values' are
|
|
26
|
+
np.ndarray[number of fill values] that contain the indexes and corresponding values
|
|
27
|
+
for filling. (default: ``None``)
|
|
28
|
+
|
|
29
|
+
Raises:
|
|
30
|
+
ValueError: Input information is not correct.
|
|
31
|
+
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, parameter_transforms=None, fill_dict=None):
|
|
35
|
+
|
|
36
|
+
# store originals
|
|
37
|
+
self.original_parameter_transforms = parameter_transforms
|
|
38
|
+
if parameter_transforms is not None:
|
|
39
|
+
# differentiate between single and multi parameter transformations
|
|
40
|
+
self.base_transforms = {"single_param": {}, "mult_param": {}}
|
|
41
|
+
|
|
42
|
+
# iterate through transforms and setup single and multiparameter transforms
|
|
43
|
+
for key, item in parameter_transforms.items():
|
|
44
|
+
if isinstance(key, int):
|
|
45
|
+
self.base_transforms["single_param"][key] = item
|
|
46
|
+
elif isinstance(key, tuple):
|
|
47
|
+
self.base_transforms["mult_param"][key] = item
|
|
48
|
+
else:
|
|
49
|
+
raise ValueError(
|
|
50
|
+
"Parameter transform keys must be int or tuple of ints. {} is neither.".format(
|
|
51
|
+
key
|
|
52
|
+
)
|
|
53
|
+
)
|
|
54
|
+
else:
|
|
55
|
+
self.base_transforms = None
|
|
56
|
+
|
|
57
|
+
if fill_dict is not None:
|
|
58
|
+
if not isinstance(fill_dict, dict):
|
|
59
|
+
raise ValueError("fill_dict must be a dictionary.")
|
|
60
|
+
|
|
61
|
+
self.fill_dict = fill_dict
|
|
62
|
+
fill_dict_keys = list(self.fill_dict.keys())
|
|
63
|
+
for key in ["ndim_full", "fill_inds", "fill_values"]:
|
|
64
|
+
# check to make sure it has all necessary pieces
|
|
65
|
+
if key not in fill_dict_keys:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"If providing fill_inds, dictionary must have {key} as a key."
|
|
68
|
+
)
|
|
69
|
+
# check all the inputs
|
|
70
|
+
if not isinstance(fill_dict["ndim_full"], int):
|
|
71
|
+
raise ValueError("fill_dict['ndim_full'] must be an int.")
|
|
72
|
+
if not isinstance(fill_dict["fill_inds"], np.ndarray):
|
|
73
|
+
raise ValueError("fill_dict['fill_inds'] must be an np.ndarray.")
|
|
74
|
+
if not isinstance(fill_dict["fill_values"], np.ndarray):
|
|
75
|
+
raise ValueError("fill_dict['fill_values'] must be an np.ndarray.")
|
|
76
|
+
|
|
77
|
+
# set up test_inds accordingly
|
|
78
|
+
self.fill_dict["test_inds"] = np.delete(
|
|
79
|
+
np.arange(self.fill_dict["ndim_full"]), self.fill_dict["fill_inds"]
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
else:
|
|
83
|
+
self.fill_dict = None
|
|
84
|
+
|
|
85
|
+
def transform_base_parameters(
|
|
86
|
+
self, params, copy=True, return_transpose=False, xp=None
|
|
87
|
+
):
|
|
88
|
+
"""Transform the base parameters
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
params (np.ndarray[..., ndim]): Array with coordinates. This array is
|
|
92
|
+
transformed according to the ``self.base_transforms`` dictionary.
|
|
93
|
+
copy (bool, optional): If True, copy the input array.
|
|
94
|
+
(default: ``True``)
|
|
95
|
+
return_transpose (bool, optional): If True, return the transpose of the
|
|
96
|
+
array. (default: ``False``)
|
|
97
|
+
xp (object, optional): ``numpy`` or ``cupy``. If ``None``, use ``numpy``.
|
|
98
|
+
(default: ``None``)
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
np.ndarray[..., ndim]: Transformed ``params`` array.
|
|
102
|
+
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
# cupy or numpy
|
|
106
|
+
if xp is None:
|
|
107
|
+
xp = np
|
|
108
|
+
|
|
109
|
+
if self.base_transforms is not None:
|
|
110
|
+
params_temp = params.copy() if copy else params
|
|
111
|
+
params_temp = params_temp.T
|
|
112
|
+
# regular transforms
|
|
113
|
+
for ind, trans_fn in self.base_transforms["single_param"].items():
|
|
114
|
+
params_temp[ind] = trans_fn(params_temp[ind])
|
|
115
|
+
|
|
116
|
+
# multi parameter transforms
|
|
117
|
+
for inds, trans_fn in self.base_transforms["mult_param"].items():
|
|
118
|
+
temp = trans_fn(*[params_temp[i] for i in inds])
|
|
119
|
+
for j, i in enumerate(inds):
|
|
120
|
+
params_temp[i] = temp[j]
|
|
121
|
+
|
|
122
|
+
# its actually the opposite now
|
|
123
|
+
if return_transpose:
|
|
124
|
+
return params_temp
|
|
125
|
+
else:
|
|
126
|
+
return params_temp.T
|
|
127
|
+
|
|
128
|
+
else:
|
|
129
|
+
if return_transpose:
|
|
130
|
+
return params.T
|
|
131
|
+
else:
|
|
132
|
+
return params
|
|
133
|
+
|
|
134
|
+
def fill_values(self, params, xp=None):
|
|
135
|
+
"""fill fixed parameters
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
params (np.ndarray[..., ndim]): Array with coordinates. This array is
|
|
139
|
+
filled with values according to the ``self.fill_dict`` dictionary.
|
|
140
|
+
xp (object, optional): ``numpy`` or ``cupy``. If ``None``, use ``numpy``.
|
|
141
|
+
(default: ``None``)
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
np.ndarray[..., ndim_full]: Filled ``params`` array.
|
|
145
|
+
|
|
146
|
+
"""
|
|
147
|
+
if self.fill_dict is not None:
|
|
148
|
+
if xp is None:
|
|
149
|
+
xp = np
|
|
150
|
+
|
|
151
|
+
# get shape
|
|
152
|
+
shape = params.shape
|
|
153
|
+
|
|
154
|
+
# setup new array to fill
|
|
155
|
+
params_filled = xp.zeros(shape[:-1] + (self.fill_dict["ndim_full"],))
|
|
156
|
+
test_inds = xp.asarray(self.fill_dict["test_inds"])
|
|
157
|
+
# special indexing to properly fill array with params
|
|
158
|
+
indexing_test_inds = tuple([slice(0, temp) for temp in shape[:-1]]) + (
|
|
159
|
+
test_inds,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# fill values directly from params array
|
|
163
|
+
params_filled[indexing_test_inds] = params
|
|
164
|
+
|
|
165
|
+
fill_inds = xp.asarray(self.fill_dict["fill_inds"])
|
|
166
|
+
# special indexing to fill fill_values
|
|
167
|
+
indexing_fill_inds = tuple([slice(0, temp) for temp in shape[:-1]]) + (
|
|
168
|
+
fill_inds,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
# add fill_values at fill_inds
|
|
172
|
+
params_filled[indexing_fill_inds] = xp.asarray(
|
|
173
|
+
self.fill_dict["fill_values"]
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
return params_filled
|
|
177
|
+
|
|
178
|
+
else:
|
|
179
|
+
return params
|
|
180
|
+
|
|
181
|
+
def both_transforms(
|
|
182
|
+
self, params, copy=True, return_transpose=False, reverse=False, xp=None
|
|
183
|
+
):
|
|
184
|
+
"""Transform the parameters and fill fixed parameters
|
|
185
|
+
|
|
186
|
+
This fills the fixed parameters and then transforms all of them. Therefore, the user
|
|
187
|
+
must be careful with the indexes input.
|
|
188
|
+
|
|
189
|
+
This is generally the direction recommended because fixed parameters may change
|
|
190
|
+
non-fixed parameters during parameter transformations. This can be reversed
|
|
191
|
+
with the ``reverse`` kwarg.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
params (np.ndarray[..., ndim]): Array with coordinates. This array is
|
|
195
|
+
transformed according to the ``self.base_transforms`` dictionary.
|
|
196
|
+
copy (bool, optional): If True, copy the input array.
|
|
197
|
+
(default: ``True``)
|
|
198
|
+
return_transpose (bool, optional): If ``True``, return the transpose of the
|
|
199
|
+
array. (default: ``False``)
|
|
200
|
+
reverse (bool, optional): If ``True`` perform the filling after the transforms. This makes
|
|
201
|
+
indexing easier, but removes the ability of fixed parameters to affect transforms.
|
|
202
|
+
(default: ``False``)
|
|
203
|
+
xp (object, optional): ``numpy`` or ``cupy``. If ``None``, use ``numpy``.
|
|
204
|
+
(default: ``None``)
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
np.ndarray[..., ndim]: Transformed and filleds ``params`` array.
|
|
208
|
+
|
|
209
|
+
"""
|
|
210
|
+
# numpy or cupy
|
|
211
|
+
if xp is None:
|
|
212
|
+
xp = np
|
|
213
|
+
|
|
214
|
+
# run transforms first
|
|
215
|
+
if reverse:
|
|
216
|
+
temp = self.transform_base_parameters(
|
|
217
|
+
params, copy=copy, return_transpose=return_transpose, xp=xp
|
|
218
|
+
)
|
|
219
|
+
temp = self.fill_values(temp, xp=xp)
|
|
220
|
+
|
|
221
|
+
else:
|
|
222
|
+
temp = self.fill_values(params, xp=xp)
|
|
223
|
+
temp = self.transform_base_parameters(
|
|
224
|
+
temp, copy=copy, return_transpose=return_transpose, xp=xp
|
|
225
|
+
)
|
|
226
|
+
return temp
|
eryn/utils/updates.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
from abc import ABC
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Update(ABC, object):
|
|
9
|
+
"""Update the sampler."""
|
|
10
|
+
|
|
11
|
+
@classmethod
|
|
12
|
+
def __call__(self, iter, last_sample, sampler):
|
|
13
|
+
"""Call update function.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
iter (int): Iteration of the sampler.
|
|
17
|
+
last_sample (obj): Last state of sampler (:class:`eryn.state.State`).
|
|
18
|
+
sampler (obj): Full sampler oject (:class:`eryn.ensemble.EnsembleSampler`).
|
|
19
|
+
|
|
20
|
+
"""
|
|
21
|
+
raise NotImplementedError
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class AdjustStretchProposalScale(Update):
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
target_acceptance=0.22,
|
|
28
|
+
supression_factor=0.1,
|
|
29
|
+
max_change=0.5,
|
|
30
|
+
verbose=False,
|
|
31
|
+
):
|
|
32
|
+
"""Adjusted scale for stretch proposal based on cold chain acceptance rate"""
|
|
33
|
+
self.target_acceptance = target_acceptance
|
|
34
|
+
self.verbose = verbose
|
|
35
|
+
self.max_change, self.supression_factor = max_change, supression_factor
|
|
36
|
+
|
|
37
|
+
self.time = 0
|
|
38
|
+
|
|
39
|
+
def __call__(self, iter, last_sample, sampler):
|
|
40
|
+
|
|
41
|
+
mean_af = 0.0
|
|
42
|
+
change = 1.0
|
|
43
|
+
if self.time > 0:
|
|
44
|
+
# cold chain -> 0
|
|
45
|
+
mean_af = np.mean(
|
|
46
|
+
(sampler.backend.accepted[:, 0] - self.previously_accepted)
|
|
47
|
+
/ (sampler.backend.iteration - self.previous_iter)
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
if mean_af > self.target_acceptance:
|
|
51
|
+
factor = self.supression_factor * (mean_af / self.target_acceptance)
|
|
52
|
+
if factor > self.max_change:
|
|
53
|
+
factor = self.max_change
|
|
54
|
+
change = 1 + self.supression_factor * factor
|
|
55
|
+
|
|
56
|
+
else:
|
|
57
|
+
factor = self.supression_factor * (self.target_acceptance / mean_af)
|
|
58
|
+
if factor > self.max_change:
|
|
59
|
+
factor = self.max_change
|
|
60
|
+
change = 1 - factor
|
|
61
|
+
|
|
62
|
+
sampler._moves[0].a *= change
|
|
63
|
+
|
|
64
|
+
self.previously_accepted = sampler.backend.accepted[:, 0].copy()
|
|
65
|
+
print(
|
|
66
|
+
self.previously_accepted, "\n", mean_af, change, "\n", sampler._moves[0].a
|
|
67
|
+
)
|
|
68
|
+
self.previous_iter = sampler.backend.iteration
|
|
69
|
+
self.time += 1
|