imdclient 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imdclient/IMDClient.py +896 -0
- imdclient/IMDProtocol.py +164 -0
- imdclient/IMDREADER.py +129 -0
- imdclient/__init__.py +14 -0
- imdclient/backends.py +352 -0
- imdclient/data/__init__.py +0 -0
- imdclient/data/gromacs/md/gromacs_struct.gro +21151 -0
- imdclient/data/gromacs/md/gromacs_v3.top +11764 -0
- imdclient/data/gromacs/md/gromacs_v3_nst1.mdp +58 -0
- imdclient/data/gromacs/md/gromacs_v3_nst1.tpr +0 -0
- imdclient/data/gromacs/md/gromacs_v3_nst1.trr +0 -0
- imdclient/data/lammps/md/lammps_topol.data +8022 -0
- imdclient/data/lammps/md/lammps_trj.h5md +0 -0
- imdclient/data/lammps/md/lammps_v3.in +71 -0
- imdclient/data/namd/md/alanin.dcd +0 -0
- imdclient/data/namd/md/alanin.params +402 -0
- imdclient/data/namd/md/alanin.pdb +77 -0
- imdclient/data/namd/md/alanin.psf +206 -0
- imdclient/data/namd/md/namd_v3.namd +47 -0
- imdclient/results.py +332 -0
- imdclient/streamanalysis.py +1056 -0
- imdclient/streambase.py +199 -0
- imdclient/tests/__init__.py +0 -0
- imdclient/tests/base.py +122 -0
- imdclient/tests/conftest.py +38 -0
- imdclient/tests/datafiles.py +34 -0
- imdclient/tests/server.py +212 -0
- imdclient/tests/test_gromacs.py +33 -0
- imdclient/tests/test_imdclient.py +150 -0
- imdclient/tests/test_imdreader.py +644 -0
- imdclient/tests/test_lammps.py +38 -0
- imdclient/tests/test_manual.py +70 -0
- imdclient/tests/test_namd.py +38 -0
- imdclient/tests/test_stream_analysis.py +61 -0
- imdclient/tests/utils.py +41 -0
- imdclient/utils.py +118 -0
- imdclient-0.1.0.dist-info/AUTHORS.md +23 -0
- imdclient-0.1.0.dist-info/LICENSE +21 -0
- imdclient-0.1.0.dist-info/METADATA +143 -0
- imdclient-0.1.0.dist-info/RECORD +42 -0
- imdclient-0.1.0.dist-info/WHEEL +5 -0
- imdclient-0.1.0.dist-info/top_level.txt +1 -0
imdclient/IMDProtocol.py
ADDED
@@ -0,0 +1,164 @@
|
|
1
|
+
import struct
|
2
|
+
import logging
|
3
|
+
from enum import Enum, auto
|
4
|
+
from typing import Union
|
5
|
+
from dataclasses import dataclass
|
6
|
+
import numpy as np
|
7
|
+
|
8
|
+
|
9
|
+
IMDHEADERSIZE = 8
|
10
|
+
IMDENERGYPACKETLENGTH = 40
|
11
|
+
IMDBOXPACKETLENGTH = 36
|
12
|
+
IMDTIMEPACKETLENGTH = 24
|
13
|
+
IMDVERSIONS = {2, 3}
|
14
|
+
IMDAWAITGOTIME = 1
|
15
|
+
|
16
|
+
logger = logging.getLogger("imdclient.IMDClient")
|
17
|
+
|
18
|
+
|
19
|
+
class IMDHeaderType(Enum):
|
20
|
+
IMD_DISCONNECT = 0
|
21
|
+
IMD_ENERGIES = 1
|
22
|
+
IMD_FCOORDS = 2
|
23
|
+
IMD_GO = 3
|
24
|
+
IMD_HANDSHAKE = 4
|
25
|
+
IMD_KILL = 5
|
26
|
+
IMD_MDCOMM = 6
|
27
|
+
IMD_PAUSE = 7
|
28
|
+
IMD_TRATE = 8
|
29
|
+
IMD_IOERROR = 9
|
30
|
+
# New in IMD v3
|
31
|
+
IMD_SESSIONINFO = 10
|
32
|
+
IMD_RESUME = 11
|
33
|
+
IMD_TIME = 12
|
34
|
+
IMD_BOX = 13
|
35
|
+
IMD_VELOCITIES = 14
|
36
|
+
IMD_FORCES = 15
|
37
|
+
|
38
|
+
|
39
|
+
def parse_energy_bytes(data, endianness):
|
40
|
+
keys = [
|
41
|
+
"step",
|
42
|
+
"temperature",
|
43
|
+
"total_energy",
|
44
|
+
"potential_energy",
|
45
|
+
"van_der_walls_energy",
|
46
|
+
"coulomb_energy",
|
47
|
+
"bonds_energy",
|
48
|
+
"angles_energy",
|
49
|
+
"dihedrals_energy",
|
50
|
+
"improper_dihedrals_energy",
|
51
|
+
]
|
52
|
+
values = struct.unpack(f"{endianness}ifffffffff", data)
|
53
|
+
return dict(zip(keys, values))
|
54
|
+
|
55
|
+
|
56
|
+
def create_energy_bytes(
|
57
|
+
step,
|
58
|
+
temperature,
|
59
|
+
total_energy,
|
60
|
+
potential_energy,
|
61
|
+
van_der_walls_energy,
|
62
|
+
coulomb_energy,
|
63
|
+
bonds_energy,
|
64
|
+
angles_energy,
|
65
|
+
dihedrals_energy,
|
66
|
+
improper_dihedrals_energy,
|
67
|
+
endianness,
|
68
|
+
):
|
69
|
+
return struct.pack(
|
70
|
+
f"{endianness}ifffffffff",
|
71
|
+
step,
|
72
|
+
temperature,
|
73
|
+
total_energy,
|
74
|
+
potential_energy,
|
75
|
+
van_der_walls_energy,
|
76
|
+
coulomb_energy,
|
77
|
+
bonds_energy,
|
78
|
+
angles_energy,
|
79
|
+
dihedrals_energy,
|
80
|
+
improper_dihedrals_energy,
|
81
|
+
)
|
82
|
+
|
83
|
+
|
84
|
+
def parse_box_bytes(data, endianness):
|
85
|
+
"""Box is a 3x3 matrix of floats"""
|
86
|
+
|
87
|
+
vals = struct.unpack(f"{endianness}fffffffff", data)
|
88
|
+
return np.array(vals).reshape(3, 3)
|
89
|
+
|
90
|
+
|
91
|
+
class IMDHeader:
|
92
|
+
"""Convenience class to represent the header of an IMD packet"""
|
93
|
+
|
94
|
+
def __init__(self, data):
|
95
|
+
msg_type, length = struct.unpack("!ii", data)
|
96
|
+
h_type = IMDHeaderType(msg_type)
|
97
|
+
|
98
|
+
self.type = h_type
|
99
|
+
self.length = length
|
100
|
+
|
101
|
+
|
102
|
+
class IMDTime:
|
103
|
+
"""Convenience class to represent the body of time packet"""
|
104
|
+
|
105
|
+
def __init__(self, data, endianness):
|
106
|
+
self.dt, self.time, self.step = struct.unpack(f"{endianness}ddq", data)
|
107
|
+
|
108
|
+
|
109
|
+
@dataclass
|
110
|
+
class IMDSessionInfo:
|
111
|
+
"""Convenience class to represent the session information of an IMD connection
|
112
|
+
|
113
|
+
'<' represents little endian and '>' represents big endian
|
114
|
+
"""
|
115
|
+
|
116
|
+
version: int
|
117
|
+
endianness: str
|
118
|
+
# In IMDv2, we don't know if the server sends wrapped coordinates until
|
119
|
+
# we receive the first packet.
|
120
|
+
wrapped_coords: bool
|
121
|
+
# In IMDv2, we don't know if the server sends energies until
|
122
|
+
# we receive the first packet.
|
123
|
+
energies: bool
|
124
|
+
time: bool
|
125
|
+
box: bool
|
126
|
+
positions: bool
|
127
|
+
velocities: bool
|
128
|
+
forces: bool
|
129
|
+
|
130
|
+
|
131
|
+
def parse_imdv3_session_info(data, end):
|
132
|
+
"""Parses the session information packet of an IMD v3 connection"""
|
133
|
+
logger.debug(f"parse_imdv3_session_info: {data}")
|
134
|
+
time, energies, box, positions, wrapped_coords, velocties, forces = (
|
135
|
+
struct.unpack(f"{end}BBBBBBB", data)
|
136
|
+
)
|
137
|
+
logger.debug(f"parse_imdv3_session_info2 : {data}")
|
138
|
+
imdsinfo = IMDSessionInfo(
|
139
|
+
version=3,
|
140
|
+
endianness=end,
|
141
|
+
time=(time != 0),
|
142
|
+
box=(box != 0),
|
143
|
+
positions=(positions != 0),
|
144
|
+
wrapped_coords=(wrapped_coords != 0),
|
145
|
+
velocities=(velocties != 0),
|
146
|
+
forces=(forces != 0),
|
147
|
+
energies=(energies != 0),
|
148
|
+
)
|
149
|
+
logger.debug(f"parse_imdv3_session_info3: {data}")
|
150
|
+
return imdsinfo
|
151
|
+
|
152
|
+
|
153
|
+
def create_header_bytes(msg_type: IMDHeaderType, length: int):
|
154
|
+
# NOTE: add error checking for invalid packet msg_type here
|
155
|
+
|
156
|
+
type = msg_type.value
|
157
|
+
return struct.pack("!ii", type, length)
|
158
|
+
|
159
|
+
|
160
|
+
def parse_header_bytes(data):
|
161
|
+
msg_type, length = struct.unpack("!ii", data)
|
162
|
+
type = IMDHeaderType(msg_type)
|
163
|
+
# NOTE: add error checking for invalid packet msg_type here
|
164
|
+
return IMDHeader(type, length)
|
imdclient/IMDREADER.py
ADDED
@@ -0,0 +1,129 @@
|
|
1
|
+
"""
|
2
|
+
MDAnalysis IMDReader
|
3
|
+
^^^^^^^^^^^^^^^^^^^^
|
4
|
+
|
5
|
+
.. autoclass:: IMDReader
|
6
|
+
:members:
|
7
|
+
:inherited-members:
|
8
|
+
|
9
|
+
"""
|
10
|
+
|
11
|
+
from MDAnalysis.coordinates import core
|
12
|
+
from MDAnalysis.lib.util import store_init_arguments
|
13
|
+
|
14
|
+
# NOTE: changeme
|
15
|
+
from .IMDClient import IMDClient
|
16
|
+
from .utils import *
|
17
|
+
import logging
|
18
|
+
|
19
|
+
from .streambase import StreamReaderBase
|
20
|
+
|
21
|
+
logger = logging.getLogger("imdclient.IMDClient")
|
22
|
+
|
23
|
+
|
24
|
+
class IMDReader(StreamReaderBase):
|
25
|
+
"""
|
26
|
+
Reader for IMD protocol packets.
|
27
|
+
"""
|
28
|
+
|
29
|
+
format = "IMD"
|
30
|
+
one_pass = True
|
31
|
+
|
32
|
+
@store_init_arguments
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
filename,
|
36
|
+
convert_units=True,
|
37
|
+
n_atoms=None,
|
38
|
+
**kwargs,
|
39
|
+
):
|
40
|
+
"""
|
41
|
+
Parameters
|
42
|
+
----------
|
43
|
+
filename : a string of the form "host:port" where host is the hostname
|
44
|
+
or IP address of the listening GROMACS server and port
|
45
|
+
is the port number.
|
46
|
+
n_atoms : int (optional)
|
47
|
+
number of atoms in the system. defaults to number of atoms
|
48
|
+
in the topology. don't set this unless you know what you're doing.
|
49
|
+
"""
|
50
|
+
|
51
|
+
super(IMDReader, self).__init__(filename, **kwargs)
|
52
|
+
|
53
|
+
logger.debug("IMDReader initializing")
|
54
|
+
|
55
|
+
if n_atoms is None:
|
56
|
+
raise ValueError("IMDReader: n_atoms must be specified")
|
57
|
+
self.n_atoms = n_atoms
|
58
|
+
|
59
|
+
host, port = parse_host_port(filename)
|
60
|
+
|
61
|
+
# This starts the simulation
|
62
|
+
self._imdclient = IMDClient(host, port, n_atoms, **kwargs)
|
63
|
+
|
64
|
+
imdsinfo = self._imdclient.get_imdsessioninfo()
|
65
|
+
# NOTE: after testing phase, fail out on IMDv2
|
66
|
+
|
67
|
+
self.ts = self._Timestep(
|
68
|
+
self.n_atoms,
|
69
|
+
positions=imdsinfo.positions,
|
70
|
+
velocities=imdsinfo.velocities,
|
71
|
+
forces=imdsinfo.forces,
|
72
|
+
**self._ts_kwargs,
|
73
|
+
)
|
74
|
+
|
75
|
+
self._frame = -1
|
76
|
+
|
77
|
+
try:
|
78
|
+
self._read_next_timestep()
|
79
|
+
except StopIteration:
|
80
|
+
raise RuntimeError("IMDReader: No data found in stream")
|
81
|
+
|
82
|
+
def _read_frame(self, frame):
|
83
|
+
|
84
|
+
try:
|
85
|
+
imdf = self._imdclient.get_imdframe()
|
86
|
+
except EOFError:
|
87
|
+
# Not strictly necessary, but for clarity
|
88
|
+
raise StopIteration
|
89
|
+
|
90
|
+
self._frame = frame
|
91
|
+
self._load_imdframe_into_ts(imdf)
|
92
|
+
|
93
|
+
logger.debug(f"IMDReader: Loaded frame {self._frame}")
|
94
|
+
return self.ts
|
95
|
+
|
96
|
+
def _load_imdframe_into_ts(self, imdf):
|
97
|
+
self.ts.frame = self._frame
|
98
|
+
if imdf.time is not None:
|
99
|
+
self.ts.time = imdf.time
|
100
|
+
# NOTE: timestep.pyx "dt" method is suspicious bc it uses "new" keyword for a float
|
101
|
+
self.ts.data["dt"] = imdf.dt
|
102
|
+
self.ts.data["step"] = imdf.step
|
103
|
+
if imdf.energies is not None:
|
104
|
+
self.ts.data.update(imdf.energies)
|
105
|
+
if imdf.box is not None:
|
106
|
+
self.ts.dimensions = core.triclinic_box(*imdf.box)
|
107
|
+
if imdf.positions is not None:
|
108
|
+
# must call copy because reference is expected to reset
|
109
|
+
# see 'test_frame_collect_all_same' in MDAnalysisTests.coordinates.base
|
110
|
+
self.ts.positions = imdf.positions
|
111
|
+
if imdf.velocities is not None:
|
112
|
+
self.ts.velocities = imdf.velocities
|
113
|
+
if imdf.forces is not None:
|
114
|
+
self.ts.forces = imdf.forces
|
115
|
+
|
116
|
+
@staticmethod
|
117
|
+
def _format_hint(thing):
|
118
|
+
try:
|
119
|
+
parse_host_port(thing)
|
120
|
+
except:
|
121
|
+
return False
|
122
|
+
return True
|
123
|
+
|
124
|
+
def close(self):
|
125
|
+
"""Gracefully shut down the reader. Stops the producer thread."""
|
126
|
+
logger.debug("IMDReader close() called")
|
127
|
+
self._imdclient.stop()
|
128
|
+
# NOTE: removeme after testing
|
129
|
+
logger.debug("IMDReader shut down gracefully.")
|
imdclient/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
1
|
+
"""
|
2
|
+
IMDClient
|
3
|
+
"""
|
4
|
+
|
5
|
+
# Don't import IMDReader here, eventually it may be moved to a separate package
|
6
|
+
from .IMDClient import IMDClient
|
7
|
+
from importlib.metadata import version
|
8
|
+
|
9
|
+
from .streamanalysis import AnalysisBase, StackableAnalysis
|
10
|
+
from MDAnalysis.analysis import base
|
11
|
+
|
12
|
+
base.AnalysisBase = AnalysisBase
|
13
|
+
|
14
|
+
__version__ = version("imdclient")
|
imdclient/backends.py
ADDED
@@ -0,0 +1,352 @@
|
|
1
|
+
# Copy of backends from MDA 2.8.0
|
2
|
+
"""Analysis backends --- :mod:`MDAnalysis.analysis.backends`
|
3
|
+
============================================================
|
4
|
+
|
5
|
+
.. versionadded:: 2.8.0
|
6
|
+
|
7
|
+
|
8
|
+
The :mod:`backends` module provides :class:`BackendBase` base class to
|
9
|
+
implement custom execution backends for
|
10
|
+
:meth:`MDAnalysis.analysis.base.AnalysisBase.run` and its
|
11
|
+
subclasses.
|
12
|
+
|
13
|
+
.. SeeAlso:: :ref:`parallel-analysis`
|
14
|
+
|
15
|
+
.. _backends:
|
16
|
+
|
17
|
+
Backends
|
18
|
+
--------
|
19
|
+
|
20
|
+
Three built-in backend classes are provided:
|
21
|
+
|
22
|
+
* *serial*: :class:`BackendSerial`, that is equivalent to using no
|
23
|
+
parallelization and is the default
|
24
|
+
|
25
|
+
* *multiprocessing*: :class:`BackendMultiprocessing` that supports
|
26
|
+
parallelization via standard Python :mod:`multiprocessing` module
|
27
|
+
and uses default :mod:`pickle` serialization
|
28
|
+
|
29
|
+
* *dask*: :class:`BackendDask`, that uses the same process-based
|
30
|
+
parallelization as :class:`BackendMultiprocessing`, but different
|
31
|
+
serialization algorithm via `dask <https://dask.org/>`_ (see `dask
|
32
|
+
serialization algorithms
|
33
|
+
<https://distributed.dask.org/en/latest/serialization.html>`_ for details)
|
34
|
+
|
35
|
+
Classes
|
36
|
+
-------
|
37
|
+
|
38
|
+
"""
|
39
|
+
import warnings
|
40
|
+
from typing import Callable
|
41
|
+
import importlib.util
|
42
|
+
|
43
|
+
|
44
|
+
def is_installed(modulename: str):
|
45
|
+
"""Checks if module is installed
|
46
|
+
|
47
|
+
Parameters
|
48
|
+
----------
|
49
|
+
modulename : str
|
50
|
+
name of the module to be tested
|
51
|
+
|
52
|
+
|
53
|
+
.. versionadded:: 2.8.0
|
54
|
+
"""
|
55
|
+
return importlib.util.find_spec(modulename) is not None
|
56
|
+
|
57
|
+
|
58
|
+
class BackendBase:
|
59
|
+
"""Base class for backend implementation.
|
60
|
+
|
61
|
+
Initializes an instance and performs checks for its validity, such as
|
62
|
+
``n_workers`` and possibly other ones.
|
63
|
+
|
64
|
+
Parameters
|
65
|
+
----------
|
66
|
+
n_workers : int
|
67
|
+
number of workers (usually, processes) over which the work is split
|
68
|
+
|
69
|
+
Examples
|
70
|
+
--------
|
71
|
+
.. code-block:: python
|
72
|
+
|
73
|
+
from MDAnalysis.analysis.backends import BackendBase
|
74
|
+
|
75
|
+
class ThreadsBackend(BackendBase):
|
76
|
+
def apply(self, func, computations):
|
77
|
+
from multiprocessing.dummy import Pool
|
78
|
+
|
79
|
+
with Pool(processes=self.n_workers) as pool:
|
80
|
+
results = pool.map(func, computations)
|
81
|
+
return results
|
82
|
+
|
83
|
+
import MDAnalysis as mda
|
84
|
+
from MDAnalysis.tests.datafiles import PSF, DCD
|
85
|
+
from MDAnalysis.analysis.rms import RMSD
|
86
|
+
|
87
|
+
u = mda.Universe(PSF, DCD)
|
88
|
+
ref = mda.Universe(PSF, DCD)
|
89
|
+
|
90
|
+
R = RMSD(u, ref)
|
91
|
+
|
92
|
+
n_workers = 2
|
93
|
+
backend = ThreadsBackend(n_workers=n_workers)
|
94
|
+
R.run(backend=backend, unsupported_backend=True)
|
95
|
+
|
96
|
+
.. warning::
|
97
|
+
Using `ThreadsBackend` above will lead to erroneous results, since it
|
98
|
+
is an educational example. Do not use it for real analysis.
|
99
|
+
|
100
|
+
|
101
|
+
.. versionadded:: 2.8.0
|
102
|
+
|
103
|
+
"""
|
104
|
+
|
105
|
+
def __init__(self, n_workers: int):
|
106
|
+
self.n_workers = n_workers
|
107
|
+
self._validate()
|
108
|
+
|
109
|
+
def _get_checks(self):
|
110
|
+
"""Get dictionary with ``condition: error_message`` pairs that ensure the
|
111
|
+
validity of the backend instance
|
112
|
+
|
113
|
+
Returns
|
114
|
+
-------
|
115
|
+
dict
|
116
|
+
dictionary with ``condition: error_message`` pairs that will get
|
117
|
+
checked during ``_validate()`` run
|
118
|
+
"""
|
119
|
+
return {
|
120
|
+
isinstance(self.n_workers, int)
|
121
|
+
and self.n_workers
|
122
|
+
> 0: f"n_workers should be positive integer, got {self.n_workers=}",
|
123
|
+
}
|
124
|
+
|
125
|
+
def _get_warnings(self):
|
126
|
+
"""Get dictionary with ``condition: warning_message`` pairs that ensure
|
127
|
+
the good usage of the backend instance
|
128
|
+
|
129
|
+
Returns
|
130
|
+
-------
|
131
|
+
dict
|
132
|
+
dictionary with ``condition: warning_message`` pairs that will get
|
133
|
+
checked during ``_validate()`` run
|
134
|
+
"""
|
135
|
+
return dict()
|
136
|
+
|
137
|
+
def _validate(self):
|
138
|
+
"""Check correctness (e.g. ``dask`` is installed if using ``backend='dask'``)
|
139
|
+
and good usage (e.g. ``n_workers=1`` if backend is serial) of the backend
|
140
|
+
|
141
|
+
Raises
|
142
|
+
------
|
143
|
+
ValueError
|
144
|
+
if one of the conditions in :meth:`_get_checks` is ``True``
|
145
|
+
"""
|
146
|
+
for check, msg in self._get_checks().items():
|
147
|
+
if not check:
|
148
|
+
raise ValueError(msg)
|
149
|
+
for check, msg in self._get_warnings().items():
|
150
|
+
if not check:
|
151
|
+
warnings.warn(msg)
|
152
|
+
|
153
|
+
def apply(self, func: Callable, computations: list) -> list:
|
154
|
+
"""map function `func` to all tasks in the `computations` list
|
155
|
+
|
156
|
+
Main method that will get called when using an instance of
|
157
|
+
``BackendBase``. It is equivalent to running ``[func(item) for item in
|
158
|
+
computations]`` while using the parallel backend capabilities.
|
159
|
+
|
160
|
+
Parameters
|
161
|
+
----------
|
162
|
+
func : Callable
|
163
|
+
function to be called on each of the tasks in computations list
|
164
|
+
computations : list
|
165
|
+
computation tasks to apply function to
|
166
|
+
|
167
|
+
Returns
|
168
|
+
-------
|
169
|
+
list
|
170
|
+
list of results of the function
|
171
|
+
|
172
|
+
"""
|
173
|
+
raise NotImplementedError
|
174
|
+
|
175
|
+
|
176
|
+
class BackendSerial(BackendBase):
|
177
|
+
"""A built-in backend that does serial execution of the function, without any
|
178
|
+
parallelization.
|
179
|
+
|
180
|
+
Parameters
|
181
|
+
----------
|
182
|
+
n_workers : int
|
183
|
+
Is ignored in this class, and if ``n_workers`` > 1, a warning will be
|
184
|
+
given.
|
185
|
+
|
186
|
+
|
187
|
+
.. versionadded:: 2.8.0
|
188
|
+
"""
|
189
|
+
|
190
|
+
def _get_warnings(self):
|
191
|
+
"""Get dictionary with ``condition: warning_message`` pairs that ensure
|
192
|
+
the good usage of the backend instance. Here, it checks if the number
|
193
|
+
of workers is not 1, otherwise gives warning.
|
194
|
+
|
195
|
+
Returns
|
196
|
+
-------
|
197
|
+
dict
|
198
|
+
dictionary with ``condition: warning_message`` pairs that will get
|
199
|
+
checked during ``_validate()`` run
|
200
|
+
"""
|
201
|
+
return {
|
202
|
+
self.n_workers
|
203
|
+
== 1: "n_workers is ignored when executing with backend='serial'"
|
204
|
+
}
|
205
|
+
|
206
|
+
def apply(self, func: Callable, computations: list) -> list:
|
207
|
+
"""
|
208
|
+
Serially applies `func` to each task object in ``computations``.
|
209
|
+
|
210
|
+
Parameters
|
211
|
+
----------
|
212
|
+
func : Callable
|
213
|
+
function to be called on each of the tasks in computations list
|
214
|
+
computations : list
|
215
|
+
computation tasks to apply function to
|
216
|
+
|
217
|
+
Returns
|
218
|
+
-------
|
219
|
+
list
|
220
|
+
list of results of the function
|
221
|
+
"""
|
222
|
+
return [func(task) for task in computations]
|
223
|
+
|
224
|
+
|
225
|
+
class BackendMultiprocessing(BackendBase):
|
226
|
+
"""A built-in backend that executes a given function using the
|
227
|
+
:meth:`multiprocessing.Pool.map <multiprocessing.pool.Pool.map>` method.
|
228
|
+
|
229
|
+
Parameters
|
230
|
+
----------
|
231
|
+
n_workers : int
|
232
|
+
number of processes in :class:`multiprocessing.Pool
|
233
|
+
<multiprocessing.pool.Pool>` to distribute the workload
|
234
|
+
between. Must be a positive integer.
|
235
|
+
|
236
|
+
Examples
|
237
|
+
--------
|
238
|
+
|
239
|
+
.. code-block:: python
|
240
|
+
|
241
|
+
from MDAnalysis.analysis.backends import BackendMultiprocessing
|
242
|
+
import multiprocessing as mp
|
243
|
+
|
244
|
+
backend_obj = BackendMultiprocessing(n_workers=mp.cpu_count())
|
245
|
+
|
246
|
+
|
247
|
+
.. versionadded:: 2.8.0
|
248
|
+
|
249
|
+
"""
|
250
|
+
|
251
|
+
def apply(self, func: Callable, computations: list) -> list:
|
252
|
+
"""Applies `func` to each object in ``computations`` using `multiprocessing`'s `Pool.map`.
|
253
|
+
|
254
|
+
Parameters
|
255
|
+
----------
|
256
|
+
func : Callable
|
257
|
+
function to be called on each of the tasks in computations list
|
258
|
+
computations : list
|
259
|
+
computation tasks to apply function to
|
260
|
+
|
261
|
+
Returns
|
262
|
+
-------
|
263
|
+
list
|
264
|
+
list of results of the function
|
265
|
+
"""
|
266
|
+
from multiprocessing import Pool
|
267
|
+
|
268
|
+
with Pool(processes=self.n_workers) as pool:
|
269
|
+
results = pool.map(func, computations)
|
270
|
+
return results
|
271
|
+
|
272
|
+
|
273
|
+
class BackendDask(BackendBase):
|
274
|
+
"""A built-in backend that executes a given function with *dask*.
|
275
|
+
|
276
|
+
Execution is performed with the :func:`dask.compute` function of
|
277
|
+
:class:`dask.delayed.Delayed` object (created with
|
278
|
+
:func:`dask.delayed.delayed`) with ``scheduler='processes'`` and
|
279
|
+
``chunksize=1`` (this ensures uniform distribution of tasks among
|
280
|
+
processes). Requires the `dask package <https://docs.dask.org/en/stable/>`_
|
281
|
+
to be `installed <https://docs.dask.org/en/stable/install.html>`_.
|
282
|
+
|
283
|
+
Parameters
|
284
|
+
----------
|
285
|
+
n_workers : int
|
286
|
+
number of processes in to distribute the workload
|
287
|
+
between. Must be a positive integer. Workers are actually
|
288
|
+
:class:`multiprocessing.pool.Pool` processes, but they use a different and
|
289
|
+
more flexible `serialization protocol
|
290
|
+
<https://docs.dask.org/en/stable/phases-of-computation.html#graph-serialization>`_.
|
291
|
+
|
292
|
+
Examples
|
293
|
+
--------
|
294
|
+
|
295
|
+
.. code-block:: python
|
296
|
+
|
297
|
+
from MDAnalysis.analysis.backends import BackendDask
|
298
|
+
import multiprocessing as mp
|
299
|
+
|
300
|
+
backend_obj = BackendDask(n_workers=mp.cpu_count())
|
301
|
+
|
302
|
+
|
303
|
+
.. versionadded:: 2.8.0
|
304
|
+
|
305
|
+
"""
|
306
|
+
|
307
|
+
def apply(self, func: Callable, computations: list) -> list:
|
308
|
+
"""Applies `func` to each object in ``computations``.
|
309
|
+
|
310
|
+
Parameters
|
311
|
+
----------
|
312
|
+
func : Callable
|
313
|
+
function to be called on each of the tasks in computations list
|
314
|
+
computations : list
|
315
|
+
computation tasks to apply function to
|
316
|
+
|
317
|
+
Returns
|
318
|
+
-------
|
319
|
+
list
|
320
|
+
list of results of the function
|
321
|
+
"""
|
322
|
+
from dask.delayed import delayed
|
323
|
+
import dask
|
324
|
+
|
325
|
+
computations = [delayed(func)(task) for task in computations]
|
326
|
+
results = dask.compute(
|
327
|
+
computations,
|
328
|
+
scheduler="processes",
|
329
|
+
chunksize=1,
|
330
|
+
num_workers=self.n_workers,
|
331
|
+
)[0]
|
332
|
+
return results
|
333
|
+
|
334
|
+
def _get_checks(self):
|
335
|
+
"""Get dictionary with ``condition: error_message`` pairs that ensure the
|
336
|
+
validity of the backend instance. Here checks if ``dask`` module is
|
337
|
+
installed in the environment.
|
338
|
+
|
339
|
+
Returns
|
340
|
+
-------
|
341
|
+
dict
|
342
|
+
dictionary with ``condition: error_message`` pairs that will get
|
343
|
+
checked during ``_validate()`` run
|
344
|
+
"""
|
345
|
+
base_checks = super()._get_checks()
|
346
|
+
checks = {
|
347
|
+
is_installed("dask"): (
|
348
|
+
"module 'dask' is missing. Please install 'dask': "
|
349
|
+
"https://docs.dask.org/en/stable/install.html"
|
350
|
+
)
|
351
|
+
}
|
352
|
+
return base_checks | checks
|
File without changes
|