westpa 2022.10__cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of westpa might be problematic. Click here for more details.

Files changed (150) hide show
  1. westpa/__init__.py +14 -0
  2. westpa/_version.py +21 -0
  3. westpa/analysis/__init__.py +5 -0
  4. westpa/analysis/core.py +746 -0
  5. westpa/analysis/statistics.py +27 -0
  6. westpa/analysis/trajectories.py +360 -0
  7. westpa/cli/__init__.py +0 -0
  8. westpa/cli/core/__init__.py +0 -0
  9. westpa/cli/core/w_fork.py +152 -0
  10. westpa/cli/core/w_init.py +230 -0
  11. westpa/cli/core/w_run.py +77 -0
  12. westpa/cli/core/w_states.py +212 -0
  13. westpa/cli/core/w_succ.py +99 -0
  14. westpa/cli/core/w_truncate.py +59 -0
  15. westpa/cli/tools/__init__.py +0 -0
  16. westpa/cli/tools/ploterr.py +506 -0
  17. westpa/cli/tools/plothist.py +706 -0
  18. westpa/cli/tools/w_assign.py +596 -0
  19. westpa/cli/tools/w_bins.py +166 -0
  20. westpa/cli/tools/w_crawl.py +119 -0
  21. westpa/cli/tools/w_direct.py +547 -0
  22. westpa/cli/tools/w_dumpsegs.py +94 -0
  23. westpa/cli/tools/w_eddist.py +506 -0
  24. westpa/cli/tools/w_fluxanl.py +378 -0
  25. westpa/cli/tools/w_ipa.py +833 -0
  26. westpa/cli/tools/w_kinavg.py +127 -0
  27. westpa/cli/tools/w_kinetics.py +96 -0
  28. westpa/cli/tools/w_multi_west.py +414 -0
  29. westpa/cli/tools/w_ntop.py +213 -0
  30. westpa/cli/tools/w_pdist.py +515 -0
  31. westpa/cli/tools/w_postanalysis_matrix.py +82 -0
  32. westpa/cli/tools/w_postanalysis_reweight.py +53 -0
  33. westpa/cli/tools/w_red.py +486 -0
  34. westpa/cli/tools/w_reweight.py +780 -0
  35. westpa/cli/tools/w_select.py +226 -0
  36. westpa/cli/tools/w_stateprobs.py +111 -0
  37. westpa/cli/tools/w_trace.py +599 -0
  38. westpa/core/__init__.py +0 -0
  39. westpa/core/_rc.py +673 -0
  40. westpa/core/binning/__init__.py +55 -0
  41. westpa/core/binning/_assign.cpython-312-x86_64-linux-gnu.so +0 -0
  42. westpa/core/binning/assign.py +449 -0
  43. westpa/core/binning/binless.py +96 -0
  44. westpa/core/binning/binless_driver.py +54 -0
  45. westpa/core/binning/binless_manager.py +190 -0
  46. westpa/core/binning/bins.py +47 -0
  47. westpa/core/binning/mab.py +427 -0
  48. westpa/core/binning/mab_driver.py +54 -0
  49. westpa/core/binning/mab_manager.py +198 -0
  50. westpa/core/data_manager.py +1694 -0
  51. westpa/core/extloader.py +74 -0
  52. westpa/core/h5io.py +995 -0
  53. westpa/core/kinetics/__init__.py +24 -0
  54. westpa/core/kinetics/_kinetics.cpython-312-x86_64-linux-gnu.so +0 -0
  55. westpa/core/kinetics/events.py +147 -0
  56. westpa/core/kinetics/matrates.py +156 -0
  57. westpa/core/kinetics/rate_averaging.py +266 -0
  58. westpa/core/progress.py +218 -0
  59. westpa/core/propagators/__init__.py +54 -0
  60. westpa/core/propagators/executable.py +715 -0
  61. westpa/core/reweight/__init__.py +14 -0
  62. westpa/core/reweight/_reweight.cpython-312-x86_64-linux-gnu.so +0 -0
  63. westpa/core/reweight/matrix.py +126 -0
  64. westpa/core/segment.py +119 -0
  65. westpa/core/sim_manager.py +830 -0
  66. westpa/core/states.py +359 -0
  67. westpa/core/systems.py +93 -0
  68. westpa/core/textio.py +74 -0
  69. westpa/core/trajectory.py +330 -0
  70. westpa/core/we_driver.py +908 -0
  71. westpa/core/wm_ops.py +43 -0
  72. westpa/core/yamlcfg.py +391 -0
  73. westpa/fasthist/__init__.py +34 -0
  74. westpa/fasthist/__main__.py +110 -0
  75. westpa/fasthist/_fasthist.cpython-312-x86_64-linux-gnu.so +0 -0
  76. westpa/mclib/__init__.py +264 -0
  77. westpa/mclib/__main__.py +28 -0
  78. westpa/mclib/_mclib.cpython-312-x86_64-linux-gnu.so +0 -0
  79. westpa/oldtools/__init__.py +4 -0
  80. westpa/oldtools/aframe/__init__.py +35 -0
  81. westpa/oldtools/aframe/atool.py +75 -0
  82. westpa/oldtools/aframe/base_mixin.py +26 -0
  83. westpa/oldtools/aframe/binning.py +178 -0
  84. westpa/oldtools/aframe/data_reader.py +560 -0
  85. westpa/oldtools/aframe/iter_range.py +200 -0
  86. westpa/oldtools/aframe/kinetics.py +117 -0
  87. westpa/oldtools/aframe/mcbs.py +146 -0
  88. westpa/oldtools/aframe/output.py +39 -0
  89. westpa/oldtools/aframe/plotting.py +90 -0
  90. westpa/oldtools/aframe/trajwalker.py +126 -0
  91. westpa/oldtools/aframe/transitions.py +469 -0
  92. westpa/oldtools/cmds/__init__.py +0 -0
  93. westpa/oldtools/cmds/w_ttimes.py +358 -0
  94. westpa/oldtools/files.py +34 -0
  95. westpa/oldtools/miscfn.py +23 -0
  96. westpa/oldtools/stats/__init__.py +4 -0
  97. westpa/oldtools/stats/accumulator.py +35 -0
  98. westpa/oldtools/stats/edfs.py +129 -0
  99. westpa/oldtools/stats/mcbs.py +89 -0
  100. westpa/tools/__init__.py +33 -0
  101. westpa/tools/binning.py +472 -0
  102. westpa/tools/core.py +340 -0
  103. westpa/tools/data_reader.py +159 -0
  104. westpa/tools/dtypes.py +31 -0
  105. westpa/tools/iter_range.py +198 -0
  106. westpa/tools/kinetics_tool.py +340 -0
  107. westpa/tools/plot.py +283 -0
  108. westpa/tools/progress.py +17 -0
  109. westpa/tools/selected_segs.py +154 -0
  110. westpa/tools/wipi.py +751 -0
  111. westpa/trajtree/__init__.py +4 -0
  112. westpa/trajtree/_trajtree.cpython-312-x86_64-linux-gnu.so +0 -0
  113. westpa/trajtree/trajtree.py +117 -0
  114. westpa/westext/__init__.py +0 -0
  115. westpa/westext/adaptvoronoi/__init__.py +3 -0
  116. westpa/westext/adaptvoronoi/adaptVor_driver.py +214 -0
  117. westpa/westext/hamsm_restarting/__init__.py +3 -0
  118. westpa/westext/hamsm_restarting/example_overrides.py +35 -0
  119. westpa/westext/hamsm_restarting/restart_driver.py +1165 -0
  120. westpa/westext/stringmethod/__init__.py +11 -0
  121. westpa/westext/stringmethod/fourier_fitting.py +69 -0
  122. westpa/westext/stringmethod/string_driver.py +253 -0
  123. westpa/westext/stringmethod/string_method.py +306 -0
  124. westpa/westext/weed/BinCluster.py +180 -0
  125. westpa/westext/weed/ProbAdjustEquil.py +100 -0
  126. westpa/westext/weed/UncertMath.py +247 -0
  127. westpa/westext/weed/__init__.py +10 -0
  128. westpa/westext/weed/weed_driver.py +182 -0
  129. westpa/westext/wess/ProbAdjust.py +101 -0
  130. westpa/westext/wess/__init__.py +6 -0
  131. westpa/westext/wess/wess_driver.py +207 -0
  132. westpa/work_managers/__init__.py +57 -0
  133. westpa/work_managers/core.py +396 -0
  134. westpa/work_managers/environment.py +134 -0
  135. westpa/work_managers/mpi.py +318 -0
  136. westpa/work_managers/processes.py +187 -0
  137. westpa/work_managers/serial.py +28 -0
  138. westpa/work_managers/threads.py +79 -0
  139. westpa/work_managers/zeromq/__init__.py +20 -0
  140. westpa/work_managers/zeromq/core.py +641 -0
  141. westpa/work_managers/zeromq/node.py +131 -0
  142. westpa/work_managers/zeromq/work_manager.py +526 -0
  143. westpa/work_managers/zeromq/worker.py +320 -0
  144. westpa-2022.10.dist-info/AUTHORS +22 -0
  145. westpa-2022.10.dist-info/LICENSE +21 -0
  146. westpa-2022.10.dist-info/METADATA +183 -0
  147. westpa-2022.10.dist-info/RECORD +150 -0
  148. westpa-2022.10.dist-info/WHEEL +6 -0
  149. westpa-2022.10.dist-info/entry_points.txt +29 -0
  150. westpa-2022.10.dist-info/top_level.txt +1 -0
@@ -0,0 +1,830 @@
1
+ import logging
2
+ import math
3
+ import operator
4
+ import random
5
+ import time
6
+ from datetime import timedelta
7
+ from pickle import PickleError
8
+ from itertools import zip_longest
9
+ from collections import Counter
10
+
11
+ import numpy as np
12
+
13
+ import westpa
14
+ from .data_manager import weight_dtype
15
+ from .segment import Segment
16
+ from .states import InitialState
17
+ from . import extloader
18
+ from . import wm_ops
19
+
20
+
21
+ log = logging.getLogger(__name__)
22
+
23
+ EPS = np.finfo(weight_dtype).eps
24
+
25
+
26
+ def grouper(n, iterable, fillvalue=None):
27
+ "Collect data into fixed-length chunks or blocks"
28
+ # grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx
29
+ args = [iter(iterable)] * n
30
+ return zip_longest(fillvalue=fillvalue, *args)
31
+
32
+
33
+ class PropagationError(RuntimeError):
34
+ pass
35
+
36
+
37
+ class WESimManager:
38
+ def process_config(self):
39
+ config = self.rc.config
40
+ for entry, type_ in [('gen_istates', bool), ('block_size', int), ('save_transition_matrices', bool)]:
41
+ config.require_type_if_present(['west', 'propagation', entry], type_)
42
+
43
+ self.do_gen_istates = config.get(['west', 'propagation', 'gen_istates'], False)
44
+ self.propagator_block_size = config.get(['west', 'propagation', 'block_size'], 1)
45
+ self.save_transition_matrices = config.get(['west', 'propagation', 'save_transition_matrices'], False)
46
+ self.max_run_walltime = config.get(['west', 'propagation', 'max_run_wallclock'], default=None)
47
+ self.max_total_iterations = config.get(['west', 'propagation', 'max_total_iterations'], default=None)
48
+
49
+ def __init__(self, rc=None):
50
+ self.rc = rc or westpa.rc
51
+ self.work_manager = self.rc.get_work_manager()
52
+ self.data_manager = self.rc.get_data_manager()
53
+ self.we_driver = self.rc.get_we_driver()
54
+ self.system = self.rc.get_system_driver()
55
+
56
+ # A table of function -> list of (priority, name, callback) tuples
57
+ self._callback_table = {}
58
+ self._valid_callbacks = set(
59
+ (
60
+ self.prepare_run,
61
+ self.finalize_run,
62
+ self.prepare_iteration,
63
+ self.finalize_iteration,
64
+ self.pre_propagation,
65
+ self.post_propagation,
66
+ self.pre_we,
67
+ self.post_we,
68
+ self.prepare_new_iteration,
69
+ )
70
+ )
71
+ self._callbacks_by_name = {fn.__name__: fn for fn in self._valid_callbacks}
72
+ self.n_propagated = 0
73
+
74
+ # config items
75
+ self.do_gen_istates = False
76
+ self.propagator_block_size = 1
77
+ self.save_transition_matrices = False
78
+ self.max_run_walltime = None
79
+ self.max_total_iterations = None
80
+ self.process_config()
81
+
82
+ # Per-iteration variables
83
+ self.n_iter = None # current iteration
84
+
85
+ # Basis and initial states for this iteration, in case the propagator needs them
86
+ self.current_iter_bstates = None # BasisStates valid at this iteration
87
+ self.current_iter_istates = None # InitialStates used in this iteration
88
+
89
+ # Basis states for next iteration
90
+ self.next_iter_bstates = None # BasisStates valid for the next iteration
91
+ self.next_iter_bstate_cprobs = None # Cumulative probabilities for basis states, used for selection
92
+
93
+ # Tracking of this iteration's segments
94
+ self.segments = None # Mapping of seg_id to segment for all segments in this iteration
95
+ self.completed_segments = None # Mapping of seg_id to segment for all completed segments in this iteration
96
+ self.incomplete_segments = None # Mapping of seg_id to segment for all incomplete segments in this iteration
97
+
98
+ # Tracking of binning
99
+ self.bin_mapper_hash = None # Hash of bin mapper from most recently-run WE, for use by post-WE analysis plugins
100
+
101
+ def register_callback(self, hook, function, priority=0):
102
+ '''Registers a callback to execute during the given ``hook`` into the simulation loop. The optional
103
+ priority is used to order when the function is called relative to other registered callbacks.'''
104
+
105
+ if hook not in self._valid_callbacks:
106
+ try:
107
+ hook = self._callbacks_by_name[hook]
108
+ except KeyError:
109
+ raise KeyError('invalid hook {!r}'.format(hook))
110
+
111
+ # It's possible to register a callback that's a duplicate function, but at a different place in memory.
112
+ # For example, if you launch a run without clearing state from a previous run.
113
+ # More details on this are available in https://github.com/westpa/westpa/issues/182 but the below code
114
+ # handles specifically the problem that causes in plugin loading.
115
+ try:
116
+ # Before checking for set membership of (priority, function.__name__, function), just check
117
+ # function hash for collisions in this hook.
118
+ hook_function_hash = [hash(callback[2]) for callback in self._callback_table[hook]]
119
+ except KeyError:
120
+ # If there's no entry in self._callback_table for this hook, then there definitely aren't any collisions
121
+ # because no plugins are registered to it yet in the first place.
122
+ pass
123
+ else:
124
+ # If there are plugins registered to this hook, check for duplicate hash, which will definitely have the same name, module, function.
125
+ try:
126
+ if hash(function) in hook_function_hash:
127
+ log.info('{!r} has already been loaded, skipping'.format(function))
128
+ return
129
+ except KeyError:
130
+ pass
131
+
132
+ try:
133
+ self._callback_table[hook].add((priority, function.__name__, function))
134
+ except KeyError:
135
+ self._callback_table[hook] = set([(priority, function.__name__, function)])
136
+
137
+ # Raise warning if there are multiple callback with same priority.
138
+ for priority, count in Counter([callback[0] for callback in self._callback_table[hook]]).items():
139
+ if count > 1:
140
+ log.warning(
141
+ f'{count} callbacks in {hook} have identical priority {priority}. The order of callback execution is not guaranteed.'
142
+ )
143
+ log.warning(f'{hook}: {self._callback_table[hook]}')
144
+
145
+ log.debug('registered callback {!r} for hook {!r}'.format(function, hook))
146
+
147
+ def invoke_callbacks(self, hook, *args, **kwargs):
148
+ callbacks = self._callback_table.get(hook, [])
149
+ # Sort by priority, function name, then module name
150
+ sorted_callbacks = sorted(callbacks, key=lambda x: (x[0], x[1], x[2].__module__))
151
+ for priority, name, fn in sorted_callbacks:
152
+ log.debug('invoking callback {!r} for hook {!r}'.format(fn, hook))
153
+ fn(*args, **kwargs)
154
+
155
+ def load_plugins(self, plugins=None):
156
+ if plugins is None:
157
+ plugins = []
158
+
159
+ try:
160
+ plugins_config = westpa.rc.config['west', 'plugins']
161
+ except KeyError:
162
+ plugins_config = []
163
+
164
+ if plugins_config is None:
165
+ plugins_config = []
166
+
167
+ plugins += plugins_config
168
+
169
+ for plugin_config in plugins:
170
+ plugin_name = plugin_config['plugin']
171
+ if plugin_config.get('enabled', True):
172
+ log.info('loading plugin {!r}'.format(plugin_name))
173
+ plugin = extloader.get_object(plugin_name)(self, plugin_config)
174
+ log.debug('loaded plugin {!r}'.format(plugin))
175
+
176
+ def report_bin_statistics(self, bins, target_states, save_summary=False):
177
+ segments = list(self.segments.values())
178
+ bin_counts = np.fromiter(map(len, bins), dtype=np.int_, count=len(bins))
179
+ target_counts = self.we_driver.bin_target_counts
180
+
181
+ # Do not include bins with target count zero (e.g. sinks, never-filled bins) in the (non)empty bins statistics
182
+ n_active_bins = len(target_counts[target_counts != 0])
183
+
184
+ if target_states:
185
+ n_active_bins -= len(target_states)
186
+
187
+ seg_probs = np.fromiter(map(operator.attrgetter('weight'), segments), dtype=weight_dtype, count=len(segments))
188
+ bin_probs = np.fromiter(map(operator.attrgetter('weight'), bins), dtype=weight_dtype, count=len(bins))
189
+ norm = seg_probs.sum()
190
+
191
+ assert abs(1 - norm) < EPS * (len(segments) + n_active_bins)
192
+
193
+ min_seg_prob = seg_probs[seg_probs != 0].min()
194
+ max_seg_prob = seg_probs.max()
195
+ seg_drange = math.log(max_seg_prob / min_seg_prob)
196
+ min_bin_prob = bin_probs[bin_probs != 0].min()
197
+ max_bin_prob = bin_probs.max()
198
+ bin_drange = math.log(max_bin_prob / min_bin_prob)
199
+ n_pop = len(bin_counts[bin_counts != 0])
200
+
201
+ self.rc.pstatus('{:d} of {:d} ({:%}) active bins are populated'.format(n_pop, n_active_bins, n_pop / n_active_bins))
202
+ self.rc.pstatus('per-bin minimum non-zero probability: {:g}'.format(min_bin_prob))
203
+ self.rc.pstatus('per-bin maximum probability: {:g}'.format(max_bin_prob))
204
+ self.rc.pstatus('per-bin probability dynamic range (kT): {:g}'.format(bin_drange))
205
+ self.rc.pstatus('per-segment minimum non-zero probability: {:g}'.format(min_seg_prob))
206
+ self.rc.pstatus('per-segment maximum non-zero probability: {:g}'.format(max_seg_prob))
207
+ self.rc.pstatus('per-segment probability dynamic range (kT): {:g}'.format(seg_drange))
208
+ self.rc.pstatus('norm = {:g}, error in norm = {:g} ({:.2g}*epsilon)'.format(norm, (norm - 1), (norm - 1) / EPS))
209
+ self.rc.pflush()
210
+
211
+ if min_seg_prob < 1e-100:
212
+ log.warning(
213
+ '\nMinimum segment weight is < 1e-100 and might not be physically relevant. Please reconsider your progress coordinate or binning scheme.'
214
+ )
215
+
216
+ if save_summary:
217
+ iter_summary = self.data_manager.get_iter_summary()
218
+ iter_summary['n_particles'] = len(segments)
219
+ iter_summary['norm'] = norm
220
+ iter_summary['min_bin_prob'] = min_bin_prob
221
+ iter_summary['max_bin_prob'] = max_bin_prob
222
+ iter_summary['min_seg_prob'] = min_seg_prob
223
+ iter_summary['max_seg_prob'] = max_seg_prob
224
+ if np.isnan(iter_summary['cputime']):
225
+ iter_summary['cputime'] = 0.0
226
+ if np.isnan(iter_summary['walltime']):
227
+ iter_summary['walltime'] = 0.0
228
+ self.data_manager.update_iter_summary(iter_summary)
229
+
230
+ def get_bstate_pcoords(self, basis_states, label='basis'):
231
+ '''For each of the given ``basis_states``, calculate progress coordinate values
232
+ as necessary. The HDF5 file is not updated.'''
233
+
234
+ self.rc.pstatus('Calculating progress coordinate values for {} states.'.format(label))
235
+ futures = [self.work_manager.submit(wm_ops.get_pcoord, args=(basis_state,)) for basis_state in basis_states]
236
+ fmap = {future: i for (i, future) in enumerate(futures)}
237
+ for future in self.work_manager.as_completed(futures):
238
+ basis_states[fmap[future]].pcoord = future.get_result().pcoord
239
+
240
+ def report_basis_states(self, basis_states, label='basis'):
241
+ pstatus = self.rc.pstatus
242
+ pstatus('{:d} {} state(s) present'.format(len(basis_states), label), end='')
243
+ if self.rc.verbose_mode:
244
+ pstatus(':')
245
+ pstatus(
246
+ '{:6s} {:12s} {:20s} {:20s} {}'.format(
247
+ 'ID', 'Label', 'Probability', 'Aux Reference', 'Progress Coordinate'
248
+ )
249
+ )
250
+ for basis_state in basis_states:
251
+ pstatus(
252
+ '{:<6d} {:12s} {:<20.14g} {:20s} {}'.format(
253
+ basis_state.state_id,
254
+ basis_state.label,
255
+ basis_state.probability,
256
+ basis_state.auxref or '',
257
+ ', '.join(map(str, basis_state.pcoord)),
258
+ )
259
+ )
260
+ pstatus()
261
+ self.rc.pflush()
262
+
263
+ def report_target_states(self, target_states):
264
+ pstatus = self.rc.pstatus
265
+ pstatus('{:d} target state(s) present'.format(len(target_states)), end='')
266
+ if self.rc.verbose_mode and target_states:
267
+ pstatus(':')
268
+ pstatus('{:6s} {:12s} {}'.format('ID', 'Label', 'Progress Coordinate'))
269
+ for target_state in target_states:
270
+ pstatus(
271
+ '{:<6d} {:12s} {}'.format(
272
+ target_state.state_id, target_state.label, ','.join(map(str, target_state.pcoord))
273
+ )
274
+ )
275
+ pstatus()
276
+ self.rc.pflush()
277
+
278
+ def initialize_simulation(self, basis_states, target_states, start_states, segs_per_state=1, suppress_we=False):
279
+ '''Initialize a new weighted ensemble simulation, taking ``segs_per_state`` initial
280
+ states from each of the given ``basis_states``.
281
+
282
+ ``w_init`` is the forward-facing version of this function'''
283
+
284
+ data_manager = self.data_manager
285
+ work_manager = self.work_manager
286
+ pstatus = self.rc.pstatus
287
+ system = self.system
288
+
289
+ pstatus('Creating HDF5 file {!r}'.format(self.data_manager.we_h5filename))
290
+ data_manager.prepare_backing()
291
+
292
+ # Process target states
293
+ data_manager.save_target_states(target_states)
294
+ self.report_target_states(target_states)
295
+
296
+ # Process basis states
297
+ self.get_bstate_pcoords(basis_states)
298
+ self.data_manager.create_ibstate_group(basis_states)
299
+ self.data_manager.create_ibstate_iter_h5file(basis_states)
300
+ self.report_basis_states(basis_states)
301
+
302
+ # Process start states
303
+ # Unlike the above, does not create an ibstate group.
304
+ # TODO: Should it? I don't think so, if needed it can be traced back through basis_auxref
305
+
306
+ # Here, we are trying to assign a state_id to the start state to be initialized, without actually
307
+ # saving it to the ibstates records in any of the h5 files. It might actually be a problem
308
+ # when tracing trajectories with westpa.analysis (especially with HDF5 framework) since it would
309
+ # try to look for a basis state id > len(basis_state).
310
+ # Since start states are only used while initializing and iteration 1, it's ok to not save it to save space. If necessary,
311
+ # the structure can be traced directly to the parent file using the standard basis state logic referencing
312
+ # west[iterations/iter_00000001/ibstates/istate_index/basis_auxref] of that istate.
313
+
314
+ if len(start_states) > 0 and start_states[0].state_id is None:
315
+ last_id = basis_states[-1].state_id
316
+ for start_state in start_states:
317
+ start_state.state_id = last_id + 1
318
+ last_id += 1
319
+
320
+ self.get_bstate_pcoords(start_states, label='start')
321
+ self.report_basis_states(start_states, label='start')
322
+
323
+ pstatus('Preparing initial states')
324
+ initial_states = []
325
+ weights = []
326
+ if self.do_gen_istates:
327
+ istate_type = InitialState.ISTATE_TYPE_GENERATED
328
+ else:
329
+ istate_type = InitialState.ISTATE_TYPE_BASIS
330
+
331
+ for basis_state in basis_states:
332
+ for _iseg in range(segs_per_state):
333
+ initial_state = data_manager.create_initial_states(1, 1)[0]
334
+ initial_state.basis_state_id = basis_state.state_id
335
+ initial_state.basis_state = basis_state
336
+ initial_state.istate_type = istate_type
337
+ weights.append(basis_state.probability / segs_per_state)
338
+ initial_states.append(initial_state)
339
+
340
+ for start_state in start_states:
341
+ for _iseg in range(segs_per_state):
342
+ initial_state = data_manager.create_initial_states(1, 1)[0]
343
+ initial_state.basis_state_id = start_state.state_id
344
+ initial_state.basis_state = start_state
345
+ initial_state.basis_auxref = start_state.auxref
346
+
347
+ # Start states are assigned their own type, so they can be identified later
348
+ initial_state.istate_type = InitialState.ISTATE_TYPE_START
349
+ weights.append(start_state.probability / segs_per_state)
350
+ initial_state.iter_used = 1
351
+ initial_states.append(initial_state)
352
+
353
+ if self.do_gen_istates:
354
+ futures = [
355
+ work_manager.submit(wm_ops.gen_istate, args=(initial_state.basis_state, initial_state))
356
+ for initial_state in initial_states
357
+ ]
358
+ for future in work_manager.as_completed(futures):
359
+ rbstate, ristate = future.get_result()
360
+ initial_states[ristate.state_id].pcoord = ristate.pcoord
361
+ else:
362
+ for initial_state in initial_states:
363
+ basis_state = initial_state.basis_state
364
+ initial_state.pcoord = basis_state.pcoord
365
+ initial_state.istate_status = InitialState.ISTATE_STATUS_PREPARED
366
+
367
+ for initial_state in initial_states:
368
+ log.debug('initial state created: {!r}'.format(initial_state))
369
+
370
+ # save list of initial states just generated
371
+ # some of these may not be used, depending on how WE shakes out
372
+ data_manager.update_initial_states(initial_states, n_iter=1)
373
+
374
+ if not suppress_we:
375
+ self.we_driver.populate_initial(initial_states, weights, system)
376
+ segments = list(self.we_driver.next_iter_segments)
377
+ binning = self.we_driver.next_iter_binning
378
+ else:
379
+ segments = list(self.we_driver.current_iter_segments)
380
+ binning = self.we_driver.final_binning
381
+
382
+ bin_occupancies = np.fromiter(map(len, binning), dtype=np.uint, count=self.we_driver.bin_mapper.nbins)
383
+ target_occupancies = np.require(self.we_driver.bin_target_counts, dtype=np.uint)
384
+
385
+ # total_bins/replicas defined here to remove target state bin from "active" bins
386
+ total_bins = len(bin_occupancies) - len(target_states)
387
+ total_replicas = int(sum(target_occupancies)) - int(self.we_driver.bin_target_counts[-1]) * len(target_states)
388
+
389
+ # Make sure we have
390
+ for segment in segments:
391
+ segment.n_iter = 1
392
+ segment.status = Segment.SEG_STATUS_PREPARED
393
+ assert segment.parent_id < 0
394
+ assert initial_states[segment.initial_state_id].iter_used == 1
395
+
396
+ data_manager.prepare_iteration(1, segments)
397
+ data_manager.update_initial_states(initial_states, n_iter=1)
398
+
399
+ if self.rc.verbose_mode:
400
+ pstatus('\nSegments generated:')
401
+ for segment in segments:
402
+ pstatus('{!r}'.format(segment))
403
+
404
+ pstatus(
405
+ '''
406
+ Total bins: {total_bins:d}
407
+ Initial replicas: {init_replicas:d} in {occ_bins:d} bins, total weight = {weight:g}
408
+ Total target replicas: {total_replicas:d}
409
+ '''.format(
410
+ total_bins=total_bins,
411
+ init_replicas=int(sum(bin_occupancies)),
412
+ occ_bins=len(bin_occupancies[bin_occupancies > 0]),
413
+ weight=float(sum(segment.weight for segment in segments)),
414
+ total_replicas=total_replicas,
415
+ )
416
+ )
417
+
418
+ total_prob = float(sum(segment.weight for segment in segments))
419
+ pstatus(f'1-prob: {1 - total_prob:.4e}')
420
+
421
+ target_counts = self.we_driver.bin_target_counts
422
+ # Do not include bins with target count zero (e.g. sinks, never-filled bins) in the (non)empty bins statistics
423
+ n_active_bins = len(target_counts[target_counts != 0])
424
+ seg_probs = np.fromiter(map(operator.attrgetter('weight'), segments), dtype=weight_dtype, count=len(segments))
425
+ norm = seg_probs.sum()
426
+
427
+ if not abs(1 - norm) < EPS * (len(segments) + n_active_bins):
428
+ pstatus("Normalization check failed at w_init, explicitly renormalizing")
429
+ for segment in segments:
430
+ segment.weight /= norm
431
+
432
+ # Send the segments over to the data manager to commit to disk
433
+ data_manager.current_iteration = 1
434
+
435
+ # Report statistics
436
+ pstatus('Simulation prepared.')
437
+ self.segments = {segment.seg_id: segment for segment in segments}
438
+ self.report_bin_statistics(binning, target_states, save_summary=True)
439
+ data_manager.flush_backing()
440
+ data_manager.close_backing()
441
+
442
+ def prepare_iteration(self):
443
+ log.debug('beginning iteration {:d}'.format(self.n_iter))
444
+
445
+ # the WE driver needs a list of all target states for this iteration
446
+ # along with information about any new weights introduced (e.g. by recycling)
447
+ target_states = self.data_manager.get_target_states(self.n_iter)
448
+ new_weights = self.data_manager.get_new_weight_data(self.n_iter)
449
+
450
+ self.we_driver.new_iteration(target_states=target_states, new_weights=new_weights)
451
+
452
+ # Get basis states used in this iteration
453
+ self.current_iter_bstates = self.data_manager.get_basis_states(self.n_iter)
454
+
455
+ # Get the segments for this iteration and separate into complete and incomplete
456
+ if self.segments is None:
457
+ segments = self.segments = {segment.seg_id: segment for segment in self.data_manager.get_segments()}
458
+ log.debug('loaded {:d} segments'.format(len(segments)))
459
+ else:
460
+ segments = self.segments
461
+ log.debug('using {:d} pre-existing segments'.format(len(segments)))
462
+
463
+ completed_segments = self.completed_segments = {}
464
+ incomplete_segments = self.incomplete_segments = {}
465
+ for segment in segments.values():
466
+ if segment.status == Segment.SEG_STATUS_COMPLETE:
467
+ completed_segments[segment.seg_id] = segment
468
+ else:
469
+ incomplete_segments[segment.seg_id] = segment
470
+ log.debug('{:d} segments are complete; {:d} are incomplete'.format(len(completed_segments), len(incomplete_segments)))
471
+
472
+ if len(incomplete_segments) == len(segments):
473
+ # Starting a new iteration
474
+ self.rc.pstatus('Beginning iteration {:d}'.format(self.n_iter))
475
+ elif incomplete_segments:
476
+ self.rc.pstatus('Continuing iteration {:d}'.format(self.n_iter))
477
+ self.rc.pstatus(
478
+ '{:d} segments remain in iteration {:d} ({:d} total)'.format(len(incomplete_segments), self.n_iter, len(segments))
479
+ )
480
+
481
+ # Get the initial states active for this iteration (so that the propagator has them if necessary)
482
+ self.current_iter_istates = {
483
+ state.state_id: state for state in self.data_manager.get_segment_initial_states(list(segments.values()))
484
+ }
485
+ log.debug('This iteration uses {:d} initial states'.format(len(self.current_iter_istates)))
486
+
487
+ # Assign this iteration's segments' initial points to bins and report on bin population
488
+ initial_pcoords = self.system.new_pcoord_array(len(segments))
489
+ initial_binning = self.system.bin_mapper.construct_bins()
490
+ for iseg, segment in enumerate(segments.values()):
491
+ initial_pcoords[iseg] = segment.pcoord[0]
492
+ initial_assignments = self.system.bin_mapper.assign(initial_pcoords)
493
+ for segment, assignment in zip(iter(segments.values()), initial_assignments):
494
+ initial_binning[assignment].add(segment)
495
+ self.report_bin_statistics(initial_binning, [], save_summary=True)
496
+ del initial_pcoords, initial_binning
497
+
498
+ self.rc.pstatus('Waiting for segments to complete...')
499
+
500
+ # Let the WE driver assign completed segments
501
+ if completed_segments:
502
+ self.we_driver.assign(list(completed_segments.values()))
503
+
504
+ # load restart data
505
+ self.data_manager.prepare_segment_restarts(
506
+ incomplete_segments.values(), self.current_iter_bstates, self.current_iter_istates
507
+ )
508
+
509
+ # Get the basis states and initial states for the next iteration, necessary for doing on-the-fly recycling
510
+ self.next_iter_bstates = self.data_manager.get_basis_states(self.n_iter + 1)
511
+ self.next_iter_bstate_cprobs = np.add.accumulate([bstate.probability for bstate in self.next_iter_bstates])
512
+
513
+ self.we_driver.avail_initial_states = {
514
+ istate.state_id: istate for istate in self.data_manager.get_unused_initial_states(n_iter=self.n_iter + 1)
515
+ }
516
+ log.debug('{:d} unused initial states found'.format(len(self.we_driver.avail_initial_states)))
517
+
518
+ # Invoke callbacks
519
+ self.invoke_callbacks(self.prepare_iteration)
520
+
521
+ # dispatch and immediately wait on result for prep_iter
522
+ log.debug('dispatching propagator prep_iter to work manager')
523
+ self.work_manager.submit(wm_ops.prep_iter, args=(self.n_iter, segments)).get_result()
524
+
525
+ def finalize_iteration(self):
526
+ '''Clean up after an iteration and prepare for the next.'''
527
+ log.debug('finalizing iteration {:d}'.format(self.n_iter))
528
+
529
+ self.invoke_callbacks(self.finalize_iteration)
530
+
531
+ # dispatch and immediately wait on result for post_iter
532
+ log.debug('dispatching propagator post_iter to work manager')
533
+ self.work_manager.submit(wm_ops.post_iter, args=(self.n_iter, list(self.segments.values()))).get_result()
534
+
535
+ # Move existing segments into place as new segments
536
+ del self.segments
537
+ self.segments = {segment.seg_id: segment for segment in self.we_driver.next_iter_segments}
538
+
539
+ self.rc.pstatus("Iteration completed successfully")
540
+
541
+ def get_istate_futures(self):
542
+ '''Add ``n_states`` initial states to the internal list of initial states assigned to
543
+ recycled particles. Spare states are used if available, otherwise new states are created.
544
+ If created new initial states requires generation, then a set of futures is returned
545
+ representing work manager tasks corresponding to the necessary generation work.'''
546
+
547
+ n_recycled = self.we_driver.n_recycled_segs
548
+ n_istates_needed = self.we_driver.n_istates_needed
549
+
550
+ log.debug('{:d} unused initial states available'.format(len(self.we_driver.avail_initial_states)))
551
+ log.debug('{:d} new initial states required for recycling {:d} walkers'.format(n_istates_needed, n_recycled))
552
+
553
+ futures = set()
554
+ updated_states = []
555
+ for _i in range(n_istates_needed):
556
+ # Select a basis state according to its weight
557
+ ibstate = np.digitize([random.random()], self.next_iter_bstate_cprobs)
558
+ basis_state = self.next_iter_bstates[ibstate[0]]
559
+ initial_state = self.data_manager.create_initial_states(1, n_iter=self.n_iter + 1)[0]
560
+ initial_state.iter_created = self.n_iter
561
+ initial_state.basis_state_id = basis_state.state_id
562
+ initial_state.istate_status = InitialState.ISTATE_STATUS_PENDING
563
+
564
+ if self.do_gen_istates:
565
+ log.debug('generating new initial state from basis state {!r}'.format(basis_state))
566
+ initial_state.istate_type = InitialState.ISTATE_TYPE_GENERATED
567
+ futures.add(self.work_manager.submit(wm_ops.gen_istate, args=(basis_state, initial_state)))
568
+ else:
569
+ log.debug('using basis state {!r} directly'.format(basis_state))
570
+ initial_state.istate_type = InitialState.ISTATE_TYPE_BASIS
571
+ initial_state.pcoord = basis_state.pcoord.copy()
572
+ initial_state.istate_status = InitialState.ISTATE_STATUS_PREPARED
573
+ self.we_driver.avail_initial_states[initial_state.state_id] = initial_state
574
+ updated_states.append(initial_state)
575
+ self.data_manager.update_initial_states(updated_states, n_iter=self.n_iter + 1)
576
+ return futures
577
+
578
+ def propagate(self):
579
+ segments = list(self.incomplete_segments.values())
580
+ log.debug('iteration {:d}: propagating {:d} segments'.format(self.n_iter, len(segments)))
581
+
582
+ # all futures dispatched for this iteration
583
+ futures = set()
584
+ segment_futures = set()
585
+
586
+ # Immediately dispatch any necessary initial state generation
587
+ istate_gen_futures = self.get_istate_futures()
588
+ futures.update(istate_gen_futures)
589
+
590
+ # Dispatch propagation tasks using work manager
591
+ for segment_block in grouper(self.propagator_block_size, segments):
592
+ segment_block = [_f for _f in segment_block if _f]
593
+ pbstates, pistates = westpa.core.states.pare_basis_initial_states(
594
+ self.current_iter_bstates, list(self.current_iter_istates.values()), segment_block
595
+ )
596
+ future = self.work_manager.submit(wm_ops.propagate, args=(pbstates, pistates, segment_block))
597
+ futures.add(future)
598
+ segment_futures.add(future)
599
+
600
+ while futures:
601
+ # TODO: add capacity for timeout or SIGINT here
602
+ future = self.work_manager.wait_any(futures)
603
+ futures.remove(future)
604
+
605
+ if future in segment_futures:
606
+ segment_futures.remove(future)
607
+ incoming = future.get_result()
608
+ self.n_propagated += 1
609
+
610
+ self.segments.update({segment.seg_id: segment for segment in incoming})
611
+ self.completed_segments.update({segment.seg_id: segment for segment in incoming})
612
+
613
+ self.we_driver.assign(incoming)
614
+ new_istate_futures = self.get_istate_futures()
615
+ istate_gen_futures.update(new_istate_futures)
616
+ futures.update(new_istate_futures)
617
+
618
+ with self.data_manager.expiring_flushing_lock():
619
+ self.data_manager.update_segments(self.n_iter, incoming)
620
+
621
+ elif future in istate_gen_futures:
622
+ istate_gen_futures.remove(future)
623
+ _basis_state, initial_state = future.get_result()
624
+ log.debug('received newly-prepared initial state {!r}'.format(initial_state))
625
+ initial_state.istate_status = InitialState.ISTATE_STATUS_PREPARED
626
+ with self.data_manager.expiring_flushing_lock():
627
+ self.data_manager.update_initial_states([initial_state], n_iter=self.n_iter + 1)
628
+ self.we_driver.avail_initial_states[initial_state.state_id] = initial_state
629
+ else:
630
+ log.error('unknown future {!r} received from work manager'.format(future))
631
+ raise AssertionError('untracked future {!r}'.format(future))
632
+
633
+ log.debug('done with propagation')
634
+ self.save_bin_data()
635
+ self.data_manager.flush_backing()
636
+
637
+ def save_bin_data(self):
638
+ '''Calculate and write flux and transition count matrices to HDF5. Population and rate matrices
639
+ are likely useless at the single-tau level and are no longer written.'''
640
+ # save_bin_data(self, populations, n_trans, fluxes, rates, n_iter=None)
641
+
642
+ if self.save_transition_matrices:
643
+ with self.data_manager.expiring_flushing_lock():
644
+ iter_group = self.data_manager.get_iter_group(self.n_iter)
645
+ for key in ['bin_ntrans', 'bin_fluxes']:
646
+ try:
647
+ del iter_group[key]
648
+ except KeyError:
649
+ pass
650
+ iter_group['bin_ntrans'] = self.we_driver.transition_matrix
651
+ iter_group['bin_fluxes'] = self.we_driver.flux_matrix
652
+
653
+ def check_propagation(self):
654
+ '''Check for failures in propagation or initial state generation, and raise an exception
655
+ if any are found.'''
656
+
657
+ failed_segments = [segment for segment in self.segments.values() if segment.status != Segment.SEG_STATUS_COMPLETE]
658
+
659
+ if failed_segments:
660
+ failed_ids = ' \n'.join(str(segment.seg_id) for segment in failed_segments)
661
+ log.error('propagation failed for {:d} segment(s):\n{}'.format(len(failed_segments), failed_ids))
662
+ raise PropagationError('propagation failed for {:d} segments'.format(len(failed_segments)))
663
+ else:
664
+ log.debug('propagation complete for iteration {:d}'.format(self.n_iter))
665
+
666
+ failed_istates = [
667
+ istate
668
+ for istate in self.we_driver.used_initial_states.values()
669
+ if istate.istate_status != InitialState.ISTATE_STATUS_PREPARED
670
+ ]
671
+ log.debug('{!r}'.format(failed_istates))
672
+ if failed_istates:
673
+ failed_ids = ' \n'.join(str(istate.state_id) for istate in failed_istates)
674
+ log.error('initial state generation failed for {:d} states:\n{}'.format(len(failed_istates), failed_ids))
675
+ raise PropagationError('initial state generation failed for {:d} states'.format(len(failed_istates)))
676
+ else:
677
+ log.debug('initial state generation complete for iteration {:d}'.format(self.n_iter))
678
+
679
+ def run_we(self):
680
+ '''Run the weighted ensemble algorithm based on the binning in self.final_bins and
681
+ the recycled particles in self.to_recycle, creating and committing the next iteration's
682
+ segments to storage as well.'''
683
+
684
+ # The WE driver now does almost everything; we just have to record the
685
+ # mapper used for binning this iteration, and update initial states
686
+ # that have been used
687
+
688
+ try:
689
+ pickled, hashed = self.we_driver.bin_mapper.pickle_and_hash()
690
+ except PickleError:
691
+ pickled = hashed = ''
692
+
693
+ self.bin_mapper_hash = hashed
694
+ self.we_driver.construct_next()
695
+
696
+ if self.we_driver.used_initial_states:
697
+ for initial_state in self.we_driver.used_initial_states.values():
698
+ initial_state.iter_used = self.n_iter + 1
699
+ self.data_manager.update_initial_states(list(self.we_driver.used_initial_states.values()))
700
+
701
+ self.data_manager.update_segments(self.n_iter, list(self.segments.values()))
702
+
703
+ self.data_manager.require_iter_group(self.n_iter + 1)
704
+ self.data_manager.save_iter_binning(self.n_iter + 1, hashed, pickled, self.we_driver.bin_target_counts)
705
+
706
+ # Report on recycling
707
+ recycling_events = {}
708
+ for nw in self.we_driver.new_weights:
709
+ try:
710
+ recycling_events[nw.target_state_id].append(nw.weight)
711
+ except KeyError:
712
+ recycling_events[nw.target_state_id] = list([nw.weight])
713
+
714
+ tstates_by_id = {state.state_id: state for state in self.we_driver.target_states.values()}
715
+
716
+ for tstate_id, weights in recycling_events.items():
717
+ tstate = tstates_by_id[tstate_id]
718
+ self.rc.pstatus(
719
+ 'Recycled {:g} probability ({:d} walkers) from target state {!r}'.format(sum(weights), len(weights), tstate.label)
720
+ )
721
+
722
+ def prepare_new_iteration(self):
723
+ '''Commit data for the coming iteration to the HDF5 file.'''
724
+ self.invoke_callbacks(self.prepare_new_iteration)
725
+
726
+ if self.rc.debug_mode:
727
+ self.rc.pstatus('\nSegments generated:')
728
+ for segment in self.we_driver.next_iter_segments:
729
+ self.rc.pstatus('{!r} pcoord[0]={!r}'.format(segment, segment.pcoord[0]))
730
+
731
+ self.data_manager.prepare_iteration(self.n_iter + 1, list(self.we_driver.next_iter_segments))
732
+ self.data_manager.save_new_weight_data(self.n_iter + 1, self.we_driver.new_weights)
733
+
734
+ def run(self):
735
+ run_starttime = time.time()
736
+ max_walltime = self.max_run_walltime
737
+ if max_walltime:
738
+ run_killtime = run_starttime + max_walltime
739
+ self.rc.pstatus('Maximum wallclock time: %s' % timedelta(seconds=max_walltime or 0))
740
+ else:
741
+ run_killtime = None
742
+
743
+ self.n_iter = self.data_manager.current_iteration
744
+ max_iter = self.max_total_iterations or self.n_iter + 1
745
+
746
+ iter_elapsed = 0
747
+ while self.n_iter <= max_iter:
748
+ if max_walltime and time.time() + 1.1 * iter_elapsed >= run_killtime:
749
+ self.rc.pstatus('Iteration {:d} would require more than the allotted time. Ending run.'.format(self.n_iter))
750
+ return
751
+
752
+ try:
753
+ iter_start_time = time.time()
754
+
755
+ self.rc.pstatus('\n%s' % time.asctime())
756
+ self.rc.pstatus('Iteration %d (%d requested)' % (self.n_iter, max_iter))
757
+
758
+ self.prepare_iteration()
759
+ self.rc.pflush()
760
+
761
+ self.pre_propagation()
762
+ self.propagate()
763
+ self.rc.pflush()
764
+ self.check_propagation()
765
+ self.rc.pflush()
766
+ self.post_propagation()
767
+
768
+ cputime = sum(segment.cputime for segment in self.segments.values())
769
+
770
+ self.rc.pflush()
771
+ self.pre_we()
772
+ self.run_we()
773
+ self.post_we()
774
+ self.rc.pflush()
775
+
776
+ self.prepare_new_iteration()
777
+
778
+ self.finalize_iteration()
779
+
780
+ iter_elapsed = time.time() - iter_start_time
781
+ iter_summary = self.data_manager.get_iter_summary()
782
+ iter_summary['walltime'] += iter_elapsed
783
+ iter_summary['cputime'] = cputime
784
+ self.data_manager.update_iter_summary(iter_summary)
785
+
786
+ self.n_iter += 1
787
+ self.data_manager.current_iteration += 1
788
+
789
+ try:
790
+ # This may give NaN if starting a truncated simulation
791
+ walltime = timedelta(seconds=float(iter_summary['walltime']))
792
+ except ValueError:
793
+ walltime = 0.0
794
+
795
+ try:
796
+ cputime = timedelta(seconds=float(iter_summary['cputime']))
797
+ except ValueError:
798
+ cputime = 0.0
799
+
800
+ self.rc.pstatus('Iteration wallclock: {0!s}, cputime: {1!s}\n'.format(walltime, cputime))
801
+ self.rc.pflush()
802
+ finally:
803
+ self.data_manager.flush_backing()
804
+
805
+ self.rc.pstatus('\n%s' % time.asctime())
806
+ self.rc.pstatus('WEST run complete.')
807
+
808
+ def prepare_run(self):
809
+ '''Prepare a new run.'''
810
+ self.data_manager.prepare_run()
811
+ self.system.prepare_run()
812
+ self.invoke_callbacks(self.prepare_run)
813
+
814
+ def finalize_run(self):
815
+ '''Perform cleanup at the normal end of a run'''
816
+ self.invoke_callbacks(self.finalize_run)
817
+ self.system.finalize_run()
818
+ self.data_manager.finalize_run()
819
+
820
+ def pre_propagation(self):
821
+ self.invoke_callbacks(self.pre_propagation)
822
+
823
+ def post_propagation(self):
824
+ self.invoke_callbacks(self.post_propagation)
825
+
826
+ def pre_we(self):
827
+ self.invoke_callbacks(self.pre_we)
828
+
829
+ def post_we(self):
830
+ self.invoke_callbacks(self.post_we)