asyncmd 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,25 +12,34 @@
12
12
  #
13
13
  # You should have received a copy of the GNU General Public License
14
14
  # along with asyncmd. If not, see <https://www.gnu.org/licenses/>.
15
- import os
15
+ """
16
+ This module contains classes for propagation of MD in segments and/or until a condition is reached.
17
+
18
+ Most notable are the InPartsTrajectoryPropagator and the ConditionalTrajectoryPropagator.
19
+ Also of interest especially in the context of path sampling could be the function
20
+ construct_tp_from_plus_and_minus_traj_segments, which can be used directly on the
21
+ output of a ConditionalTrajectoryPropagator to generate trajectories connecting two
22
+ fulfilled conditions.
23
+ """
16
24
  import asyncio
17
- import aiofiles
18
- import aiofiles.os
25
+ import collections.abc
26
+ import copy
19
27
  import inspect
20
28
  import logging
29
+ import os
21
30
  import typing
31
+
32
+ import aiofiles
33
+ import aiofiles.os
22
34
  import numpy as np
23
35
 
24
- from .trajectory import Trajectory
25
- from .functionwrapper import TrajectoryFunctionWrapper
36
+ from ..mdengine import MDEngine
37
+ from ..tools import FlagChangeList, remove_file_if_exist_async
38
+ from ..utils import (get_all_file_parts, get_all_traj_parts,
39
+ nstout_from_mdconfig)
26
40
  from .convert import TrajectoryConcatenator
27
- from ..utils import (get_all_traj_parts,
28
- get_all_file_parts,
29
- nstout_from_mdconfig,
30
- )
31
- from ..tools import (remove_file_if_exist,
32
- remove_file_if_exist_async,
33
- )
41
+ from .functionwrapper import TrajectoryFunctionWrapper
42
+ from .trajectory import Trajectory
34
43
 
35
44
 
36
45
  logger = logging.getLogger(__name__)
@@ -41,22 +50,20 @@ class MaxStepsReachedError(Exception):
41
50
  Error raised when the simulation terminated because the (user-defined)
42
51
  maximum number of integration steps/trajectory frames has been reached.
43
52
  """
44
- pass
45
53
 
46
54
 
47
- # TODO: move to trajectory.convert?
48
- async def construct_TP_from_plus_and_minus_traj_segments(
55
+ async def construct_tp_from_plus_and_minus_traj_segments(
56
+ *,
49
57
  minus_trajs: "list[Trajectory]",
50
58
  minus_state: int,
51
59
  plus_trajs: "list[Trajectory]",
52
60
  plus_state: int,
53
61
  state_funcs: "list[TrajectoryFunctionWrapper]",
54
62
  tra_out: str,
55
- struct_out: typing.Optional[str] = None,
56
- overwrite: bool = False
57
- ) -> Trajectory:
63
+ **concatenate_kwargs,
64
+ ) -> Trajectory:
58
65
  """
59
- Construct a continous TP from plus and minus segments until states.
66
+ Construct a continuous TP from plus and minus segments until states.
60
67
 
61
68
  This is used e.g. in TwoWay TPS or if you try to get TPs out of a committor
62
69
  simulation. Note, that this inverts all velocities on the minus segments.
@@ -76,11 +83,15 @@ async def construct_TP_from_plus_and_minus_traj_segments(
76
83
  the minus and plus state indices!
77
84
  tra_out : str
78
85
  Absolute or relative path to the output trajectory file.
79
- struct_out : str, optional
80
- Absolute or relative path to the output structure file, if None we will
81
- use the structure file of the first minus_traj, by default None.
82
- overwrite : bool, optional
83
- Whether we should overwrite tra_out if it exists, by default False.
86
+ concatenate_kwargs : dict
87
+ All (other) keyword arguments will be passed as is to the
88
+ :meth:`TrajectoryConcatenator.concatenate` method. These include, e.g.,
89
+
90
+ - struct_out : str, optional
91
+ Absolute or relative path to the output structure file, if None we will
92
+ use the structure file of the first minus_traj, by default None.
93
+ - overwrite : bool, optional
94
+ Whether we should overwrite tra_out if it exists, by default False.
84
95
 
85
96
  Returns
86
97
  -------
@@ -168,42 +179,65 @@ async def construct_TP_from_plus_and_minus_traj_segments(
168
179
  trajs=trajs,
169
180
  slices=slices,
170
181
  tra_out=tra_out,
171
- struct_out=struct_out,
172
- overwrite=overwrite,
182
+ **concatenate_kwargs
173
183
  )
174
184
  return path_traj
175
185
 
176
186
 
187
+ # pylint: disable-next=too-few-public-methods
177
188
  class _TrajectoryPropagator:
178
189
  # (private) superclass for InPartsTrajectoryPropagator and
179
190
  # ConditionalTrajectoryPropagator,
180
- # here we keep the common functions shared between them
181
- async def remove_parts(self, workdir: str, deffnm: str,
182
- file_endings_to_remove: list[str] = ["trajectories",
183
- "log"],
184
- remove_mda_offset_and_lock_files: bool = True,
185
- remove_asyncmd_npz_caches: bool = True,
191
+ # here we keep the common functions shared between them, currently
192
+ # this is only the file removal method and a bit of the init logic.
193
+ def __init__(self, *,
194
+ engine_cls: type[MDEngine], engine_kwargs: dict[str, typing.Any],
195
+ walltime_per_part: float,
196
+ ) -> None:
197
+ self.engine_cls = engine_cls
198
+ self.engine_kwargs = engine_kwargs
199
+ self.walltime_per_part = walltime_per_part
200
+
201
+ async def remove_parts(self, workdir: str, deffnm: str, *,
202
+ file_endings_to_remove: list[str] | None = None,
203
+ remove_mda_offset_and_lock_files: bool = True,
204
+ remove_asyncmd_npz_caches: bool = True,
186
205
  ) -> None:
187
206
  """
188
- Remove all `$deffnm.part$num.$file_ending` files for given file_endings
207
+ Remove all ``$deffnm.part$num.$file_ending`` files for file_endings.
189
208
 
190
- Can be useful to clean the `workdir` from temporary files if e.g. only
191
- the concatenate trajectory is of interesst (like in TPS).
209
+ Can be useful to clean the ``workdir`` from temporary files if e.g. only
210
+ the concatenate trajectory is of interest (like in TPS).
192
211
 
193
212
  Parameters
194
213
  ----------
195
214
  workdir : str
196
215
  The directory to clean.
197
216
  deffnm : str
198
- The `deffnm` that the files we clean must have.
199
- file_endings_to_remove : list[str], optional
200
- The strings in the list `file_endings_to_remove` indicate which
217
+ The ``deffnm`` that the files we clean must have.
218
+ file_endings_to_remove : list[str] | None, optional
219
+ The strings in the list ``file_endings_to_remove`` indicate which
201
220
  file endings to remove.
202
- E.g. `file_endings_to_remove=["trajectories", "log"]` will result
203
- in removal of trajectory parts and the log files. If you add "edr"
204
- to the list we would also remove the edr files,
205
- by default ["trajectories", "log"]
221
+ The 'special' string "trajectories" will be translated to the file
222
+ ending of the trajectories the engine produces, i.e. ``engine.output_traj_type``.
223
+ E.g. passing ``file_endings_to_remove=["trajectories", "log"]`` will
224
+ result in removal of trajectory parts and the log files. If you add
225
+ "edr" to the list we would also remove the edr files.
226
+ By default, i.e., if None the list will be ["trajectories", "log"].
227
+ remove_mda_offset_and_lock_files : bool, optional
228
+ Whether to remove any (hidden) offset and lock files generated by
229
+ MDAnalysis associated with the removed trajectory files (if they exist).
230
+ By default True.
231
+ remove_asyncmd_npz_caches : bool, optional
232
+ Whether to remove any (hidden) npz cache files generated by asyncmd
233
+ associated with the removed trajectory files (if they exist).
234
+ By default True.
206
235
  """
236
+ if file_endings_to_remove is None:
237
+ file_endings_to_remove = ["trajectories", "log"]
238
+ else:
239
+ # copy the list so we dont mutate what we got passed
240
+ file_endings_to_remove = copy.copy(file_endings_to_remove)
207
241
  if "trajectories" in file_endings_to_remove:
208
242
  # replace "trajectories" with the actual output traj type
209
243
  try:
@@ -222,7 +256,7 @@ class _TrajectoryPropagator:
222
256
  )
223
257
  # make sure we dont miss anything because we have different
224
258
  # capitalization
225
- if len(parts_to_remove) == 0:
259
+ if not parts_to_remove:
226
260
  parts_to_remove = await get_all_file_parts(
227
261
  folder=workdir,
228
262
  deffnm=deffnm,
@@ -232,7 +266,7 @@ class _TrajectoryPropagator:
232
266
  for f in parts_to_remove
233
267
  )
234
268
  )
235
- # TODO: the note below?
269
+ # TODO: address the note below?
236
270
  # NOTE: this is a bit hacky: we just try to remove the offset and
237
271
  # lock files for every file we remove (since we do not know
238
272
  # if the file we remove is a trajectory [and therefore
@@ -286,8 +320,9 @@ class InPartsTrajectoryPropagator(_TrajectoryPropagator):
286
320
  Useful to make full use of backfilling with short(ish) simulation jobs and
287
321
  also to run simulations that are longer than the timelimit.
288
322
  """
289
- def __init__(self, n_steps: int, engine_cls,
290
- engine_kwargs: dict, walltime_per_part: float,
323
+ def __init__(self, n_steps: int, *,
324
+ engine_cls: type[MDEngine], engine_kwargs: dict[str, typing.Any],
325
+ walltime_per_part: float,
291
326
  ) -> None:
292
327
  """
293
328
  Initialize an `InPartTrajectoryPropagator`.
@@ -303,19 +338,18 @@ class InPartsTrajectoryPropagator(_TrajectoryPropagator):
303
338
  walltime_per_part : float
304
339
  Walltime per trajectory segment, in hours.
305
340
  """
341
+ super().__init__(engine_cls=engine_cls, engine_kwargs=engine_kwargs,
342
+ walltime_per_part=walltime_per_part)
306
343
  self.n_steps = n_steps
307
- self.engine_cls = engine_cls
308
- self.engine_kwargs = engine_kwargs
309
- self.walltime_per_part = walltime_per_part
310
344
 
311
345
  async def propagate_and_concatenate(self,
312
346
  starting_configuration: Trajectory,
313
347
  workdir: str,
314
- deffnm: str,
348
+ deffnm: str, *,
315
349
  tra_out: str,
316
- overwrite: bool = False,
317
- continuation: bool = False
318
- ) -> tuple[Trajectory, int]:
350
+ continuation: bool = False,
351
+ **concatenate_kwargs,
352
+ ) -> Trajectory | None:
319
353
  """
320
354
  Chain :meth:`propagate` and :meth:`cut_and_concatenate` methods.
321
355
 
@@ -329,12 +363,18 @@ class InPartsTrajectoryPropagator(_TrajectoryPropagator):
329
363
  MD engine deffnm for trajectory parts and other files.
330
364
  tra_out : str
331
365
  Absolute or relative path for the concatenated output trajectory.
332
- overwrite : bool, optional
333
- Whether the output trajectory should be overwritten if it exists,
334
- by default False.
335
366
  continuation : bool, optional
336
367
  Whether we are continuing a previous MD run (with the same deffnm
337
368
  and working directory), by default False.
369
+ concatenate_kwargs : dict
370
+ All (other) keyword arguments will be passed as is to the
371
+ :meth:`TrajectoryConcatenator.concatenate` method. These include, e.g.,
372
+
373
+ - struct_out : str, optional
374
+ Absolute or relative path to the output structure file, if None
375
+ we will use the structure file of the first traj, by default None.
376
+ - overwrite : bool, optional
377
+ Whether we should overwrite tra_out if it exists, by default False.
338
378
 
339
379
  Returns
340
380
  -------
@@ -348,23 +388,21 @@ class InPartsTrajectoryPropagator(_TrajectoryPropagator):
348
388
  deffnm=deffnm,
349
389
  continuation=continuation
350
390
  )
351
- full_traj = await self.cut_and_concatenate(
352
- trajs=trajs,
353
- tra_out=tra_out,
354
- overwrite=overwrite,
391
+ full_traj = await self.cut_and_concatenate(trajs=trajs, tra_out=tra_out,
392
+ **concatenate_kwargs,
355
393
  )
356
394
  return full_traj
357
395
 
358
396
  async def propagate(self,
359
397
  starting_configuration: Trajectory,
360
398
  workdir: str,
361
- deffnm: str,
399
+ deffnm: str, *,
362
400
  continuation: bool = False,
363
401
  ) -> list[Trajectory]:
364
402
  """
365
403
  Propagate the trajectory until self.n_steps integration are done.
366
404
 
367
- Return a list of trajecory segments and the first condition fullfilled.
405
+ Return a list of trajectory segments and the first condition fulfilled.
368
406
 
369
407
  Parameters
370
408
  ----------
@@ -385,7 +423,7 @@ class InPartsTrajectoryPropagator(_TrajectoryPropagator):
385
423
  Returns
386
424
  -------
387
425
  traj_segments : list[Trajectory]
388
- List of trajectory (segements), ordered in time.
426
+ List of trajectory (segments), ordered in time.
389
427
  """
390
428
  engine = self.engine_cls(**self.engine_kwargs)
391
429
  if continuation:
@@ -396,11 +434,10 @@ class InPartsTrajectoryPropagator(_TrajectoryPropagator):
396
434
  )
397
435
  if len(trajs) > 0:
398
436
  # can only continue if we find the previous trajs
399
- step_counter = engine.steps_done
400
- if step_counter >= self.n_steps:
437
+ await engine.prepare_from_files(workdir=workdir, deffnm=deffnm)
438
+ if (step_counter := engine.steps_done) >= self.n_steps:
401
439
  # already longer than what we want to do, bail out
402
440
  return trajs
403
- await engine.prepare_from_files(workdir=workdir, deffnm=deffnm)
404
441
  else:
405
442
  # no previous trajs, prepare engine from scratch
406
443
  continuation = False
@@ -417,11 +454,10 @@ class InPartsTrajectoryPropagator(_TrajectoryPropagator):
417
454
  trajs = []
418
455
  step_counter = 0
419
456
 
420
- while (step_counter < self.n_steps):
421
- traj = await engine.run(nsteps=self.n_steps,
422
- walltime=self.walltime_per_part,
423
- steps_per_part=False,
424
- )
457
+ while step_counter < self.n_steps:
458
+ traj = await engine.run_walltime(walltime=self.walltime_per_part,
459
+ max_steps=self.n_steps,
460
+ )
425
461
  step_counter = engine.steps_done
426
462
  trajs.append(traj)
427
463
  return trajs
@@ -429,15 +465,16 @@ class InPartsTrajectoryPropagator(_TrajectoryPropagator):
429
465
  async def cut_and_concatenate(self,
430
466
  trajs: list[Trajectory],
431
467
  tra_out: str,
432
- overwrite: bool = False,
433
- ) -> Trajectory:
468
+ **concatenate_kwargs,
469
+ ) -> Trajectory | None:
434
470
  """
435
471
  Cut and concatenate the trajectory until it has length n_steps.
436
472
 
437
- Take a list of trajectory segments and form one continous trajectory
473
+ Take a list of trajectory segments and form one continuous trajectory
438
474
  containing n_steps integration steps. The expected input
439
475
  is a list of trajectories, e.g. the output of the :meth:`propagate`
440
476
  method.
477
+ Returns None if ``self.n_steps`` is zero.
441
478
 
442
479
  Parameters
443
480
  ----------
@@ -445,9 +482,15 @@ class InPartsTrajectoryPropagator(_TrajectoryPropagator):
445
482
  Trajectory segments to cut and concatenate.
446
483
  tra_out : str
447
484
  Absolute or relative path for the concatenated output trajectory.
448
- overwrite : bool, optional
449
- Whether the output trajectory should be overwritten if it exists,
450
- by default False.
485
+ concatenate_kwargs : dict
486
+ All (other) keyword arguments will be passed as is to the
487
+ :meth:`TrajectoryConcatenator.concatenate` method. These include, e.g.,
488
+
489
+ - struct_out : str, optional
490
+ Absolute or relative path to the output structure file, if None
491
+ we will use the structure file of the first traj, by default None.
492
+ - overwrite : bool, optional
493
+ Whether we should overwrite tra_out if it exists, by default False.
451
494
 
452
495
  Returns
453
496
  -------
@@ -462,23 +505,18 @@ class InPartsTrajectoryPropagator(_TrajectoryPropagator):
462
505
  """
463
506
  # trajs is a list of trajectories, e.g. the return of propagate
464
507
  # tra_out and overwrite are passed directly to the Concatenator
465
- if len(trajs) == 0:
466
- # no trajectories to concatenate, happens e.g. if self.n_steps=0
467
- # we return None (TODO: is this what we want?)
508
+ if not self.n_steps:
509
+ # no trajectories to concatenate, since self.n_steps=0
510
+ # we return None
468
511
  return None
469
512
  if self.n_steps > trajs[-1].last_step:
470
513
  # not enough steps in trajectories
471
- raise ValueError("The given trajectories are to short (< self.n_steps).")
472
- elif self.n_steps == trajs[-1].last_step:
514
+ raise ValueError("The given trajectories are too short (< self.n_steps).")
515
+ if self.n_steps == trajs[-1].last_step:
473
516
  # all good, we just take all trajectory parts fully
474
517
  slices = [(0, None, 1) for _ in range(len(trajs))]
475
518
  last_part_idx = len(trajs) - 1
476
519
  else:
477
- logger.warning("Trajectories do not exactly contain n_steps "
478
- "integration steps. Using a heuristic to find the "
479
- "correct last frame to include, note that this "
480
- "heuristic might fail if n_steps is not a multiple "
481
- "of the trajectory output frequency.")
482
520
  # need to find the subtrajectory that contains the correct number
483
521
  # of integration steps
484
522
  # first find the part in which we go over n_steps
@@ -490,76 +528,80 @@ class InPartsTrajectoryPropagator(_TrajectoryPropagator):
490
528
  last_part_len_steps = (trajs[last_part_idx].last_step
491
529
  - trajs[last_part_idx].first_step)
492
530
  steps_per_frame = last_part_len_steps / last_part_len_frames
493
- frames_in_last_part = 0
494
- while ((trajs[last_part_idx].first_step
495
- + frames_in_last_part * steps_per_frame) < self.n_steps):
496
- # I guess we stay with the < (instead of <=) and rather have
497
- # one frame too much?
498
- frames_in_last_part += 1
531
+ frames_in_last_part = ((self.n_steps
532
+ - trajs[last_part_idx].first_step
533
+ )
534
+ / steps_per_frame)
535
+ log_str = ("Trajectories do not exactly contain n_steps integration steps. "
536
+ "Using a heuristic to find the correct last frame to include."
537
+ )
538
+ if frames_in_last_part != (frames_in_last_part_int := round(frames_in_last_part)):
539
+ log_str += (" Note that this heuristic might fail because n_steps"
540
+ " is not a multiple of the trajectory output frequency."
541
+ )
542
+ logger.warning(log_str)
543
+ else:
544
+ logger.info(log_str)
499
545
  # build slices
500
546
  slices = [(0, None, 1) for _ in range(last_part_idx)]
501
- slices += [(0, frames_in_last_part + 1, 1)]
547
+ slices += [(0, frames_in_last_part_int + 1, 1)]
502
548
 
503
549
  # and concatenate
504
550
  full_traj = await TrajectoryConcatenator().concatenate_async(
505
551
  trajs=trajs[:last_part_idx + 1],
506
552
  slices=slices,
507
- # take the structure file of the traj, as it
508
- # comes from the engine directly
509
- tra_out=tra_out, struct_out=None,
510
- overwrite=overwrite,
511
- )
553
+ tra_out=tra_out,
554
+ **concatenate_kwargs
555
+ )
512
556
  return full_traj
513
557
 
514
558
 
515
559
  class ConditionalTrajectoryPropagator(_TrajectoryPropagator):
516
560
  """
517
- Propagate a trajectory until any of the given conditions is fullfilled.
561
+ Propagate a trajectory until any of the given conditions is fulfilled.
518
562
 
519
563
  This class propagates the trajectory using a given MD engine (class) in
520
564
  small chunks (chunksize is determined by walltime_per_part) and checks
521
- after every chunk is done if any condition has been fullfilled.
565
+ after every chunk is done if any condition has been fulfilled.
522
566
  It then returns a list of trajectory parts and the index of the condition
523
- first fullfilled. It can also concatenate the parts into one trajectory,
567
+ first fulfilled. It can also concatenate the parts into one trajectory,
524
568
  which then starts with the starting configuration and ends with the frame
525
- fullfilling the condition.
526
-
527
- Attributes
528
- ----------
529
- conditions : list[callable]
530
- List of (wrapped) condition functions.
569
+ fulfilling the condition.
531
570
 
532
571
  Notes
533
572
  -----
534
573
  We assume that every condition function returns a list/ a 1d array with
535
- True or False for each frame, i.e. if we fullfill condition at any given
574
+ True or False for each frame, i.e. if we fulfill condition at any given
536
575
  frame.
537
- We assume non-overlapping conditions, i.e. a configuration can not fullfill
576
+ We assume non-overlapping conditions, i.e. a configuration can not fulfill
538
577
  two conditions at the same time, **it is the users responsibility to ensure
539
578
  that their conditions are sane**.
540
579
  """
541
580
 
542
581
  # NOTE: we assume that every condition function returns a list/ a 1d array
543
- # with True/False for each frame, i.e. if we fullfill condition at
582
+ # with True/False for each frame, i.e. if we fulfill condition at
544
583
  # any given frame
545
584
  # NOTE: we assume non-overlapping conditions, i.e. a configuration can not
546
- # fullfill two conditions at the same time, it is the users
585
+ # fulfill two conditions at the same time, it is the users
547
586
  # responsibility to ensure that their conditions are sane
548
587
 
549
- def __init__(self, conditions, engine_cls,
550
- engine_kwargs: dict,
588
+ # Note: max_steps and max_frames are mutually exclusive and this is enforced,
589
+ # but pylint does not know that, so we tell it to not be mad for one arg more
590
+ # pylint: disable-next=too-many-arguments
591
+ def __init__(self, conditions, *,
592
+ engine_cls: type[MDEngine], engine_kwargs: dict[str, typing.Any],
551
593
  walltime_per_part: float,
552
- max_steps: typing.Optional[int] = None,
553
- max_frames: typing.Optional[int] = None,
594
+ max_steps: int | None = None,
595
+ max_frames: int | None = None,
554
596
  ):
555
597
  """
556
- Initialize a `ConditionalTrajectoryPropagator`.
598
+ Initialize a :class:`ConditionalTrajectoryPropagator`.
557
599
 
558
600
  Parameters
559
601
  ----------
560
602
  conditions : list[callable], usually list[TrajectoryFunctionWrapper]
561
603
  List of condition functions, usually wrapped function for
562
- asyncronous application, but can be any callable that takes a
604
+ asynchronous application, but can be any callable that takes a
563
605
  :class:`asyncmd.Trajectory` and returns an array of True and False
564
606
  values (one value per frame).
565
607
  engine_cls : :class:`asyncmd.mdengine.MDEngine`
@@ -571,7 +613,7 @@ class ConditionalTrajectoryPropagator(_TrajectoryPropagator):
571
613
  max_steps : int, optional
572
614
  Maximum number of integration steps to do before stopping the
573
615
  simulation because it did not commit to any condition,
574
- by default None. Takes precendence over max_frames if both given.
616
+ by default None. Takes precedence over max_frames if both given.
575
617
  max_frames : int, optional
576
618
  Maximum number of frames to produce before stopping the simulation
577
619
  because it did not commit to any condition, by default None.
@@ -582,12 +624,11 @@ class ConditionalTrajectoryPropagator(_TrajectoryPropagator):
582
624
  ``max_steps = max_frames * output_frequency``, if both are given
583
625
  max_steps takes precedence.
584
626
  """
585
- self._conditions = None
586
- self._condition_func_is_coroutine = None
627
+ super().__init__(engine_cls=engine_cls, engine_kwargs=engine_kwargs,
628
+ walltime_per_part=walltime_per_part)
629
+ self._conditions = FlagChangeList([])
630
+ self._condition_func_is_coroutine: list[bool] = []
587
631
  self.conditions = conditions
588
- self.engine_cls = engine_cls
589
- self.engine_kwargs = engine_kwargs
590
- self.walltime_per_part = walltime_per_part
591
632
  # find nstout
592
633
  try:
593
634
  traj_type = engine_kwargs["output_traj_type"]
@@ -599,51 +640,58 @@ class ConditionalTrajectoryPropagator(_TrajectoryPropagator):
599
640
  # sort out if we use max_frames or max_steps
600
641
  if max_frames is not None and max_steps is not None:
601
642
  logger.warning("Both max_steps and max_frames given. Note that "
602
- + "max_steps will take precedence.")
643
+ "max_steps will take precedence.")
603
644
  if max_steps is not None:
604
645
  self.max_steps = max_steps
605
646
  elif max_frames is not None:
606
647
  self.max_steps = max_frames * nstout
607
648
  else:
608
649
  logger.info("Neither max_frames nor max_steps given. "
609
- + "Setting max_steps to infinity.")
650
+ "Setting max_steps to infinity.")
610
651
  # this is a float but can be compared to ints
611
652
  self.max_steps = np.inf
612
653
 
613
- # TODO/FIXME: self._conditions is a list...that means users can change
614
- # single elements without using the setter!
615
- # we could use a list subclass as for the MDconfig?!
616
654
  @property
617
- def conditions(self):
655
+ def conditions(self) -> FlagChangeList:
656
+ """List of (wrapped) condition functions."""
618
657
  return self._conditions
619
658
 
620
659
  @conditions.setter
621
- def conditions(self, conditions):
622
- # use asyncio.iscorotinefunction to check the conditions
623
- self._condition_func_is_coroutine = [
660
+ def conditions(self, conditions: collections.abc.Sequence):
661
+ if len(conditions) < 1:
662
+ raise ValueError("Must supply at least one termination condition.")
663
+ self._condition_func_is_coroutine = self._check_condition_funcs(
664
+ conditions=conditions
665
+ )
666
+ self._conditions = FlagChangeList(conditions)
667
+
668
+ def _check_condition_funcs(self, conditions: collections.abc.Sequence,
669
+ ) -> list[bool]:
670
+ # use asyncio.iscoroutinefunction to check the conditions
671
+ condition_func_is_coroutine = [
624
672
  (inspect.iscoroutinefunction(c)
625
673
  or inspect.iscoroutinefunction(c.__call__))
626
674
  for c in conditions
627
- ]
628
- if not all(self._condition_func_is_coroutine):
629
- # and warn if it is not a corotinefunction
675
+ ]
676
+ if not all(condition_func_is_coroutine):
677
+ # and warn if it is not a coroutinefunction
630
678
  logger.warning(
631
679
  "It is recommended to use coroutinefunctions for all "
632
- + "conditions. This can easily be achieved by wrapping any"
633
- + " function in a TrajectoryFunctionWrapper. All "
634
- + "non-coroutine condition functions will be blocking when"
635
- + " applied! ([c is coroutine for c in conditions] = %s)",
636
- self._condition_func_is_coroutine
680
+ "conditions. This can easily be achieved by wrapping any "
681
+ "function in a TrajectoryFunctionWrapper. All "
682
+ "non-coroutine condition functions will be blocking when "
683
+ "applied! ([c is coroutine for c in conditions] = %s)",
684
+ condition_func_is_coroutine
637
685
  )
638
- self._conditions = conditions
686
+ return condition_func_is_coroutine
639
687
 
640
688
  async def propagate_and_concatenate(self,
641
689
  starting_configuration: Trajectory,
642
690
  workdir: str,
643
- deffnm: str,
691
+ deffnm: str, *,
644
692
  tra_out: str,
645
- overwrite: bool = False,
646
- continuation: bool = False
693
+ continuation: bool = False,
694
+ **concatenate_kwargs
647
695
  ) -> tuple[Trajectory, int]:
648
696
  """
649
697
  Chain :meth:`propagate` and :meth:`cut_and_concatenate` methods.
@@ -658,16 +706,22 @@ class ConditionalTrajectoryPropagator(_TrajectoryPropagator):
658
706
  MD engine deffnm for trajectory parts and other files.
659
707
  tra_out : str
660
708
  Absolute or relative path for the concatenated output trajectory.
661
- overwrite : bool, optional
662
- Whether the output trajectory should be overwritten if it exists,
663
- by default False.
664
709
  continuation : bool, optional
665
710
  Whether we are continuing a previous MD run (with the same deffnm
666
711
  and working directory), by default False.
712
+ concatenate_kwargs : dict
713
+ All (other) keyword arguments will be passed as is to the
714
+ :meth:`TrajectoryConcatenator.concatenate` method. These include, e.g.,
715
+
716
+ - struct_out : str, optional
717
+ Absolute or relative path to the output structure file, if None
718
+ we will use the structure file of the first traj, by default None.
719
+ - overwrite : bool, optional
720
+ Whether we should overwrite tra_out if it exists, by default False.
667
721
 
668
722
  Returns
669
723
  -------
670
- (traj_out, idx_of_condition_fullfilled) : (Trajectory, int)
724
+ (traj_out, idx_of_condition_fulfilled) : (Trajectory, int)
671
725
  The concatenated output trajectory from starting configuration
672
726
  until the first condition is True and the index to the condition
673
727
  function in `conditions`.
@@ -679,34 +733,34 @@ class ConditionalTrajectoryPropagator(_TrajectoryPropagator):
679
733
  frames has been reached in :meth:`propagate`.
680
734
  """
681
735
  # this just chains propagate and cut_and_concatenate
682
- # usefull for committor simulations, for e.g. TPS one should try to
736
+ # useful for committor simulations, for e.g. TPS one should try to
683
737
  # directly concatenate both directions to a full TP if possible
684
- trajs, first_condition_fullfilled = await self.propagate(
738
+ trajs, first_condition_fulfilled = await self.propagate(
685
739
  starting_configuration=starting_configuration,
686
740
  workdir=workdir,
687
741
  deffnm=deffnm,
688
742
  continuation=continuation
689
- )
743
+ )
690
744
  # NOTE: it should not matter too much speedwise that we recalculate
691
745
  # the condition functions, they are expected to be wrapped funcs
692
746
  # i.e. the second time we should just get the values from cache
693
- full_traj, first_condition_fullfilled = await self.cut_and_concatenate(
747
+ full_traj, first_condition_fulfilled = await self.cut_and_concatenate(
694
748
  trajs=trajs,
695
749
  tra_out=tra_out,
696
- overwrite=overwrite,
697
- )
698
- return full_traj, first_condition_fullfilled
750
+ **concatenate_kwargs
751
+ )
752
+ return full_traj, first_condition_fulfilled
699
753
 
700
754
  async def propagate(self,
701
755
  starting_configuration: Trajectory,
702
756
  workdir: str,
703
- deffnm: str,
757
+ deffnm: str, *,
704
758
  continuation: bool = False,
705
759
  ) -> tuple[list[Trajectory], int]:
706
760
  """
707
- Propagate the trajectory until any condition is fullfilled.
761
+ Propagate the trajectory until any condition is fulfilled.
708
762
 
709
- Return a list of trajecory segments and the first condition fullfilled.
763
+ Return a list of trajectory segments and the first condition fulfilled.
710
764
 
711
765
  Parameters
712
766
  ----------
@@ -722,10 +776,10 @@ class ConditionalTrajectoryPropagator(_TrajectoryPropagator):
722
776
 
723
777
  Returns
724
778
  -------
725
- (traj_segments, idx_of_condition_fullfilled) : (list[Trajectory], int)
726
- List of trajectory (segements), the last entry is the one on which
727
- the first condition is fullfilled at some frame, the ineger is the
728
- index to the condition function in `conditions`.
779
+ (traj_segments, idx_of_condition_fulfilled) : (list[Trajectory], int)
780
+ List of trajectory (segments), the last entry is the one on which
781
+ the first condition is fulfilled at some frame, the integer is the
782
+ index to the condition function in ``conditions``.
729
783
 
730
784
  Raises
731
785
  ------
@@ -733,26 +787,29 @@ class ConditionalTrajectoryPropagator(_TrajectoryPropagator):
733
787
  When the defined maximum number of integration steps/trajectory
734
788
  frames has been reached.
735
789
  """
736
- # NOTE: curently this just returns a list of trajs + the condition
737
- # fullfilled
790
+ # NOTE: currently this just returns a list of trajs + the condition
791
+ # fulfilled
738
792
  # this feels a bit uncomfortable but avoids that we concatenate
739
793
  # everything a quadrillion times when we use the results
740
- # check first if the start configuration is fullfilling any condition
794
+ # check first if the start configuration is fulfilling any condition
741
795
  cond_vals = await self._condition_vals_for_traj(starting_configuration)
742
796
  if np.any(cond_vals):
743
- conds_fullfilled, frame_nums = np.where(cond_vals)
797
+ conds_fulfilled, frame_nums = np.where(cond_vals)
744
798
  # gets the frame with the lowest idx where any condition is True
745
799
  min_idx = np.argmin(frame_nums)
746
- first_condition_fullfilled = conds_fullfilled[min_idx]
747
- logger.error(f"Starting configuration ({starting_configuration}) "
748
- + "is already fullfilling the condition with idx"
749
- + f" {first_condition_fullfilled}.")
800
+ first_condition_fulfilled = conds_fulfilled[min_idx]
801
+ logger.error("Starting configuration (%s) is already fulfilling "
802
+ "the condition with idx %s.",
803
+ starting_configuration, first_condition_fulfilled,
804
+ )
750
805
  # we just return the starting configuration/trajectory
751
806
  trajs = [starting_configuration]
752
- return trajs, first_condition_fullfilled
807
+ return trajs, first_condition_fulfilled
753
808
 
754
- # starting configuration does not fullfill any condition, lets do MD
809
+ # starting configuration does not fulfill any condition, lets do MD
755
810
  engine = self.engine_cls(**self.engine_kwargs)
811
+ # Note: we first check for continuation because if we do not find a run
812
+ # to continue we fallback to no continuation
756
813
  if continuation:
757
814
  # continuation: get all traj parts already done and continue from
758
815
  # there, i.e. append to the last traj part found
@@ -761,32 +818,31 @@ class ConditionalTrajectoryPropagator(_TrajectoryPropagator):
761
818
  trajs = await get_all_traj_parts(folder=workdir, deffnm=deffnm,
762
819
  engine=engine,
763
820
  )
764
- if len(trajs) > 0:
765
- # can only calc CV values if we have trajectories prouced
821
+ if not trajs:
822
+ # no trajectories, so we should prepare engine from scratch
823
+ continuation = False
824
+ logger.error("continuation=True, but we found no previous "
825
+ "trajectories. Setting continuation=False and "
826
+ "preparing the engine from scratch.")
827
+ else:
828
+ # (can only) calc CV values if we have found trajectories
766
829
  cond_vals = await asyncio.gather(
767
830
  *(self._condition_vals_for_traj(t) for t in trajs)
768
831
  )
769
832
  cond_vals = np.concatenate([np.asarray(s) for s in cond_vals],
770
833
  axis=1)
771
- # see if we already fullfill a condition on the existing traj parts
772
- any_cond_fullfilled = np.any(cond_vals)
773
- if any_cond_fullfilled:
774
- conds_fullfilled, frame_nums = np.where(cond_vals)
834
+ # see if we already fulfill a condition on the existing traj parts
835
+ if (any_cond_fulfilled := np.any(cond_vals)):
836
+ conds_fulfilled, frame_nums = np.where(cond_vals)
775
837
  # gets the frame with the lowest idx where any cond is True
776
838
  min_idx = np.argmin(frame_nums)
777
- first_condition_fullfilled = conds_fullfilled[min_idx]
778
- # already fullfill a condition, get out of here!
779
- return trajs, first_condition_fullfilled
780
- # Did not fullfill any condition yet, so prepare the engine to
839
+ first_condition_fulfilled = conds_fulfilled[min_idx]
840
+ # already fulfill a condition, get out of here!
841
+ return trajs, first_condition_fulfilled
842
+ # Did not fulfill any condition yet, so prepare the engine to
781
843
  # continue the simulation until we reach any of the (new) conds
782
844
  await engine.prepare_from_files(workdir=workdir, deffnm=deffnm)
783
- step_counter = engine.steps_done
784
- else:
785
- # no trajectories, so we should prepare engine from scratch
786
- continuation = False
787
- logger.error("continuation=True, but we found no previous "
788
- "trajectories. Setting continuation=False and "
789
- "preparing the engine from scratch.")
845
+
790
846
  if not continuation:
791
847
  # no continuation, just prepare the engine from scratch
792
848
  await engine.prepare(
@@ -794,45 +850,46 @@ class ConditionalTrajectoryPropagator(_TrajectoryPropagator):
794
850
  workdir=workdir,
795
851
  deffnm=deffnm,
796
852
  )
797
- any_cond_fullfilled = False
798
853
  trajs = []
799
- step_counter = 0
800
854
 
801
- while ((not any_cond_fullfilled)
855
+ step_counter = engine.steps_done
856
+ any_cond_fulfilled = False
857
+ while ((not any_cond_fulfilled)
802
858
  and (step_counter <= self.max_steps)):
803
859
  traj = await engine.run_walltime(self.walltime_per_part)
804
860
  cond_vals = await self._condition_vals_for_traj(traj)
805
- any_cond_fullfilled = np.any(cond_vals)
861
+ any_cond_fulfilled = np.any(cond_vals)
806
862
  step_counter = engine.steps_done
807
863
  trajs.append(traj)
808
- if not any_cond_fullfilled:
864
+ if not any_cond_fulfilled:
809
865
  # left while loop because of max_frames reached
810
866
  raise MaxStepsReachedError(
811
867
  f"Engine produced {step_counter} steps (>= {self.max_steps})."
812
868
  )
813
869
  # cond_vals are the ones for the last traj
814
870
  # here we get which conditions are True and at which frame
815
- conds_fullfilled, frame_nums = np.where(cond_vals)
871
+ conds_fulfilled, frame_nums = np.where(cond_vals)
816
872
  # gets the frame with the lowest idx where any condition is True
817
873
  min_idx = np.argmin(frame_nums)
818
874
  # and now the idx to self.conditions for cond that was first fullfilled
819
875
  # NOTE/FIXME: if two conditions are reached simultaneously at min_idx,
820
876
  # this will find the condition with the lower idx only
821
- first_condition_fullfilled = conds_fullfilled[min_idx]
822
- return trajs, first_condition_fullfilled
877
+ first_condition_fulfilled = conds_fulfilled[min_idx]
878
+ return trajs, first_condition_fulfilled
823
879
 
824
880
  async def cut_and_concatenate(self,
825
881
  trajs: list[Trajectory],
826
882
  tra_out: str,
827
- overwrite: bool = False,
883
+ **concatenate_kwargs,
828
884
  ) -> tuple[Trajectory, int]:
829
885
  """
830
886
  Cut and concatenate the trajectory until the first condition is True.
831
887
 
832
- Take a list of trajectory segments and form one continous trajectory
833
- until the first frame that fullfills any condition. The expected input
834
- is a list of trajectories, e.g. the output of the :meth:`propagate`
835
- method.
888
+ Take a list of trajectory segments and form one continuous trajectory
889
+ until the first frame that fulfills any condition. The first frame in
890
+ that fulfills any condition is included in the trajectory.
891
+ The expected input is a list of trajectories, e.g. the output of the
892
+ :meth:`propagate` method.
836
893
 
837
894
  Parameters
838
895
  ----------
@@ -840,42 +897,42 @@ class ConditionalTrajectoryPropagator(_TrajectoryPropagator):
840
897
  Trajectory segments to cut and concatenate.
841
898
  tra_out : str
842
899
  Absolute or relative path for the concatenated output trajectory.
843
- overwrite : bool, optional
844
- Whether the output trajectory should be overwritten if it exists,
845
- by default False.
900
+ concatenate_kwargs : dict
901
+ All (other) keyword arguments will be passed as is to the
902
+ :meth:`TrajectoryConcatenator.concatenate` method. These include, e.g.,
903
+
904
+ - struct_out : str, optional
905
+ Absolute or relative path to the output structure file, if None
906
+ we will use the structure file of the first traj, by default None.
907
+ - overwrite : bool, optional
908
+ Whether we should overwrite tra_out if it exists, by default False.
846
909
 
847
910
  Returns
848
911
  -------
849
- (traj_out, idx_of_condition_fullfilled) : (Trajectory, int)
912
+ (traj_out, idx_of_condition_fulfilled) : (Trajectory, int)
850
913
  The concatenated output trajectory from starting configuration
851
914
  until the first condition is True and the index to the condition
852
915
  function in `conditions`.
853
916
  """
854
- # trajs is a list of trajectories, e.g. the return of propagate
855
- # tra_out and overwrite are passed directly to the Concatenator
856
- # NOTE: we assume that frame0 of traj0 does not fullfill any condition
857
- # and return only the subtrajectory from frame0 until any
858
- # condition is first True (the rest is ignored)
859
- # get all func values and put them into one big array
917
+ # get all func values and put them into one big list
860
918
  cond_vals = await asyncio.gather(
861
919
  *(self._condition_vals_for_traj(t) for t in trajs)
862
920
  )
863
921
  # cond_vals is a list (trajs) of lists (conditions)
864
922
  # take condition 0 (always present) to get the traj part lengths
865
923
  part_lens = [len(c[0]) for c in cond_vals] # c[0] is 1d (np)array
866
- cond_vals = np.concatenate([np.asarray(c) for c in cond_vals],
867
- axis=1)
868
- conds_fullfilled, frame_nums = np.where(cond_vals)
869
- # gets the frame with the lowest idx where any condition is True
924
+ # get all occurrences where any condition is True
925
+ conds_fulfilled, frame_nums = np.where(
926
+ np.concatenate([np.asarray(c) for c in cond_vals], axis=1)
927
+ )
928
+ # get the index of the frame with the lowest number where any condition is True
870
929
  min_idx = np.argmin(frame_nums)
871
- first_condition_fullfilled = conds_fullfilled[min_idx]
930
+ first_condition_fulfilled = conds_fulfilled[min_idx]
872
931
  first_frame_in_cond = frame_nums[min_idx]
873
932
  # find out in which part it is
874
- last_part_idx = 0
875
- frame_sum = part_lens[last_part_idx]
876
- while first_frame_in_cond >= frame_sum:
877
- last_part_idx += 1
878
- frame_sum += part_lens[last_part_idx]
933
+ # nonzero always returns a tuple (first zero index below)
934
+ # and we only care for the first occurrence of True/nonzero (second zero index below)
935
+ last_part_idx = (np.cumsum(part_lens) >= first_frame_in_cond).nonzero()[0][0]
879
936
  # find the first frame in cond (counting from start of last part)
880
937
  _first_frame_in_cond = (first_frame_in_cond
881
938
  - sum(part_lens[:last_part_idx])) # >= 0
@@ -891,46 +948,37 @@ class ConditionalTrajectoryPropagator(_TrajectoryPropagator):
891
948
  full_traj = await TrajectoryConcatenator().concatenate_async(
892
949
  trajs=trajs[:last_part_idx + 1],
893
950
  slices=slices,
894
- # take the structure file of the traj, as it
895
- # comes from the engine directly
896
- tra_out=tra_out, struct_out=None,
897
- overwrite=overwrite,
898
- )
899
- return full_traj, first_condition_fullfilled
900
-
901
- async def _condition_vals_for_traj(self, traj):
951
+ tra_out=tra_out,
952
+ **concatenate_kwargs
953
+ )
954
+ return full_traj, first_condition_fulfilled
955
+
956
+ async def _condition_vals_for_traj(self, traj: Trajectory
957
+ ) -> list[np.ndarray]:
902
958
  # return a list of condition_func results,
903
959
  # one for each condition func in conditions
904
- if all(self._condition_func_is_coroutine):
905
- # easy, all coroutines
906
- return await asyncio.gather(*(c(traj) for c in self.conditions))
907
- elif not any(self._condition_func_is_coroutine):
908
- # also easy (but blocking), none is coroutine
909
- return [c(traj) for c in self.conditions]
910
- else:
911
- # need to piece it together
912
- # first the coroutines concurrently
913
- coros = [c(traj) for c, c_is_coro
914
- in zip(self.conditions, self._condition_func_is_coroutine)
915
- if c_is_coro
916
- ]
917
- coro_res = await asyncio.gather(*coros)
918
- # now either take the result from coro execution or calculate it
919
- all_results = []
960
+ if self.conditions.changed:
961
+ # first check if the conditions (single entries) have been modified
962
+ # if yes just reassign to the property so we recheck which of them
963
+ # are coroutines
964
+ self.conditions = self.conditions
965
+ # we wrap the non-coroutines into tasks to schedule all together
966
+ all_conditions_as_coro = [
967
+ c(traj) if c_is_coro else asyncio.to_thread(c, traj)
920
968
  for c, c_is_coro in zip(self.conditions,
921
- self._condition_func_is_coroutine):
922
- if c_is_coro:
923
- all_results.append(coro_res.pop(0))
924
- else:
925
- all_results.append(c(traj))
926
- return all_results
927
- # NOTE: this would be elegant, but to_thread() is py v>=3.9
928
- # we wrap the non-coroutines into tasks to schedule all together
929
- #all_conditions_as_coro = [
930
- # c(traj) if c_is_cor else asyncio.to_thread(c, traj)
931
- # for c, c_is_cor in zip(self.conditions, self._condition_func_is_coroutine)
932
- # ]
933
- #return await asyncio.gather(*all_conditions_as_coro)
969
+ self._condition_func_is_coroutine)
970
+ ]
971
+ results = await asyncio.gather(*all_conditions_as_coro)
972
+ cond_eq_traj_len = [len(traj) == len(r) for r in results]
973
+ if not all(cond_eq_traj_len):
974
+ bad_condition_idx_str = ", ".join([f"{idx}" for idx, good
975
+ in enumerate(cond_eq_traj_len)
976
+ if not good])
977
+ raise ValueError("At least one of the conditions does not return "
978
+ "an array of shape (len(traj), ) when applied to "
979
+ "the trajectory traj. The conditions in question "
980
+ "have indexes " + bad_condition_idx_str + " .")
981
+ return results
934
982
 
935
983
 
936
984
  # alias for people coming from the path sampling community :)