westpa 2022.12__cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of westpa might be problematic. Click here for more details.

Files changed (149) hide show
  1. westpa/__init__.py +14 -0
  2. westpa/_version.py +21 -0
  3. westpa/analysis/__init__.py +5 -0
  4. westpa/analysis/core.py +746 -0
  5. westpa/analysis/statistics.py +27 -0
  6. westpa/analysis/trajectories.py +360 -0
  7. westpa/cli/__init__.py +0 -0
  8. westpa/cli/core/__init__.py +0 -0
  9. westpa/cli/core/w_fork.py +152 -0
  10. westpa/cli/core/w_init.py +230 -0
  11. westpa/cli/core/w_run.py +77 -0
  12. westpa/cli/core/w_states.py +212 -0
  13. westpa/cli/core/w_succ.py +99 -0
  14. westpa/cli/core/w_truncate.py +68 -0
  15. westpa/cli/tools/__init__.py +0 -0
  16. westpa/cli/tools/ploterr.py +506 -0
  17. westpa/cli/tools/plothist.py +706 -0
  18. westpa/cli/tools/w_assign.py +596 -0
  19. westpa/cli/tools/w_bins.py +166 -0
  20. westpa/cli/tools/w_crawl.py +119 -0
  21. westpa/cli/tools/w_direct.py +547 -0
  22. westpa/cli/tools/w_dumpsegs.py +94 -0
  23. westpa/cli/tools/w_eddist.py +506 -0
  24. westpa/cli/tools/w_fluxanl.py +376 -0
  25. westpa/cli/tools/w_ipa.py +833 -0
  26. westpa/cli/tools/w_kinavg.py +127 -0
  27. westpa/cli/tools/w_kinetics.py +96 -0
  28. westpa/cli/tools/w_multi_west.py +414 -0
  29. westpa/cli/tools/w_ntop.py +213 -0
  30. westpa/cli/tools/w_pdist.py +515 -0
  31. westpa/cli/tools/w_postanalysis_matrix.py +82 -0
  32. westpa/cli/tools/w_postanalysis_reweight.py +53 -0
  33. westpa/cli/tools/w_red.py +491 -0
  34. westpa/cli/tools/w_reweight.py +780 -0
  35. westpa/cli/tools/w_select.py +226 -0
  36. westpa/cli/tools/w_stateprobs.py +111 -0
  37. westpa/cli/tools/w_trace.py +599 -0
  38. westpa/core/__init__.py +0 -0
  39. westpa/core/_rc.py +673 -0
  40. westpa/core/binning/__init__.py +55 -0
  41. westpa/core/binning/_assign.cpython-313-x86_64-linux-gnu.so +0 -0
  42. westpa/core/binning/assign.py +455 -0
  43. westpa/core/binning/binless.py +96 -0
  44. westpa/core/binning/binless_driver.py +54 -0
  45. westpa/core/binning/binless_manager.py +190 -0
  46. westpa/core/binning/bins.py +47 -0
  47. westpa/core/binning/mab.py +506 -0
  48. westpa/core/binning/mab_driver.py +54 -0
  49. westpa/core/binning/mab_manager.py +198 -0
  50. westpa/core/data_manager.py +1694 -0
  51. westpa/core/extloader.py +74 -0
  52. westpa/core/h5io.py +995 -0
  53. westpa/core/kinetics/__init__.py +24 -0
  54. westpa/core/kinetics/_kinetics.cpython-313-x86_64-linux-gnu.so +0 -0
  55. westpa/core/kinetics/events.py +147 -0
  56. westpa/core/kinetics/matrates.py +156 -0
  57. westpa/core/kinetics/rate_averaging.py +266 -0
  58. westpa/core/progress.py +218 -0
  59. westpa/core/propagators/__init__.py +54 -0
  60. westpa/core/propagators/executable.py +719 -0
  61. westpa/core/reweight/__init__.py +14 -0
  62. westpa/core/reweight/_reweight.cpython-313-x86_64-linux-gnu.so +0 -0
  63. westpa/core/reweight/matrix.py +126 -0
  64. westpa/core/segment.py +119 -0
  65. westpa/core/sim_manager.py +835 -0
  66. westpa/core/states.py +359 -0
  67. westpa/core/systems.py +93 -0
  68. westpa/core/textio.py +74 -0
  69. westpa/core/trajectory.py +330 -0
  70. westpa/core/we_driver.py +910 -0
  71. westpa/core/wm_ops.py +43 -0
  72. westpa/core/yamlcfg.py +391 -0
  73. westpa/fasthist/__init__.py +34 -0
  74. westpa/fasthist/_fasthist.cpython-313-x86_64-linux-gnu.so +0 -0
  75. westpa/mclib/__init__.py +271 -0
  76. westpa/mclib/__main__.py +28 -0
  77. westpa/mclib/_mclib.cpython-313-x86_64-linux-gnu.so +0 -0
  78. westpa/oldtools/__init__.py +4 -0
  79. westpa/oldtools/aframe/__init__.py +35 -0
  80. westpa/oldtools/aframe/atool.py +75 -0
  81. westpa/oldtools/aframe/base_mixin.py +26 -0
  82. westpa/oldtools/aframe/binning.py +178 -0
  83. westpa/oldtools/aframe/data_reader.py +560 -0
  84. westpa/oldtools/aframe/iter_range.py +200 -0
  85. westpa/oldtools/aframe/kinetics.py +117 -0
  86. westpa/oldtools/aframe/mcbs.py +153 -0
  87. westpa/oldtools/aframe/output.py +39 -0
  88. westpa/oldtools/aframe/plotting.py +90 -0
  89. westpa/oldtools/aframe/trajwalker.py +126 -0
  90. westpa/oldtools/aframe/transitions.py +469 -0
  91. westpa/oldtools/cmds/__init__.py +0 -0
  92. westpa/oldtools/cmds/w_ttimes.py +361 -0
  93. westpa/oldtools/files.py +34 -0
  94. westpa/oldtools/miscfn.py +23 -0
  95. westpa/oldtools/stats/__init__.py +4 -0
  96. westpa/oldtools/stats/accumulator.py +35 -0
  97. westpa/oldtools/stats/edfs.py +129 -0
  98. westpa/oldtools/stats/mcbs.py +96 -0
  99. westpa/tools/__init__.py +33 -0
  100. westpa/tools/binning.py +472 -0
  101. westpa/tools/core.py +340 -0
  102. westpa/tools/data_reader.py +159 -0
  103. westpa/tools/dtypes.py +31 -0
  104. westpa/tools/iter_range.py +198 -0
  105. westpa/tools/kinetics_tool.py +340 -0
  106. westpa/tools/plot.py +283 -0
  107. westpa/tools/progress.py +17 -0
  108. westpa/tools/selected_segs.py +154 -0
  109. westpa/tools/wipi.py +751 -0
  110. westpa/trajtree/__init__.py +4 -0
  111. westpa/trajtree/_trajtree.cpython-313-x86_64-linux-gnu.so +0 -0
  112. westpa/trajtree/trajtree.py +117 -0
  113. westpa/westext/__init__.py +0 -0
  114. westpa/westext/adaptvoronoi/__init__.py +3 -0
  115. westpa/westext/adaptvoronoi/adaptVor_driver.py +214 -0
  116. westpa/westext/hamsm_restarting/__init__.py +3 -0
  117. westpa/westext/hamsm_restarting/example_overrides.py +35 -0
  118. westpa/westext/hamsm_restarting/restart_driver.py +1165 -0
  119. westpa/westext/stringmethod/__init__.py +11 -0
  120. westpa/westext/stringmethod/fourier_fitting.py +69 -0
  121. westpa/westext/stringmethod/string_driver.py +253 -0
  122. westpa/westext/stringmethod/string_method.py +306 -0
  123. westpa/westext/weed/BinCluster.py +180 -0
  124. westpa/westext/weed/ProbAdjustEquil.py +100 -0
  125. westpa/westext/weed/UncertMath.py +247 -0
  126. westpa/westext/weed/__init__.py +10 -0
  127. westpa/westext/weed/weed_driver.py +192 -0
  128. westpa/westext/wess/ProbAdjust.py +101 -0
  129. westpa/westext/wess/__init__.py +6 -0
  130. westpa/westext/wess/wess_driver.py +217 -0
  131. westpa/work_managers/__init__.py +57 -0
  132. westpa/work_managers/core.py +396 -0
  133. westpa/work_managers/environment.py +134 -0
  134. westpa/work_managers/mpi.py +318 -0
  135. westpa/work_managers/processes.py +187 -0
  136. westpa/work_managers/serial.py +28 -0
  137. westpa/work_managers/threads.py +79 -0
  138. westpa/work_managers/zeromq/__init__.py +20 -0
  139. westpa/work_managers/zeromq/core.py +641 -0
  140. westpa/work_managers/zeromq/node.py +131 -0
  141. westpa/work_managers/zeromq/work_manager.py +526 -0
  142. westpa/work_managers/zeromq/worker.py +320 -0
  143. westpa-2022.12.dist-info/AUTHORS +22 -0
  144. westpa-2022.12.dist-info/LICENSE +21 -0
  145. westpa-2022.12.dist-info/METADATA +193 -0
  146. westpa-2022.12.dist-info/RECORD +149 -0
  147. westpa-2022.12.dist-info/WHEEL +6 -0
  148. westpa-2022.12.dist-info/entry_points.txt +29 -0
  149. westpa-2022.12.dist-info/top_level.txt +1 -0
@@ -0,0 +1,835 @@
1
+ import logging
2
+ import math
3
+ import operator
4
+ import time
5
+ from datetime import timedelta
6
+ from pickle import PickleError
7
+ from itertools import zip_longest
8
+ from collections import Counter
9
+
10
+ import numpy as np
11
+ from numpy.random import Generator, MT19937
12
+
13
+ import westpa
14
+ from .data_manager import weight_dtype
15
+ from .segment import Segment
16
+ from .states import InitialState
17
+ from . import extloader
18
+ from . import wm_ops
19
+
20
+
21
+ log = logging.getLogger(__name__)
22
+
23
+ EPS = np.finfo(weight_dtype).eps
24
+
25
+
26
+ def grouper(n, iterable, fillvalue=None):
27
+ "Collect data into fixed-length chunks or blocks"
28
+ # grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx
29
+ args = [iter(iterable)] * n
30
+ return zip_longest(fillvalue=fillvalue, *args)
31
+
32
+
33
+ class PropagationError(RuntimeError):
34
+ pass
35
+
36
+
37
+ class WESimManager:
38
+ def process_config(self):
39
+ config = self.rc.config
40
+ for entry, type_ in [('gen_istates', bool), ('block_size', int), ('save_transition_matrices', bool)]:
41
+ config.require_type_if_present(['west', 'propagation', entry], type_)
42
+
43
+ self.do_gen_istates = config.get(['west', 'propagation', 'gen_istates'], False)
44
+ self.propagator_block_size = config.get(['west', 'propagation', 'block_size'], 1)
45
+ self.save_transition_matrices = config.get(['west', 'propagation', 'save_transition_matrices'], False)
46
+ self.max_run_walltime = config.get(['west', 'propagation', 'max_run_wallclock'], default=None)
47
+ self.max_total_iterations = config.get(['west', 'propagation', 'max_total_iterations'], default=None)
48
+
49
+ def __init__(self, rc=None):
50
+ self.rc = rc or westpa.rc
51
+ self.work_manager = self.rc.get_work_manager()
52
+ self.data_manager = self.rc.get_data_manager()
53
+ self.we_driver = self.rc.get_we_driver()
54
+ self.system = self.rc.get_system_driver()
55
+
56
+ # A table of function -> list of (priority, name, callback) tuples
57
+ self._callback_table = {}
58
+ self._valid_callbacks = set(
59
+ (
60
+ self.prepare_run,
61
+ self.finalize_run,
62
+ self.prepare_iteration,
63
+ self.finalize_iteration,
64
+ self.pre_propagation,
65
+ self.post_propagation,
66
+ self.pre_we,
67
+ self.post_we,
68
+ self.prepare_new_iteration,
69
+ )
70
+ )
71
+ self._callbacks_by_name = {fn.__name__: fn for fn in self._valid_callbacks}
72
+ self.n_propagated = 0
73
+
74
+ # config items
75
+ self.do_gen_istates = False
76
+ self.propagator_block_size = 1
77
+ self.save_transition_matrices = False
78
+ self.max_run_walltime = None
79
+ self.max_total_iterations = None
80
+ self.process_config()
81
+
82
+ # Per-iteration variables
83
+ self.n_iter = None # current iteration
84
+
85
+ # Basis and initial states for this iteration, in case the propagator needs them
86
+ self.current_iter_bstates = None # BasisStates valid at this iteration
87
+ self.current_iter_istates = None # InitialStates used in this iteration
88
+
89
+ # Basis states for next iteration
90
+ self.next_iter_bstates = None # BasisStates valid for the next iteration
91
+ self.next_iter_bstate_cprobs = None # Cumulative probabilities for basis states, used for selection
92
+
93
+ # Tracking of this iteration's segments
94
+ self.segments = None # Mapping of seg_id to segment for all segments in this iteration
95
+ self.completed_segments = None # Mapping of seg_id to segment for all completed segments in this iteration
96
+ self.incomplete_segments = None # Mapping of seg_id to segment for all incomplete segments in this iteration
97
+
98
+ # Tracking of binning
99
+ self.bin_mapper_hash = None # Hash of bin mapper from most recently-run WE, for use by post-WE analysis plugins
100
+
101
+ # Pseudo Random Number Generator
102
+ self.rng = Generator(MT19937())
103
+
104
+ def register_callback(self, hook, function, priority=0):
105
+ '''Registers a callback to execute during the given ``hook`` into the simulation loop. The optional
106
+ priority is used to order when the function is called relative to other registered callbacks.'''
107
+
108
+ if hook not in self._valid_callbacks:
109
+ try:
110
+ hook = self._callbacks_by_name[hook]
111
+ except KeyError:
112
+ raise KeyError('invalid hook {!r}'.format(hook))
113
+
114
+ # It's possible to register a callback that's a duplicate function, but at a different place in memory.
115
+ # For example, if you launch a run without clearing state from a previous run.
116
+ # More details on this are available in https://github.com/westpa/westpa/issues/182 but the below code
117
+ # handles specifically the problem that causes in plugin loading.
118
+ try:
119
+ # Before checking for set membership of (priority, function.__name__, function), just check
120
+ # function hash for collisions in this hook.
121
+ hook_function_hash = [hash(callback[2]) for callback in self._callback_table[hook]]
122
+ except KeyError:
123
+ # If there's no entry in self._callback_table for this hook, then there definitely aren't any collisions
124
+ # because no plugins are registered to it yet in the first place.
125
+ pass
126
+ else:
127
+ # If there are plugins registered to this hook, check for duplicate hash, which will definitely have the same name, module, function.
128
+ try:
129
+ if hash(function) in hook_function_hash:
130
+ log.info('{!r} has already been loaded, skipping'.format(function))
131
+ return
132
+ except KeyError:
133
+ pass
134
+
135
+ try:
136
+ self._callback_table[hook].add((priority, function.__name__, function))
137
+ except KeyError:
138
+ self._callback_table[hook] = set([(priority, function.__name__, function)])
139
+
140
+ # Raise warning if there are multiple callback with same priority.
141
+ for priority, count in Counter([callback[0] for callback in self._callback_table[hook]]).items():
142
+ if count > 1:
143
+ log.warning(
144
+ f'{count} callbacks in {hook} have identical priority {priority}. The order of callback execution is not guaranteed.'
145
+ )
146
+ log.warning(f'{hook}: {self._callback_table[hook]}')
147
+
148
+ log.debug('registered callback {!r} for hook {!r}'.format(function, hook))
149
+
150
+ def invoke_callbacks(self, hook, *args, **kwargs):
151
+ callbacks = self._callback_table.get(hook, [])
152
+ # Sort by priority, function name, then module name
153
+ sorted_callbacks = sorted(callbacks, key=lambda x: (x[0], x[1], x[2].__module__))
154
+ for priority, name, fn in sorted_callbacks:
155
+ log.debug('invoking callback {!r} for hook {!r}'.format(fn, hook))
156
+ fn(*args, **kwargs)
157
+
158
+ def load_plugins(self, plugins=None):
159
+ if plugins is None:
160
+ plugins = []
161
+
162
+ try:
163
+ plugins_config = westpa.rc.config['west', 'plugins']
164
+ except KeyError:
165
+ plugins_config = []
166
+
167
+ if plugins_config is None:
168
+ plugins_config = []
169
+
170
+ plugins += plugins_config
171
+
172
+ for plugin_config in plugins:
173
+ plugin_name = plugin_config['plugin']
174
+ if plugin_config.get('enabled', True):
175
+ log.info('loading plugin {!r}'.format(plugin_name))
176
+ plugin = extloader.get_object(plugin_name)(self, plugin_config)
177
+ log.debug('loaded plugin {!r}'.format(plugin))
178
+
179
+ def report_bin_statistics(self, bins, target_states, save_summary=False):
180
+ segments = list(self.segments.values())
181
+ bin_counts = np.fromiter(map(len, bins), dtype=np.int_, count=len(bins))
182
+ target_counts = self.we_driver.bin_target_counts
183
+
184
+ # Do not include bins with target count zero (e.g. sinks, never-filled bins) in the (non)empty bins statistics
185
+ n_active_bins = len(target_counts[target_counts != 0])
186
+
187
+ if target_states:
188
+ n_active_bins -= len(target_states)
189
+
190
+ seg_probs = np.fromiter(map(operator.attrgetter('weight'), segments), dtype=weight_dtype, count=len(segments))
191
+ bin_probs = np.fromiter(map(operator.attrgetter('weight'), bins), dtype=weight_dtype, count=len(bins))
192
+ norm = seg_probs.sum()
193
+
194
+ assert abs(1 - norm) < EPS * (len(segments) + n_active_bins)
195
+
196
+ min_seg_prob = seg_probs[seg_probs != 0].min()
197
+ max_seg_prob = seg_probs.max()
198
+ seg_drange = math.log(max_seg_prob / min_seg_prob)
199
+ min_bin_prob = bin_probs[bin_probs != 0].min()
200
+ max_bin_prob = bin_probs.max()
201
+ bin_drange = math.log(max_bin_prob / min_bin_prob)
202
+ n_pop = len(bin_counts[bin_counts != 0])
203
+
204
+ self.rc.pstatus('{:d} of {:d} ({:%}) active bins are populated'.format(n_pop, n_active_bins, n_pop / n_active_bins))
205
+ self.rc.pstatus('per-bin minimum non-zero probability: {:g}'.format(min_bin_prob))
206
+ self.rc.pstatus('per-bin maximum probability: {:g}'.format(max_bin_prob))
207
+ self.rc.pstatus('per-bin probability dynamic range (kT): {:g}'.format(bin_drange))
208
+ self.rc.pstatus('per-segment minimum non-zero probability: {:g}'.format(min_seg_prob))
209
+ self.rc.pstatus('per-segment maximum non-zero probability: {:g}'.format(max_seg_prob))
210
+ self.rc.pstatus('per-segment probability dynamic range (kT): {:g}'.format(seg_drange))
211
+ self.rc.pstatus('norm = {:g}, error in norm = {:g} ({:.2g}*epsilon)'.format(norm, (norm - 1), (norm - 1) / EPS))
212
+ self.rc.pflush()
213
+
214
+ if min_seg_prob < 1e-100:
215
+ log.warning(
216
+ '\nMinimum segment weight is < 1e-100 and might not be physically relevant. Please reconsider your progress coordinate or binning scheme.'
217
+ )
218
+
219
+ if save_summary:
220
+ iter_summary = self.data_manager.get_iter_summary()
221
+ iter_summary['n_particles'] = len(segments)
222
+ iter_summary['norm'] = norm
223
+ iter_summary['min_bin_prob'] = min_bin_prob
224
+ iter_summary['max_bin_prob'] = max_bin_prob
225
+ iter_summary['min_seg_prob'] = min_seg_prob
226
+ iter_summary['max_seg_prob'] = max_seg_prob
227
+ if np.isnan(iter_summary['cputime']):
228
+ iter_summary['cputime'] = 0.0
229
+ if np.isnan(iter_summary['walltime']):
230
+ iter_summary['walltime'] = 0.0
231
+ self.data_manager.update_iter_summary(iter_summary)
232
+
233
+ def get_bstate_pcoords(self, basis_states, label='basis'):
234
+ '''For each of the given ``basis_states``, calculate progress coordinate values
235
+ as necessary. The HDF5 file is not updated. The BasisState objects are explicitly
236
+ copied from the futures in order to retain auxdata/restart files (under BasisState.data)
237
+ from certain work managers (e.g., the ``processes`` work manager.)'''
238
+
239
+ self.rc.pstatus('Calculating progress coordinate values for {} states.'.format(label))
240
+ futures = [self.work_manager.submit(wm_ops.get_pcoord, args=(basis_state,)) for basis_state in basis_states]
241
+ fmap = {future: i for (i, future) in enumerate(futures)}
242
+ for future in self.work_manager.as_completed(futures):
243
+ basis_states[fmap[future]] = future.get_result()
244
+
245
+ def report_basis_states(self, basis_states, label='basis'):
246
+ pstatus = self.rc.pstatus
247
+ pstatus('{:d} {} state(s) present'.format(len(basis_states), label), end='')
248
+ if self.rc.verbose_mode:
249
+ pstatus(':')
250
+ pstatus(
251
+ '{:6s} {:12s} {:20s} {:20s} {}'.format(
252
+ 'ID', 'Label', 'Probability', 'Aux Reference', 'Progress Coordinate'
253
+ )
254
+ )
255
+ for basis_state in basis_states:
256
+ pstatus(
257
+ '{:<6d} {:12s} {:<20.14g} {:20s} {}'.format(
258
+ basis_state.state_id,
259
+ basis_state.label,
260
+ basis_state.probability,
261
+ basis_state.auxref or '',
262
+ ', '.join(map(str, basis_state.pcoord)),
263
+ )
264
+ )
265
+ pstatus()
266
+ self.rc.pflush()
267
+
268
+ def report_target_states(self, target_states):
269
+ pstatus = self.rc.pstatus
270
+ pstatus('{:d} target state(s) present'.format(len(target_states)), end='')
271
+ if self.rc.verbose_mode and target_states:
272
+ pstatus(':')
273
+ pstatus('{:6s} {:12s} {}'.format('ID', 'Label', 'Progress Coordinate'))
274
+ for target_state in target_states:
275
+ pstatus(
276
+ '{:<6d} {:12s} {}'.format(
277
+ target_state.state_id, target_state.label, ','.join(map(str, target_state.pcoord))
278
+ )
279
+ )
280
+ pstatus()
281
+ self.rc.pflush()
282
+
283
+ def initialize_simulation(self, basis_states, target_states, start_states, segs_per_state=1, suppress_we=False):
284
+ '''Initialize a new weighted ensemble simulation, taking ``segs_per_state`` initial
285
+ states from each of the given ``basis_states``.
286
+
287
+ ``w_init`` is the forward-facing version of this function'''
288
+
289
+ data_manager = self.data_manager
290
+ work_manager = self.work_manager
291
+ pstatus = self.rc.pstatus
292
+ system = self.system
293
+
294
+ pstatus('Creating HDF5 file {!r}'.format(self.data_manager.we_h5filename))
295
+ data_manager.prepare_backing()
296
+
297
+ # Process target states
298
+ data_manager.save_target_states(target_states)
299
+ self.report_target_states(target_states)
300
+
301
+ # Process basis states
302
+ self.get_bstate_pcoords(basis_states)
303
+ self.data_manager.create_ibstate_group(basis_states)
304
+ self.data_manager.create_ibstate_iter_h5file(basis_states)
305
+ self.report_basis_states(basis_states)
306
+
307
+ # Process start states
308
+ # Unlike the above, does not create an ibstate group.
309
+ # TODO: Should it? I don't think so, if needed it can be traced back through basis_auxref
310
+
311
+ # Here, we are trying to assign a state_id to the start state to be initialized, without actually
312
+ # saving it to the ibstates records in any of the h5 files. It might actually be a problem
313
+ # when tracing trajectories with westpa.analysis (especially with HDF5 framework) since it would
314
+ # try to look for a basis state id > len(basis_state).
315
+ # Since start states are only used while initializing and iteration 1, it's ok to not save it to save space. If necessary,
316
+ # the structure can be traced directly to the parent file using the standard basis state logic referencing
317
+ # west[iterations/iter_00000001/ibstates/istate_index/basis_auxref] of that istate.
318
+
319
+ if len(start_states) > 0 and start_states[0].state_id is None:
320
+ last_id = basis_states[-1].state_id
321
+ for start_state in start_states:
322
+ start_state.state_id = last_id + 1
323
+ last_id += 1
324
+
325
+ self.get_bstate_pcoords(start_states, label='start')
326
+ self.report_basis_states(start_states, label='start')
327
+
328
+ pstatus('Preparing initial states')
329
+ initial_states = []
330
+ weights = []
331
+ if self.do_gen_istates:
332
+ istate_type = InitialState.ISTATE_TYPE_GENERATED
333
+ else:
334
+ istate_type = InitialState.ISTATE_TYPE_BASIS
335
+
336
+ for basis_state in basis_states:
337
+ for _iseg in range(segs_per_state):
338
+ initial_state = data_manager.create_initial_states(1, 1)[0]
339
+ initial_state.basis_state_id = basis_state.state_id
340
+ initial_state.basis_state = basis_state
341
+ initial_state.istate_type = istate_type
342
+ weights.append(basis_state.probability / segs_per_state)
343
+ initial_states.append(initial_state)
344
+
345
+ for start_state in start_states:
346
+ for _iseg in range(segs_per_state):
347
+ initial_state = data_manager.create_initial_states(1, 1)[0]
348
+ initial_state.basis_state_id = start_state.state_id
349
+ initial_state.basis_state = start_state
350
+ initial_state.basis_auxref = start_state.auxref
351
+
352
+ # Start states are assigned their own type, so they can be identified later
353
+ initial_state.istate_type = InitialState.ISTATE_TYPE_START
354
+ weights.append(start_state.probability / segs_per_state)
355
+ initial_state.iter_used = 1
356
+ initial_states.append(initial_state)
357
+
358
+ if self.do_gen_istates:
359
+ futures = [
360
+ work_manager.submit(wm_ops.gen_istate, args=(initial_state.basis_state, initial_state))
361
+ for initial_state in initial_states
362
+ ]
363
+ for future in work_manager.as_completed(futures):
364
+ rbstate, ristate = future.get_result()
365
+ initial_states[ristate.state_id].pcoord = ristate.pcoord
366
+ else:
367
+ for initial_state in initial_states:
368
+ basis_state = initial_state.basis_state
369
+ initial_state.pcoord = basis_state.pcoord
370
+ initial_state.istate_status = InitialState.ISTATE_STATUS_PREPARED
371
+
372
+ for initial_state in initial_states:
373
+ log.debug('initial state created: {!r}'.format(initial_state))
374
+
375
+ # save list of initial states just generated
376
+ # some of these may not be used, depending on how WE shakes out
377
+ data_manager.update_initial_states(initial_states, n_iter=1)
378
+
379
+ if not suppress_we:
380
+ self.we_driver.populate_initial(initial_states, weights, system)
381
+ segments = list(self.we_driver.next_iter_segments)
382
+ binning = self.we_driver.next_iter_binning
383
+ else:
384
+ segments = list(self.we_driver.current_iter_segments)
385
+ binning = self.we_driver.final_binning
386
+
387
+ bin_occupancies = np.fromiter(map(len, binning), dtype=np.uint, count=self.we_driver.bin_mapper.nbins)
388
+ target_occupancies = np.require(self.we_driver.bin_target_counts, dtype=np.uint)
389
+
390
+ # total_bins/replicas defined here to remove target state bin from "active" bins
391
+ total_bins = len(bin_occupancies) - len(target_states)
392
+ total_replicas = int(sum(target_occupancies)) - int(self.we_driver.bin_target_counts[-1]) * len(target_states)
393
+
394
+ # Make sure we have
395
+ for segment in segments:
396
+ segment.n_iter = 1
397
+ segment.status = Segment.SEG_STATUS_PREPARED
398
+ assert segment.parent_id < 0
399
+ assert initial_states[segment.initial_state_id].iter_used == 1
400
+
401
+ data_manager.prepare_iteration(1, segments)
402
+ data_manager.update_initial_states(initial_states, n_iter=1)
403
+
404
+ if self.rc.verbose_mode:
405
+ pstatus('\nSegments generated:')
406
+ for segment in segments:
407
+ pstatus('{!r}'.format(segment))
408
+
409
+ pstatus(
410
+ '''
411
+ Total bins: {total_bins:d}
412
+ Initial replicas: {init_replicas:d} in {occ_bins:d} bins, total weight = {weight:g}
413
+ Total target replicas: {total_replicas:d}
414
+ '''.format(
415
+ total_bins=total_bins,
416
+ init_replicas=int(sum(bin_occupancies)),
417
+ occ_bins=len(bin_occupancies[bin_occupancies > 0]),
418
+ weight=float(sum(segment.weight for segment in segments)),
419
+ total_replicas=total_replicas,
420
+ )
421
+ )
422
+
423
+ total_prob = float(sum(segment.weight for segment in segments))
424
+ pstatus(f'1-prob: {1 - total_prob:.4e}')
425
+
426
+ target_counts = self.we_driver.bin_target_counts
427
+ # Do not include bins with target count zero (e.g. sinks, never-filled bins) in the (non)empty bins statistics
428
+ n_active_bins = len(target_counts[target_counts != 0])
429
+ seg_probs = np.fromiter(map(operator.attrgetter('weight'), segments), dtype=weight_dtype, count=len(segments))
430
+ norm = seg_probs.sum()
431
+
432
+ if not abs(1 - norm) < EPS * (len(segments) + n_active_bins):
433
+ pstatus("Normalization check failed at w_init, explicitly renormalizing")
434
+ for segment in segments:
435
+ segment.weight /= norm
436
+
437
+ # Send the segments over to the data manager to commit to disk
438
+ data_manager.current_iteration = 1
439
+
440
+ # Report statistics
441
+ pstatus('Simulation prepared.')
442
+ self.segments = {segment.seg_id: segment for segment in segments}
443
+ self.report_bin_statistics(binning, target_states, save_summary=True)
444
+ data_manager.flush_backing()
445
+ data_manager.close_backing()
446
+
447
+ def prepare_iteration(self):
448
+ log.debug('beginning iteration {:d}'.format(self.n_iter))
449
+
450
+ # the WE driver needs a list of all target states for this iteration
451
+ # along with information about any new weights introduced (e.g. by recycling)
452
+ target_states = self.data_manager.get_target_states(self.n_iter)
453
+ new_weights = self.data_manager.get_new_weight_data(self.n_iter)
454
+
455
+ self.we_driver.new_iteration(target_states=target_states, new_weights=new_weights)
456
+
457
+ # Get basis states used in this iteration
458
+ self.current_iter_bstates = self.data_manager.get_basis_states(self.n_iter)
459
+
460
+ # Get the segments for this iteration and separate into complete and incomplete
461
+ if self.segments is None:
462
+ segments = self.segments = {segment.seg_id: segment for segment in self.data_manager.get_segments()}
463
+ log.debug('loaded {:d} segments'.format(len(segments)))
464
+ else:
465
+ segments = self.segments
466
+ log.debug('using {:d} pre-existing segments'.format(len(segments)))
467
+
468
+ completed_segments = self.completed_segments = {}
469
+ incomplete_segments = self.incomplete_segments = {}
470
+ for segment in segments.values():
471
+ if segment.status == Segment.SEG_STATUS_COMPLETE:
472
+ completed_segments[segment.seg_id] = segment
473
+ else:
474
+ incomplete_segments[segment.seg_id] = segment
475
+ log.debug('{:d} segments are complete; {:d} are incomplete'.format(len(completed_segments), len(incomplete_segments)))
476
+
477
+ if len(incomplete_segments) == len(segments):
478
+ # Starting a new iteration
479
+ self.rc.pstatus('Beginning iteration {:d}'.format(self.n_iter))
480
+ elif incomplete_segments:
481
+ self.rc.pstatus('Continuing iteration {:d}'.format(self.n_iter))
482
+ self.rc.pstatus(
483
+ '{:d} segments remain in iteration {:d} ({:d} total)'.format(len(incomplete_segments), self.n_iter, len(segments))
484
+ )
485
+
486
+ # Get the initial states active for this iteration (so that the propagator has them if necessary)
487
+ self.current_iter_istates = {
488
+ state.state_id: state for state in self.data_manager.get_segment_initial_states(list(segments.values()))
489
+ }
490
+ log.debug('This iteration uses {:d} initial states'.format(len(self.current_iter_istates)))
491
+
492
+ # Assign this iteration's segments' initial points to bins and report on bin population
493
+ initial_pcoords = self.system.new_pcoord_array(len(segments))
494
+ initial_binning = self.system.bin_mapper.construct_bins()
495
+ for iseg, segment in enumerate(segments.values()):
496
+ initial_pcoords[iseg] = segment.pcoord[0]
497
+ initial_assignments = self.system.bin_mapper.assign(initial_pcoords)
498
+ for segment, assignment in zip(iter(segments.values()), initial_assignments):
499
+ initial_binning[assignment].add(segment)
500
+ self.report_bin_statistics(initial_binning, [], save_summary=True)
501
+ del initial_pcoords, initial_binning
502
+
503
+ self.rc.pstatus('Waiting for segments to complete...')
504
+
505
+ # Let the WE driver assign completed segments
506
+ if completed_segments:
507
+ self.we_driver.assign(list(completed_segments.values()))
508
+
509
+ # load restart data
510
+ self.data_manager.prepare_segment_restarts(
511
+ incomplete_segments.values(), self.current_iter_bstates, self.current_iter_istates
512
+ )
513
+
514
+ # Get the basis states and initial states for the next iteration, necessary for doing on-the-fly recycling
515
+ self.next_iter_bstates = self.data_manager.get_basis_states(self.n_iter + 1)
516
+ self.next_iter_bstate_cprobs = np.add.accumulate([bstate.probability for bstate in self.next_iter_bstates])
517
+
518
+ self.we_driver.avail_initial_states = {
519
+ istate.state_id: istate for istate in self.data_manager.get_unused_initial_states(n_iter=self.n_iter + 1)
520
+ }
521
+ log.debug('{:d} unused initial states found'.format(len(self.we_driver.avail_initial_states)))
522
+
523
+ # Invoke callbacks
524
+ self.invoke_callbacks(self.prepare_iteration)
525
+
526
+ # dispatch and immediately wait on result for prep_iter
527
+ log.debug('dispatching propagator prep_iter to work manager')
528
+ self.work_manager.submit(wm_ops.prep_iter, args=(self.n_iter, segments)).get_result()
529
+
530
+ def finalize_iteration(self):
531
+ '''Clean up after an iteration and prepare for the next.'''
532
+ log.debug('finalizing iteration {:d}'.format(self.n_iter))
533
+
534
+ self.invoke_callbacks(self.finalize_iteration)
535
+
536
+ # dispatch and immediately wait on result for post_iter
537
+ log.debug('dispatching propagator post_iter to work manager')
538
+ self.work_manager.submit(wm_ops.post_iter, args=(self.n_iter, list(self.segments.values()))).get_result()
539
+
540
+ # Move existing segments into place as new segments
541
+ del self.segments
542
+ self.segments = {segment.seg_id: segment for segment in self.we_driver.next_iter_segments}
543
+
544
+ self.rc.pstatus("Iteration completed successfully")
545
+
546
+ def get_istate_futures(self):
547
+ '''Add ``n_states`` initial states to the internal list of initial states assigned to
548
+ recycled particles. Spare states are used if available, otherwise new states are created.
549
+ If created new initial states requires generation, then a set of futures is returned
550
+ representing work manager tasks corresponding to the necessary generation work.'''
551
+
552
+ n_recycled = self.we_driver.n_recycled_segs
553
+ n_istates_needed = self.we_driver.n_istates_needed
554
+
555
+ log.debug('{:d} unused initial states available'.format(len(self.we_driver.avail_initial_states)))
556
+ log.debug('{:d} new initial states required for recycling {:d} walkers'.format(n_istates_needed, n_recycled))
557
+
558
+ futures = set()
559
+ updated_states = []
560
+ for _i in range(n_istates_needed):
561
+ # Select a basis state according to its weight
562
+ ibstate = np.digitize([self.rng.random()], self.next_iter_bstate_cprobs)
563
+ basis_state = self.next_iter_bstates[ibstate[0]]
564
+ initial_state = self.data_manager.create_initial_states(1, n_iter=self.n_iter + 1)[0]
565
+ initial_state.iter_created = self.n_iter
566
+ initial_state.basis_state_id = basis_state.state_id
567
+ initial_state.istate_status = InitialState.ISTATE_STATUS_PENDING
568
+
569
+ if self.do_gen_istates:
570
+ log.debug('generating new initial state from basis state {!r}'.format(basis_state))
571
+ initial_state.istate_type = InitialState.ISTATE_TYPE_GENERATED
572
+ futures.add(self.work_manager.submit(wm_ops.gen_istate, args=(basis_state, initial_state)))
573
+ else:
574
+ log.debug('using basis state {!r} directly'.format(basis_state))
575
+ initial_state.istate_type = InitialState.ISTATE_TYPE_BASIS
576
+ initial_state.pcoord = basis_state.pcoord.copy()
577
+ initial_state.istate_status = InitialState.ISTATE_STATUS_PREPARED
578
+ self.we_driver.avail_initial_states[initial_state.state_id] = initial_state
579
+ updated_states.append(initial_state)
580
+ self.data_manager.update_initial_states(updated_states, n_iter=self.n_iter + 1)
581
+ return futures
582
+
583
+ def propagate(self):
584
+ segments = list(self.incomplete_segments.values())
585
+ log.debug('iteration {:d}: propagating {:d} segments'.format(self.n_iter, len(segments)))
586
+
587
+ # all futures dispatched for this iteration
588
+ futures = set()
589
+ segment_futures = set()
590
+
591
+ # Immediately dispatch any necessary initial state generation
592
+ istate_gen_futures = self.get_istate_futures()
593
+ futures.update(istate_gen_futures)
594
+
595
+ # Dispatch propagation tasks using work manager
596
+ for segment_block in grouper(self.propagator_block_size, segments):
597
+ segment_block = [_f for _f in segment_block if _f]
598
+ pbstates, pistates = westpa.core.states.pare_basis_initial_states(
599
+ self.current_iter_bstates, list(self.current_iter_istates.values()), segment_block
600
+ )
601
+ future = self.work_manager.submit(wm_ops.propagate, args=(pbstates, pistates, segment_block))
602
+ futures.add(future)
603
+ segment_futures.add(future)
604
+
605
+ while futures:
606
+ # TODO: add capacity for timeout or SIGINT here
607
+ future = self.work_manager.wait_any(futures)
608
+ futures.remove(future)
609
+
610
+ if future in segment_futures:
611
+ segment_futures.remove(future)
612
+ incoming = future.get_result()
613
+ self.n_propagated += 1
614
+
615
+ self.segments.update({segment.seg_id: segment for segment in incoming})
616
+ self.completed_segments.update({segment.seg_id: segment for segment in incoming})
617
+
618
+ self.we_driver.assign(incoming)
619
+ new_istate_futures = self.get_istate_futures()
620
+ istate_gen_futures.update(new_istate_futures)
621
+ futures.update(new_istate_futures)
622
+
623
+ with self.data_manager.expiring_flushing_lock():
624
+ self.data_manager.update_segments(self.n_iter, incoming)
625
+
626
+ elif future in istate_gen_futures:
627
+ istate_gen_futures.remove(future)
628
+ _basis_state, initial_state = future.get_result()
629
+ log.debug('received newly-prepared initial state {!r}'.format(initial_state))
630
+ initial_state.istate_status = InitialState.ISTATE_STATUS_PREPARED
631
+ with self.data_manager.expiring_flushing_lock():
632
+ self.data_manager.update_initial_states([initial_state], n_iter=self.n_iter + 1)
633
+ self.we_driver.avail_initial_states[initial_state.state_id] = initial_state
634
+ else:
635
+ log.error('unknown future {!r} received from work manager'.format(future))
636
+ raise AssertionError('untracked future {!r}'.format(future))
637
+
638
+ log.debug('done with propagation')
639
+ self.save_bin_data()
640
+ self.data_manager.flush_backing()
641
+
642
+ def save_bin_data(self):
643
+ '''Calculate and write flux and transition count matrices to HDF5. Population and rate matrices
644
+ are likely useless at the single-tau level and are no longer written.'''
645
+ # save_bin_data(self, populations, n_trans, fluxes, rates, n_iter=None)
646
+
647
+ if self.save_transition_matrices:
648
+ with self.data_manager.expiring_flushing_lock():
649
+ iter_group = self.data_manager.get_iter_group(self.n_iter)
650
+ for key in ['bin_ntrans', 'bin_fluxes']:
651
+ try:
652
+ del iter_group[key]
653
+ except KeyError:
654
+ pass
655
+ iter_group['bin_ntrans'] = self.we_driver.transition_matrix
656
+ iter_group['bin_fluxes'] = self.we_driver.flux_matrix
657
+
658
+ def check_propagation(self):
659
+ '''Check for failures in propagation or initial state generation, and raise an exception
660
+ if any are found.'''
661
+
662
+ failed_segments = [segment for segment in self.segments.values() if segment.status != Segment.SEG_STATUS_COMPLETE]
663
+
664
+ if failed_segments:
665
+ failed_ids = ' \n'.join(str(segment.seg_id) for segment in failed_segments)
666
+ log.error('propagation failed for {:d} segment(s):\n{}'.format(len(failed_segments), failed_ids))
667
+ raise PropagationError('propagation failed for {:d} segments'.format(len(failed_segments)))
668
+ else:
669
+ log.debug('propagation complete for iteration {:d}'.format(self.n_iter))
670
+
671
+ failed_istates = [
672
+ istate
673
+ for istate in self.we_driver.used_initial_states.values()
674
+ if istate.istate_status != InitialState.ISTATE_STATUS_PREPARED
675
+ ]
676
+ log.debug('{!r}'.format(failed_istates))
677
+ if failed_istates:
678
+ failed_ids = ' \n'.join(str(istate.state_id) for istate in failed_istates)
679
+ log.error('initial state generation failed for {:d} states:\n{}'.format(len(failed_istates), failed_ids))
680
+ raise PropagationError('initial state generation failed for {:d} states'.format(len(failed_istates)))
681
+ else:
682
+ log.debug('initial state generation complete for iteration {:d}'.format(self.n_iter))
683
+
684
+ def run_we(self):
685
+ '''Run the weighted ensemble algorithm based on the binning in self.final_bins and
686
+ the recycled particles in self.to_recycle, creating and committing the next iteration's
687
+ segments to storage as well.'''
688
+
689
+ # The WE driver now does almost everything; we just have to record the
690
+ # mapper used for binning this iteration, and update initial states
691
+ # that have been used
692
+
693
+ try:
694
+ pickled, hashed = self.we_driver.bin_mapper.pickle_and_hash()
695
+ except PickleError:
696
+ pickled = hashed = ''
697
+
698
+ self.bin_mapper_hash = hashed
699
+ self.we_driver.construct_next()
700
+
701
+ if self.we_driver.used_initial_states:
702
+ for initial_state in self.we_driver.used_initial_states.values():
703
+ initial_state.iter_used = self.n_iter + 1
704
+ self.data_manager.update_initial_states(list(self.we_driver.used_initial_states.values()))
705
+
706
+ self.data_manager.update_segments(self.n_iter, list(self.segments.values()))
707
+
708
+ self.data_manager.require_iter_group(self.n_iter + 1)
709
+ self.data_manager.save_iter_binning(self.n_iter + 1, hashed, pickled, self.we_driver.bin_target_counts)
710
+
711
+ # Report on recycling
712
+ recycling_events = {}
713
+ for nw in self.we_driver.new_weights:
714
+ try:
715
+ recycling_events[nw.target_state_id].append(nw.weight)
716
+ except KeyError:
717
+ recycling_events[nw.target_state_id] = list([nw.weight])
718
+
719
+ tstates_by_id = {state.state_id: state for state in self.we_driver.target_states.values()}
720
+
721
+ for tstate_id, weights in recycling_events.items():
722
+ tstate = tstates_by_id[tstate_id]
723
+ self.rc.pstatus(
724
+ 'Recycled {:g} probability ({:d} walkers) from target state {!r}'.format(sum(weights), len(weights), tstate.label)
725
+ )
726
+
727
+ def prepare_new_iteration(self):
728
+ '''Commit data for the coming iteration to the HDF5 file.'''
729
+ self.invoke_callbacks(self.prepare_new_iteration)
730
+
731
+ if self.rc.debug_mode:
732
+ self.rc.pstatus('\nSegments generated:')
733
+ for segment in self.we_driver.next_iter_segments:
734
+ self.rc.pstatus('{!r} pcoord[0]={!r}'.format(segment, segment.pcoord[0]))
735
+
736
+ self.data_manager.prepare_iteration(self.n_iter + 1, list(self.we_driver.next_iter_segments))
737
+ self.data_manager.save_new_weight_data(self.n_iter + 1, self.we_driver.new_weights)
738
+
739
+ def run(self):
740
+ run_starttime = time.time()
741
+ max_walltime = self.max_run_walltime
742
+ if max_walltime:
743
+ run_killtime = run_starttime + max_walltime
744
+ self.rc.pstatus('Maximum wallclock time: %s' % timedelta(seconds=max_walltime or 0))
745
+ else:
746
+ run_killtime = None
747
+
748
+ self.n_iter = self.data_manager.current_iteration
749
+ max_iter = self.max_total_iterations or self.n_iter + 1
750
+
751
+ iter_elapsed = 0
752
+ while self.n_iter <= max_iter:
753
+ if max_walltime and time.time() + 1.1 * iter_elapsed >= run_killtime:
754
+ self.rc.pstatus('Iteration {:d} would require more than the allotted time. Ending run.'.format(self.n_iter))
755
+ return
756
+
757
+ try:
758
+ iter_start_time = time.time()
759
+
760
+ self.rc.pstatus('\n%s' % time.asctime())
761
+ self.rc.pstatus('Iteration %d (%d requested)' % (self.n_iter, max_iter))
762
+
763
+ self.prepare_iteration()
764
+ self.rc.pflush()
765
+
766
+ self.pre_propagation()
767
+ self.propagate()
768
+ self.rc.pflush()
769
+ self.check_propagation()
770
+ self.rc.pflush()
771
+ self.post_propagation()
772
+
773
+ cputime = sum(segment.cputime for segment in self.segments.values())
774
+
775
+ self.rc.pflush()
776
+ self.pre_we()
777
+ self.run_we()
778
+ self.post_we()
779
+ self.rc.pflush()
780
+
781
+ self.prepare_new_iteration()
782
+
783
+ self.finalize_iteration()
784
+
785
+ iter_elapsed = time.time() - iter_start_time
786
+ iter_summary = self.data_manager.get_iter_summary()
787
+ iter_summary['walltime'] += iter_elapsed
788
+ iter_summary['cputime'] = cputime
789
+ self.data_manager.update_iter_summary(iter_summary)
790
+
791
+ self.n_iter += 1
792
+ self.data_manager.current_iteration += 1
793
+
794
+ try:
795
+ # This may give NaN if starting a truncated simulation
796
+ walltime = timedelta(seconds=float(iter_summary['walltime']))
797
+ except ValueError:
798
+ walltime = 0.0
799
+
800
+ try:
801
+ cputime = timedelta(seconds=float(iter_summary['cputime']))
802
+ except ValueError:
803
+ cputime = 0.0
804
+
805
+ self.rc.pstatus('Iteration wallclock: {0!s}, cputime: {1!s}\n'.format(walltime, cputime))
806
+ self.rc.pflush()
807
+ finally:
808
+ self.data_manager.flush_backing()
809
+
810
+ self.rc.pstatus('\n%s' % time.asctime())
811
+ self.rc.pstatus('WEST run complete.')
812
+
813
+ def prepare_run(self):
814
+ '''Prepare a new run.'''
815
+ self.data_manager.prepare_run()
816
+ self.system.prepare_run()
817
+ self.invoke_callbacks(self.prepare_run)
818
+
819
+ def finalize_run(self):
820
+ '''Perform cleanup at the normal end of a run'''
821
+ self.invoke_callbacks(self.finalize_run)
822
+ self.system.finalize_run()
823
+ self.data_manager.finalize_run()
824
+
825
+ def pre_propagation(self):
826
+ self.invoke_callbacks(self.pre_propagation)
827
+
828
+ def post_propagation(self):
829
+ self.invoke_callbacks(self.post_propagation)
830
+
831
+ def pre_we(self):
832
+ self.invoke_callbacks(self.pre_we)
833
+
834
+ def post_we(self):
835
+ self.invoke_callbacks(self.post_we)