asyncmd 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,25 +12,34 @@
12
12
  #
13
13
  # You should have received a copy of the GNU General Public License
14
14
  # along with asyncmd. If not, see <https://www.gnu.org/licenses/>.
15
- import os
15
+ """
16
+ This module contains classes for propagation of MD in segments and/or until a condition is reached.
17
+
18
+ Most notable are the InPartsTrajectoryPropagator and the ConditionalTrajectoryPropagator.
19
+ Also of interest especially in the context of path sampling could be the function
20
+ construct_tp_from_plus_and_minus_traj_segments, which can be used directly on the
21
+ output of a ConditionalTrajectoryPropagator to generate trajectories connecting two
22
+ fulfilled conditions.
23
+ """
16
24
  import asyncio
17
- import aiofiles
18
- import aiofiles.os
25
+ import collections.abc
26
+ import copy
19
27
  import inspect
20
28
  import logging
29
+ import os
21
30
  import typing
31
+
32
+ import aiofiles
33
+ import aiofiles.os
22
34
  import numpy as np
23
35
 
24
- from .trajectory import Trajectory
25
- from .functionwrapper import TrajectoryFunctionWrapper
36
+ from ..mdengine import MDEngine
37
+ from ..tools import FlagChangeList, remove_file_if_exist_async
38
+ from ..utils import (get_all_file_parts, get_all_traj_parts,
39
+ nstout_from_mdconfig)
26
40
  from .convert import TrajectoryConcatenator
27
- from ..utils import (get_all_traj_parts,
28
- get_all_file_parts,
29
- nstout_from_mdconfig,
30
- )
31
- from ..tools import (remove_file_if_exist,
32
- remove_file_if_exist_async,
33
- )
41
+ from .functionwrapper import TrajectoryFunctionWrapper
42
+ from .trajectory import Trajectory
34
43
 
35
44
 
36
45
  logger = logging.getLogger(__name__)
@@ -41,22 +50,20 @@ class MaxStepsReachedError(Exception):
41
50
  Error raised when the simulation terminated because the (user-defined)
42
51
  maximum number of integration steps/trajectory frames has been reached.
43
52
  """
44
- pass
45
53
 
46
54
 
47
- # TODO: move to trajectory.convert?
48
- async def construct_TP_from_plus_and_minus_traj_segments(
55
+ async def construct_tp_from_plus_and_minus_traj_segments(
56
+ *,
49
57
  minus_trajs: "list[Trajectory]",
50
58
  minus_state: int,
51
59
  plus_trajs: "list[Trajectory]",
52
60
  plus_state: int,
53
61
  state_funcs: "list[TrajectoryFunctionWrapper]",
54
62
  tra_out: str,
55
- struct_out: typing.Optional[str] = None,
56
- overwrite: bool = False
57
- ) -> Trajectory:
63
+ **concatenate_kwargs,
64
+ ) -> Trajectory:
58
65
  """
59
- Construct a continous TP from plus and minus segments until states.
66
+ Construct a continuous TP from plus and minus segments until states.
60
67
 
61
68
  This is used e.g. in TwoWay TPS or if you try to get TPs out of a committor
62
69
  simulation. Note, that this inverts all velocities on the minus segments.
@@ -76,11 +83,15 @@ async def construct_TP_from_plus_and_minus_traj_segments(
76
83
  the minus and plus state indices!
77
84
  tra_out : str
78
85
  Absolute or relative path to the output trajectory file.
79
- struct_out : str, optional
80
- Absolute or relative path to the output structure file, if None we will
81
- use the structure file of the first minus_traj, by default None.
82
- overwrite : bool, optional
83
- Whether we should overwrite tra_out if it exists, by default False.
86
+ concatenate_kwargs : dict
87
+ All (other) keyword arguments will be passed as is to the
88
+ :meth:`TrajectoryConcatenator.concatenate` method. These include, e.g.,
89
+
90
+ - struct_out : str, optional
91
+ Absolute or relative path to the output structure file, if None we will
92
+ use the structure file of the first minus_traj, by default None.
93
+ - overwrite : bool, optional
94
+ Whether we should overwrite tra_out if it exists, by default False.
84
95
 
85
96
  Returns
86
97
  -------
@@ -168,42 +179,65 @@ async def construct_TP_from_plus_and_minus_traj_segments(
168
179
  trajs=trajs,
169
180
  slices=slices,
170
181
  tra_out=tra_out,
171
- struct_out=struct_out,
172
- overwrite=overwrite,
182
+ **concatenate_kwargs
173
183
  )
174
184
  return path_traj
175
185
 
176
186
 
187
+ # pylint: disable-next=too-few-public-methods
177
188
  class _TrajectoryPropagator:
178
189
  # (private) superclass for InPartsTrajectoryPropagator and
179
190
  # ConditionalTrajectoryPropagator,
180
- # here we keep the common functions shared between them
181
- async def remove_parts(self, workdir: str, deffnm: str,
182
- file_endings_to_remove: list[str] = ["trajectories",
183
- "log"],
184
- remove_mda_offset_and_lock_files: bool = True,
185
- remove_asyncmd_npz_caches: bool = True,
191
+ # here we keep the common functions shared between them, currently
192
+ # this is only the file removal method and a bit of the init logic.
193
+ def __init__(self, *,
194
+ engine_cls: type[MDEngine], engine_kwargs: dict[str, typing.Any],
195
+ walltime_per_part: float,
196
+ ) -> None:
197
+ self.engine_cls = engine_cls
198
+ self.engine_kwargs = engine_kwargs
199
+ self.walltime_per_part = walltime_per_part
200
+
201
+ async def remove_parts(self, workdir: str, deffnm: str, *,
202
+ file_endings_to_remove: list[str] | None = None,
203
+ remove_mda_offset_and_lock_files: bool = True,
204
+ remove_asyncmd_npz_caches: bool = True,
186
205
  ) -> None:
187
206
  """
188
- Remove all `$deffnm.part$num.$file_ending` files for given file_endings
207
+ Remove all ``$deffnm.part$num.$file_ending`` files for file_endings.
189
208
 
190
- Can be useful to clean the `workdir` from temporary files if e.g. only
191
- the concatenate trajectory is of interesst (like in TPS).
209
+ Can be useful to clean the ``workdir`` from temporary files if e.g. only
210
+ the concatenate trajectory is of interest (like in TPS).
192
211
 
193
212
  Parameters
194
213
  ----------
195
214
  workdir : str
196
215
  The directory to clean.
197
216
  deffnm : str
198
- The `deffnm` that the files we clean must have.
199
- file_endings_to_remove : list[str], optional
200
- The strings in the list `file_endings_to_remove` indicate which
217
+ The ``deffnm`` that the files we clean must have.
218
+ file_endings_to_remove : list[str] | None, optional
219
+ The strings in the list ``file_endings_to_remove`` indicate which
201
220
  file endings to remove.
202
- E.g. `file_endings_to_remove=["trajectories", "log"]` will result
203
- in removal of trajectory parts and the log files. If you add "edr"
204
- to the list we would also remove the edr files,
205
- by default ["trajectories", "log"]
221
+ The 'special' string "trajectories" will be translated to the file
222
+ ending of the trajectories the engine produces, i.e. ``engine.output_traj_type``.
223
+ E.g. passing ``file_endings_to_remove=["trajectories", "log"]`` will
224
+ result in removal of trajectory parts and the log files. If you add
225
+ "edr" to the list we would also remove the edr files.
226
+ By default, i.e., if None the list will be ["trajectories", "log"].
227
+ remove_mda_offset_and_lock_files : bool, optional
228
+ Whether to remove any (hidden) offset and lock files generated by
229
+ MDAnalysis associated with the removed trajectory files (if they exist).
230
+ By default True.
231
+ remove_asyncmd_npz_caches : bool, optional
232
+ Whether to remove any (hidden) npz cache files generated by asyncmd
233
+ associated with the removed trajectory files (if they exist).
234
+ By default True.
206
235
  """
236
+ if file_endings_to_remove is None:
237
+ file_endings_to_remove = ["trajectories", "log"]
238
+ else:
239
+ # copy the list so we dont mutate what we got passed
240
+ file_endings_to_remove = copy.copy(file_endings_to_remove)
207
241
  if "trajectories" in file_endings_to_remove:
208
242
  # replace "trajectories" with the actual output traj type
209
243
  try:
@@ -222,7 +256,7 @@ class _TrajectoryPropagator:
222
256
  )
223
257
  # make sure we dont miss anything because we have different
224
258
  # capitalization
225
- if len(parts_to_remove) == 0:
259
+ if not parts_to_remove:
226
260
  parts_to_remove = await get_all_file_parts(
227
261
  folder=workdir,
228
262
  deffnm=deffnm,
@@ -232,7 +266,7 @@ class _TrajectoryPropagator:
232
266
  for f in parts_to_remove
233
267
  )
234
268
  )
235
- # TODO: the note below?
269
+ # TODO: address the note below?
236
270
  # NOTE: this is a bit hacky: we just try to remove the offset and
237
271
  # lock files for every file we remove (since we do not know
238
272
  # if the file we remove is a trajectory [and therefore
@@ -286,8 +320,9 @@ class InPartsTrajectoryPropagator(_TrajectoryPropagator):
286
320
  Useful to make full use of backfilling with short(ish) simulation jobs and
287
321
  also to run simulations that are longer than the timelimit.
288
322
  """
289
- def __init__(self, n_steps: int, engine_cls,
290
- engine_kwargs: dict, walltime_per_part: float,
323
+ def __init__(self, n_steps: int, *,
324
+ engine_cls: type[MDEngine], engine_kwargs: dict[str, typing.Any],
325
+ walltime_per_part: float,
291
326
  ) -> None:
292
327
  """
293
328
  Initialize an `InPartTrajectoryPropagator`.
@@ -303,19 +338,18 @@ class InPartsTrajectoryPropagator(_TrajectoryPropagator):
303
338
  walltime_per_part : float
304
339
  Walltime per trajectory segment, in hours.
305
340
  """
341
+ super().__init__(engine_cls=engine_cls, engine_kwargs=engine_kwargs,
342
+ walltime_per_part=walltime_per_part)
306
343
  self.n_steps = n_steps
307
- self.engine_cls = engine_cls
308
- self.engine_kwargs = engine_kwargs
309
- self.walltime_per_part = walltime_per_part
310
344
 
311
345
  async def propagate_and_concatenate(self,
312
346
  starting_configuration: Trajectory,
313
347
  workdir: str,
314
- deffnm: str,
348
+ deffnm: str, *,
315
349
  tra_out: str,
316
- overwrite: bool = False,
317
- continuation: bool = False
318
- ) -> tuple[Trajectory, int]:
350
+ continuation: bool = False,
351
+ **concatenate_kwargs,
352
+ ) -> Trajectory | None:
319
353
  """
320
354
  Chain :meth:`propagate` and :meth:`cut_and_concatenate` methods.
321
355
 
@@ -329,12 +363,18 @@ class InPartsTrajectoryPropagator(_TrajectoryPropagator):
329
363
  MD engine deffnm for trajectory parts and other files.
330
364
  tra_out : str
331
365
  Absolute or relative path for the concatenated output trajectory.
332
- overwrite : bool, optional
333
- Whether the output trajectory should be overwritten if it exists,
334
- by default False.
335
366
  continuation : bool, optional
336
367
  Whether we are continuing a previous MD run (with the same deffnm
337
368
  and working directory), by default False.
369
+ concatenate_kwargs : dict
370
+ All (other) keyword arguments will be passed as is to the
371
+ :meth:`TrajectoryConcatenator.concatenate` method. These include, e.g.,
372
+
373
+ - struct_out : str, optional
374
+ Absolute or relative path to the output structure file, if None
375
+ we will use the structure file of the first traj, by default None.
376
+ - overwrite : bool, optional
377
+ Whether we should overwrite tra_out if it exists, by default False.
338
378
 
339
379
  Returns
340
380
  -------
@@ -348,23 +388,21 @@ class InPartsTrajectoryPropagator(_TrajectoryPropagator):
348
388
  deffnm=deffnm,
349
389
  continuation=continuation
350
390
  )
351
- full_traj = await self.cut_and_concatenate(
352
- trajs=trajs,
353
- tra_out=tra_out,
354
- overwrite=overwrite,
391
+ full_traj = await self.cut_and_concatenate(trajs=trajs, tra_out=tra_out,
392
+ **concatenate_kwargs,
355
393
  )
356
394
  return full_traj
357
395
 
358
396
  async def propagate(self,
359
397
  starting_configuration: Trajectory,
360
398
  workdir: str,
361
- deffnm: str,
399
+ deffnm: str, *,
362
400
  continuation: bool = False,
363
401
  ) -> list[Trajectory]:
364
402
  """
365
403
  Propagate the trajectory until self.n_steps integration are done.
366
404
 
367
- Return a list of trajecory segments and the first condition fullfilled.
405
+ Return a list of trajectory segments and the first condition fulfilled.
368
406
 
369
407
  Parameters
370
408
  ----------
@@ -385,7 +423,7 @@ class InPartsTrajectoryPropagator(_TrajectoryPropagator):
385
423
  Returns
386
424
  -------
387
425
  traj_segments : list[Trajectory]
388
- List of trajectory (segements), ordered in time.
426
+ List of trajectory (segments), ordered in time.
389
427
  """
390
428
  engine = self.engine_cls(**self.engine_kwargs)
391
429
  if continuation:
@@ -396,11 +434,10 @@ class InPartsTrajectoryPropagator(_TrajectoryPropagator):
396
434
  )
397
435
  if len(trajs) > 0:
398
436
  # can only continue if we find the previous trajs
399
- step_counter = engine.steps_done
400
- if step_counter >= self.n_steps:
437
+ await engine.prepare_from_files(workdir=workdir, deffnm=deffnm)
438
+ if (step_counter := engine.steps_done) >= self.n_steps:
401
439
  # already longer than what we want to do, bail out
402
440
  return trajs
403
- await engine.prepare_from_files(workdir=workdir, deffnm=deffnm)
404
441
  else:
405
442
  # no previous trajs, prepare engine from scratch
406
443
  continuation = False
@@ -417,11 +454,10 @@ class InPartsTrajectoryPropagator(_TrajectoryPropagator):
417
454
  trajs = []
418
455
  step_counter = 0
419
456
 
420
- while (step_counter < self.n_steps):
421
- traj = await engine.run(nsteps=self.n_steps,
422
- walltime=self.walltime_per_part,
423
- steps_per_part=False,
424
- )
457
+ while step_counter < self.n_steps:
458
+ traj = await engine.run_walltime(walltime=self.walltime_per_part,
459
+ max_steps=self.n_steps,
460
+ )
425
461
  step_counter = engine.steps_done
426
462
  trajs.append(traj)
427
463
  return trajs
@@ -429,15 +465,16 @@ class InPartsTrajectoryPropagator(_TrajectoryPropagator):
429
465
  async def cut_and_concatenate(self,
430
466
  trajs: list[Trajectory],
431
467
  tra_out: str,
432
- overwrite: bool = False,
433
- ) -> Trajectory:
468
+ **concatenate_kwargs,
469
+ ) -> Trajectory | None:
434
470
  """
435
471
  Cut and concatenate the trajectory until it has length n_steps.
436
472
 
437
- Take a list of trajectory segments and form one continous trajectory
473
+ Take a list of trajectory segments and form one continuous trajectory
438
474
  containing n_steps integration steps. The expected input
439
475
  is a list of trajectories, e.g. the output of the :meth:`propagate`
440
476
  method.
477
+ Returns None if ``self.n_steps`` is zero.
441
478
 
442
479
  Parameters
443
480
  ----------
@@ -445,9 +482,15 @@ class InPartsTrajectoryPropagator(_TrajectoryPropagator):
445
482
  Trajectory segments to cut and concatenate.
446
483
  tra_out : str
447
484
  Absolute or relative path for the concatenated output trajectory.
448
- overwrite : bool, optional
449
- Whether the output trajectory should be overwritten if it exists,
450
- by default False.
485
+ concatenate_kwargs : dict
486
+ All (other) keyword arguments will be passed as is to the
487
+ :meth:`TrajectoryConcatenator.concatenate` method. These include, e.g.,
488
+
489
+ - struct_out : str, optional
490
+ Absolute or relative path to the output structure file, if None
491
+ we will use the structure file of the first traj, by default None.
492
+ - overwrite : bool, optional
493
+ Whether we should overwrite tra_out if it exists, by default False.
451
494
 
452
495
  Returns
453
496
  -------
@@ -462,23 +505,18 @@ class InPartsTrajectoryPropagator(_TrajectoryPropagator):
462
505
  """
463
506
  # trajs is a list of trajectories, e.g. the return of propagate
464
507
  # tra_out and overwrite are passed directly to the Concatenator
465
- if len(trajs) == 0:
466
- # no trajectories to concatenate, happens e.g. if self.n_steps=0
467
- # we return None (TODO: is this what we want?)
508
+ if not self.n_steps:
509
+ # no trajectories to concatenate, since self.n_steps=0
510
+ # we return None
468
511
  return None
469
512
  if self.n_steps > trajs[-1].last_step:
470
513
  # not enough steps in trajectories
471
- raise ValueError("The given trajectories are to short (< self.n_steps).")
472
- elif self.n_steps == trajs[-1].last_step:
514
+ raise ValueError("The given trajectories are too short (< self.n_steps).")
515
+ if self.n_steps == trajs[-1].last_step:
473
516
  # all good, we just take all trajectory parts fully
474
517
  slices = [(0, None, 1) for _ in range(len(trajs))]
475
518
  last_part_idx = len(trajs) - 1
476
519
  else:
477
- logger.warning("Trajectories do not exactly contain n_steps "
478
- "integration steps. Using a heuristic to find the "
479
- "correct last frame to include, note that this "
480
- "heuristic might fail if n_steps is not a multiple "
481
- "of the trajectory output frequency.")
482
520
  # need to find the subtrajectory that contains the correct number
483
521
  # of integration steps
484
522
  # first find the part in which we go over n_steps
@@ -490,76 +528,80 @@ class InPartsTrajectoryPropagator(_TrajectoryPropagator):
490
528
  last_part_len_steps = (trajs[last_part_idx].last_step
491
529
  - trajs[last_part_idx].first_step)
492
530
  steps_per_frame = last_part_len_steps / last_part_len_frames
493
- frames_in_last_part = 0
494
- while ((trajs[last_part_idx].first_step
495
- + frames_in_last_part * steps_per_frame) < self.n_steps):
496
- # I guess we stay with the < (instead of <=) and rather have
497
- # one frame too much?
498
- frames_in_last_part += 1
531
+ frames_in_last_part = ((self.n_steps
532
+ - trajs[last_part_idx].first_step
533
+ )
534
+ / steps_per_frame)
535
+ log_str = ("Trajectories do not exactly contain n_steps integration steps. "
536
+ "Using a heuristic to find the correct last frame to include."
537
+ )
538
+ if frames_in_last_part != (frames_in_last_part_int := round(frames_in_last_part)):
539
+ log_str += (" Note that this heuristic might fail because n_steps"
540
+ " is not a multiple of the trajectory output frequency."
541
+ )
542
+ logger.warning(log_str)
543
+ else:
544
+ logger.info(log_str)
499
545
  # build slices
500
546
  slices = [(0, None, 1) for _ in range(last_part_idx)]
501
- slices += [(0, frames_in_last_part + 1, 1)]
547
+ slices += [(0, frames_in_last_part_int + 1, 1)]
502
548
 
503
549
  # and concatenate
504
550
  full_traj = await TrajectoryConcatenator().concatenate_async(
505
551
  trajs=trajs[:last_part_idx + 1],
506
552
  slices=slices,
507
- # take the structure file of the traj, as it
508
- # comes from the engine directly
509
- tra_out=tra_out, struct_out=None,
510
- overwrite=overwrite,
511
- )
553
+ tra_out=tra_out,
554
+ **concatenate_kwargs
555
+ )
512
556
  return full_traj
513
557
 
514
558
 
515
559
  class ConditionalTrajectoryPropagator(_TrajectoryPropagator):
516
560
  """
517
- Propagate a trajectory until any of the given conditions is fullfilled.
561
+ Propagate a trajectory until any of the given conditions is fulfilled.
518
562
 
519
563
  This class propagates the trajectory using a given MD engine (class) in
520
564
  small chunks (chunksize is determined by walltime_per_part) and checks
521
- after every chunk is done if any condition has been fullfilled.
565
+ after every chunk is done if any condition has been fulfilled.
522
566
  It then returns a list of trajectory parts and the index of the condition
523
- first fullfilled. It can also concatenate the parts into one trajectory,
567
+ first fulfilled. It can also concatenate the parts into one trajectory,
524
568
  which then starts with the starting configuration and ends with the frame
525
- fullfilling the condition.
526
-
527
- Attributes
528
- ----------
529
- conditions : list[callable]
530
- List of (wrapped) condition functions.
569
+ fulfilling the condition.
531
570
 
532
571
  Notes
533
572
  -----
534
573
  We assume that every condition function returns a list/ a 1d array with
535
- True or False for each frame, i.e. if we fullfill condition at any given
574
+ True or False for each frame, i.e. if we fulfill condition at any given
536
575
  frame.
537
- We assume non-overlapping conditions, i.e. a configuration can not fullfill
576
+ We assume non-overlapping conditions, i.e. a configuration can not fulfill
538
577
  two conditions at the same time, **it is the users responsibility to ensure
539
578
  that their conditions are sane**.
540
579
  """
541
580
 
542
581
  # NOTE: we assume that every condition function returns a list/ a 1d array
543
- # with True/False for each frame, i.e. if we fullfill condition at
582
+ # with True/False for each frame, i.e. if we fulfill condition at
544
583
  # any given frame
545
584
  # NOTE: we assume non-overlapping conditions, i.e. a configuration can not
546
- # fullfill two conditions at the same time, it is the users
585
+ # fulfill two conditions at the same time, it is the users
547
586
  # responsibility to ensure that their conditions are sane
548
587
 
549
- def __init__(self, conditions, engine_cls,
550
- engine_kwargs: dict,
588
+ # Note: max_steps and max_frames are mutually exclusive and this is enforced,
589
+ # but pylint does not know that, so we tell it to not be mad for one arg more
590
+ # pylint: disable-next=too-many-arguments
591
+ def __init__(self, conditions, *,
592
+ engine_cls: type[MDEngine], engine_kwargs: dict[str, typing.Any],
551
593
  walltime_per_part: float,
552
- max_steps: typing.Optional[int] = None,
553
- max_frames: typing.Optional[int] = None,
594
+ max_steps: int | None = None,
595
+ max_frames: int | None = None,
554
596
  ):
555
597
  """
556
- Initialize a `ConditionalTrajectoryPropagator`.
598
+ Initialize a :class:`ConditionalTrajectoryPropagator`.
557
599
 
558
600
  Parameters
559
601
  ----------
560
602
  conditions : list[callable], usually list[TrajectoryFunctionWrapper]
561
603
  List of condition functions, usually wrapped function for
562
- asyncronous application, but can be any callable that takes a
604
+ asynchronous application, but can be any callable that takes a
563
605
  :class:`asyncmd.Trajectory` and returns an array of True and False
564
606
  values (one value per frame).
565
607
  engine_cls : :class:`asyncmd.mdengine.MDEngine`
@@ -571,7 +613,7 @@ class ConditionalTrajectoryPropagator(_TrajectoryPropagator):
571
613
  max_steps : int, optional
572
614
  Maximum number of integration steps to do before stopping the
573
615
  simulation because it did not commit to any condition,
574
- by default None. Takes precendence over max_frames if both given.
616
+ by default None. Takes precedence over max_frames if both given.
575
617
  max_frames : int, optional
576
618
  Maximum number of frames to produce before stopping the simulation
577
619
  because it did not commit to any condition, by default None.
@@ -582,12 +624,11 @@ class ConditionalTrajectoryPropagator(_TrajectoryPropagator):
582
624
  ``max_steps = max_frames * output_frequency``, if both are given
583
625
  max_steps takes precedence.
584
626
  """
585
- self._conditions = None
586
- self._condition_func_is_coroutine = None
627
+ super().__init__(engine_cls=engine_cls, engine_kwargs=engine_kwargs,
628
+ walltime_per_part=walltime_per_part)
629
+ self._conditions = FlagChangeList([])
630
+ self._condition_func_is_coroutine: list[bool] = []
587
631
  self.conditions = conditions
588
- self.engine_cls = engine_cls
589
- self.engine_kwargs = engine_kwargs
590
- self.walltime_per_part = walltime_per_part
591
632
  # find nstout
592
633
  try:
593
634
  traj_type = engine_kwargs["output_traj_type"]
@@ -610,40 +651,47 @@ class ConditionalTrajectoryPropagator(_TrajectoryPropagator):
610
651
  # this is a float but can be compared to ints
611
652
  self.max_steps = np.inf
612
653
 
613
- # TODO/FIXME: self._conditions is a list...that means users can change
614
- # single elements without using the setter!
615
- # we could use a list subclass as for the MDconfig?!
616
654
  @property
617
- def conditions(self):
655
+ def conditions(self) -> FlagChangeList:
656
+ """List of (wrapped) condition functions."""
618
657
  return self._conditions
619
658
 
620
659
  @conditions.setter
621
- def conditions(self, conditions):
622
- # use asyncio.iscorotinefunction to check the conditions
623
- self._condition_func_is_coroutine = [
660
+ def conditions(self, conditions: collections.abc.Sequence):
661
+ if len(conditions) < 1:
662
+ raise ValueError("Must supply at least one termination condition.")
663
+ self._condition_func_is_coroutine = self._check_condition_funcs(
664
+ conditions=conditions
665
+ )
666
+ self._conditions = FlagChangeList(conditions)
667
+
668
+ def _check_condition_funcs(self, conditions: collections.abc.Sequence,
669
+ ) -> list[bool]:
670
+ # use asyncio.iscoroutinefunction to check the conditions
671
+ condition_func_is_coroutine = [
624
672
  (inspect.iscoroutinefunction(c)
625
673
  or inspect.iscoroutinefunction(c.__call__))
626
674
  for c in conditions
627
- ]
628
- if not all(self._condition_func_is_coroutine):
629
- # and warn if it is not a corotinefunction
675
+ ]
676
+ if not all(condition_func_is_coroutine):
677
+ # and warn if it is not a coroutinefunction
630
678
  logger.warning(
631
679
  "It is recommended to use coroutinefunctions for all "
632
680
  "conditions. This can easily be achieved by wrapping any "
633
681
  "function in a TrajectoryFunctionWrapper. All "
634
682
  "non-coroutine condition functions will be blocking when "
635
683
  "applied! ([c is coroutine for c in conditions] = %s)",
636
- self._condition_func_is_coroutine
684
+ condition_func_is_coroutine
637
685
  )
638
- self._conditions = conditions
686
+ return condition_func_is_coroutine
639
687
 
640
688
  async def propagate_and_concatenate(self,
641
689
  starting_configuration: Trajectory,
642
690
  workdir: str,
643
- deffnm: str,
691
+ deffnm: str, *,
644
692
  tra_out: str,
645
- overwrite: bool = False,
646
- continuation: bool = False
693
+ continuation: bool = False,
694
+ **concatenate_kwargs
647
695
  ) -> tuple[Trajectory, int]:
648
696
  """
649
697
  Chain :meth:`propagate` and :meth:`cut_and_concatenate` methods.
@@ -658,16 +706,22 @@ class ConditionalTrajectoryPropagator(_TrajectoryPropagator):
658
706
  MD engine deffnm for trajectory parts and other files.
659
707
  tra_out : str
660
708
  Absolute or relative path for the concatenated output trajectory.
661
- overwrite : bool, optional
662
- Whether the output trajectory should be overwritten if it exists,
663
- by default False.
664
709
  continuation : bool, optional
665
710
  Whether we are continuing a previous MD run (with the same deffnm
666
711
  and working directory), by default False.
712
+ concatenate_kwargs : dict
713
+ All (other) keyword arguments will be passed as is to the
714
+ :meth:`TrajectoryConcatenator.concatenate` method. These include, e.g.,
715
+
716
+ - struct_out : str, optional
717
+ Absolute or relative path to the output structure file, if None
718
+ we will use the structure file of the first traj, by default None.
719
+ - overwrite : bool, optional
720
+ Whether we should overwrite tra_out if it exists, by default False.
667
721
 
668
722
  Returns
669
723
  -------
670
- (traj_out, idx_of_condition_fullfilled) : (Trajectory, int)
724
+ (traj_out, idx_of_condition_fulfilled) : (Trajectory, int)
671
725
  The concatenated output trajectory from starting configuration
672
726
  until the first condition is True and the index to the condition
673
727
  function in `conditions`.
@@ -679,34 +733,34 @@ class ConditionalTrajectoryPropagator(_TrajectoryPropagator):
679
733
  frames has been reached in :meth:`propagate`.
680
734
  """
681
735
  # this just chains propagate and cut_and_concatenate
682
- # usefull for committor simulations, for e.g. TPS one should try to
736
+ # useful for committor simulations, for e.g. TPS one should try to
683
737
  # directly concatenate both directions to a full TP if possible
684
- trajs, first_condition_fullfilled = await self.propagate(
738
+ trajs, first_condition_fulfilled = await self.propagate(
685
739
  starting_configuration=starting_configuration,
686
740
  workdir=workdir,
687
741
  deffnm=deffnm,
688
742
  continuation=continuation
689
- )
743
+ )
690
744
  # NOTE: it should not matter too much speedwise that we recalculate
691
745
  # the condition functions, they are expected to be wrapped funcs
692
746
  # i.e. the second time we should just get the values from cache
693
- full_traj, first_condition_fullfilled = await self.cut_and_concatenate(
747
+ full_traj, first_condition_fulfilled = await self.cut_and_concatenate(
694
748
  trajs=trajs,
695
749
  tra_out=tra_out,
696
- overwrite=overwrite,
697
- )
698
- return full_traj, first_condition_fullfilled
750
+ **concatenate_kwargs
751
+ )
752
+ return full_traj, first_condition_fulfilled
699
753
 
700
754
  async def propagate(self,
701
755
  starting_configuration: Trajectory,
702
756
  workdir: str,
703
- deffnm: str,
757
+ deffnm: str, *,
704
758
  continuation: bool = False,
705
759
  ) -> tuple[list[Trajectory], int]:
706
760
  """
707
- Propagate the trajectory until any condition is fullfilled.
761
+ Propagate the trajectory until any condition is fulfilled.
708
762
 
709
- Return a list of trajecory segments and the first condition fullfilled.
763
+ Return a list of trajectory segments and the first condition fulfilled.
710
764
 
711
765
  Parameters
712
766
  ----------
@@ -722,10 +776,10 @@ class ConditionalTrajectoryPropagator(_TrajectoryPropagator):
722
776
 
723
777
  Returns
724
778
  -------
725
- (traj_segments, idx_of_condition_fullfilled) : (list[Trajectory], int)
726
- List of trajectory (segements), the last entry is the one on which
727
- the first condition is fullfilled at some frame, the ineger is the
728
- index to the condition function in `conditions`.
779
+ (traj_segments, idx_of_condition_fulfilled) : (list[Trajectory], int)
780
+ List of trajectory (segments), the last entry is the one on which
781
+ the first condition is fulfilled at some frame, the integer is the
782
+ index to the condition function in ``conditions``.
729
783
 
730
784
  Raises
731
785
  ------
@@ -733,27 +787,29 @@ class ConditionalTrajectoryPropagator(_TrajectoryPropagator):
733
787
  When the defined maximum number of integration steps/trajectory
734
788
  frames has been reached.
735
789
  """
736
- # NOTE: curently this just returns a list of trajs + the condition
737
- # fullfilled
790
+ # NOTE: currently this just returns a list of trajs + the condition
791
+ # fulfilled
738
792
  # this feels a bit uncomfortable but avoids that we concatenate
739
793
  # everything a quadrillion times when we use the results
740
- # check first if the start configuration is fullfilling any condition
794
+ # check first if the start configuration is fulfilling any condition
741
795
  cond_vals = await self._condition_vals_for_traj(starting_configuration)
742
796
  if np.any(cond_vals):
743
- conds_fullfilled, frame_nums = np.where(cond_vals)
797
+ conds_fulfilled, frame_nums = np.where(cond_vals)
744
798
  # gets the frame with the lowest idx where any condition is True
745
799
  min_idx = np.argmin(frame_nums)
746
- first_condition_fullfilled = conds_fullfilled[min_idx]
747
- logger.error("Starting configuration (%s) is already fullfilling "
800
+ first_condition_fulfilled = conds_fulfilled[min_idx]
801
+ logger.error("Starting configuration (%s) is already fulfilling "
748
802
  "the condition with idx %s.",
749
- starting_configuration, first_condition_fullfilled,
750
- )
803
+ starting_configuration, first_condition_fulfilled,
804
+ )
751
805
  # we just return the starting configuration/trajectory
752
806
  trajs = [starting_configuration]
753
- return trajs, first_condition_fullfilled
807
+ return trajs, first_condition_fulfilled
754
808
 
755
- # starting configuration does not fullfill any condition, lets do MD
809
+ # starting configuration does not fulfill any condition, lets do MD
756
810
  engine = self.engine_cls(**self.engine_kwargs)
811
+ # Note: we first check for continuation because if we do not find a run
812
+ # to continue we fallback to no continuation
757
813
  if continuation:
758
814
  # continuation: get all traj parts already done and continue from
759
815
  # there, i.e. append to the last traj part found
@@ -762,32 +818,31 @@ class ConditionalTrajectoryPropagator(_TrajectoryPropagator):
762
818
  trajs = await get_all_traj_parts(folder=workdir, deffnm=deffnm,
763
819
  engine=engine,
764
820
  )
765
- if len(trajs) > 0:
766
- # can only calc CV values if we have trajectories prouced
821
+ if not trajs:
822
+ # no trajectories, so we should prepare engine from scratch
823
+ continuation = False
824
+ logger.error("continuation=True, but we found no previous "
825
+ "trajectories. Setting continuation=False and "
826
+ "preparing the engine from scratch.")
827
+ else:
828
+ # (can only) calc CV values if we have found trajectories
767
829
  cond_vals = await asyncio.gather(
768
830
  *(self._condition_vals_for_traj(t) for t in trajs)
769
831
  )
770
832
  cond_vals = np.concatenate([np.asarray(s) for s in cond_vals],
771
833
  axis=1)
772
- # see if we already fullfill a condition on the existing traj parts
773
- any_cond_fullfilled = np.any(cond_vals)
774
- if any_cond_fullfilled:
775
- conds_fullfilled, frame_nums = np.where(cond_vals)
834
+ # see if we already fulfill a condition on the existing traj parts
835
+ if (any_cond_fulfilled := np.any(cond_vals)):
836
+ conds_fulfilled, frame_nums = np.where(cond_vals)
776
837
  # gets the frame with the lowest idx where any cond is True
777
838
  min_idx = np.argmin(frame_nums)
778
- first_condition_fullfilled = conds_fullfilled[min_idx]
779
- # already fullfill a condition, get out of here!
780
- return trajs, first_condition_fullfilled
781
- # Did not fullfill any condition yet, so prepare the engine to
839
+ first_condition_fulfilled = conds_fulfilled[min_idx]
840
+ # already fulfill a condition, get out of here!
841
+ return trajs, first_condition_fulfilled
842
+ # Did not fulfill any condition yet, so prepare the engine to
782
843
  # continue the simulation until we reach any of the (new) conds
783
844
  await engine.prepare_from_files(workdir=workdir, deffnm=deffnm)
784
- step_counter = engine.steps_done
785
- else:
786
- # no trajectories, so we should prepare engine from scratch
787
- continuation = False
788
- logger.error("continuation=True, but we found no previous "
789
- "trajectories. Setting continuation=False and "
790
- "preparing the engine from scratch.")
845
+
791
846
  if not continuation:
792
847
  # no continuation, just prepare the engine from scratch
793
848
  await engine.prepare(
@@ -795,45 +850,46 @@ class ConditionalTrajectoryPropagator(_TrajectoryPropagator):
795
850
  workdir=workdir,
796
851
  deffnm=deffnm,
797
852
  )
798
- any_cond_fullfilled = False
799
853
  trajs = []
800
- step_counter = 0
801
854
 
802
- while ((not any_cond_fullfilled)
855
+ step_counter = engine.steps_done
856
+ any_cond_fulfilled = False
857
+ while ((not any_cond_fulfilled)
803
858
  and (step_counter <= self.max_steps)):
804
859
  traj = await engine.run_walltime(self.walltime_per_part)
805
860
  cond_vals = await self._condition_vals_for_traj(traj)
806
- any_cond_fullfilled = np.any(cond_vals)
861
+ any_cond_fulfilled = np.any(cond_vals)
807
862
  step_counter = engine.steps_done
808
863
  trajs.append(traj)
809
- if not any_cond_fullfilled:
864
+ if not any_cond_fulfilled:
810
865
  # left while loop because of max_frames reached
811
866
  raise MaxStepsReachedError(
812
867
  f"Engine produced {step_counter} steps (>= {self.max_steps})."
813
868
  )
814
869
  # cond_vals are the ones for the last traj
815
870
  # here we get which conditions are True and at which frame
816
- conds_fullfilled, frame_nums = np.where(cond_vals)
871
+ conds_fulfilled, frame_nums = np.where(cond_vals)
817
872
  # gets the frame with the lowest idx where any condition is True
818
873
  min_idx = np.argmin(frame_nums)
819
874
  # and now the idx to self.conditions for cond that was first fullfilled
820
875
  # NOTE/FIXME: if two conditions are reached simultaneously at min_idx,
821
876
  # this will find the condition with the lower idx only
822
- first_condition_fullfilled = conds_fullfilled[min_idx]
823
- return trajs, first_condition_fullfilled
877
+ first_condition_fulfilled = conds_fulfilled[min_idx]
878
+ return trajs, first_condition_fulfilled
824
879
 
825
880
  async def cut_and_concatenate(self,
826
881
  trajs: list[Trajectory],
827
882
  tra_out: str,
828
- overwrite: bool = False,
883
+ **concatenate_kwargs,
829
884
  ) -> tuple[Trajectory, int]:
830
885
  """
831
886
  Cut and concatenate the trajectory until the first condition is True.
832
887
 
833
- Take a list of trajectory segments and form one continous trajectory
834
- until the first frame that fullfills any condition. The expected input
835
- is a list of trajectories, e.g. the output of the :meth:`propagate`
836
- method.
888
+ Take a list of trajectory segments and form one continuous trajectory
889
+ until the first frame that fulfills any condition. The first frame in
890
+ that fulfills any condition is included in the trajectory.
891
+ The expected input is a list of trajectories, e.g. the output of the
892
+ :meth:`propagate` method.
837
893
 
838
894
  Parameters
839
895
  ----------
@@ -841,42 +897,42 @@ class ConditionalTrajectoryPropagator(_TrajectoryPropagator):
841
897
  Trajectory segments to cut and concatenate.
842
898
  tra_out : str
843
899
  Absolute or relative path for the concatenated output trajectory.
844
- overwrite : bool, optional
845
- Whether the output trajectory should be overwritten if it exists,
846
- by default False.
900
+ concatenate_kwargs : dict
901
+ All (other) keyword arguments will be passed as is to the
902
+ :meth:`TrajectoryConcatenator.concatenate` method. These include, e.g.,
903
+
904
+ - struct_out : str, optional
905
+ Absolute or relative path to the output structure file, if None
906
+ we will use the structure file of the first traj, by default None.
907
+ - overwrite : bool, optional
908
+ Whether we should overwrite tra_out if it exists, by default False.
847
909
 
848
910
  Returns
849
911
  -------
850
- (traj_out, idx_of_condition_fullfilled) : (Trajectory, int)
912
+ (traj_out, idx_of_condition_fulfilled) : (Trajectory, int)
851
913
  The concatenated output trajectory from starting configuration
852
914
  until the first condition is True and the index to the condition
853
915
  function in `conditions`.
854
916
  """
855
- # trajs is a list of trajectories, e.g. the return of propagate
856
- # tra_out and overwrite are passed directly to the Concatenator
857
- # NOTE: we assume that frame0 of traj0 does not fullfill any condition
858
- # and return only the subtrajectory from frame0 until any
859
- # condition is first True (the rest is ignored)
860
- # get all func values and put them into one big array
917
+ # get all func values and put them into one big list
861
918
  cond_vals = await asyncio.gather(
862
919
  *(self._condition_vals_for_traj(t) for t in trajs)
863
920
  )
864
921
  # cond_vals is a list (trajs) of lists (conditions)
865
922
  # take condition 0 (always present) to get the traj part lengths
866
923
  part_lens = [len(c[0]) for c in cond_vals] # c[0] is 1d (np)array
867
- cond_vals = np.concatenate([np.asarray(c) for c in cond_vals],
868
- axis=1)
869
- conds_fullfilled, frame_nums = np.where(cond_vals)
870
- # gets the frame with the lowest idx where any condition is True
924
+ # get all occurrences where any condition is True
925
+ conds_fulfilled, frame_nums = np.where(
926
+ np.concatenate([np.asarray(c) for c in cond_vals], axis=1)
927
+ )
928
+ # get the index of the frame with the lowest number where any condition is True
871
929
  min_idx = np.argmin(frame_nums)
872
- first_condition_fullfilled = conds_fullfilled[min_idx]
930
+ first_condition_fulfilled = conds_fulfilled[min_idx]
873
931
  first_frame_in_cond = frame_nums[min_idx]
874
932
  # find out in which part it is
875
- last_part_idx = 0
876
- frame_sum = part_lens[last_part_idx]
877
- while first_frame_in_cond >= frame_sum:
878
- last_part_idx += 1
879
- frame_sum += part_lens[last_part_idx]
933
+ # nonzero always returns a tuple (first zero index below)
934
+ # and we only care for the first occurrence of True/nonzero (second zero index below)
935
+ last_part_idx = (np.cumsum(part_lens) >= first_frame_in_cond).nonzero()[0][0]
880
936
  # find the first frame in cond (counting from start of last part)
881
937
  _first_frame_in_cond = (first_frame_in_cond
882
938
  - sum(part_lens[:last_part_idx])) # >= 0
@@ -892,46 +948,37 @@ class ConditionalTrajectoryPropagator(_TrajectoryPropagator):
892
948
  full_traj = await TrajectoryConcatenator().concatenate_async(
893
949
  trajs=trajs[:last_part_idx + 1],
894
950
  slices=slices,
895
- # take the structure file of the traj, as it
896
- # comes from the engine directly
897
- tra_out=tra_out, struct_out=None,
898
- overwrite=overwrite,
899
- )
900
- return full_traj, first_condition_fullfilled
901
-
902
- async def _condition_vals_for_traj(self, traj):
951
+ tra_out=tra_out,
952
+ **concatenate_kwargs
953
+ )
954
+ return full_traj, first_condition_fulfilled
955
+
956
+ async def _condition_vals_for_traj(self, traj: Trajectory
957
+ ) -> list[np.ndarray]:
903
958
  # return a list of condition_func results,
904
959
  # one for each condition func in conditions
905
- if all(self._condition_func_is_coroutine):
906
- # easy, all coroutines
907
- return await asyncio.gather(*(c(traj) for c in self.conditions))
908
- elif not any(self._condition_func_is_coroutine):
909
- # also easy (but blocking), none is coroutine
910
- return [c(traj) for c in self.conditions]
911
- else:
912
- # need to piece it together
913
- # first the coroutines concurrently
914
- coros = [c(traj) for c, c_is_coro
915
- in zip(self.conditions, self._condition_func_is_coroutine)
916
- if c_is_coro
917
- ]
918
- coro_res = await asyncio.gather(*coros)
919
- # now either take the result from coro execution or calculate it
920
- all_results = []
960
+ if self.conditions.changed:
961
+ # first check if the conditions (single entries) have been modified
962
+ # if yes just reassign to the property so we recheck which of them
963
+ # are coroutines
964
+ self.conditions = self.conditions
965
+ # we wrap the non-coroutines into tasks to schedule all together
966
+ all_conditions_as_coro = [
967
+ c(traj) if c_is_coro else asyncio.to_thread(c, traj)
921
968
  for c, c_is_coro in zip(self.conditions,
922
- self._condition_func_is_coroutine):
923
- if c_is_coro:
924
- all_results.append(coro_res.pop(0))
925
- else:
926
- all_results.append(c(traj))
927
- return all_results
928
- # NOTE: this would be elegant, but to_thread() is py v>=3.9
929
- # we wrap the non-coroutines into tasks to schedule all together
930
- #all_conditions_as_coro = [
931
- # c(traj) if c_is_cor else asyncio.to_thread(c, traj)
932
- # for c, c_is_cor in zip(self.conditions, self._condition_func_is_coroutine)
933
- # ]
934
- #return await asyncio.gather(*all_conditions_as_coro)
969
+ self._condition_func_is_coroutine)
970
+ ]
971
+ results = await asyncio.gather(*all_conditions_as_coro)
972
+ cond_eq_traj_len = [len(traj) == len(r) for r in results]
973
+ if not all(cond_eq_traj_len):
974
+ bad_condition_idx_str = ", ".join([f"{idx}" for idx, good
975
+ in enumerate(cond_eq_traj_len)
976
+ if not good])
977
+ raise ValueError("At least one of the conditions does not return "
978
+ "an array of shape (len(traj), ) when applied to "
979
+ "the trajectory traj. The conditions in question "
980
+ "have indexes " + bad_condition_idx_str + " .")
981
+ return results
935
982
 
936
983
 
937
984
  # alias for people coming from the path sampling community :)