westpa 2022.13__cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. westpa/__init__.py +14 -0
  2. westpa/_version.py +21 -0
  3. westpa/analysis/__init__.py +5 -0
  4. westpa/analysis/core.py +749 -0
  5. westpa/analysis/statistics.py +27 -0
  6. westpa/analysis/trajectories.py +369 -0
  7. westpa/cli/__init__.py +0 -0
  8. westpa/cli/core/__init__.py +0 -0
  9. westpa/cli/core/w_fork.py +152 -0
  10. westpa/cli/core/w_init.py +230 -0
  11. westpa/cli/core/w_run.py +77 -0
  12. westpa/cli/core/w_states.py +212 -0
  13. westpa/cli/core/w_succ.py +99 -0
  14. westpa/cli/core/w_truncate.py +68 -0
  15. westpa/cli/tools/__init__.py +0 -0
  16. westpa/cli/tools/ploterr.py +506 -0
  17. westpa/cli/tools/plothist.py +706 -0
  18. westpa/cli/tools/w_assign.py +597 -0
  19. westpa/cli/tools/w_bins.py +166 -0
  20. westpa/cli/tools/w_crawl.py +119 -0
  21. westpa/cli/tools/w_direct.py +557 -0
  22. westpa/cli/tools/w_dumpsegs.py +94 -0
  23. westpa/cli/tools/w_eddist.py +506 -0
  24. westpa/cli/tools/w_fluxanl.py +376 -0
  25. westpa/cli/tools/w_ipa.py +832 -0
  26. westpa/cli/tools/w_kinavg.py +127 -0
  27. westpa/cli/tools/w_kinetics.py +96 -0
  28. westpa/cli/tools/w_multi_west.py +414 -0
  29. westpa/cli/tools/w_ntop.py +213 -0
  30. westpa/cli/tools/w_pdist.py +515 -0
  31. westpa/cli/tools/w_postanalysis_matrix.py +82 -0
  32. westpa/cli/tools/w_postanalysis_reweight.py +53 -0
  33. westpa/cli/tools/w_red.py +491 -0
  34. westpa/cli/tools/w_reweight.py +780 -0
  35. westpa/cli/tools/w_select.py +226 -0
  36. westpa/cli/tools/w_stateprobs.py +111 -0
  37. westpa/cli/tools/w_timings.py +113 -0
  38. westpa/cli/tools/w_trace.py +599 -0
  39. westpa/core/__init__.py +0 -0
  40. westpa/core/_rc.py +673 -0
  41. westpa/core/binning/__init__.py +55 -0
  42. westpa/core/binning/_assign.c +36018 -0
  43. westpa/core/binning/_assign.cpython-312-aarch64-linux-gnu.so +0 -0
  44. westpa/core/binning/_assign.pyx +370 -0
  45. westpa/core/binning/assign.py +454 -0
  46. westpa/core/binning/binless.py +96 -0
  47. westpa/core/binning/binless_driver.py +54 -0
  48. westpa/core/binning/binless_manager.py +189 -0
  49. westpa/core/binning/bins.py +47 -0
  50. westpa/core/binning/mab.py +506 -0
  51. westpa/core/binning/mab_driver.py +54 -0
  52. westpa/core/binning/mab_manager.py +197 -0
  53. westpa/core/data_manager.py +1761 -0
  54. westpa/core/extloader.py +74 -0
  55. westpa/core/h5io.py +1079 -0
  56. westpa/core/kinetics/__init__.py +24 -0
  57. westpa/core/kinetics/_kinetics.c +45174 -0
  58. westpa/core/kinetics/_kinetics.cpython-312-aarch64-linux-gnu.so +0 -0
  59. westpa/core/kinetics/_kinetics.pyx +815 -0
  60. westpa/core/kinetics/events.py +147 -0
  61. westpa/core/kinetics/matrates.py +156 -0
  62. westpa/core/kinetics/rate_averaging.py +266 -0
  63. westpa/core/progress.py +218 -0
  64. westpa/core/propagators/__init__.py +54 -0
  65. westpa/core/propagators/executable.py +592 -0
  66. westpa/core/propagators/loaders.py +196 -0
  67. westpa/core/reweight/__init__.py +14 -0
  68. westpa/core/reweight/_reweight.c +36899 -0
  69. westpa/core/reweight/_reweight.cpython-312-aarch64-linux-gnu.so +0 -0
  70. westpa/core/reweight/_reweight.pyx +439 -0
  71. westpa/core/reweight/matrix.py +126 -0
  72. westpa/core/segment.py +119 -0
  73. westpa/core/sim_manager.py +839 -0
  74. westpa/core/states.py +359 -0
  75. westpa/core/systems.py +93 -0
  76. westpa/core/textio.py +74 -0
  77. westpa/core/trajectory.py +603 -0
  78. westpa/core/we_driver.py +910 -0
  79. westpa/core/wm_ops.py +43 -0
  80. westpa/core/yamlcfg.py +298 -0
  81. westpa/fasthist/__init__.py +34 -0
  82. westpa/fasthist/_fasthist.c +38755 -0
  83. westpa/fasthist/_fasthist.cpython-312-aarch64-linux-gnu.so +0 -0
  84. westpa/fasthist/_fasthist.pyx +222 -0
  85. westpa/mclib/__init__.py +271 -0
  86. westpa/mclib/__main__.py +28 -0
  87. westpa/mclib/_mclib.c +34610 -0
  88. westpa/mclib/_mclib.cpython-312-aarch64-linux-gnu.so +0 -0
  89. westpa/mclib/_mclib.pyx +226 -0
  90. westpa/oldtools/__init__.py +4 -0
  91. westpa/oldtools/aframe/__init__.py +35 -0
  92. westpa/oldtools/aframe/atool.py +75 -0
  93. westpa/oldtools/aframe/base_mixin.py +26 -0
  94. westpa/oldtools/aframe/binning.py +178 -0
  95. westpa/oldtools/aframe/data_reader.py +560 -0
  96. westpa/oldtools/aframe/iter_range.py +200 -0
  97. westpa/oldtools/aframe/kinetics.py +117 -0
  98. westpa/oldtools/aframe/mcbs.py +153 -0
  99. westpa/oldtools/aframe/output.py +39 -0
  100. westpa/oldtools/aframe/plotting.py +88 -0
  101. westpa/oldtools/aframe/trajwalker.py +126 -0
  102. westpa/oldtools/aframe/transitions.py +469 -0
  103. westpa/oldtools/cmds/__init__.py +0 -0
  104. westpa/oldtools/cmds/w_ttimes.py +361 -0
  105. westpa/oldtools/files.py +34 -0
  106. westpa/oldtools/miscfn.py +23 -0
  107. westpa/oldtools/stats/__init__.py +4 -0
  108. westpa/oldtools/stats/accumulator.py +35 -0
  109. westpa/oldtools/stats/edfs.py +129 -0
  110. westpa/oldtools/stats/mcbs.py +96 -0
  111. westpa/tools/__init__.py +33 -0
  112. westpa/tools/binning.py +472 -0
  113. westpa/tools/core.py +340 -0
  114. westpa/tools/data_reader.py +159 -0
  115. westpa/tools/dtypes.py +31 -0
  116. westpa/tools/iter_range.py +198 -0
  117. westpa/tools/kinetics_tool.py +343 -0
  118. westpa/tools/plot.py +283 -0
  119. westpa/tools/progress.py +17 -0
  120. westpa/tools/selected_segs.py +154 -0
  121. westpa/tools/wipi.py +751 -0
  122. westpa/trajtree/__init__.py +4 -0
  123. westpa/trajtree/_trajtree.c +17829 -0
  124. westpa/trajtree/_trajtree.cpython-312-aarch64-linux-gnu.so +0 -0
  125. westpa/trajtree/_trajtree.pyx +130 -0
  126. westpa/trajtree/trajtree.py +117 -0
  127. westpa/westext/__init__.py +0 -0
  128. westpa/westext/adaptvoronoi/__init__.py +3 -0
  129. westpa/westext/adaptvoronoi/adaptVor_driver.py +214 -0
  130. westpa/westext/hamsm_restarting/__init__.py +3 -0
  131. westpa/westext/hamsm_restarting/example_overrides.py +35 -0
  132. westpa/westext/hamsm_restarting/restart_driver.py +1165 -0
  133. westpa/westext/stringmethod/__init__.py +11 -0
  134. westpa/westext/stringmethod/fourier_fitting.py +69 -0
  135. westpa/westext/stringmethod/string_driver.py +253 -0
  136. westpa/westext/stringmethod/string_method.py +306 -0
  137. westpa/westext/weed/BinCluster.py +180 -0
  138. westpa/westext/weed/ProbAdjustEquil.py +100 -0
  139. westpa/westext/weed/UncertMath.py +247 -0
  140. westpa/westext/weed/__init__.py +10 -0
  141. westpa/westext/weed/weed_driver.py +192 -0
  142. westpa/westext/wess/ProbAdjust.py +101 -0
  143. westpa/westext/wess/__init__.py +6 -0
  144. westpa/westext/wess/wess_driver.py +217 -0
  145. westpa/work_managers/__init__.py +57 -0
  146. westpa/work_managers/core.py +396 -0
  147. westpa/work_managers/environment.py +134 -0
  148. westpa/work_managers/mpi.py +318 -0
  149. westpa/work_managers/processes.py +201 -0
  150. westpa/work_managers/serial.py +28 -0
  151. westpa/work_managers/threads.py +79 -0
  152. westpa/work_managers/zeromq/__init__.py +20 -0
  153. westpa/work_managers/zeromq/core.py +635 -0
  154. westpa/work_managers/zeromq/node.py +131 -0
  155. westpa/work_managers/zeromq/work_manager.py +526 -0
  156. westpa/work_managers/zeromq/worker.py +320 -0
  157. westpa-2022.13.dist-info/METADATA +179 -0
  158. westpa-2022.13.dist-info/RECORD +162 -0
  159. westpa-2022.13.dist-info/WHEEL +7 -0
  160. westpa-2022.13.dist-info/entry_points.txt +30 -0
  161. westpa-2022.13.dist-info/licenses/LICENSE +21 -0
  162. westpa-2022.13.dist-info/top_level.txt +1 -0
@@ -0,0 +1,27 @@
1
+ import numpy as np
2
+
3
+
4
+ def time_average(observable, iterations):
5
+ """Compute the time average of an observable.
6
+
7
+ Parameters
8
+ ----------
9
+ observable : Callable[[Walker], ArrayLike]
10
+ Function that takes a walker as input and returns a number or
11
+ a fixed-size array of numbers.
12
+ iterations : Sequence[Iteration]
13
+ Sequence of iterations over which to compute the average.
14
+
15
+ Returns
16
+ -------
17
+ ArrayLike
18
+ The time average of `observable` over `iterations`.
19
+
20
+ """
21
+ for iteration in iterations:
22
+ values = [observable(walker) for walker in iteration]
23
+ try:
24
+ value += np.dot(iteration.weights, values)
25
+ except NameError:
26
+ value = np.dot(iteration.weights, values)
27
+ return value / len(iterations)
@@ -0,0 +1,369 @@
1
+ import concurrent.futures as cf
2
+ import functools
3
+ import inspect
4
+ import operator
5
+ import os
6
+ from typing import Callable
7
+
8
+ try:
9
+ import mdtraj
10
+ except ImportError:
11
+ mdtraj = None
12
+ import numpy as np
13
+ from tqdm import tqdm
14
+
15
+ from westpa.analysis.core import Walker, Trace
16
+ from westpa.core.states import InitialState
17
+ from westpa.core.h5io import WESTIterationFile
18
+
19
+
20
+ class Trajectory:
21
+ """A callable that returns the trajectory of a walker or trace.
22
+
23
+ Parameters
24
+ ----------
25
+ fget : callable
26
+ Function for retrieving a single trajectory segment. Must take a
27
+ :class:`Walker` instance as its first argument and accept a boolean
28
+ keyword argument `include_initpoint`. The function should return a
29
+ sequence (e.g., a list or ndarray) representing the trajectory of
30
+ the walker. If `include_initpoint` is True, the trajectory segment
31
+ should include its initial point. Otherwise, the trajectory segment
32
+ should exclude its initial point.
33
+ fconcat : callable, optional
34
+ Function for concatenating trajectory segments. Must take a sequence
35
+ of trajectory segments as input and return their concatenation. The
36
+ default concatenation function is :func:`concatenate`.
37
+
38
+ """
39
+
40
+ def __init__(self, fget=None, *, fconcat=None):
41
+ if fget is None:
42
+ return functools.partial(self.__init__, fconcat=fconcat)
43
+
44
+ if 'include_initpoint' not in inspect.signature(fget).parameters:
45
+ raise ValueError("'fget' must accept a parameter 'include_initpoint'")
46
+
47
+ self._fget = fget
48
+ self.fconcat = fconcat
49
+
50
+ self._segment_collector = SegmentCollector(self)
51
+
52
+ @property
53
+ def segment_collector(self):
54
+ """SegmentCollector: Segment retrieval manager."""
55
+ return self._segment_collector
56
+
57
+ @property
58
+ def fget(self):
59
+ """callable: Function for getting trajectory segments."""
60
+ return self._fget
61
+
62
+ @property
63
+ def fconcat(self):
64
+ """callable: Function for concatenating trajectory segments."""
65
+ return self._fconcat
66
+
67
+ @fconcat.setter
68
+ def fconcat(self, value):
69
+ if value is None:
70
+ value = concatenate
71
+ elif not isinstance(value, Callable):
72
+ raise TypeError("'fconcat' must be a callable object")
73
+ self._fconcat = value
74
+
75
+ def __call__(self, obj, include_initpoint=True, **kwargs):
76
+ if isinstance(obj, Walker):
77
+ value = self.fget(obj, include_initpoint=include_initpoint, **kwargs)
78
+ self._validate_segment(value)
79
+ return value
80
+ if isinstance(obj, Trace):
81
+ initpoint_mask = np.full(len(obj), False)
82
+ initpoint_mask[0] = include_initpoint
83
+ segments = self.segment_collector.get_segments(obj, initpoint_mask, **kwargs)
84
+ return self.fconcat(segments)
85
+ raise TypeError('argument must be a Walker or Trace instance')
86
+
87
+ def _validate_segment(self, value):
88
+ if not hasattr(value, '__getitem__'):
89
+ msg = f"{type(value).__name__!r} object can't be concatenated"
90
+ raise TypeError(msg)
91
+
92
+
93
+ class SegmentCollector:
94
+ """An object that manages the retrieval of trajectory segments.
95
+
96
+ Parameters
97
+ ----------
98
+ trajectory : Trajectory
99
+ The trajectory to which the segment collector is attached.
100
+ use_threads : bool, default False
101
+ Whether to use a pool of threads to retrieve trajectory segments
102
+ asynchronously. Setting this parameter to True may be
103
+ useful when segment retrieval is an I/O bound task.
104
+ max_workers : int, optional
105
+ Maximum number of threads to use. The default value is specified in the
106
+ `ThreadPoolExecutor <https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ThreadPoolExecutor>`_
107
+ documentation.
108
+ show_progress : bool, default False
109
+ Whether to show a progress bar when retrieving multiple segments.
110
+
111
+ """
112
+
113
+ def __init__(self, trajectory, use_threads=False, max_workers=None, show_progress=False):
114
+ self.trajectory = trajectory
115
+ self.use_threads = use_threads
116
+ self.max_workers = max_workers
117
+ self.show_progress = show_progress
118
+
119
+ @property
120
+ def trajectory(self):
121
+ return self._trajectory
122
+
123
+ @trajectory.setter
124
+ def trajectory(self, value):
125
+ if not isinstance(value, Trajectory):
126
+ msg = f'trajectory must be an instance of {Trajectory}'
127
+ raise TypeError(msg)
128
+ self._trajectory = value
129
+
130
+ @property
131
+ def use_threads(self):
132
+ return self._use_threads
133
+
134
+ @use_threads.setter
135
+ def use_threads(self, value):
136
+ if not isinstance(value, bool):
137
+ raise TypeError('use_threads must be True or False')
138
+ self._use_threads = value
139
+
140
+ @property
141
+ def max_workers(self):
142
+ return self._max_workers
143
+
144
+ @max_workers.setter
145
+ def max_workers(self, value):
146
+ if value is None:
147
+ self._max_workers = None
148
+ return
149
+ if value <= 0:
150
+ raise ValueError('max_workers must be greater than 0')
151
+ self._max_workers = value
152
+
153
+ @property
154
+ def show_progress(self):
155
+ return self._show_progress
156
+
157
+ @show_progress.setter
158
+ def show_progress(self, value):
159
+ if not isinstance(value, bool):
160
+ raise ValueError('show_progress must be True or False')
161
+ self._show_progress = value
162
+
163
+ def get_segments(self, walkers, initpoint_mask=None, **kwargs):
164
+ """Retrieve the trajectories of multiple walkers.
165
+
166
+ Parameters
167
+ ----------
168
+ walkers : sequence of Walker
169
+ The walkers for which to retrieve trajectories.
170
+ initpoint_mask : sequence of bool, optional
171
+ A Boolean mask indicating whether each trajectory segment should
172
+ include (True) or exclude (False) its initial point. Default is
173
+ all True.
174
+
175
+ Returns
176
+ -------
177
+ list of sequences
178
+ The trajectory of each walker.
179
+
180
+ """
181
+ if initpoint_mask is None:
182
+ initpoint_mask = np.full(len(walkers), True)
183
+ else:
184
+ initpoint_mask = np.asarray(initpoint_mask, dtype=bool)
185
+
186
+ get_segment = functools.partial(self.trajectory, **kwargs)
187
+
188
+ tqdm_kwargs = dict(
189
+ desc='Retrieving segments',
190
+ disable=(not self.show_progress),
191
+ position=0,
192
+ total=len(walkers),
193
+ )
194
+
195
+ if self.use_threads:
196
+ with cf.ThreadPoolExecutor(self.max_workers) as executor:
197
+ future_to_key = {
198
+ executor.submit(get_segment, walker, include_initpoint=i): key
199
+ for key, (walker, i) in enumerate(zip(walkers, initpoint_mask))
200
+ }
201
+ futures = list(tqdm(cf.as_completed(future_to_key), **tqdm_kwargs))
202
+ futures.sort(key=future_to_key.get)
203
+ segments = (future.result() for future in futures)
204
+ else:
205
+ it = (get_segment(walker, include_initpoint=i) for walker, i in zip(walkers, initpoint_mask))
206
+ segments = tqdm(it, **tqdm_kwargs)
207
+
208
+ return list(segments)
209
+
210
+
211
+ class BasicMDTrajectory(Trajectory):
212
+ """Trajectory reader for MD trajectories stored as in the
213
+ `Basic Tutorial <https://github.com/westpa/westpa_tutorials/tree/main/tutorial7.1-basic-nacl>`_.
214
+
215
+ Parameters
216
+ ----------
217
+ top : str or mdtraj.Topology, default 'bstate.pdb'
218
+ traj_ext : str, default '.dcd'
219
+ state_ext : str, default '.xml'
220
+ sim_root : str, default '.'
221
+
222
+ """
223
+
224
+ def __init__(self, top='bstate.pdb', traj_ext='.dcd', state_ext='.xml', sim_root='.'):
225
+ if mdtraj is None:
226
+ raise ImportError('MDTraj must be installed to use the BasicMDTrajectory reader')
227
+
228
+ self.top = top
229
+ self.traj_ext = traj_ext
230
+ self.state_ext = state_ext
231
+ self.sim_root = sim_root
232
+
233
+ def fget(walker, include_initpoint=True, atom_indices=None, sim_root=None):
234
+ sim_root = sim_root or self.sim_root
235
+
236
+ if isinstance(self.top, str):
237
+ top = os.path.join(sim_root, 'common_files', self.top)
238
+ else:
239
+ top = self.top
240
+
241
+ path = os.path.join(
242
+ sim_root,
243
+ 'traj_segs',
244
+ format(walker.iteration.number, '06d'),
245
+ format(walker.index, '06d'),
246
+ 'seg' + self.traj_ext,
247
+ )
248
+ if top is not None:
249
+ traj = mdtraj.load(path, top=top)
250
+ else:
251
+ traj = mdtraj.load(path)
252
+
253
+ if include_initpoint:
254
+ parent = walker.parent
255
+
256
+ if isinstance(parent, InitialState):
257
+ if parent.istate_type == InitialState.ISTATE_TYPE_BASIS:
258
+ path = os.path.join(
259
+ sim_root,
260
+ 'bstates',
261
+ parent.basis_state.auxref,
262
+ )
263
+ else:
264
+ path = os.path.join(
265
+ sim_root,
266
+ 'istates',
267
+ str(parent.iter_created),
268
+ str(parent.state_id) + self.state_ext,
269
+ )
270
+ else:
271
+ path = os.path.join(
272
+ sim_root,
273
+ 'traj_segs',
274
+ format(walker.iteration.number - 1, '06d'),
275
+ format(parent.index, '06d'),
276
+ 'seg' + self.state_ext,
277
+ )
278
+
279
+ frame = mdtraj.load(path, top=traj.top)
280
+ traj = frame.join(traj, check_topology=False)
281
+
282
+ if atom_indices is not None:
283
+ traj.atom_slice(atom_indices, inplace=True)
284
+
285
+ return traj
286
+
287
+ super().__init__(fget)
288
+
289
+ self.segment_collector.use_threads = True
290
+ self.segment_collector.show_progress = True
291
+
292
+
293
+ class HDF5MDTrajectory(Trajectory):
294
+ """Trajectory reader for MD trajectories stored by the HDF5 framework."""
295
+
296
+ def __init__(self):
297
+ if mdtraj is None:
298
+ raise ImportError('MDTraj must be installed to use the HDF5MDTrajectory reader')
299
+
300
+ def fget(walker, include_initpoint=True, atom_indices=None):
301
+ iteration = walker.iteration
302
+
303
+ try:
304
+ link = iteration.h5group['trajectories']
305
+ except KeyError:
306
+ msg = 'the HDF5 framework does not appear to have been used to store trajectories for this run'
307
+ raise ValueError(msg)
308
+
309
+ with WESTIterationFile(link.file.filename) as traj_file:
310
+ traj = traj_file.read_as_traj(
311
+ iteration=iteration.number,
312
+ segment=walker.index,
313
+ atom_indices=atom_indices,
314
+ )
315
+
316
+ if include_initpoint:
317
+ parent = walker.parent
318
+
319
+ if isinstance(parent, InitialState):
320
+ if parent.istate_type == InitialState.ISTATE_TYPE_BASIS:
321
+ link = walker.run.h5file.get_iter_group(0)['trajectories']
322
+ with WESTIterationFile(link.file.filename) as traj_file:
323
+ frame = traj_file.read_as_traj(
324
+ iteration=0,
325
+ segment=parent.basis_state_id,
326
+ atom_indices=atom_indices,
327
+ )
328
+ elif parent.istate_type == InitialState.ISTATE_TYPE_GENERATED:
329
+ link = walker.run.h5file.get_iter_group(0)['trajectories']
330
+ istate_iter = -int(parent.iter_created) # the conversion to int is because iter_created is uint
331
+ with WESTIterationFile(link.file.filename) as traj_file:
332
+ frame = traj_file.read_as_traj(
333
+ iteration=istate_iter,
334
+ segment=parent.state_id,
335
+ atom_indices=atom_indices,
336
+ )
337
+ else:
338
+ raise ValueError('unsupported initial state type: %d' % parent.istate_type)
339
+ else:
340
+ frame = fget(parent, include_initpoint=False, atom_indices=atom_indices)[-1]
341
+ traj = frame.join(traj, check_topology=False)
342
+
343
+ return traj
344
+
345
+ super().__init__(fget)
346
+
347
+ self.segment_collector.use_threads = False
348
+ self.segment_collector.show_progress = True
349
+
350
+
351
+ def concatenate(segments):
352
+ """Return the concatenation of a sequence of trajectory segments.
353
+
354
+ Parameters
355
+ ----------
356
+ segments : sequence of sequences
357
+ A sequence of trajectory segments.
358
+
359
+ Returns
360
+ -------
361
+ sequence
362
+ The concatenation of `segments`.
363
+
364
+ """
365
+ if isinstance(segments[0], np.ndarray):
366
+ return np.concatenate(segments)
367
+ if mdtraj is not None and isinstance(segments[0], mdtraj.Trajectory):
368
+ return segments[0].join(segments[1:], check_topology=False)
369
+ return functools.reduce(operator.concat, segments)
westpa/cli/__init__.py ADDED
File without changes
File without changes
@@ -0,0 +1,152 @@
1
+ import argparse
2
+ import logging
3
+
4
+ import numpy as np
5
+
6
+ import westpa
7
+ from westpa.core.segment import Segment
8
+ from westpa.core.states import InitialState
9
+ from westpa.core.data_manager import n_iter_dtype, seg_id_dtype
10
+
11
+ log = logging.getLogger('w_fork')
12
+
13
+
14
+ def entry_point():
15
+ parser = argparse.ArgumentParser(
16
+ 'w_fork',
17
+ description='''\
18
+ Prepare a new weighted ensemble simulation from an existing one at a particular
19
+ point. A new HDF5 file is generated. In the case of executable propagation,
20
+ it is the user's responsibility to prepare the new simulation directory
21
+ appropriately, particularly making the old simulation's restart data from the
22
+ appropriate iteration available as the new simulations initial state data; a
23
+ mapping of old simulation segment to new simulation initial states is
24
+ created, both in the new HDF5 file and as a flat text file, to aid in this.
25
+ Target states and basis states for the new simulation are taken from those
26
+ in the original simulation.
27
+ ''',
28
+ )
29
+
30
+ westpa.rc.add_args(parser)
31
+ parser.add_argument(
32
+ '-i',
33
+ '--input',
34
+ dest='input_h5file',
35
+ help='''Create simulation from the given INPUT_H5FILE (default: read from
36
+ configuration file.''',
37
+ )
38
+ parser.add_argument(
39
+ '-I',
40
+ '--iteration',
41
+ dest='n_iter',
42
+ type=int,
43
+ help='''Take initial distribution for new simulation from iteration N_ITER
44
+ (default: last complete iteration).''',
45
+ )
46
+ parser.add_argument(
47
+ '-o',
48
+ '--output',
49
+ dest='output_h5file',
50
+ default='forked.h5',
51
+ help='''Save new simulation HDF5 file as OUTPUT (default: %(default)s).''',
52
+ )
53
+ parser.add_argument(
54
+ '--istate-map',
55
+ default='istate_map.txt',
56
+ help='''Write text file describing mapping of existing segments to new initial
57
+ states in ISTATE_MAP (default: %(default)s).''',
58
+ )
59
+ parser.add_argument('--no-headers', action='store_true', help='''Do not write header to ISTATE_MAP''')
60
+ args = parser.parse_args()
61
+ westpa.rc.process_args(args)
62
+
63
+ # Open old HDF5 file
64
+ dm_old = westpa.rc.new_data_manager()
65
+ if args.input_h5file:
66
+ dm_old.we_h5filename = args.input_h5file
67
+ dm_old.open_backing(mode='r')
68
+
69
+ # Get iteration if necessary
70
+ n_iter = args.n_iter or dm_old.current_iteration - 1
71
+
72
+ # Create and open new HDF5 file
73
+ dm_new = westpa.rc.new_data_manager()
74
+ dm_new.we_h5filename = args.output_h5file
75
+ dm_new.prepare_backing()
76
+ dm_new.open_backing()
77
+
78
+ # Copy target states
79
+ target_states = dm_old.get_target_states(n_iter)
80
+ dm_new.save_target_states(target_states, n_iter)
81
+
82
+ # Copy basis states
83
+ basis_states = dm_old.get_basis_states(n_iter)
84
+ dm_new.create_ibstate_group(basis_states, n_iter=1)
85
+
86
+ # Transform old segments into initial states and new segments
87
+ # We produce one initial state and one corresponding
88
+ # new segment for each old segment. Further adjustment
89
+ # can be accomplished by using w_binning.
90
+ old_iter_group = dm_old.get_iter_group(n_iter)
91
+ old_index = old_iter_group['seg_index'][...]
92
+ old_pcoord_ds = old_iter_group['pcoord']
93
+ n_segments = old_pcoord_ds.shape[0]
94
+ pcoord_len = old_pcoord_ds.shape[1]
95
+ pcoord_ndim = old_pcoord_ds.shape[2]
96
+ old_final_pcoords = old_pcoord_ds[:, pcoord_len - 1, :]
97
+
98
+ istates = dm_new.create_initial_states(n_segments, n_iter=1)
99
+ segments = []
100
+ state_map_dtype = np.dtype([('old_n_iter', n_iter_dtype), ('old_seg_id', seg_id_dtype), ('new_istate_id', seg_id_dtype)])
101
+ state_map = np.empty((n_segments,), dtype=state_map_dtype)
102
+ state_map['old_n_iter'] = n_iter
103
+
104
+ for iseg, (index_row, pcoord) in enumerate(zip(old_index, old_final_pcoords)):
105
+ istate = istates[iseg]
106
+ istate.iter_created = 0
107
+ istate.iter_used = 1
108
+ istate.istate_type = InitialState.ISTATE_TYPE_RESTART
109
+ istate.istate_status = InitialState.ISTATE_STATUS_PREPARED
110
+ istate.pcoord = pcoord
111
+
112
+ segment = Segment(
113
+ n_iter=1,
114
+ seg_id=iseg,
115
+ weight=index_row['weight'],
116
+ parent_id=-(istate.state_id + 1),
117
+ wtg_parent_ids=[-(istate.state_id + 1)],
118
+ status=Segment.SEG_STATUS_PREPARED,
119
+ )
120
+ segment.pcoord = np.zeros((pcoord_len, pcoord_ndim), dtype=pcoord.dtype)
121
+ segment.pcoord[0] = pcoord
122
+ segments.append(segment)
123
+ state_map[iseg]['old_seg_id'] = iseg
124
+ state_map[iseg]['new_istate_id'] = istate.state_id
125
+
126
+ dm_new.update_initial_states(istates, n_iter=0)
127
+ dm_new.prepare_iteration(n_iter=1, segments=segments)
128
+
129
+ # Update current iteration and close both files
130
+ dm_new.current_iteration = 1
131
+ dm_new.close_backing()
132
+ dm_old.close_backing()
133
+
134
+ # Write state map
135
+ istate_map_file = open(args.istate_map, 'wt')
136
+ if not args.no_headers:
137
+ istate_map_file.write('# mapping from previous segment IDs to new initial states\n')
138
+ istate_map_file.write('# generated by w_fork\n')
139
+ istate_map_file.write('# column 0: old simulation n_iter\n')
140
+ istate_map_file.write('# column 1: old simulation seg_id\n')
141
+ istate_map_file.write('# column 2: new simulation initial state ID\n')
142
+
143
+ for row in state_map:
144
+ istate_map_file.write(
145
+ '{old_n_iter:20d} {old_seg_id:20d} {new_istate_id:20d}\n'.format(
146
+ old_n_iter=int(row['old_n_iter']), old_seg_id=int(row['old_seg_id']), new_istate_id=int(row['new_istate_id'])
147
+ )
148
+ )
149
+
150
+
151
+ if __name__ == '__main__':
152
+ entry_point()