eryn 1.2.4__tar.gz → 1.2.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {eryn-1.2.4 → eryn-1.2.6}/PKG-INFO +7 -15
- {eryn-1.2.4 → eryn-1.2.6}/README.md +1 -12
- {eryn-1.2.4 → eryn-1.2.6}/pyproject.toml +6 -3
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/backends/backend.py +6 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/backends/hdfbackend.py +15 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/ensemble.py +16 -13
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/prior.py +47 -2
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/utils/__init__.py +1 -1
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/utils/periodic.py +19 -3
- eryn-1.2.6/src/eryn/utils/plot.py +1393 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/utils/transform.py +52 -39
- eryn-1.2.6/src/eryn/utils/updates.py +175 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/utils/utility.py +1 -1
- eryn-1.2.4/src/eryn/utils/updates.py +0 -69
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/CMakeLists.txt +0 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/__init__.py +0 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/backends/__init__.py +0 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/git_version.py.in +0 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/model.py +0 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/moves/__init__.py +0 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/moves/combine.py +0 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/moves/delayedrejection.py +0 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/moves/distgen.py +0 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/moves/distgenrj.py +0 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/moves/gaussian.py +0 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/moves/group.py +0 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/moves/groupstretch.py +0 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/moves/mh.py +0 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/moves/move.py +0 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/moves/mtdistgen.py +0 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/moves/mtdistgenrj.py +0 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/moves/multipletry.py +0 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/moves/red_blue.py +0 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/moves/rj.py +0 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/moves/stretch.py +0 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/moves/tempering.py +0 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/pbar.py +0 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/state.py +0 -0
- {eryn-1.2.4 → eryn-1.2.6}/src/eryn/utils/stopping.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: eryn
|
|
3
|
-
Version: 1.2.
|
|
3
|
+
Version: 1.2.6
|
|
4
4
|
Summary: Eryn: an omni-MCMC sampling package.
|
|
5
5
|
Author: Michael Katz
|
|
6
6
|
Author-email: Michael Katz <mikekatz04@gmail.com>
|
|
@@ -9,7 +9,6 @@ Classifier: Natural Language :: English
|
|
|
9
9
|
Classifier: Programming Language :: C++
|
|
10
10
|
Classifier: Programming Language :: Cython
|
|
11
11
|
Classifier: Programming Language :: Python :: 3 :: Only
|
|
12
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
13
12
|
Classifier: Programming Language :: Python :: 3.10
|
|
14
13
|
Classifier: Programming Language :: Python :: 3.11
|
|
15
14
|
Classifier: Programming Language :: Python :: 3.12
|
|
@@ -19,6 +18,9 @@ Requires-Dist: h5py
|
|
|
19
18
|
Requires-Dist: jsonschema
|
|
20
19
|
Requires-Dist: matplotlib
|
|
21
20
|
Requires-Dist: numpy
|
|
21
|
+
Requires-Dist: pandas
|
|
22
|
+
Requires-Dist: corner
|
|
23
|
+
Requires-Dist: seaborn
|
|
22
24
|
Requires-Dist: nvidia-ml-py
|
|
23
25
|
Requires-Dist: platformdirs
|
|
24
26
|
Requires-Dist: pydantic
|
|
@@ -41,7 +43,8 @@ Requires-Dist: corner ; extra == 'doc'
|
|
|
41
43
|
Requires-Dist: matplotlib ; extra == 'testing'
|
|
42
44
|
Requires-Dist: corner ; extra == 'testing'
|
|
43
45
|
Requires-Dist: chainconsumer ; extra == 'testing'
|
|
44
|
-
Requires-
|
|
46
|
+
Requires-Dist: scienceplots ; extra == 'testing'
|
|
47
|
+
Requires-Python: >=3.10
|
|
45
48
|
Provides-Extra: doc
|
|
46
49
|
Provides-Extra: testing
|
|
47
50
|
Description-Content-Type: text/markdown
|
|
@@ -90,7 +93,7 @@ python -m unittest discover
|
|
|
90
93
|
|
|
91
94
|
## Contributing
|
|
92
95
|
|
|
93
|
-
Please read [CONTRIBUTING
|
|
96
|
+
Please read [CONTRIBUTING](CONTRIBUTING.md) for details on our code of conduct, and the process for submitting pull requests to us. See [CONTRIBUTORS](CONTRIBUTORS.md) for those that have authored code for and contributed to Eryn.
|
|
94
97
|
|
|
95
98
|
## Versioning
|
|
96
99
|
|
|
@@ -146,17 +149,6 @@ archivePrefix = {arXiv},
|
|
|
146
149
|
|
|
147
150
|
Depending on which proposals are used, you may be required to cite more sources. Please make sure you do this properly.
|
|
148
151
|
|
|
149
|
-
## Authors
|
|
150
|
-
|
|
151
|
-
* **Michael Katz**
|
|
152
|
-
* Nikos Karnesis
|
|
153
|
-
* Natalia Korsakova
|
|
154
|
-
* Jonathan Gair
|
|
155
|
-
|
|
156
|
-
### Contibutors
|
|
157
|
-
|
|
158
|
-
* Maybe you!
|
|
159
|
-
|
|
160
152
|
## License
|
|
161
153
|
|
|
162
154
|
This project is licensed under the GNU License - see the [LICENSE.md](LICENSE) file for details.
|
|
@@ -42,7 +42,7 @@ python -m unittest discover
|
|
|
42
42
|
|
|
43
43
|
## Contributing
|
|
44
44
|
|
|
45
|
-
Please read [CONTRIBUTING
|
|
45
|
+
Please read [CONTRIBUTING](CONTRIBUTING.md) for details on our code of conduct, and the process for submitting pull requests to us. See [CONTRIBUTORS](CONTRIBUTORS.md) for those that have authored code for and contributed to Eryn.
|
|
46
46
|
|
|
47
47
|
## Versioning
|
|
48
48
|
|
|
@@ -98,17 +98,6 @@ archivePrefix = {arXiv},
|
|
|
98
98
|
|
|
99
99
|
Depending on which proposals are used, you may be required to cite more sources. Please make sure you do this properly.
|
|
100
100
|
|
|
101
|
-
## Authors
|
|
102
|
-
|
|
103
|
-
* **Michael Katz**
|
|
104
|
-
* Nikos Karnesis
|
|
105
|
-
* Natalia Korsakova
|
|
106
|
-
* Jonathan Gair
|
|
107
|
-
|
|
108
|
-
### Contibutors
|
|
109
|
-
|
|
110
|
-
* Maybe you!
|
|
111
|
-
|
|
112
101
|
## License
|
|
113
102
|
|
|
114
103
|
This project is licensed under the GNU License - see the [LICENSE.md](LICENSE) file for details.
|
|
@@ -8,7 +8,7 @@ requires = [
|
|
|
8
8
|
[project]
|
|
9
9
|
name = "eryn" #@NAMESUFFIX@
|
|
10
10
|
|
|
11
|
-
version = "1.2.
|
|
11
|
+
version = "1.2.6"
|
|
12
12
|
|
|
13
13
|
description = "Eryn: an omni-MCMC sampling package."
|
|
14
14
|
|
|
@@ -17,7 +17,7 @@ readme = "README.md"
|
|
|
17
17
|
authors = [
|
|
18
18
|
{ name = "Michael Katz", email = "mikekatz04@gmail.com" },
|
|
19
19
|
]
|
|
20
|
-
requires-python = ">=3.
|
|
20
|
+
requires-python = ">=3.10"
|
|
21
21
|
|
|
22
22
|
classifiers = [
|
|
23
23
|
"License :: OSI Approved :: Apache Software License",
|
|
@@ -25,7 +25,6 @@ classifiers = [
|
|
|
25
25
|
"Programming Language :: C++",
|
|
26
26
|
"Programming Language :: Cython",
|
|
27
27
|
"Programming Language :: Python :: 3 :: Only",
|
|
28
|
-
"Programming Language :: Python :: 3.9",
|
|
29
28
|
"Programming Language :: Python :: 3.10",
|
|
30
29
|
"Programming Language :: Python :: 3.11",
|
|
31
30
|
"Programming Language :: Python :: 3.12",
|
|
@@ -38,6 +37,9 @@ dependencies = [
|
|
|
38
37
|
"jsonschema", # To validate content of file registry
|
|
39
38
|
"matplotlib",
|
|
40
39
|
"numpy",
|
|
40
|
+
"pandas",
|
|
41
|
+
"corner",
|
|
42
|
+
"seaborn", # for runtime plots
|
|
41
43
|
"nvidia-ml-py", # To detect CUDA version if any
|
|
42
44
|
"platformdirs", # To locate config and data dir on all platforms
|
|
43
45
|
"pydantic", # To handle citations and references with advanced dataclasses
|
|
@@ -68,6 +70,7 @@ optional-dependencies.testing = [
|
|
|
68
70
|
"matplotlib",
|
|
69
71
|
"corner",
|
|
70
72
|
"chainconsumer",
|
|
73
|
+
"scienceplots", # for runtime plots
|
|
71
74
|
]
|
|
72
75
|
|
|
73
76
|
[tool.pyproject-fmt]
|
|
@@ -83,6 +83,7 @@ class Backend(object):
|
|
|
83
83
|
nbranches=1,
|
|
84
84
|
rj=False,
|
|
85
85
|
moves=None,
|
|
86
|
+
key_order=None,
|
|
86
87
|
**info,
|
|
87
88
|
):
|
|
88
89
|
"""Clear the state of the chain and empty the backend
|
|
@@ -106,6 +107,9 @@ class Backend(object):
|
|
|
106
107
|
(default: ``False``)
|
|
107
108
|
moves (list, optional): List of all of the move classes input into the sampler.
|
|
108
109
|
(default: ``None``)
|
|
110
|
+
key_order (dict, optional): Keys are ``branch_names`` and values are lists of key ordering for each
|
|
111
|
+
branch. For example, ``{"model_0": ["x1", "x2", "x3"]}``.
|
|
112
|
+
(default: ``None``)
|
|
109
113
|
**info (dict, optional): Any other key-value pairs to be added
|
|
110
114
|
as attributes to the backend.
|
|
111
115
|
|
|
@@ -118,6 +122,7 @@ class Backend(object):
|
|
|
118
122
|
branch_names=branch_names,
|
|
119
123
|
rj=rj,
|
|
120
124
|
moves=moves,
|
|
125
|
+
key_order=key_order,
|
|
121
126
|
info=info,
|
|
122
127
|
)
|
|
123
128
|
|
|
@@ -184,6 +189,7 @@ class Backend(object):
|
|
|
184
189
|
self.branch_names = branch_names
|
|
185
190
|
self.ndims = ndims
|
|
186
191
|
self.nleaves_max = nleaves_max
|
|
192
|
+
self.key_order = key_order
|
|
187
193
|
|
|
188
194
|
self.iteration = 0
|
|
189
195
|
|
|
@@ -176,6 +176,7 @@ class HDFBackend(Backend):
|
|
|
176
176
|
nbranches=1,
|
|
177
177
|
rj=False,
|
|
178
178
|
moves=None,
|
|
179
|
+
key_order=None,
|
|
179
180
|
**info,
|
|
180
181
|
):
|
|
181
182
|
"""Clear the state of the chain and empty the backend
|
|
@@ -199,6 +200,9 @@ class HDFBackend(Backend):
|
|
|
199
200
|
(default: ``False``)
|
|
200
201
|
moves (list, optional): List of all of the move classes input into the sampler.
|
|
201
202
|
(default: ``None``)
|
|
203
|
+
key_order (dict, optional): Keys are ``branch_names`` and values are lists of key ordering for each
|
|
204
|
+
branch. For example, ``{"model_0": ["x1", "x2", "x3"]}``.
|
|
205
|
+
(default: ``None``)
|
|
202
206
|
**info (dict, optional): Any other key-value pairs to be added
|
|
203
207
|
as attributes to the backend. These are also added to the HDF5 file.
|
|
204
208
|
|
|
@@ -343,6 +347,7 @@ class HDFBackend(Backend):
|
|
|
343
347
|
|
|
344
348
|
chain = g.create_group("chain")
|
|
345
349
|
inds = g.create_group("inds")
|
|
350
|
+
k_o_g = g.create_group("key_order")
|
|
346
351
|
|
|
347
352
|
for name in branch_names:
|
|
348
353
|
nleaves = self.nleaves_max[name]
|
|
@@ -365,6 +370,9 @@ class HDFBackend(Backend):
|
|
|
365
370
|
compression_opts=self.compression_opts,
|
|
366
371
|
)
|
|
367
372
|
|
|
373
|
+
if key_order is not None:
|
|
374
|
+
k_o_g.attrs[name] = key_order[name]
|
|
375
|
+
|
|
368
376
|
# store move specific information
|
|
369
377
|
if moves is not None:
|
|
370
378
|
move_group = g.create_group("moves")
|
|
@@ -388,6 +396,12 @@ class HDFBackend(Backend):
|
|
|
388
396
|
|
|
389
397
|
self.blobs = None
|
|
390
398
|
|
|
399
|
+
@property
|
|
400
|
+
def key_order(self):
|
|
401
|
+
"""Key order of parameters for each model."""
|
|
402
|
+
with self.open() as f:
|
|
403
|
+
return {key: value for key, value in f[self.name]["key_order"].attrs.items()}
|
|
404
|
+
|
|
391
405
|
@property
|
|
392
406
|
def nwalkers(self):
|
|
393
407
|
"""Get nwalkers from h5 file."""
|
|
@@ -456,6 +470,7 @@ class HDFBackend(Backend):
|
|
|
456
470
|
branch_names=self.branch_names,
|
|
457
471
|
rj=self.rj,
|
|
458
472
|
moves=self.moves,
|
|
473
|
+
key_order=self.key_order,
|
|
459
474
|
)
|
|
460
475
|
|
|
461
476
|
@property
|
|
@@ -13,7 +13,7 @@ from .pbar import get_progress_bar
|
|
|
13
13
|
from .state import State
|
|
14
14
|
from .prior import ProbDistContainer
|
|
15
15
|
|
|
16
|
-
|
|
16
|
+
from .utils import PlotContainer
|
|
17
17
|
from .utils import PeriodicContainer
|
|
18
18
|
from .utils.utility import groups_from_inds
|
|
19
19
|
|
|
@@ -198,6 +198,7 @@ class EnsembleSampler(object):
|
|
|
198
198
|
if it is changed. In this case, the user should declare a new backend and use the last
|
|
199
199
|
state from the previous backend. **Warning**: If the order of moves of the same move class
|
|
200
200
|
is changed, the check may not catch it, so the tracking may mix move acceptance fractions together.
|
|
201
|
+
(default: ``True'')
|
|
201
202
|
info (dict, optional): Key and value pairs reprenting any information
|
|
202
203
|
the user wants to add to the backend if the user is not inputing
|
|
203
204
|
their own backend.
|
|
@@ -232,7 +233,7 @@ class EnsembleSampler(object):
|
|
|
232
233
|
blobs_dtype=None, # TODO check this
|
|
233
234
|
plot_iterations=-1, # TODO: do plot stuff?
|
|
234
235
|
plot_generator=None,
|
|
235
|
-
|
|
236
|
+
plot_folder=None,
|
|
236
237
|
periodic=None,
|
|
237
238
|
update_fn=None,
|
|
238
239
|
update_iterations=-1,
|
|
@@ -597,6 +598,7 @@ class EnsembleSampler(object):
|
|
|
597
598
|
nleaves_max=nleaves_max,
|
|
598
599
|
rj=self.has_reversible_jump,
|
|
599
600
|
moves=move_keys,
|
|
601
|
+
key_order=self.key_order,
|
|
600
602
|
**info,
|
|
601
603
|
)
|
|
602
604
|
state = np.random.get_state()
|
|
@@ -615,6 +617,9 @@ class EnsembleSampler(object):
|
|
|
615
617
|
"Configuration of moves has changed. Cannot use the same backend. Declare a new backend and start from the previous state. If you would prefer not to track move acceptance fraction, set track_moves to False in the EnsembleSampler."
|
|
616
618
|
)
|
|
617
619
|
|
|
620
|
+
if self.key_order != self.backend.key_order:
|
|
621
|
+
raise ValueError("Input key order from priors does not match backend.")
|
|
622
|
+
|
|
618
623
|
# Check the backend shape
|
|
619
624
|
for i, (name, shape) in enumerate(self.backend.shape.items()):
|
|
620
625
|
test_shape = (
|
|
@@ -657,18 +662,15 @@ class EnsembleSampler(object):
|
|
|
657
662
|
self.plot_iterations = plot_iterations
|
|
658
663
|
|
|
659
664
|
if plot_generator is None and self.plot_iterations > 0:
|
|
660
|
-
raise NotImplementedError
|
|
661
665
|
# set to default if not provided
|
|
662
|
-
if
|
|
663
|
-
|
|
664
|
-
else:
|
|
665
|
-
name = "output"
|
|
666
|
+
if plot_folder is None:
|
|
667
|
+
plot_folder = "./runtime_plots"
|
|
666
668
|
self.plot_generator = PlotContainer(
|
|
667
|
-
|
|
669
|
+
backend=self.backend, plots=['base', 'rj'], parent_folder=plot_folder, discard=0.2
|
|
668
670
|
)
|
|
669
671
|
elif self.plot_iterations > 0:
|
|
670
|
-
raise NotImplementedError
|
|
671
672
|
self.plot_generator = plot_generator
|
|
673
|
+
self.plot_generator.backend = self.backend # make sure backend is correctly set
|
|
672
674
|
|
|
673
675
|
# prepare stopping functions
|
|
674
676
|
self.stopping_fn = stopping_fn
|
|
@@ -689,7 +691,7 @@ class EnsembleSampler(object):
|
|
|
689
691
|
|
|
690
692
|
"""
|
|
691
693
|
return self._random.get_state()
|
|
692
|
-
|
|
694
|
+
|
|
693
695
|
@random_state.setter # NOQA
|
|
694
696
|
def random_state(self, state):
|
|
695
697
|
"""
|
|
@@ -751,13 +753,14 @@ class EnsembleSampler(object):
|
|
|
751
753
|
else:
|
|
752
754
|
raise ValueError("Priors must be a dictionary.")
|
|
753
755
|
|
|
756
|
+
self.key_order = {key: value.key_order for key, value in self._priors.items()}
|
|
754
757
|
return
|
|
755
758
|
|
|
756
759
|
@property
|
|
757
760
|
def iteration(self):
|
|
758
761
|
return self.backend.iteration
|
|
759
762
|
|
|
760
|
-
def reset(self, **
|
|
763
|
+
def reset(self, **kwargs):
|
|
761
764
|
"""
|
|
762
765
|
Reset the backend.
|
|
763
766
|
|
|
@@ -765,7 +768,7 @@ class EnsembleSampler(object):
|
|
|
765
768
|
**info (dict, optional): information to pass to backend reset method.
|
|
766
769
|
|
|
767
770
|
"""
|
|
768
|
-
self.backend.reset(self.nwalkers, self.ndims, **
|
|
771
|
+
self.backend.reset(self.nwalkers, self.ndims, **kwargs)
|
|
769
772
|
|
|
770
773
|
def __getstate__(self):
|
|
771
774
|
# In order to be generally picklable, we need to discard the pool
|
|
@@ -1101,7 +1104,7 @@ class EnsembleSampler(object):
|
|
|
1101
1104
|
# diagnostic plots
|
|
1102
1105
|
# TODO: adjust diagnostic plots
|
|
1103
1106
|
if self.plot_iterations > 0 and (i + 1) % (self.plot_iterations) == 0:
|
|
1104
|
-
self.plot_generator.
|
|
1107
|
+
self.plot_generator.produce_plots(sampler=self) # TODO: remove defaults
|
|
1105
1108
|
|
|
1106
1109
|
# check for stopping before updating
|
|
1107
1110
|
if (
|
|
@@ -248,24 +248,69 @@ class ProbDistContainer:
|
|
|
248
248
|
# to separate out in list form
|
|
249
249
|
self.priors = []
|
|
250
250
|
|
|
251
|
+
self.has_strings = False
|
|
252
|
+
self.has_ints = False
|
|
253
|
+
|
|
254
|
+
# this is for the strings (for the ints it just counts them)
|
|
255
|
+
current_ind = 0
|
|
256
|
+
key_order = []
|
|
257
|
+
|
|
251
258
|
# setup lists
|
|
252
259
|
temp_inds = []
|
|
253
260
|
for inds, dist in priors_in.items():
|
|
254
261
|
# multiple index
|
|
255
262
|
if isinstance(inds, tuple):
|
|
256
|
-
|
|
263
|
+
inds_tmp = []
|
|
264
|
+
for i in range(len(inds)):
|
|
265
|
+
if isinstance(inds[i], str):
|
|
266
|
+
assert not self.has_ints
|
|
267
|
+
self.has_strings = True
|
|
268
|
+
inds_tmp.append(current_ind)
|
|
269
|
+
key_order.append(inds[i])
|
|
270
|
+
|
|
271
|
+
elif isinstance(inds[i], int):
|
|
272
|
+
assert not self.has_strings
|
|
273
|
+
self.has_ints = True
|
|
274
|
+
inds_tmp.append(i)
|
|
275
|
+
|
|
276
|
+
else:
|
|
277
|
+
raise ValueError("Index in tuple must be int or str and all be the same type.")
|
|
278
|
+
|
|
279
|
+
current_ind += 1
|
|
280
|
+
|
|
281
|
+
inds_in = np.asarray(inds_tmp)
|
|
257
282
|
self.priors.append([inds_in, dist])
|
|
258
283
|
|
|
259
284
|
# single index
|
|
260
285
|
elif isinstance(inds, int):
|
|
286
|
+
self.has_ints = True
|
|
287
|
+
assert not self.has_strings
|
|
261
288
|
inds_in = np.array([inds])
|
|
262
289
|
self.priors.append([inds_in, dist])
|
|
290
|
+
current_ind += 1
|
|
291
|
+
|
|
292
|
+
elif isinstance(inds, str):
|
|
293
|
+
assert not self.has_ints
|
|
294
|
+
self.has_strings = True
|
|
295
|
+
key_order.append(inds)
|
|
296
|
+
inds_in = np.array([current_ind])
|
|
297
|
+
current_ind += 1
|
|
298
|
+
self.priors.append([inds_in, dist])
|
|
263
299
|
|
|
264
300
|
else:
|
|
265
301
|
raise ValueError(
|
|
266
|
-
"Keys for prior dictionary must be an integer or tuple."
|
|
302
|
+
"Keys for prior dictionary must be an integer, string, or tuple."
|
|
267
303
|
)
|
|
268
304
|
|
|
305
|
+
if self.has_strings:
|
|
306
|
+
assert not self.has_ints
|
|
307
|
+
# key order is already set
|
|
308
|
+
self.key_order = key_order
|
|
309
|
+
|
|
310
|
+
if self.has_ints:
|
|
311
|
+
self.key_order = [i for i in range(current_ind)] # here current_ind is the total count
|
|
312
|
+
assert not self.has_strings
|
|
313
|
+
|
|
269
314
|
temp_inds.append(np.asarray([inds_in]))
|
|
270
315
|
|
|
271
316
|
uni_inds = np.unique(np.concatenate(temp_inds, axis=1).flatten())
|
|
@@ -18,15 +18,31 @@ class PeriodicContainer:
|
|
|
18
18
|
|
|
19
19
|
"""
|
|
20
20
|
|
|
21
|
-
def __init__(self, periodic):
|
|
21
|
+
def __init__(self, periodic, key_order=None):
|
|
22
22
|
|
|
23
23
|
# store all the information
|
|
24
24
|
self.periodic = periodic
|
|
25
|
+
inds_periodic = {}
|
|
26
|
+
periods = {}
|
|
27
|
+
for key in periodic:
|
|
28
|
+
if periodic[key] is None:
|
|
29
|
+
continue
|
|
30
|
+
inds_periodic[key] = []
|
|
31
|
+
periods[key] = []
|
|
32
|
+
for var, period in periodic[key].items():
|
|
33
|
+
if isinstance(var, str):
|
|
34
|
+
if key_order is None:
|
|
35
|
+
raise ValueError(f"If providing str values for the variable names, must provide key_order argument.")
|
|
36
|
+
|
|
37
|
+
index = key_order[key].index(var)
|
|
38
|
+
inds_periodic[key].append(index)
|
|
39
|
+
periods[key].append(period)
|
|
40
|
+
|
|
25
41
|
self.inds_periodic = {
|
|
26
|
-
key: np.asarray(
|
|
42
|
+
key: np.asarray(tmp) for key, tmp in inds_periodic.items()
|
|
27
43
|
}
|
|
28
44
|
self.periods = {
|
|
29
|
-
key: np.asarray(
|
|
45
|
+
key: np.asarray(tmp) for key, tmp in periods.items()
|
|
30
46
|
}
|
|
31
47
|
|
|
32
48
|
def distance(self, p1, p2, xp=None):
|