asyncmd 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- asyncmd/__init__.py +18 -0
- asyncmd/_config.py +26 -0
- asyncmd/_version.py +75 -0
- asyncmd/config.py +203 -0
- asyncmd/gromacs/__init__.py +16 -0
- asyncmd/gromacs/mdconfig.py +351 -0
- asyncmd/gromacs/mdengine.py +1127 -0
- asyncmd/gromacs/utils.py +197 -0
- asyncmd/mdconfig.py +440 -0
- asyncmd/mdengine.py +100 -0
- asyncmd/slurm.py +1199 -0
- asyncmd/tools.py +86 -0
- asyncmd/trajectory/__init__.py +25 -0
- asyncmd/trajectory/convert.py +577 -0
- asyncmd/trajectory/functionwrapper.py +556 -0
- asyncmd/trajectory/propagate.py +937 -0
- asyncmd/trajectory/trajectory.py +1103 -0
- asyncmd/utils.py +148 -0
- asyncmd-0.3.2.dist-info/LICENSE +232 -0
- asyncmd-0.3.2.dist-info/METADATA +179 -0
- asyncmd-0.3.2.dist-info/RECORD +23 -0
- asyncmd-0.3.2.dist-info/WHEEL +5 -0
- asyncmd-0.3.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,937 @@
|
|
1
|
+
# This file is part of asyncmd.
|
2
|
+
#
|
3
|
+
# asyncmd is free software: you can redistribute it and/or modify
|
4
|
+
# it under the terms of the GNU General Public License as published by
|
5
|
+
# the Free Software Foundation, either version 3 of the License, or
|
6
|
+
# (at your option) any later version.
|
7
|
+
#
|
8
|
+
# asyncmd is distributed in the hope that it will be useful,
|
9
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
10
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
11
|
+
# GNU General Public License for more details.
|
12
|
+
#
|
13
|
+
# You should have received a copy of the GNU General Public License
|
14
|
+
# along with asyncmd. If not, see <https://www.gnu.org/licenses/>.
|
15
|
+
import os
|
16
|
+
import asyncio
|
17
|
+
import aiofiles
|
18
|
+
import aiofiles.os
|
19
|
+
import inspect
|
20
|
+
import logging
|
21
|
+
import typing
|
22
|
+
import numpy as np
|
23
|
+
|
24
|
+
from .trajectory import Trajectory
|
25
|
+
from .functionwrapper import TrajectoryFunctionWrapper
|
26
|
+
from .convert import TrajectoryConcatenator
|
27
|
+
from ..utils import (get_all_traj_parts,
|
28
|
+
get_all_file_parts,
|
29
|
+
nstout_from_mdconfig,
|
30
|
+
)
|
31
|
+
from ..tools import (remove_file_if_exist,
|
32
|
+
remove_file_if_exist_async,
|
33
|
+
)
|
34
|
+
|
35
|
+
|
36
|
+
logger = logging.getLogger(__name__)
|
37
|
+
|
38
|
+
|
39
|
+
class MaxStepsReachedError(Exception):
|
40
|
+
"""
|
41
|
+
Error raised when the simulation terminated because the (user-defined)
|
42
|
+
maximum number of integration steps/trajectory frames has been reached.
|
43
|
+
"""
|
44
|
+
pass
|
45
|
+
|
46
|
+
|
47
|
+
# TODO: move to trajectory.convert?
|
48
|
+
async def construct_TP_from_plus_and_minus_traj_segments(
|
49
|
+
minus_trajs: "list[Trajectory]",
|
50
|
+
minus_state: int,
|
51
|
+
plus_trajs: "list[Trajectory]",
|
52
|
+
plus_state: int,
|
53
|
+
state_funcs: "list[TrajectoryFunctionWrapper]",
|
54
|
+
tra_out: str,
|
55
|
+
struct_out: typing.Optional[str] = None,
|
56
|
+
overwrite: bool = False
|
57
|
+
) -> Trajectory:
|
58
|
+
"""
|
59
|
+
Construct a continous TP from plus and minus segments until states.
|
60
|
+
|
61
|
+
This is used e.g. in TwoWay TPS or if you try to get TPs out of a committor
|
62
|
+
simulation. Note, that this inverts all velocities on the minus segments.
|
63
|
+
|
64
|
+
Parameters
|
65
|
+
----------
|
66
|
+
minus_trajs : list[Trajectory]
|
67
|
+
Trajectories that go "backward in time", these are going to be inverted
|
68
|
+
minus_state : int
|
69
|
+
Index (in ``state_funcs``) of the first state reached on minus trajs.
|
70
|
+
plus_trajs : list[Trajectory]
|
71
|
+
Trajectories that go "forward in time", these are taken as is.
|
72
|
+
plus_state : int
|
73
|
+
Index (in ``state_funcs``) of the first state reached on plus trajs.
|
74
|
+
state_funcs : list[TrajectoryFunctionWrapper]
|
75
|
+
List of wrapped state functions, the indices to the states must match
|
76
|
+
the minus and plus state indices!
|
77
|
+
tra_out : str
|
78
|
+
Absolute or relative path to the output trajectory file.
|
79
|
+
struct_out : str, optional
|
80
|
+
Absolute or relative path to the output structure file, if None we will
|
81
|
+
use the structure file of the first minus_traj, by default None.
|
82
|
+
overwrite : bool, optional
|
83
|
+
Whether we should overwrite tra_out if it exists, by default False.
|
84
|
+
|
85
|
+
Returns
|
86
|
+
-------
|
87
|
+
Trajectory
|
88
|
+
The constructed transition.
|
89
|
+
"""
|
90
|
+
# first find the slices to concatenate
|
91
|
+
# minus state first
|
92
|
+
minus_state_vals = await asyncio.gather(*(state_funcs[minus_state](t)
|
93
|
+
for t in minus_trajs)
|
94
|
+
)
|
95
|
+
part_lens = [len(v) for v in minus_state_vals]
|
96
|
+
# make it into one long array
|
97
|
+
minus_state_vals = np.concatenate(minus_state_vals, axis=0)
|
98
|
+
# get the first frame in state (np.where always returns a tuple)
|
99
|
+
frames_in_minus, = np.where(minus_state_vals)
|
100
|
+
# get the first frame in minus state in minus_trajs, this will become the
|
101
|
+
# first frame of the traj since we invert this part
|
102
|
+
first_frame_in_minus = np.min(frames_in_minus)
|
103
|
+
# I think this is overkill, i.e. we can always expect that
|
104
|
+
# first frame in state is in last part?!
|
105
|
+
# [this could potentially make this a bit shorter and maybe
|
106
|
+
# even a bit more readable :)]
|
107
|
+
# But for now: better be save than sorry :)
|
108
|
+
# find the first part in which minus state is reached, i.e. the last one
|
109
|
+
# to take when constructing the TP
|
110
|
+
last_part_idx = 0
|
111
|
+
frame_sum = part_lens[last_part_idx]
|
112
|
+
while first_frame_in_minus >= frame_sum:
|
113
|
+
last_part_idx += 1
|
114
|
+
frame_sum += part_lens[last_part_idx]
|
115
|
+
# find the first frame in state (counting from start of last part to take)
|
116
|
+
_first_frame_in_minus = (first_frame_in_minus
|
117
|
+
- sum(part_lens[:last_part_idx])) # >= 0
|
118
|
+
# now construct the slices and trajs list (backwards!)
|
119
|
+
# the last/first part
|
120
|
+
slices = [(_first_frame_in_minus, None, -1)] # negative stride!
|
121
|
+
trajs = [minus_trajs[last_part_idx]]
|
122
|
+
# the ones we take fully (if any) [the range looks a bit strange
|
123
|
+
# because we dont take last_part_index but include the zero as idx]
|
124
|
+
slices += [(-1, None, -1) for _ in range(last_part_idx - 1, -1, -1)]
|
125
|
+
trajs += [minus_trajs[i] for i in range(last_part_idx - 1, -1, -1)]
|
126
|
+
|
127
|
+
# now plus trajectories, i.e. the part we put in positive stride
|
128
|
+
plus_state_vals = await asyncio.gather(*(state_funcs[plus_state](t)
|
129
|
+
for t in plus_trajs)
|
130
|
+
)
|
131
|
+
part_lens = [len(v) for v in plus_state_vals]
|
132
|
+
# make it into one long array
|
133
|
+
plus_state_vals = np.concatenate(plus_state_vals, axis=0)
|
134
|
+
# get the first frame in state
|
135
|
+
frames_in_plus, = np.where(plus_state_vals)
|
136
|
+
first_frame_in_plus = np.min(frames_in_plus)
|
137
|
+
# find the part
|
138
|
+
last_part_idx = 0
|
139
|
+
frame_sum = part_lens[last_part_idx]
|
140
|
+
while first_frame_in_plus >= frame_sum:
|
141
|
+
last_part_idx += 1
|
142
|
+
frame_sum += part_lens[last_part_idx]
|
143
|
+
# find the first frame in state (counting from start of last part)
|
144
|
+
_first_frame_in_plus = (first_frame_in_plus
|
145
|
+
- sum(part_lens[:last_part_idx])) # >= 0
|
146
|
+
# construct the slices and add trajs to list (forward!)
|
147
|
+
# NOTE: here we exclude the starting configuration, i.e. the SP,
|
148
|
+
# such that it is in the concatenated trajectory only once!
|
149
|
+
# (gromacs has the first frame in the trajectory)
|
150
|
+
if last_part_idx > 0:
|
151
|
+
# these are the trajectory segments we take completely
|
152
|
+
# [this excludes last_part_idx so far]
|
153
|
+
slices += [(1, None, 1)]
|
154
|
+
trajs += [plus_trajs[0]]
|
155
|
+
# these will be empty if last_part_idx < 2
|
156
|
+
slices += [(0, None, 1) for _ in range(1, last_part_idx)]
|
157
|
+
trajs += [plus_trajs[i] for i in range(1, last_part_idx)]
|
158
|
+
# add last part (with the last frame as first frame in plus state)
|
159
|
+
slices += [(0, _first_frame_in_plus + 1, 1)]
|
160
|
+
trajs += [plus_trajs[last_part_idx]]
|
161
|
+
else:
|
162
|
+
# first and last part is the same, so exclude starting configuration
|
163
|
+
# from the same segment that has the last frame as first frame in plus
|
164
|
+
slices += [(1, _first_frame_in_plus + 1, 1)]
|
165
|
+
trajs += [plus_trajs[last_part_idx]]
|
166
|
+
# finally produce the concatenated path
|
167
|
+
path_traj = await TrajectoryConcatenator().concatenate_async(
|
168
|
+
trajs=trajs,
|
169
|
+
slices=slices,
|
170
|
+
tra_out=tra_out,
|
171
|
+
struct_out=struct_out,
|
172
|
+
overwrite=overwrite,
|
173
|
+
)
|
174
|
+
return path_traj
|
175
|
+
|
176
|
+
|
177
|
+
class _TrajectoryPropagator:
|
178
|
+
# (private) superclass for InPartsTrajectoryPropagator and
|
179
|
+
# ConditionalTrajectoryPropagator,
|
180
|
+
# here we keep the common functions shared between them
|
181
|
+
async def remove_parts(self, workdir: str, deffnm: str,
|
182
|
+
file_endings_to_remove: list[str] = ["trajectories",
|
183
|
+
"log"],
|
184
|
+
remove_mda_offset_and_lock_files: bool = True,
|
185
|
+
remove_asyncmd_npz_caches: bool = True,
|
186
|
+
) -> None:
|
187
|
+
"""
|
188
|
+
Remove all `$deffnm.part$num.$file_ending` files for given file_endings
|
189
|
+
|
190
|
+
Can be useful to clean the `workdir` from temporary files if e.g. only
|
191
|
+
the concatenate trajectory is of interesst (like in TPS).
|
192
|
+
|
193
|
+
Parameters
|
194
|
+
----------
|
195
|
+
workdir : str
|
196
|
+
The directory to clean.
|
197
|
+
deffnm : str
|
198
|
+
The `deffnm` that the files we clean must have.
|
199
|
+
file_endings_to_remove : list[str], optional
|
200
|
+
The strings in the list `file_endings_to_remove` indicate which
|
201
|
+
file endings to remove.
|
202
|
+
E.g. `file_endings_to_remove=["trajectories", "log"]` will result
|
203
|
+
in removal of trajectory parts and the log files. If you add "edr"
|
204
|
+
to the list we would also remove the edr files,
|
205
|
+
by default ["trajectories", "log"]
|
206
|
+
"""
|
207
|
+
if "trajectories" in file_endings_to_remove:
|
208
|
+
# replace "trajectories" with the actual output traj type
|
209
|
+
try:
|
210
|
+
traj_type = self.engine_kwargs["output_traj_type"]
|
211
|
+
except KeyError:
|
212
|
+
# not in there so it will be the engine default
|
213
|
+
traj_type = self.engine_cls.output_traj_type
|
214
|
+
file_endings_to_remove.remove("trajectories")
|
215
|
+
file_endings_to_remove += [traj_type]
|
216
|
+
# now find and remove the files
|
217
|
+
for ending in file_endings_to_remove:
|
218
|
+
parts_to_remove = await get_all_file_parts(
|
219
|
+
folder=workdir,
|
220
|
+
deffnm=deffnm,
|
221
|
+
file_ending=ending.lower(),
|
222
|
+
)
|
223
|
+
# make sure we dont miss anything because we have different
|
224
|
+
# capitalization
|
225
|
+
if len(parts_to_remove) == 0:
|
226
|
+
parts_to_remove = await get_all_file_parts(
|
227
|
+
folder=workdir,
|
228
|
+
deffnm=deffnm,
|
229
|
+
file_ending=ending.upper(),
|
230
|
+
)
|
231
|
+
await asyncio.gather(*(aiofiles.os.unlink(f)
|
232
|
+
for f in parts_to_remove
|
233
|
+
)
|
234
|
+
)
|
235
|
+
# TODO: the note below?
|
236
|
+
# NOTE: this is a bit hacky: we just try to remove the offset and
|
237
|
+
# lock files for every file we remove (since we do not know
|
238
|
+
# if the file we remove is a trajectory [and therefore
|
239
|
+
# potentially has corresponding offset and lock files] or if
|
240
|
+
# the file we remove is e.g. an edr which has no lock/offset)
|
241
|
+
# If we would know that we are removing a trajectory we could
|
242
|
+
# try the removal only there, however even by comparing with
|
243
|
+
# `traj_type` (from above) we can not be certain since users
|
244
|
+
# could remove the wildcard "trajectories" and replace it by
|
245
|
+
# their specific trajectory file ending, i.e. we would need
|
246
|
+
# to know all potential traj file endings to be sure
|
247
|
+
if remove_mda_offset_and_lock_files or remove_asyncmd_npz_caches:
|
248
|
+
# create list with head, tail filenames only if needed
|
249
|
+
f_splits = [os.path.split(f) for f in parts_to_remove]
|
250
|
+
if remove_mda_offset_and_lock_files:
|
251
|
+
offset_lock_files_to_remove = [os.path.join(
|
252
|
+
f_head,
|
253
|
+
"." + f_tail + "_offsets.npz",
|
254
|
+
)
|
255
|
+
for f_head, f_tail in f_splits
|
256
|
+
]
|
257
|
+
offset_lock_files_to_remove += [os.path.join(
|
258
|
+
f_head,
|
259
|
+
"." + f_tail + "_offsets.lock",
|
260
|
+
)
|
261
|
+
for f_head, f_tail in f_splits
|
262
|
+
]
|
263
|
+
else:
|
264
|
+
offset_lock_files_to_remove = []
|
265
|
+
if remove_asyncmd_npz_caches:
|
266
|
+
# NOTE: we do not try to remove the multipart traj caches since
|
267
|
+
# the Propagators only return non-multipart Trajectories
|
268
|
+
npz_caches_to_remove = [os.path.join(
|
269
|
+
f_head,
|
270
|
+
"." + f_tail + "_asyncmd_cv_cache.npz",
|
271
|
+
)
|
272
|
+
for f_head, f_tail in f_splits
|
273
|
+
]
|
274
|
+
else:
|
275
|
+
npz_caches_to_remove = []
|
276
|
+
await asyncio.gather(*(remove_file_if_exist_async(f)
|
277
|
+
for f in offset_lock_files_to_remove + npz_caches_to_remove
|
278
|
+
)
|
279
|
+
)
|
280
|
+
|
281
|
+
|
282
|
+
class InPartsTrajectoryPropagator(_TrajectoryPropagator):
|
283
|
+
"""
|
284
|
+
Propagate a trajectory in parts of walltime until given number of steps.
|
285
|
+
|
286
|
+
Useful to make full use of backfilling with short(ish) simulation jobs and
|
287
|
+
also to run simulations that are longer than the timelimit.
|
288
|
+
"""
|
289
|
+
def __init__(self, n_steps: int, engine_cls,
|
290
|
+
engine_kwargs: dict, walltime_per_part: float,
|
291
|
+
) -> None:
|
292
|
+
"""
|
293
|
+
Initialize an `InPartTrajectoryPropagator`.
|
294
|
+
|
295
|
+
Parameters
|
296
|
+
----------
|
297
|
+
n_steps : int
|
298
|
+
Number of integration steps to do in total.
|
299
|
+
engine_cls : :class:`asyncmd.mdengine.MDEngine`
|
300
|
+
Class of the MD engine to use, **uninitialized!**
|
301
|
+
engine_kwargs : dict
|
302
|
+
Dictionary of key word arguments to initialize the MD engine.
|
303
|
+
walltime_per_part : float
|
304
|
+
Walltime per trajectory segment, in hours.
|
305
|
+
"""
|
306
|
+
self.n_steps = n_steps
|
307
|
+
self.engine_cls = engine_cls
|
308
|
+
self.engine_kwargs = engine_kwargs
|
309
|
+
self.walltime_per_part = walltime_per_part
|
310
|
+
|
311
|
+
async def propagate_and_concatenate(self,
|
312
|
+
starting_configuration: Trajectory,
|
313
|
+
workdir: str,
|
314
|
+
deffnm: str,
|
315
|
+
tra_out: str,
|
316
|
+
overwrite: bool = False,
|
317
|
+
continuation: bool = False
|
318
|
+
) -> tuple[Trajectory, int]:
|
319
|
+
"""
|
320
|
+
Chain :meth:`propagate` and :meth:`cut_and_concatenate` methods.
|
321
|
+
|
322
|
+
Parameters
|
323
|
+
----------
|
324
|
+
starting_configuration : :class:`asyncmd.Trajectory`
|
325
|
+
The configuration (including momenta) to start MD from.
|
326
|
+
workdir : str
|
327
|
+
Absolute or relative path to the working directory.
|
328
|
+
deffnm : str
|
329
|
+
MD engine deffnm for trajectory parts and other files.
|
330
|
+
tra_out : str
|
331
|
+
Absolute or relative path for the concatenated output trajectory.
|
332
|
+
overwrite : bool, optional
|
333
|
+
Whether the output trajectory should be overwritten if it exists,
|
334
|
+
by default False.
|
335
|
+
continuation : bool, optional
|
336
|
+
Whether we are continuing a previous MD run (with the same deffnm
|
337
|
+
and working directory), by default False.
|
338
|
+
|
339
|
+
Returns
|
340
|
+
-------
|
341
|
+
traj_out : :class:`asyncmd.Trajectory`
|
342
|
+
The concatenated output trajectory.
|
343
|
+
"""
|
344
|
+
# this just chains propagate and cut_and_concatenate
|
345
|
+
trajs = await self.propagate(
|
346
|
+
starting_configuration=starting_configuration,
|
347
|
+
workdir=workdir,
|
348
|
+
deffnm=deffnm,
|
349
|
+
continuation=continuation
|
350
|
+
)
|
351
|
+
full_traj = await self.cut_and_concatenate(
|
352
|
+
trajs=trajs,
|
353
|
+
tra_out=tra_out,
|
354
|
+
overwrite=overwrite,
|
355
|
+
)
|
356
|
+
return full_traj
|
357
|
+
|
358
|
+
async def propagate(self,
|
359
|
+
starting_configuration: Trajectory,
|
360
|
+
workdir: str,
|
361
|
+
deffnm: str,
|
362
|
+
continuation: bool = False,
|
363
|
+
) -> list[Trajectory]:
|
364
|
+
"""
|
365
|
+
Propagate the trajectory until self.n_steps integration are done.
|
366
|
+
|
367
|
+
Return a list of trajecory segments and the first condition fullfilled.
|
368
|
+
|
369
|
+
Parameters
|
370
|
+
----------
|
371
|
+
starting_configuration : :class:`asyncmd.Trajectory`
|
372
|
+
The configuration (including momenta) to start MD from.
|
373
|
+
workdir : str
|
374
|
+
Absolute or relative path to the working directory.
|
375
|
+
deffnm : str
|
376
|
+
MD engine deffnm for trajectory parts and other files.
|
377
|
+
continuation : bool, optional
|
378
|
+
Whether we are continuing a previous MD run (with the same deffnm
|
379
|
+
and working directory), by default False.
|
380
|
+
Note that when doing continuations and n_steps is lower than the
|
381
|
+
number of steps done already found in the directory, we still
|
382
|
+
return all trajectory parts (i.e. potentially too much).
|
383
|
+
:meth:`cut_and_concatenate` can return a trimmed subtrajectory.
|
384
|
+
|
385
|
+
Returns
|
386
|
+
-------
|
387
|
+
traj_segments : list[Trajectory]
|
388
|
+
List of trajectory (segements), ordered in time.
|
389
|
+
"""
|
390
|
+
engine = self.engine_cls(**self.engine_kwargs)
|
391
|
+
if continuation:
|
392
|
+
# continuation: get all traj parts already done and continue from
|
393
|
+
# there, i.e. append to the last traj part found
|
394
|
+
trajs = await get_all_traj_parts(folder=workdir, deffnm=deffnm,
|
395
|
+
engine=engine,
|
396
|
+
)
|
397
|
+
if len(trajs) > 0:
|
398
|
+
# can only continue if we find the previous trajs
|
399
|
+
step_counter = engine.steps_done
|
400
|
+
if step_counter >= self.n_steps:
|
401
|
+
# already longer than what we want to do, bail out
|
402
|
+
return trajs
|
403
|
+
await engine.prepare_from_files(workdir=workdir, deffnm=deffnm)
|
404
|
+
else:
|
405
|
+
# no previous trajs, prepare engine from scratch
|
406
|
+
continuation = False
|
407
|
+
logger.error("continuation=True, but we found no previous "
|
408
|
+
"trajectories. Setting continuation=False and "
|
409
|
+
"preparing the engine from scratch.")
|
410
|
+
if not continuation:
|
411
|
+
# no continuation, just prepare the engine from scratch
|
412
|
+
await engine.prepare(
|
413
|
+
starting_configuration=starting_configuration,
|
414
|
+
workdir=workdir,
|
415
|
+
deffnm=deffnm,
|
416
|
+
)
|
417
|
+
trajs = []
|
418
|
+
step_counter = 0
|
419
|
+
|
420
|
+
while (step_counter < self.n_steps):
|
421
|
+
traj = await engine.run(nsteps=self.n_steps,
|
422
|
+
walltime=self.walltime_per_part,
|
423
|
+
steps_per_part=False,
|
424
|
+
)
|
425
|
+
step_counter = engine.steps_done
|
426
|
+
trajs.append(traj)
|
427
|
+
return trajs
|
428
|
+
|
429
|
+
async def cut_and_concatenate(self,
|
430
|
+
trajs: list[Trajectory],
|
431
|
+
tra_out: str,
|
432
|
+
overwrite: bool = False,
|
433
|
+
) -> Trajectory:
|
434
|
+
"""
|
435
|
+
Cut and concatenate the trajectory until it has length n_steps.
|
436
|
+
|
437
|
+
Take a list of trajectory segments and form one continous trajectory
|
438
|
+
containing n_steps integration steps. The expected input
|
439
|
+
is a list of trajectories, e.g. the output of the :meth:`propagate`
|
440
|
+
method.
|
441
|
+
|
442
|
+
Parameters
|
443
|
+
----------
|
444
|
+
trajs : list[Trajectory]
|
445
|
+
Trajectory segments to cut and concatenate.
|
446
|
+
tra_out : str
|
447
|
+
Absolute or relative path for the concatenated output trajectory.
|
448
|
+
overwrite : bool, optional
|
449
|
+
Whether the output trajectory should be overwritten if it exists,
|
450
|
+
by default False.
|
451
|
+
|
452
|
+
Returns
|
453
|
+
-------
|
454
|
+
traj_out : :class:`asyncmd.Trajectory`
|
455
|
+
The concatenated output trajectory.
|
456
|
+
|
457
|
+
Raises
|
458
|
+
------
|
459
|
+
ValueError
|
460
|
+
If the given trajectories are to short to create a trajectory
|
461
|
+
containing n_steps integration steps
|
462
|
+
"""
|
463
|
+
# trajs is a list of trajectories, e.g. the return of propagate
|
464
|
+
# tra_out and overwrite are passed directly to the Concatenator
|
465
|
+
if len(trajs) == 0:
|
466
|
+
# no trajectories to concatenate, happens e.g. if self.n_steps=0
|
467
|
+
# we return None (TODO: is this what we want?)
|
468
|
+
return None
|
469
|
+
if self.n_steps > trajs[-1].last_step:
|
470
|
+
# not enough steps in trajectories
|
471
|
+
raise ValueError("The given trajectories are to short (< self.n_steps).")
|
472
|
+
elif self.n_steps == trajs[-1].last_step:
|
473
|
+
# all good, we just take all trajectory parts fully
|
474
|
+
slices = [(0, None, 1) for _ in range(len(trajs))]
|
475
|
+
last_part_idx = len(trajs) - 1
|
476
|
+
else:
|
477
|
+
logger.warning("Trajectories do not exactly contain n_steps "
|
478
|
+
"integration steps. Using a heuristic to find the "
|
479
|
+
"correct last frame to include, note that this "
|
480
|
+
"heuristic might fail if n_steps is not a multiple "
|
481
|
+
"of the trajectory output frequency.")
|
482
|
+
# need to find the subtrajectory that contains the correct number
|
483
|
+
# of integration steps
|
484
|
+
# first find the part in which we go over n_steps
|
485
|
+
last_part_idx = 0
|
486
|
+
while self.n_steps > trajs[last_part_idx].last_step:
|
487
|
+
last_part_idx += 1
|
488
|
+
# find out how much frames to take on last part
|
489
|
+
last_part_len_frames = len(trajs[last_part_idx])
|
490
|
+
last_part_len_steps = (trajs[last_part_idx].last_step
|
491
|
+
- trajs[last_part_idx].first_step)
|
492
|
+
steps_per_frame = last_part_len_steps / last_part_len_frames
|
493
|
+
frames_in_last_part = 0
|
494
|
+
while ((trajs[last_part_idx].first_step
|
495
|
+
+ frames_in_last_part * steps_per_frame) < self.n_steps):
|
496
|
+
# I guess we stay with the < (instead of <=) and rather have
|
497
|
+
# one frame too much?
|
498
|
+
frames_in_last_part += 1
|
499
|
+
# build slices
|
500
|
+
slices = [(0, None, 1) for _ in range(last_part_idx)]
|
501
|
+
slices += [(0, frames_in_last_part + 1, 1)]
|
502
|
+
|
503
|
+
# and concatenate
|
504
|
+
full_traj = await TrajectoryConcatenator().concatenate_async(
|
505
|
+
trajs=trajs[:last_part_idx + 1],
|
506
|
+
slices=slices,
|
507
|
+
# take the structure file of the traj, as it
|
508
|
+
# comes from the engine directly
|
509
|
+
tra_out=tra_out, struct_out=None,
|
510
|
+
overwrite=overwrite,
|
511
|
+
)
|
512
|
+
return full_traj
|
513
|
+
|
514
|
+
|
515
|
+
class ConditionalTrajectoryPropagator(_TrajectoryPropagator):
|
516
|
+
"""
|
517
|
+
Propagate a trajectory until any of the given conditions is fullfilled.
|
518
|
+
|
519
|
+
This class propagates the trajectory using a given MD engine (class) in
|
520
|
+
small chunks (chunksize is determined by walltime_per_part) and checks
|
521
|
+
after every chunk is done if any condition has been fullfilled.
|
522
|
+
It then returns a list of trajectory parts and the index of the condition
|
523
|
+
first fullfilled. It can also concatenate the parts into one trajectory,
|
524
|
+
which then starts with the starting configuration and ends with the frame
|
525
|
+
fullfilling the condition.
|
526
|
+
|
527
|
+
Attributes
|
528
|
+
----------
|
529
|
+
conditions : list[callable]
|
530
|
+
List of (wrapped) condition functions.
|
531
|
+
|
532
|
+
Notes
|
533
|
+
-----
|
534
|
+
We assume that every condition function returns a list/ a 1d array with
|
535
|
+
True or False for each frame, i.e. if we fullfill condition at any given
|
536
|
+
frame.
|
537
|
+
We assume non-overlapping conditions, i.e. a configuration can not fullfill
|
538
|
+
two conditions at the same time, **it is the users responsibility to ensure
|
539
|
+
that their conditions are sane**.
|
540
|
+
"""
|
541
|
+
|
542
|
+
# NOTE: we assume that every condition function returns a list/ a 1d array
|
543
|
+
# with True/False for each frame, i.e. if we fullfill condition at
|
544
|
+
# any given frame
|
545
|
+
# NOTE: we assume non-overlapping conditions, i.e. a configuration can not
|
546
|
+
# fullfill two conditions at the same time, it is the users
|
547
|
+
# responsibility to ensure that their conditions are sane
|
548
|
+
|
549
|
+
def __init__(self, conditions, engine_cls,
|
550
|
+
engine_kwargs: dict,
|
551
|
+
walltime_per_part: float,
|
552
|
+
max_steps: typing.Optional[int] = None,
|
553
|
+
max_frames: typing.Optional[int] = None,
|
554
|
+
):
|
555
|
+
"""
|
556
|
+
Initialize a `ConditionalTrajectoryPropagator`.
|
557
|
+
|
558
|
+
Parameters
|
559
|
+
----------
|
560
|
+
conditions : list[callable], usually list[TrajectoryFunctionWrapper]
|
561
|
+
List of condition functions, usually wrapped function for
|
562
|
+
asyncronous application, but can be any callable that takes a
|
563
|
+
:class:`asyncmd.Trajectory` and returns an array of True and False
|
564
|
+
values (one value per frame).
|
565
|
+
engine_cls : :class:`asyncmd.mdengine.MDEngine`
|
566
|
+
Class of the MD engine to use, **uninitialized!**
|
567
|
+
engine_kwargs : dict
|
568
|
+
Dictionary of key word arguments to initialize the MD engine.
|
569
|
+
walltime_per_part : float
|
570
|
+
Walltime per trajectory segment, in hours.
|
571
|
+
max_steps : int, optional
|
572
|
+
Maximum number of integration steps to do before stopping the
|
573
|
+
simulation because it did not commit to any condition,
|
574
|
+
by default None. Takes precendence over max_frames if both given.
|
575
|
+
max_frames : int, optional
|
576
|
+
Maximum number of frames to produce before stopping the simulation
|
577
|
+
because it did not commit to any condition, by default None.
|
578
|
+
|
579
|
+
Notes
|
580
|
+
-----
|
581
|
+
``max_steps`` and ``max_frames`` are redundant since
|
582
|
+
``max_steps = max_frames * output_frequency``, if both are given
|
583
|
+
max_steps takes precedence.
|
584
|
+
"""
|
585
|
+
self._conditions = None
|
586
|
+
self._condition_func_is_coroutine = None
|
587
|
+
self.conditions = conditions
|
588
|
+
self.engine_cls = engine_cls
|
589
|
+
self.engine_kwargs = engine_kwargs
|
590
|
+
self.walltime_per_part = walltime_per_part
|
591
|
+
# find nstout
|
592
|
+
try:
|
593
|
+
traj_type = engine_kwargs["output_traj_type"]
|
594
|
+
except KeyError:
|
595
|
+
# not in there so it will be the engine default
|
596
|
+
traj_type = engine_cls.output_traj_type
|
597
|
+
nstout = nstout_from_mdconfig(mdconfig=engine_kwargs["mdconfig"],
|
598
|
+
output_traj_type=traj_type)
|
599
|
+
# sort out if we use max_frames or max_steps
|
600
|
+
if max_frames is not None and max_steps is not None:
|
601
|
+
logger.warning("Both max_steps and max_frames given. Note that "
|
602
|
+
+ "max_steps will take precedence.")
|
603
|
+
if max_steps is not None:
|
604
|
+
self.max_steps = max_steps
|
605
|
+
elif max_frames is not None:
|
606
|
+
self.max_steps = max_frames * nstout
|
607
|
+
else:
|
608
|
+
logger.info("Neither max_frames nor max_steps given. "
|
609
|
+
+ "Setting max_steps to infinity.")
|
610
|
+
# this is a float but can be compared to ints
|
611
|
+
self.max_steps = np.inf
|
612
|
+
|
613
|
+
# TODO/FIXME: self._conditions is a list...that means users can change
|
614
|
+
# single elements without using the setter!
|
615
|
+
# we could use a list subclass as for the MDconfig?!
|
616
|
+
@property
|
617
|
+
def conditions(self):
|
618
|
+
return self._conditions
|
619
|
+
|
620
|
+
@conditions.setter
|
621
|
+
def conditions(self, conditions):
|
622
|
+
# use asyncio.iscorotinefunction to check the conditions
|
623
|
+
self._condition_func_is_coroutine = [
|
624
|
+
(inspect.iscoroutinefunction(c)
|
625
|
+
or inspect.iscoroutinefunction(c.__call__))
|
626
|
+
for c in conditions
|
627
|
+
]
|
628
|
+
if not all(self._condition_func_is_coroutine):
|
629
|
+
# and warn if it is not a corotinefunction
|
630
|
+
logger.warning(
|
631
|
+
"It is recommended to use coroutinefunctions for all "
|
632
|
+
+ "conditions. This can easily be achieved by wrapping any"
|
633
|
+
+ " function in a TrajectoryFunctionWrapper. All "
|
634
|
+
+ "non-coroutine condition functions will be blocking when"
|
635
|
+
+ " applied! ([c is coroutine for c in conditions] = %s)",
|
636
|
+
self._condition_func_is_coroutine
|
637
|
+
)
|
638
|
+
self._conditions = conditions
|
639
|
+
|
640
|
+
async def propagate_and_concatenate(self,
|
641
|
+
starting_configuration: Trajectory,
|
642
|
+
workdir: str,
|
643
|
+
deffnm: str,
|
644
|
+
tra_out: str,
|
645
|
+
overwrite: bool = False,
|
646
|
+
continuation: bool = False
|
647
|
+
) -> tuple[Trajectory, int]:
|
648
|
+
"""
|
649
|
+
Chain :meth:`propagate` and :meth:`cut_and_concatenate` methods.
|
650
|
+
|
651
|
+
Parameters
|
652
|
+
----------
|
653
|
+
starting_configuration : Trajectory
|
654
|
+
The configuration (including momenta) to start MD from.
|
655
|
+
workdir : str
|
656
|
+
Absolute or relative path to the working directory.
|
657
|
+
deffnm : str
|
658
|
+
MD engine deffnm for trajectory parts and other files.
|
659
|
+
tra_out : str
|
660
|
+
Absolute or relative path for the concatenated output trajectory.
|
661
|
+
overwrite : bool, optional
|
662
|
+
Whether the output trajectory should be overwritten if it exists,
|
663
|
+
by default False.
|
664
|
+
continuation : bool, optional
|
665
|
+
Whether we are continuing a previous MD run (with the same deffnm
|
666
|
+
and working directory), by default False.
|
667
|
+
|
668
|
+
Returns
|
669
|
+
-------
|
670
|
+
(traj_out, idx_of_condition_fullfilled) : (Trajectory, int)
|
671
|
+
The concatenated output trajectory from starting configuration
|
672
|
+
until the first condition is True and the index to the condition
|
673
|
+
function in `conditions`.
|
674
|
+
|
675
|
+
Raises
|
676
|
+
------
|
677
|
+
MaxStepsReachedError
|
678
|
+
When the defined maximum number of integration steps/trajectory
|
679
|
+
frames has been reached in :meth:`propagate`.
|
680
|
+
"""
|
681
|
+
# this just chains propagate and cut_and_concatenate
|
682
|
+
# usefull for committor simulations, for e.g. TPS one should try to
|
683
|
+
# directly concatenate both directions to a full TP if possible
|
684
|
+
trajs, first_condition_fullfilled = await self.propagate(
|
685
|
+
starting_configuration=starting_configuration,
|
686
|
+
workdir=workdir,
|
687
|
+
deffnm=deffnm,
|
688
|
+
continuation=continuation
|
689
|
+
)
|
690
|
+
# NOTE: it should not matter too much speedwise that we recalculate
|
691
|
+
# the condition functions, they are expected to be wrapped funcs
|
692
|
+
# i.e. the second time we should just get the values from cache
|
693
|
+
full_traj, first_condition_fullfilled = await self.cut_and_concatenate(
|
694
|
+
trajs=trajs,
|
695
|
+
tra_out=tra_out,
|
696
|
+
overwrite=overwrite,
|
697
|
+
)
|
698
|
+
return full_traj, first_condition_fullfilled
|
699
|
+
|
700
|
+
async def propagate(self,
|
701
|
+
starting_configuration: Trajectory,
|
702
|
+
workdir: str,
|
703
|
+
deffnm: str,
|
704
|
+
continuation: bool = False,
|
705
|
+
) -> tuple[list[Trajectory], int]:
|
706
|
+
"""
|
707
|
+
Propagate the trajectory until any condition is fullfilled.
|
708
|
+
|
709
|
+
Return a list of trajecory segments and the first condition fullfilled.
|
710
|
+
|
711
|
+
Parameters
|
712
|
+
----------
|
713
|
+
starting_configuration : Trajectory
|
714
|
+
The configuration (including momenta) to start MD from.
|
715
|
+
workdir : str
|
716
|
+
Absolute or relative path to the working directory.
|
717
|
+
deffnm : str
|
718
|
+
MD engine deffnm for trajectory parts and other files.
|
719
|
+
continuation : bool, optional
|
720
|
+
Whether we are continuing a previous MD run (with the same deffnm
|
721
|
+
and working directory), by default False.
|
722
|
+
|
723
|
+
Returns
|
724
|
+
-------
|
725
|
+
(traj_segments, idx_of_condition_fullfilled) : (list[Trajectory], int)
|
726
|
+
List of trajectory (segements), the last entry is the one on which
|
727
|
+
the first condition is fullfilled at some frame, the ineger is the
|
728
|
+
index to the condition function in `conditions`.
|
729
|
+
|
730
|
+
Raises
|
731
|
+
------
|
732
|
+
MaxStepsReachedError
|
733
|
+
When the defined maximum number of integration steps/trajectory
|
734
|
+
frames has been reached.
|
735
|
+
"""
|
736
|
+
# NOTE: curently this just returns a list of trajs + the condition
|
737
|
+
# fullfilled
|
738
|
+
# this feels a bit uncomfortable but avoids that we concatenate
|
739
|
+
# everything a quadrillion times when we use the results
|
740
|
+
# check first if the start configuration is fullfilling any condition
|
741
|
+
cond_vals = await self._condition_vals_for_traj(starting_configuration)
|
742
|
+
if np.any(cond_vals):
|
743
|
+
conds_fullfilled, frame_nums = np.where(cond_vals)
|
744
|
+
# gets the frame with the lowest idx where any condition is True
|
745
|
+
min_idx = np.argmin(frame_nums)
|
746
|
+
first_condition_fullfilled = conds_fullfilled[min_idx]
|
747
|
+
logger.error(f"Starting configuration ({starting_configuration}) "
|
748
|
+
+ "is already fullfilling the condition with idx"
|
749
|
+
+ f" {first_condition_fullfilled}.")
|
750
|
+
# we just return the starting configuration/trajectory
|
751
|
+
trajs = [starting_configuration]
|
752
|
+
return trajs, first_condition_fullfilled
|
753
|
+
|
754
|
+
# starting configuration does not fullfill any condition, lets do MD
|
755
|
+
engine = self.engine_cls(**self.engine_kwargs)
|
756
|
+
if continuation:
|
757
|
+
# continuation: get all traj parts already done and continue from
|
758
|
+
# there, i.e. append to the last traj part found
|
759
|
+
# NOTE: we assume that the condition functions could be different
|
760
|
+
# so get all traj parts and calculate the condition funcs on them
|
761
|
+
trajs = await get_all_traj_parts(folder=workdir, deffnm=deffnm,
|
762
|
+
engine=engine,
|
763
|
+
)
|
764
|
+
if len(trajs) > 0:
|
765
|
+
# can only calc CV values if we have trajectories prouced
|
766
|
+
cond_vals = await asyncio.gather(
|
767
|
+
*(self._condition_vals_for_traj(t) for t in trajs)
|
768
|
+
)
|
769
|
+
cond_vals = np.concatenate([np.asarray(s) for s in cond_vals],
|
770
|
+
axis=1)
|
771
|
+
# see if we already fullfill a condition on the existing traj parts
|
772
|
+
any_cond_fullfilled = np.any(cond_vals)
|
773
|
+
if any_cond_fullfilled:
|
774
|
+
conds_fullfilled, frame_nums = np.where(cond_vals)
|
775
|
+
# gets the frame with the lowest idx where any cond is True
|
776
|
+
min_idx = np.argmin(frame_nums)
|
777
|
+
first_condition_fullfilled = conds_fullfilled[min_idx]
|
778
|
+
# already fullfill a condition, get out of here!
|
779
|
+
return trajs, first_condition_fullfilled
|
780
|
+
# Did not fullfill any condition yet, so prepare the engine to
|
781
|
+
# continue the simulation until we reach any of the (new) conds
|
782
|
+
await engine.prepare_from_files(workdir=workdir, deffnm=deffnm)
|
783
|
+
step_counter = engine.steps_done
|
784
|
+
else:
|
785
|
+
# no trajectories, so we should prepare engine from scratch
|
786
|
+
continuation = False
|
787
|
+
logger.error("continuation=True, but we found no previous "
|
788
|
+
"trajectories. Setting continuation=False and "
|
789
|
+
"preparing the engine from scratch.")
|
790
|
+
if not continuation:
|
791
|
+
# no continuation, just prepare the engine from scratch
|
792
|
+
await engine.prepare(
|
793
|
+
starting_configuration=starting_configuration,
|
794
|
+
workdir=workdir,
|
795
|
+
deffnm=deffnm,
|
796
|
+
)
|
797
|
+
any_cond_fullfilled = False
|
798
|
+
trajs = []
|
799
|
+
step_counter = 0
|
800
|
+
|
801
|
+
while ((not any_cond_fullfilled)
|
802
|
+
and (step_counter <= self.max_steps)):
|
803
|
+
traj = await engine.run_walltime(self.walltime_per_part)
|
804
|
+
cond_vals = await self._condition_vals_for_traj(traj)
|
805
|
+
any_cond_fullfilled = np.any(cond_vals)
|
806
|
+
step_counter = engine.steps_done
|
807
|
+
trajs.append(traj)
|
808
|
+
if not any_cond_fullfilled:
|
809
|
+
# left while loop because of max_frames reached
|
810
|
+
raise MaxStepsReachedError(
|
811
|
+
f"Engine produced {step_counter} steps (>= {self.max_steps})."
|
812
|
+
)
|
813
|
+
# cond_vals are the ones for the last traj
|
814
|
+
# here we get which conditions are True and at which frame
|
815
|
+
conds_fullfilled, frame_nums = np.where(cond_vals)
|
816
|
+
# gets the frame with the lowest idx where any condition is True
|
817
|
+
min_idx = np.argmin(frame_nums)
|
818
|
+
# and now the idx to self.conditions for cond that was first fullfilled
|
819
|
+
# NOTE/FIXME: if two conditions are reached simultaneously at min_idx,
|
820
|
+
# this will find the condition with the lower idx only
|
821
|
+
first_condition_fullfilled = conds_fullfilled[min_idx]
|
822
|
+
return trajs, first_condition_fullfilled
|
823
|
+
|
824
|
+
async def cut_and_concatenate(self,
|
825
|
+
trajs: list[Trajectory],
|
826
|
+
tra_out: str,
|
827
|
+
overwrite: bool = False,
|
828
|
+
) -> tuple[Trajectory, int]:
|
829
|
+
"""
|
830
|
+
Cut and concatenate the trajectory until the first condition is True.
|
831
|
+
|
832
|
+
Take a list of trajectory segments and form one continous trajectory
|
833
|
+
until the first frame that fullfills any condition. The expected input
|
834
|
+
is a list of trajectories, e.g. the output of the :meth:`propagate`
|
835
|
+
method.
|
836
|
+
|
837
|
+
Parameters
|
838
|
+
----------
|
839
|
+
trajs : list[Trajectory]
|
840
|
+
Trajectory segments to cut and concatenate.
|
841
|
+
tra_out : str
|
842
|
+
Absolute or relative path for the concatenated output trajectory.
|
843
|
+
overwrite : bool, optional
|
844
|
+
Whether the output trajectory should be overwritten if it exists,
|
845
|
+
by default False.
|
846
|
+
|
847
|
+
Returns
|
848
|
+
-------
|
849
|
+
(traj_out, idx_of_condition_fullfilled) : (Trajectory, int)
|
850
|
+
The concatenated output trajectory from starting configuration
|
851
|
+
until the first condition is True and the index to the condition
|
852
|
+
function in `conditions`.
|
853
|
+
"""
|
854
|
+
# trajs is a list of trajectories, e.g. the return of propagate
|
855
|
+
# tra_out and overwrite are passed directly to the Concatenator
|
856
|
+
# NOTE: we assume that frame0 of traj0 does not fullfill any condition
|
857
|
+
# and return only the subtrajectory from frame0 until any
|
858
|
+
# condition is first True (the rest is ignored)
|
859
|
+
# get all func values and put them into one big array
|
860
|
+
cond_vals = await asyncio.gather(
|
861
|
+
*(self._condition_vals_for_traj(t) for t in trajs)
|
862
|
+
)
|
863
|
+
# cond_vals is a list (trajs) of lists (conditions)
|
864
|
+
# take condition 0 (always present) to get the traj part lengths
|
865
|
+
part_lens = [len(c[0]) for c in cond_vals] # c[0] is 1d (np)array
|
866
|
+
cond_vals = np.concatenate([np.asarray(c) for c in cond_vals],
|
867
|
+
axis=1)
|
868
|
+
conds_fullfilled, frame_nums = np.where(cond_vals)
|
869
|
+
# gets the frame with the lowest idx where any condition is True
|
870
|
+
min_idx = np.argmin(frame_nums)
|
871
|
+
first_condition_fullfilled = conds_fullfilled[min_idx]
|
872
|
+
first_frame_in_cond = frame_nums[min_idx]
|
873
|
+
# find out in which part it is
|
874
|
+
last_part_idx = 0
|
875
|
+
frame_sum = part_lens[last_part_idx]
|
876
|
+
while first_frame_in_cond >= frame_sum:
|
877
|
+
last_part_idx += 1
|
878
|
+
frame_sum += part_lens[last_part_idx]
|
879
|
+
# find the first frame in cond (counting from start of last part)
|
880
|
+
_first_frame_in_cond = (first_frame_in_cond
|
881
|
+
- sum(part_lens[:last_part_idx])) # >= 0
|
882
|
+
if last_part_idx > 0:
|
883
|
+
# trajectory parts which we take fully
|
884
|
+
slices = [(0, None, 1) for _ in range(last_part_idx)]
|
885
|
+
else:
|
886
|
+
# only the first/last part
|
887
|
+
slices = []
|
888
|
+
# and the last part until including first_frame_in_cond
|
889
|
+
slices += [(0, _first_frame_in_cond + 1, 1)]
|
890
|
+
# we fill in all args as kwargs because there are so many
|
891
|
+
full_traj = await TrajectoryConcatenator().concatenate_async(
|
892
|
+
trajs=trajs[:last_part_idx + 1],
|
893
|
+
slices=slices,
|
894
|
+
# take the structure file of the traj, as it
|
895
|
+
# comes from the engine directly
|
896
|
+
tra_out=tra_out, struct_out=None,
|
897
|
+
overwrite=overwrite,
|
898
|
+
)
|
899
|
+
return full_traj, first_condition_fullfilled
|
900
|
+
|
901
|
+
async def _condition_vals_for_traj(self, traj):
|
902
|
+
# return a list of condition_func results,
|
903
|
+
# one for each condition func in conditions
|
904
|
+
if all(self._condition_func_is_coroutine):
|
905
|
+
# easy, all coroutines
|
906
|
+
return await asyncio.gather(*(c(traj) for c in self.conditions))
|
907
|
+
elif not any(self._condition_func_is_coroutine):
|
908
|
+
# also easy (but blocking), none is coroutine
|
909
|
+
return [c(traj) for c in self.conditions]
|
910
|
+
else:
|
911
|
+
# need to piece it together
|
912
|
+
# first the coroutines concurrently
|
913
|
+
coros = [c(traj) for c, c_is_coro
|
914
|
+
in zip(self.conditions, self._condition_func_is_coroutine)
|
915
|
+
if c_is_coro
|
916
|
+
]
|
917
|
+
coro_res = await asyncio.gather(*coros)
|
918
|
+
# now either take the result from coro execution or calculate it
|
919
|
+
all_results = []
|
920
|
+
for c, c_is_coro in zip(self.conditions,
|
921
|
+
self._condition_func_is_coroutine):
|
922
|
+
if c_is_coro:
|
923
|
+
all_results.append(coro_res.pop(0))
|
924
|
+
else:
|
925
|
+
all_results.append(c(traj))
|
926
|
+
return all_results
|
927
|
+
# NOTE: this would be elegant, but to_thread() is py v>=3.9
|
928
|
+
# we wrap the non-coroutines into tasks to schedule all together
|
929
|
+
#all_conditions_as_coro = [
|
930
|
+
# c(traj) if c_is_cor else asyncio.to_thread(c, traj)
|
931
|
+
# for c, c_is_cor in zip(self.conditions, self._condition_func_is_coroutine)
|
932
|
+
# ]
|
933
|
+
#return await asyncio.gather(*all_conditions_as_coro)
|
934
|
+
|
935
|
+
|
936
|
+
# alias for people coming from the path sampling community :)
|
937
|
+
TrajectoryPropagatorUntilAnyState = ConditionalTrajectoryPropagator
|