asyncmd 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,6 +12,12 @@
12
12
  #
13
13
  # You should have received a copy of the GNU General Public License
14
14
  # along with asyncmd. If not, see <https://www.gnu.org/licenses/>.
15
+ """
16
+ This module contains the Implementation of all TrajectoryFunctionWrapper classes.
17
+
18
+ It contains the abstract base class TrajectoryFunctionWrapper,
19
+ and the functional classes PyTrajectoryFunctionWrapper and SlurmTrajectoryFunctionWrapper.
20
+ """
15
21
  import os
16
22
  import abc
17
23
  import shlex
@@ -20,16 +26,18 @@ import inspect
20
26
  import logging
21
27
  import hashlib
22
28
  import functools
29
+ import collections
30
+ import dataclasses
23
31
  import typing
32
+ from concurrent.futures import ThreadPoolExecutor
24
33
  import aiofiles
25
- import aiofiles.os
26
34
  import numpy as np
27
- from concurrent.futures import ThreadPoolExecutor
28
35
 
29
36
 
30
- from .._config import _SEMAPHORES
37
+ from .._config import _SEMAPHORES, _OPT_SEMAPHORES
31
38
  from .. import slurm
32
39
  from ..tools import ensure_executable_available, remove_file_if_exist_async
40
+ from ..tools import attach_kwargs_to_object as _attach_kwargs_to_object
33
41
  from .trajectory import Trajectory
34
42
 
35
43
 
@@ -44,28 +52,15 @@ class TrajectoryFunctionWrapper(abc.ABC):
44
52
  # (otherwise users could overwrite them by passing _id="blub" to
45
53
  # init), but since the subclasses sets call_kwargs again and
46
54
  # have to calculate the id according to their own recipe anyway
47
- # we can savely set them here (this enables us to use the id
55
+ # we can safely set them here (this enables us to use the id
48
56
  # property at initialization time as e.g. in the slurm_jobname
49
57
  # of the SlurmTrajectoryFunctionWrapper)
50
- self._id = None
51
- self._call_kwargs = {} # init to empty dict such that iteration works
58
+ self._id = ""
59
+ # initialize to an empty dict such that iteration works
60
+ self._call_kwargs: dict[str, typing.Any] = {}
52
61
  # make it possible to set any attribute via kwargs
53
62
  # check the type for attributes with default values
54
- dval = object()
55
- for kwarg, value in kwargs.items():
56
- cval = getattr(self, kwarg, dval)
57
- if cval is not dval:
58
- if isinstance(value, type(cval)):
59
- # value is of same type as default so set it
60
- setattr(self, kwarg, value)
61
- else:
62
- raise TypeError(f"Setting attribute {kwarg} with "
63
- + f"mismatching type ({type(value)}). "
64
- + f" Default type is {type(cval)}."
65
- )
66
- else:
67
- # not previously defined, so warn that we ignore it
68
- logger.warning("Ignoring unknown keyword-argument %s.", kwarg)
63
+ _attach_kwargs_to_object(obj=self, logger=logger, **kwargs)
69
64
 
70
65
  @property
71
66
  def id(self) -> str:
@@ -77,15 +72,17 @@ class TrajectoryFunctionWrapper(abc.ABC):
77
72
  return self._id
78
73
 
79
74
  @property
80
- def call_kwargs(self) -> dict:
81
- """Additional calling arguments for the wrapped function/executable."""
75
+ def call_kwargs(self) -> dict[str, typing.Any]:
76
+ """
77
+ Additional calling arguments for the wrapped function/executable.
78
+
79
+ **NOTE:** You can only (re)set the complete dict and not single keys!
80
+ """
82
81
  # return a copy to avoid people modifying entries without us noticing
83
- # TODO/FIXME: this will make unhappy users if they try to set single
84
- # items in the dict!
85
82
  return self._call_kwargs.copy()
86
83
 
87
84
  @call_kwargs.setter
88
- def call_kwargs(self, value):
85
+ def call_kwargs(self, value: dict[str, typing.Any]) -> None:
89
86
  if not isinstance(value, dict):
90
87
  raise TypeError("call_kwargs must be a dictionary.")
91
88
  self._call_kwargs = value
@@ -93,21 +90,24 @@ class TrajectoryFunctionWrapper(abc.ABC):
93
90
 
94
91
  @abc.abstractmethod
95
92
  def _get_id_str(self) -> str:
96
- # this is expected to return an unique idetifying string
97
- # this should be unique and portable, i.e. it should enable us to make
98
- # ensure that the cached values will only be used for the same function
99
- # called with the same arguments
100
- pass
93
+ """
94
+ This function is expected to return an unique identifying string.
95
+
96
+ It should be unique and portable, i.e. it should enable us to make
97
+ sure that the cached values will only be used for the same function
98
+ called with the same arguments.
99
+ """
101
100
 
102
101
  @abc.abstractmethod
103
- async def get_values_for_trajectory(self, traj):
104
- # will be called by trajectory._apply_wrapped_func()
105
- # is expected to return a numpy array, shape=(n_frames, n_dim_function)
106
- pass
102
+ async def _get_values_for_trajectory(self, traj: Trajectory) -> np.ndarray:
103
+ """
104
+ Will be called by self.__call__ to actually perform the function calculation.
105
+ Is expected to return a numpy array, shape=(n_frames, n_dim_function).
106
+ """
107
107
 
108
- async def __call__(self, value):
108
+ async def __call__(self, value: Trajectory) -> np.ndarray:
109
109
  """
110
- Apply wrapped function asyncronously on given trajectory.
110
+ Apply wrapped function asynchronously on given trajectory.
111
111
 
112
112
  Parameters
113
113
  ----------
@@ -119,40 +119,37 @@ class TrajectoryFunctionWrapper(abc.ABC):
119
119
  iterable, usually list or np.ndarray
120
120
  The values of the wrapped function when applied on the trajectory.
121
121
  """
122
- if isinstance(value, Trajectory):
123
- if self.id is not None:
124
- return await value._apply_wrapped_func(self.id, self)
125
- raise RuntimeError(f"{type(self)} must have a unique id, but "
126
- "self.id is None.")
127
- raise TypeError(f"{type(self)} must be called with an "
128
- "`asyncmd.Trajectory` but was called with "
129
- f"{type(value)}.")
122
+ if not isinstance(value, Trajectory):
123
+ raise TypeError(f"{type(self)} must be called with an "
124
+ "`asyncmd.Trajectory` but was called with "
125
+ f"{type(value)}.")
126
+ async with value._semaphores_by_func_id[self.id]:
127
+ if ( # see if we have already values cached
128
+ func_values := value._retrieve_cached_values(func_wrapper=self)
129
+ ) is None:
130
+ # no values cached yet, so calc them and cache them
131
+ func_values = await self._get_values_for_trajectory(value)
132
+ value._register_cached_values(values=func_values, func_wrapper=self)
133
+ return func_values
130
134
 
131
135
 
132
136
  class PyTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
133
137
  """
134
138
  Wrap python functions for use on :class:`asyncmd.Trajectory`.
135
139
 
136
- Turns every python callable into an asyncronous (awaitable) and cached
140
+ Turns every python callable into an asynchronous (awaitable) and cached
137
141
  function for application on :class:`asyncmd.Trajectory`. Also works for
138
- asyncronous (awaitable) functions, they will be cached.
139
-
140
- Attributes
141
- ----------
142
- function : callable
143
- The wrapped callable.
144
- call_kwargs : dict
145
- Keyword arguments for wrapped function.
142
+ asynchronous (awaitable) functions, they will be cached.
146
143
  """
147
- def __init__(self, function, call_kwargs: typing.Optional[dict] = None,
148
- **kwargs):
144
+ def __init__(self, function, call_kwargs: dict[str, typing.Any] | None = None,
145
+ **kwargs) -> None:
149
146
  """
150
147
  Initialize a :class:`PyTrajectoryFunctionWrapper`.
151
148
 
152
149
  Parameters
153
150
  ----------
154
151
  function : callable
155
- The (synchronous) callable to wrap.
152
+ The (synchronous or asynchronous) callable to wrap.
156
153
  call_kwargs : dict, optional
157
154
  Keyword arguments for `function`,
158
155
  the keys will be used as keyword with the corresponding values,
@@ -173,26 +170,26 @@ class PyTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
173
170
  + f"call_kwargs={self.call_kwargs})"
174
171
  )
175
172
 
176
- def _get_id_str(self):
173
+ def _get_id_str(self) -> str:
177
174
  # calculate a hash over function src and call_kwargs dict
178
175
  # this should be unique and portable, i.e. it should enable us to make
179
176
  # ensure that the cached values will only be used for the same function
180
177
  # called with the same arguments
181
- id = 0
178
+ _id = 0
182
179
  # NOTE: addition is commutative, i.e. order does not matter here!
183
180
  for k, v in self._call_kwargs.items():
184
181
  # hash the value
185
- id += int(hashlib.blake2b(str(v).encode('utf-8')).hexdigest(), 16)
182
+ _id += int(hashlib.blake2b(str(v).encode('utf-8')).hexdigest(), 16)
186
183
  # hash the key
187
- id += int(hashlib.blake2b(str(k).encode('utf-8')).hexdigest(), 16)
184
+ _id += int(hashlib.blake2b(str(k).encode('utf-8')).hexdigest(), 16)
188
185
  # and add the func_src
189
- id += int(hashlib.blake2b(str(self._func_src).encode('utf-8')).hexdigest(), 16)
190
- return str(id) # return a str because we want to use it as dict keys
186
+ _id += int(hashlib.blake2b(str(self._func_src).encode('utf-8')).hexdigest(), 16)
187
+ return str(_id) # return a str because we want to use it as dict keys
191
188
 
192
189
  @property
193
190
  def function(self):
194
191
  """
195
- The python callable this :class:`PyTrajectoryFunctionWrapper` wrapps.
192
+ The python callable this :class:`PyTrajectoryFunctionWrapper` wraps.
196
193
  """
197
194
  return self._func
198
195
 
@@ -202,23 +199,20 @@ class PyTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
202
199
  or inspect.iscoroutinefunction(value.__call__))
203
200
  try:
204
201
  src = inspect.getsource(value)
205
- except OSError:
202
+ except OSError as e:
206
203
  # OSError is raised if source can not be retrieved
207
- self._func_src = None
208
- self._id = None
209
- logger.warning("Could not retrieve source for %s."
210
- " No caching can/will be performed.",
211
- value,
212
- )
204
+ raise OSError(f"Could not retrieve source for {value}."
205
+ " No hash can be calculated and no caching performed."
206
+ ) from e
213
207
  else:
214
208
  self._func_src = src
215
209
  self._id = self._get_id_str() # get/set ID
216
210
  finally:
217
211
  self._func = value
218
212
 
219
- async def get_values_for_trajectory(self, traj):
213
+ async def _get_values_for_trajectory(self, traj: Trajectory) -> np.ndarray:
220
214
  """
221
- Apply wrapped function asyncronously on given trajectory.
215
+ Apply wrapped function asynchronously on given trajectory.
222
216
 
223
217
  Parameters
224
218
  ----------
@@ -234,10 +228,10 @@ class PyTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
234
228
  return await self._get_values_for_trajectory_async(traj)
235
229
  return await self._get_values_for_trajectory_sync(traj)
236
230
 
237
- async def _get_values_for_trajectory_async(self, traj):
231
+ async def _get_values_for_trajectory_async(self, traj: Trajectory) -> np.ndarray:
238
232
  return await self._func(traj, **self._call_kwargs)
239
233
 
240
- async def _get_values_for_trajectory_sync(self, traj):
234
+ async def _get_values_for_trajectory_sync(self, traj: Trajectory) -> np.ndarray:
241
235
  loop = asyncio.get_running_loop()
242
236
  async with _SEMAPHORES["MAX_PROCESS"]:
243
237
  # fill in additional kwargs (if any)
@@ -255,7 +249,7 @@ class PyTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
255
249
  # only place where this could make sense to think about as
256
250
  # opposed to concatenation of trajs (which is IO bound)
257
251
  # NOTE: make sure we do not fork! (not save with multithreading)
258
- # see e.g. https://stackoverflow.com/questions/46439740/safe-to-call-multiprocessing-from-a-thread-in-python
252
+ # see e.g. https://stackoverflow.com/a/46440564
259
253
  #ctx = multiprocessing.get_context("forkserver")
260
254
  #with ProcessPoolExecutor(1, mp_context=ctx) as pool:
261
255
  with ThreadPoolExecutor(max_workers=1,
@@ -265,11 +259,21 @@ class PyTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
265
259
  return vals
266
260
 
267
261
 
262
+ @dataclasses.dataclass
263
+ class _SlurmDataForTrajFuncWrapper:
264
+ """
265
+ Bundle/store all data related to slurm (submission) for :class:`SlurmTrajectoryFunctionWrapper`.
266
+ """
267
+ sbatch_script: str = ""
268
+ jobname: str | None = None
269
+ sbatch_options: dict[str, str] | None = None
270
+
271
+
268
272
  class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
269
273
  """
270
274
  Wrap executables to use on :class:`asyncmd.Trajectory` via SLURM.
271
275
 
272
- The execution of the job is submited to the queueing system with the
276
+ The execution of the job is submitted to the queueing system with the
273
277
  given sbatch script (template).
274
278
  The executable will be called with the following positional arguments:
275
279
 
@@ -289,25 +293,13 @@ class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
289
293
 
290
294
  See also the examples for a reference (python) implementation of multiple
291
295
  different functions/executables for use with this class.
292
-
293
- Attributes
294
- ----------
295
- slurm_jobname : str
296
- Used as name for the job in slurm and also as part of the filename for
297
- the submission script that will be written (and deleted if everything
298
- goes well) for every trajectory.
299
- NOTE: If you modify it, ensure that each SlurmTrajectoryWrapper has a
300
- unique slurm_jobname.
301
- executable : str
302
- Name of or path to the wrapped executable.
303
- call_kwargs : dict
304
- Keyword arguments for wrapped executable.
305
296
  """
306
297
 
307
- def __init__(self, executable, sbatch_script,
298
+ def __init__(self, executable, sbatch_script, *,
308
299
  sbatch_options: dict | None = None,
309
- call_kwargs: typing.Optional[dict] = None,
310
- load_results_func=None, **kwargs):
300
+ call_kwargs: dict | None = None,
301
+ load_results_func: collections.abc.Callable | None = None,
302
+ **kwargs) -> None:
311
303
  """
312
304
  Initialize :class:`SlurmTrajectoryFunctionWrapper`.
313
305
 
@@ -328,23 +320,26 @@ class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
328
320
 
329
321
  sbatch_options : dict or None
330
322
  Dictionary of sbatch options, keys are long names for options,
331
- values are the correponding values. The keys/long names are given
332
- without the dashes, e.g. to specify "--mem=1024" the dictionary
333
- needs to be {"mem": "1024"}. To specify options without values use
334
- keys with empty strings as values, e.g. to specify "--contiguous"
335
- the dictionary needs to be {"contiguous": ""}.
323
+ values are the corresponding values. The keys/long names are given
324
+ without the dashes, e.g. to specify ``--mem=1024`` the dictionary
325
+ needs to be ``{"mem": "1024"}``. To specify options without values
326
+ use keys with empty strings as values, e.g. to specify
327
+ ``--contiguous`` the dictionary needs to be ``{"contiguous": ""}``.
336
328
  See the SLURM documentation for a full list of sbatch options
337
329
  (https://slurm.schedmd.com/sbatch.html).
338
- Note: This argument is passed as is to the `SlurmProcess` in which
330
+ Note: This argument is passed as is to the ``SlurmProcess`` in which
339
331
  the computation is performed. Each call of the TrajectoryFunction
340
- triggers the creation of a new `SlurmProcess` and will use the then
341
- current `sbatch_options`.
332
+ triggers the creation of a new :class:`asyncmd.slurm.SlurmProcess`
333
+ and will use the then current ``sbatch_options``.
342
334
  call_kwargs : dict, optional
343
335
  Dictionary of additional arguments to pass to the executable, they
344
- will be added to the call as pair ' {key} {val}', note that in case
345
- you want to pass single command line flags (like '-v') this can be
346
- achieved by setting key='-v' and val='', i.e. to the empty string.
347
- The values are shell escaped using `shlex.quote()` when writing
336
+ will be added to the call as pair `` {key} {val}``, note that in
337
+ case you want to pass single command line flags (like ``-v``) this
338
+ can be achieved by setting ``key="-v"`` and ``val=""``, i.e. to the
339
+ empty string.
340
+ Lists as values will be unpacked and added as (for a list with n
341
+ entries): `` {key} {val1} {val2} ... {valn}``.
342
+ The values are shell escaped using :func:`shlex.quote` when writing
348
343
  them to the sbatch script.
349
344
  load_results_func : None or function (callable)
350
345
  Function to call to customize the loading of the results.
@@ -352,21 +347,13 @@ class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
352
347
  the results file (as in the call to the executable) and should
353
348
  return a numpy array containing the loaded values.
354
349
  """
355
- # property defaults before superclass init to be resettable via kwargs
356
- self._slurm_jobname = None
350
+ # init property defaults before superclass init to be resettable via kwargs
351
+ self._slurm_data = _SlurmDataForTrajFuncWrapper(sbatch_options=sbatch_options,
352
+ )
357
353
  super().__init__(**kwargs)
358
- self._executable = None
359
- # we expect sbatch_script to be a str,
360
- # but it could be either the path to a submit script or the content of
361
- # the submission script directly
362
- # we decide what it is by checking for the shebang
363
- if not sbatch_script.startswith("#!"):
364
- # probably path to a file, lets try to read it
365
- with open(sbatch_script, 'r') as f:
366
- sbatch_script = f.read()
367
- # (possibly) use properties to calc the id directly
354
+ self._executable = ""
355
+ # use property setters to (possibly) calc the id directly and/or preprocess values
368
356
  self.sbatch_script = sbatch_script
369
- self.sbatch_options = sbatch_options
370
357
  self.executable = executable
371
358
  if call_kwargs is None:
372
359
  call_kwargs = {}
@@ -378,58 +365,95 @@ class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
378
365
  """
379
366
  The jobname of the slurm job used to compute the function results.
380
367
 
381
- Must be unique for each :class:`SlurmTrajectoryFunctionWrapper`
368
+ Also used as part of the filename for the submission script that will
369
+ be written (and deleted if everything goes well) for every trajectory.
370
+
371
+ **NOTE:** Must be unique for each :class:`SlurmTrajectoryFunctionWrapper`
382
372
  instance. Will by default include the persistent unique ID :meth:`id`.
383
373
  To (re)set to the default set it to None.
384
374
  """
385
- if self._slurm_jobname is None:
375
+ if self._slurm_data.jobname is None:
386
376
  return f"CVfunc_id_{self.id}"
387
- return self._slurm_jobname
377
+ return self._slurm_data.jobname
388
378
 
389
379
  @slurm_jobname.setter
390
- def slurm_jobname(self, val: str | None):
391
- self._slurm_jobname = val
380
+ def slurm_jobname(self, val: str | None) -> None:
381
+ self._slurm_data.jobname = val
382
+
383
+ @property
384
+ def sbatch_script(self) -> str:
385
+ """
386
+ Content of the sbatch script (see the corresponding __init__ argument).
387
+
388
+ Can also be set with the path to a file, in this case the script will be read.
389
+ """
390
+ return self._slurm_data.sbatch_script
391
+
392
+ @sbatch_script.setter
393
+ def sbatch_script(self, value: str) -> None:
394
+ # we expect sbatch_script to be a str,
395
+ # but it could be either the path to a submit script or the content of
396
+ # the submission script directly
397
+ # we decide what it is by checking for the shebang
398
+ if not value.startswith("#!"):
399
+ # probably path to a file, lets try to read it
400
+ with open(value, 'r', encoding="locale") as f:
401
+ value = f.read()
402
+ self._slurm_data.sbatch_script = value
403
+
404
+ @property
405
+ def sbatch_options(self) -> dict[str, str] | None:
406
+ """
407
+ Dictionary of sbatch_options or None (see the corresponding __init__ argument).
408
+
409
+ **NOTE:** You can only (re)set the complete dict and not single keys!
410
+ """
411
+ return self._slurm_data.sbatch_options
412
+
413
+ @sbatch_options.setter
414
+ def sbatch_options(self, value: dict[str, str] | None) -> None:
415
+ self._slurm_data.sbatch_options = value
392
416
 
393
417
  def __repr__(self) -> str:
394
418
  return (f"SlurmTrajectoryFunctionWrapper(executable={self._executable}, "
395
419
  + f"call_kwargs={self.call_kwargs})"
396
420
  )
397
421
 
398
- def _get_id_str(self):
422
+ def _get_id_str(self) -> str:
399
423
  # calculate a hash over executable and call_kwargs dict
400
424
  # this should be unique and portable, i.e. it should enable us to make
401
425
  # ensure that the cached values will only be used for the same function
402
426
  # called with the same arguments
403
- id = 0
427
+ _id = 0
404
428
  # NOTE: addition is commutative, i.e. order does not matter here!
405
429
  for k, v in self._call_kwargs.items():
406
430
  # hash the value
407
- id += int(hashlib.blake2b(str(v).encode('utf-8')).hexdigest(), 16)
431
+ _id += int(hashlib.blake2b(str(v).encode('utf-8')).hexdigest(), 16)
408
432
  # hash the key
409
- id += int(hashlib.blake2b(str(k).encode('utf-8')).hexdigest(), 16)
433
+ _id += int(hashlib.blake2b(str(k).encode('utf-8')).hexdigest(), 16)
410
434
  # and add the executable hash
411
435
  with open(self.executable, "rb") as exe_file:
412
436
  # NOTE: we assume that executable is small enough to read at once
413
437
  # if this crashes becasue of OOM we should use chunks...
414
438
  data = exe_file.read()
415
- id += int(hashlib.blake2b(data).hexdigest(), 16)
416
- return str(id) # return a str because we want to use it as dict keys
439
+ _id += int(hashlib.blake2b(data).hexdigest(), 16)
440
+ return str(_id) # return a str because we want to use it as dict keys
417
441
 
418
442
  @property
419
- def executable(self):
443
+ def executable(self) -> str:
420
444
  """The executable used to compute the function results."""
421
445
  return self._executable
422
446
 
423
447
  @executable.setter
424
- def executable(self, val):
448
+ def executable(self, val: str) -> None:
425
449
  exe = ensure_executable_available(val)
426
450
  # if we get here it should be save to set, i.e. it exists + has X-bit
427
451
  self._executable = exe
428
452
  self._id = self._get_id_str() # get the new hash/id
429
453
 
430
- async def get_values_for_trajectory(self, traj):
454
+ async def _get_values_for_trajectory(self, traj: Trajectory) -> np.ndarray:
431
455
  """
432
- Apply wrapped function asyncronously on given trajectory.
456
+ Apply wrapped function asynchronously on given trajectory.
433
457
 
434
458
  Parameters
435
459
  ----------
@@ -438,7 +462,7 @@ class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
438
462
 
439
463
  Returns
440
464
  -------
441
- iterable, usually list or np.ndarray
465
+ np.ndarray
442
466
  The values of the wrapped function when applied on the trajectory.
443
467
  """
444
468
  # first construct the path/name for the numpy npy file in which we expect
@@ -453,22 +477,8 @@ class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
453
477
  result_file = os.path.abspath(os.path.join(
454
478
  tra_dir, f"{tra_name}_{hash_part}_CVfunc_id_{self.id}"
455
479
  ))
456
- # we expect executable to take 3 postional args:
457
- # struct traj outfile
458
- cmd_str = f"{self.executable} {os.path.abspath(traj.structure_file)}"
459
- cmd_str += f" {' '.join(os.path.abspath(t) for t in traj.trajectory_files)}"
460
- cmd_str += f" {result_file}"
461
- if len(self.call_kwargs) > 0:
462
- for key, val in self.call_kwargs.items():
463
- # shell escape only the values,
464
- # the keys (i.e. option names/flags) should be no issue
465
- if isinstance(val, list):
466
- # enable lists of arguments for the same key,
467
- # can then be used e.g. with pythons argparse `nargs="*"` or `nargs="+"`
468
- cmd_str += f" {key} {' '.join([shlex.quote(str(v)) for v in val])}"
469
- else:
470
- cmd_str += f" {key} {shlex.quote(str(val))}"
471
480
  # now prepare the sbatch script
481
+ cmd_str = self._build_cmd_str(traj=traj, result_file=result_file)
472
482
  script = self.sbatch_script.format(cmd_str=cmd_str)
473
483
  # write it out
474
484
  sbatch_fname = os.path.join(tra_dir,
@@ -479,7 +489,7 @@ class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
479
489
  sbatch_fname,
480
490
  )
481
491
  async with _SEMAPHORES["MAX_FILES_OPEN"]:
482
- async with aiofiles.open(sbatch_fname, 'w') as f:
492
+ async with aiofiles.open(sbatch_fname, 'w', encoding="locale") as f:
483
493
  await f.write(script)
484
494
  # NOTE: we set returncode to 2 (what slurmprocess returns in case of
485
495
  # node failure) and rerun/retry until we either get a completed job
@@ -499,7 +509,8 @@ class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
499
509
  self.executable, traj, slurm_proc.slurm_jobid,
500
510
  stdout.decode(), stderr.decode(),
501
511
  )
502
- if returncode != 0:
512
+ if returncode:
513
+ # Non-zero return code,
503
514
  # Can not be exitcode 2, because of the while loop above
504
515
  raise RuntimeError(
505
516
  "Non-zero exit code from CV batch job for "
@@ -511,6 +522,29 @@ class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
511
522
  + f" and stdout was: {stdout.decode()}"
512
523
  )
513
524
  # zero-exitcode: load the results
525
+ return await self._load_results_remove_files(result_file=result_file,
526
+ sbatch_fname=sbatch_fname)
527
+
528
+ def _build_cmd_str(self, traj: Trajectory, result_file: str) -> str:
529
+ # we expect executable to take 3 positional args:
530
+ # struct traj outfile
531
+ cmd_str = f"{self.executable} {os.path.abspath(traj.structure_file)}"
532
+ cmd_str += f" {' '.join(os.path.abspath(t) for t in traj.trajectory_files)}"
533
+ cmd_str += f" {result_file}"
534
+ if len(self.call_kwargs) > 0:
535
+ for key, val in self.call_kwargs.items():
536
+ # shell escape only the values,
537
+ # the keys (i.e. option names/flags) should be no issue
538
+ if isinstance(val, list):
539
+ # enable lists of arguments for the same key,
540
+ # can then be used e.g. with pythons argparse `nargs="*"` or `nargs="+"`
541
+ cmd_str += f" {key} {' '.join([shlex.quote(str(v)) for v in val])}"
542
+ else:
543
+ cmd_str += f" {key} {shlex.quote(str(val))}"
544
+ return cmd_str
545
+
546
+ async def _load_results_remove_files(self, result_file: str,
547
+ sbatch_fname: str) -> np.ndarray:
514
548
  if self.load_results_func is None:
515
549
  # we do not have '.npy' ending in results_file,
516
550
  # numpy.save() adds it if it is not there, so we need it here
@@ -538,12 +572,11 @@ class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
538
572
 
539
573
  async def _run_slurm_job(self, sbatch_fname: str, result_file: str,
540
574
  slurm_workdir: str,
541
- ) -> tuple[int,slurm.SlurmProcess,bytes,bytes]:
575
+ ) -> tuple[int, slurm.SlurmProcess, bytes, bytes]:
542
576
  # submit and run slurm-job
543
- if _SEMAPHORES["SLURM_MAX_JOB"] is not None:
544
- await _SEMAPHORES["SLURM_MAX_JOB"].acquire()
545
- try: # this try is just to make sure we always release the semaphore
546
- slurm_proc = await slurm.create_slurmprocess_submit(
577
+ if _OPT_SEMAPHORES["SLURM_MAX_JOB"] is not None:
578
+ await _OPT_SEMAPHORES["SLURM_MAX_JOB"].acquire()
579
+ slurm_proc = await slurm.create_slurmprocess_submit(
547
580
  jobname=self.slurm_jobname,
548
581
  sbatch_script=sbatch_fname,
549
582
  workdir=slurm_workdir,
@@ -552,12 +585,13 @@ class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
552
585
  stdin=None,
553
586
  # sleep 5 s between checking
554
587
  sleep_time=5,
555
- )
588
+ )
589
+ # this try is just to make sure we always release the semaphore
590
+ # and to possibly clean up any left-over files from the fail
591
+ try:
556
592
  # wait for the slurm job to finish
557
593
  # also cancel the job when this future is canceled
558
594
  stdout, stderr = await slurm_proc.communicate()
559
- returncode = slurm_proc.returncode
560
- return returncode, slurm_proc, stdout, stderr
561
595
  except asyncio.CancelledError:
562
596
  slurm_proc.kill()
563
597
  # clean up the sbatch file and potentialy written result file
@@ -567,6 +601,8 @@ class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
567
601
  remove_file_if_exist_async(res_fname),
568
602
  )
569
603
  raise # reraise CancelledError for encompassing coroutines
604
+ else:
605
+ return slurm_proc.returncode, slurm_proc, stdout, stderr
570
606
  finally:
571
- if _SEMAPHORES["SLURM_MAX_JOB"] is not None:
572
- _SEMAPHORES["SLURM_MAX_JOB"].release()
607
+ if _OPT_SEMAPHORES["SLURM_MAX_JOB"] is not None:
608
+ _OPT_SEMAPHORES["SLURM_MAX_JOB"].release()