westpa 2022.12__cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of westpa might be problematic. Click here for more details.
- westpa/__init__.py +14 -0
- westpa/_version.py +21 -0
- westpa/analysis/__init__.py +5 -0
- westpa/analysis/core.py +746 -0
- westpa/analysis/statistics.py +27 -0
- westpa/analysis/trajectories.py +360 -0
- westpa/cli/__init__.py +0 -0
- westpa/cli/core/__init__.py +0 -0
- westpa/cli/core/w_fork.py +152 -0
- westpa/cli/core/w_init.py +230 -0
- westpa/cli/core/w_run.py +77 -0
- westpa/cli/core/w_states.py +212 -0
- westpa/cli/core/w_succ.py +99 -0
- westpa/cli/core/w_truncate.py +68 -0
- westpa/cli/tools/__init__.py +0 -0
- westpa/cli/tools/ploterr.py +506 -0
- westpa/cli/tools/plothist.py +706 -0
- westpa/cli/tools/w_assign.py +596 -0
- westpa/cli/tools/w_bins.py +166 -0
- westpa/cli/tools/w_crawl.py +119 -0
- westpa/cli/tools/w_direct.py +547 -0
- westpa/cli/tools/w_dumpsegs.py +94 -0
- westpa/cli/tools/w_eddist.py +506 -0
- westpa/cli/tools/w_fluxanl.py +376 -0
- westpa/cli/tools/w_ipa.py +833 -0
- westpa/cli/tools/w_kinavg.py +127 -0
- westpa/cli/tools/w_kinetics.py +96 -0
- westpa/cli/tools/w_multi_west.py +414 -0
- westpa/cli/tools/w_ntop.py +213 -0
- westpa/cli/tools/w_pdist.py +515 -0
- westpa/cli/tools/w_postanalysis_matrix.py +82 -0
- westpa/cli/tools/w_postanalysis_reweight.py +53 -0
- westpa/cli/tools/w_red.py +491 -0
- westpa/cli/tools/w_reweight.py +780 -0
- westpa/cli/tools/w_select.py +226 -0
- westpa/cli/tools/w_stateprobs.py +111 -0
- westpa/cli/tools/w_trace.py +599 -0
- westpa/core/__init__.py +0 -0
- westpa/core/_rc.py +673 -0
- westpa/core/binning/__init__.py +55 -0
- westpa/core/binning/_assign.cpython-313-x86_64-linux-gnu.so +0 -0
- westpa/core/binning/assign.py +455 -0
- westpa/core/binning/binless.py +96 -0
- westpa/core/binning/binless_driver.py +54 -0
- westpa/core/binning/binless_manager.py +190 -0
- westpa/core/binning/bins.py +47 -0
- westpa/core/binning/mab.py +506 -0
- westpa/core/binning/mab_driver.py +54 -0
- westpa/core/binning/mab_manager.py +198 -0
- westpa/core/data_manager.py +1694 -0
- westpa/core/extloader.py +74 -0
- westpa/core/h5io.py +995 -0
- westpa/core/kinetics/__init__.py +24 -0
- westpa/core/kinetics/_kinetics.cpython-313-x86_64-linux-gnu.so +0 -0
- westpa/core/kinetics/events.py +147 -0
- westpa/core/kinetics/matrates.py +156 -0
- westpa/core/kinetics/rate_averaging.py +266 -0
- westpa/core/progress.py +218 -0
- westpa/core/propagators/__init__.py +54 -0
- westpa/core/propagators/executable.py +719 -0
- westpa/core/reweight/__init__.py +14 -0
- westpa/core/reweight/_reweight.cpython-313-x86_64-linux-gnu.so +0 -0
- westpa/core/reweight/matrix.py +126 -0
- westpa/core/segment.py +119 -0
- westpa/core/sim_manager.py +835 -0
- westpa/core/states.py +359 -0
- westpa/core/systems.py +93 -0
- westpa/core/textio.py +74 -0
- westpa/core/trajectory.py +330 -0
- westpa/core/we_driver.py +910 -0
- westpa/core/wm_ops.py +43 -0
- westpa/core/yamlcfg.py +391 -0
- westpa/fasthist/__init__.py +34 -0
- westpa/fasthist/_fasthist.cpython-313-x86_64-linux-gnu.so +0 -0
- westpa/mclib/__init__.py +271 -0
- westpa/mclib/__main__.py +28 -0
- westpa/mclib/_mclib.cpython-313-x86_64-linux-gnu.so +0 -0
- westpa/oldtools/__init__.py +4 -0
- westpa/oldtools/aframe/__init__.py +35 -0
- westpa/oldtools/aframe/atool.py +75 -0
- westpa/oldtools/aframe/base_mixin.py +26 -0
- westpa/oldtools/aframe/binning.py +178 -0
- westpa/oldtools/aframe/data_reader.py +560 -0
- westpa/oldtools/aframe/iter_range.py +200 -0
- westpa/oldtools/aframe/kinetics.py +117 -0
- westpa/oldtools/aframe/mcbs.py +153 -0
- westpa/oldtools/aframe/output.py +39 -0
- westpa/oldtools/aframe/plotting.py +90 -0
- westpa/oldtools/aframe/trajwalker.py +126 -0
- westpa/oldtools/aframe/transitions.py +469 -0
- westpa/oldtools/cmds/__init__.py +0 -0
- westpa/oldtools/cmds/w_ttimes.py +361 -0
- westpa/oldtools/files.py +34 -0
- westpa/oldtools/miscfn.py +23 -0
- westpa/oldtools/stats/__init__.py +4 -0
- westpa/oldtools/stats/accumulator.py +35 -0
- westpa/oldtools/stats/edfs.py +129 -0
- westpa/oldtools/stats/mcbs.py +96 -0
- westpa/tools/__init__.py +33 -0
- westpa/tools/binning.py +472 -0
- westpa/tools/core.py +340 -0
- westpa/tools/data_reader.py +159 -0
- westpa/tools/dtypes.py +31 -0
- westpa/tools/iter_range.py +198 -0
- westpa/tools/kinetics_tool.py +340 -0
- westpa/tools/plot.py +283 -0
- westpa/tools/progress.py +17 -0
- westpa/tools/selected_segs.py +154 -0
- westpa/tools/wipi.py +751 -0
- westpa/trajtree/__init__.py +4 -0
- westpa/trajtree/_trajtree.cpython-313-x86_64-linux-gnu.so +0 -0
- westpa/trajtree/trajtree.py +117 -0
- westpa/westext/__init__.py +0 -0
- westpa/westext/adaptvoronoi/__init__.py +3 -0
- westpa/westext/adaptvoronoi/adaptVor_driver.py +214 -0
- westpa/westext/hamsm_restarting/__init__.py +3 -0
- westpa/westext/hamsm_restarting/example_overrides.py +35 -0
- westpa/westext/hamsm_restarting/restart_driver.py +1165 -0
- westpa/westext/stringmethod/__init__.py +11 -0
- westpa/westext/stringmethod/fourier_fitting.py +69 -0
- westpa/westext/stringmethod/string_driver.py +253 -0
- westpa/westext/stringmethod/string_method.py +306 -0
- westpa/westext/weed/BinCluster.py +180 -0
- westpa/westext/weed/ProbAdjustEquil.py +100 -0
- westpa/westext/weed/UncertMath.py +247 -0
- westpa/westext/weed/__init__.py +10 -0
- westpa/westext/weed/weed_driver.py +192 -0
- westpa/westext/wess/ProbAdjust.py +101 -0
- westpa/westext/wess/__init__.py +6 -0
- westpa/westext/wess/wess_driver.py +217 -0
- westpa/work_managers/__init__.py +57 -0
- westpa/work_managers/core.py +396 -0
- westpa/work_managers/environment.py +134 -0
- westpa/work_managers/mpi.py +318 -0
- westpa/work_managers/processes.py +187 -0
- westpa/work_managers/serial.py +28 -0
- westpa/work_managers/threads.py +79 -0
- westpa/work_managers/zeromq/__init__.py +20 -0
- westpa/work_managers/zeromq/core.py +641 -0
- westpa/work_managers/zeromq/node.py +131 -0
- westpa/work_managers/zeromq/work_manager.py +526 -0
- westpa/work_managers/zeromq/worker.py +320 -0
- westpa-2022.12.dist-info/AUTHORS +22 -0
- westpa-2022.12.dist-info/LICENSE +21 -0
- westpa-2022.12.dist-info/METADATA +193 -0
- westpa-2022.12.dist-info/RECORD +149 -0
- westpa-2022.12.dist-info/WHEEL +6 -0
- westpa-2022.12.dist-info/entry_points.txt +29 -0
- westpa-2022.12.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1165 @@
|
|
|
1
|
+
import h5py
|
|
2
|
+
import logging
|
|
3
|
+
import operator
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
import westpa
|
|
7
|
+
from westpa.cli.core import w_init
|
|
8
|
+
from westpa.cli.core import w_run
|
|
9
|
+
from westpa.core.extloader import get_object
|
|
10
|
+
from westpa.core.segment import Segment
|
|
11
|
+
from westpa import analysis
|
|
12
|
+
from westpa.core._rc import bins_from_yaml_dict
|
|
13
|
+
|
|
14
|
+
import json
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import shutil
|
|
18
|
+
import pickle
|
|
19
|
+
import importlib.util
|
|
20
|
+
|
|
21
|
+
import tqdm
|
|
22
|
+
|
|
23
|
+
import mdtraj as md
|
|
24
|
+
from rich.logging import RichHandler
|
|
25
|
+
|
|
26
|
+
from matplotlib import pyplot as plt
|
|
27
|
+
|
|
28
|
+
# Ensure this is installed via pip. msm_we's setup.py is all set up for that.
|
|
29
|
+
# Navigate to the folder where msm_we is, and run python3 -m pip install .
|
|
30
|
+
# If you're doing development on msm_we, add the -e flag to pip, i.e. "python3 -m pip install -e ."
|
|
31
|
+
# -e will install it in editable mode, so changes to msm_we will take effect next time it's imported.
|
|
32
|
+
# Otherwise, if you modify the msm_we code, you'll need to re-install it through pip.
|
|
33
|
+
from msm_we import msm_we
|
|
34
|
+
|
|
35
|
+
import ray
|
|
36
|
+
import tempfile
|
|
37
|
+
|
|
38
|
+
EPS = np.finfo(np.float64).eps
|
|
39
|
+
|
|
40
|
+
log = logging.getLogger(__name__)
|
|
41
|
+
log.setLevel("INFO")
|
|
42
|
+
log.propagate = False
|
|
43
|
+
log.addHandler(RichHandler())
|
|
44
|
+
|
|
45
|
+
msm_we_logger = logging.getLogger("msm_we.msm_we")
|
|
46
|
+
msm_we_logger.setLevel("INFO")
|
|
47
|
+
|
|
48
|
+
# Map structure types to extensions.
|
|
49
|
+
# This tells the plugin what extension to put on generated start-state files.
|
|
50
|
+
STRUCT_EXTENSIONS = {
|
|
51
|
+
md.formats.PDBTrajectoryFile: "pdb",
|
|
52
|
+
md.formats.AmberRestartFile: "rst7",
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
EXTENSION_LOCKFILE = 'doing_extension'
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def check_target_reached(h5_filename):
|
|
59
|
+
"""
|
|
60
|
+
Check if the target state was reached, given the data in a WEST H5 file.
|
|
61
|
+
|
|
62
|
+
Parameters
|
|
63
|
+
----------
|
|
64
|
+
h5_filename: string
|
|
65
|
+
Path to a WESTPA HDF5 data file
|
|
66
|
+
"""
|
|
67
|
+
with h5py.File(h5_filename, 'r') as h5_file:
|
|
68
|
+
# Get the key to the final iteration. Need to do -2 instead of -1 because there's an empty-ish final iteration
|
|
69
|
+
# written.
|
|
70
|
+
for iteration_key in list(h5_file['iterations'].keys())[-2:0:-1]:
|
|
71
|
+
endpoint_types = h5_file[f'iterations/{iteration_key}/seg_index']['endpoint_type']
|
|
72
|
+
if Segment.SEG_ENDPOINT_RECYCLED in endpoint_types:
|
|
73
|
+
log.debug(f"recycled segment found in file {h5_filename} at iteration {iteration_key}")
|
|
74
|
+
return True
|
|
75
|
+
return False
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def fix_deprecated_initialization(initialization_state):
|
|
79
|
+
"""
|
|
80
|
+
I changed my initialization JSON schema to use underscores instead of hyphens so I can directly expand it into
|
|
81
|
+
keywords arguments to w_init. This just handles any old-style JSON files I still had, so they don't choke and die.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
log.debug(f"Starting processing, dict is now {initialization_state}")
|
|
85
|
+
|
|
86
|
+
# Some of my initial files had this old-style formatting. Handle it for now, but remove eventually
|
|
87
|
+
for old_key, new_key in [
|
|
88
|
+
('tstate-file', 'tstate_file'),
|
|
89
|
+
('bstate-file', 'bstate_file'),
|
|
90
|
+
('sstate-file', 'sstate_file'),
|
|
91
|
+
('segs-per-state', 'segs_per_state'),
|
|
92
|
+
]:
|
|
93
|
+
if old_key in initialization_state.keys():
|
|
94
|
+
log.warning(
|
|
95
|
+
f"This initialization JSON file uses the deprecated " f"hyphenated form for {old_key}. Replace with underscores."
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
value = initialization_state.pop(old_key)
|
|
99
|
+
initialization_state[new_key] = value
|
|
100
|
+
|
|
101
|
+
log.debug(f"Finished processing, dict is now {initialization_state}")
|
|
102
|
+
return initialization_state
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
# TODO: Break this out into a separate module, let it be specified (if it's necessary) as a plugin option
|
|
106
|
+
# This may not always be required -- i.e. you may be able to directly output to the h5 file in your propagator
|
|
107
|
+
def prepare_coordinates(plugin_config, h5file, we_h5filename):
|
|
108
|
+
"""
|
|
109
|
+
Copy relevant coordinates from trajectory files into <iteration>/auxdata/coord of the h5 file.
|
|
110
|
+
|
|
111
|
+
Directly modifies the input h5 file.
|
|
112
|
+
|
|
113
|
+
Adds ALL coordinates to auxdata/coord.
|
|
114
|
+
|
|
115
|
+
Adapted from original msmWE collectCoordinates.py script.
|
|
116
|
+
|
|
117
|
+
Parameters
|
|
118
|
+
----------
|
|
119
|
+
plugin_config: YAMLConfig object
|
|
120
|
+
Stores the configuration options provided to the plugin in the WESTPA configuration file
|
|
121
|
+
|
|
122
|
+
h5file: h5py.File
|
|
123
|
+
WESTPA h5 data file
|
|
124
|
+
|
|
125
|
+
we_h5filename: string
|
|
126
|
+
Name of the WESTPA h5 file
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
refPDBfile = plugin_config.get('ref_pdb_file')
|
|
130
|
+
modelName = plugin_config.get('model_name')
|
|
131
|
+
|
|
132
|
+
# TODO: Don't need this explicit option, use WEST_SIM_ROOT or something
|
|
133
|
+
WEfolder = plugin_config.get('we_folder')
|
|
134
|
+
|
|
135
|
+
parentTraj = plugin_config.get('parent_traj_filename')
|
|
136
|
+
childTraj = plugin_config.get('child_traj_filename')
|
|
137
|
+
pcoord_ndim = plugin_config.get('pcoord_ndim', 1)
|
|
138
|
+
|
|
139
|
+
model = msm_we.modelWE()
|
|
140
|
+
log.info('Preparing coordinates...')
|
|
141
|
+
|
|
142
|
+
# Only need the model to get the number of iterations and atoms
|
|
143
|
+
# TODO: Replace this with something more lightweight, get directly from WE
|
|
144
|
+
log.debug(f'Doing collectCoordinates on WE file {we_h5filename}')
|
|
145
|
+
model.initialize(
|
|
146
|
+
[we_h5filename],
|
|
147
|
+
refPDBfile,
|
|
148
|
+
modelName,
|
|
149
|
+
# Pass some dummy arguments -- these aren't important, this model is just created for convenience
|
|
150
|
+
# in the coordinate collection. Dummy arguments prevent warnings from being raised.
|
|
151
|
+
basis_pcoord_bounds=None,
|
|
152
|
+
target_pcoord_bounds=None,
|
|
153
|
+
tau=1,
|
|
154
|
+
pcoord_ndim=pcoord_ndim,
|
|
155
|
+
_suppress_boundary_warning=True,
|
|
156
|
+
)
|
|
157
|
+
model.get_iterations()
|
|
158
|
+
|
|
159
|
+
log.debug(f"Found {model.maxIter} iterations")
|
|
160
|
+
|
|
161
|
+
n_iter = None
|
|
162
|
+
for n_iter in tqdm.tqdm(range(1, model.maxIter + 1)):
|
|
163
|
+
nS = model.numSegments[n_iter - 1].astype(int)
|
|
164
|
+
coords = np.zeros((nS, 2, model.nAtoms, 3))
|
|
165
|
+
dsetName = "/iterations/iter_%08d/auxdata/coord" % int(n_iter)
|
|
166
|
+
|
|
167
|
+
coords_exist = False
|
|
168
|
+
try:
|
|
169
|
+
dset = h5file.create_dataset(dsetName, np.shape(coords))
|
|
170
|
+
except (RuntimeError, ValueError):
|
|
171
|
+
log.debug('coords exist for iteration ' + str(n_iter) + ' NOT overwritten')
|
|
172
|
+
coords_exist = True
|
|
173
|
+
continue
|
|
174
|
+
|
|
175
|
+
for iS in range(nS):
|
|
176
|
+
trajpath = WEfolder + "/traj_segs/%06d/%06d" % (n_iter, iS)
|
|
177
|
+
|
|
178
|
+
try:
|
|
179
|
+
coord0 = np.squeeze(md.load(f'{trajpath}/{parentTraj}', top=model.reference_structure.topology)._xyz)
|
|
180
|
+
except OSError:
|
|
181
|
+
log.warning("Parent traj file doesn't exist, loading reference structure coords")
|
|
182
|
+
coord0 = np.squeeze(model.reference_structure._xyz)
|
|
183
|
+
|
|
184
|
+
coord1 = np.squeeze(md.load(f'{trajpath}/{childTraj}', top=model.reference_structure.topology)._xyz)
|
|
185
|
+
|
|
186
|
+
coords[iS, 0, :, :] = coord0
|
|
187
|
+
coords[iS, 1, :, :] = coord1
|
|
188
|
+
|
|
189
|
+
if not coords_exist:
|
|
190
|
+
dset[:] = coords
|
|
191
|
+
|
|
192
|
+
log.debug(f"Wrote coords for {n_iter} iterations.")
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def msmwe_compute_ss(plugin_config, west_files):
|
|
196
|
+
"""
|
|
197
|
+
Prepare and initialize an msm_we model, and use it to predict a steady-state distribution.
|
|
198
|
+
|
|
199
|
+
1. Load coordinate data
|
|
200
|
+
2. Perform dimensionality reduction
|
|
201
|
+
3. Compute flux and transition matrices
|
|
202
|
+
4. Compute steady-state distribution (via eigenvectors of transition matrix)
|
|
203
|
+
5. Compute target-state flux
|
|
204
|
+
|
|
205
|
+
TODO
|
|
206
|
+
----
|
|
207
|
+
This function does far too many things. Break it up a bit.
|
|
208
|
+
|
|
209
|
+
Parameters
|
|
210
|
+
----------
|
|
211
|
+
plugin_config: YAMLConfig object
|
|
212
|
+
Stores the configuration options provided to the plugin in the WESTPA configuration file
|
|
213
|
+
|
|
214
|
+
last_iter: int
|
|
215
|
+
The last WE iteration to use for computing steady-state.
|
|
216
|
+
|
|
217
|
+
Returns
|
|
218
|
+
-------
|
|
219
|
+
ss_alg: np.ndarray
|
|
220
|
+
The steady-state distribution
|
|
221
|
+
|
|
222
|
+
ss_flux: float
|
|
223
|
+
Flux into target state
|
|
224
|
+
|
|
225
|
+
model: modelWE object
|
|
226
|
+
The modelWE object produced for analysis.
|
|
227
|
+
"""
|
|
228
|
+
|
|
229
|
+
n_lag = 0
|
|
230
|
+
|
|
231
|
+
log.debug("Initializing msm_we")
|
|
232
|
+
|
|
233
|
+
# TODO: Refactor this to use westpa.core.extloader.get_object
|
|
234
|
+
# I'm reinventing the wheel a bit here, I can replace almost all this code w/ that
|
|
235
|
+
# ##### Monkey-patch modelWE with the user-override functions
|
|
236
|
+
override_file = plugin_config.get('user_functions')
|
|
237
|
+
|
|
238
|
+
# First, import the file with the user-override functions
|
|
239
|
+
# This is a decently janky implementation, but it seems to work, and I don't know of a better way of doing it.
|
|
240
|
+
# This is nice because it avoids mucking around with the path, which I think is a Good Thing.
|
|
241
|
+
|
|
242
|
+
# We're given a path to the user-specified file containing overrides
|
|
243
|
+
# This comes from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
|
|
244
|
+
|
|
245
|
+
# I don't think the name provided here actually matters
|
|
246
|
+
user_override_spec = importlib.util.spec_from_file_location("override_module", override_file)
|
|
247
|
+
user_overrides = importlib.util.module_from_spec(user_override_spec)
|
|
248
|
+
|
|
249
|
+
# Make the functions that were overriden in override_file available in the namespace under user_overrides
|
|
250
|
+
user_override_spec.loader.exec_module(user_overrides)
|
|
251
|
+
|
|
252
|
+
# So now we can do the actual monkey-patching of modelWE.
|
|
253
|
+
# We monkey-patch at the module level rather than just override the function in the instanced object
|
|
254
|
+
# so that the functions retain access to self.
|
|
255
|
+
msm_we.modelWE.processCoordinates = user_overrides.processCoordinates
|
|
256
|
+
# ##### Done with monkey-patching.
|
|
257
|
+
|
|
258
|
+
model = msm_we.modelWE()
|
|
259
|
+
streaming = plugin_config.get('streaming', False)
|
|
260
|
+
|
|
261
|
+
refPDBfile = plugin_config.get('ref_pdb_file')
|
|
262
|
+
modelName = plugin_config.get('model_name')
|
|
263
|
+
n_clusters = plugin_config.get('n_clusters')
|
|
264
|
+
tau = plugin_config.get('tau', None)
|
|
265
|
+
pcoord_ndim = plugin_config.get('pcoord_ndim', 1)
|
|
266
|
+
|
|
267
|
+
basis_pcoord_bounds = np.array(plugin_config.get('basis_pcoord_bounds', np.nan), dtype=float)
|
|
268
|
+
target_pcoord_bounds = np.array(plugin_config.get('target_pcoord_bounds', np.nan), dtype=float)
|
|
269
|
+
|
|
270
|
+
user_bin_mapper = plugin_config.get('user_bin_mapper', None)
|
|
271
|
+
if user_bin_mapper is not None:
|
|
272
|
+
user_bin_mapper = bins_from_yaml_dict(user_bin_mapper)
|
|
273
|
+
|
|
274
|
+
if np.isnan(basis_pcoord_bounds).any() or np.isnan(target_pcoord_bounds).any():
|
|
275
|
+
log.critical(
|
|
276
|
+
"Target and/or basis pcoord bounds were not specified. "
|
|
277
|
+
"Set them using the 'basis_pcoord_bounds' or 'target_pcoord_bounds' parameters. "
|
|
278
|
+
"'basis/target_pcoord1_min/max' and 'basis/target_pcoord1' are no longer supported. "
|
|
279
|
+
"See https://jdrusso.github.io/msm_we/api.html#msm_we.msm_we.modelWE.initialize for details."
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
if tau is None:
|
|
283
|
+
log.warning('No tau provided to restarting plugin. Defaulting to 1.')
|
|
284
|
+
tau = 1
|
|
285
|
+
|
|
286
|
+
# Fire up the model object
|
|
287
|
+
model.initialize(
|
|
288
|
+
west_files,
|
|
289
|
+
refPDBfile,
|
|
290
|
+
modelName,
|
|
291
|
+
basis_pcoord_bounds=basis_pcoord_bounds,
|
|
292
|
+
target_pcoord_bounds=target_pcoord_bounds,
|
|
293
|
+
tau=tau,
|
|
294
|
+
pcoord_ndim=pcoord_ndim,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
model.dimReduceMethod = plugin_config.get('dim_reduce_method')
|
|
298
|
+
|
|
299
|
+
model.n_lag = n_lag
|
|
300
|
+
|
|
301
|
+
log.debug("Loading in iteration data.. (this could take a while)")
|
|
302
|
+
|
|
303
|
+
# First dimension is the total number of segments
|
|
304
|
+
model.get_iterations()
|
|
305
|
+
|
|
306
|
+
model.get_coordSet(model.maxIter)
|
|
307
|
+
|
|
308
|
+
model.dimReduce()
|
|
309
|
+
|
|
310
|
+
first_iter, last_iter = model.first_iter, model.maxIter
|
|
311
|
+
|
|
312
|
+
clusterFile = modelName + "_clusters_s" + str(first_iter) + "_e" + str(last_iter) + "_nC" + str(n_clusters) + ".h5"
|
|
313
|
+
# TODO: Uncomment this to actually load the clusterFile if it exists. For now, disable for development.
|
|
314
|
+
exists = os.path.isfile(clusterFile)
|
|
315
|
+
exists = False
|
|
316
|
+
log.warning("Skipping any potential cluster reloading!")
|
|
317
|
+
|
|
318
|
+
log.info(f"Launching Ray with {plugin_config.get('n_cpus', 1)} cpus")
|
|
319
|
+
|
|
320
|
+
ray_tempdir_root = plugin_config.get('ray_temp_dir', None)
|
|
321
|
+
if ray_tempdir_root is not None:
|
|
322
|
+
ray_tempdir = tempfile.TemporaryDirectory(dir=ray_tempdir_root)
|
|
323
|
+
log.info(f"Using {ray_tempdir.name} as temp_dir for Ray")
|
|
324
|
+
ray.init(
|
|
325
|
+
num_cpus=plugin_config.get('n_cpus', 1), _temp_dir=ray_tempdir.name, ignore_reinit_error=True, include_dashboard=False
|
|
326
|
+
)
|
|
327
|
+
else:
|
|
328
|
+
ray.init(num_cpus=plugin_config.get('n_cpus', 1), ignore_reinit_error=True, include_dashboard=False)
|
|
329
|
+
|
|
330
|
+
# If a cluster file with the name corresponding to these parameters exists, load clusters from it.
|
|
331
|
+
if exists:
|
|
332
|
+
log.debug("loading clusters...")
|
|
333
|
+
model.load_clusters(clusterFile)
|
|
334
|
+
# Otherwise, do the clustering (which will create and save to that file)
|
|
335
|
+
else:
|
|
336
|
+
# FIXME: This gives the wrong shape, but loading from the clusterfile gives the right shape
|
|
337
|
+
log.debug("clustering coordinates into " + str(n_clusters) + " clusters...")
|
|
338
|
+
model.cluster_coordinates(n_clusters, streaming=streaming, user_bin_mapper=user_bin_mapper)
|
|
339
|
+
|
|
340
|
+
first_iter = 1
|
|
341
|
+
model.get_fluxMatrix(n_lag, first_iter, last_iter) # extracts flux matrix, output model.fluxMatrixRaw
|
|
342
|
+
log.debug(f"Unprocessed flux matrix has shape {model.fluxMatrixRaw.shape}")
|
|
343
|
+
model.organize_fluxMatrix() # gets rid of bins with no connectivity, sorts along p1, output model.fluxMatrix
|
|
344
|
+
model.get_Tmatrix() # normalizes fluxMatrix to transition matrix, output model.Tmatrix
|
|
345
|
+
|
|
346
|
+
log.debug(f"Processed flux matrix has shape {model.fluxMatrix.shape}")
|
|
347
|
+
|
|
348
|
+
model.get_steady_state() # gets steady-state from eigen decomp, output model.pSS
|
|
349
|
+
model.get_steady_state_target_flux() # gets steady-state target flux, output model.JtargetSS
|
|
350
|
+
|
|
351
|
+
# Why is model.pss sometimes the wrong shape? It's "occasionally" returned as a nested array.
|
|
352
|
+
# Squeeze fixes it and removes the dimension of length 1, but why does it happen in the first place?
|
|
353
|
+
|
|
354
|
+
if type(model.pSS) is np.matrix:
|
|
355
|
+
ss_alg = np.squeeze(model.pSS.A)
|
|
356
|
+
else:
|
|
357
|
+
ss_alg = np.squeeze(model.pSS)
|
|
358
|
+
|
|
359
|
+
ss_flux = model.JtargetSS
|
|
360
|
+
|
|
361
|
+
log.debug("Got steady state:")
|
|
362
|
+
log.debug(ss_alg)
|
|
363
|
+
log.debug(ss_flux)
|
|
364
|
+
|
|
365
|
+
log.info("Completed flux matrix calculation and steady-state estimation!")
|
|
366
|
+
|
|
367
|
+
log.info("Starting block validation")
|
|
368
|
+
num_validation_groups = plugin_config.get('n_validation_groups', 2)
|
|
369
|
+
num_validation_blocks = plugin_config.get('n_validation_blocks', 4)
|
|
370
|
+
try:
|
|
371
|
+
model.do_block_validation(num_validation_groups, num_validation_blocks, use_ray=True)
|
|
372
|
+
except Exception as e:
|
|
373
|
+
log.exception(e)
|
|
374
|
+
log.error("Failed block validation! Continuing with restart, but BEWARE!")
|
|
375
|
+
ray.shutdown()
|
|
376
|
+
|
|
377
|
+
return ss_alg, ss_flux, model
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
class RestartDriver:
|
|
381
|
+
"""
|
|
382
|
+
WESTPA plugin to automatically handle estimating steady-state from a WE run, re-initializing a new WE run in that
|
|
383
|
+
steady-state, and then running that initialized WE run.
|
|
384
|
+
|
|
385
|
+
Data from the previous run will be stored in the restart<restart_number>/ subdirectory of $WEST_SIM_ROOT.
|
|
386
|
+
|
|
387
|
+
This plugin depends on having the start-states implementation in the main WESTPA code, which allows initializing
|
|
388
|
+
a WE run using states that are NOT later used for recycling.
|
|
389
|
+
|
|
390
|
+
These are used so that when the new WE run is initialized, initial structure selection is chosen by w_init, using
|
|
391
|
+
weights assigned to the start-states based on MSM bin weight and WE segment weight.
|
|
392
|
+
|
|
393
|
+
Since it closes out the current WE run and starts a new one, this plugin should run LAST, after all other plugins.
|
|
394
|
+
"""
|
|
395
|
+
|
|
396
|
+
def __init__(self, sim_manager, plugin_config):
|
|
397
|
+
"""
|
|
398
|
+
Initialize the RestartDriver plugin.
|
|
399
|
+
|
|
400
|
+
Pulls the data_manager and sim_manager from the WESTPA run that just completed, along with
|
|
401
|
+
"""
|
|
402
|
+
|
|
403
|
+
westpa.rc.pstatus("Restart plugin initialized")
|
|
404
|
+
|
|
405
|
+
if not sim_manager.work_manager.is_master:
|
|
406
|
+
westpa.rc.pstatus("Reweighting not master, skipping")
|
|
407
|
+
return
|
|
408
|
+
|
|
409
|
+
self.data_manager = sim_manager.data_manager
|
|
410
|
+
self.sim_manager = sim_manager
|
|
411
|
+
|
|
412
|
+
self.plugin_config = plugin_config
|
|
413
|
+
|
|
414
|
+
self.restart_file = plugin_config.get('restart_file', 'restart.dat')
|
|
415
|
+
self.initialization_file = plugin_config.get('initialization_file', 'restart_initialization.json')
|
|
416
|
+
|
|
417
|
+
self.extension_iters = plugin_config.get('extension_iters', 0)
|
|
418
|
+
self.max_total_iterations = westpa.rc.config.get(['west', 'propagation', 'max_total_iterations'], default=None)
|
|
419
|
+
self.base_total_iterations = self.max_total_iterations
|
|
420
|
+
|
|
421
|
+
self.coord_len = plugin_config.get('coord_len', 2)
|
|
422
|
+
self.n_restarts = plugin_config.get('n_restarts', -1)
|
|
423
|
+
self.n_runs = plugin_config.get('n_runs', 1)
|
|
424
|
+
|
|
425
|
+
# Number of CPUs available for parallelizing msm_we calculations
|
|
426
|
+
self.parallel_cpus = plugin_config.get('n_cpus', 1)
|
|
427
|
+
self.ray_tempdir = plugin_config.get('ray_temp_dir', None)
|
|
428
|
+
|
|
429
|
+
# .get() might return this as a bool anyways, but be safe
|
|
430
|
+
self.debug = bool(plugin_config.get('debug', False))
|
|
431
|
+
if self.debug:
|
|
432
|
+
log.setLevel("DEBUG")
|
|
433
|
+
msm_we_logger.setLevel("DEBUG")
|
|
434
|
+
|
|
435
|
+
# Default to using all restarts
|
|
436
|
+
self.restarts_to_use = plugin_config.get('n_restarts_to_use', self.n_restarts)
|
|
437
|
+
assert self.restarts_to_use > 0 or self.restarts_to_use == -1, "Invalid number of restarts to use"
|
|
438
|
+
if self.restarts_to_use >= 1:
|
|
439
|
+
assert (
|
|
440
|
+
self.restarts_to_use == self.restarts_to_use // 1
|
|
441
|
+
), "If choosing a decimal restarts_to_use, must be between 0 and 1."
|
|
442
|
+
|
|
443
|
+
struct_filetype = plugin_config.get('struct_filetype', 'mdtraj.formats.PDBTrajectoryFile')
|
|
444
|
+
self.struct_filetype = get_object(struct_filetype)
|
|
445
|
+
|
|
446
|
+
# This should be low priority, because it closes the H5 file and starts a new WE run. So it should run LAST
|
|
447
|
+
# after any other plugins.
|
|
448
|
+
self.priority = plugin_config.get('priority', 100) # I think a big number is lower priority...
|
|
449
|
+
|
|
450
|
+
sim_manager.register_callback(sim_manager.finalize_run, self.prepare_new_we, self.priority)
|
|
451
|
+
|
|
452
|
+
# Initialize data
|
|
453
|
+
self.ss_alg = None
|
|
454
|
+
self.ss_dist = None
|
|
455
|
+
self.model = None
|
|
456
|
+
|
|
457
|
+
def get_original_bins(self):
|
|
458
|
+
"""
|
|
459
|
+
Obtains the WE bins and their probabilities at the end of the previous iteration.
|
|
460
|
+
|
|
461
|
+
Returns
|
|
462
|
+
-------
|
|
463
|
+
bins : np.ndarray
|
|
464
|
+
Array of WE bins
|
|
465
|
+
|
|
466
|
+
binprobs: np.ndarray
|
|
467
|
+
WE bin weights
|
|
468
|
+
"""
|
|
469
|
+
|
|
470
|
+
we_driver = self.sim_manager.we_driver
|
|
471
|
+
bins = we_driver.next_iter_binning
|
|
472
|
+
n_bins = len(bins)
|
|
473
|
+
binprobs = np.fromiter(map(operator.attrgetter('weight'), bins), dtype=np.float64, count=n_bins)
|
|
474
|
+
|
|
475
|
+
return bins, binprobs
|
|
476
|
+
|
|
477
|
+
@property
|
|
478
|
+
def cur_iter(self):
|
|
479
|
+
"""
|
|
480
|
+
Get the current WE iteration.
|
|
481
|
+
|
|
482
|
+
Returns
|
|
483
|
+
-------
|
|
484
|
+
int: The current iteration. Subtract one, because in finalize_run the iter has been incremented
|
|
485
|
+
"""
|
|
486
|
+
return self.sim_manager.n_iter - 1
|
|
487
|
+
|
|
488
|
+
@property
|
|
489
|
+
def is_last_iteration(self):
|
|
490
|
+
"""
|
|
491
|
+
Get whether this is, or is past, the last iteration in this WE run.
|
|
492
|
+
|
|
493
|
+
Returns
|
|
494
|
+
-------
|
|
495
|
+
bool: Whether the current iteration is the final iteration
|
|
496
|
+
"""
|
|
497
|
+
|
|
498
|
+
final_iter = self.sim_manager.max_total_iterations
|
|
499
|
+
|
|
500
|
+
return self.cur_iter >= final_iter
|
|
501
|
+
|
|
502
|
+
def prepare_extension_run(self, run_number, restart_state, first_extension=False):
|
|
503
|
+
"""
|
|
504
|
+
Copy the necessary files for an extension run (versus initializing a fresh run)
|
|
505
|
+
|
|
506
|
+
Parameters
|
|
507
|
+
----------
|
|
508
|
+
run_number: int
|
|
509
|
+
The index of this run (should be 1-indexed!)
|
|
510
|
+
|
|
511
|
+
restart_state: dict
|
|
512
|
+
Dictionary holding the current state of the restarting procedure
|
|
513
|
+
|
|
514
|
+
first_extension: bool
|
|
515
|
+
True if this is the first run of an extension set. If True, then back up west.cfg, and write the extended
|
|
516
|
+
west.cfg file.
|
|
517
|
+
"""
|
|
518
|
+
|
|
519
|
+
log.debug(f"Linking run files from restart0/run{run_number}")
|
|
520
|
+
|
|
521
|
+
# Copy traj_segs, seg_logs, and west.h5 for restart0/runXX back into ./
|
|
522
|
+
# Later: (May only need to copy the latest iteration traj_segs, to avoid tons of back and forth)
|
|
523
|
+
try:
|
|
524
|
+
shutil.rmtree('traj_segs')
|
|
525
|
+
shutil.rmtree('seg_logs')
|
|
526
|
+
except OSError as e:
|
|
527
|
+
if str(e) == 'Cannot call rmtree on a symbolic link':
|
|
528
|
+
os.unlink('traj_segs')
|
|
529
|
+
os.unlink('seg_logs')
|
|
530
|
+
|
|
531
|
+
os.remove(self.data_manager.we_h5filename)
|
|
532
|
+
|
|
533
|
+
os.symlink(f'restart0/run{run_number}/traj_segs', 'traj_segs')
|
|
534
|
+
os.symlink(f'restart0/run{run_number}/seg_logs', 'seg_logs')
|
|
535
|
+
|
|
536
|
+
if first_extension:
|
|
537
|
+
# Get lines to make a new west.cfg by extending west.propagation.max_total_iterations
|
|
538
|
+
with open('west.cfg', 'r') as west_config:
|
|
539
|
+
lines = west_config.readlines()
|
|
540
|
+
for i, line in enumerate(lines):
|
|
541
|
+
# Parse out the number of maximum iterations
|
|
542
|
+
if 'max_total_iterations' in line:
|
|
543
|
+
max_iters = [int(i) for i in line.replace(':', ' ').replace('\n', ' ').split() if i.isdigit()]
|
|
544
|
+
new_max_iters = max_iters[0] + self.extension_iters
|
|
545
|
+
new_line = f"{line.split(':')[0]}: {new_max_iters}\n"
|
|
546
|
+
lines[i] = new_line
|
|
547
|
+
break
|
|
548
|
+
|
|
549
|
+
with open(self.restart_file, 'w') as fp:
|
|
550
|
+
json.dump(restart_state, fp)
|
|
551
|
+
|
|
552
|
+
log.info("First WE extension run ready!")
|
|
553
|
+
westpa.rc.pstatus(
|
|
554
|
+
f"\n\n===== Restart {restart_state['restarts_completed']}, "
|
|
555
|
+
+ f"Run {restart_state['runs_completed'] + 1} extension running =====\n"
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
# TODO: I can't just go straight into a w_run here. w_run expects some things to be set I think, that aren't.
|
|
559
|
+
# I can do w_init, and do a new simulation just fine...
|
|
560
|
+
# I can do this on ONE run repeatedly just fine
|
|
561
|
+
# But if I try to just copy files and continue like this, there's something screwy in state somewhere that
|
|
562
|
+
# causes it to fail.
|
|
563
|
+
# The error has to do with offsets in the HDF5 file?
|
|
564
|
+
# Need to figure out what state would be cleared by w_init
|
|
565
|
+
|
|
566
|
+
# Frankly, this is a really sketchy way of doing this, but it seems to work...
|
|
567
|
+
# I remain skeptical there's not something weird under the hood that isn't being addressed correctly with
|
|
568
|
+
# regard to state, but if it works, it's good enough for now..
|
|
569
|
+
westpa.rc.sim_manager.segments = None
|
|
570
|
+
shutil.copy(f'restart0/run{run_number}/west.h5', self.data_manager.we_h5filename)
|
|
571
|
+
self.data_manager.open_backing()
|
|
572
|
+
|
|
573
|
+
log.debug(f"Sim manager thought n_iter was {westpa.rc.sim_manager.n_iter}")
|
|
574
|
+
log.debug(f"Data manager thought current_iteration was {self.data_manager.current_iteration}")
|
|
575
|
+
log.debug(f"{self.sim_manager} vs {westpa.rc.sim_manager}")
|
|
576
|
+
|
|
577
|
+
if run_number == 1:
|
|
578
|
+
westpa.rc.sim_manager.max_total_iterations += self.extension_iters
|
|
579
|
+
|
|
580
|
+
w_run.run_simulation()
|
|
581
|
+
return
|
|
582
|
+
|
|
583
|
+
def generate_plots(self, restart_directory):
|
|
584
|
+
model = self.model
|
|
585
|
+
|
|
586
|
+
log.info("Producing flux-profile, pseudocommittor, and target flux comparison plots.")
|
|
587
|
+
flux_pcoord_fig, flux_pcoord_ax = plt.subplots()
|
|
588
|
+
model.plot_flux(ax=flux_pcoord_ax, suppress_validation=True)
|
|
589
|
+
flux_pcoord_fig.text(x=0.1, y=-0.05, s='This flux profile should become flatter after restarting', fontsize=12)
|
|
590
|
+
flux_pcoord_ax.legend(bbox_to_anchor=(1.01, 1.0), loc="upper left")
|
|
591
|
+
flux_pcoord_fig.savefig(f'{restart_directory}/flux_plot.pdf', bbox_inches="tight")
|
|
592
|
+
|
|
593
|
+
flux_pseudocomm_fig, flux_pseudocomm_ax = plt.subplots()
|
|
594
|
+
model.plot_flux_committor(ax=flux_pseudocomm_ax, suppress_validation=True)
|
|
595
|
+
flux_pseudocomm_fig.text(
|
|
596
|
+
x=0.1,
|
|
597
|
+
y=-0.05,
|
|
598
|
+
s='This flux profile should become flatter after restarting.'
|
|
599
|
+
'\nThe x-axis is a "pseudo"committor, since it may be '
|
|
600
|
+
'calculated from WE trajectories in the one-way ensemble.',
|
|
601
|
+
fontsize=12,
|
|
602
|
+
)
|
|
603
|
+
flux_pseudocomm_ax.legend(bbox_to_anchor=(1.01, 1.0), loc="upper left")
|
|
604
|
+
flux_pseudocomm_fig.savefig(f'{restart_directory}/pseudocomm-flux_plot.pdf', bbox_inches="tight")
|
|
605
|
+
|
|
606
|
+
flux_comparison_fig, flux_comparison_ax = plt.subplots(figsize=(7, 3))
|
|
607
|
+
# Get haMSM flux estimates
|
|
608
|
+
models = [model]
|
|
609
|
+
models.extend(model.validation_models)
|
|
610
|
+
n_validation_models = len(model.validation_models)
|
|
611
|
+
|
|
612
|
+
flux_estimates = []
|
|
613
|
+
for _model in models:
|
|
614
|
+
flux_estimates.append(_model.JtargetSS)
|
|
615
|
+
|
|
616
|
+
hamsm_flux_colors = iter(plt.rcParams['axes.prop_cycle'].by_key()['color'])
|
|
617
|
+
direct_flux_colors = iter(plt.cm.cool(np.linspace(0.2, 0.8, len(model.fileList))))
|
|
618
|
+
|
|
619
|
+
# Get WE direct flux estimate
|
|
620
|
+
for _file in model.fileList:
|
|
621
|
+
run = analysis.Run(_file)
|
|
622
|
+
last_iter = run.num_iterations
|
|
623
|
+
recycled = list(run.iteration(last_iter - 1).recycled_walkers)
|
|
624
|
+
target_flux = sum(walker.weight for walker in recycled) / model.tau
|
|
625
|
+
|
|
626
|
+
# TODO: Correct for time!
|
|
627
|
+
if len(_file) >= 15:
|
|
628
|
+
short_filename = f"....{_file[-12:]}"
|
|
629
|
+
else:
|
|
630
|
+
short_filename = _file
|
|
631
|
+
|
|
632
|
+
if target_flux == 0:
|
|
633
|
+
continue
|
|
634
|
+
|
|
635
|
+
flux_comparison_ax.axhline(
|
|
636
|
+
target_flux,
|
|
637
|
+
color=next(direct_flux_colors),
|
|
638
|
+
label=f"Last iter WE direct {target_flux:.2e}" f"\n ({short_filename})",
|
|
639
|
+
linestyle='--',
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
flux_comparison_ax.axhline(
|
|
643
|
+
flux_estimates[0], label=f"Main model estimate\n {flux_estimates[0]:.2e}", color=next(hamsm_flux_colors)
|
|
644
|
+
)
|
|
645
|
+
for i in range(1, n_validation_models + 1):
|
|
646
|
+
flux_comparison_ax.axhline(
|
|
647
|
+
flux_estimates[i],
|
|
648
|
+
label=f"Validation model {i - 1} estimate\n {flux_estimates[i]:.2e}",
|
|
649
|
+
color=next(hamsm_flux_colors),
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
flux_comparison_ax.legend(bbox_to_anchor=(1.01, 0.9), loc='upper left')
|
|
653
|
+
flux_comparison_ax.set_yscale('log')
|
|
654
|
+
flux_comparison_ax.set_ylabel('Flux')
|
|
655
|
+
flux_comparison_ax.set_xticks([])
|
|
656
|
+
flux_comparison_fig.tight_layout()
|
|
657
|
+
flux_comparison_fig.savefig(f'{restart_directory}/hamsm_vs_direct_flux_comparison_plot.pdf', bbox_inches="tight")
|
|
658
|
+
|
|
659
|
+
def prepare_new_we(self):
|
|
660
|
+
"""
|
|
661
|
+
This function prepares a new WESTPA simulation using haMSM analysis to accelerate convergence.
|
|
662
|
+
|
|
663
|
+
The marathon functionality does re-implement some of the functionality of w_multi_west.
|
|
664
|
+
However, w_multi_west merges independent WE simulations, which may or may not be desirable.
|
|
665
|
+
I think for the purposes of this, it's good to keep the runs completely independent until haMSM model building.
|
|
666
|
+
Either that, or I'm just justifying not having known about w_multi_west when I wrote this. TBD.
|
|
667
|
+
|
|
668
|
+
# TODO: Replace all manual path-building with pathlib
|
|
669
|
+
|
|
670
|
+
The algorithm is as follows:
|
|
671
|
+
1. Check to see if we've just completed the final iteration
|
|
672
|
+
2. Handle launching multiple runs, if desired
|
|
673
|
+
2. Build haMSM
|
|
674
|
+
3. Obtain structures for each haMSM bin
|
|
675
|
+
4. Make each structure a start-state, with probability set by (MSM-bin SS prob / # structures in bin)
|
|
676
|
+
5. Potentially some renormalization?
|
|
677
|
+
6. Start new WE simulation
|
|
678
|
+
"""
|
|
679
|
+
|
|
680
|
+
# Do nothing if it's not the final iteration
|
|
681
|
+
if not self.is_last_iteration:
|
|
682
|
+
print(self.cur_iter)
|
|
683
|
+
return
|
|
684
|
+
|
|
685
|
+
log.debug("Final iteration, preparing restart")
|
|
686
|
+
|
|
687
|
+
restart_state = {'restarts_completed': 0, 'runs_completed': 0}
|
|
688
|
+
|
|
689
|
+
# Check for the existence of the extension lockfile here
|
|
690
|
+
doing_extension = os.path.exists(EXTENSION_LOCKFILE)
|
|
691
|
+
|
|
692
|
+
# Look for a restart.dat file to get the current state (how many restarts have been performed already)
|
|
693
|
+
if os.path.exists(self.restart_file):
|
|
694
|
+
with open(self.restart_file, 'r') as fp:
|
|
695
|
+
restart_state = json.load(fp)
|
|
696
|
+
|
|
697
|
+
# This is the final iteration of a run, so mark this run as completed
|
|
698
|
+
restart_state['runs_completed'] += 1
|
|
699
|
+
|
|
700
|
+
# Make the folder to store data for this marathon
|
|
701
|
+
restart_directory = f"restart{restart_state['restarts_completed']}"
|
|
702
|
+
run_directory = f"{restart_directory}/run{restart_state['runs_completed']}"
|
|
703
|
+
if not os.path.exists(run_directory):
|
|
704
|
+
os.makedirs(run_directory)
|
|
705
|
+
|
|
706
|
+
# Write coordinates to h5
|
|
707
|
+
prepare_coordinates(self.plugin_config, self.data_manager.we_h5file, self.data_manager.we_h5filename)
|
|
708
|
+
|
|
709
|
+
for data_folder in ['traj_segs', 'seg_logs']:
|
|
710
|
+
old_path = data_folder
|
|
711
|
+
|
|
712
|
+
# If you're doing an extension, this will be a symlink. So no need to copy, just unlink it and move on
|
|
713
|
+
if doing_extension and os.path.islink(old_path):
|
|
714
|
+
log.debug('Unlinking symlink')
|
|
715
|
+
os.unlink(old_path)
|
|
716
|
+
os.mkdir(old_path)
|
|
717
|
+
continue
|
|
718
|
+
|
|
719
|
+
new_path = f"{run_directory}/{old_path}"
|
|
720
|
+
|
|
721
|
+
log.debug(f"Moving {old_path} to {new_path}")
|
|
722
|
+
|
|
723
|
+
if os.path.exists(new_path):
|
|
724
|
+
log.info(f"{new_path} already exists. Removing and overwriting.")
|
|
725
|
+
shutil.rmtree(new_path)
|
|
726
|
+
|
|
727
|
+
try:
|
|
728
|
+
os.rename(old_path, new_path)
|
|
729
|
+
except FileNotFoundError:
|
|
730
|
+
log.warning(f"Folder {old_path} was not found." "This may be normal, but check your configuration.")
|
|
731
|
+
else:
|
|
732
|
+
# Make a new data folder for the next run
|
|
733
|
+
os.mkdir(old_path)
|
|
734
|
+
|
|
735
|
+
last_run = restart_state['runs_completed'] >= self.n_runs
|
|
736
|
+
last_restart = restart_state['restarts_completed'] >= self.n_restarts
|
|
737
|
+
|
|
738
|
+
# We've just finished a run. Let's check if we have to do any more runs in this marathon before doing a restart.
|
|
739
|
+
# In the case of n_runs == 1, then we're just doing a single run and restarting it every so often.
|
|
740
|
+
# Otherwise, a marathon consists of multiple runs, and restarts are performed between marathons.
|
|
741
|
+
if last_run:
|
|
742
|
+
log.info(f"All {self.n_runs} runs in this marathon completed.")
|
|
743
|
+
|
|
744
|
+
if last_restart:
|
|
745
|
+
log.info("All restarts completed! Performing final analysis.")
|
|
746
|
+
|
|
747
|
+
else:
|
|
748
|
+
log.info("Proceeding to prepare a restart.")
|
|
749
|
+
|
|
750
|
+
# Duplicating this is gross, but given the structure here, my options are either put it above these ifs
|
|
751
|
+
# entirely, meaning it'll be unnecessarily run at the end of the final restart, or duplicate it below.
|
|
752
|
+
log.info("Preparing coordinates for this run.")
|
|
753
|
+
|
|
754
|
+
# Now, continue on to haMSM calculation below.
|
|
755
|
+
|
|
756
|
+
# If we have more runs left to do in this marathon, prepare them
|
|
757
|
+
elif not last_run:
|
|
758
|
+
log.info(f"Run {restart_state['runs_completed']}/{self.n_runs} completed.")
|
|
759
|
+
|
|
760
|
+
# TODO: Initialize a new run, from the same configuration as this run was
|
|
761
|
+
# On the 1st run, I can write bstates/tstates/sstates into restart files, and use those for spawning
|
|
762
|
+
# subsequent runs in the marathon. That way, I don't make unnecessary copies of all those.
|
|
763
|
+
# Basis and target states are unchanged. Can I get the original parameters passed to w_init?
|
|
764
|
+
# Ideally, I should be able to call w_init with the exact same parameters that went to it the first time
|
|
765
|
+
initialization_state = {
|
|
766
|
+
'tstate_file': None,
|
|
767
|
+
'bstate_file': None,
|
|
768
|
+
'sstate_file': None,
|
|
769
|
+
'tstates': None,
|
|
770
|
+
'bstates': None,
|
|
771
|
+
'sstates': None,
|
|
772
|
+
'segs_per_state': None,
|
|
773
|
+
}
|
|
774
|
+
|
|
775
|
+
# TODO: Implement this, and get rid of the initialization_file usage right below. Placeholder for now.
|
|
776
|
+
if restart_state['runs_completed'] == 1:
|
|
777
|
+
# Get and write basis, target, start states and segs per state for this marathon to disk
|
|
778
|
+
pass
|
|
779
|
+
|
|
780
|
+
# Save the WESTPA h5 data from this run
|
|
781
|
+
self.data_manager.finalize_run()
|
|
782
|
+
shutil.copyfile('west.h5', f"{run_directory}/west.h5")
|
|
783
|
+
|
|
784
|
+
# If this is a regular, fresh run (not an extension)
|
|
785
|
+
if not doing_extension:
|
|
786
|
+
if os.path.exists(self.initialization_file):
|
|
787
|
+
with open(self.initialization_file, 'r') as fp:
|
|
788
|
+
initialization_dict = json.load(fp)
|
|
789
|
+
initialization_dict = fix_deprecated_initialization(initialization_dict)
|
|
790
|
+
initialization_state.update(initialization_dict)
|
|
791
|
+
else:
|
|
792
|
+
raise Exception(
|
|
793
|
+
"No initialization JSON file provided -- " "I don't know how to start new runs in this marathon."
|
|
794
|
+
)
|
|
795
|
+
|
|
796
|
+
westpa.rc.pstatus(
|
|
797
|
+
f"\n\n===== Restart {restart_state['restarts_completed']}, "
|
|
798
|
+
+ f"Run {restart_state['runs_completed'] + 1} initializing =====\n"
|
|
799
|
+
)
|
|
800
|
+
|
|
801
|
+
westpa.rc.pstatus(
|
|
802
|
+
f"\nRun: \n\t w_init --tstate-file {initialization_state['tstate_file']} "
|
|
803
|
+
+ f"--bstate-file {initialization_state['bstate_file']} "
|
|
804
|
+
f"--sstate-file {initialization_state['sstate_file']} "
|
|
805
|
+
f"--segs-per-state {initialization_state['segs_per_state']}\n"
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
w_init.initialize(
|
|
809
|
+
**initialization_state,
|
|
810
|
+
shotgun=False,
|
|
811
|
+
)
|
|
812
|
+
|
|
813
|
+
with open(self.restart_file, 'w') as fp:
|
|
814
|
+
json.dump(restart_state, fp)
|
|
815
|
+
|
|
816
|
+
log.info("New WE run ready!")
|
|
817
|
+
westpa.rc.pstatus(
|
|
818
|
+
f"\n\n===== Restart {restart_state['restarts_completed']}, "
|
|
819
|
+
+ f"Run {restart_state['runs_completed'] + 1} running =====\n"
|
|
820
|
+
)
|
|
821
|
+
|
|
822
|
+
w_run.run_simulation()
|
|
823
|
+
return
|
|
824
|
+
|
|
825
|
+
# If we're doing an extension set
|
|
826
|
+
# Instead of w_initting a new iteration, copy the files from restart0/runXX back into ./
|
|
827
|
+
elif doing_extension:
|
|
828
|
+
self.prepare_extension_run(run_number=restart_state['runs_completed'] + 1, restart_state=restart_state)
|
|
829
|
+
return
|
|
830
|
+
|
|
831
|
+
log.debug(f"{restart_state['restarts_completed']}/{self.n_restarts} restarts completed")
|
|
832
|
+
|
|
833
|
+
# Build the haMSM
|
|
834
|
+
log.debug("Initializing haMSM")
|
|
835
|
+
|
|
836
|
+
# Need to write the h5 file and close it out, but I need to get the current bstates first.
|
|
837
|
+
original_bstates = self.sim_manager.current_iter_bstates
|
|
838
|
+
if original_bstates is None:
|
|
839
|
+
original_bstates = self.data_manager.get_basis_states(self.sim_manager.n_iter - 1)
|
|
840
|
+
|
|
841
|
+
assert original_bstates is not None, "Bstates are none in the current iteration"
|
|
842
|
+
|
|
843
|
+
original_tstates = self.data_manager.get_target_states(self.cur_iter)
|
|
844
|
+
|
|
845
|
+
# Flush h5 file writes and copy it to the run directory
|
|
846
|
+
self.data_manager.finalize_run()
|
|
847
|
+
shutil.copyfile(self.data_manager.we_h5filename, f"{run_directory}/west.h5")
|
|
848
|
+
|
|
849
|
+
# Use all files in all restarts
|
|
850
|
+
# Restarts index at 0, because there's a 0th restart before you've... restarted anything.
|
|
851
|
+
# Runs index at 1, because Run 1 is the first run.
|
|
852
|
+
# TODO: Let the user pick last half or something in the plugin config.
|
|
853
|
+
marathon_west_files = []
|
|
854
|
+
# When doing the first restart, restarts_completed is 0 (because the first restart isn't complete yet) and
|
|
855
|
+
# the data generated during this restart is in /restart0.
|
|
856
|
+
# So when doing the Nth restart, restarts_completed is N-1
|
|
857
|
+
|
|
858
|
+
# If set to -1, use all restarts
|
|
859
|
+
if self.restarts_to_use == -1:
|
|
860
|
+
last_N_restarts = 1 + restart_state['restarts_completed']
|
|
861
|
+
# If this is an integer, use the last N restarts
|
|
862
|
+
elif self.restarts_to_use >= 1:
|
|
863
|
+
last_N_restarts = self.restarts_to_use
|
|
864
|
+
# If it's a decimal between 0 and 1, use it as a fraction
|
|
865
|
+
# At restart 1, and a fraction of 0.5, this should just use restart 1
|
|
866
|
+
elif 0 < self.restarts_to_use < 1:
|
|
867
|
+
last_N_restarts = int(self.restarts_to_use * (1 + restart_state['restarts_completed']))
|
|
868
|
+
|
|
869
|
+
# If this fraction is <1, use all until it's not
|
|
870
|
+
if last_N_restarts < 1:
|
|
871
|
+
last_N_restarts = 1 + restart_state['restarts_completed']
|
|
872
|
+
|
|
873
|
+
log.debug(f"Last N is {last_N_restarts}")
|
|
874
|
+
first_restart = max(1 + restart_state['restarts_completed'] - last_N_restarts, 0)
|
|
875
|
+
usable_restarts = range(first_restart, 1 + restart_state['restarts_completed'])
|
|
876
|
+
|
|
877
|
+
log.info(
|
|
878
|
+
f"At restart {restart_state['restarts_completed']}, building haMSM using data from restarts {list(usable_restarts)}"
|
|
879
|
+
)
|
|
880
|
+
for restart_number in usable_restarts:
|
|
881
|
+
for run_number in range(1, 1 + restart_state['runs_completed']):
|
|
882
|
+
west_file_path = f"restart{restart_number}/run{run_number}/west.h5"
|
|
883
|
+
marathon_west_files.append(west_file_path)
|
|
884
|
+
|
|
885
|
+
log.debug(f"WESTPA datafile for analysis are {marathon_west_files}")
|
|
886
|
+
|
|
887
|
+
#
|
|
888
|
+
# If this is the first restart, check to see if you got any target state flux
|
|
889
|
+
if restart_state['restarts_completed'] == 0:
|
|
890
|
+
pass
|
|
891
|
+
|
|
892
|
+
# Check to see if you got any target flux in ANY runs
|
|
893
|
+
target_reached = False
|
|
894
|
+
for west_file_path in marathon_west_files:
|
|
895
|
+
if check_target_reached(west_file_path):
|
|
896
|
+
target_reached = True
|
|
897
|
+
break
|
|
898
|
+
|
|
899
|
+
# If you reached the target, clean up from the extensions and then continue as normal
|
|
900
|
+
# If extension_iters is set to 0, then don't do extensions.
|
|
901
|
+
if target_reached or self.extension_iters == 0:
|
|
902
|
+
log.info("All runs reached target!")
|
|
903
|
+
|
|
904
|
+
# Do some cleanup from the extension run
|
|
905
|
+
if doing_extension and not self.extension_iters == 0:
|
|
906
|
+
# Remove the doing_extensions.lck lockfile
|
|
907
|
+
os.remove(EXTENSION_LOCKFILE)
|
|
908
|
+
|
|
909
|
+
westpa.rc.sim_manager.max_total_iterations = self.base_total_iterations
|
|
910
|
+
|
|
911
|
+
# Otherwise, just continue as normal
|
|
912
|
+
pass
|
|
913
|
+
|
|
914
|
+
# If no runs reached the target, then we need to extend them
|
|
915
|
+
elif not target_reached:
|
|
916
|
+
log.info("Target not reached. Preparing for extensions.")
|
|
917
|
+
|
|
918
|
+
# Create the doing_extensions.lck "lockfile" to indicate we're in extend mode (or keep if exists)
|
|
919
|
+
# and write the initial number of iterations to it.
|
|
920
|
+
if not os.path.exists(EXTENSION_LOCKFILE):
|
|
921
|
+
with open(EXTENSION_LOCKFILE, 'w') as lockfile:
|
|
922
|
+
lockfile.write(str(self.max_total_iterations))
|
|
923
|
+
|
|
924
|
+
# Reset runs_completed to 0, and rewrite restart.dat accordingly
|
|
925
|
+
restart_state['runs_completed'] = 0
|
|
926
|
+
|
|
927
|
+
self.prepare_extension_run(run_number=1, restart_state=restart_state, first_extension=True)
|
|
928
|
+
return
|
|
929
|
+
|
|
930
|
+
log.debug("Building haMSM and computing steady-state")
|
|
931
|
+
log.debug(f"Cur iter is {self.cur_iter}")
|
|
932
|
+
ss_dist, ss_flux, model = msmwe_compute_ss(self.plugin_config, marathon_west_files)
|
|
933
|
+
self.ss_dist = ss_dist
|
|
934
|
+
self.model = model
|
|
935
|
+
|
|
936
|
+
log.debug(f'Steady-state distribution: {ss_dist}')
|
|
937
|
+
log.info(f"Target steady-state flux is {ss_flux}")
|
|
938
|
+
|
|
939
|
+
# Obtain cluster-structures
|
|
940
|
+
log.debug("Obtaining cluster-structures")
|
|
941
|
+
model.update_cluster_structures()
|
|
942
|
+
|
|
943
|
+
# TODO: Do this with pathlib
|
|
944
|
+
struct_directory = f"{restart_directory}/structs"
|
|
945
|
+
if not os.path.exists(struct_directory):
|
|
946
|
+
os.makedirs(struct_directory)
|
|
947
|
+
|
|
948
|
+
flux_filename = f"{restart_directory}/JtargetSS.txt"
|
|
949
|
+
with open(flux_filename, 'w') as fp:
|
|
950
|
+
log.debug(f"Writing flux to {flux_filename}")
|
|
951
|
+
fp.write(str(model.JtargetSS))
|
|
952
|
+
fp.close()
|
|
953
|
+
|
|
954
|
+
ss_filename = f"{restart_directory}/pSS.txt"
|
|
955
|
+
with open(ss_filename, 'w') as fp:
|
|
956
|
+
log.debug(f"Writing pSS to {ss_filename}")
|
|
957
|
+
np.savetxt(fp, model.pSS)
|
|
958
|
+
fp.close()
|
|
959
|
+
|
|
960
|
+
# If this is the last run of the last restart, do nothing and exit.
|
|
961
|
+
# if restart_state['runs_completed'] >= self.n_runs and restart_state['restarts_completed'] >= self.n_restarts:
|
|
962
|
+
# log.info("All restarts completed!")
|
|
963
|
+
# return
|
|
964
|
+
|
|
965
|
+
# Construct start-state file with all structures and their weights
|
|
966
|
+
# TODO: Don't explicitly write EVERY structure to disk, or this will be a nightmare for large runs.
|
|
967
|
+
# However, for now, it's fine...
|
|
968
|
+
log.debug("Writing structures")
|
|
969
|
+
# TODO: Include start states from previous runs
|
|
970
|
+
sstates_filename = f"{restart_directory}/startstates.txt"
|
|
971
|
+
with open(sstates_filename, 'w') as fp:
|
|
972
|
+
# Track the total number of segments iterated over
|
|
973
|
+
seg_idx = 0
|
|
974
|
+
|
|
975
|
+
log.info(f"Obtaining potential start structures ({len(model.cluster_structures.items())} bins avail)")
|
|
976
|
+
|
|
977
|
+
# Can use these for sanity checks
|
|
978
|
+
total_weight = 0.0
|
|
979
|
+
total_bin_weights = []
|
|
980
|
+
|
|
981
|
+
# Loop over each set of (bin index, all the structures in that bin)
|
|
982
|
+
for msm_bin_idx, structures in tqdm.tqdm(model.cluster_structures.items()):
|
|
983
|
+
total_bin_weights.append(0)
|
|
984
|
+
|
|
985
|
+
# Don't put structures in the basis or target
|
|
986
|
+
if msm_bin_idx in [model.n_clusters, model.n_clusters + 1]:
|
|
987
|
+
continue
|
|
988
|
+
|
|
989
|
+
# The per-segment bin probability.
|
|
990
|
+
# Map a cluster number onto a cluster INDEX, because after cleaning the cluster numbers may no longer
|
|
991
|
+
# be consecutive.
|
|
992
|
+
bin_prob = self.ss_dist[model.cluster_mapping[msm_bin_idx]] # / len(structures)
|
|
993
|
+
|
|
994
|
+
if bin_prob == 0:
|
|
995
|
+
log.info(f"MSM-Bin {msm_bin_idx} has probability 0, so not saving any structs from it.")
|
|
996
|
+
continue
|
|
997
|
+
|
|
998
|
+
# The total amount of WE weight in this MSM microbin
|
|
999
|
+
msm_bin_we_weight = sum(model.cluster_structure_weights[msm_bin_idx])
|
|
1000
|
+
|
|
1001
|
+
# Write each structure to disk. Loop over each structure within a bin.
|
|
1002
|
+
msm_bin_we_weight_tracker = 0
|
|
1003
|
+
for struct_idx, structure in enumerate(structures):
|
|
1004
|
+
structure_filename = (
|
|
1005
|
+
f"{struct_directory}/bin{msm_bin_idx}_" f"struct{struct_idx}.{STRUCT_EXTENSIONS[self.struct_filetype]}"
|
|
1006
|
+
)
|
|
1007
|
+
|
|
1008
|
+
with self.struct_filetype(structure_filename, 'w') as struct_file:
|
|
1009
|
+
# One structure per segment
|
|
1010
|
+
seg_we_weight = model.cluster_structure_weights[msm_bin_idx][struct_idx]
|
|
1011
|
+
msm_bin_we_weight_tracker += seg_we_weight
|
|
1012
|
+
|
|
1013
|
+
# Structure weights are set according to Algorithm 5.3 in
|
|
1014
|
+
# Aristoff, D. & Zuckerman, D. M. Optimizing Weighted Ensemble Sampling of Steady States.
|
|
1015
|
+
# Multiscale Model Sim 18, 646–673 (2020).
|
|
1016
|
+
structure_weight = seg_we_weight * (bin_prob / msm_bin_we_weight)
|
|
1017
|
+
|
|
1018
|
+
total_bin_weights[-1] += structure_weight
|
|
1019
|
+
total_weight += structure_weight
|
|
1020
|
+
|
|
1021
|
+
topology = model.reference_structure.topology
|
|
1022
|
+
|
|
1023
|
+
try:
|
|
1024
|
+
angles = model.reference_structure.unitcell_angles[0]
|
|
1025
|
+
lengths = model.reference_structure.unitcell_lengths[0] * 10
|
|
1026
|
+
# This throws typeerror if reference_structure.unitcell_angles is None, or AttributeError
|
|
1027
|
+
# if reference_structure.unitcell_angles doesn't exist.
|
|
1028
|
+
except (TypeError, AttributeError):
|
|
1029
|
+
angles, lengths = None, None
|
|
1030
|
+
|
|
1031
|
+
coords = structure * 10 # Correct units
|
|
1032
|
+
|
|
1033
|
+
# Write the structure file
|
|
1034
|
+
if self.struct_filetype is md.formats.PDBTrajectoryFile:
|
|
1035
|
+
struct_file.write(coords, topology, modelIndex=1, unitcell_angles=angles, unitcell_lengths=lengths)
|
|
1036
|
+
|
|
1037
|
+
elif self.struct_filetype is md.formats.AmberRestartFile:
|
|
1038
|
+
# AmberRestartFile takes slightly differently named keyword args
|
|
1039
|
+
struct_file.write(coords, time=None, cell_angles=angles, cell_lengths=lengths)
|
|
1040
|
+
|
|
1041
|
+
else:
|
|
1042
|
+
# Otherwise, YOLO just hope all the positional arguments are in the right place
|
|
1043
|
+
log.warning(
|
|
1044
|
+
f"This output filetype ({self.struct_filetype}) is probably supported, "
|
|
1045
|
+
f"but not explicitly handled."
|
|
1046
|
+
" You should ensure that it takes argument as (coords, topology)"
|
|
1047
|
+
)
|
|
1048
|
+
struct_file.write(coords, topology)
|
|
1049
|
+
raise Exception("Don't know what extension to use for this filetype")
|
|
1050
|
+
|
|
1051
|
+
# Add this start-state to the start-states file
|
|
1052
|
+
# This path is relative to WEST_SIM_ROOT
|
|
1053
|
+
fp.write(f'b{msm_bin_idx}_s{struct_idx} {structure_weight} {structure_filename}\n')
|
|
1054
|
+
seg_idx += 1
|
|
1055
|
+
|
|
1056
|
+
# log.info(f"WE weight ({msm_bin_we_weight_tracker:.5e} / {msm_bin_we_weight:.5e})")
|
|
1057
|
+
|
|
1058
|
+
# TODO: Fix this check. It's never quite worked right, nor has it ever caught an actual problem, so just
|
|
1059
|
+
# disable for now.
|
|
1060
|
+
# In equilibrium, all probabilities count, but in steady-state the last 2 are the target/basis
|
|
1061
|
+
# Subtract off the probabilities of the basis and target states, since those don't have structures
|
|
1062
|
+
# assigned to them.
|
|
1063
|
+
# assert np.isclose(total_weight, 1 - sum(model.pSS[model.n_clusters :])), (
|
|
1064
|
+
# f"Total steady-state structure weights not normalized! (Total: {total_weight}) "
|
|
1065
|
+
# f"\n\t pSS: {model.pSS}"
|
|
1066
|
+
# f"\n\t Total bin weights {total_bin_weights}"
|
|
1067
|
+
# f"\n\t pSS sum: {sum(model.pSS)}"
|
|
1068
|
+
# f"\n\t pSS -2 sum: {sum(model.pSS[:-2])}"
|
|
1069
|
+
# f"\n\t pSS (+target, no basis) sum: {sum(model.pSS[:-2]) + model.pSS[-1]}"
|
|
1070
|
+
# )
|
|
1071
|
+
|
|
1072
|
+
### Start the new simulation
|
|
1073
|
+
|
|
1074
|
+
bstates_str = ""
|
|
1075
|
+
for original_bstate in original_bstates:
|
|
1076
|
+
orig_bstate_prob = original_bstate.probability
|
|
1077
|
+
orig_bstate_label = original_bstate.label
|
|
1078
|
+
orig_bstate_aux = original_bstate.auxref
|
|
1079
|
+
|
|
1080
|
+
bstate_str = f"{orig_bstate_label} {orig_bstate_prob} {orig_bstate_aux}\n"
|
|
1081
|
+
|
|
1082
|
+
bstates_str += bstate_str
|
|
1083
|
+
|
|
1084
|
+
bstates_filename = f"{restart_directory}/basisstates.txt"
|
|
1085
|
+
with open(bstates_filename, 'w') as fp:
|
|
1086
|
+
fp.write(bstates_str)
|
|
1087
|
+
|
|
1088
|
+
tstates_str = ""
|
|
1089
|
+
for original_tstate in original_tstates:
|
|
1090
|
+
orig_tstate_label = original_tstate.label
|
|
1091
|
+
# TODO: Handle multidimensional pcoords
|
|
1092
|
+
orig_tstate_pcoord = original_tstate.pcoord[0]
|
|
1093
|
+
|
|
1094
|
+
tstate_str = f"{orig_tstate_label} {orig_tstate_pcoord}\n"
|
|
1095
|
+
tstates_str += tstate_str
|
|
1096
|
+
tstates_filename = f"{restart_directory}/targetstates.txt"
|
|
1097
|
+
with open(tstates_filename, 'w') as fp:
|
|
1098
|
+
fp.write(tstates_str)
|
|
1099
|
+
|
|
1100
|
+
# Pickle the model
|
|
1101
|
+
objFile = f"{restart_directory}/hamsm.obj"
|
|
1102
|
+
with open(objFile, "wb") as objFileHandler:
|
|
1103
|
+
log.debug("Pickling model")
|
|
1104
|
+
pickle.dump(model, objFileHandler, protocol=4)
|
|
1105
|
+
objFileHandler.close()
|
|
1106
|
+
|
|
1107
|
+
# Before finishing this restart, make a plot of the flux profile.
|
|
1108
|
+
# This is made so the user can see whether
|
|
1109
|
+
|
|
1110
|
+
self.generate_plots(restart_directory)
|
|
1111
|
+
|
|
1112
|
+
# At this point, the restart is completed, and the data for the next one is ready (though still need to make the
|
|
1113
|
+
# initialization file and such).
|
|
1114
|
+
|
|
1115
|
+
if last_restart:
|
|
1116
|
+
log.info("All restarts completed! Finished.")
|
|
1117
|
+
return
|
|
1118
|
+
|
|
1119
|
+
# Update restart_file file
|
|
1120
|
+
restart_state['restarts_completed'] += 1
|
|
1121
|
+
# If we're doing a restart, then reset the number of completed runs to 0 for the next marathon.
|
|
1122
|
+
restart_state['runs_completed'] = 0
|
|
1123
|
+
with open(self.restart_file, 'w') as fp:
|
|
1124
|
+
json.dump(restart_state, fp)
|
|
1125
|
+
|
|
1126
|
+
log.info("Initializing new run")
|
|
1127
|
+
|
|
1128
|
+
# TODO: Read this from config if available
|
|
1129
|
+
segs_per_state = 1
|
|
1130
|
+
|
|
1131
|
+
old_initialization_path = self.initialization_file
|
|
1132
|
+
new_initialization_path = f"{restart_directory}/{self.initialization_file}"
|
|
1133
|
+
log.debug(f"Moving initialization file from {old_initialization_path} to {new_initialization_path}.")
|
|
1134
|
+
shutil.move(old_initialization_path, new_initialization_path)
|
|
1135
|
+
|
|
1136
|
+
initialization_state = {
|
|
1137
|
+
'tstate_file': tstates_filename,
|
|
1138
|
+
'bstate_file': bstates_filename,
|
|
1139
|
+
'sstate_file': sstates_filename,
|
|
1140
|
+
'tstates': None,
|
|
1141
|
+
'bstates': None,
|
|
1142
|
+
'sstates': None,
|
|
1143
|
+
'segs_per_state': segs_per_state,
|
|
1144
|
+
}
|
|
1145
|
+
|
|
1146
|
+
with open(self.initialization_file, 'w') as fp:
|
|
1147
|
+
json.dump(initialization_state, fp)
|
|
1148
|
+
|
|
1149
|
+
westpa.rc.pstatus(
|
|
1150
|
+
f"\n\n"
|
|
1151
|
+
f"===== Restart {restart_state['restarts_completed']}, "
|
|
1152
|
+
+ f"Run {restart_state['runs_completed'] + 1} initializing =====\n"
|
|
1153
|
+
)
|
|
1154
|
+
|
|
1155
|
+
westpa.rc.pstatus(
|
|
1156
|
+
f"\nRun: \n\t w_init --tstate-file {tstates_filename} "
|
|
1157
|
+
+ f"--bstate-file {bstates_filename} --sstate-file {sstates_filename} --segs-per-state {segs_per_state}\n"
|
|
1158
|
+
)
|
|
1159
|
+
|
|
1160
|
+
w_init.initialize(**initialization_state, shotgun=False)
|
|
1161
|
+
|
|
1162
|
+
log.info("New WE run ready!")
|
|
1163
|
+
westpa.rc.pstatus(f"\n\n===== Restart {restart_state['restarts_completed']} running =====\n")
|
|
1164
|
+
|
|
1165
|
+
w_run.run_simulation()
|