westpa 2022.10__cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of westpa might be problematic. Click here for more details.
- westpa/__init__.py +14 -0
- westpa/_version.py +21 -0
- westpa/analysis/__init__.py +5 -0
- westpa/analysis/core.py +746 -0
- westpa/analysis/statistics.py +27 -0
- westpa/analysis/trajectories.py +360 -0
- westpa/cli/__init__.py +0 -0
- westpa/cli/core/__init__.py +0 -0
- westpa/cli/core/w_fork.py +152 -0
- westpa/cli/core/w_init.py +230 -0
- westpa/cli/core/w_run.py +77 -0
- westpa/cli/core/w_states.py +212 -0
- westpa/cli/core/w_succ.py +99 -0
- westpa/cli/core/w_truncate.py +59 -0
- westpa/cli/tools/__init__.py +0 -0
- westpa/cli/tools/ploterr.py +506 -0
- westpa/cli/tools/plothist.py +706 -0
- westpa/cli/tools/w_assign.py +596 -0
- westpa/cli/tools/w_bins.py +166 -0
- westpa/cli/tools/w_crawl.py +119 -0
- westpa/cli/tools/w_direct.py +547 -0
- westpa/cli/tools/w_dumpsegs.py +94 -0
- westpa/cli/tools/w_eddist.py +506 -0
- westpa/cli/tools/w_fluxanl.py +378 -0
- westpa/cli/tools/w_ipa.py +833 -0
- westpa/cli/tools/w_kinavg.py +127 -0
- westpa/cli/tools/w_kinetics.py +96 -0
- westpa/cli/tools/w_multi_west.py +414 -0
- westpa/cli/tools/w_ntop.py +213 -0
- westpa/cli/tools/w_pdist.py +515 -0
- westpa/cli/tools/w_postanalysis_matrix.py +82 -0
- westpa/cli/tools/w_postanalysis_reweight.py +53 -0
- westpa/cli/tools/w_red.py +486 -0
- westpa/cli/tools/w_reweight.py +780 -0
- westpa/cli/tools/w_select.py +226 -0
- westpa/cli/tools/w_stateprobs.py +111 -0
- westpa/cli/tools/w_trace.py +599 -0
- westpa/core/__init__.py +0 -0
- westpa/core/_rc.py +673 -0
- westpa/core/binning/__init__.py +55 -0
- westpa/core/binning/_assign.cpython-312-x86_64-linux-gnu.so +0 -0
- westpa/core/binning/assign.py +449 -0
- westpa/core/binning/binless.py +96 -0
- westpa/core/binning/binless_driver.py +54 -0
- westpa/core/binning/binless_manager.py +190 -0
- westpa/core/binning/bins.py +47 -0
- westpa/core/binning/mab.py +427 -0
- westpa/core/binning/mab_driver.py +54 -0
- westpa/core/binning/mab_manager.py +198 -0
- westpa/core/data_manager.py +1694 -0
- westpa/core/extloader.py +74 -0
- westpa/core/h5io.py +995 -0
- westpa/core/kinetics/__init__.py +24 -0
- westpa/core/kinetics/_kinetics.cpython-312-x86_64-linux-gnu.so +0 -0
- westpa/core/kinetics/events.py +147 -0
- westpa/core/kinetics/matrates.py +156 -0
- westpa/core/kinetics/rate_averaging.py +266 -0
- westpa/core/progress.py +218 -0
- westpa/core/propagators/__init__.py +54 -0
- westpa/core/propagators/executable.py +715 -0
- westpa/core/reweight/__init__.py +14 -0
- westpa/core/reweight/_reweight.cpython-312-x86_64-linux-gnu.so +0 -0
- westpa/core/reweight/matrix.py +126 -0
- westpa/core/segment.py +119 -0
- westpa/core/sim_manager.py +830 -0
- westpa/core/states.py +359 -0
- westpa/core/systems.py +93 -0
- westpa/core/textio.py +74 -0
- westpa/core/trajectory.py +330 -0
- westpa/core/we_driver.py +908 -0
- westpa/core/wm_ops.py +43 -0
- westpa/core/yamlcfg.py +391 -0
- westpa/fasthist/__init__.py +34 -0
- westpa/fasthist/__main__.py +110 -0
- westpa/fasthist/_fasthist.cpython-312-x86_64-linux-gnu.so +0 -0
- westpa/mclib/__init__.py +264 -0
- westpa/mclib/__main__.py +28 -0
- westpa/mclib/_mclib.cpython-312-x86_64-linux-gnu.so +0 -0
- westpa/oldtools/__init__.py +4 -0
- westpa/oldtools/aframe/__init__.py +35 -0
- westpa/oldtools/aframe/atool.py +75 -0
- westpa/oldtools/aframe/base_mixin.py +26 -0
- westpa/oldtools/aframe/binning.py +178 -0
- westpa/oldtools/aframe/data_reader.py +560 -0
- westpa/oldtools/aframe/iter_range.py +200 -0
- westpa/oldtools/aframe/kinetics.py +117 -0
- westpa/oldtools/aframe/mcbs.py +146 -0
- westpa/oldtools/aframe/output.py +39 -0
- westpa/oldtools/aframe/plotting.py +90 -0
- westpa/oldtools/aframe/trajwalker.py +126 -0
- westpa/oldtools/aframe/transitions.py +469 -0
- westpa/oldtools/cmds/__init__.py +0 -0
- westpa/oldtools/cmds/w_ttimes.py +358 -0
- westpa/oldtools/files.py +34 -0
- westpa/oldtools/miscfn.py +23 -0
- westpa/oldtools/stats/__init__.py +4 -0
- westpa/oldtools/stats/accumulator.py +35 -0
- westpa/oldtools/stats/edfs.py +129 -0
- westpa/oldtools/stats/mcbs.py +89 -0
- westpa/tools/__init__.py +33 -0
- westpa/tools/binning.py +472 -0
- westpa/tools/core.py +340 -0
- westpa/tools/data_reader.py +159 -0
- westpa/tools/dtypes.py +31 -0
- westpa/tools/iter_range.py +198 -0
- westpa/tools/kinetics_tool.py +340 -0
- westpa/tools/plot.py +283 -0
- westpa/tools/progress.py +17 -0
- westpa/tools/selected_segs.py +154 -0
- westpa/tools/wipi.py +751 -0
- westpa/trajtree/__init__.py +4 -0
- westpa/trajtree/_trajtree.cpython-312-x86_64-linux-gnu.so +0 -0
- westpa/trajtree/trajtree.py +117 -0
- westpa/westext/__init__.py +0 -0
- westpa/westext/adaptvoronoi/__init__.py +3 -0
- westpa/westext/adaptvoronoi/adaptVor_driver.py +214 -0
- westpa/westext/hamsm_restarting/__init__.py +3 -0
- westpa/westext/hamsm_restarting/example_overrides.py +35 -0
- westpa/westext/hamsm_restarting/restart_driver.py +1165 -0
- westpa/westext/stringmethod/__init__.py +11 -0
- westpa/westext/stringmethod/fourier_fitting.py +69 -0
- westpa/westext/stringmethod/string_driver.py +253 -0
- westpa/westext/stringmethod/string_method.py +306 -0
- westpa/westext/weed/BinCluster.py +180 -0
- westpa/westext/weed/ProbAdjustEquil.py +100 -0
- westpa/westext/weed/UncertMath.py +247 -0
- westpa/westext/weed/__init__.py +10 -0
- westpa/westext/weed/weed_driver.py +182 -0
- westpa/westext/wess/ProbAdjust.py +101 -0
- westpa/westext/wess/__init__.py +6 -0
- westpa/westext/wess/wess_driver.py +207 -0
- westpa/work_managers/__init__.py +57 -0
- westpa/work_managers/core.py +396 -0
- westpa/work_managers/environment.py +134 -0
- westpa/work_managers/mpi.py +318 -0
- westpa/work_managers/processes.py +187 -0
- westpa/work_managers/serial.py +28 -0
- westpa/work_managers/threads.py +79 -0
- westpa/work_managers/zeromq/__init__.py +20 -0
- westpa/work_managers/zeromq/core.py +641 -0
- westpa/work_managers/zeromq/node.py +131 -0
- westpa/work_managers/zeromq/work_manager.py +526 -0
- westpa/work_managers/zeromq/worker.py +320 -0
- westpa-2022.10.dist-info/AUTHORS +22 -0
- westpa-2022.10.dist-info/LICENSE +21 -0
- westpa-2022.10.dist-info/METADATA +183 -0
- westpa-2022.10.dist-info/RECORD +150 -0
- westpa-2022.10.dist-info/WHEEL +6 -0
- westpa-2022.10.dist-info/entry_points.txt +29 -0
- westpa-2022.10.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,830 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import math
|
|
3
|
+
import operator
|
|
4
|
+
import random
|
|
5
|
+
import time
|
|
6
|
+
from datetime import timedelta
|
|
7
|
+
from pickle import PickleError
|
|
8
|
+
from itertools import zip_longest
|
|
9
|
+
from collections import Counter
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
import westpa
|
|
14
|
+
from .data_manager import weight_dtype
|
|
15
|
+
from .segment import Segment
|
|
16
|
+
from .states import InitialState
|
|
17
|
+
from . import extloader
|
|
18
|
+
from . import wm_ops
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
log = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
EPS = np.finfo(weight_dtype).eps
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def grouper(n, iterable, fillvalue=None):
|
|
27
|
+
"Collect data into fixed-length chunks or blocks"
|
|
28
|
+
# grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx
|
|
29
|
+
args = [iter(iterable)] * n
|
|
30
|
+
return zip_longest(fillvalue=fillvalue, *args)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class PropagationError(RuntimeError):
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class WESimManager:
|
|
38
|
+
def process_config(self):
|
|
39
|
+
config = self.rc.config
|
|
40
|
+
for entry, type_ in [('gen_istates', bool), ('block_size', int), ('save_transition_matrices', bool)]:
|
|
41
|
+
config.require_type_if_present(['west', 'propagation', entry], type_)
|
|
42
|
+
|
|
43
|
+
self.do_gen_istates = config.get(['west', 'propagation', 'gen_istates'], False)
|
|
44
|
+
self.propagator_block_size = config.get(['west', 'propagation', 'block_size'], 1)
|
|
45
|
+
self.save_transition_matrices = config.get(['west', 'propagation', 'save_transition_matrices'], False)
|
|
46
|
+
self.max_run_walltime = config.get(['west', 'propagation', 'max_run_wallclock'], default=None)
|
|
47
|
+
self.max_total_iterations = config.get(['west', 'propagation', 'max_total_iterations'], default=None)
|
|
48
|
+
|
|
49
|
+
def __init__(self, rc=None):
|
|
50
|
+
self.rc = rc or westpa.rc
|
|
51
|
+
self.work_manager = self.rc.get_work_manager()
|
|
52
|
+
self.data_manager = self.rc.get_data_manager()
|
|
53
|
+
self.we_driver = self.rc.get_we_driver()
|
|
54
|
+
self.system = self.rc.get_system_driver()
|
|
55
|
+
|
|
56
|
+
# A table of function -> list of (priority, name, callback) tuples
|
|
57
|
+
self._callback_table = {}
|
|
58
|
+
self._valid_callbacks = set(
|
|
59
|
+
(
|
|
60
|
+
self.prepare_run,
|
|
61
|
+
self.finalize_run,
|
|
62
|
+
self.prepare_iteration,
|
|
63
|
+
self.finalize_iteration,
|
|
64
|
+
self.pre_propagation,
|
|
65
|
+
self.post_propagation,
|
|
66
|
+
self.pre_we,
|
|
67
|
+
self.post_we,
|
|
68
|
+
self.prepare_new_iteration,
|
|
69
|
+
)
|
|
70
|
+
)
|
|
71
|
+
self._callbacks_by_name = {fn.__name__: fn for fn in self._valid_callbacks}
|
|
72
|
+
self.n_propagated = 0
|
|
73
|
+
|
|
74
|
+
# config items
|
|
75
|
+
self.do_gen_istates = False
|
|
76
|
+
self.propagator_block_size = 1
|
|
77
|
+
self.save_transition_matrices = False
|
|
78
|
+
self.max_run_walltime = None
|
|
79
|
+
self.max_total_iterations = None
|
|
80
|
+
self.process_config()
|
|
81
|
+
|
|
82
|
+
# Per-iteration variables
|
|
83
|
+
self.n_iter = None # current iteration
|
|
84
|
+
|
|
85
|
+
# Basis and initial states for this iteration, in case the propagator needs them
|
|
86
|
+
self.current_iter_bstates = None # BasisStates valid at this iteration
|
|
87
|
+
self.current_iter_istates = None # InitialStates used in this iteration
|
|
88
|
+
|
|
89
|
+
# Basis states for next iteration
|
|
90
|
+
self.next_iter_bstates = None # BasisStates valid for the next iteration
|
|
91
|
+
self.next_iter_bstate_cprobs = None # Cumulative probabilities for basis states, used for selection
|
|
92
|
+
|
|
93
|
+
# Tracking of this iteration's segments
|
|
94
|
+
self.segments = None # Mapping of seg_id to segment for all segments in this iteration
|
|
95
|
+
self.completed_segments = None # Mapping of seg_id to segment for all completed segments in this iteration
|
|
96
|
+
self.incomplete_segments = None # Mapping of seg_id to segment for all incomplete segments in this iteration
|
|
97
|
+
|
|
98
|
+
# Tracking of binning
|
|
99
|
+
self.bin_mapper_hash = None # Hash of bin mapper from most recently-run WE, for use by post-WE analysis plugins
|
|
100
|
+
|
|
101
|
+
def register_callback(self, hook, function, priority=0):
|
|
102
|
+
'''Registers a callback to execute during the given ``hook`` into the simulation loop. The optional
|
|
103
|
+
priority is used to order when the function is called relative to other registered callbacks.'''
|
|
104
|
+
|
|
105
|
+
if hook not in self._valid_callbacks:
|
|
106
|
+
try:
|
|
107
|
+
hook = self._callbacks_by_name[hook]
|
|
108
|
+
except KeyError:
|
|
109
|
+
raise KeyError('invalid hook {!r}'.format(hook))
|
|
110
|
+
|
|
111
|
+
# It's possible to register a callback that's a duplicate function, but at a different place in memory.
|
|
112
|
+
# For example, if you launch a run without clearing state from a previous run.
|
|
113
|
+
# More details on this are available in https://github.com/westpa/westpa/issues/182 but the below code
|
|
114
|
+
# handles specifically the problem that causes in plugin loading.
|
|
115
|
+
try:
|
|
116
|
+
# Before checking for set membership of (priority, function.__name__, function), just check
|
|
117
|
+
# function hash for collisions in this hook.
|
|
118
|
+
hook_function_hash = [hash(callback[2]) for callback in self._callback_table[hook]]
|
|
119
|
+
except KeyError:
|
|
120
|
+
# If there's no entry in self._callback_table for this hook, then there definitely aren't any collisions
|
|
121
|
+
# because no plugins are registered to it yet in the first place.
|
|
122
|
+
pass
|
|
123
|
+
else:
|
|
124
|
+
# If there are plugins registered to this hook, check for duplicate hash, which will definitely have the same name, module, function.
|
|
125
|
+
try:
|
|
126
|
+
if hash(function) in hook_function_hash:
|
|
127
|
+
log.info('{!r} has already been loaded, skipping'.format(function))
|
|
128
|
+
return
|
|
129
|
+
except KeyError:
|
|
130
|
+
pass
|
|
131
|
+
|
|
132
|
+
try:
|
|
133
|
+
self._callback_table[hook].add((priority, function.__name__, function))
|
|
134
|
+
except KeyError:
|
|
135
|
+
self._callback_table[hook] = set([(priority, function.__name__, function)])
|
|
136
|
+
|
|
137
|
+
# Raise warning if there are multiple callback with same priority.
|
|
138
|
+
for priority, count in Counter([callback[0] for callback in self._callback_table[hook]]).items():
|
|
139
|
+
if count > 1:
|
|
140
|
+
log.warning(
|
|
141
|
+
f'{count} callbacks in {hook} have identical priority {priority}. The order of callback execution is not guaranteed.'
|
|
142
|
+
)
|
|
143
|
+
log.warning(f'{hook}: {self._callback_table[hook]}')
|
|
144
|
+
|
|
145
|
+
log.debug('registered callback {!r} for hook {!r}'.format(function, hook))
|
|
146
|
+
|
|
147
|
+
def invoke_callbacks(self, hook, *args, **kwargs):
|
|
148
|
+
callbacks = self._callback_table.get(hook, [])
|
|
149
|
+
# Sort by priority, function name, then module name
|
|
150
|
+
sorted_callbacks = sorted(callbacks, key=lambda x: (x[0], x[1], x[2].__module__))
|
|
151
|
+
for priority, name, fn in sorted_callbacks:
|
|
152
|
+
log.debug('invoking callback {!r} for hook {!r}'.format(fn, hook))
|
|
153
|
+
fn(*args, **kwargs)
|
|
154
|
+
|
|
155
|
+
def load_plugins(self, plugins=None):
|
|
156
|
+
if plugins is None:
|
|
157
|
+
plugins = []
|
|
158
|
+
|
|
159
|
+
try:
|
|
160
|
+
plugins_config = westpa.rc.config['west', 'plugins']
|
|
161
|
+
except KeyError:
|
|
162
|
+
plugins_config = []
|
|
163
|
+
|
|
164
|
+
if plugins_config is None:
|
|
165
|
+
plugins_config = []
|
|
166
|
+
|
|
167
|
+
plugins += plugins_config
|
|
168
|
+
|
|
169
|
+
for plugin_config in plugins:
|
|
170
|
+
plugin_name = plugin_config['plugin']
|
|
171
|
+
if plugin_config.get('enabled', True):
|
|
172
|
+
log.info('loading plugin {!r}'.format(plugin_name))
|
|
173
|
+
plugin = extloader.get_object(plugin_name)(self, plugin_config)
|
|
174
|
+
log.debug('loaded plugin {!r}'.format(plugin))
|
|
175
|
+
|
|
176
|
+
def report_bin_statistics(self, bins, target_states, save_summary=False):
|
|
177
|
+
segments = list(self.segments.values())
|
|
178
|
+
bin_counts = np.fromiter(map(len, bins), dtype=np.int_, count=len(bins))
|
|
179
|
+
target_counts = self.we_driver.bin_target_counts
|
|
180
|
+
|
|
181
|
+
# Do not include bins with target count zero (e.g. sinks, never-filled bins) in the (non)empty bins statistics
|
|
182
|
+
n_active_bins = len(target_counts[target_counts != 0])
|
|
183
|
+
|
|
184
|
+
if target_states:
|
|
185
|
+
n_active_bins -= len(target_states)
|
|
186
|
+
|
|
187
|
+
seg_probs = np.fromiter(map(operator.attrgetter('weight'), segments), dtype=weight_dtype, count=len(segments))
|
|
188
|
+
bin_probs = np.fromiter(map(operator.attrgetter('weight'), bins), dtype=weight_dtype, count=len(bins))
|
|
189
|
+
norm = seg_probs.sum()
|
|
190
|
+
|
|
191
|
+
assert abs(1 - norm) < EPS * (len(segments) + n_active_bins)
|
|
192
|
+
|
|
193
|
+
min_seg_prob = seg_probs[seg_probs != 0].min()
|
|
194
|
+
max_seg_prob = seg_probs.max()
|
|
195
|
+
seg_drange = math.log(max_seg_prob / min_seg_prob)
|
|
196
|
+
min_bin_prob = bin_probs[bin_probs != 0].min()
|
|
197
|
+
max_bin_prob = bin_probs.max()
|
|
198
|
+
bin_drange = math.log(max_bin_prob / min_bin_prob)
|
|
199
|
+
n_pop = len(bin_counts[bin_counts != 0])
|
|
200
|
+
|
|
201
|
+
self.rc.pstatus('{:d} of {:d} ({:%}) active bins are populated'.format(n_pop, n_active_bins, n_pop / n_active_bins))
|
|
202
|
+
self.rc.pstatus('per-bin minimum non-zero probability: {:g}'.format(min_bin_prob))
|
|
203
|
+
self.rc.pstatus('per-bin maximum probability: {:g}'.format(max_bin_prob))
|
|
204
|
+
self.rc.pstatus('per-bin probability dynamic range (kT): {:g}'.format(bin_drange))
|
|
205
|
+
self.rc.pstatus('per-segment minimum non-zero probability: {:g}'.format(min_seg_prob))
|
|
206
|
+
self.rc.pstatus('per-segment maximum non-zero probability: {:g}'.format(max_seg_prob))
|
|
207
|
+
self.rc.pstatus('per-segment probability dynamic range (kT): {:g}'.format(seg_drange))
|
|
208
|
+
self.rc.pstatus('norm = {:g}, error in norm = {:g} ({:.2g}*epsilon)'.format(norm, (norm - 1), (norm - 1) / EPS))
|
|
209
|
+
self.rc.pflush()
|
|
210
|
+
|
|
211
|
+
if min_seg_prob < 1e-100:
|
|
212
|
+
log.warning(
|
|
213
|
+
'\nMinimum segment weight is < 1e-100 and might not be physically relevant. Please reconsider your progress coordinate or binning scheme.'
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
if save_summary:
|
|
217
|
+
iter_summary = self.data_manager.get_iter_summary()
|
|
218
|
+
iter_summary['n_particles'] = len(segments)
|
|
219
|
+
iter_summary['norm'] = norm
|
|
220
|
+
iter_summary['min_bin_prob'] = min_bin_prob
|
|
221
|
+
iter_summary['max_bin_prob'] = max_bin_prob
|
|
222
|
+
iter_summary['min_seg_prob'] = min_seg_prob
|
|
223
|
+
iter_summary['max_seg_prob'] = max_seg_prob
|
|
224
|
+
if np.isnan(iter_summary['cputime']):
|
|
225
|
+
iter_summary['cputime'] = 0.0
|
|
226
|
+
if np.isnan(iter_summary['walltime']):
|
|
227
|
+
iter_summary['walltime'] = 0.0
|
|
228
|
+
self.data_manager.update_iter_summary(iter_summary)
|
|
229
|
+
|
|
230
|
+
def get_bstate_pcoords(self, basis_states, label='basis'):
|
|
231
|
+
'''For each of the given ``basis_states``, calculate progress coordinate values
|
|
232
|
+
as necessary. The HDF5 file is not updated.'''
|
|
233
|
+
|
|
234
|
+
self.rc.pstatus('Calculating progress coordinate values for {} states.'.format(label))
|
|
235
|
+
futures = [self.work_manager.submit(wm_ops.get_pcoord, args=(basis_state,)) for basis_state in basis_states]
|
|
236
|
+
fmap = {future: i for (i, future) in enumerate(futures)}
|
|
237
|
+
for future in self.work_manager.as_completed(futures):
|
|
238
|
+
basis_states[fmap[future]].pcoord = future.get_result().pcoord
|
|
239
|
+
|
|
240
|
+
def report_basis_states(self, basis_states, label='basis'):
|
|
241
|
+
pstatus = self.rc.pstatus
|
|
242
|
+
pstatus('{:d} {} state(s) present'.format(len(basis_states), label), end='')
|
|
243
|
+
if self.rc.verbose_mode:
|
|
244
|
+
pstatus(':')
|
|
245
|
+
pstatus(
|
|
246
|
+
'{:6s} {:12s} {:20s} {:20s} {}'.format(
|
|
247
|
+
'ID', 'Label', 'Probability', 'Aux Reference', 'Progress Coordinate'
|
|
248
|
+
)
|
|
249
|
+
)
|
|
250
|
+
for basis_state in basis_states:
|
|
251
|
+
pstatus(
|
|
252
|
+
'{:<6d} {:12s} {:<20.14g} {:20s} {}'.format(
|
|
253
|
+
basis_state.state_id,
|
|
254
|
+
basis_state.label,
|
|
255
|
+
basis_state.probability,
|
|
256
|
+
basis_state.auxref or '',
|
|
257
|
+
', '.join(map(str, basis_state.pcoord)),
|
|
258
|
+
)
|
|
259
|
+
)
|
|
260
|
+
pstatus()
|
|
261
|
+
self.rc.pflush()
|
|
262
|
+
|
|
263
|
+
def report_target_states(self, target_states):
|
|
264
|
+
pstatus = self.rc.pstatus
|
|
265
|
+
pstatus('{:d} target state(s) present'.format(len(target_states)), end='')
|
|
266
|
+
if self.rc.verbose_mode and target_states:
|
|
267
|
+
pstatus(':')
|
|
268
|
+
pstatus('{:6s} {:12s} {}'.format('ID', 'Label', 'Progress Coordinate'))
|
|
269
|
+
for target_state in target_states:
|
|
270
|
+
pstatus(
|
|
271
|
+
'{:<6d} {:12s} {}'.format(
|
|
272
|
+
target_state.state_id, target_state.label, ','.join(map(str, target_state.pcoord))
|
|
273
|
+
)
|
|
274
|
+
)
|
|
275
|
+
pstatus()
|
|
276
|
+
self.rc.pflush()
|
|
277
|
+
|
|
278
|
+
def initialize_simulation(self, basis_states, target_states, start_states, segs_per_state=1, suppress_we=False):
|
|
279
|
+
'''Initialize a new weighted ensemble simulation, taking ``segs_per_state`` initial
|
|
280
|
+
states from each of the given ``basis_states``.
|
|
281
|
+
|
|
282
|
+
``w_init`` is the forward-facing version of this function'''
|
|
283
|
+
|
|
284
|
+
data_manager = self.data_manager
|
|
285
|
+
work_manager = self.work_manager
|
|
286
|
+
pstatus = self.rc.pstatus
|
|
287
|
+
system = self.system
|
|
288
|
+
|
|
289
|
+
pstatus('Creating HDF5 file {!r}'.format(self.data_manager.we_h5filename))
|
|
290
|
+
data_manager.prepare_backing()
|
|
291
|
+
|
|
292
|
+
# Process target states
|
|
293
|
+
data_manager.save_target_states(target_states)
|
|
294
|
+
self.report_target_states(target_states)
|
|
295
|
+
|
|
296
|
+
# Process basis states
|
|
297
|
+
self.get_bstate_pcoords(basis_states)
|
|
298
|
+
self.data_manager.create_ibstate_group(basis_states)
|
|
299
|
+
self.data_manager.create_ibstate_iter_h5file(basis_states)
|
|
300
|
+
self.report_basis_states(basis_states)
|
|
301
|
+
|
|
302
|
+
# Process start states
|
|
303
|
+
# Unlike the above, does not create an ibstate group.
|
|
304
|
+
# TODO: Should it? I don't think so, if needed it can be traced back through basis_auxref
|
|
305
|
+
|
|
306
|
+
# Here, we are trying to assign a state_id to the start state to be initialized, without actually
|
|
307
|
+
# saving it to the ibstates records in any of the h5 files. It might actually be a problem
|
|
308
|
+
# when tracing trajectories with westpa.analysis (especially with HDF5 framework) since it would
|
|
309
|
+
# try to look for a basis state id > len(basis_state).
|
|
310
|
+
# Since start states are only used while initializing and iteration 1, it's ok to not save it to save space. If necessary,
|
|
311
|
+
# the structure can be traced directly to the parent file using the standard basis state logic referencing
|
|
312
|
+
# west[iterations/iter_00000001/ibstates/istate_index/basis_auxref] of that istate.
|
|
313
|
+
|
|
314
|
+
if len(start_states) > 0 and start_states[0].state_id is None:
|
|
315
|
+
last_id = basis_states[-1].state_id
|
|
316
|
+
for start_state in start_states:
|
|
317
|
+
start_state.state_id = last_id + 1
|
|
318
|
+
last_id += 1
|
|
319
|
+
|
|
320
|
+
self.get_bstate_pcoords(start_states, label='start')
|
|
321
|
+
self.report_basis_states(start_states, label='start')
|
|
322
|
+
|
|
323
|
+
pstatus('Preparing initial states')
|
|
324
|
+
initial_states = []
|
|
325
|
+
weights = []
|
|
326
|
+
if self.do_gen_istates:
|
|
327
|
+
istate_type = InitialState.ISTATE_TYPE_GENERATED
|
|
328
|
+
else:
|
|
329
|
+
istate_type = InitialState.ISTATE_TYPE_BASIS
|
|
330
|
+
|
|
331
|
+
for basis_state in basis_states:
|
|
332
|
+
for _iseg in range(segs_per_state):
|
|
333
|
+
initial_state = data_manager.create_initial_states(1, 1)[0]
|
|
334
|
+
initial_state.basis_state_id = basis_state.state_id
|
|
335
|
+
initial_state.basis_state = basis_state
|
|
336
|
+
initial_state.istate_type = istate_type
|
|
337
|
+
weights.append(basis_state.probability / segs_per_state)
|
|
338
|
+
initial_states.append(initial_state)
|
|
339
|
+
|
|
340
|
+
for start_state in start_states:
|
|
341
|
+
for _iseg in range(segs_per_state):
|
|
342
|
+
initial_state = data_manager.create_initial_states(1, 1)[0]
|
|
343
|
+
initial_state.basis_state_id = start_state.state_id
|
|
344
|
+
initial_state.basis_state = start_state
|
|
345
|
+
initial_state.basis_auxref = start_state.auxref
|
|
346
|
+
|
|
347
|
+
# Start states are assigned their own type, so they can be identified later
|
|
348
|
+
initial_state.istate_type = InitialState.ISTATE_TYPE_START
|
|
349
|
+
weights.append(start_state.probability / segs_per_state)
|
|
350
|
+
initial_state.iter_used = 1
|
|
351
|
+
initial_states.append(initial_state)
|
|
352
|
+
|
|
353
|
+
if self.do_gen_istates:
|
|
354
|
+
futures = [
|
|
355
|
+
work_manager.submit(wm_ops.gen_istate, args=(initial_state.basis_state, initial_state))
|
|
356
|
+
for initial_state in initial_states
|
|
357
|
+
]
|
|
358
|
+
for future in work_manager.as_completed(futures):
|
|
359
|
+
rbstate, ristate = future.get_result()
|
|
360
|
+
initial_states[ristate.state_id].pcoord = ristate.pcoord
|
|
361
|
+
else:
|
|
362
|
+
for initial_state in initial_states:
|
|
363
|
+
basis_state = initial_state.basis_state
|
|
364
|
+
initial_state.pcoord = basis_state.pcoord
|
|
365
|
+
initial_state.istate_status = InitialState.ISTATE_STATUS_PREPARED
|
|
366
|
+
|
|
367
|
+
for initial_state in initial_states:
|
|
368
|
+
log.debug('initial state created: {!r}'.format(initial_state))
|
|
369
|
+
|
|
370
|
+
# save list of initial states just generated
|
|
371
|
+
# some of these may not be used, depending on how WE shakes out
|
|
372
|
+
data_manager.update_initial_states(initial_states, n_iter=1)
|
|
373
|
+
|
|
374
|
+
if not suppress_we:
|
|
375
|
+
self.we_driver.populate_initial(initial_states, weights, system)
|
|
376
|
+
segments = list(self.we_driver.next_iter_segments)
|
|
377
|
+
binning = self.we_driver.next_iter_binning
|
|
378
|
+
else:
|
|
379
|
+
segments = list(self.we_driver.current_iter_segments)
|
|
380
|
+
binning = self.we_driver.final_binning
|
|
381
|
+
|
|
382
|
+
bin_occupancies = np.fromiter(map(len, binning), dtype=np.uint, count=self.we_driver.bin_mapper.nbins)
|
|
383
|
+
target_occupancies = np.require(self.we_driver.bin_target_counts, dtype=np.uint)
|
|
384
|
+
|
|
385
|
+
# total_bins/replicas defined here to remove target state bin from "active" bins
|
|
386
|
+
total_bins = len(bin_occupancies) - len(target_states)
|
|
387
|
+
total_replicas = int(sum(target_occupancies)) - int(self.we_driver.bin_target_counts[-1]) * len(target_states)
|
|
388
|
+
|
|
389
|
+
# Make sure we have
|
|
390
|
+
for segment in segments:
|
|
391
|
+
segment.n_iter = 1
|
|
392
|
+
segment.status = Segment.SEG_STATUS_PREPARED
|
|
393
|
+
assert segment.parent_id < 0
|
|
394
|
+
assert initial_states[segment.initial_state_id].iter_used == 1
|
|
395
|
+
|
|
396
|
+
data_manager.prepare_iteration(1, segments)
|
|
397
|
+
data_manager.update_initial_states(initial_states, n_iter=1)
|
|
398
|
+
|
|
399
|
+
if self.rc.verbose_mode:
|
|
400
|
+
pstatus('\nSegments generated:')
|
|
401
|
+
for segment in segments:
|
|
402
|
+
pstatus('{!r}'.format(segment))
|
|
403
|
+
|
|
404
|
+
pstatus(
|
|
405
|
+
'''
|
|
406
|
+
Total bins: {total_bins:d}
|
|
407
|
+
Initial replicas: {init_replicas:d} in {occ_bins:d} bins, total weight = {weight:g}
|
|
408
|
+
Total target replicas: {total_replicas:d}
|
|
409
|
+
'''.format(
|
|
410
|
+
total_bins=total_bins,
|
|
411
|
+
init_replicas=int(sum(bin_occupancies)),
|
|
412
|
+
occ_bins=len(bin_occupancies[bin_occupancies > 0]),
|
|
413
|
+
weight=float(sum(segment.weight for segment in segments)),
|
|
414
|
+
total_replicas=total_replicas,
|
|
415
|
+
)
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
total_prob = float(sum(segment.weight for segment in segments))
|
|
419
|
+
pstatus(f'1-prob: {1 - total_prob:.4e}')
|
|
420
|
+
|
|
421
|
+
target_counts = self.we_driver.bin_target_counts
|
|
422
|
+
# Do not include bins with target count zero (e.g. sinks, never-filled bins) in the (non)empty bins statistics
|
|
423
|
+
n_active_bins = len(target_counts[target_counts != 0])
|
|
424
|
+
seg_probs = np.fromiter(map(operator.attrgetter('weight'), segments), dtype=weight_dtype, count=len(segments))
|
|
425
|
+
norm = seg_probs.sum()
|
|
426
|
+
|
|
427
|
+
if not abs(1 - norm) < EPS * (len(segments) + n_active_bins):
|
|
428
|
+
pstatus("Normalization check failed at w_init, explicitly renormalizing")
|
|
429
|
+
for segment in segments:
|
|
430
|
+
segment.weight /= norm
|
|
431
|
+
|
|
432
|
+
# Send the segments over to the data manager to commit to disk
|
|
433
|
+
data_manager.current_iteration = 1
|
|
434
|
+
|
|
435
|
+
# Report statistics
|
|
436
|
+
pstatus('Simulation prepared.')
|
|
437
|
+
self.segments = {segment.seg_id: segment for segment in segments}
|
|
438
|
+
self.report_bin_statistics(binning, target_states, save_summary=True)
|
|
439
|
+
data_manager.flush_backing()
|
|
440
|
+
data_manager.close_backing()
|
|
441
|
+
|
|
442
|
+
def prepare_iteration(self):
|
|
443
|
+
log.debug('beginning iteration {:d}'.format(self.n_iter))
|
|
444
|
+
|
|
445
|
+
# the WE driver needs a list of all target states for this iteration
|
|
446
|
+
# along with information about any new weights introduced (e.g. by recycling)
|
|
447
|
+
target_states = self.data_manager.get_target_states(self.n_iter)
|
|
448
|
+
new_weights = self.data_manager.get_new_weight_data(self.n_iter)
|
|
449
|
+
|
|
450
|
+
self.we_driver.new_iteration(target_states=target_states, new_weights=new_weights)
|
|
451
|
+
|
|
452
|
+
# Get basis states used in this iteration
|
|
453
|
+
self.current_iter_bstates = self.data_manager.get_basis_states(self.n_iter)
|
|
454
|
+
|
|
455
|
+
# Get the segments for this iteration and separate into complete and incomplete
|
|
456
|
+
if self.segments is None:
|
|
457
|
+
segments = self.segments = {segment.seg_id: segment for segment in self.data_manager.get_segments()}
|
|
458
|
+
log.debug('loaded {:d} segments'.format(len(segments)))
|
|
459
|
+
else:
|
|
460
|
+
segments = self.segments
|
|
461
|
+
log.debug('using {:d} pre-existing segments'.format(len(segments)))
|
|
462
|
+
|
|
463
|
+
completed_segments = self.completed_segments = {}
|
|
464
|
+
incomplete_segments = self.incomplete_segments = {}
|
|
465
|
+
for segment in segments.values():
|
|
466
|
+
if segment.status == Segment.SEG_STATUS_COMPLETE:
|
|
467
|
+
completed_segments[segment.seg_id] = segment
|
|
468
|
+
else:
|
|
469
|
+
incomplete_segments[segment.seg_id] = segment
|
|
470
|
+
log.debug('{:d} segments are complete; {:d} are incomplete'.format(len(completed_segments), len(incomplete_segments)))
|
|
471
|
+
|
|
472
|
+
if len(incomplete_segments) == len(segments):
|
|
473
|
+
# Starting a new iteration
|
|
474
|
+
self.rc.pstatus('Beginning iteration {:d}'.format(self.n_iter))
|
|
475
|
+
elif incomplete_segments:
|
|
476
|
+
self.rc.pstatus('Continuing iteration {:d}'.format(self.n_iter))
|
|
477
|
+
self.rc.pstatus(
|
|
478
|
+
'{:d} segments remain in iteration {:d} ({:d} total)'.format(len(incomplete_segments), self.n_iter, len(segments))
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
# Get the initial states active for this iteration (so that the propagator has them if necessary)
|
|
482
|
+
self.current_iter_istates = {
|
|
483
|
+
state.state_id: state for state in self.data_manager.get_segment_initial_states(list(segments.values()))
|
|
484
|
+
}
|
|
485
|
+
log.debug('This iteration uses {:d} initial states'.format(len(self.current_iter_istates)))
|
|
486
|
+
|
|
487
|
+
# Assign this iteration's segments' initial points to bins and report on bin population
|
|
488
|
+
initial_pcoords = self.system.new_pcoord_array(len(segments))
|
|
489
|
+
initial_binning = self.system.bin_mapper.construct_bins()
|
|
490
|
+
for iseg, segment in enumerate(segments.values()):
|
|
491
|
+
initial_pcoords[iseg] = segment.pcoord[0]
|
|
492
|
+
initial_assignments = self.system.bin_mapper.assign(initial_pcoords)
|
|
493
|
+
for segment, assignment in zip(iter(segments.values()), initial_assignments):
|
|
494
|
+
initial_binning[assignment].add(segment)
|
|
495
|
+
self.report_bin_statistics(initial_binning, [], save_summary=True)
|
|
496
|
+
del initial_pcoords, initial_binning
|
|
497
|
+
|
|
498
|
+
self.rc.pstatus('Waiting for segments to complete...')
|
|
499
|
+
|
|
500
|
+
# Let the WE driver assign completed segments
|
|
501
|
+
if completed_segments:
|
|
502
|
+
self.we_driver.assign(list(completed_segments.values()))
|
|
503
|
+
|
|
504
|
+
# load restart data
|
|
505
|
+
self.data_manager.prepare_segment_restarts(
|
|
506
|
+
incomplete_segments.values(), self.current_iter_bstates, self.current_iter_istates
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
# Get the basis states and initial states for the next iteration, necessary for doing on-the-fly recycling
|
|
510
|
+
self.next_iter_bstates = self.data_manager.get_basis_states(self.n_iter + 1)
|
|
511
|
+
self.next_iter_bstate_cprobs = np.add.accumulate([bstate.probability for bstate in self.next_iter_bstates])
|
|
512
|
+
|
|
513
|
+
self.we_driver.avail_initial_states = {
|
|
514
|
+
istate.state_id: istate for istate in self.data_manager.get_unused_initial_states(n_iter=self.n_iter + 1)
|
|
515
|
+
}
|
|
516
|
+
log.debug('{:d} unused initial states found'.format(len(self.we_driver.avail_initial_states)))
|
|
517
|
+
|
|
518
|
+
# Invoke callbacks
|
|
519
|
+
self.invoke_callbacks(self.prepare_iteration)
|
|
520
|
+
|
|
521
|
+
# dispatch and immediately wait on result for prep_iter
|
|
522
|
+
log.debug('dispatching propagator prep_iter to work manager')
|
|
523
|
+
self.work_manager.submit(wm_ops.prep_iter, args=(self.n_iter, segments)).get_result()
|
|
524
|
+
|
|
525
|
+
def finalize_iteration(self):
|
|
526
|
+
'''Clean up after an iteration and prepare for the next.'''
|
|
527
|
+
log.debug('finalizing iteration {:d}'.format(self.n_iter))
|
|
528
|
+
|
|
529
|
+
self.invoke_callbacks(self.finalize_iteration)
|
|
530
|
+
|
|
531
|
+
# dispatch and immediately wait on result for post_iter
|
|
532
|
+
log.debug('dispatching propagator post_iter to work manager')
|
|
533
|
+
self.work_manager.submit(wm_ops.post_iter, args=(self.n_iter, list(self.segments.values()))).get_result()
|
|
534
|
+
|
|
535
|
+
# Move existing segments into place as new segments
|
|
536
|
+
del self.segments
|
|
537
|
+
self.segments = {segment.seg_id: segment for segment in self.we_driver.next_iter_segments}
|
|
538
|
+
|
|
539
|
+
self.rc.pstatus("Iteration completed successfully")
|
|
540
|
+
|
|
541
|
+
def get_istate_futures(self):
|
|
542
|
+
'''Add ``n_states`` initial states to the internal list of initial states assigned to
|
|
543
|
+
recycled particles. Spare states are used if available, otherwise new states are created.
|
|
544
|
+
If created new initial states requires generation, then a set of futures is returned
|
|
545
|
+
representing work manager tasks corresponding to the necessary generation work.'''
|
|
546
|
+
|
|
547
|
+
n_recycled = self.we_driver.n_recycled_segs
|
|
548
|
+
n_istates_needed = self.we_driver.n_istates_needed
|
|
549
|
+
|
|
550
|
+
log.debug('{:d} unused initial states available'.format(len(self.we_driver.avail_initial_states)))
|
|
551
|
+
log.debug('{:d} new initial states required for recycling {:d} walkers'.format(n_istates_needed, n_recycled))
|
|
552
|
+
|
|
553
|
+
futures = set()
|
|
554
|
+
updated_states = []
|
|
555
|
+
for _i in range(n_istates_needed):
|
|
556
|
+
# Select a basis state according to its weight
|
|
557
|
+
ibstate = np.digitize([random.random()], self.next_iter_bstate_cprobs)
|
|
558
|
+
basis_state = self.next_iter_bstates[ibstate[0]]
|
|
559
|
+
initial_state = self.data_manager.create_initial_states(1, n_iter=self.n_iter + 1)[0]
|
|
560
|
+
initial_state.iter_created = self.n_iter
|
|
561
|
+
initial_state.basis_state_id = basis_state.state_id
|
|
562
|
+
initial_state.istate_status = InitialState.ISTATE_STATUS_PENDING
|
|
563
|
+
|
|
564
|
+
if self.do_gen_istates:
|
|
565
|
+
log.debug('generating new initial state from basis state {!r}'.format(basis_state))
|
|
566
|
+
initial_state.istate_type = InitialState.ISTATE_TYPE_GENERATED
|
|
567
|
+
futures.add(self.work_manager.submit(wm_ops.gen_istate, args=(basis_state, initial_state)))
|
|
568
|
+
else:
|
|
569
|
+
log.debug('using basis state {!r} directly'.format(basis_state))
|
|
570
|
+
initial_state.istate_type = InitialState.ISTATE_TYPE_BASIS
|
|
571
|
+
initial_state.pcoord = basis_state.pcoord.copy()
|
|
572
|
+
initial_state.istate_status = InitialState.ISTATE_STATUS_PREPARED
|
|
573
|
+
self.we_driver.avail_initial_states[initial_state.state_id] = initial_state
|
|
574
|
+
updated_states.append(initial_state)
|
|
575
|
+
self.data_manager.update_initial_states(updated_states, n_iter=self.n_iter + 1)
|
|
576
|
+
return futures
|
|
577
|
+
|
|
578
|
+
def propagate(self):
|
|
579
|
+
segments = list(self.incomplete_segments.values())
|
|
580
|
+
log.debug('iteration {:d}: propagating {:d} segments'.format(self.n_iter, len(segments)))
|
|
581
|
+
|
|
582
|
+
# all futures dispatched for this iteration
|
|
583
|
+
futures = set()
|
|
584
|
+
segment_futures = set()
|
|
585
|
+
|
|
586
|
+
# Immediately dispatch any necessary initial state generation
|
|
587
|
+
istate_gen_futures = self.get_istate_futures()
|
|
588
|
+
futures.update(istate_gen_futures)
|
|
589
|
+
|
|
590
|
+
# Dispatch propagation tasks using work manager
|
|
591
|
+
for segment_block in grouper(self.propagator_block_size, segments):
|
|
592
|
+
segment_block = [_f for _f in segment_block if _f]
|
|
593
|
+
pbstates, pistates = westpa.core.states.pare_basis_initial_states(
|
|
594
|
+
self.current_iter_bstates, list(self.current_iter_istates.values()), segment_block
|
|
595
|
+
)
|
|
596
|
+
future = self.work_manager.submit(wm_ops.propagate, args=(pbstates, pistates, segment_block))
|
|
597
|
+
futures.add(future)
|
|
598
|
+
segment_futures.add(future)
|
|
599
|
+
|
|
600
|
+
while futures:
|
|
601
|
+
# TODO: add capacity for timeout or SIGINT here
|
|
602
|
+
future = self.work_manager.wait_any(futures)
|
|
603
|
+
futures.remove(future)
|
|
604
|
+
|
|
605
|
+
if future in segment_futures:
|
|
606
|
+
segment_futures.remove(future)
|
|
607
|
+
incoming = future.get_result()
|
|
608
|
+
self.n_propagated += 1
|
|
609
|
+
|
|
610
|
+
self.segments.update({segment.seg_id: segment for segment in incoming})
|
|
611
|
+
self.completed_segments.update({segment.seg_id: segment for segment in incoming})
|
|
612
|
+
|
|
613
|
+
self.we_driver.assign(incoming)
|
|
614
|
+
new_istate_futures = self.get_istate_futures()
|
|
615
|
+
istate_gen_futures.update(new_istate_futures)
|
|
616
|
+
futures.update(new_istate_futures)
|
|
617
|
+
|
|
618
|
+
with self.data_manager.expiring_flushing_lock():
|
|
619
|
+
self.data_manager.update_segments(self.n_iter, incoming)
|
|
620
|
+
|
|
621
|
+
elif future in istate_gen_futures:
|
|
622
|
+
istate_gen_futures.remove(future)
|
|
623
|
+
_basis_state, initial_state = future.get_result()
|
|
624
|
+
log.debug('received newly-prepared initial state {!r}'.format(initial_state))
|
|
625
|
+
initial_state.istate_status = InitialState.ISTATE_STATUS_PREPARED
|
|
626
|
+
with self.data_manager.expiring_flushing_lock():
|
|
627
|
+
self.data_manager.update_initial_states([initial_state], n_iter=self.n_iter + 1)
|
|
628
|
+
self.we_driver.avail_initial_states[initial_state.state_id] = initial_state
|
|
629
|
+
else:
|
|
630
|
+
log.error('unknown future {!r} received from work manager'.format(future))
|
|
631
|
+
raise AssertionError('untracked future {!r}'.format(future))
|
|
632
|
+
|
|
633
|
+
log.debug('done with propagation')
|
|
634
|
+
self.save_bin_data()
|
|
635
|
+
self.data_manager.flush_backing()
|
|
636
|
+
|
|
637
|
+
def save_bin_data(self):
|
|
638
|
+
'''Calculate and write flux and transition count matrices to HDF5. Population and rate matrices
|
|
639
|
+
are likely useless at the single-tau level and are no longer written.'''
|
|
640
|
+
# save_bin_data(self, populations, n_trans, fluxes, rates, n_iter=None)
|
|
641
|
+
|
|
642
|
+
if self.save_transition_matrices:
|
|
643
|
+
with self.data_manager.expiring_flushing_lock():
|
|
644
|
+
iter_group = self.data_manager.get_iter_group(self.n_iter)
|
|
645
|
+
for key in ['bin_ntrans', 'bin_fluxes']:
|
|
646
|
+
try:
|
|
647
|
+
del iter_group[key]
|
|
648
|
+
except KeyError:
|
|
649
|
+
pass
|
|
650
|
+
iter_group['bin_ntrans'] = self.we_driver.transition_matrix
|
|
651
|
+
iter_group['bin_fluxes'] = self.we_driver.flux_matrix
|
|
652
|
+
|
|
653
|
+
def check_propagation(self):
|
|
654
|
+
'''Check for failures in propagation or initial state generation, and raise an exception
|
|
655
|
+
if any are found.'''
|
|
656
|
+
|
|
657
|
+
failed_segments = [segment for segment in self.segments.values() if segment.status != Segment.SEG_STATUS_COMPLETE]
|
|
658
|
+
|
|
659
|
+
if failed_segments:
|
|
660
|
+
failed_ids = ' \n'.join(str(segment.seg_id) for segment in failed_segments)
|
|
661
|
+
log.error('propagation failed for {:d} segment(s):\n{}'.format(len(failed_segments), failed_ids))
|
|
662
|
+
raise PropagationError('propagation failed for {:d} segments'.format(len(failed_segments)))
|
|
663
|
+
else:
|
|
664
|
+
log.debug('propagation complete for iteration {:d}'.format(self.n_iter))
|
|
665
|
+
|
|
666
|
+
failed_istates = [
|
|
667
|
+
istate
|
|
668
|
+
for istate in self.we_driver.used_initial_states.values()
|
|
669
|
+
if istate.istate_status != InitialState.ISTATE_STATUS_PREPARED
|
|
670
|
+
]
|
|
671
|
+
log.debug('{!r}'.format(failed_istates))
|
|
672
|
+
if failed_istates:
|
|
673
|
+
failed_ids = ' \n'.join(str(istate.state_id) for istate in failed_istates)
|
|
674
|
+
log.error('initial state generation failed for {:d} states:\n{}'.format(len(failed_istates), failed_ids))
|
|
675
|
+
raise PropagationError('initial state generation failed for {:d} states'.format(len(failed_istates)))
|
|
676
|
+
else:
|
|
677
|
+
log.debug('initial state generation complete for iteration {:d}'.format(self.n_iter))
|
|
678
|
+
|
|
679
|
+
def run_we(self):
|
|
680
|
+
'''Run the weighted ensemble algorithm based on the binning in self.final_bins and
|
|
681
|
+
the recycled particles in self.to_recycle, creating and committing the next iteration's
|
|
682
|
+
segments to storage as well.'''
|
|
683
|
+
|
|
684
|
+
# The WE driver now does almost everything; we just have to record the
|
|
685
|
+
# mapper used for binning this iteration, and update initial states
|
|
686
|
+
# that have been used
|
|
687
|
+
|
|
688
|
+
try:
|
|
689
|
+
pickled, hashed = self.we_driver.bin_mapper.pickle_and_hash()
|
|
690
|
+
except PickleError:
|
|
691
|
+
pickled = hashed = ''
|
|
692
|
+
|
|
693
|
+
self.bin_mapper_hash = hashed
|
|
694
|
+
self.we_driver.construct_next()
|
|
695
|
+
|
|
696
|
+
if self.we_driver.used_initial_states:
|
|
697
|
+
for initial_state in self.we_driver.used_initial_states.values():
|
|
698
|
+
initial_state.iter_used = self.n_iter + 1
|
|
699
|
+
self.data_manager.update_initial_states(list(self.we_driver.used_initial_states.values()))
|
|
700
|
+
|
|
701
|
+
self.data_manager.update_segments(self.n_iter, list(self.segments.values()))
|
|
702
|
+
|
|
703
|
+
self.data_manager.require_iter_group(self.n_iter + 1)
|
|
704
|
+
self.data_manager.save_iter_binning(self.n_iter + 1, hashed, pickled, self.we_driver.bin_target_counts)
|
|
705
|
+
|
|
706
|
+
# Report on recycling
|
|
707
|
+
recycling_events = {}
|
|
708
|
+
for nw in self.we_driver.new_weights:
|
|
709
|
+
try:
|
|
710
|
+
recycling_events[nw.target_state_id].append(nw.weight)
|
|
711
|
+
except KeyError:
|
|
712
|
+
recycling_events[nw.target_state_id] = list([nw.weight])
|
|
713
|
+
|
|
714
|
+
tstates_by_id = {state.state_id: state for state in self.we_driver.target_states.values()}
|
|
715
|
+
|
|
716
|
+
for tstate_id, weights in recycling_events.items():
|
|
717
|
+
tstate = tstates_by_id[tstate_id]
|
|
718
|
+
self.rc.pstatus(
|
|
719
|
+
'Recycled {:g} probability ({:d} walkers) from target state {!r}'.format(sum(weights), len(weights), tstate.label)
|
|
720
|
+
)
|
|
721
|
+
|
|
722
|
+
def prepare_new_iteration(self):
|
|
723
|
+
'''Commit data for the coming iteration to the HDF5 file.'''
|
|
724
|
+
self.invoke_callbacks(self.prepare_new_iteration)
|
|
725
|
+
|
|
726
|
+
if self.rc.debug_mode:
|
|
727
|
+
self.rc.pstatus('\nSegments generated:')
|
|
728
|
+
for segment in self.we_driver.next_iter_segments:
|
|
729
|
+
self.rc.pstatus('{!r} pcoord[0]={!r}'.format(segment, segment.pcoord[0]))
|
|
730
|
+
|
|
731
|
+
self.data_manager.prepare_iteration(self.n_iter + 1, list(self.we_driver.next_iter_segments))
|
|
732
|
+
self.data_manager.save_new_weight_data(self.n_iter + 1, self.we_driver.new_weights)
|
|
733
|
+
|
|
734
|
+
def run(self):
|
|
735
|
+
run_starttime = time.time()
|
|
736
|
+
max_walltime = self.max_run_walltime
|
|
737
|
+
if max_walltime:
|
|
738
|
+
run_killtime = run_starttime + max_walltime
|
|
739
|
+
self.rc.pstatus('Maximum wallclock time: %s' % timedelta(seconds=max_walltime or 0))
|
|
740
|
+
else:
|
|
741
|
+
run_killtime = None
|
|
742
|
+
|
|
743
|
+
self.n_iter = self.data_manager.current_iteration
|
|
744
|
+
max_iter = self.max_total_iterations or self.n_iter + 1
|
|
745
|
+
|
|
746
|
+
iter_elapsed = 0
|
|
747
|
+
while self.n_iter <= max_iter:
|
|
748
|
+
if max_walltime and time.time() + 1.1 * iter_elapsed >= run_killtime:
|
|
749
|
+
self.rc.pstatus('Iteration {:d} would require more than the allotted time. Ending run.'.format(self.n_iter))
|
|
750
|
+
return
|
|
751
|
+
|
|
752
|
+
try:
|
|
753
|
+
iter_start_time = time.time()
|
|
754
|
+
|
|
755
|
+
self.rc.pstatus('\n%s' % time.asctime())
|
|
756
|
+
self.rc.pstatus('Iteration %d (%d requested)' % (self.n_iter, max_iter))
|
|
757
|
+
|
|
758
|
+
self.prepare_iteration()
|
|
759
|
+
self.rc.pflush()
|
|
760
|
+
|
|
761
|
+
self.pre_propagation()
|
|
762
|
+
self.propagate()
|
|
763
|
+
self.rc.pflush()
|
|
764
|
+
self.check_propagation()
|
|
765
|
+
self.rc.pflush()
|
|
766
|
+
self.post_propagation()
|
|
767
|
+
|
|
768
|
+
cputime = sum(segment.cputime for segment in self.segments.values())
|
|
769
|
+
|
|
770
|
+
self.rc.pflush()
|
|
771
|
+
self.pre_we()
|
|
772
|
+
self.run_we()
|
|
773
|
+
self.post_we()
|
|
774
|
+
self.rc.pflush()
|
|
775
|
+
|
|
776
|
+
self.prepare_new_iteration()
|
|
777
|
+
|
|
778
|
+
self.finalize_iteration()
|
|
779
|
+
|
|
780
|
+
iter_elapsed = time.time() - iter_start_time
|
|
781
|
+
iter_summary = self.data_manager.get_iter_summary()
|
|
782
|
+
iter_summary['walltime'] += iter_elapsed
|
|
783
|
+
iter_summary['cputime'] = cputime
|
|
784
|
+
self.data_manager.update_iter_summary(iter_summary)
|
|
785
|
+
|
|
786
|
+
self.n_iter += 1
|
|
787
|
+
self.data_manager.current_iteration += 1
|
|
788
|
+
|
|
789
|
+
try:
|
|
790
|
+
# This may give NaN if starting a truncated simulation
|
|
791
|
+
walltime = timedelta(seconds=float(iter_summary['walltime']))
|
|
792
|
+
except ValueError:
|
|
793
|
+
walltime = 0.0
|
|
794
|
+
|
|
795
|
+
try:
|
|
796
|
+
cputime = timedelta(seconds=float(iter_summary['cputime']))
|
|
797
|
+
except ValueError:
|
|
798
|
+
cputime = 0.0
|
|
799
|
+
|
|
800
|
+
self.rc.pstatus('Iteration wallclock: {0!s}, cputime: {1!s}\n'.format(walltime, cputime))
|
|
801
|
+
self.rc.pflush()
|
|
802
|
+
finally:
|
|
803
|
+
self.data_manager.flush_backing()
|
|
804
|
+
|
|
805
|
+
self.rc.pstatus('\n%s' % time.asctime())
|
|
806
|
+
self.rc.pstatus('WEST run complete.')
|
|
807
|
+
|
|
808
|
+
def prepare_run(self):
|
|
809
|
+
'''Prepare a new run.'''
|
|
810
|
+
self.data_manager.prepare_run()
|
|
811
|
+
self.system.prepare_run()
|
|
812
|
+
self.invoke_callbacks(self.prepare_run)
|
|
813
|
+
|
|
814
|
+
def finalize_run(self):
|
|
815
|
+
'''Perform cleanup at the normal end of a run'''
|
|
816
|
+
self.invoke_callbacks(self.finalize_run)
|
|
817
|
+
self.system.finalize_run()
|
|
818
|
+
self.data_manager.finalize_run()
|
|
819
|
+
|
|
820
|
+
def pre_propagation(self):
|
|
821
|
+
self.invoke_callbacks(self.pre_propagation)
|
|
822
|
+
|
|
823
|
+
def post_propagation(self):
|
|
824
|
+
self.invoke_callbacks(self.post_propagation)
|
|
825
|
+
|
|
826
|
+
def pre_we(self):
|
|
827
|
+
self.invoke_callbacks(self.pre_we)
|
|
828
|
+
|
|
829
|
+
def post_we(self):
|
|
830
|
+
self.invoke_callbacks(self.post_we)
|