asyncmd 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,556 @@
1
+ # This file is part of asyncmd.
2
+ #
3
+ # asyncmd is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # asyncmd is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with asyncmd. If not, see <https://www.gnu.org/licenses/>.
15
+ import os
16
+ import abc
17
+ import shlex
18
+ import asyncio
19
+ import inspect
20
+ import logging
21
+ import hashlib
22
+ import functools
23
+ import typing
24
+ import aiofiles
25
+ import aiofiles.os
26
+ import numpy as np
27
+ from concurrent.futures import ThreadPoolExecutor
28
+
29
+
30
+ from .._config import _SEMAPHORES
31
+ from .. import slurm
32
+ from ..tools import ensure_executable_available, remove_file_if_exist_async
33
+ from .trajectory import Trajectory
34
+
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ # TODO: DaskTrajectoryFunctionWrapper?!
40
+ class TrajectoryFunctionWrapper(abc.ABC):
41
+ """Abstract base class to define the API and some common methods."""
42
+ def __init__(self, **kwargs) -> None:
43
+ # NOTE: in principal we should set these after the stuff set via kwargs
44
+ # (otherwise users could overwrite them by passing _id="blub" to
45
+ # init), but since the subclasses sets call_kwargs again and
46
+ # have to calculate the id according to their own recipe anyway
47
+ # we can savely set them here (this enables us to use the id
48
+ # property at initialization time as e.g. in the slurm_jobname
49
+ # of the SlurmTrajectoryFunctionWrapper)
50
+ self._id = None
51
+ self._call_kwargs = {} # init to empty dict such that iteration works
52
+ # make it possible to set any attribute via kwargs
53
+ # check the type for attributes with default values
54
+ dval = object()
55
+ for kwarg, value in kwargs.items():
56
+ cval = getattr(self, kwarg, dval)
57
+ if cval is not dval:
58
+ if isinstance(value, type(cval)):
59
+ # value is of same type as default so set it
60
+ setattr(self, kwarg, value)
61
+ else:
62
+ raise TypeError(f"Setting attribute {kwarg} with "
63
+ + f"mismatching type ({type(value)}). "
64
+ + f" Default type is {type(cval)}."
65
+ )
66
+ else:
67
+ # not previously defined, so warn that we ignore it
68
+ logger.warning("Ignoring unknown keyword-argument %s.", kwarg)
69
+
70
+ @property
71
+ def id(self) -> str:
72
+ """
73
+ Unique and persistent identifier.
74
+
75
+ Takes into account the wrapped function and its calling arguments.
76
+ """
77
+ return self._id
78
+
79
+ @property
80
+ def call_kwargs(self) -> dict:
81
+ """Additional calling arguments for the wrapped function/executable."""
82
+ # return a copy to avoid people modifying entries without us noticing
83
+ # TODO/FIXME: this will make unhappy users if they try to set single
84
+ # items in the dict!
85
+ return self._call_kwargs.copy()
86
+
87
+ @call_kwargs.setter
88
+ def call_kwargs(self, value):
89
+ if not isinstance(value, dict):
90
+ raise TypeError("call_kwargs must be a dictionary.")
91
+ self._call_kwargs = value
92
+ self._id = self._get_id_str() # get/set ID
93
+
94
+ @abc.abstractmethod
95
+ def _get_id_str(self) -> str:
96
+ # this is expected to return an unique idetifying string
97
+ # this should be unique and portable, i.e. it should enable us to make
98
+ # ensure that the cached values will only be used for the same function
99
+ # called with the same arguments
100
+ pass
101
+
102
+ @abc.abstractmethod
103
+ async def get_values_for_trajectory(self, traj):
104
+ # will be called by trajectory._apply_wrapped_func()
105
+ # is expected to return a numpy array, shape=(n_frames, n_dim_function)
106
+ pass
107
+
108
+ async def __call__(self, value):
109
+ """
110
+ Apply wrapped function asyncronously on given trajectory.
111
+
112
+ Parameters
113
+ ----------
114
+ value : asyncmd.Trajectory
115
+ Input trajectory.
116
+
117
+ Returns
118
+ -------
119
+ iterable, usually list or np.ndarray
120
+ The values of the wrapped function when applied on the trajectory.
121
+ """
122
+ if isinstance(value, Trajectory):
123
+ if self.id is not None:
124
+ return await value._apply_wrapped_func(self.id, self)
125
+ raise RuntimeError(f"{type(self)} must have a unique id, but "
126
+ "self.id is None.")
127
+ raise TypeError(f"{type(self)} must be called with an "
128
+ "`asyncmd.Trajectory` but was called with "
129
+ f"{type(value)}.")
130
+
131
+
132
+ class PyTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
133
+ """
134
+ Wrap python functions for use on :class:`asyncmd.Trajectory`.
135
+
136
+ Turns every python callable into an asyncronous (awaitable) and cached
137
+ function for application on :class:`asyncmd.Trajectory`. Also works for
138
+ asyncronous (awaitable) functions, they will be cached.
139
+
140
+ Attributes
141
+ ----------
142
+ function : callable
143
+ The wrapped callable.
144
+ call_kwargs : dict
145
+ Keyword arguments for wrapped function.
146
+ """
147
+ def __init__(self, function, call_kwargs: typing.Optional[dict] = None,
148
+ **kwargs):
149
+ """
150
+ Initialize a :class:`PyTrajectoryFunctionWrapper`.
151
+
152
+ Parameters
153
+ ----------
154
+ function : callable
155
+ The (synchronous) callable to wrap.
156
+ call_kwargs : dict, optional
157
+ Keyword arguments for `function`,
158
+ the keys will be used as keyword with the corresponding values,
159
+ by default {}
160
+ """
161
+ super().__init__(**kwargs)
162
+ self._func = None
163
+ self._func_src = None
164
+ self._func_is_async = None
165
+ # use the properties to directly calculate/get the id
166
+ self.function = function
167
+ if call_kwargs is None:
168
+ call_kwargs = {}
169
+ self.call_kwargs = call_kwargs
170
+
171
+ def __repr__(self) -> str:
172
+ return (f"PyTrajectoryFunctionWrapper(function={self._func}, "
173
+ + f"call_kwargs={self.call_kwargs})"
174
+ )
175
+
176
+ def _get_id_str(self):
177
+ # calculate a hash over function src and call_kwargs dict
178
+ # this should be unique and portable, i.e. it should enable us to make
179
+ # ensure that the cached values will only be used for the same function
180
+ # called with the same arguments
181
+ id = 0
182
+ # NOTE: addition is commutative, i.e. order does not matter here!
183
+ for k, v in self._call_kwargs.items():
184
+ # hash the value
185
+ id += int(hashlib.blake2b(str(v).encode('utf-8')).hexdigest(), 16)
186
+ # hash the key
187
+ id += int(hashlib.blake2b(str(k).encode('utf-8')).hexdigest(), 16)
188
+ # and add the func_src
189
+ id += int(hashlib.blake2b(str(self._func_src).encode('utf-8')).hexdigest(), 16)
190
+ return str(id) # return a str because we want to use it as dict keys
191
+
192
+ @property
193
+ def function(self):
194
+ """
195
+ The python callable this :class:`PyTrajectoryFunctionWrapper` wrapps.
196
+ """
197
+ return self._func
198
+
199
+ @function.setter
200
+ def function(self, value):
201
+ self._func_is_async = (inspect.iscoroutinefunction(value)
202
+ or inspect.iscoroutinefunction(value.__call__))
203
+ try:
204
+ src = inspect.getsource(value)
205
+ except OSError:
206
+ # OSError is raised if source can not be retrieved
207
+ self._func_src = None
208
+ self._id = None
209
+ logger.warning("Could not retrieve source for %s."
210
+ " No caching can/will be performed.",
211
+ value,
212
+ )
213
+ else:
214
+ self._func_src = src
215
+ self._id = self._get_id_str() # get/set ID
216
+ finally:
217
+ self._func = value
218
+
219
+ async def get_values_for_trajectory(self, traj):
220
+ """
221
+ Apply wrapped function asyncronously on given trajectory.
222
+
223
+ Parameters
224
+ ----------
225
+ traj : asyncmd.Trajectory
226
+ Input trajectory.
227
+
228
+ Returns
229
+ -------
230
+ iterable, usually list or np.ndarray
231
+ The values of the wrapped function when applied on the trajectory.
232
+ """
233
+ if self._func_is_async:
234
+ return await self._get_values_for_trajectory_async(traj)
235
+ return await self._get_values_for_trajectory_sync(traj)
236
+
237
+ async def _get_values_for_trajectory_async(self, traj):
238
+ return await self._func(traj, **self._call_kwargs)
239
+
240
+ async def _get_values_for_trajectory_sync(self, traj):
241
+ loop = asyncio.get_running_loop()
242
+ async with _SEMAPHORES["MAX_PROCESS"]:
243
+ # fill in additional kwargs (if any)
244
+ if len(self.call_kwargs) > 0:
245
+ func = functools.partial(self.function, **self._call_kwargs)
246
+ else:
247
+ func = self.function
248
+ # NOTE: even though one would expect pythonCVs to be CPU bound
249
+ # it is actually faster to use a ThreadPoolExecutor because
250
+ # we then skip the setup + import needed for a second process
251
+ # In addition most pythonCVs will actually call c/cython-code
252
+ # like MDAnalysis/mdtraj/etc and are therefore not limited
253
+ # by the GIL anyway
254
+ # We leave the code for ProcessPool here because this is the
255
+ # only place where this could make sense to think about as
256
+ # opposed to concatenation of trajs (which is IO bound)
257
+ # NOTE: make sure we do not fork! (not save with multithreading)
258
+ # see e.g. https://stackoverflow.com/questions/46439740/safe-to-call-multiprocessing-from-a-thread-in-python
259
+ #ctx = multiprocessing.get_context("forkserver")
260
+ #with ProcessPoolExecutor(1, mp_context=ctx) as pool:
261
+ with ThreadPoolExecutor(max_workers=1,
262
+ thread_name_prefix="PyTrajFunc_thread",
263
+ ) as pool:
264
+ vals = await loop.run_in_executor(pool, func, traj)
265
+ return vals
266
+
267
+
268
+ class SlurmTrajectoryFunctionWrapper(TrajectoryFunctionWrapper):
269
+ """
270
+ Wrap executables to use on :class:`asyncmd.Trajectory` via SLURM.
271
+
272
+ The execution of the job is submited to the queueing system with the
273
+ given sbatch script (template).
274
+ The executable will be called with the following positional arguments:
275
+
276
+ - full filepath of the structure file associated with the trajectory
277
+
278
+ - full filepath of the trajectory to calculate values for, note that
279
+ multipart trajectories result in multiple files/arguments here.
280
+
281
+ - full filepath of the file the results should be written to without
282
+ fileending. Note that if no custom loading function is supplied we
283
+ expect that the written file has 'npy' format and the added ending
284
+ '.npy', i.e. we expect the executable to add the ending '.npy' to
285
+ the passed filepath (as e.g. ``np.save($FILEPATH, data)`` would do)
286
+
287
+ - any additional arguments from call_kwargs are added as
288
+ ``" {key} {value}" for key, value in call_kwargs.items()``
289
+
290
+ See also the examples for a reference (python) implementation of multiple
291
+ different functions/executables for use with this class.
292
+
293
+ Attributes
294
+ ----------
295
+ slurm_jobname : str
296
+ Used as name for the job in slurm and also as part of the filename for
297
+ the submission script that will be written (and deleted if everything
298
+ goes well) for every trajectory.
299
+ NOTE: If you modify it, ensure that each SlurmTrajectoryWrapper has a
300
+ unique slurm_jobname.
301
+ executable : str
302
+ Name of or path to the wrapped executable.
303
+ call_kwargs : dict
304
+ Keyword arguments for wrapped executable.
305
+ """
306
+
307
+ def __init__(self, executable, sbatch_script,
308
+ call_kwargs: typing.Optional[dict] = None,
309
+ load_results_func=None, **kwargs):
310
+ """
311
+ Initialize :class:`SlurmTrajectoryFunctionWrapper`.
312
+
313
+ Note that all attributes can be set via __init__ by passing them as
314
+ keyword arguments.
315
+
316
+ Parameters
317
+ ----------
318
+ executable : str
319
+ Absolute or relative path to an executable or name of an executable
320
+ available via the environment (e.g. via the $PATH variable on LINUX)
321
+ sbatch_script : str
322
+ Path to a sbatch submission script file or string with the content
323
+ of a submission script. Note that the submission script must
324
+ contain the following placeholders (also see the examples folder):
325
+
326
+ - {cmd_str} : Replaced by the command to call the executable on a given trajectory.
327
+
328
+ call_kwargs : dict, optional
329
+ Dictionary of additional arguments to pass to the executable, they
330
+ will be added to the call as pair ' {key} {val}', note that in case
331
+ you want to pass single command line flags (like '-v') this can be
332
+ achieved by setting key='-v' and val='', i.e. to the empty string.
333
+ The values are shell escaped using `shlex.quote()` when writing
334
+ them to the sbatch script.
335
+ load_results_func : None or function (callable)
336
+ Function to call to customize the loading of the results.
337
+ If a function is supplied, it will be called with the full path to
338
+ the results file (as in the call to the executable) and should
339
+ return a numpy array containing the loaded values.
340
+ """
341
+ # property defaults before superclass init to be resettable via kwargs
342
+ self._slurm_jobname = None
343
+ super().__init__(**kwargs)
344
+ self._executable = None
345
+ # we expect sbatch_script to be a str,
346
+ # but it could be either the path to a submit script or the content of
347
+ # the submission script directly
348
+ # we decide what it is by checking for the shebang
349
+ if not sbatch_script.startswith("#!"):
350
+ # probably path to a file, lets try to read it
351
+ with open(sbatch_script, 'r') as f:
352
+ sbatch_script = f.read()
353
+ # (possibly) use properties to calc the id directly
354
+ self.sbatch_script = sbatch_script
355
+ self.executable = executable
356
+ if call_kwargs is None:
357
+ call_kwargs = {}
358
+ self.call_kwargs = call_kwargs
359
+ self.load_results_func = load_results_func
360
+
361
+ @property
362
+ def slurm_jobname(self) -> str:
363
+ """
364
+ The jobname of the slurm job used to compute the function results.
365
+
366
+ Must be unique for each :class:`SlurmTrajectoryFunctionWrapper`
367
+ instance. Will by default include the persistent unique ID :meth:`id`.
368
+ To (re)set to the default set it to None.
369
+ """
370
+ if self._slurm_jobname is None:
371
+ return f"CVfunc_id_{self.id}"
372
+ return self._slurm_jobname
373
+
374
+ @slurm_jobname.setter
375
+ def slurm_jobname(self, val: str | None):
376
+ self._slurm_jobname = val
377
+
378
+ def __repr__(self) -> str:
379
+ return (f"SlurmTrajectoryFunctionWrapper(executable={self._executable}, "
380
+ + f"call_kwargs={self.call_kwargs})"
381
+ )
382
+
383
+ def _get_id_str(self):
384
+ # calculate a hash over executable and call_kwargs dict
385
+ # this should be unique and portable, i.e. it should enable us to make
386
+ # ensure that the cached values will only be used for the same function
387
+ # called with the same arguments
388
+ id = 0
389
+ # NOTE: addition is commutative, i.e. order does not matter here!
390
+ for k, v in self._call_kwargs.items():
391
+ # hash the value
392
+ id += int(hashlib.blake2b(str(v).encode('utf-8')).hexdigest(), 16)
393
+ # hash the key
394
+ id += int(hashlib.blake2b(str(k).encode('utf-8')).hexdigest(), 16)
395
+ # and add the executable hash
396
+ with open(self.executable, "rb") as exe_file:
397
+ # NOTE: we assume that executable is small enough to read at once
398
+ # if this crashes becasue of OOM we should use chunks...
399
+ data = exe_file.read()
400
+ id += int(hashlib.blake2b(data).hexdigest(), 16)
401
+ return str(id) # return a str because we want to use it as dict keys
402
+
403
+ @property
404
+ def executable(self):
405
+ """The executable used to compute the function results."""
406
+ return self._executable
407
+
408
+ @executable.setter
409
+ def executable(self, val):
410
+ exe = ensure_executable_available(val)
411
+ # if we get here it should be save to set, i.e. it exists + has X-bit
412
+ self._executable = exe
413
+ self._id = self._get_id_str() # get the new hash/id
414
+
415
+ async def get_values_for_trajectory(self, traj):
416
+ """
417
+ Apply wrapped function asyncronously on given trajectory.
418
+
419
+ Parameters
420
+ ----------
421
+ traj : asyncmd.Trajectory
422
+ Input trajectory.
423
+
424
+ Returns
425
+ -------
426
+ iterable, usually list or np.ndarray
427
+ The values of the wrapped function when applied on the trajectory.
428
+ """
429
+ # first construct the path/name for the numpy npy file in which we expect
430
+ # the results to be written
431
+ tra_dir, tra_name = os.path.split(traj.trajectory_files[0])
432
+ if len(traj.trajectory_files) > 1:
433
+ tra_name += f"_len{len(traj.trajectory_files)}multipart"
434
+ hash_part = str(traj.trajectory_hash)[:5]
435
+ # put in the hash (calculated over all traj parts for multipart)
436
+ # to make sure trajectories with the same first part but different
437
+ # remaining parts dont get mixed up
438
+ result_file = os.path.abspath(os.path.join(
439
+ tra_dir, f"{tra_name}_{hash_part}_CVfunc_id_{self.id}"
440
+ ))
441
+ # we expect executable to take 3 postional args:
442
+ # struct traj outfile
443
+ cmd_str = f"{self.executable} {os.path.abspath(traj.structure_file)}"
444
+ cmd_str += f" {' '.join(os.path.abspath(t) for t in traj.trajectory_files)}"
445
+ cmd_str += f" {result_file}"
446
+ if len(self.call_kwargs) > 0:
447
+ for key, val in self.call_kwargs.items():
448
+ # shell escape only the values,
449
+ # the keys (i.e. option names/flags) should be no issue
450
+ if isinstance(val, list):
451
+ # enable lists of arguments for the same key,
452
+ # can then be used e.g. with pythons argparse `nargs="*"` or `nargs="+"`
453
+ cmd_str += f" {key} {' '.join([shlex.quote(str(v)) for v in val])}"
454
+ else:
455
+ cmd_str += f" {key} {shlex.quote(str(val))}"
456
+ # now prepare the sbatch script
457
+ script = self.sbatch_script.format(cmd_str=cmd_str)
458
+ # write it out
459
+ sbatch_fname = os.path.join(tra_dir,
460
+ tra_name + "_" + self.slurm_jobname + ".slurm")
461
+ if os.path.exists(sbatch_fname):
462
+ logger.error("Overwriting existing submission file (%s)."
463
+ " Are you sure your `slurm_jobname` is unique?",
464
+ sbatch_fname,
465
+ )
466
+ async with _SEMAPHORES["MAX_FILES_OPEN"]:
467
+ async with aiofiles.open(sbatch_fname, 'w') as f:
468
+ await f.write(script)
469
+ # NOTE: we set returncode to 2 (what slurmprocess returns in case of
470
+ # node failure) and rerun/retry until we either get a completed job
471
+ # or a non-node-failure error
472
+ returncode = 2
473
+ while returncode == 2:
474
+ # run slurm job
475
+ returncode, slurm_proc, stdout, stderr = await self._run_slurm_job(
476
+ sbatch_fname=sbatch_fname,
477
+ result_file=result_file,
478
+ slurm_workdir=tra_dir,
479
+ )
480
+ if returncode == 2:
481
+ logger.error("Exit code indicating node fail from CV batch job"
482
+ " for executable %s on trajectory %s (slurm jobid"
483
+ " %s). stderr was: %s. stdout was: %s",
484
+ self.executable, traj, slurm_proc.slurm_jobid,
485
+ stdout.decode(), stderr.decode(),
486
+ )
487
+ if returncode != 0:
488
+ # Can not be exitcode 2, because of the while loop above
489
+ raise RuntimeError(
490
+ "Non-zero exit code from CV batch job for "
491
+ + f"executable {self.executable} on "
492
+ + f"trajectory {traj} "
493
+ + f"(slurm jobid {slurm_proc.slurm_jobid})."
494
+ + f" Exit code was: {returncode}."
495
+ + f" stderr was: {stderr.decode()}."
496
+ + f" and stdout was: {stdout.decode()}"
497
+ )
498
+ # zero-exitcode: load the results
499
+ if self.load_results_func is None:
500
+ # we do not have '.npy' ending in results_file,
501
+ # numpy.save() adds it if it is not there, so we need it here
502
+ load_func = np.load
503
+ fname_results = result_file + ".npy"
504
+ else:
505
+ # use custom loading function from user
506
+ load_func = self.load_results_func
507
+ fname_results = result_file
508
+ # use a separate thread to load so we dont block with the io
509
+ loop = asyncio.get_running_loop()
510
+ async with _SEMAPHORES["MAX_FILES_OPEN"]:
511
+ async with _SEMAPHORES["MAX_PROCESS"]:
512
+ with ThreadPoolExecutor(
513
+ max_workers=1,
514
+ thread_name_prefix="SlurmTrajFunc_load_thread",
515
+ ) as pool:
516
+ vals = await loop.run_in_executor(pool, load_func,
517
+ fname_results)
518
+ # remove the results file and sbatch script
519
+ await asyncio.gather(remove_file_if_exist_async(fname_results),
520
+ remove_file_if_exist_async(sbatch_fname),
521
+ )
522
+ return vals
523
+
524
+ async def _run_slurm_job(self, sbatch_fname: str, result_file: str,
525
+ slurm_workdir: str,
526
+ ) -> tuple[int,slurm.SlurmProcess,bytes,bytes]:
527
+ # submit and run slurm-job
528
+ if _SEMAPHORES["SLURM_MAX_JOB"] is not None:
529
+ await _SEMAPHORES["SLURM_MAX_JOB"].acquire()
530
+ try: # this try is just to make sure we always release the semaphore
531
+ slurm_proc = await slurm.create_slurmprocess_submit(
532
+ jobname=self.slurm_jobname,
533
+ sbatch_script=sbatch_fname,
534
+ workdir=slurm_workdir,
535
+ stdfiles_removal="success",
536
+ stdin=None,
537
+ # sleep 5 s between checking
538
+ sleep_time=5,
539
+ )
540
+ # wait for the slurm job to finish
541
+ # also cancel the job when this future is canceled
542
+ stdout, stderr = await slurm_proc.communicate()
543
+ returncode = slurm_proc.returncode
544
+ return returncode, slurm_proc, stdout, stderr
545
+ except asyncio.CancelledError:
546
+ slurm_proc.kill()
547
+ # clean up the sbatch file and potentialy written result file
548
+ res_fname = (result_file + ".npy" if self.load_results_func is None
549
+ else result_file)
550
+ await asyncio.gather(remove_file_if_exist_async(sbatch_fname),
551
+ remove_file_if_exist_async(res_fname),
552
+ )
553
+ raise # reraise CancelledError for encompassing coroutines
554
+ finally:
555
+ if _SEMAPHORES["SLURM_MAX_JOB"] is not None:
556
+ _SEMAPHORES["SLURM_MAX_JOB"].release()