asyncmd 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,6 +12,12 @@
12
12
  #
13
13
  # You should have received a copy of the GNU General Public License
14
14
  # along with asyncmd. If not, see <https://www.gnu.org/licenses/>.
15
+ """
16
+ This module contains the Implementation of all TrajectoryFunctionWrapper classes.
17
+
18
+ It contains the abstract base class TrajectoryFunctionWrapper,
19
+ and the functional classes PyTrajectoryFunctionWrapper and SlurmTrajectoryFunctionWrapper.
20
+ """
15
21
  import os
16
22
  import abc
17
23
  import shlex
@@ -20,16 +26,18 @@ import inspect
20
26
  import logging
21
27
  import hashlib
22
28
  import functools
29
+ import collections
30
+ import dataclasses
23
31
  import typing
32
+ from concurrent.futures import ThreadPoolExecutor
24
33
  import aiofiles
25
- import aiofiles.os
26
34
  import numpy as np
27
- from concurrent.futures import ThreadPoolExecutor
28
35
 
29
36
 
30
- from .._config import _SEMAPHORES
37
+ from .._config import _SEMAPHORES, _OPT_SEMAPHORES
31
38
  from .. import slurm
32
39
  from ..tools import ensure_executable_available, remove_file_if_exist_async
40
+ from ..tools import attach_kwargs_to_object as _attach_kwargs_to_object
33
41
  from .trajectory import Trajectory
34
42
 
35
43
 
@@ -44,28 +52,15 @@ class TrajectoryFunctionWrapper(abc.ABC):
44
52
  # (otherwise users could overwrite them by passing _id="blub" to
45
53
  # init), but since the subclasses sets call_kwargs again and
46
54
  # have to calculate the id according to their own recipe anyway
47
- # we can savely set them here (this enables us to use the id
55
+ # we can safely set them here (this enables us to use the id
48
56
  # property at initialization time as e.g. in the slurm_jobname
49
57
  # of the SlurmTrajectoryFunctionWrapper)
50
- self._id = None
51
- self._call_kwargs = {} # init to empty dict such that iteration works
58
+ self._id = ""
59
+ # initialize to an empty dict such that iteration works
60
+ self._call_kwargs: dict[str, typing.Any] = {}
52
61
  # make it possible to set any attribute via kwargs
53
62
  # check the type for attributes with default values
54
- dval = object()
55
- for kwarg, value in kwargs.items():
56
- cval = getattr(self, kwarg, dval)
57
- if cval is not dval:
58
- if isinstance(value, type(cval)):
59
- # value is of same type as default so set it
60
- setattr(self, kwarg, value)
61
- else:
62
- raise TypeError(f"Setting attribute {kwarg} with "
63
- + f"mismatching type ({type(value)}). "
64
- + f" Default type is {type(cval)}."
65
- )
66
- else:
67
- # not previously defined, so warn that we ignore it
68
- logger.warning("Ignoring unknown keyword-argument %s.", kwarg)
63
+ _attach_kwargs_to_object(obj=self, logger=logger, **kwargs)
69
64
 
70
65
  @property
71
66
  def id(self) -> str:
@@ -77,15 +72,17 @@ class TrajectoryFunctionWrapper(abc.ABC):
77
72
  return self._id
78
73
 
79
74
  @property
80
- def call_kwargs(self) -> dict:
81
- """Additional calling arguments for the wrapped function/executable."""
75
+ def call_kwargs(self) -> dict[str, typing.Any]:
76
+ """
77
+ Additional calling arguments for the wrapped function/executable.
78
+
79
+ **NOTE:** You can only (re)set the complete dict and not single keys!
80
+ """
82
81
  # return a copy to avoid people modifying entries without us noticing
83
- # TODO/FIXME: this will make unhappy users if they try to set single
84
- # items in the dict!
85
82
  return self._call_kwargs.copy()
86
83
 
87
84
  @call_kwargs.setter
88
- def call_kwargs(self, value):
85
+ def call_kwargs(self, value: dict[str, typing.Any]) -> None:
89
86
  if not isinstance(value, dict):
90
87
  raise TypeError("call_kwargs must be a dictionary.")
91
88
  self._call_kwargs = value
@@ -93,21 +90,24 @@ class TrajectoryFunctionWrapper(abc.ABC):
93
90
 
94
91
  @abc.abstractmethod
95
92
  def _get_id_str(self) -> str:
96
- # this is expected to return an unique idetifying string
97
- # this should be unique and portable, i.e. it should enable us to make
98
- # ensure that the cached values will only be used for the same function
99
- # called with the same arguments
100
- pass
93
+ """
94
+ This function is expected to return an unique identifying string.
95
+
96
+ It should be unique and portable, i.e. it should enable us to make
97
+ sure that the cached values will only be used for the same function
98
+ called with the same arguments.
99
+ """
101
100
 
102
101
  @abc.abstractmethod
103
- async def get_values_for_trajectory(self, traj):
104
- # will be called by trajectory._apply_wrapped_func()
105
- # is expected to return a numpy array, shape=(n_frames, n_dim_function)
106
- pass
102
+ async def _get_values_for_trajectory(self, traj: Trajectory) -> np.ndarray:
103
+ """
104
+ Will be called by self.__call__ to actually perform the function calculation.
105
+ Is expected to return a numpy array, shape=(n_frames, n_dim_function).
106
+ """
107
107
 
108
- async def __call__(self, value):
108
+ async def __call__(self, value: Trajectory) -> np.ndarray:
109
109
  """
110
- Apply wrapped function asyncronously on given trajectory.
110
+ Apply wrapped function asynchronously on given trajectory.
111
111
 
112
112
  Parameters
113
113
  ----------
@@ -119,40 +119,37 @@ class TrajectoryFunctionWrapper(abc.ABC):
119
119
  iterable, usually list or np.ndarray
120
120
  The values of the wrapped function when applied on the trajectory.
121
121
  """
122
- if isinstance(value, Trajectory):
123
- if self.id is not None:
124
- return await value._apply_wrapped_func(self.id, self)
125
- raise RuntimeError(f"{type(self)} must have a unique id, but "
126
- "self.id is None.")
127
- raise TypeError(f"{type(self)} must be called with an "
128
- "`asyncmd.Trajectory` but was called with "
129
- f"{type(value)}.")
122
+ if not isinstance(value, Trajectory):
123
+ raise TypeError(f"{type(self)} must be called with an "
124
+ "`asyncmd.Trajectory` but was called with "
125
+ f"{type(value)}.")
126
+ async with value._semaphores_by_func_id[self.id]:
127
+ if ( # see if we have already values cached
128
+ func_values := value._retrieve_cached_values(func_wrapper=self)
129
+ ) is None:
130
+ # no values cached yet, so calc them and cache them
131
+ func_values = await self._get_values_for_trajectory(value)
132
+ value._register_cached_values(values=func_values, func_wrapper=self)
133
+ return func_values
130
134
 
131
135
 
132
136
  class PyTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
133
137
  """
134
138
  Wrap python functions for use on :class:`asyncmd.Trajectory`.
135
139
 
136
- Turns every python callable into an asyncronous (awaitable) and cached
140
+ Turns every python callable into an asynchronous (awaitable) and cached
137
141
  function for application on :class:`asyncmd.Trajectory`. Also works for
138
- asyncronous (awaitable) functions, they will be cached.
139
-
140
- Attributes
141
- ----------
142
- function : callable
143
- The wrapped callable.
144
- call_kwargs : dict
145
- Keyword arguments for wrapped function.
142
+ asynchronous (awaitable) functions, they will be cached.
146
143
  """
147
- def __init__(self, function, call_kwargs: typing.Optional[dict] = None,
148
- **kwargs):
144
+ def __init__(self, function, call_kwargs: dict[str, typing.Any] | None = None,
145
+ **kwargs) -> None:
149
146
  """
150
147
  Initialize a :class:`PyTrajectoryFunctionWrapper`.
151
148
 
152
149
  Parameters
153
150
  ----------
154
151
  function : callable
155
- The (synchronous) callable to wrap.
152
+ The (synchronous or asynchronous) callable to wrap.
156
153
  call_kwargs : dict, optional
157
154
  Keyword arguments for `function`,
158
155
  the keys will be used as keyword with the corresponding values,
@@ -173,26 +170,26 @@ class PyTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
173
170
  + f"call_kwargs={self.call_kwargs})"
174
171
  )
175
172
 
176
- def _get_id_str(self):
173
+ def _get_id_str(self) -> str:
177
174
  # calculate a hash over function src and call_kwargs dict
178
175
  # this should be unique and portable, i.e. it should enable us to make
179
176
  # ensure that the cached values will only be used for the same function
180
177
  # called with the same arguments
181
- id = 0
178
+ _id = 0
182
179
  # NOTE: addition is commutative, i.e. order does not matter here!
183
180
  for k, v in self._call_kwargs.items():
184
181
  # hash the value
185
- id += int(hashlib.blake2b(str(v).encode('utf-8')).hexdigest(), 16)
182
+ _id += int(hashlib.blake2b(str(v).encode('utf-8')).hexdigest(), 16)
186
183
  # hash the key
187
- id += int(hashlib.blake2b(str(k).encode('utf-8')).hexdigest(), 16)
184
+ _id += int(hashlib.blake2b(str(k).encode('utf-8')).hexdigest(), 16)
188
185
  # and add the func_src
189
- id += int(hashlib.blake2b(str(self._func_src).encode('utf-8')).hexdigest(), 16)
190
- return str(id) # return a str because we want to use it as dict keys
186
+ _id += int(hashlib.blake2b(str(self._func_src).encode('utf-8')).hexdigest(), 16)
187
+ return str(_id) # return a str because we want to use it as dict keys
191
188
 
192
189
  @property
193
190
  def function(self):
194
191
  """
195
- The python callable this :class:`PyTrajectoryFunctionWrapper` wrapps.
192
+ The python callable this :class:`PyTrajectoryFunctionWrapper` wraps.
196
193
  """
197
194
  return self._func
198
195
 
@@ -202,23 +199,20 @@ class PyTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
202
199
  or inspect.iscoroutinefunction(value.__call__))
203
200
  try:
204
201
  src = inspect.getsource(value)
205
- except OSError:
202
+ except OSError as e:
206
203
  # OSError is raised if source can not be retrieved
207
- self._func_src = None
208
- self._id = None
209
- logger.warning("Could not retrieve source for %s."
210
- " No caching can/will be performed.",
211
- value,
212
- )
204
+ raise OSError(f"Could not retrieve source for {value}."
205
+ " No hash can be calculated and no caching performed."
206
+ ) from e
213
207
  else:
214
208
  self._func_src = src
215
209
  self._id = self._get_id_str() # get/set ID
216
210
  finally:
217
211
  self._func = value
218
212
 
219
- async def get_values_for_trajectory(self, traj):
213
+ async def _get_values_for_trajectory(self, traj: Trajectory) -> np.ndarray:
220
214
  """
221
- Apply wrapped function asyncronously on given trajectory.
215
+ Apply wrapped function asynchronously on given trajectory.
222
216
 
223
217
  Parameters
224
218
  ----------
@@ -234,10 +228,10 @@ class PyTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
234
228
  return await self._get_values_for_trajectory_async(traj)
235
229
  return await self._get_values_for_trajectory_sync(traj)
236
230
 
237
- async def _get_values_for_trajectory_async(self, traj):
231
+ async def _get_values_for_trajectory_async(self, traj: Trajectory) -> np.ndarray:
238
232
  return await self._func(traj, **self._call_kwargs)
239
233
 
240
- async def _get_values_for_trajectory_sync(self, traj):
234
+ async def _get_values_for_trajectory_sync(self, traj: Trajectory) -> np.ndarray:
241
235
  loop = asyncio.get_running_loop()
242
236
  async with _SEMAPHORES["MAX_PROCESS"]:
243
237
  # fill in additional kwargs (if any)
@@ -255,7 +249,7 @@ class PyTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
255
249
  # only place where this could make sense to think about as
256
250
  # opposed to concatenation of trajs (which is IO bound)
257
251
  # NOTE: make sure we do not fork! (not save with multithreading)
258
- # see e.g. https://stackoverflow.com/questions/46439740/safe-to-call-multiprocessing-from-a-thread-in-python
252
+ # see e.g. https://stackoverflow.com/a/46440564
259
253
  #ctx = multiprocessing.get_context("forkserver")
260
254
  #with ProcessPoolExecutor(1, mp_context=ctx) as pool:
261
255
  with ThreadPoolExecutor(max_workers=1,
@@ -265,11 +259,21 @@ class PyTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
265
259
  return vals
266
260
 
267
261
 
262
+ @dataclasses.dataclass
263
+ class _SlurmDataForTrajFuncWrapper:
264
+ """
265
+ Bundle/store all data related to slurm (submission) for :class:`SlurmTrajectoryFunctionWrapper`.
266
+ """
267
+ sbatch_script: str = ""
268
+ jobname: str | None = None
269
+ sbatch_options: dict[str, str] | None = None
270
+
271
+
268
272
  class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
269
273
  """
270
274
  Wrap executables to use on :class:`asyncmd.Trajectory` via SLURM.
271
275
 
272
- The execution of the job is submited to the queueing system with the
276
+ The execution of the job is submitted to the queueing system with the
273
277
  given sbatch script (template).
274
278
  The executable will be called with the following positional arguments:
275
279
 
@@ -289,24 +293,13 @@ class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
289
293
 
290
294
  See also the examples for a reference (python) implementation of multiple
291
295
  different functions/executables for use with this class.
292
-
293
- Attributes
294
- ----------
295
- slurm_jobname : str
296
- Used as name for the job in slurm and also as part of the filename for
297
- the submission script that will be written (and deleted if everything
298
- goes well) for every trajectory.
299
- NOTE: If you modify it, ensure that each SlurmTrajectoryWrapper has a
300
- unique slurm_jobname.
301
- executable : str
302
- Name of or path to the wrapped executable.
303
- call_kwargs : dict
304
- Keyword arguments for wrapped executable.
305
296
  """
306
297
 
307
- def __init__(self, executable, sbatch_script,
308
- call_kwargs: typing.Optional[dict] = None,
309
- load_results_func=None, **kwargs):
298
+ def __init__(self, executable, sbatch_script, *,
299
+ sbatch_options: dict | None = None,
300
+ call_kwargs: dict | None = None,
301
+ load_results_func: collections.abc.Callable | None = None,
302
+ **kwargs) -> None:
310
303
  """
311
304
  Initialize :class:`SlurmTrajectoryFunctionWrapper`.
312
305
 
@@ -325,12 +318,28 @@ class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
325
318
 
326
319
  - {cmd_str} : Replaced by the command to call the executable on a given trajectory.
327
320
 
321
+ sbatch_options : dict or None
322
+ Dictionary of sbatch options, keys are long names for options,
323
+ values are the corresponding values. The keys/long names are given
324
+ without the dashes, e.g. to specify ``--mem=1024`` the dictionary
325
+ needs to be ``{"mem": "1024"}``. To specify options without values
326
+ use keys with empty strings as values, e.g. to specify
327
+ ``--contiguous`` the dictionary needs to be ``{"contiguous": ""}``.
328
+ See the SLURM documentation for a full list of sbatch options
329
+ (https://slurm.schedmd.com/sbatch.html).
330
+ Note: This argument is passed as is to the ``SlurmProcess`` in which
331
+ the computation is performed. Each call of the TrajectoryFunction
332
+ triggers the creation of a new :class:`asyncmd.slurm.SlurmProcess`
333
+ and will use the then current ``sbatch_options``.
328
334
  call_kwargs : dict, optional
329
335
  Dictionary of additional arguments to pass to the executable, they
330
- will be added to the call as pair ' {key} {val}', note that in case
331
- you want to pass single command line flags (like '-v') this can be
332
- achieved by setting key='-v' and val='', i.e. to the empty string.
333
- The values are shell escaped using `shlex.quote()` when writing
336
+ will be added to the call as pair `` {key} {val}``, note that in
337
+ case you want to pass single command line flags (like ``-v``) this
338
+ can be achieved by setting ``key="-v"`` and ``val=""``, i.e. to the
339
+ empty string.
340
+ Lists as values will be unpacked and added as (for a list with n
341
+ entries): `` {key} {val1} {val2} ... {valn}``.
342
+ The values are shell escaped using :func:`shlex.quote` when writing
334
343
  them to the sbatch script.
335
344
  load_results_func : None or function (callable)
336
345
  Function to call to customize the loading of the results.
@@ -338,19 +347,12 @@ class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
338
347
  the results file (as in the call to the executable) and should
339
348
  return a numpy array containing the loaded values.
340
349
  """
341
- # property defaults before superclass init to be resettable via kwargs
342
- self._slurm_jobname = None
350
+ # init property defaults before superclass init to be resettable via kwargs
351
+ self._slurm_data = _SlurmDataForTrajFuncWrapper(sbatch_options=sbatch_options,
352
+ )
343
353
  super().__init__(**kwargs)
344
- self._executable = None
345
- # we expect sbatch_script to be a str,
346
- # but it could be either the path to a submit script or the content of
347
- # the submission script directly
348
- # we decide what it is by checking for the shebang
349
- if not sbatch_script.startswith("#!"):
350
- # probably path to a file, lets try to read it
351
- with open(sbatch_script, 'r') as f:
352
- sbatch_script = f.read()
353
- # (possibly) use properties to calc the id directly
354
+ self._executable = ""
355
+ # use property setters to (possibly) calc the id directly and/or preprocess values
354
356
  self.sbatch_script = sbatch_script
355
357
  self.executable = executable
356
358
  if call_kwargs is None:
@@ -363,58 +365,95 @@ class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
363
365
  """
364
366
  The jobname of the slurm job used to compute the function results.
365
367
 
366
- Must be unique for each :class:`SlurmTrajectoryFunctionWrapper`
368
+ Also used as part of the filename for the submission script that will
369
+ be written (and deleted if everything goes well) for every trajectory.
370
+
371
+ **NOTE:** Must be unique for each :class:`SlurmTrajectoryFunctionWrapper`
367
372
  instance. Will by default include the persistent unique ID :meth:`id`.
368
373
  To (re)set to the default set it to None.
369
374
  """
370
- if self._slurm_jobname is None:
375
+ if self._slurm_data.jobname is None:
371
376
  return f"CVfunc_id_{self.id}"
372
- return self._slurm_jobname
377
+ return self._slurm_data.jobname
373
378
 
374
379
  @slurm_jobname.setter
375
- def slurm_jobname(self, val: str | None):
376
- self._slurm_jobname = val
380
+ def slurm_jobname(self, val: str | None) -> None:
381
+ self._slurm_data.jobname = val
382
+
383
+ @property
384
+ def sbatch_script(self) -> str:
385
+ """
386
+ Content of the sbatch script (see the corresponding __init__ argument).
387
+
388
+ Can also be set with the path to a file, in this case the script will be read.
389
+ """
390
+ return self._slurm_data.sbatch_script
391
+
392
+ @sbatch_script.setter
393
+ def sbatch_script(self, value: str) -> None:
394
+ # we expect sbatch_script to be a str,
395
+ # but it could be either the path to a submit script or the content of
396
+ # the submission script directly
397
+ # we decide what it is by checking for the shebang
398
+ if not value.startswith("#!"):
399
+ # probably path to a file, lets try to read it
400
+ with open(value, 'r', encoding="locale") as f:
401
+ value = f.read()
402
+ self._slurm_data.sbatch_script = value
403
+
404
+ @property
405
+ def sbatch_options(self) -> dict[str, str] | None:
406
+ """
407
+ Dictionary of sbatch_options or None (see the corresponding __init__ argument).
408
+
409
+ **NOTE:** You can only (re)set the complete dict and not single keys!
410
+ """
411
+ return self._slurm_data.sbatch_options
412
+
413
+ @sbatch_options.setter
414
+ def sbatch_options(self, value: dict[str, str] | None) -> None:
415
+ self._slurm_data.sbatch_options = value
377
416
 
378
417
  def __repr__(self) -> str:
379
418
  return (f"SlurmTrajectoryFunctionWrapper(executable={self._executable}, "
380
419
  + f"call_kwargs={self.call_kwargs})"
381
420
  )
382
421
 
383
- def _get_id_str(self):
422
+ def _get_id_str(self) -> str:
384
423
  # calculate a hash over executable and call_kwargs dict
385
424
  # this should be unique and portable, i.e. it should enable us to make
386
425
  # ensure that the cached values will only be used for the same function
387
426
  # called with the same arguments
388
- id = 0
427
+ _id = 0
389
428
  # NOTE: addition is commutative, i.e. order does not matter here!
390
429
  for k, v in self._call_kwargs.items():
391
430
  # hash the value
392
- id += int(hashlib.blake2b(str(v).encode('utf-8')).hexdigest(), 16)
431
+ _id += int(hashlib.blake2b(str(v).encode('utf-8')).hexdigest(), 16)
393
432
  # hash the key
394
- id += int(hashlib.blake2b(str(k).encode('utf-8')).hexdigest(), 16)
433
+ _id += int(hashlib.blake2b(str(k).encode('utf-8')).hexdigest(), 16)
395
434
  # and add the executable hash
396
435
  with open(self.executable, "rb") as exe_file:
397
436
  # NOTE: we assume that executable is small enough to read at once
398
437
  # if this crashes becasue of OOM we should use chunks...
399
438
  data = exe_file.read()
400
- id += int(hashlib.blake2b(data).hexdigest(), 16)
401
- return str(id) # return a str because we want to use it as dict keys
439
+ _id += int(hashlib.blake2b(data).hexdigest(), 16)
440
+ return str(_id) # return a str because we want to use it as dict keys
402
441
 
403
442
  @property
404
- def executable(self):
443
+ def executable(self) -> str:
405
444
  """The executable used to compute the function results."""
406
445
  return self._executable
407
446
 
408
447
  @executable.setter
409
- def executable(self, val):
448
+ def executable(self, val: str) -> None:
410
449
  exe = ensure_executable_available(val)
411
450
  # if we get here it should be save to set, i.e. it exists + has X-bit
412
451
  self._executable = exe
413
452
  self._id = self._get_id_str() # get the new hash/id
414
453
 
415
- async def get_values_for_trajectory(self, traj):
454
+ async def _get_values_for_trajectory(self, traj: Trajectory) -> np.ndarray:
416
455
  """
417
- Apply wrapped function asyncronously on given trajectory.
456
+ Apply wrapped function asynchronously on given trajectory.
418
457
 
419
458
  Parameters
420
459
  ----------
@@ -423,7 +462,7 @@ class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
423
462
 
424
463
  Returns
425
464
  -------
426
- iterable, usually list or np.ndarray
465
+ np.ndarray
427
466
  The values of the wrapped function when applied on the trajectory.
428
467
  """
429
468
  # first construct the path/name for the numpy npy file in which we expect
@@ -438,22 +477,8 @@ class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
438
477
  result_file = os.path.abspath(os.path.join(
439
478
  tra_dir, f"{tra_name}_{hash_part}_CVfunc_id_{self.id}"
440
479
  ))
441
- # we expect executable to take 3 postional args:
442
- # struct traj outfile
443
- cmd_str = f"{self.executable} {os.path.abspath(traj.structure_file)}"
444
- cmd_str += f" {' '.join(os.path.abspath(t) for t in traj.trajectory_files)}"
445
- cmd_str += f" {result_file}"
446
- if len(self.call_kwargs) > 0:
447
- for key, val in self.call_kwargs.items():
448
- # shell escape only the values,
449
- # the keys (i.e. option names/flags) should be no issue
450
- if isinstance(val, list):
451
- # enable lists of arguments for the same key,
452
- # can then be used e.g. with pythons argparse `nargs="*"` or `nargs="+"`
453
- cmd_str += f" {key} {' '.join([shlex.quote(str(v)) for v in val])}"
454
- else:
455
- cmd_str += f" {key} {shlex.quote(str(val))}"
456
480
  # now prepare the sbatch script
481
+ cmd_str = self._build_cmd_str(traj=traj, result_file=result_file)
457
482
  script = self.sbatch_script.format(cmd_str=cmd_str)
458
483
  # write it out
459
484
  sbatch_fname = os.path.join(tra_dir,
@@ -464,7 +489,7 @@ class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
464
489
  sbatch_fname,
465
490
  )
466
491
  async with _SEMAPHORES["MAX_FILES_OPEN"]:
467
- async with aiofiles.open(sbatch_fname, 'w') as f:
492
+ async with aiofiles.open(sbatch_fname, 'w', encoding="locale") as f:
468
493
  await f.write(script)
469
494
  # NOTE: we set returncode to 2 (what slurmprocess returns in case of
470
495
  # node failure) and rerun/retry until we either get a completed job
@@ -484,7 +509,8 @@ class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
484
509
  self.executable, traj, slurm_proc.slurm_jobid,
485
510
  stdout.decode(), stderr.decode(),
486
511
  )
487
- if returncode != 0:
512
+ if returncode:
513
+ # Non-zero return code,
488
514
  # Can not be exitcode 2, because of the while loop above
489
515
  raise RuntimeError(
490
516
  "Non-zero exit code from CV batch job for "
@@ -496,6 +522,29 @@ class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
496
522
  + f" and stdout was: {stdout.decode()}"
497
523
  )
498
524
  # zero-exitcode: load the results
525
+ return await self._load_results_remove_files(result_file=result_file,
526
+ sbatch_fname=sbatch_fname)
527
+
528
+ def _build_cmd_str(self, traj: Trajectory, result_file: str) -> str:
529
+ # we expect executable to take 3 positional args:
530
+ # struct traj outfile
531
+ cmd_str = f"{self.executable} {os.path.abspath(traj.structure_file)}"
532
+ cmd_str += f" {' '.join(os.path.abspath(t) for t in traj.trajectory_files)}"
533
+ cmd_str += f" {result_file}"
534
+ if len(self.call_kwargs) > 0:
535
+ for key, val in self.call_kwargs.items():
536
+ # shell escape only the values,
537
+ # the keys (i.e. option names/flags) should be no issue
538
+ if isinstance(val, list):
539
+ # enable lists of arguments for the same key,
540
+ # can then be used e.g. with pythons argparse `nargs="*"` or `nargs="+"`
541
+ cmd_str += f" {key} {' '.join([shlex.quote(str(v)) for v in val])}"
542
+ else:
543
+ cmd_str += f" {key} {shlex.quote(str(val))}"
544
+ return cmd_str
545
+
546
+ async def _load_results_remove_files(self, result_file: str,
547
+ sbatch_fname: str) -> np.ndarray:
499
548
  if self.load_results_func is None:
500
549
  # we do not have '.npy' ending in results_file,
501
550
  # numpy.save() adds it if it is not there, so we need it here
@@ -523,25 +572,26 @@ class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
523
572
 
524
573
  async def _run_slurm_job(self, sbatch_fname: str, result_file: str,
525
574
  slurm_workdir: str,
526
- ) -> tuple[int,slurm.SlurmProcess,bytes,bytes]:
575
+ ) -> tuple[int, slurm.SlurmProcess, bytes, bytes]:
527
576
  # submit and run slurm-job
528
- if _SEMAPHORES["SLURM_MAX_JOB"] is not None:
529
- await _SEMAPHORES["SLURM_MAX_JOB"].acquire()
530
- try: # this try is just to make sure we always release the semaphore
531
- slurm_proc = await slurm.create_slurmprocess_submit(
532
- jobname=self.slurm_jobname,
533
- sbatch_script=sbatch_fname,
534
- workdir=slurm_workdir,
535
- stdfiles_removal="success",
536
- stdin=None,
537
- # sleep 5 s between checking
538
- sleep_time=5,
539
- )
577
+ if _OPT_SEMAPHORES["SLURM_MAX_JOB"] is not None:
578
+ await _OPT_SEMAPHORES["SLURM_MAX_JOB"].acquire()
579
+ slurm_proc = await slurm.create_slurmprocess_submit(
580
+ jobname=self.slurm_jobname,
581
+ sbatch_script=sbatch_fname,
582
+ workdir=slurm_workdir,
583
+ sbatch_options=self.sbatch_options,
584
+ stdfiles_removal="success",
585
+ stdin=None,
586
+ # sleep 5 s between checking
587
+ sleep_time=5,
588
+ )
589
+ # this try is just to make sure we always release the semaphore
590
+ # and to possibly clean up any left-over files from the fail
591
+ try:
540
592
  # wait for the slurm job to finish
541
593
  # also cancel the job when this future is canceled
542
594
  stdout, stderr = await slurm_proc.communicate()
543
- returncode = slurm_proc.returncode
544
- return returncode, slurm_proc, stdout, stderr
545
595
  except asyncio.CancelledError:
546
596
  slurm_proc.kill()
547
597
  # clean up the sbatch file and potentialy written result file
@@ -551,6 +601,8 @@ class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
551
601
  remove_file_if_exist_async(res_fname),
552
602
  )
553
603
  raise # reraise CancelledError for encompassing coroutines
604
+ else:
605
+ return slurm_proc.returncode, slurm_proc, stdout, stderr
554
606
  finally:
555
- if _SEMAPHORES["SLURM_MAX_JOB"] is not None:
556
- _SEMAPHORES["SLURM_MAX_JOB"].release()
607
+ if _OPT_SEMAPHORES["SLURM_MAX_JOB"] is not None:
608
+ _OPT_SEMAPHORES["SLURM_MAX_JOB"].release()