eryn 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
eryn/utils/__init__.py ADDED
@@ -0,0 +1,10 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # from .plot import PlotContainer
4
+ from .utility import *
5
+ from .periodic import *
6
+ from .transform import *
7
+ from .updates import *
8
+ from .stopping import *
9
+
10
+ __all__ = ["PlotContainer", "PeriodicContainer", "TransformContainer"]
eryn/utils/periodic.py ADDED
@@ -0,0 +1,134 @@
1
+ from ast import Import
2
+ import numpy as np
3
+
4
+ try:
5
+ import cupy as xp
6
+
7
+ except (ModuleNotFoundError, ImportError) as e:
8
+ pass
9
+
10
+
11
+ class PeriodicContainer:
12
+ """Perform operations for periodic parameters
13
+
14
+ Args:
15
+ periodic_in (dict): Keys are ``branch_names``. Values are
16
+ dictionaries. These dictionaries have keys as the parameter
17
+ indexes and values their associated period.
18
+
19
+ """
20
+
21
+ def __init__(self, periodic):
22
+
23
+ # store all the information
24
+ self.periodic = periodic
25
+ self.inds_periodic = {
26
+ key: np.asarray([i for i in periodic[key].keys()]) for key in periodic
27
+ }
28
+ self.periods = {
29
+ key: np.asarray([i for i in periodic[key].values()]) for key in periodic
30
+ }
31
+
32
+ def distance(self, p1, p2, xp=None):
33
+ """Move from p1 to p2 with periodic distance control
34
+
35
+ Args:
36
+ p1 (dict): If dict, keys are ``branch_names``
37
+ and values are positions with parameters along the final dimension.
38
+ p2 (dict): If dict, keys are ``branch_names``
39
+ and values are positions with parameters along the final dimension.
40
+ xp (object, optional): ``numpy`` or ``cupy``. If ``None``, use ``numpy``.
41
+ (default: ``None``)
42
+
43
+ Returns:
44
+ dict: Distances accounting for periodicity.
45
+ Keys are branch names and values are distance arrays.
46
+
47
+ """
48
+
49
+ # cupy or numpy
50
+ if xp is None:
51
+ xp = np
52
+
53
+ # make sure both have same branches
54
+ assert list(p1.keys()) == list(p2.keys())
55
+
56
+ names = list(p1.keys())
57
+
58
+ # prepare output
59
+ out_diff = {}
60
+ for key in names:
61
+
62
+ # get basic distance
63
+ diff = p2[key] - p1[key]
64
+
65
+ # no periodic parameters for this key
66
+ if key not in self.periods:
67
+ out_diff[key] = diff
68
+ continue
69
+
70
+ # get period info
71
+ periods = xp.asarray(self.periods[key])
72
+ inds_periodic = xp.asarray(self.inds_periodic[key])
73
+
74
+ if len(self.periods[key]) > 0:
75
+ # get specific periodic parameterss
76
+ diff_periodic = diff[:, :, inds_periodic]
77
+
78
+ # fix when the distance is over 1/2 period away
79
+ inds_fix = (
80
+ xp.abs(diff_periodic) > periods[xp.newaxis, xp.newaxis, :] / 2.0
81
+ )
82
+
83
+ # wrap back to make proper periodic distance
84
+ new_s = -(
85
+ periods[xp.newaxis, xp.newaxis, :] - p1[key][:, :, inds_periodic]
86
+ ) * (diff_periodic < 0.0) + (
87
+ periods[xp.newaxis, xp.newaxis, :] + p1[key][:, :, inds_periodic]
88
+ ) * (
89
+ diff_periodic >= 0.0
90
+ )
91
+
92
+ # fill new information
93
+ diff_periodic[inds_fix] = (
94
+ p2[key][:, :, inds_periodic][inds_fix] - new_s[inds_fix]
95
+ )
96
+ diff[:, :, inds_periodic] = diff_periodic
97
+
98
+ out_diff[key] = diff
99
+
100
+ return out_diff
101
+
102
+ def wrap(self, p, xp=None):
103
+ """Wrap p with periodic distance control
104
+
105
+ Args:
106
+ p (dict): If dict, keys are ``branch_names``
107
+ and values are positions with parameters along the final dimension.
108
+ xp (object, optional): ``numpy`` or ``cupy``. If ``None``, use ``numpy``.
109
+ (default: ``None``)
110
+
111
+ """
112
+
113
+ # cupy or numpy
114
+ if xp is None:
115
+ xp = np
116
+
117
+ names = list(p.keys())
118
+ # wrap for each branch
119
+ for key in names:
120
+ pos = p[key]
121
+
122
+ if len(self.periods[key]) > 0:
123
+ # get periodic information
124
+ periods = xp.asarray(self.periods[key])
125
+ inds_periodic = xp.asarray(self.inds_periodic[key])
126
+ # wrap
127
+ pos[:, :, inds_periodic] = (
128
+ pos[:, :, inds_periodic] % periods[xp.newaxis, xp.newaxis, :]
129
+ )
130
+
131
+ # fill new info
132
+ p[key] = pos
133
+
134
+ return p
eryn/utils/stopping.py ADDED
@@ -0,0 +1,164 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from abc import ABC
4
+
5
+ import numpy as np
6
+
7
+
8
+ class Stopping(ABC, object):
9
+ """Base class for stopping.
10
+
11
+ Stopping checks are only performed every ``thin_by`` iterations.
12
+
13
+ """
14
+
15
+ @classmethod
16
+ def __call__(self, iter, last_sample, sampler):
17
+ """Call update function.
18
+
19
+ Args:
20
+ iter (int): Iteration of the sampler.
21
+ last_sample (obj): Last state of sampler (:class:`eryn.state.State`).
22
+ sampler (obj): Full sampler oject (:class:`eryn.ensemble.EnsembleSampler`).
23
+
24
+ Returns:
25
+ bool: Value of ``stop``. If ``True``, stop sampling.
26
+
27
+ """
28
+ raise NotImplementedError
29
+
30
+
31
+ class SearchConvergeStopping(Stopping):
32
+ """Stopping function based on a convergence to a maximunm Likelihood.
33
+
34
+ Stopping checks are only performed every ``thin_by`` iterations.
35
+ Therefore, the iterations of stopping checks are really every
36
+ ``sampler iterations * thin_by``.
37
+
38
+ All arguments are stored as attributes.
39
+
40
+ Args:
41
+ n_iters (int, optional): Number of iterative stopping checks that need to pass
42
+ in order to stop the sampler. (default: ``30``)
43
+ diff (float, optional): Change in the Likelihood needed to fail the stopping check. In other words,
44
+ if the new maximum Likelihood is more than ``diff`` greater than the old, all iterative checks
45
+ reset. (default: 0.1).
46
+ start_iteration (int, optional): Iteration of sampler to start checking to stop. (default: 0)
47
+ verbose (bool, optional): If ``True``, print information. (default: ``False``)
48
+
49
+ Attributes:
50
+ iters_consecutive (int): Number of consecutive passes of the stopping check.
51
+ past_like_best (float): Previous best Likelihood. The initial value is ``-np.inf``.
52
+
53
+ """
54
+
55
+ def __init__(self, n_iters=30, diff=0.1, start_iteration=0, verbose=False):
56
+
57
+ # store all the relevant information
58
+ self.n_iters = n_iters
59
+
60
+ self.diff = diff
61
+ self.verbose = verbose
62
+ self.start_iteration = start_iteration
63
+
64
+ # initialize important info
65
+ self.iters_consecutive = 0
66
+ self.past_like_best = -np.inf
67
+
68
+ def __call__(self, iter, sample, sampler):
69
+ """Call update function.
70
+
71
+ Args:
72
+ iter (int): Iteration of the sampler.
73
+ last_sample (obj): Last state of sampler (:class:`eryn.state.State`).
74
+ sampler (obj): Full sampler oject (:class:`eryn.ensemble.EnsembleSampler`).
75
+
76
+ Returns:
77
+ bool: Value of ``stop``. If ``True``, stop sampling.
78
+
79
+ """
80
+
81
+ # if we have not reached the start iteration return
82
+ if iter < self.start_iteration:
83
+ return False
84
+
85
+ # get best Likelihood so far
86
+ like_best = sampler.get_log_like(discard=self.start_iteration).max()
87
+
88
+ # compare to last
89
+ # if it is less than diff change it passes
90
+ if np.abs(like_best - self.past_like_best) < self.diff:
91
+ self.iters_consecutive += 1
92
+
93
+ else:
94
+ # if it fails reset iters consecutive
95
+ self.iters_consecutive = 0
96
+
97
+ # store new best
98
+ self.past_like_best = like_best
99
+
100
+ # print information
101
+ if self.verbose:
102
+ print(
103
+ f"\nITERS CONSECUTIVE: {self.iters_consecutive}",
104
+ f"Previous best LL: {self.past_like_best}",
105
+ f"Current best LL: {like_best}\n",
106
+ )
107
+
108
+ if self.iters_consecutive >= self.n_iters:
109
+ # if we have passes the number of iters necessary, return True and reset
110
+ self.iters_consecutive = 0
111
+ return True
112
+
113
+ else:
114
+ return False
115
+
116
+
117
+ """
118
+ class AutoCorrelationStop(Stopping):
119
+ # TODO: check and doc this
120
+ def __init__(self, autocorr_multiplier=50, verbose=False):
121
+ self.autocorr_multiplier = autocorr_multiplier
122
+ self.verbose = verbose
123
+
124
+ self.time = 0
125
+
126
+ def __call__(self, iter, last_sample, sampler):
127
+
128
+ tau = sampler.backend.get_autocorr_time(multiply_thin=False)
129
+
130
+ if self.time > 0:
131
+ # backend iteration
132
+ iteration = sampler.backend.iteration
133
+
134
+ finish = []
135
+
136
+ for name, values in tau.items():
137
+ converged = np.all(tau[name] * self.autocorr_multiplier < iteration)
138
+ converged &= np.all(
139
+ np.abs(self.old_tau[name] - tau[name]) / tau[name] < 0.01
140
+ )
141
+
142
+ finish.append(converged)
143
+
144
+ stop = True if np.all(finish) else False
145
+ if self.verbose:
146
+ print(
147
+ "\ntau:",
148
+ tau,
149
+ "\nIteration:",
150
+ iteration,
151
+ "\nAutocorrelation multiplier:",
152
+ self.autocorr_multiplier,
153
+ "\nStopping:",
154
+ stop,
155
+ "\n",
156
+ )
157
+
158
+ else:
159
+ stop = False
160
+
161
+ self.old_tau = tau
162
+ self.time += 1
163
+ return stop
164
+ """
@@ -0,0 +1,226 @@
1
+ try:
2
+ import cupy as xp
3
+
4
+ except (ModuleNotFoundError, ImportError) as e:
5
+ pass
6
+
7
+ import numpy as np
8
+
9
+
10
+ class TransformContainer:
11
+ """Container for helpful transformations
12
+
13
+ Args:
14
+ parameter_transforms (dict, optional): Keys are ``int`` or ``tuple``
15
+ of ``int`` that contain the indexes into the parameters
16
+ that correspond to the transformation added as the Values to the
17
+ dict. If using ``fill_values``, you must be careful with
18
+ making sure parameter transforms properly comes before or after
19
+ filling values. ``int`` indicate single parameter transforms. These
20
+ are performed first. ``tuple`` of ``int`` indicates multiple
21
+ parameter transforms. These are performed after single-parameter transforms.
22
+ (default: ``None``)
23
+ fill_dict (dict, optional): Keys must contain ``'ndim_full'``, ``'fill_inds'``,
24
+ and ``'fill_values'``. ``'ndim_full'`` is the full last dimension of the final
25
+ array after fill_values are added. 'fill_inds' and 'fill_values' are
26
+ np.ndarray[number of fill values] that contain the indexes and corresponding values
27
+ for filling. (default: ``None``)
28
+
29
+ Raises:
30
+ ValueError: Input information is not correct.
31
+
32
+ """
33
+
34
+ def __init__(self, parameter_transforms=None, fill_dict=None):
35
+
36
+ # store originals
37
+ self.original_parameter_transforms = parameter_transforms
38
+ if parameter_transforms is not None:
39
+ # differentiate between single and multi parameter transformations
40
+ self.base_transforms = {"single_param": {}, "mult_param": {}}
41
+
42
+ # iterate through transforms and setup single and multiparameter transforms
43
+ for key, item in parameter_transforms.items():
44
+ if isinstance(key, int):
45
+ self.base_transforms["single_param"][key] = item
46
+ elif isinstance(key, tuple):
47
+ self.base_transforms["mult_param"][key] = item
48
+ else:
49
+ raise ValueError(
50
+ "Parameter transform keys must be int or tuple of ints. {} is neither.".format(
51
+ key
52
+ )
53
+ )
54
+ else:
55
+ self.base_transforms = None
56
+
57
+ if fill_dict is not None:
58
+ if not isinstance(fill_dict, dict):
59
+ raise ValueError("fill_dict must be a dictionary.")
60
+
61
+ self.fill_dict = fill_dict
62
+ fill_dict_keys = list(self.fill_dict.keys())
63
+ for key in ["ndim_full", "fill_inds", "fill_values"]:
64
+ # check to make sure it has all necessary pieces
65
+ if key not in fill_dict_keys:
66
+ raise ValueError(
67
+ f"If providing fill_inds, dictionary must have {key} as a key."
68
+ )
69
+ # check all the inputs
70
+ if not isinstance(fill_dict["ndim_full"], int):
71
+ raise ValueError("fill_dict['ndim_full'] must be an int.")
72
+ if not isinstance(fill_dict["fill_inds"], np.ndarray):
73
+ raise ValueError("fill_dict['fill_inds'] must be an np.ndarray.")
74
+ if not isinstance(fill_dict["fill_values"], np.ndarray):
75
+ raise ValueError("fill_dict['fill_values'] must be an np.ndarray.")
76
+
77
+ # set up test_inds accordingly
78
+ self.fill_dict["test_inds"] = np.delete(
79
+ np.arange(self.fill_dict["ndim_full"]), self.fill_dict["fill_inds"]
80
+ )
81
+
82
+ else:
83
+ self.fill_dict = None
84
+
85
+ def transform_base_parameters(
86
+ self, params, copy=True, return_transpose=False, xp=None
87
+ ):
88
+ """Transform the base parameters
89
+
90
+ Args:
91
+ params (np.ndarray[..., ndim]): Array with coordinates. This array is
92
+ transformed according to the ``self.base_transforms`` dictionary.
93
+ copy (bool, optional): If True, copy the input array.
94
+ (default: ``True``)
95
+ return_transpose (bool, optional): If True, return the transpose of the
96
+ array. (default: ``False``)
97
+ xp (object, optional): ``numpy`` or ``cupy``. If ``None``, use ``numpy``.
98
+ (default: ``None``)
99
+
100
+ Returns:
101
+ np.ndarray[..., ndim]: Transformed ``params`` array.
102
+
103
+ """
104
+
105
+ # cupy or numpy
106
+ if xp is None:
107
+ xp = np
108
+
109
+ if self.base_transforms is not None:
110
+ params_temp = params.copy() if copy else params
111
+ params_temp = params_temp.T
112
+ # regular transforms
113
+ for ind, trans_fn in self.base_transforms["single_param"].items():
114
+ params_temp[ind] = trans_fn(params_temp[ind])
115
+
116
+ # multi parameter transforms
117
+ for inds, trans_fn in self.base_transforms["mult_param"].items():
118
+ temp = trans_fn(*[params_temp[i] for i in inds])
119
+ for j, i in enumerate(inds):
120
+ params_temp[i] = temp[j]
121
+
122
+ # its actually the opposite now
123
+ if return_transpose:
124
+ return params_temp
125
+ else:
126
+ return params_temp.T
127
+
128
+ else:
129
+ if return_transpose:
130
+ return params.T
131
+ else:
132
+ return params
133
+
134
+ def fill_values(self, params, xp=None):
135
+ """fill fixed parameters
136
+
137
+ Args:
138
+ params (np.ndarray[..., ndim]): Array with coordinates. This array is
139
+ filled with values according to the ``self.fill_dict`` dictionary.
140
+ xp (object, optional): ``numpy`` or ``cupy``. If ``None``, use ``numpy``.
141
+ (default: ``None``)
142
+
143
+ Returns:
144
+ np.ndarray[..., ndim_full]: Filled ``params`` array.
145
+
146
+ """
147
+ if self.fill_dict is not None:
148
+ if xp is None:
149
+ xp = np
150
+
151
+ # get shape
152
+ shape = params.shape
153
+
154
+ # setup new array to fill
155
+ params_filled = xp.zeros(shape[:-1] + (self.fill_dict["ndim_full"],))
156
+ test_inds = xp.asarray(self.fill_dict["test_inds"])
157
+ # special indexing to properly fill array with params
158
+ indexing_test_inds = tuple([slice(0, temp) for temp in shape[:-1]]) + (
159
+ test_inds,
160
+ )
161
+
162
+ # fill values directly from params array
163
+ params_filled[indexing_test_inds] = params
164
+
165
+ fill_inds = xp.asarray(self.fill_dict["fill_inds"])
166
+ # special indexing to fill fill_values
167
+ indexing_fill_inds = tuple([slice(0, temp) for temp in shape[:-1]]) + (
168
+ fill_inds,
169
+ )
170
+
171
+ # add fill_values at fill_inds
172
+ params_filled[indexing_fill_inds] = xp.asarray(
173
+ self.fill_dict["fill_values"]
174
+ )
175
+
176
+ return params_filled
177
+
178
+ else:
179
+ return params
180
+
181
+ def both_transforms(
182
+ self, params, copy=True, return_transpose=False, reverse=False, xp=None
183
+ ):
184
+ """Transform the parameters and fill fixed parameters
185
+
186
+ This fills the fixed parameters and then transforms all of them. Therefore, the user
187
+ must be careful with the indexes input.
188
+
189
+ This is generally the direction recommended because fixed parameters may change
190
+ non-fixed parameters during parameter transformations. This can be reversed
191
+ with the ``reverse`` kwarg.
192
+
193
+ Args:
194
+ params (np.ndarray[..., ndim]): Array with coordinates. This array is
195
+ transformed according to the ``self.base_transforms`` dictionary.
196
+ copy (bool, optional): If True, copy the input array.
197
+ (default: ``True``)
198
+ return_transpose (bool, optional): If ``True``, return the transpose of the
199
+ array. (default: ``False``)
200
+ reverse (bool, optional): If ``True`` perform the filling after the transforms. This makes
201
+ indexing easier, but removes the ability of fixed parameters to affect transforms.
202
+ (default: ``False``)
203
+ xp (object, optional): ``numpy`` or ``cupy``. If ``None``, use ``numpy``.
204
+ (default: ``None``)
205
+
206
+ Returns:
207
+ np.ndarray[..., ndim]: Transformed and filleds ``params`` array.
208
+
209
+ """
210
+ # numpy or cupy
211
+ if xp is None:
212
+ xp = np
213
+
214
+ # run transforms first
215
+ if reverse:
216
+ temp = self.transform_base_parameters(
217
+ params, copy=copy, return_transpose=return_transpose, xp=xp
218
+ )
219
+ temp = self.fill_values(temp, xp=xp)
220
+
221
+ else:
222
+ temp = self.fill_values(params, xp=xp)
223
+ temp = self.transform_base_parameters(
224
+ temp, copy=copy, return_transpose=return_transpose, xp=xp
225
+ )
226
+ return temp
eryn/utils/updates.py ADDED
@@ -0,0 +1,69 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from abc import ABC
4
+
5
+ import numpy as np
6
+
7
+
8
+ class Update(ABC, object):
9
+ """Update the sampler."""
10
+
11
+ @classmethod
12
+ def __call__(self, iter, last_sample, sampler):
13
+ """Call update function.
14
+
15
+ Args:
16
+ iter (int): Iteration of the sampler.
17
+ last_sample (obj): Last state of sampler (:class:`eryn.state.State`).
18
+ sampler (obj): Full sampler oject (:class:`eryn.ensemble.EnsembleSampler`).
19
+
20
+ """
21
+ raise NotImplementedError
22
+
23
+
24
+ class AdjustStretchProposalScale(Update):
25
+ def __init__(
26
+ self,
27
+ target_acceptance=0.22,
28
+ supression_factor=0.1,
29
+ max_change=0.5,
30
+ verbose=False,
31
+ ):
32
+ """Adjusted scale for stretch proposal based on cold chain acceptance rate"""
33
+ self.target_acceptance = target_acceptance
34
+ self.verbose = verbose
35
+ self.max_change, self.supression_factor = max_change, supression_factor
36
+
37
+ self.time = 0
38
+
39
+ def __call__(self, iter, last_sample, sampler):
40
+
41
+ mean_af = 0.0
42
+ change = 1.0
43
+ if self.time > 0:
44
+ # cold chain -> 0
45
+ mean_af = np.mean(
46
+ (sampler.backend.accepted[:, 0] - self.previously_accepted)
47
+ / (sampler.backend.iteration - self.previous_iter)
48
+ )
49
+
50
+ if mean_af > self.target_acceptance:
51
+ factor = self.supression_factor * (mean_af / self.target_acceptance)
52
+ if factor > self.max_change:
53
+ factor = self.max_change
54
+ change = 1 + self.supression_factor * factor
55
+
56
+ else:
57
+ factor = self.supression_factor * (self.target_acceptance / mean_af)
58
+ if factor > self.max_change:
59
+ factor = self.max_change
60
+ change = 1 - factor
61
+
62
+ sampler._moves[0].a *= change
63
+
64
+ self.previously_accepted = sampler.backend.accepted[:, 0].copy()
65
+ print(
66
+ self.previously_accepted, "\n", mean_af, change, "\n", sampler._moves[0].a
67
+ )
68
+ self.previous_iter = sampler.backend.iteration
69
+ self.time += 1