eryn 1.2.4__py3-none-any.whl → 1.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
eryn/utils/transform.py CHANGED
@@ -11,6 +11,10 @@ class TransformContainer:
11
11
  """Container for helpful transformations
12
12
 
13
13
  Args:
14
+ input_basis (list): List of integers or strings representing each
15
+ basis element from the input basis.
16
+ output_basis (list): List of integers or strings representing each
17
+ basis element for the output basis.
14
18
  parameter_transforms (dict, optional): Keys are ``int`` or ``tuple``
15
19
  of ``int`` that contain the indexes into the parameters
16
20
  that correspond to the transformation added as the Values to the
@@ -31,53 +35,70 @@ class TransformContainer:
31
35
 
32
36
  """
33
37
 
34
- def __init__(self, parameter_transforms=None, fill_dict=None):
38
+ def __init__(self, input_basis=None, output_basis=None, parameter_transforms=None, fill_dict=None, key_map={}):
35
39
 
40
+
36
41
  # store originals
37
42
  self.original_parameter_transforms = parameter_transforms
43
+ self.ndim_full = len(output_basis)
44
+ self.ndim = len(input_basis)
45
+
46
+ self.input_basis, self.output_basis = input_basis, output_basis
47
+
48
+ test_inds = []
49
+ for key in input_basis:
50
+ if key not in output_basis and key not in key_map:
51
+ raise ValueError("All keys in input_basis must be present in output basis, or you must provide a key_map")
52
+ key_in = key if key not in key_map else key_map[key]
53
+ test_inds.append(output_basis.index(key_in))
54
+
55
+ self.test_inds = test_inds = np.asarray(test_inds)
38
56
  if parameter_transforms is not None:
39
57
  # differentiate between single and multi parameter transformations
40
58
  self.base_transforms = {"single_param": {}, "mult_param": {}}
41
59
 
42
60
  # iterate through transforms and setup single and multiparameter transforms
43
61
  for key, item in parameter_transforms.items():
44
- if isinstance(key, int):
45
- self.base_transforms["single_param"][key] = item
62
+ if isinstance(key, str) or isinstance(key, int):
63
+ if key not in output_basis:
64
+ assert key in key_map
65
+ key = key_map[key]
66
+ key_in = output_basis.index(key)
67
+ self.base_transforms["single_param"][key_in] = item
46
68
  elif isinstance(key, tuple):
47
- self.base_transforms["mult_param"][key] = item
69
+ _tmp = []
70
+ for i in range(len(key)):
71
+ key_tmp = key[i]
72
+ if key_tmp not in output_basis:
73
+ assert key_tmp in key_map
74
+ key_tmp = key_map[key_tmp]
75
+ _tmp.append(output_basis.index(key_tmp))
76
+ self.base_transforms["mult_param"][tuple(_tmp)] = item
48
77
  else:
49
78
  raise ValueError(
50
- "Parameter transform keys must be int or tuple of ints. {} is neither.".format(
79
+ "Parameter transform keys must be str (or int) or tuple of strs (or ints). {} is neither.".format(
51
80
  key
52
81
  )
53
82
  )
54
83
  else:
55
84
  self.base_transforms = None
56
85
 
86
+ self.original_fill_dict = fill_dict
57
87
  if fill_dict is not None:
58
88
  if not isinstance(fill_dict, dict):
59
89
  raise ValueError("fill_dict must be a dictionary.")
60
90
 
61
- self.fill_dict = fill_dict
62
- fill_dict_keys = list(self.fill_dict.keys())
63
- for key in ["ndim_full", "fill_inds", "fill_values"]:
64
- # check to make sure it has all necessary pieces
65
- if key not in fill_dict_keys:
66
- raise ValueError(
67
- f"If providing fill_inds, dictionary must have {key} as a key."
68
- )
69
- # check all the inputs
70
- if not isinstance(fill_dict["ndim_full"], int):
71
- raise ValueError("fill_dict['ndim_full'] must be an int.")
72
- if not isinstance(fill_dict["fill_inds"], np.ndarray):
73
- raise ValueError("fill_dict['fill_inds'] must be an np.ndarray.")
74
- if not isinstance(fill_dict["fill_values"], np.ndarray):
75
- raise ValueError("fill_dict['fill_values'] must be an np.ndarray.")
91
+ self.fill_dict = {}
92
+ self.fill_dict["fill_inds"] = []
93
+ self.fill_dict["fill_values"] = []
94
+ for key in fill_dict.keys():
95
+ self.fill_dict["fill_inds"].append(output_basis.index(key))
96
+ self.fill_dict["fill_values"].append(fill_dict[key])
76
97
 
77
98
  # set up test_inds accordingly
78
- self.fill_dict["test_inds"] = np.delete(
79
- np.arange(self.fill_dict["ndim_full"]), self.fill_dict["fill_inds"]
80
- )
99
+ self.fill_dict["test_inds"] = test_inds
100
+ self.fill_dict["fill_inds"] = np.asarray(self.fill_dict["fill_inds"])
101
+ self.fill_dict["fill_values"] = np.asarray(self.fill_dict["fill_values"])
81
102
 
82
103
  else:
83
104
  self.fill_dict = None
@@ -134,6 +155,8 @@ class TransformContainer:
134
155
  def fill_values(self, params, xp=None):
135
156
  """fill fixed parameters
136
157
 
158
+ This also adjusts parameter order as needed between the two bases.
159
+
137
160
  Args:
138
161
  params (np.ndarray[..., ndim]): Array with coordinates. This array is
139
162
  filled with values according to the ``self.fill_dict`` dictionary.
@@ -152,7 +175,7 @@ class TransformContainer:
152
175
  shape = params.shape
153
176
 
154
177
  # setup new array to fill
155
- params_filled = xp.zeros(shape[:-1] + (self.fill_dict["ndim_full"],))
178
+ params_filled = xp.zeros(shape[:-1] + (self.ndim_full,))
156
179
  test_inds = xp.asarray(self.fill_dict["test_inds"])
157
180
  # special indexing to properly fill array with params
158
181
  indexing_test_inds = tuple([slice(0, temp) for temp in shape[:-1]]) + (
@@ -179,7 +202,7 @@ class TransformContainer:
179
202
  return params
180
203
 
181
204
  def both_transforms(
182
- self, params, copy=True, return_transpose=False, reverse=False, xp=None
205
+ self, params, copy=True, return_transpose=False, xp=None
183
206
  ):
184
207
  """Transform the parameters and fill fixed parameters
185
208
 
@@ -197,9 +220,6 @@ class TransformContainer:
197
220
  (default: ``True``)
198
221
  return_transpose (bool, optional): If ``True``, return the transpose of the
199
222
  array. (default: ``False``)
200
- reverse (bool, optional): If ``True`` perform the filling after the transforms. This makes
201
- indexing easier, but removes the ability of fixed parameters to affect transforms.
202
- (default: ``False``)
203
223
  xp (object, optional): ``numpy`` or ``cupy``. If ``None``, use ``numpy``.
204
224
  (default: ``None``)
205
225
 
@@ -212,15 +232,8 @@ class TransformContainer:
212
232
  xp = np
213
233
 
214
234
  # run transforms first
215
- if reverse:
216
- temp = self.transform_base_parameters(
217
- params, copy=copy, return_transpose=return_transpose, xp=xp
218
- )
219
- temp = self.fill_values(temp, xp=xp)
220
-
221
- else:
222
- temp = self.fill_values(params, xp=xp)
223
- temp = self.transform_base_parameters(
224
- temp, copy=copy, return_transpose=return_transpose, xp=xp
225
- )
235
+ temp = self.fill_values(params, xp=xp)
236
+ temp = self.transform_base_parameters(
237
+ temp, copy=copy, return_transpose=return_transpose, xp=xp
238
+ )
226
239
  return temp
eryn/utils/updates.py CHANGED
@@ -1,6 +1,7 @@
1
1
  # -*- coding: utf-8 -*-
2
2
 
3
3
  from abc import ABC
4
+ import dataclasses
4
5
 
5
6
  import numpy as np
6
7
 
@@ -20,6 +21,111 @@ class Update(ABC, object):
20
21
  """
21
22
  raise NotImplementedError
22
23
 
24
+ class CompositeUpdate(Update):
25
+ """A composite update that chains multiple Update objects together."""
26
+
27
+ def __init__(self, updates: list):
28
+ """
29
+ Args:
30
+ updates (list): List of Update objects to chain together.
31
+ """
32
+ self._updates = updates
33
+
34
+ def __call__(self, iter, last_sample, sampler):
35
+ """Call all chained updates in sequence."""
36
+ for update in self._updates:
37
+ update(iter, last_sample, sampler)
38
+
39
+ def __add__(self, other):
40
+ """Concatenate with another Update or CompositeUpdate."""
41
+ if isinstance(other, CompositeUpdate):
42
+ return CompositeUpdate(self._updates + other._updates)
43
+ elif isinstance(other, Update):
44
+ return CompositeUpdate(self._updates + [other])
45
+ else:
46
+ raise NotImplementedError
47
+
48
+ def __radd__(self, other):
49
+ """Support other + self."""
50
+ if isinstance(other, CompositeUpdate):
51
+ return CompositeUpdate(other._updates + self._updates)
52
+ elif isinstance(other, Update):
53
+ return CompositeUpdate([other] + self._updates)
54
+ else:
55
+ raise NotImplementedError
56
+
57
+ def __repr__(self):
58
+ return f"CompositeUpdate({self._updates})"
59
+
60
+
61
+ @dataclasses.dataclass
62
+ class UpdateStep(Update):
63
+ """
64
+ Base class for chainable update steps.
65
+
66
+ Attributes:
67
+ nsteps (int): Base number of steps between updates.
68
+ increment (int): Factor by which to increase the interval.
69
+ increment_every (int): Number of iterations after which to increase the interval.
70
+ stop (int): Optional iteration to stop updates.
71
+ """
72
+ nsteps: int = 100
73
+ increment: int = 1
74
+ increment_every: int = 500
75
+ stop: int = None
76
+
77
+ def __add__(self, other):
78
+ """Concatenate with another Update or CompositeUpdate."""
79
+ if isinstance(other, CompositeUpdate):
80
+ return CompositeUpdate([self] + other._updates)
81
+ elif isinstance(other, Update):
82
+ return CompositeUpdate([self, other])
83
+ else:
84
+ return NotImplemented
85
+
86
+ def __radd__(self, other):
87
+ """Support other + self."""
88
+ if isinstance(other, CompositeUpdate):
89
+ return CompositeUpdate(other._updates + [self])
90
+ elif isinstance(other, Update):
91
+ return CompositeUpdate([other, self])
92
+ else:
93
+ return NotImplemented
94
+
95
+ def check_step(self, iteration):
96
+ """Check if the update should be applied at this iteration.
97
+
98
+ The diagnostic frequency decreases over time. The interval between
99
+ diagnostics is multiplied by `increment` every `increment_every` steps.
100
+
101
+ Example with nsteps=100, increment=2, increment_every=500:
102
+ - iterations 0-499: check every 100 steps (but not at 0)
103
+ - iterations 500-999: check every 200 steps
104
+ - iterations 1000-1499: check every 400 steps
105
+ - etc.
106
+ """
107
+ if iteration == 0:
108
+ return False
109
+
110
+ exponent = iteration // self.increment_every
111
+ interval = self.nsteps * (self.increment ** exponent)
112
+
113
+ if self.stop is not None and iteration >= self.stop:
114
+ return False
115
+
116
+ return (iteration % interval == 0)
117
+
118
+ def update(self, iteration, last_sample, sampler):
119
+ """Override this method in subclasses to define the update behavior."""
120
+ raise NotImplementedError("Subclasses must implement update() method.")
121
+
122
+ def __call__(self, iteration, last_sample, sampler):
123
+ """Call the update if the step condition is met."""
124
+ if self.check_step(iteration):
125
+ print(f'Calling {self.__class__.__name__} at iteration {iteration}')
126
+ self.update(iteration, last_sample, sampler)
127
+
128
+
23
129
 
24
130
  class AdjustStretchProposalScale(Update):
25
131
  def __init__(
eryn/utils/utility.py CHANGED
@@ -218,7 +218,7 @@ def stepping_stone_log_evidence(betas, logls, block_len=50, repeats=100):
218
218
 
219
219
  Based on
220
220
  a. https://arxiv.org/abs/1810.04488 and
221
- b. https://pubmed.ncbi.nlm.nih.gov/21187451/.
221
+ b. doi: 10.1093/sysbio/syq085
222
222
 
223
223
  Args:
224
224
  betas (np.ndarray[ntemps]): The inverse temperatures to use for the quadrature.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: eryn
3
- Version: 1.2.4
3
+ Version: 1.2.6
4
4
  Summary: Eryn: an omni-MCMC sampling package.
5
5
  Author: Michael Katz
6
6
  Author-email: Michael Katz <mikekatz04@gmail.com>
@@ -9,7 +9,6 @@ Classifier: Natural Language :: English
9
9
  Classifier: Programming Language :: C++
10
10
  Classifier: Programming Language :: Cython
11
11
  Classifier: Programming Language :: Python :: 3 :: Only
12
- Classifier: Programming Language :: Python :: 3.9
13
12
  Classifier: Programming Language :: Python :: 3.10
14
13
  Classifier: Programming Language :: Python :: 3.11
15
14
  Classifier: Programming Language :: Python :: 3.12
@@ -19,6 +18,9 @@ Requires-Dist: h5py
19
18
  Requires-Dist: jsonschema
20
19
  Requires-Dist: matplotlib
21
20
  Requires-Dist: numpy
21
+ Requires-Dist: pandas
22
+ Requires-Dist: corner
23
+ Requires-Dist: seaborn
22
24
  Requires-Dist: nvidia-ml-py
23
25
  Requires-Dist: platformdirs
24
26
  Requires-Dist: pydantic
@@ -41,7 +43,8 @@ Requires-Dist: corner ; extra == 'doc'
41
43
  Requires-Dist: matplotlib ; extra == 'testing'
42
44
  Requires-Dist: corner ; extra == 'testing'
43
45
  Requires-Dist: chainconsumer ; extra == 'testing'
44
- Requires-Python: >=3.9
46
+ Requires-Dist: scienceplots ; extra == 'testing'
47
+ Requires-Python: >=3.10
45
48
  Provides-Extra: doc
46
49
  Provides-Extra: testing
47
50
  Description-Content-Type: text/markdown
@@ -90,7 +93,7 @@ python -m unittest discover
90
93
 
91
94
  ## Contributing
92
95
 
93
- Please read [CONTRIBUTING.md](CONTRIBUTING) for details on our code of conduct, and the process for submitting pull requests to us.
96
+ Please read [CONTRIBUTING](CONTRIBUTING.md) for details on our code of conduct, and the process for submitting pull requests to us. See [CONTRIBUTORS](CONTRIBUTORS.md) for those that have authored code for and contributed to Eryn.
94
97
 
95
98
  ## Versioning
96
99
 
@@ -146,17 +149,6 @@ archivePrefix = {arXiv},
146
149
 
147
150
  Depending on which proposals are used, you may be required to cite more sources. Please make sure you do this properly.
148
151
 
149
- ## Authors
150
-
151
- * **Michael Katz**
152
- * Nikos Karnesis
153
- * Natalia Korsakova
154
- * Jonathan Gair
155
-
156
- ### Contibutors
157
-
158
- * Maybe you!
159
-
160
152
  ## License
161
153
 
162
154
  This project is licensed under the GNU License - see the [LICENSE.md](LICENSE) file for details.
@@ -1,9 +1,9 @@
1
1
  eryn/CMakeLists.txt,sha256=rs-_qMYpJryM_FyvERto4RgQQ_NV4lkYvFzCNU7vvFc,1736
2
2
  eryn/__init__.py,sha256=eMxCEUQyqtaUM8zTr6kDCxeuFWpxZsfY41TefWUNHXI,821
3
3
  eryn/backends/__init__.py,sha256=yRQszA4WSofDDsSpTsA1V9eNw-pLVO_qalP5wpKjyZQ,380
4
- eryn/backends/backend.py,sha256=VitOOK3vkzVlpzYj-y-_N0Q5GA6DBdm9ZwIMKvQjBOE,47011
5
- eryn/backends/hdfbackend.py,sha256=njW1KA2Anw9zxpLTYLkpNErNRBgNMA4VKidZXidkh-A,29414
6
- eryn/ensemble.py,sha256=TqpTLun3iydLOycEi2Gtlg0enLEL2raVrPyVMIQgn-o,71998
4
+ eryn/backends/backend.py,sha256=Udw_29TelVxAZsMNfsMiuIOUSlEUIRKTqXhfpNwnIng,47327
5
+ eryn/backends/hdfbackend.py,sha256=dYJZtgTyNmW7zGpN6bVP8iY16hViDq2Gj2h970_w-pw,30062
6
+ eryn/ensemble.py,sha256=-swmSbfrfxrQzP4yRUuG-GtQKvcaHPIgZMmyt595r0Q,72342
7
7
  eryn/git_version.py.in,sha256=dZ5WklaoF4dDsCVqhgw5jwr3kJCc8zjRX_LR90byZOw,139
8
8
  eryn/model.py,sha256=5TeWTI6V-Xcuy5C2LI6AmtZZU-EkRSSuA7VojXNALk8,284
9
9
  eryn/moves/__init__.py,sha256=9pWsSZSKLt05Ihd46vPASHwotTOHOPk_zEsCm8jWiw8,1081
@@ -24,14 +24,15 @@ eryn/moves/rj.py,sha256=6krjJ5EsvgLZMTMgE9rStjjKtBIW6nw87ywRYbYtROU,15915
24
24
  eryn/moves/stretch.py,sha256=auKjeN5elf9fqLR1-oDeR0pF1vdaRnKJEQcpI0mLgVU,8242
25
25
  eryn/moves/tempering.py,sha256=e2doT8jVWSuaPpVUKIkWQjRe20T0i98w70wi-dz7buo,23977
26
26
  eryn/pbar.py,sha256=uDDn8dMVHLD6EqZyk6vGhkOQwxgFm21Us9dz-nZE4oI,1330
27
- eryn/prior.py,sha256=x4E5NS4v7Odag7a30OXQ-kJuoU3a6M6JnJuKlWGO6F4,14393
27
+ eryn/prior.py,sha256=3JhtkcD3TqZCpu54T_CO_M0_4wfabA9CLf_V_OVKoFA,16034
28
28
  eryn/state.py,sha256=x4HZNrGhxnR6Ia2JrVskJGDS1Uk3AgQHgxJ4384Hpzs,31456
29
- eryn/utils/__init__.py,sha256=HzlQs1wg3J1xdrZjIMO34QHd0ZT58SQFCKEdclj7vpM,250
30
- eryn/utils/periodic.py,sha256=Q07HKMNeUN8V_rauUjT7fKRwlYOd2AFsa9DekuRYUbk,4135
29
+ eryn/utils/__init__.py,sha256=2jn40OhyhBETx0O6PExDQ104UigIe3xx76KBVZaYZSQ,248
30
+ eryn/utils/periodic.py,sha256=w0C7YT6v7iVkvhP2OAg0UZ3aMZRUUCyLXrxAwT1pOeY,4745
31
+ eryn/utils/plot.py,sha256=RAzhocHyk7xyxPVnisnZ4ri8OObFL16KJG7XxkSITEM,56841
31
32
  eryn/utils/stopping.py,sha256=fX1np10U3B-fpI3dGqEPZfqeYt8dc0x3PQGwrvYbbFU,5095
32
- eryn/utils/transform.py,sha256=wzOYow7xHjqVOi8ZQDXBeoFj9y53cCtIeLggrQuo_sc,8895
33
- eryn/utils/updates.py,sha256=U3T9UxPLabJzJuuB9s2OuX3vMD_2P7486SkgaFEkbLw,2137
34
- eryn/utils/utility.py,sha256=mgmfoL0BFFb3hho7OAQSJLO7T_erx6f6t38V-5yKSA4,11296
35
- eryn-1.2.4.dist-info/WHEEL,sha256=eh7sammvW2TypMMMGKgsM83HyA_3qQ5Lgg3ynoecH3M,79
36
- eryn-1.2.4.dist-info/METADATA,sha256=IBGDzBc3Esx7RPZ_RKSFQewgaitRnFyVEaacPn3-9MA,6240
37
- eryn-1.2.4.dist-info/RECORD,,
33
+ eryn/utils/transform.py,sha256=PqtGG__DAtIuRJm2SR9_ZuRyPvlUmh-eplgbpmZZghQ,9467
34
+ eryn/utils/updates.py,sha256=W9oWUkUXzofTX21MDx8K89GMR_x-SYJERGphvGUcGF8,6003
35
+ eryn/utils/utility.py,sha256=9i3xe0SllA31jWvO_D39dGN_mi8r_va8Pe-dOWhasaM,11280
36
+ eryn-1.2.6.dist-info/WHEEL,sha256=eh7sammvW2TypMMMGKgsM83HyA_3qQ5Lgg3ynoecH3M,79
37
+ eryn-1.2.6.dist-info/METADATA,sha256=qN0g_VqhB1a9_uYFKzzSorCQBfLnecSB9MyfCU-N3gY,6290
38
+ eryn-1.2.6.dist-info/RECORD,,
File without changes