westpa 2022.12__cp313-cp313-macosx_10_13_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of westpa might be problematic. Click here for more details.

Files changed (149) hide show
  1. westpa/__init__.py +14 -0
  2. westpa/_version.py +21 -0
  3. westpa/analysis/__init__.py +5 -0
  4. westpa/analysis/core.py +746 -0
  5. westpa/analysis/statistics.py +27 -0
  6. westpa/analysis/trajectories.py +360 -0
  7. westpa/cli/__init__.py +0 -0
  8. westpa/cli/core/__init__.py +0 -0
  9. westpa/cli/core/w_fork.py +152 -0
  10. westpa/cli/core/w_init.py +230 -0
  11. westpa/cli/core/w_run.py +77 -0
  12. westpa/cli/core/w_states.py +212 -0
  13. westpa/cli/core/w_succ.py +99 -0
  14. westpa/cli/core/w_truncate.py +68 -0
  15. westpa/cli/tools/__init__.py +0 -0
  16. westpa/cli/tools/ploterr.py +506 -0
  17. westpa/cli/tools/plothist.py +706 -0
  18. westpa/cli/tools/w_assign.py +596 -0
  19. westpa/cli/tools/w_bins.py +166 -0
  20. westpa/cli/tools/w_crawl.py +119 -0
  21. westpa/cli/tools/w_direct.py +547 -0
  22. westpa/cli/tools/w_dumpsegs.py +94 -0
  23. westpa/cli/tools/w_eddist.py +506 -0
  24. westpa/cli/tools/w_fluxanl.py +376 -0
  25. westpa/cli/tools/w_ipa.py +833 -0
  26. westpa/cli/tools/w_kinavg.py +127 -0
  27. westpa/cli/tools/w_kinetics.py +96 -0
  28. westpa/cli/tools/w_multi_west.py +414 -0
  29. westpa/cli/tools/w_ntop.py +213 -0
  30. westpa/cli/tools/w_pdist.py +515 -0
  31. westpa/cli/tools/w_postanalysis_matrix.py +82 -0
  32. westpa/cli/tools/w_postanalysis_reweight.py +53 -0
  33. westpa/cli/tools/w_red.py +491 -0
  34. westpa/cli/tools/w_reweight.py +780 -0
  35. westpa/cli/tools/w_select.py +226 -0
  36. westpa/cli/tools/w_stateprobs.py +111 -0
  37. westpa/cli/tools/w_trace.py +599 -0
  38. westpa/core/__init__.py +0 -0
  39. westpa/core/_rc.py +673 -0
  40. westpa/core/binning/__init__.py +55 -0
  41. westpa/core/binning/_assign.cpython-313-darwin.so +0 -0
  42. westpa/core/binning/assign.py +455 -0
  43. westpa/core/binning/binless.py +96 -0
  44. westpa/core/binning/binless_driver.py +54 -0
  45. westpa/core/binning/binless_manager.py +190 -0
  46. westpa/core/binning/bins.py +47 -0
  47. westpa/core/binning/mab.py +506 -0
  48. westpa/core/binning/mab_driver.py +54 -0
  49. westpa/core/binning/mab_manager.py +198 -0
  50. westpa/core/data_manager.py +1694 -0
  51. westpa/core/extloader.py +74 -0
  52. westpa/core/h5io.py +995 -0
  53. westpa/core/kinetics/__init__.py +24 -0
  54. westpa/core/kinetics/_kinetics.cpython-313-darwin.so +0 -0
  55. westpa/core/kinetics/events.py +147 -0
  56. westpa/core/kinetics/matrates.py +156 -0
  57. westpa/core/kinetics/rate_averaging.py +266 -0
  58. westpa/core/progress.py +218 -0
  59. westpa/core/propagators/__init__.py +54 -0
  60. westpa/core/propagators/executable.py +719 -0
  61. westpa/core/reweight/__init__.py +14 -0
  62. westpa/core/reweight/_reweight.cpython-313-darwin.so +0 -0
  63. westpa/core/reweight/matrix.py +126 -0
  64. westpa/core/segment.py +119 -0
  65. westpa/core/sim_manager.py +835 -0
  66. westpa/core/states.py +359 -0
  67. westpa/core/systems.py +93 -0
  68. westpa/core/textio.py +74 -0
  69. westpa/core/trajectory.py +330 -0
  70. westpa/core/we_driver.py +910 -0
  71. westpa/core/wm_ops.py +43 -0
  72. westpa/core/yamlcfg.py +391 -0
  73. westpa/fasthist/__init__.py +34 -0
  74. westpa/fasthist/_fasthist.cpython-313-darwin.so +0 -0
  75. westpa/mclib/__init__.py +271 -0
  76. westpa/mclib/__main__.py +28 -0
  77. westpa/mclib/_mclib.cpython-313-darwin.so +0 -0
  78. westpa/oldtools/__init__.py +4 -0
  79. westpa/oldtools/aframe/__init__.py +35 -0
  80. westpa/oldtools/aframe/atool.py +75 -0
  81. westpa/oldtools/aframe/base_mixin.py +26 -0
  82. westpa/oldtools/aframe/binning.py +178 -0
  83. westpa/oldtools/aframe/data_reader.py +560 -0
  84. westpa/oldtools/aframe/iter_range.py +200 -0
  85. westpa/oldtools/aframe/kinetics.py +117 -0
  86. westpa/oldtools/aframe/mcbs.py +153 -0
  87. westpa/oldtools/aframe/output.py +39 -0
  88. westpa/oldtools/aframe/plotting.py +90 -0
  89. westpa/oldtools/aframe/trajwalker.py +126 -0
  90. westpa/oldtools/aframe/transitions.py +469 -0
  91. westpa/oldtools/cmds/__init__.py +0 -0
  92. westpa/oldtools/cmds/w_ttimes.py +361 -0
  93. westpa/oldtools/files.py +34 -0
  94. westpa/oldtools/miscfn.py +23 -0
  95. westpa/oldtools/stats/__init__.py +4 -0
  96. westpa/oldtools/stats/accumulator.py +35 -0
  97. westpa/oldtools/stats/edfs.py +129 -0
  98. westpa/oldtools/stats/mcbs.py +96 -0
  99. westpa/tools/__init__.py +33 -0
  100. westpa/tools/binning.py +472 -0
  101. westpa/tools/core.py +340 -0
  102. westpa/tools/data_reader.py +159 -0
  103. westpa/tools/dtypes.py +31 -0
  104. westpa/tools/iter_range.py +198 -0
  105. westpa/tools/kinetics_tool.py +340 -0
  106. westpa/tools/plot.py +283 -0
  107. westpa/tools/progress.py +17 -0
  108. westpa/tools/selected_segs.py +154 -0
  109. westpa/tools/wipi.py +751 -0
  110. westpa/trajtree/__init__.py +4 -0
  111. westpa/trajtree/_trajtree.cpython-313-darwin.so +0 -0
  112. westpa/trajtree/trajtree.py +117 -0
  113. westpa/westext/__init__.py +0 -0
  114. westpa/westext/adaptvoronoi/__init__.py +3 -0
  115. westpa/westext/adaptvoronoi/adaptVor_driver.py +214 -0
  116. westpa/westext/hamsm_restarting/__init__.py +3 -0
  117. westpa/westext/hamsm_restarting/example_overrides.py +35 -0
  118. westpa/westext/hamsm_restarting/restart_driver.py +1165 -0
  119. westpa/westext/stringmethod/__init__.py +11 -0
  120. westpa/westext/stringmethod/fourier_fitting.py +69 -0
  121. westpa/westext/stringmethod/string_driver.py +253 -0
  122. westpa/westext/stringmethod/string_method.py +306 -0
  123. westpa/westext/weed/BinCluster.py +180 -0
  124. westpa/westext/weed/ProbAdjustEquil.py +100 -0
  125. westpa/westext/weed/UncertMath.py +247 -0
  126. westpa/westext/weed/__init__.py +10 -0
  127. westpa/westext/weed/weed_driver.py +192 -0
  128. westpa/westext/wess/ProbAdjust.py +101 -0
  129. westpa/westext/wess/__init__.py +6 -0
  130. westpa/westext/wess/wess_driver.py +217 -0
  131. westpa/work_managers/__init__.py +57 -0
  132. westpa/work_managers/core.py +396 -0
  133. westpa/work_managers/environment.py +134 -0
  134. westpa/work_managers/mpi.py +318 -0
  135. westpa/work_managers/processes.py +187 -0
  136. westpa/work_managers/serial.py +28 -0
  137. westpa/work_managers/threads.py +79 -0
  138. westpa/work_managers/zeromq/__init__.py +20 -0
  139. westpa/work_managers/zeromq/core.py +641 -0
  140. westpa/work_managers/zeromq/node.py +131 -0
  141. westpa/work_managers/zeromq/work_manager.py +526 -0
  142. westpa/work_managers/zeromq/worker.py +320 -0
  143. westpa-2022.12.dist-info/AUTHORS +22 -0
  144. westpa-2022.12.dist-info/LICENSE +21 -0
  145. westpa-2022.12.dist-info/METADATA +193 -0
  146. westpa-2022.12.dist-info/RECORD +149 -0
  147. westpa-2022.12.dist-info/WHEEL +6 -0
  148. westpa-2022.12.dist-info/entry_points.txt +29 -0
  149. westpa-2022.12.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1165 @@
1
+ import h5py
2
+ import logging
3
+ import operator
4
+ import numpy as np
5
+
6
+ import westpa
7
+ from westpa.cli.core import w_init
8
+ from westpa.cli.core import w_run
9
+ from westpa.core.extloader import get_object
10
+ from westpa.core.segment import Segment
11
+ from westpa import analysis
12
+ from westpa.core._rc import bins_from_yaml_dict
13
+
14
+ import json
15
+
16
+ import os
17
+ import shutil
18
+ import pickle
19
+ import importlib.util
20
+
21
+ import tqdm
22
+
23
+ import mdtraj as md
24
+ from rich.logging import RichHandler
25
+
26
+ from matplotlib import pyplot as plt
27
+
28
+ # Ensure this is installed via pip. msm_we's setup.py is all set up for that.
29
+ # Navigate to the folder where msm_we is, and run python3 -m pip install .
30
+ # If you're doing development on msm_we, add the -e flag to pip, i.e. "python3 -m pip install -e ."
31
+ # -e will install it in editable mode, so changes to msm_we will take effect next time it's imported.
32
+ # Otherwise, if you modify the msm_we code, you'll need to re-install it through pip.
33
+ from msm_we import msm_we
34
+
35
+ import ray
36
+ import tempfile
37
+
38
+ EPS = np.finfo(np.float64).eps
39
+
40
+ log = logging.getLogger(__name__)
41
+ log.setLevel("INFO")
42
+ log.propagate = False
43
+ log.addHandler(RichHandler())
44
+
45
+ msm_we_logger = logging.getLogger("msm_we.msm_we")
46
+ msm_we_logger.setLevel("INFO")
47
+
48
+ # Map structure types to extensions.
49
+ # This tells the plugin what extension to put on generated start-state files.
50
+ STRUCT_EXTENSIONS = {
51
+ md.formats.PDBTrajectoryFile: "pdb",
52
+ md.formats.AmberRestartFile: "rst7",
53
+ }
54
+
55
+ EXTENSION_LOCKFILE = 'doing_extension'
56
+
57
+
58
+ def check_target_reached(h5_filename):
59
+ """
60
+ Check if the target state was reached, given the data in a WEST H5 file.
61
+
62
+ Parameters
63
+ ----------
64
+ h5_filename: string
65
+ Path to a WESTPA HDF5 data file
66
+ """
67
+ with h5py.File(h5_filename, 'r') as h5_file:
68
+ # Get the key to the final iteration. Need to do -2 instead of -1 because there's an empty-ish final iteration
69
+ # written.
70
+ for iteration_key in list(h5_file['iterations'].keys())[-2:0:-1]:
71
+ endpoint_types = h5_file[f'iterations/{iteration_key}/seg_index']['endpoint_type']
72
+ if Segment.SEG_ENDPOINT_RECYCLED in endpoint_types:
73
+ log.debug(f"recycled segment found in file {h5_filename} at iteration {iteration_key}")
74
+ return True
75
+ return False
76
+
77
+
78
+ def fix_deprecated_initialization(initialization_state):
79
+ """
80
+ I changed my initialization JSON schema to use underscores instead of hyphens so I can directly expand it into
81
+ keywords arguments to w_init. This just handles any old-style JSON files I still had, so they don't choke and die.
82
+ """
83
+
84
+ log.debug(f"Starting processing, dict is now {initialization_state}")
85
+
86
+ # Some of my initial files had this old-style formatting. Handle it for now, but remove eventually
87
+ for old_key, new_key in [
88
+ ('tstate-file', 'tstate_file'),
89
+ ('bstate-file', 'bstate_file'),
90
+ ('sstate-file', 'sstate_file'),
91
+ ('segs-per-state', 'segs_per_state'),
92
+ ]:
93
+ if old_key in initialization_state.keys():
94
+ log.warning(
95
+ f"This initialization JSON file uses the deprecated " f"hyphenated form for {old_key}. Replace with underscores."
96
+ )
97
+
98
+ value = initialization_state.pop(old_key)
99
+ initialization_state[new_key] = value
100
+
101
+ log.debug(f"Finished processing, dict is now {initialization_state}")
102
+ return initialization_state
103
+
104
+
105
+ # TODO: Break this out into a separate module, let it be specified (if it's necessary) as a plugin option
106
+ # This may not always be required -- i.e. you may be able to directly output to the h5 file in your propagator
107
+ def prepare_coordinates(plugin_config, h5file, we_h5filename):
108
+ """
109
+ Copy relevant coordinates from trajectory files into <iteration>/auxdata/coord of the h5 file.
110
+
111
+ Directly modifies the input h5 file.
112
+
113
+ Adds ALL coordinates to auxdata/coord.
114
+
115
+ Adapted from original msmWE collectCoordinates.py script.
116
+
117
+ Parameters
118
+ ----------
119
+ plugin_config: YAMLConfig object
120
+ Stores the configuration options provided to the plugin in the WESTPA configuration file
121
+
122
+ h5file: h5py.File
123
+ WESTPA h5 data file
124
+
125
+ we_h5filename: string
126
+ Name of the WESTPA h5 file
127
+ """
128
+
129
+ refPDBfile = plugin_config.get('ref_pdb_file')
130
+ modelName = plugin_config.get('model_name')
131
+
132
+ # TODO: Don't need this explicit option, use WEST_SIM_ROOT or something
133
+ WEfolder = plugin_config.get('we_folder')
134
+
135
+ parentTraj = plugin_config.get('parent_traj_filename')
136
+ childTraj = plugin_config.get('child_traj_filename')
137
+ pcoord_ndim = plugin_config.get('pcoord_ndim', 1)
138
+
139
+ model = msm_we.modelWE()
140
+ log.info('Preparing coordinates...')
141
+
142
+ # Only need the model to get the number of iterations and atoms
143
+ # TODO: Replace this with something more lightweight, get directly from WE
144
+ log.debug(f'Doing collectCoordinates on WE file {we_h5filename}')
145
+ model.initialize(
146
+ [we_h5filename],
147
+ refPDBfile,
148
+ modelName,
149
+ # Pass some dummy arguments -- these aren't important, this model is just created for convenience
150
+ # in the coordinate collection. Dummy arguments prevent warnings from being raised.
151
+ basis_pcoord_bounds=None,
152
+ target_pcoord_bounds=None,
153
+ tau=1,
154
+ pcoord_ndim=pcoord_ndim,
155
+ _suppress_boundary_warning=True,
156
+ )
157
+ model.get_iterations()
158
+
159
+ log.debug(f"Found {model.maxIter} iterations")
160
+
161
+ n_iter = None
162
+ for n_iter in tqdm.tqdm(range(1, model.maxIter + 1)):
163
+ nS = model.numSegments[n_iter - 1].astype(int)
164
+ coords = np.zeros((nS, 2, model.nAtoms, 3))
165
+ dsetName = "/iterations/iter_%08d/auxdata/coord" % int(n_iter)
166
+
167
+ coords_exist = False
168
+ try:
169
+ dset = h5file.create_dataset(dsetName, np.shape(coords))
170
+ except (RuntimeError, ValueError):
171
+ log.debug('coords exist for iteration ' + str(n_iter) + ' NOT overwritten')
172
+ coords_exist = True
173
+ continue
174
+
175
+ for iS in range(nS):
176
+ trajpath = WEfolder + "/traj_segs/%06d/%06d" % (n_iter, iS)
177
+
178
+ try:
179
+ coord0 = np.squeeze(md.load(f'{trajpath}/{parentTraj}', top=model.reference_structure.topology)._xyz)
180
+ except OSError:
181
+ log.warning("Parent traj file doesn't exist, loading reference structure coords")
182
+ coord0 = np.squeeze(model.reference_structure._xyz)
183
+
184
+ coord1 = np.squeeze(md.load(f'{trajpath}/{childTraj}', top=model.reference_structure.topology)._xyz)
185
+
186
+ coords[iS, 0, :, :] = coord0
187
+ coords[iS, 1, :, :] = coord1
188
+
189
+ if not coords_exist:
190
+ dset[:] = coords
191
+
192
+ log.debug(f"Wrote coords for {n_iter} iterations.")
193
+
194
+
195
+ def msmwe_compute_ss(plugin_config, west_files):
196
+ """
197
+ Prepare and initialize an msm_we model, and use it to predict a steady-state distribution.
198
+
199
+ 1. Load coordinate data
200
+ 2. Perform dimensionality reduction
201
+ 3. Compute flux and transition matrices
202
+ 4. Compute steady-state distribution (via eigenvectors of transition matrix)
203
+ 5. Compute target-state flux
204
+
205
+ TODO
206
+ ----
207
+ This function does far too many things. Break it up a bit.
208
+
209
+ Parameters
210
+ ----------
211
+ plugin_config: YAMLConfig object
212
+ Stores the configuration options provided to the plugin in the WESTPA configuration file
213
+
214
+ last_iter: int
215
+ The last WE iteration to use for computing steady-state.
216
+
217
+ Returns
218
+ -------
219
+ ss_alg: np.ndarray
220
+ The steady-state distribution
221
+
222
+ ss_flux: float
223
+ Flux into target state
224
+
225
+ model: modelWE object
226
+ The modelWE object produced for analysis.
227
+ """
228
+
229
+ n_lag = 0
230
+
231
+ log.debug("Initializing msm_we")
232
+
233
+ # TODO: Refactor this to use westpa.core.extloader.get_object
234
+ # I'm reinventing the wheel a bit here, I can replace almost all this code w/ that
235
+ # ##### Monkey-patch modelWE with the user-override functions
236
+ override_file = plugin_config.get('user_functions')
237
+
238
+ # First, import the file with the user-override functions
239
+ # This is a decently janky implementation, but it seems to work, and I don't know of a better way of doing it.
240
+ # This is nice because it avoids mucking around with the path, which I think is a Good Thing.
241
+
242
+ # We're given a path to the user-specified file containing overrides
243
+ # This comes from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
244
+
245
+ # I don't think the name provided here actually matters
246
+ user_override_spec = importlib.util.spec_from_file_location("override_module", override_file)
247
+ user_overrides = importlib.util.module_from_spec(user_override_spec)
248
+
249
+ # Make the functions that were overriden in override_file available in the namespace under user_overrides
250
+ user_override_spec.loader.exec_module(user_overrides)
251
+
252
+ # So now we can do the actual monkey-patching of modelWE.
253
+ # We monkey-patch at the module level rather than just override the function in the instanced object
254
+ # so that the functions retain access to self.
255
+ msm_we.modelWE.processCoordinates = user_overrides.processCoordinates
256
+ # ##### Done with monkey-patching.
257
+
258
+ model = msm_we.modelWE()
259
+ streaming = plugin_config.get('streaming', False)
260
+
261
+ refPDBfile = plugin_config.get('ref_pdb_file')
262
+ modelName = plugin_config.get('model_name')
263
+ n_clusters = plugin_config.get('n_clusters')
264
+ tau = plugin_config.get('tau', None)
265
+ pcoord_ndim = plugin_config.get('pcoord_ndim', 1)
266
+
267
+ basis_pcoord_bounds = np.array(plugin_config.get('basis_pcoord_bounds', np.nan), dtype=float)
268
+ target_pcoord_bounds = np.array(plugin_config.get('target_pcoord_bounds', np.nan), dtype=float)
269
+
270
+ user_bin_mapper = plugin_config.get('user_bin_mapper', None)
271
+ if user_bin_mapper is not None:
272
+ user_bin_mapper = bins_from_yaml_dict(user_bin_mapper)
273
+
274
+ if np.isnan(basis_pcoord_bounds).any() or np.isnan(target_pcoord_bounds).any():
275
+ log.critical(
276
+ "Target and/or basis pcoord bounds were not specified. "
277
+ "Set them using the 'basis_pcoord_bounds' or 'target_pcoord_bounds' parameters. "
278
+ "'basis/target_pcoord1_min/max' and 'basis/target_pcoord1' are no longer supported. "
279
+ "See https://jdrusso.github.io/msm_we/api.html#msm_we.msm_we.modelWE.initialize for details."
280
+ )
281
+
282
+ if tau is None:
283
+ log.warning('No tau provided to restarting plugin. Defaulting to 1.')
284
+ tau = 1
285
+
286
+ # Fire up the model object
287
+ model.initialize(
288
+ west_files,
289
+ refPDBfile,
290
+ modelName,
291
+ basis_pcoord_bounds=basis_pcoord_bounds,
292
+ target_pcoord_bounds=target_pcoord_bounds,
293
+ tau=tau,
294
+ pcoord_ndim=pcoord_ndim,
295
+ )
296
+
297
+ model.dimReduceMethod = plugin_config.get('dim_reduce_method')
298
+
299
+ model.n_lag = n_lag
300
+
301
+ log.debug("Loading in iteration data.. (this could take a while)")
302
+
303
+ # First dimension is the total number of segments
304
+ model.get_iterations()
305
+
306
+ model.get_coordSet(model.maxIter)
307
+
308
+ model.dimReduce()
309
+
310
+ first_iter, last_iter = model.first_iter, model.maxIter
311
+
312
+ clusterFile = modelName + "_clusters_s" + str(first_iter) + "_e" + str(last_iter) + "_nC" + str(n_clusters) + ".h5"
313
+ # TODO: Uncomment this to actually load the clusterFile if it exists. For now, disable for development.
314
+ exists = os.path.isfile(clusterFile)
315
+ exists = False
316
+ log.warning("Skipping any potential cluster reloading!")
317
+
318
+ log.info(f"Launching Ray with {plugin_config.get('n_cpus', 1)} cpus")
319
+
320
+ ray_tempdir_root = plugin_config.get('ray_temp_dir', None)
321
+ if ray_tempdir_root is not None:
322
+ ray_tempdir = tempfile.TemporaryDirectory(dir=ray_tempdir_root)
323
+ log.info(f"Using {ray_tempdir.name} as temp_dir for Ray")
324
+ ray.init(
325
+ num_cpus=plugin_config.get('n_cpus', 1), _temp_dir=ray_tempdir.name, ignore_reinit_error=True, include_dashboard=False
326
+ )
327
+ else:
328
+ ray.init(num_cpus=plugin_config.get('n_cpus', 1), ignore_reinit_error=True, include_dashboard=False)
329
+
330
+ # If a cluster file with the name corresponding to these parameters exists, load clusters from it.
331
+ if exists:
332
+ log.debug("loading clusters...")
333
+ model.load_clusters(clusterFile)
334
+ # Otherwise, do the clustering (which will create and save to that file)
335
+ else:
336
+ # FIXME: This gives the wrong shape, but loading from the clusterfile gives the right shape
337
+ log.debug("clustering coordinates into " + str(n_clusters) + " clusters...")
338
+ model.cluster_coordinates(n_clusters, streaming=streaming, user_bin_mapper=user_bin_mapper)
339
+
340
+ first_iter = 1
341
+ model.get_fluxMatrix(n_lag, first_iter, last_iter) # extracts flux matrix, output model.fluxMatrixRaw
342
+ log.debug(f"Unprocessed flux matrix has shape {model.fluxMatrixRaw.shape}")
343
+ model.organize_fluxMatrix() # gets rid of bins with no connectivity, sorts along p1, output model.fluxMatrix
344
+ model.get_Tmatrix() # normalizes fluxMatrix to transition matrix, output model.Tmatrix
345
+
346
+ log.debug(f"Processed flux matrix has shape {model.fluxMatrix.shape}")
347
+
348
+ model.get_steady_state() # gets steady-state from eigen decomp, output model.pSS
349
+ model.get_steady_state_target_flux() # gets steady-state target flux, output model.JtargetSS
350
+
351
+ # Why is model.pss sometimes the wrong shape? It's "occasionally" returned as a nested array.
352
+ # Squeeze fixes it and removes the dimension of length 1, but why does it happen in the first place?
353
+
354
+ if type(model.pSS) is np.matrix:
355
+ ss_alg = np.squeeze(model.pSS.A)
356
+ else:
357
+ ss_alg = np.squeeze(model.pSS)
358
+
359
+ ss_flux = model.JtargetSS
360
+
361
+ log.debug("Got steady state:")
362
+ log.debug(ss_alg)
363
+ log.debug(ss_flux)
364
+
365
+ log.info("Completed flux matrix calculation and steady-state estimation!")
366
+
367
+ log.info("Starting block validation")
368
+ num_validation_groups = plugin_config.get('n_validation_groups', 2)
369
+ num_validation_blocks = plugin_config.get('n_validation_blocks', 4)
370
+ try:
371
+ model.do_block_validation(num_validation_groups, num_validation_blocks, use_ray=True)
372
+ except Exception as e:
373
+ log.exception(e)
374
+ log.error("Failed block validation! Continuing with restart, but BEWARE!")
375
+ ray.shutdown()
376
+
377
+ return ss_alg, ss_flux, model
378
+
379
+
380
+ class RestartDriver:
381
+ """
382
+ WESTPA plugin to automatically handle estimating steady-state from a WE run, re-initializing a new WE run in that
383
+ steady-state, and then running that initialized WE run.
384
+
385
+ Data from the previous run will be stored in the restart<restart_number>/ subdirectory of $WEST_SIM_ROOT.
386
+
387
+ This plugin depends on having the start-states implementation in the main WESTPA code, which allows initializing
388
+ a WE run using states that are NOT later used for recycling.
389
+
390
+ These are used so that when the new WE run is initialized, initial structure selection is chosen by w_init, using
391
+ weights assigned to the start-states based on MSM bin weight and WE segment weight.
392
+
393
+ Since it closes out the current WE run and starts a new one, this plugin should run LAST, after all other plugins.
394
+ """
395
+
396
+ def __init__(self, sim_manager, plugin_config):
397
+ """
398
+ Initialize the RestartDriver plugin.
399
+
400
+ Pulls the data_manager and sim_manager from the WESTPA run that just completed, along with
401
+ """
402
+
403
+ westpa.rc.pstatus("Restart plugin initialized")
404
+
405
+ if not sim_manager.work_manager.is_master:
406
+ westpa.rc.pstatus("Reweighting not master, skipping")
407
+ return
408
+
409
+ self.data_manager = sim_manager.data_manager
410
+ self.sim_manager = sim_manager
411
+
412
+ self.plugin_config = plugin_config
413
+
414
+ self.restart_file = plugin_config.get('restart_file', 'restart.dat')
415
+ self.initialization_file = plugin_config.get('initialization_file', 'restart_initialization.json')
416
+
417
+ self.extension_iters = plugin_config.get('extension_iters', 0)
418
+ self.max_total_iterations = westpa.rc.config.get(['west', 'propagation', 'max_total_iterations'], default=None)
419
+ self.base_total_iterations = self.max_total_iterations
420
+
421
+ self.coord_len = plugin_config.get('coord_len', 2)
422
+ self.n_restarts = plugin_config.get('n_restarts', -1)
423
+ self.n_runs = plugin_config.get('n_runs', 1)
424
+
425
+ # Number of CPUs available for parallelizing msm_we calculations
426
+ self.parallel_cpus = plugin_config.get('n_cpus', 1)
427
+ self.ray_tempdir = plugin_config.get('ray_temp_dir', None)
428
+
429
+ # .get() might return this as a bool anyways, but be safe
430
+ self.debug = bool(plugin_config.get('debug', False))
431
+ if self.debug:
432
+ log.setLevel("DEBUG")
433
+ msm_we_logger.setLevel("DEBUG")
434
+
435
+ # Default to using all restarts
436
+ self.restarts_to_use = plugin_config.get('n_restarts_to_use', self.n_restarts)
437
+ assert self.restarts_to_use > 0 or self.restarts_to_use == -1, "Invalid number of restarts to use"
438
+ if self.restarts_to_use >= 1:
439
+ assert (
440
+ self.restarts_to_use == self.restarts_to_use // 1
441
+ ), "If choosing a decimal restarts_to_use, must be between 0 and 1."
442
+
443
+ struct_filetype = plugin_config.get('struct_filetype', 'mdtraj.formats.PDBTrajectoryFile')
444
+ self.struct_filetype = get_object(struct_filetype)
445
+
446
+ # This should be low priority, because it closes the H5 file and starts a new WE run. So it should run LAST
447
+ # after any other plugins.
448
+ self.priority = plugin_config.get('priority', 100) # I think a big number is lower priority...
449
+
450
+ sim_manager.register_callback(sim_manager.finalize_run, self.prepare_new_we, self.priority)
451
+
452
+ # Initialize data
453
+ self.ss_alg = None
454
+ self.ss_dist = None
455
+ self.model = None
456
+
457
+ def get_original_bins(self):
458
+ """
459
+ Obtains the WE bins and their probabilities at the end of the previous iteration.
460
+
461
+ Returns
462
+ -------
463
+ bins : np.ndarray
464
+ Array of WE bins
465
+
466
+ binprobs: np.ndarray
467
+ WE bin weights
468
+ """
469
+
470
+ we_driver = self.sim_manager.we_driver
471
+ bins = we_driver.next_iter_binning
472
+ n_bins = len(bins)
473
+ binprobs = np.fromiter(map(operator.attrgetter('weight'), bins), dtype=np.float64, count=n_bins)
474
+
475
+ return bins, binprobs
476
+
477
+ @property
478
+ def cur_iter(self):
479
+ """
480
+ Get the current WE iteration.
481
+
482
+ Returns
483
+ -------
484
+ int: The current iteration. Subtract one, because in finalize_run the iter has been incremented
485
+ """
486
+ return self.sim_manager.n_iter - 1
487
+
488
+ @property
489
+ def is_last_iteration(self):
490
+ """
491
+ Get whether this is, or is past, the last iteration in this WE run.
492
+
493
+ Returns
494
+ -------
495
+ bool: Whether the current iteration is the final iteration
496
+ """
497
+
498
+ final_iter = self.sim_manager.max_total_iterations
499
+
500
+ return self.cur_iter >= final_iter
501
+
502
+ def prepare_extension_run(self, run_number, restart_state, first_extension=False):
503
+ """
504
+ Copy the necessary files for an extension run (versus initializing a fresh run)
505
+
506
+ Parameters
507
+ ----------
508
+ run_number: int
509
+ The index of this run (should be 1-indexed!)
510
+
511
+ restart_state: dict
512
+ Dictionary holding the current state of the restarting procedure
513
+
514
+ first_extension: bool
515
+ True if this is the first run of an extension set. If True, then back up west.cfg, and write the extended
516
+ west.cfg file.
517
+ """
518
+
519
+ log.debug(f"Linking run files from restart0/run{run_number}")
520
+
521
+ # Copy traj_segs, seg_logs, and west.h5 for restart0/runXX back into ./
522
+ # Later: (May only need to copy the latest iteration traj_segs, to avoid tons of back and forth)
523
+ try:
524
+ shutil.rmtree('traj_segs')
525
+ shutil.rmtree('seg_logs')
526
+ except OSError as e:
527
+ if str(e) == 'Cannot call rmtree on a symbolic link':
528
+ os.unlink('traj_segs')
529
+ os.unlink('seg_logs')
530
+
531
+ os.remove(self.data_manager.we_h5filename)
532
+
533
+ os.symlink(f'restart0/run{run_number}/traj_segs', 'traj_segs')
534
+ os.symlink(f'restart0/run{run_number}/seg_logs', 'seg_logs')
535
+
536
+ if first_extension:
537
+ # Get lines to make a new west.cfg by extending west.propagation.max_total_iterations
538
+ with open('west.cfg', 'r') as west_config:
539
+ lines = west_config.readlines()
540
+ for i, line in enumerate(lines):
541
+ # Parse out the number of maximum iterations
542
+ if 'max_total_iterations' in line:
543
+ max_iters = [int(i) for i in line.replace(':', ' ').replace('\n', ' ').split() if i.isdigit()]
544
+ new_max_iters = max_iters[0] + self.extension_iters
545
+ new_line = f"{line.split(':')[0]}: {new_max_iters}\n"
546
+ lines[i] = new_line
547
+ break
548
+
549
+ with open(self.restart_file, 'w') as fp:
550
+ json.dump(restart_state, fp)
551
+
552
+ log.info("First WE extension run ready!")
553
+ westpa.rc.pstatus(
554
+ f"\n\n===== Restart {restart_state['restarts_completed']}, "
555
+ + f"Run {restart_state['runs_completed'] + 1} extension running =====\n"
556
+ )
557
+
558
+ # TODO: I can't just go straight into a w_run here. w_run expects some things to be set I think, that aren't.
559
+ # I can do w_init, and do a new simulation just fine...
560
+ # I can do this on ONE run repeatedly just fine
561
+ # But if I try to just copy files and continue like this, there's something screwy in state somewhere that
562
+ # causes it to fail.
563
+ # The error has to do with offsets in the HDF5 file?
564
+ # Need to figure out what state would be cleared by w_init
565
+
566
+ # Frankly, this is a really sketchy way of doing this, but it seems to work...
567
+ # I remain skeptical there's not something weird under the hood that isn't being addressed correctly with
568
+ # regard to state, but if it works, it's good enough for now..
569
+ westpa.rc.sim_manager.segments = None
570
+ shutil.copy(f'restart0/run{run_number}/west.h5', self.data_manager.we_h5filename)
571
+ self.data_manager.open_backing()
572
+
573
+ log.debug(f"Sim manager thought n_iter was {westpa.rc.sim_manager.n_iter}")
574
+ log.debug(f"Data manager thought current_iteration was {self.data_manager.current_iteration}")
575
+ log.debug(f"{self.sim_manager} vs {westpa.rc.sim_manager}")
576
+
577
+ if run_number == 1:
578
+ westpa.rc.sim_manager.max_total_iterations += self.extension_iters
579
+
580
+ w_run.run_simulation()
581
+ return
582
+
583
+ def generate_plots(self, restart_directory):
584
+ model = self.model
585
+
586
+ log.info("Producing flux-profile, pseudocommittor, and target flux comparison plots.")
587
+ flux_pcoord_fig, flux_pcoord_ax = plt.subplots()
588
+ model.plot_flux(ax=flux_pcoord_ax, suppress_validation=True)
589
+ flux_pcoord_fig.text(x=0.1, y=-0.05, s='This flux profile should become flatter after restarting', fontsize=12)
590
+ flux_pcoord_ax.legend(bbox_to_anchor=(1.01, 1.0), loc="upper left")
591
+ flux_pcoord_fig.savefig(f'{restart_directory}/flux_plot.pdf', bbox_inches="tight")
592
+
593
+ flux_pseudocomm_fig, flux_pseudocomm_ax = plt.subplots()
594
+ model.plot_flux_committor(ax=flux_pseudocomm_ax, suppress_validation=True)
595
+ flux_pseudocomm_fig.text(
596
+ x=0.1,
597
+ y=-0.05,
598
+ s='This flux profile should become flatter after restarting.'
599
+ '\nThe x-axis is a "pseudo"committor, since it may be '
600
+ 'calculated from WE trajectories in the one-way ensemble.',
601
+ fontsize=12,
602
+ )
603
+ flux_pseudocomm_ax.legend(bbox_to_anchor=(1.01, 1.0), loc="upper left")
604
+ flux_pseudocomm_fig.savefig(f'{restart_directory}/pseudocomm-flux_plot.pdf', bbox_inches="tight")
605
+
606
+ flux_comparison_fig, flux_comparison_ax = plt.subplots(figsize=(7, 3))
607
+ # Get haMSM flux estimates
608
+ models = [model]
609
+ models.extend(model.validation_models)
610
+ n_validation_models = len(model.validation_models)
611
+
612
+ flux_estimates = []
613
+ for _model in models:
614
+ flux_estimates.append(_model.JtargetSS)
615
+
616
+ hamsm_flux_colors = iter(plt.rcParams['axes.prop_cycle'].by_key()['color'])
617
+ direct_flux_colors = iter(plt.cm.cool(np.linspace(0.2, 0.8, len(model.fileList))))
618
+
619
+ # Get WE direct flux estimate
620
+ for _file in model.fileList:
621
+ run = analysis.Run(_file)
622
+ last_iter = run.num_iterations
623
+ recycled = list(run.iteration(last_iter - 1).recycled_walkers)
624
+ target_flux = sum(walker.weight for walker in recycled) / model.tau
625
+
626
+ # TODO: Correct for time!
627
+ if len(_file) >= 15:
628
+ short_filename = f"....{_file[-12:]}"
629
+ else:
630
+ short_filename = _file
631
+
632
+ if target_flux == 0:
633
+ continue
634
+
635
+ flux_comparison_ax.axhline(
636
+ target_flux,
637
+ color=next(direct_flux_colors),
638
+ label=f"Last iter WE direct {target_flux:.2e}" f"\n ({short_filename})",
639
+ linestyle='--',
640
+ )
641
+
642
+ flux_comparison_ax.axhline(
643
+ flux_estimates[0], label=f"Main model estimate\n {flux_estimates[0]:.2e}", color=next(hamsm_flux_colors)
644
+ )
645
+ for i in range(1, n_validation_models + 1):
646
+ flux_comparison_ax.axhline(
647
+ flux_estimates[i],
648
+ label=f"Validation model {i - 1} estimate\n {flux_estimates[i]:.2e}",
649
+ color=next(hamsm_flux_colors),
650
+ )
651
+
652
+ flux_comparison_ax.legend(bbox_to_anchor=(1.01, 0.9), loc='upper left')
653
+ flux_comparison_ax.set_yscale('log')
654
+ flux_comparison_ax.set_ylabel('Flux')
655
+ flux_comparison_ax.set_xticks([])
656
+ flux_comparison_fig.tight_layout()
657
+ flux_comparison_fig.savefig(f'{restart_directory}/hamsm_vs_direct_flux_comparison_plot.pdf', bbox_inches="tight")
658
+
659
+ def prepare_new_we(self):
660
+ """
661
+ This function prepares a new WESTPA simulation using haMSM analysis to accelerate convergence.
662
+
663
+ The marathon functionality does re-implement some of the functionality of w_multi_west.
664
+ However, w_multi_west merges independent WE simulations, which may or may not be desirable.
665
+ I think for the purposes of this, it's good to keep the runs completely independent until haMSM model building.
666
+ Either that, or I'm just justifying not having known about w_multi_west when I wrote this. TBD.
667
+
668
+ # TODO: Replace all manual path-building with pathlib
669
+
670
+ The algorithm is as follows:
671
+ 1. Check to see if we've just completed the final iteration
672
+ 2. Handle launching multiple runs, if desired
673
+ 2. Build haMSM
674
+ 3. Obtain structures for each haMSM bin
675
+ 4. Make each structure a start-state, with probability set by (MSM-bin SS prob / # structures in bin)
676
+ 5. Potentially some renormalization?
677
+ 6. Start new WE simulation
678
+ """
679
+
680
+ # Do nothing if it's not the final iteration
681
+ if not self.is_last_iteration:
682
+ print(self.cur_iter)
683
+ return
684
+
685
+ log.debug("Final iteration, preparing restart")
686
+
687
+ restart_state = {'restarts_completed': 0, 'runs_completed': 0}
688
+
689
+ # Check for the existence of the extension lockfile here
690
+ doing_extension = os.path.exists(EXTENSION_LOCKFILE)
691
+
692
+ # Look for a restart.dat file to get the current state (how many restarts have been performed already)
693
+ if os.path.exists(self.restart_file):
694
+ with open(self.restart_file, 'r') as fp:
695
+ restart_state = json.load(fp)
696
+
697
+ # This is the final iteration of a run, so mark this run as completed
698
+ restart_state['runs_completed'] += 1
699
+
700
+ # Make the folder to store data for this marathon
701
+ restart_directory = f"restart{restart_state['restarts_completed']}"
702
+ run_directory = f"{restart_directory}/run{restart_state['runs_completed']}"
703
+ if not os.path.exists(run_directory):
704
+ os.makedirs(run_directory)
705
+
706
+ # Write coordinates to h5
707
+ prepare_coordinates(self.plugin_config, self.data_manager.we_h5file, self.data_manager.we_h5filename)
708
+
709
+ for data_folder in ['traj_segs', 'seg_logs']:
710
+ old_path = data_folder
711
+
712
+ # If you're doing an extension, this will be a symlink. So no need to copy, just unlink it and move on
713
+ if doing_extension and os.path.islink(old_path):
714
+ log.debug('Unlinking symlink')
715
+ os.unlink(old_path)
716
+ os.mkdir(old_path)
717
+ continue
718
+
719
+ new_path = f"{run_directory}/{old_path}"
720
+
721
+ log.debug(f"Moving {old_path} to {new_path}")
722
+
723
+ if os.path.exists(new_path):
724
+ log.info(f"{new_path} already exists. Removing and overwriting.")
725
+ shutil.rmtree(new_path)
726
+
727
+ try:
728
+ os.rename(old_path, new_path)
729
+ except FileNotFoundError:
730
+ log.warning(f"Folder {old_path} was not found." "This may be normal, but check your configuration.")
731
+ else:
732
+ # Make a new data folder for the next run
733
+ os.mkdir(old_path)
734
+
735
+ last_run = restart_state['runs_completed'] >= self.n_runs
736
+ last_restart = restart_state['restarts_completed'] >= self.n_restarts
737
+
738
+ # We've just finished a run. Let's check if we have to do any more runs in this marathon before doing a restart.
739
+ # In the case of n_runs == 1, then we're just doing a single run and restarting it every so often.
740
+ # Otherwise, a marathon consists of multiple runs, and restarts are performed between marathons.
741
+ if last_run:
742
+ log.info(f"All {self.n_runs} runs in this marathon completed.")
743
+
744
+ if last_restart:
745
+ log.info("All restarts completed! Performing final analysis.")
746
+
747
+ else:
748
+ log.info("Proceeding to prepare a restart.")
749
+
750
+ # Duplicating this is gross, but given the structure here, my options are either put it above these ifs
751
+ # entirely, meaning it'll be unnecessarily run at the end of the final restart, or duplicate it below.
752
+ log.info("Preparing coordinates for this run.")
753
+
754
+ # Now, continue on to haMSM calculation below.
755
+
756
+ # If we have more runs left to do in this marathon, prepare them
757
+ elif not last_run:
758
+ log.info(f"Run {restart_state['runs_completed']}/{self.n_runs} completed.")
759
+
760
+ # TODO: Initialize a new run, from the same configuration as this run was
761
+ # On the 1st run, I can write bstates/tstates/sstates into restart files, and use those for spawning
762
+ # subsequent runs in the marathon. That way, I don't make unnecessary copies of all those.
763
+ # Basis and target states are unchanged. Can I get the original parameters passed to w_init?
764
+ # Ideally, I should be able to call w_init with the exact same parameters that went to it the first time
765
+ initialization_state = {
766
+ 'tstate_file': None,
767
+ 'bstate_file': None,
768
+ 'sstate_file': None,
769
+ 'tstates': None,
770
+ 'bstates': None,
771
+ 'sstates': None,
772
+ 'segs_per_state': None,
773
+ }
774
+
775
+ # TODO: Implement this, and get rid of the initialization_file usage right below. Placeholder for now.
776
+ if restart_state['runs_completed'] == 1:
777
+ # Get and write basis, target, start states and segs per state for this marathon to disk
778
+ pass
779
+
780
+ # Save the WESTPA h5 data from this run
781
+ self.data_manager.finalize_run()
782
+ shutil.copyfile('west.h5', f"{run_directory}/west.h5")
783
+
784
+ # If this is a regular, fresh run (not an extension)
785
+ if not doing_extension:
786
+ if os.path.exists(self.initialization_file):
787
+ with open(self.initialization_file, 'r') as fp:
788
+ initialization_dict = json.load(fp)
789
+ initialization_dict = fix_deprecated_initialization(initialization_dict)
790
+ initialization_state.update(initialization_dict)
791
+ else:
792
+ raise Exception(
793
+ "No initialization JSON file provided -- " "I don't know how to start new runs in this marathon."
794
+ )
795
+
796
+ westpa.rc.pstatus(
797
+ f"\n\n===== Restart {restart_state['restarts_completed']}, "
798
+ + f"Run {restart_state['runs_completed'] + 1} initializing =====\n"
799
+ )
800
+
801
+ westpa.rc.pstatus(
802
+ f"\nRun: \n\t w_init --tstate-file {initialization_state['tstate_file']} "
803
+ + f"--bstate-file {initialization_state['bstate_file']} "
804
+ f"--sstate-file {initialization_state['sstate_file']} "
805
+ f"--segs-per-state {initialization_state['segs_per_state']}\n"
806
+ )
807
+
808
+ w_init.initialize(
809
+ **initialization_state,
810
+ shotgun=False,
811
+ )
812
+
813
+ with open(self.restart_file, 'w') as fp:
814
+ json.dump(restart_state, fp)
815
+
816
+ log.info("New WE run ready!")
817
+ westpa.rc.pstatus(
818
+ f"\n\n===== Restart {restart_state['restarts_completed']}, "
819
+ + f"Run {restart_state['runs_completed'] + 1} running =====\n"
820
+ )
821
+
822
+ w_run.run_simulation()
823
+ return
824
+
825
+ # If we're doing an extension set
826
+ # Instead of w_initting a new iteration, copy the files from restart0/runXX back into ./
827
+ elif doing_extension:
828
+ self.prepare_extension_run(run_number=restart_state['runs_completed'] + 1, restart_state=restart_state)
829
+ return
830
+
831
+ log.debug(f"{restart_state['restarts_completed']}/{self.n_restarts} restarts completed")
832
+
833
+ # Build the haMSM
834
+ log.debug("Initializing haMSM")
835
+
836
+ # Need to write the h5 file and close it out, but I need to get the current bstates first.
837
+ original_bstates = self.sim_manager.current_iter_bstates
838
+ if original_bstates is None:
839
+ original_bstates = self.data_manager.get_basis_states(self.sim_manager.n_iter - 1)
840
+
841
+ assert original_bstates is not None, "Bstates are none in the current iteration"
842
+
843
+ original_tstates = self.data_manager.get_target_states(self.cur_iter)
844
+
845
+ # Flush h5 file writes and copy it to the run directory
846
+ self.data_manager.finalize_run()
847
+ shutil.copyfile(self.data_manager.we_h5filename, f"{run_directory}/west.h5")
848
+
849
+ # Use all files in all restarts
850
+ # Restarts index at 0, because there's a 0th restart before you've... restarted anything.
851
+ # Runs index at 1, because Run 1 is the first run.
852
+ # TODO: Let the user pick last half or something in the plugin config.
853
+ marathon_west_files = []
854
+ # When doing the first restart, restarts_completed is 0 (because the first restart isn't complete yet) and
855
+ # the data generated during this restart is in /restart0.
856
+ # So when doing the Nth restart, restarts_completed is N-1
857
+
858
+ # If set to -1, use all restarts
859
+ if self.restarts_to_use == -1:
860
+ last_N_restarts = 1 + restart_state['restarts_completed']
861
+ # If this is an integer, use the last N restarts
862
+ elif self.restarts_to_use >= 1:
863
+ last_N_restarts = self.restarts_to_use
864
+ # If it's a decimal between 0 and 1, use it as a fraction
865
+ # At restart 1, and a fraction of 0.5, this should just use restart 1
866
+ elif 0 < self.restarts_to_use < 1:
867
+ last_N_restarts = int(self.restarts_to_use * (1 + restart_state['restarts_completed']))
868
+
869
+ # If this fraction is <1, use all until it's not
870
+ if last_N_restarts < 1:
871
+ last_N_restarts = 1 + restart_state['restarts_completed']
872
+
873
+ log.debug(f"Last N is {last_N_restarts}")
874
+ first_restart = max(1 + restart_state['restarts_completed'] - last_N_restarts, 0)
875
+ usable_restarts = range(first_restart, 1 + restart_state['restarts_completed'])
876
+
877
+ log.info(
878
+ f"At restart {restart_state['restarts_completed']}, building haMSM using data from restarts {list(usable_restarts)}"
879
+ )
880
+ for restart_number in usable_restarts:
881
+ for run_number in range(1, 1 + restart_state['runs_completed']):
882
+ west_file_path = f"restart{restart_number}/run{run_number}/west.h5"
883
+ marathon_west_files.append(west_file_path)
884
+
885
+ log.debug(f"WESTPA datafile for analysis are {marathon_west_files}")
886
+
887
+ #
888
+ # If this is the first restart, check to see if you got any target state flux
889
+ if restart_state['restarts_completed'] == 0:
890
+ pass
891
+
892
+ # Check to see if you got any target flux in ANY runs
893
+ target_reached = False
894
+ for west_file_path in marathon_west_files:
895
+ if check_target_reached(west_file_path):
896
+ target_reached = True
897
+ break
898
+
899
+ # If you reached the target, clean up from the extensions and then continue as normal
900
+ # If extension_iters is set to 0, then don't do extensions.
901
+ if target_reached or self.extension_iters == 0:
902
+ log.info("All runs reached target!")
903
+
904
+ # Do some cleanup from the extension run
905
+ if doing_extension and not self.extension_iters == 0:
906
+ # Remove the doing_extensions.lck lockfile
907
+ os.remove(EXTENSION_LOCKFILE)
908
+
909
+ westpa.rc.sim_manager.max_total_iterations = self.base_total_iterations
910
+
911
+ # Otherwise, just continue as normal
912
+ pass
913
+
914
+ # If no runs reached the target, then we need to extend them
915
+ elif not target_reached:
916
+ log.info("Target not reached. Preparing for extensions.")
917
+
918
+ # Create the doing_extensions.lck "lockfile" to indicate we're in extend mode (or keep if exists)
919
+ # and write the initial number of iterations to it.
920
+ if not os.path.exists(EXTENSION_LOCKFILE):
921
+ with open(EXTENSION_LOCKFILE, 'w') as lockfile:
922
+ lockfile.write(str(self.max_total_iterations))
923
+
924
+ # Reset runs_completed to 0, and rewrite restart.dat accordingly
925
+ restart_state['runs_completed'] = 0
926
+
927
+ self.prepare_extension_run(run_number=1, restart_state=restart_state, first_extension=True)
928
+ return
929
+
930
+ log.debug("Building haMSM and computing steady-state")
931
+ log.debug(f"Cur iter is {self.cur_iter}")
932
+ ss_dist, ss_flux, model = msmwe_compute_ss(self.plugin_config, marathon_west_files)
933
+ self.ss_dist = ss_dist
934
+ self.model = model
935
+
936
+ log.debug(f'Steady-state distribution: {ss_dist}')
937
+ log.info(f"Target steady-state flux is {ss_flux}")
938
+
939
+ # Obtain cluster-structures
940
+ log.debug("Obtaining cluster-structures")
941
+ model.update_cluster_structures()
942
+
943
+ # TODO: Do this with pathlib
944
+ struct_directory = f"{restart_directory}/structs"
945
+ if not os.path.exists(struct_directory):
946
+ os.makedirs(struct_directory)
947
+
948
+ flux_filename = f"{restart_directory}/JtargetSS.txt"
949
+ with open(flux_filename, 'w') as fp:
950
+ log.debug(f"Writing flux to {flux_filename}")
951
+ fp.write(str(model.JtargetSS))
952
+ fp.close()
953
+
954
+ ss_filename = f"{restart_directory}/pSS.txt"
955
+ with open(ss_filename, 'w') as fp:
956
+ log.debug(f"Writing pSS to {ss_filename}")
957
+ np.savetxt(fp, model.pSS)
958
+ fp.close()
959
+
960
+ # If this is the last run of the last restart, do nothing and exit.
961
+ # if restart_state['runs_completed'] >= self.n_runs and restart_state['restarts_completed'] >= self.n_restarts:
962
+ # log.info("All restarts completed!")
963
+ # return
964
+
965
+ # Construct start-state file with all structures and their weights
966
+ # TODO: Don't explicitly write EVERY structure to disk, or this will be a nightmare for large runs.
967
+ # However, for now, it's fine...
968
+ log.debug("Writing structures")
969
+ # TODO: Include start states from previous runs
970
+ sstates_filename = f"{restart_directory}/startstates.txt"
971
+ with open(sstates_filename, 'w') as fp:
972
+ # Track the total number of segments iterated over
973
+ seg_idx = 0
974
+
975
+ log.info(f"Obtaining potential start structures ({len(model.cluster_structures.items())} bins avail)")
976
+
977
+ # Can use these for sanity checks
978
+ total_weight = 0.0
979
+ total_bin_weights = []
980
+
981
+ # Loop over each set of (bin index, all the structures in that bin)
982
+ for msm_bin_idx, structures in tqdm.tqdm(model.cluster_structures.items()):
983
+ total_bin_weights.append(0)
984
+
985
+ # Don't put structures in the basis or target
986
+ if msm_bin_idx in [model.n_clusters, model.n_clusters + 1]:
987
+ continue
988
+
989
+ # The per-segment bin probability.
990
+ # Map a cluster number onto a cluster INDEX, because after cleaning the cluster numbers may no longer
991
+ # be consecutive.
992
+ bin_prob = self.ss_dist[model.cluster_mapping[msm_bin_idx]] # / len(structures)
993
+
994
+ if bin_prob == 0:
995
+ log.info(f"MSM-Bin {msm_bin_idx} has probability 0, so not saving any structs from it.")
996
+ continue
997
+
998
+ # The total amount of WE weight in this MSM microbin
999
+ msm_bin_we_weight = sum(model.cluster_structure_weights[msm_bin_idx])
1000
+
1001
+ # Write each structure to disk. Loop over each structure within a bin.
1002
+ msm_bin_we_weight_tracker = 0
1003
+ for struct_idx, structure in enumerate(structures):
1004
+ structure_filename = (
1005
+ f"{struct_directory}/bin{msm_bin_idx}_" f"struct{struct_idx}.{STRUCT_EXTENSIONS[self.struct_filetype]}"
1006
+ )
1007
+
1008
+ with self.struct_filetype(structure_filename, 'w') as struct_file:
1009
+ # One structure per segment
1010
+ seg_we_weight = model.cluster_structure_weights[msm_bin_idx][struct_idx]
1011
+ msm_bin_we_weight_tracker += seg_we_weight
1012
+
1013
+ # Structure weights are set according to Algorithm 5.3 in
1014
+ # Aristoff, D. & Zuckerman, D. M. Optimizing Weighted Ensemble Sampling of Steady States.
1015
+ # Multiscale Model Sim 18, 646–673 (2020).
1016
+ structure_weight = seg_we_weight * (bin_prob / msm_bin_we_weight)
1017
+
1018
+ total_bin_weights[-1] += structure_weight
1019
+ total_weight += structure_weight
1020
+
1021
+ topology = model.reference_structure.topology
1022
+
1023
+ try:
1024
+ angles = model.reference_structure.unitcell_angles[0]
1025
+ lengths = model.reference_structure.unitcell_lengths[0] * 10
1026
+ # This throws typeerror if reference_structure.unitcell_angles is None, or AttributeError
1027
+ # if reference_structure.unitcell_angles doesn't exist.
1028
+ except (TypeError, AttributeError):
1029
+ angles, lengths = None, None
1030
+
1031
+ coords = structure * 10 # Correct units
1032
+
1033
+ # Write the structure file
1034
+ if self.struct_filetype is md.formats.PDBTrajectoryFile:
1035
+ struct_file.write(coords, topology, modelIndex=1, unitcell_angles=angles, unitcell_lengths=lengths)
1036
+
1037
+ elif self.struct_filetype is md.formats.AmberRestartFile:
1038
+ # AmberRestartFile takes slightly differently named keyword args
1039
+ struct_file.write(coords, time=None, cell_angles=angles, cell_lengths=lengths)
1040
+
1041
+ else:
1042
+ # Otherwise, YOLO just hope all the positional arguments are in the right place
1043
+ log.warning(
1044
+ f"This output filetype ({self.struct_filetype}) is probably supported, "
1045
+ f"but not explicitly handled."
1046
+ " You should ensure that it takes argument as (coords, topology)"
1047
+ )
1048
+ struct_file.write(coords, topology)
1049
+ raise Exception("Don't know what extension to use for this filetype")
1050
+
1051
+ # Add this start-state to the start-states file
1052
+ # This path is relative to WEST_SIM_ROOT
1053
+ fp.write(f'b{msm_bin_idx}_s{struct_idx} {structure_weight} {structure_filename}\n')
1054
+ seg_idx += 1
1055
+
1056
+ # log.info(f"WE weight ({msm_bin_we_weight_tracker:.5e} / {msm_bin_we_weight:.5e})")
1057
+
1058
+ # TODO: Fix this check. It's never quite worked right, nor has it ever caught an actual problem, so just
1059
+ # disable for now.
1060
+ # In equilibrium, all probabilities count, but in steady-state the last 2 are the target/basis
1061
+ # Subtract off the probabilities of the basis and target states, since those don't have structures
1062
+ # assigned to them.
1063
+ # assert np.isclose(total_weight, 1 - sum(model.pSS[model.n_clusters :])), (
1064
+ # f"Total steady-state structure weights not normalized! (Total: {total_weight}) "
1065
+ # f"\n\t pSS: {model.pSS}"
1066
+ # f"\n\t Total bin weights {total_bin_weights}"
1067
+ # f"\n\t pSS sum: {sum(model.pSS)}"
1068
+ # f"\n\t pSS -2 sum: {sum(model.pSS[:-2])}"
1069
+ # f"\n\t pSS (+target, no basis) sum: {sum(model.pSS[:-2]) + model.pSS[-1]}"
1070
+ # )
1071
+
1072
+ ### Start the new simulation
1073
+
1074
+ bstates_str = ""
1075
+ for original_bstate in original_bstates:
1076
+ orig_bstate_prob = original_bstate.probability
1077
+ orig_bstate_label = original_bstate.label
1078
+ orig_bstate_aux = original_bstate.auxref
1079
+
1080
+ bstate_str = f"{orig_bstate_label} {orig_bstate_prob} {orig_bstate_aux}\n"
1081
+
1082
+ bstates_str += bstate_str
1083
+
1084
+ bstates_filename = f"{restart_directory}/basisstates.txt"
1085
+ with open(bstates_filename, 'w') as fp:
1086
+ fp.write(bstates_str)
1087
+
1088
+ tstates_str = ""
1089
+ for original_tstate in original_tstates:
1090
+ orig_tstate_label = original_tstate.label
1091
+ # TODO: Handle multidimensional pcoords
1092
+ orig_tstate_pcoord = original_tstate.pcoord[0]
1093
+
1094
+ tstate_str = f"{orig_tstate_label} {orig_tstate_pcoord}\n"
1095
+ tstates_str += tstate_str
1096
+ tstates_filename = f"{restart_directory}/targetstates.txt"
1097
+ with open(tstates_filename, 'w') as fp:
1098
+ fp.write(tstates_str)
1099
+
1100
+ # Pickle the model
1101
+ objFile = f"{restart_directory}/hamsm.obj"
1102
+ with open(objFile, "wb") as objFileHandler:
1103
+ log.debug("Pickling model")
1104
+ pickle.dump(model, objFileHandler, protocol=4)
1105
+ objFileHandler.close()
1106
+
1107
+ # Before finishing this restart, make a plot of the flux profile.
1108
+ # This is made so the user can see whether
1109
+
1110
+ self.generate_plots(restart_directory)
1111
+
1112
+ # At this point, the restart is completed, and the data for the next one is ready (though still need to make the
1113
+ # initialization file and such).
1114
+
1115
+ if last_restart:
1116
+ log.info("All restarts completed! Finished.")
1117
+ return
1118
+
1119
+ # Update restart_file file
1120
+ restart_state['restarts_completed'] += 1
1121
+ # If we're doing a restart, then reset the number of completed runs to 0 for the next marathon.
1122
+ restart_state['runs_completed'] = 0
1123
+ with open(self.restart_file, 'w') as fp:
1124
+ json.dump(restart_state, fp)
1125
+
1126
+ log.info("Initializing new run")
1127
+
1128
+ # TODO: Read this from config if available
1129
+ segs_per_state = 1
1130
+
1131
+ old_initialization_path = self.initialization_file
1132
+ new_initialization_path = f"{restart_directory}/{self.initialization_file}"
1133
+ log.debug(f"Moving initialization file from {old_initialization_path} to {new_initialization_path}.")
1134
+ shutil.move(old_initialization_path, new_initialization_path)
1135
+
1136
+ initialization_state = {
1137
+ 'tstate_file': tstates_filename,
1138
+ 'bstate_file': bstates_filename,
1139
+ 'sstate_file': sstates_filename,
1140
+ 'tstates': None,
1141
+ 'bstates': None,
1142
+ 'sstates': None,
1143
+ 'segs_per_state': segs_per_state,
1144
+ }
1145
+
1146
+ with open(self.initialization_file, 'w') as fp:
1147
+ json.dump(initialization_state, fp)
1148
+
1149
+ westpa.rc.pstatus(
1150
+ f"\n\n"
1151
+ f"===== Restart {restart_state['restarts_completed']}, "
1152
+ + f"Run {restart_state['runs_completed'] + 1} initializing =====\n"
1153
+ )
1154
+
1155
+ westpa.rc.pstatus(
1156
+ f"\nRun: \n\t w_init --tstate-file {tstates_filename} "
1157
+ + f"--bstate-file {bstates_filename} --sstate-file {sstates_filename} --segs-per-state {segs_per_state}\n"
1158
+ )
1159
+
1160
+ w_init.initialize(**initialization_state, shotgun=False)
1161
+
1162
+ log.info("New WE run ready!")
1163
+ westpa.rc.pstatus(f"\n\n===== Restart {restart_state['restarts_completed']} running =====\n")
1164
+
1165
+ w_run.run_simulation()