westpa 2022.10__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of westpa might be problematic. Click here for more details.
- westpa/__init__.py +14 -0
- westpa/_version.py +21 -0
- westpa/analysis/__init__.py +5 -0
- westpa/analysis/core.py +746 -0
- westpa/analysis/statistics.py +27 -0
- westpa/analysis/trajectories.py +360 -0
- westpa/cli/__init__.py +0 -0
- westpa/cli/core/__init__.py +0 -0
- westpa/cli/core/w_fork.py +152 -0
- westpa/cli/core/w_init.py +230 -0
- westpa/cli/core/w_run.py +77 -0
- westpa/cli/core/w_states.py +212 -0
- westpa/cli/core/w_succ.py +99 -0
- westpa/cli/core/w_truncate.py +59 -0
- westpa/cli/tools/__init__.py +0 -0
- westpa/cli/tools/ploterr.py +506 -0
- westpa/cli/tools/plothist.py +706 -0
- westpa/cli/tools/w_assign.py +596 -0
- westpa/cli/tools/w_bins.py +166 -0
- westpa/cli/tools/w_crawl.py +119 -0
- westpa/cli/tools/w_direct.py +547 -0
- westpa/cli/tools/w_dumpsegs.py +94 -0
- westpa/cli/tools/w_eddist.py +506 -0
- westpa/cli/tools/w_fluxanl.py +378 -0
- westpa/cli/tools/w_ipa.py +833 -0
- westpa/cli/tools/w_kinavg.py +127 -0
- westpa/cli/tools/w_kinetics.py +96 -0
- westpa/cli/tools/w_multi_west.py +414 -0
- westpa/cli/tools/w_ntop.py +213 -0
- westpa/cli/tools/w_pdist.py +515 -0
- westpa/cli/tools/w_postanalysis_matrix.py +82 -0
- westpa/cli/tools/w_postanalysis_reweight.py +53 -0
- westpa/cli/tools/w_red.py +486 -0
- westpa/cli/tools/w_reweight.py +780 -0
- westpa/cli/tools/w_select.py +226 -0
- westpa/cli/tools/w_stateprobs.py +111 -0
- westpa/cli/tools/w_trace.py +599 -0
- westpa/core/__init__.py +0 -0
- westpa/core/_rc.py +673 -0
- westpa/core/binning/__init__.py +55 -0
- westpa/core/binning/_assign.cpython-312-darwin.so +0 -0
- westpa/core/binning/assign.py +449 -0
- westpa/core/binning/binless.py +96 -0
- westpa/core/binning/binless_driver.py +54 -0
- westpa/core/binning/binless_manager.py +190 -0
- westpa/core/binning/bins.py +47 -0
- westpa/core/binning/mab.py +427 -0
- westpa/core/binning/mab_driver.py +54 -0
- westpa/core/binning/mab_manager.py +198 -0
- westpa/core/data_manager.py +1694 -0
- westpa/core/extloader.py +74 -0
- westpa/core/h5io.py +995 -0
- westpa/core/kinetics/__init__.py +24 -0
- westpa/core/kinetics/_kinetics.cpython-312-darwin.so +0 -0
- westpa/core/kinetics/events.py +147 -0
- westpa/core/kinetics/matrates.py +156 -0
- westpa/core/kinetics/rate_averaging.py +266 -0
- westpa/core/progress.py +218 -0
- westpa/core/propagators/__init__.py +54 -0
- westpa/core/propagators/executable.py +715 -0
- westpa/core/reweight/__init__.py +14 -0
- westpa/core/reweight/_reweight.cpython-312-darwin.so +0 -0
- westpa/core/reweight/matrix.py +126 -0
- westpa/core/segment.py +119 -0
- westpa/core/sim_manager.py +830 -0
- westpa/core/states.py +359 -0
- westpa/core/systems.py +93 -0
- westpa/core/textio.py +74 -0
- westpa/core/trajectory.py +330 -0
- westpa/core/we_driver.py +908 -0
- westpa/core/wm_ops.py +43 -0
- westpa/core/yamlcfg.py +391 -0
- westpa/fasthist/__init__.py +34 -0
- westpa/fasthist/__main__.py +110 -0
- westpa/fasthist/_fasthist.cpython-312-darwin.so +0 -0
- westpa/mclib/__init__.py +264 -0
- westpa/mclib/__main__.py +28 -0
- westpa/mclib/_mclib.cpython-312-darwin.so +0 -0
- westpa/oldtools/__init__.py +4 -0
- westpa/oldtools/aframe/__init__.py +35 -0
- westpa/oldtools/aframe/atool.py +75 -0
- westpa/oldtools/aframe/base_mixin.py +26 -0
- westpa/oldtools/aframe/binning.py +178 -0
- westpa/oldtools/aframe/data_reader.py +560 -0
- westpa/oldtools/aframe/iter_range.py +200 -0
- westpa/oldtools/aframe/kinetics.py +117 -0
- westpa/oldtools/aframe/mcbs.py +146 -0
- westpa/oldtools/aframe/output.py +39 -0
- westpa/oldtools/aframe/plotting.py +90 -0
- westpa/oldtools/aframe/trajwalker.py +126 -0
- westpa/oldtools/aframe/transitions.py +469 -0
- westpa/oldtools/cmds/__init__.py +0 -0
- westpa/oldtools/cmds/w_ttimes.py +358 -0
- westpa/oldtools/files.py +34 -0
- westpa/oldtools/miscfn.py +23 -0
- westpa/oldtools/stats/__init__.py +4 -0
- westpa/oldtools/stats/accumulator.py +35 -0
- westpa/oldtools/stats/edfs.py +129 -0
- westpa/oldtools/stats/mcbs.py +89 -0
- westpa/tools/__init__.py +33 -0
- westpa/tools/binning.py +472 -0
- westpa/tools/core.py +340 -0
- westpa/tools/data_reader.py +159 -0
- westpa/tools/dtypes.py +31 -0
- westpa/tools/iter_range.py +198 -0
- westpa/tools/kinetics_tool.py +340 -0
- westpa/tools/plot.py +283 -0
- westpa/tools/progress.py +17 -0
- westpa/tools/selected_segs.py +154 -0
- westpa/tools/wipi.py +751 -0
- westpa/trajtree/__init__.py +4 -0
- westpa/trajtree/_trajtree.cpython-312-darwin.so +0 -0
- westpa/trajtree/trajtree.py +117 -0
- westpa/westext/__init__.py +0 -0
- westpa/westext/adaptvoronoi/__init__.py +3 -0
- westpa/westext/adaptvoronoi/adaptVor_driver.py +214 -0
- westpa/westext/hamsm_restarting/__init__.py +3 -0
- westpa/westext/hamsm_restarting/example_overrides.py +35 -0
- westpa/westext/hamsm_restarting/restart_driver.py +1165 -0
- westpa/westext/stringmethod/__init__.py +11 -0
- westpa/westext/stringmethod/fourier_fitting.py +69 -0
- westpa/westext/stringmethod/string_driver.py +253 -0
- westpa/westext/stringmethod/string_method.py +306 -0
- westpa/westext/weed/BinCluster.py +180 -0
- westpa/westext/weed/ProbAdjustEquil.py +100 -0
- westpa/westext/weed/UncertMath.py +247 -0
- westpa/westext/weed/__init__.py +10 -0
- westpa/westext/weed/weed_driver.py +182 -0
- westpa/westext/wess/ProbAdjust.py +101 -0
- westpa/westext/wess/__init__.py +6 -0
- westpa/westext/wess/wess_driver.py +207 -0
- westpa/work_managers/__init__.py +57 -0
- westpa/work_managers/core.py +396 -0
- westpa/work_managers/environment.py +134 -0
- westpa/work_managers/mpi.py +318 -0
- westpa/work_managers/processes.py +187 -0
- westpa/work_managers/serial.py +28 -0
- westpa/work_managers/threads.py +79 -0
- westpa/work_managers/zeromq/__init__.py +20 -0
- westpa/work_managers/zeromq/core.py +641 -0
- westpa/work_managers/zeromq/node.py +131 -0
- westpa/work_managers/zeromq/work_manager.py +526 -0
- westpa/work_managers/zeromq/worker.py +320 -0
- westpa-2022.10.dist-info/AUTHORS +22 -0
- westpa-2022.10.dist-info/LICENSE +21 -0
- westpa-2022.10.dist-info/METADATA +183 -0
- westpa-2022.10.dist-info/RECORD +150 -0
- westpa-2022.10.dist-info/WHEEL +5 -0
- westpa-2022.10.dist-info/entry_points.txt +29 -0
- westpa-2022.10.dist-info/top_level.txt +1 -0
westpa/core/h5io.py
ADDED
|
@@ -0,0 +1,995 @@
|
|
|
1
|
+
'''Miscellaneous routines to help with HDF5 input and output of WEST-related data.'''
|
|
2
|
+
|
|
3
|
+
import collections
|
|
4
|
+
import errno
|
|
5
|
+
import getpass
|
|
6
|
+
import os
|
|
7
|
+
import posixpath
|
|
8
|
+
import socket
|
|
9
|
+
import sys
|
|
10
|
+
import time
|
|
11
|
+
import logging
|
|
12
|
+
import warnings
|
|
13
|
+
|
|
14
|
+
import h5py
|
|
15
|
+
import numpy as np
|
|
16
|
+
from numpy import index_exp
|
|
17
|
+
from tables import NaturalNameWarning
|
|
18
|
+
|
|
19
|
+
from mdtraj import Trajectory, join as join_traj
|
|
20
|
+
from mdtraj.utils import in_units_of, import_, ensure_type
|
|
21
|
+
from mdtraj.formats import HDF5TrajectoryFile
|
|
22
|
+
from mdtraj.formats.hdf5 import _check_mode, Frames
|
|
23
|
+
|
|
24
|
+
from .trajectory import WESTTrajectory
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
import psutil
|
|
28
|
+
except ImportError:
|
|
29
|
+
psutil = None
|
|
30
|
+
|
|
31
|
+
log = logging.getLogger(__name__)
|
|
32
|
+
warnings.filterwarnings('ignore', category=NaturalNameWarning)
|
|
33
|
+
|
|
34
|
+
#
|
|
35
|
+
# Constants and globals
|
|
36
|
+
#
|
|
37
|
+
default_iter_prec = 8
|
|
38
|
+
|
|
39
|
+
#
|
|
40
|
+
# Helper functions
|
|
41
|
+
#
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def resolve_filepath(path, constructor=h5py.File, cargs=None, ckwargs=None, **addtlkwargs):
|
|
45
|
+
'''Use a combined filesystem and HDF5 path to open an HDF5 file and return the
|
|
46
|
+
appropriate object. Returns (h5file, h5object). The file is opened using
|
|
47
|
+
``constructor(filename, *cargs, **ckwargs)``.'''
|
|
48
|
+
|
|
49
|
+
cargs = cargs or ()
|
|
50
|
+
ckwargs = ckwargs or {}
|
|
51
|
+
ckwargs.update(addtlkwargs)
|
|
52
|
+
objpieces = collections.deque()
|
|
53
|
+
path = posixpath.normpath(path)
|
|
54
|
+
|
|
55
|
+
filepieces = path.split('/')
|
|
56
|
+
while filepieces:
|
|
57
|
+
testpath = '/'.join(filepieces)
|
|
58
|
+
if not testpath:
|
|
59
|
+
filepieces.pop()
|
|
60
|
+
continue
|
|
61
|
+
try:
|
|
62
|
+
h5file = constructor(testpath, *cargs, **ckwargs)
|
|
63
|
+
except IOError:
|
|
64
|
+
objpieces.appendleft(filepieces.pop())
|
|
65
|
+
continue
|
|
66
|
+
else:
|
|
67
|
+
return (h5file, h5file['/'.join([''] + list(objpieces)) if objpieces else '/'])
|
|
68
|
+
else:
|
|
69
|
+
# We don't provide a filename, because we're not sure where the filename stops
|
|
70
|
+
# and the HDF5 path begins.
|
|
71
|
+
raise IOError(errno.ENOENT, os.strerror(errno.ENOENT))
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def calc_chunksize(shape, dtype, max_chunksize=262144):
|
|
75
|
+
'''Calculate a chunk size for HDF5 data, anticipating that access will slice
|
|
76
|
+
along lower dimensions sooner than higher dimensions.'''
|
|
77
|
+
|
|
78
|
+
chunk_shape = list(shape)
|
|
79
|
+
dtype = np.dtype(dtype)
|
|
80
|
+
for idim in range(len(shape)):
|
|
81
|
+
chunk_nbytes = np.multiply.reduce(chunk_shape) * dtype.itemsize
|
|
82
|
+
while chunk_shape[idim] > 1 and chunk_nbytes > max_chunksize:
|
|
83
|
+
chunk_shape[idim] >>= 1 # divide by 2
|
|
84
|
+
chunk_nbytes = np.multiply.reduce(chunk_shape) * dtype.itemsize
|
|
85
|
+
|
|
86
|
+
if chunk_nbytes <= max_chunksize:
|
|
87
|
+
break
|
|
88
|
+
|
|
89
|
+
chunk_shape = tuple(chunk_shape)
|
|
90
|
+
return chunk_shape
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def tostr(b):
|
|
94
|
+
'''Convert a nonstandard string object ``b`` to str with the handling of the
|
|
95
|
+
case where ``b`` is bytes.'''
|
|
96
|
+
|
|
97
|
+
if b is None:
|
|
98
|
+
return None
|
|
99
|
+
elif isinstance(b, bytes):
|
|
100
|
+
return b.decode('utf-8')
|
|
101
|
+
else:
|
|
102
|
+
return str(b)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def is_within_directory(directory, target):
|
|
106
|
+
abs_directory = os.path.abspath(directory)
|
|
107
|
+
abs_target = os.path.abspath(target)
|
|
108
|
+
|
|
109
|
+
prefix = os.path.commonprefix([abs_directory, abs_target])
|
|
110
|
+
|
|
111
|
+
return prefix == abs_directory
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
|
|
115
|
+
for member in tar.getmembers():
|
|
116
|
+
member_path = os.path.join(path, member.name)
|
|
117
|
+
if not is_within_directory(path, member_path):
|
|
118
|
+
raise Exception("Attempted Path Traversal in Tar File")
|
|
119
|
+
|
|
120
|
+
tar.extractall(path, members, numeric_owner=numeric_owner)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
#
|
|
124
|
+
# Group and dataset manipulation functions
|
|
125
|
+
#
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def create_hdf5_group(parent_group, groupname, replace=False, creating_program=None):
|
|
129
|
+
'''Create (or delete and recreate) and HDF5 group named ``groupname`` within
|
|
130
|
+
the enclosing Group (object) ``parent_group``. If ``replace`` is True, then
|
|
131
|
+
the group is replaced if present; if False, then an error is raised if the
|
|
132
|
+
group is present. After the group is created, HDF5 attributes are set using
|
|
133
|
+
`stamp_creator_data`.
|
|
134
|
+
'''
|
|
135
|
+
|
|
136
|
+
if replace:
|
|
137
|
+
try:
|
|
138
|
+
del parent_group[groupname]
|
|
139
|
+
except KeyError:
|
|
140
|
+
pass
|
|
141
|
+
|
|
142
|
+
newgroup = parent_group.create_group(groupname)
|
|
143
|
+
stamp_creator_data(newgroup)
|
|
144
|
+
return newgroup
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
#
|
|
148
|
+
# Group and dataset labeling functions
|
|
149
|
+
#
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def stamp_creator_data(h5group, creating_program=None):
|
|
153
|
+
'''Mark the following on the HDF5 group ``h5group``:
|
|
154
|
+
|
|
155
|
+
:creation_program: The name of the program that created the group
|
|
156
|
+
:creation_user: The username of the user who created the group
|
|
157
|
+
:creation_hostname: The hostname of the machine on which the group was created
|
|
158
|
+
:creation_time: The date and time at which the group was created, in the
|
|
159
|
+
current locale.
|
|
160
|
+
:creation_unix_time: The Unix time (seconds from the epoch, UTC) at which the
|
|
161
|
+
group was created.
|
|
162
|
+
|
|
163
|
+
This is meant to facilitate tracking the flow of data, but should not be considered
|
|
164
|
+
a secure paper trail (after all, anyone with write access to the HDF5 file can modify
|
|
165
|
+
these attributes).
|
|
166
|
+
'''
|
|
167
|
+
now = time.time()
|
|
168
|
+
attrs = h5group.attrs
|
|
169
|
+
|
|
170
|
+
attrs['creation_program'] = creating_program or sys.argv[0] or 'unknown program'
|
|
171
|
+
attrs['creation_user'] = getpass.getuser()
|
|
172
|
+
attrs['creation_hostname'] = socket.gethostname()
|
|
173
|
+
attrs['creation_unix_time'] = now
|
|
174
|
+
attrs['creation_time'] = time.strftime('%c', time.localtime(now))
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def get_creator_data(h5group):
|
|
178
|
+
'''Read back creator data as written by ``stamp_creator_data``, returning a dictionary with
|
|
179
|
+
keys as described for ``stamp_creator_data``. Missing fields are denoted with None.
|
|
180
|
+
The ``creation_time`` field is returned as a string.'''
|
|
181
|
+
attrs = h5group.attrs
|
|
182
|
+
d = dict()
|
|
183
|
+
for attr in ['creation_program', 'creation_user', 'creation_hostname', 'creation_unix_time', 'creation_time']:
|
|
184
|
+
d[attr] = attrs.get(attr)
|
|
185
|
+
return d
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def load_west(filename):
|
|
189
|
+
"""Load WESTPA trajectory files from disk.
|
|
190
|
+
|
|
191
|
+
Parameters
|
|
192
|
+
----------
|
|
193
|
+
filename : str
|
|
194
|
+
String filename of HDF Trajectory file.
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
with h5py.File(filename, 'r') as f:
|
|
198
|
+
iter_group_template = 'iter_{1:0{0}d}'
|
|
199
|
+
iter_prec = f.attrs['west_iter_prec']
|
|
200
|
+
trajectories = []
|
|
201
|
+
n = 0
|
|
202
|
+
|
|
203
|
+
iter_group_name = iter_group_template.format(iter_prec, n)
|
|
204
|
+
for iter_group_name in f['iterations']:
|
|
205
|
+
iter_group = f['iterations/' + iter_group_name]
|
|
206
|
+
|
|
207
|
+
if 'trajectories' in iter_group:
|
|
208
|
+
traj_link = iter_group['trajectories']
|
|
209
|
+
traj_filename = traj_link.file.filename
|
|
210
|
+
|
|
211
|
+
with WESTIterationFile(traj_filename) as traj_file:
|
|
212
|
+
traj = traj_file.read_as_traj()
|
|
213
|
+
else:
|
|
214
|
+
# TODO: [HDF5] allow initializing trajectory without coordinates
|
|
215
|
+
raise ValueError("Missing trajectories for iteration %d" % n)
|
|
216
|
+
|
|
217
|
+
# pcoord is required
|
|
218
|
+
if 'pcoord' not in iter_group:
|
|
219
|
+
raise ValueError("Missing pcoords for iteration %d" % n)
|
|
220
|
+
|
|
221
|
+
raw_pcoord = iter_group['pcoord'][:]
|
|
222
|
+
if raw_pcoord.ndim != 3:
|
|
223
|
+
log.warning('pcoord is expected to be a 3-d ndarray instead of {}-d'.format(raw_pcoord.ndim))
|
|
224
|
+
continue
|
|
225
|
+
# ignore the first frame of each segment
|
|
226
|
+
if raw_pcoord.shape[1] == traj.n_frames + 1:
|
|
227
|
+
raw_pcoord = raw_pcoord[:, 1:, :]
|
|
228
|
+
elif raw_pcoord.shape[1] == traj.n_frames:
|
|
229
|
+
raw_pcoord = raw_pcoord[:, :, :]
|
|
230
|
+
else:
|
|
231
|
+
raise ValueError(
|
|
232
|
+
"Inconsistent number of pcoords (%d) and frames (%d) for iteration %d" % (raw_pcoord.shape[1], traj.n_frames, n)
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
pcoords = np.concatenate(raw_pcoord, axis=0)
|
|
236
|
+
n_frames = raw_pcoord.shape[1]
|
|
237
|
+
|
|
238
|
+
if 'seg_index' in iter_group:
|
|
239
|
+
raw_pid = iter_group['seg_index']['parent_id'][:]
|
|
240
|
+
|
|
241
|
+
if np.any(raw_pid < 0):
|
|
242
|
+
init_basis_ids = iter_group['ibstates']['istate_index']['basis_state_id'][:]
|
|
243
|
+
init_ids = -(raw_pid[raw_pid < 0] + 1)
|
|
244
|
+
raw_pid[raw_pid < 0] = [init_basis_ids[iid] for iid in init_ids]
|
|
245
|
+
parent_ids = raw_pid.repeat(n_frames, axis=0)
|
|
246
|
+
else:
|
|
247
|
+
parent_ids = None
|
|
248
|
+
|
|
249
|
+
traj.pcoords = pcoords
|
|
250
|
+
traj.parent_ids = parent_ids
|
|
251
|
+
trajectories.append(traj)
|
|
252
|
+
|
|
253
|
+
n += 1
|
|
254
|
+
iter_group_name = iter_group_template.format(iter_prec, n)
|
|
255
|
+
|
|
256
|
+
west_traj = join_traj(trajectories)
|
|
257
|
+
|
|
258
|
+
return west_traj
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
###
|
|
262
|
+
# Iteration range metadata
|
|
263
|
+
###
|
|
264
|
+
def stamp_iter_range(h5object, start_iter, stop_iter):
|
|
265
|
+
'''Mark that the HDF5 object ``h5object`` (dataset or group) contains data from iterations
|
|
266
|
+
start_iter <= n_iter < stop_iter.'''
|
|
267
|
+
h5object.attrs['iter_start'] = start_iter
|
|
268
|
+
h5object.attrs['iter_stop'] = stop_iter
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def get_iter_range(h5object):
|
|
272
|
+
'''Read back iteration range data written by ``stamp_iter_range``'''
|
|
273
|
+
return int(h5object.attrs['iter_start']), int(h5object.attrs['iter_stop'])
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def stamp_iter_step(h5group, iter_step):
|
|
277
|
+
'''Mark that the HDF5 object ``h5object`` (dataset or group) contains data with an
|
|
278
|
+
iteration step (stride) of iter_step).'''
|
|
279
|
+
h5group.attrs['iter_step'] = iter_step
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def get_iter_step(h5group):
|
|
283
|
+
'''Read back iteration step (stride) written by ``stamp_iter_step``'''
|
|
284
|
+
return int(h5group.attrs['iter_step'])
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def check_iter_range_least(h5object, iter_start, iter_stop):
|
|
288
|
+
'''Return True if the iteration range [iter_start, iter_stop) is
|
|
289
|
+
the same as or entirely contained within the iteration range stored
|
|
290
|
+
on ``h5object``.'''
|
|
291
|
+
obj_iter_start, obj_iter_stop = get_iter_range(h5object)
|
|
292
|
+
return obj_iter_start <= iter_start and obj_iter_stop >= iter_stop
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def check_iter_range_equal(h5object, iter_start, iter_stop):
|
|
296
|
+
'''Return True if the iteration range [iter_start, iter_stop) is
|
|
297
|
+
the same as the iteration range stored on ``h5object``.'''
|
|
298
|
+
obj_iter_start, obj_iter_stop = get_iter_range(h5object)
|
|
299
|
+
return obj_iter_start == iter_start and obj_iter_stop == iter_stop
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def get_iteration_entry(h5object, n_iter):
|
|
303
|
+
'''Create a slice for data corresponding to iteration ``n_iter`` in ``h5object``.'''
|
|
304
|
+
obj_iter_start, obj_iter_stop = get_iter_range(h5object)
|
|
305
|
+
if n_iter < obj_iter_start or n_iter >= obj_iter_stop:
|
|
306
|
+
raise IndexError('data for iteration {} not available in dataset {!r}'.format(n_iter, h5object))
|
|
307
|
+
return np.index_exp[n_iter - obj_iter_start]
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def get_iteration_slice(h5object, iter_start, iter_stop=None, iter_stride=None):
|
|
311
|
+
'''Create a slice for data corresponding to iterations [iter_start,iter_stop),
|
|
312
|
+
with stride iter_step, in the given ``h5object``.'''
|
|
313
|
+
obj_iter_start, obj_iter_stop = get_iter_range(h5object)
|
|
314
|
+
|
|
315
|
+
if iter_stop is None:
|
|
316
|
+
iter_stop = iter_start + 1
|
|
317
|
+
if iter_stride is None:
|
|
318
|
+
iter_stride = 1
|
|
319
|
+
|
|
320
|
+
if iter_start < obj_iter_start:
|
|
321
|
+
raise IndexError('data for iteration {} not available in dataset {!r}'.format(iter_start, h5object))
|
|
322
|
+
elif iter_start > obj_iter_stop:
|
|
323
|
+
raise IndexError('data for iteration {} not available in dataset {!r}'.format(iter_stop, h5object))
|
|
324
|
+
|
|
325
|
+
start_index = iter_start - obj_iter_start
|
|
326
|
+
stop_index = iter_stop - obj_iter_start
|
|
327
|
+
return np.index_exp[start_index:stop_index:iter_stride]
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
###
|
|
331
|
+
# Axis label metadata
|
|
332
|
+
###
|
|
333
|
+
def label_axes(h5object, labels, units=None):
|
|
334
|
+
'''Stamp the given HDF5 object with axis labels. This stores the axis labels
|
|
335
|
+
in an array of strings in an attribute called ``axis_labels`` on the given
|
|
336
|
+
object. ``units`` if provided is a corresponding list of units.'''
|
|
337
|
+
|
|
338
|
+
if len(labels) != len(h5object.shape):
|
|
339
|
+
raise ValueError('number of axes and number of labels do not match')
|
|
340
|
+
|
|
341
|
+
if units is None:
|
|
342
|
+
units = []
|
|
343
|
+
|
|
344
|
+
if len(units) and len(units) != len(labels):
|
|
345
|
+
raise ValueError('number of units labels does not match number of axes')
|
|
346
|
+
|
|
347
|
+
h5object.attrs['axis_labels'] = np.array([np.string_(i) for i in labels])
|
|
348
|
+
|
|
349
|
+
if len(units):
|
|
350
|
+
h5object.attrs['axis_units'] = np.array([np.string_(i) for i in units])
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
NotGiven = object()
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def _get_one_attr(h5object, namelist, default=NotGiven):
|
|
357
|
+
attrs = dict(h5object.attrs)
|
|
358
|
+
for name in namelist:
|
|
359
|
+
try:
|
|
360
|
+
return attrs[name]
|
|
361
|
+
except KeyError:
|
|
362
|
+
pass
|
|
363
|
+
else:
|
|
364
|
+
if default is NotGiven:
|
|
365
|
+
raise KeyError('no such key')
|
|
366
|
+
else:
|
|
367
|
+
return default
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
class WESTPAH5File(h5py.File):
|
|
371
|
+
'''Generalized input/output for WESTPA simulation (or analysis) data.'''
|
|
372
|
+
|
|
373
|
+
default_iter_prec = 8
|
|
374
|
+
_this_fileformat_version = 8
|
|
375
|
+
|
|
376
|
+
def __init__(self, *args, **kwargs):
|
|
377
|
+
# These values are used for creating files or reading files where this
|
|
378
|
+
# data is not stored. Otherwise, values stored as attributes on the root
|
|
379
|
+
# group are used instead.
|
|
380
|
+
arg_iter_prec = kwargs.pop('westpa_iter_prec', self.default_iter_prec)
|
|
381
|
+
arg_fileformat_version = kwargs.pop('westpa_fileformat_version', self._this_fileformat_version)
|
|
382
|
+
arg_creating_program = kwargs.pop('creating_program', None)
|
|
383
|
+
|
|
384
|
+
# Initialize h5py file
|
|
385
|
+
super().__init__(*args, **kwargs)
|
|
386
|
+
|
|
387
|
+
# Try to get iteration precision and I/O class version
|
|
388
|
+
h5file_iter_prec = _get_one_attr(self, ['westpa_iter_prec', 'west_iter_prec', 'wemd_iter_prec'], None)
|
|
389
|
+
h5file_fileformat_version = _get_one_attr(
|
|
390
|
+
self, ['westpa_fileformat_version', 'west_file_format_version', 'wemd_file_format_version'], None
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
self.iter_prec = h5file_iter_prec if h5file_iter_prec is not None else arg_iter_prec
|
|
394
|
+
self.fileformat_version = h5file_fileformat_version if h5file_fileformat_version is not None else arg_fileformat_version
|
|
395
|
+
|
|
396
|
+
# Ensure that file format attributes are stored, if the file is writable
|
|
397
|
+
if self.mode == 'r+':
|
|
398
|
+
self.attrs['westpa_iter_prec'] = self.iter_prec
|
|
399
|
+
self.attrs['westpa_fileformat_version'] = self.fileformat_version
|
|
400
|
+
if arg_creating_program:
|
|
401
|
+
stamp_creator_data(self, creating_program=arg_creating_program)
|
|
402
|
+
|
|
403
|
+
# Helper function to automatically replace a group, if it exists.
|
|
404
|
+
# Should really only be called when one is certain a dataset should be blown away.
|
|
405
|
+
def replace_dataset(self, *args, **kwargs):
|
|
406
|
+
try:
|
|
407
|
+
del self[args[0]]
|
|
408
|
+
except Exception:
|
|
409
|
+
pass
|
|
410
|
+
try:
|
|
411
|
+
del self[kwargs['name']]
|
|
412
|
+
except Exception:
|
|
413
|
+
pass
|
|
414
|
+
|
|
415
|
+
return self.create_dataset(*args, **kwargs)
|
|
416
|
+
|
|
417
|
+
# Iteration groups
|
|
418
|
+
|
|
419
|
+
def iter_object_name(self, n_iter, prefix='', suffix=''):
|
|
420
|
+
'''Return a properly-formatted per-iteration name for iteration
|
|
421
|
+
``n_iter``. (This is used in create/require/get_iter_group, but may
|
|
422
|
+
also be useful for naming datasets on a per-iteration basis.)'''
|
|
423
|
+
return '{prefix}iter_{n_iter:0{prec}d}{suffix}'.format(n_iter=n_iter, prefix=prefix, suffix=suffix, prec=self.iter_prec)
|
|
424
|
+
|
|
425
|
+
def create_iter_group(self, n_iter, group=None):
|
|
426
|
+
'''Create a per-iteration data storage group for iteration number ``n_iter``
|
|
427
|
+
in the group ``group`` (which is '/iterations' by default).'''
|
|
428
|
+
|
|
429
|
+
if group is None:
|
|
430
|
+
group = self.require_group('/iterations')
|
|
431
|
+
return group.create_group(self.iter_object_name(n_iter))
|
|
432
|
+
|
|
433
|
+
def require_iter_group(self, n_iter, group=None):
|
|
434
|
+
'''Ensure that a per-iteration data storage group for iteration number ``n_iter``
|
|
435
|
+
is available in the group ``group`` (which is '/iterations' by default).'''
|
|
436
|
+
if group is None:
|
|
437
|
+
group = self.require_group('/iterations')
|
|
438
|
+
return group.require_group(self.iter_object_name(n_iter))
|
|
439
|
+
|
|
440
|
+
def get_iter_group(self, n_iter, group=None):
|
|
441
|
+
'''Get the per-iteration data group for iteration number ``n_iter`` from within
|
|
442
|
+
the group ``group`` ('/iterations' by default).'''
|
|
443
|
+
if group is None:
|
|
444
|
+
group = self['/iterations']
|
|
445
|
+
return group[self.iter_object_name(n_iter)]
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
class WESTIterationFile(HDF5TrajectoryFile):
|
|
449
|
+
def __init__(self, file, mode='r', force_overwrite=True, compression='zlib', link=None):
|
|
450
|
+
if isinstance(file, str):
|
|
451
|
+
super(WESTIterationFile, self).__init__(file, mode, force_overwrite, compression)
|
|
452
|
+
else:
|
|
453
|
+
try:
|
|
454
|
+
self._init_from_handle(file)
|
|
455
|
+
except AttributeError:
|
|
456
|
+
raise ValueError('unknown input type: %s' % str(type(file)))
|
|
457
|
+
|
|
458
|
+
def _init_from_handle(self, handle):
|
|
459
|
+
self._handle = handle
|
|
460
|
+
self._open = handle.isopen != 0
|
|
461
|
+
self.mode = mode = handle.mode # the mode in which the file was opened?
|
|
462
|
+
|
|
463
|
+
if mode not in ['r', 'w', 'a']:
|
|
464
|
+
raise ValueError("mode must be one of ['r', 'w', 'a']")
|
|
465
|
+
|
|
466
|
+
# import tables
|
|
467
|
+
self.tables = import_('tables')
|
|
468
|
+
|
|
469
|
+
if mode == 'w':
|
|
470
|
+
# what frame are we currently reading or writing at?
|
|
471
|
+
self._frame_index = 0
|
|
472
|
+
# do we need to write the header information?
|
|
473
|
+
self._needs_initialization = True
|
|
474
|
+
|
|
475
|
+
elif mode == 'a':
|
|
476
|
+
try:
|
|
477
|
+
self._frame_index = len(self._handle.root.coordinates)
|
|
478
|
+
self._needs_initialization = False
|
|
479
|
+
except self.tables.NoSuchNodeError:
|
|
480
|
+
self._frame_index = 0
|
|
481
|
+
self._needs_initialization = True
|
|
482
|
+
elif mode == 'r':
|
|
483
|
+
self._frame_index = 0
|
|
484
|
+
self._needs_initialization = False
|
|
485
|
+
|
|
486
|
+
def read(self, frame_indices=None, atom_indices=None):
|
|
487
|
+
_check_mode(self.mode, ('r',))
|
|
488
|
+
|
|
489
|
+
if frame_indices is None:
|
|
490
|
+
frame_slice = slice(None)
|
|
491
|
+
self._frame_index += frame_slice.stop - frame_slice.start
|
|
492
|
+
else:
|
|
493
|
+
frame_slice = ensure_type(frame_indices, dtype=int, ndim=1, name='frame_indices', warn_on_cast=False)
|
|
494
|
+
if not np.all(frame_slice < self._handle.root.coordinates.shape[0]):
|
|
495
|
+
raise ValueError(
|
|
496
|
+
'As a zero-based index, the entries in '
|
|
497
|
+
'frame_slice must all be less than the number of frames '
|
|
498
|
+
'in the trajectory, %d' % self._handle.root.coordinates.shape[0]
|
|
499
|
+
)
|
|
500
|
+
if not np.all(frame_slice >= 0):
|
|
501
|
+
raise ValueError('The entries in frame_indices must be greater ' 'than or equal to zero')
|
|
502
|
+
self._frame_index += frame_slice[-1] - frame_slice[0]
|
|
503
|
+
|
|
504
|
+
if atom_indices is None:
|
|
505
|
+
# get all of the atoms
|
|
506
|
+
atom_slice = slice(None)
|
|
507
|
+
else:
|
|
508
|
+
atom_slice = ensure_type(atom_indices, dtype=int, ndim=1, name='atom_indices', warn_on_cast=False)
|
|
509
|
+
if not np.all(atom_slice < self._handle.root.coordinates.shape[1]):
|
|
510
|
+
raise ValueError(
|
|
511
|
+
'As a zero-based index, the entries in '
|
|
512
|
+
'atom_indices must all be less than the number of atoms '
|
|
513
|
+
'in the trajectory, %d' % self._handle.root.coordinates.shape[1]
|
|
514
|
+
)
|
|
515
|
+
if not np.all(atom_slice >= 0):
|
|
516
|
+
raise ValueError('The entries in atom_indices must be greater ' 'than or equal to zero')
|
|
517
|
+
|
|
518
|
+
def get_item(node, key):
|
|
519
|
+
if not isinstance(key, tuple):
|
|
520
|
+
return node.__getitem__(key)
|
|
521
|
+
|
|
522
|
+
n_list_like = 0
|
|
523
|
+
new_keys = []
|
|
524
|
+
for item in key:
|
|
525
|
+
if not isinstance(item, slice):
|
|
526
|
+
try:
|
|
527
|
+
d = np.diff(item)
|
|
528
|
+
if len(d) == 0:
|
|
529
|
+
item = item[0]
|
|
530
|
+
elif np.all(d == d[0]):
|
|
531
|
+
item = slice(item[0], item[-1] + d[0], d[0])
|
|
532
|
+
else:
|
|
533
|
+
n_list_like += 1
|
|
534
|
+
except Exception:
|
|
535
|
+
n_list_like += 1
|
|
536
|
+
new_keys.append(item)
|
|
537
|
+
new_keys = tuple(new_keys)
|
|
538
|
+
|
|
539
|
+
if n_list_like <= 1:
|
|
540
|
+
return node.__getitem__(new_keys)
|
|
541
|
+
|
|
542
|
+
data = node
|
|
543
|
+
for i, item in enumerate(new_keys):
|
|
544
|
+
dkey = [slice(None)] * len(key)
|
|
545
|
+
dkey[i] = item
|
|
546
|
+
dkey = tuple(dkey)
|
|
547
|
+
data = data.__getitem__(dkey)
|
|
548
|
+
|
|
549
|
+
return data
|
|
550
|
+
|
|
551
|
+
def get_field(name, slice, out_units, can_be_none=True):
|
|
552
|
+
try:
|
|
553
|
+
node = self._get_node(where='/', name=name)
|
|
554
|
+
data = get_item(node, slice)
|
|
555
|
+
in_units = node.attrs.units
|
|
556
|
+
if not isinstance(in_units, str):
|
|
557
|
+
in_units = in_units.decode()
|
|
558
|
+
data = in_units_of(data, in_units, out_units)
|
|
559
|
+
return data
|
|
560
|
+
except self.tables.NoSuchNodeError:
|
|
561
|
+
if can_be_none:
|
|
562
|
+
return None
|
|
563
|
+
raise
|
|
564
|
+
|
|
565
|
+
frames = Frames(
|
|
566
|
+
coordinates=get_field('coordinates', (frame_slice, atom_slice, slice(None)), out_units='nanometers', can_be_none=False),
|
|
567
|
+
time=get_field('time', frame_slice, out_units='picoseconds'),
|
|
568
|
+
cell_lengths=get_field('cell_lengths', (frame_slice, slice(None)), out_units='nanometers'),
|
|
569
|
+
cell_angles=get_field('cell_angles', (frame_slice, slice(None)), out_units='degrees'),
|
|
570
|
+
velocities=get_field('velocities', (frame_slice, atom_slice, slice(None)), out_units='nanometers/picosecond'),
|
|
571
|
+
kineticEnergy=get_field('kineticEnergy', frame_slice, out_units='kilojoules_per_mole'),
|
|
572
|
+
potentialEnergy=get_field('potentialEnergy', frame_slice, out_units='kilojoules_per_mole'),
|
|
573
|
+
temperature=get_field('temperature', frame_slice, out_units='kelvin'),
|
|
574
|
+
alchemicalLambda=get_field('lambda', frame_slice, out_units='dimensionless'),
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
return frames
|
|
578
|
+
|
|
579
|
+
def _has_node(self, where, name):
|
|
580
|
+
try:
|
|
581
|
+
self._get_node(where, name=name)
|
|
582
|
+
except self.tables.NoSuchNodeError:
|
|
583
|
+
return False
|
|
584
|
+
|
|
585
|
+
return True
|
|
586
|
+
|
|
587
|
+
def has_topology(self):
|
|
588
|
+
return self._has_node('/', 'topology')
|
|
589
|
+
|
|
590
|
+
def has_pointer(self):
|
|
591
|
+
return self._has_node('/', 'pointer')
|
|
592
|
+
|
|
593
|
+
def has_restart(self, segment):
|
|
594
|
+
return self._has_node('/restart', '%d_%d' % (segment.n_iter, segment.seg_id))
|
|
595
|
+
|
|
596
|
+
def write_data(self, where, name, data):
|
|
597
|
+
node = self._get_node(where=where, name=name)
|
|
598
|
+
node.append(data)
|
|
599
|
+
|
|
600
|
+
def read_data(self, where, name):
|
|
601
|
+
node = self._get_node(where=where, name=name)
|
|
602
|
+
return node.read()
|
|
603
|
+
|
|
604
|
+
def read_as_traj(self, iteration=None, segment=None, atom_indices=None):
|
|
605
|
+
_check_mode(self.mode, ('r',))
|
|
606
|
+
|
|
607
|
+
pnode = self._get_node(where='/', name='pointer')
|
|
608
|
+
|
|
609
|
+
iter_labels = pnode[:, 0]
|
|
610
|
+
seg_labels = pnode[:, 1]
|
|
611
|
+
|
|
612
|
+
if iteration is None and segment is None:
|
|
613
|
+
frame_indices = slice(None)
|
|
614
|
+
elif isinstance(iteration, (np.integer, int)) and isinstance(segment, (np.integer, int)):
|
|
615
|
+
frame_torf = np.logical_and(iter_labels == iteration, seg_labels == segment)
|
|
616
|
+
frame_indices = np.arange(len(iter_labels))[frame_torf]
|
|
617
|
+
else:
|
|
618
|
+
raise ValueError("iteration and segment must be integers and provided at the same time")
|
|
619
|
+
|
|
620
|
+
if len(frame_indices) == 0:
|
|
621
|
+
raise ValueError(f"no frame was selected: iteration={iteration}, segment={segment}, atom_indices={atom_indices}")
|
|
622
|
+
|
|
623
|
+
iter_labels = iter_labels[frame_indices]
|
|
624
|
+
seg_labels = seg_labels[frame_indices]
|
|
625
|
+
|
|
626
|
+
topology = self.topology
|
|
627
|
+
if atom_indices is not None:
|
|
628
|
+
topology = topology.subset(atom_indices)
|
|
629
|
+
|
|
630
|
+
data = self.read(frame_indices=frame_indices, atom_indices=atom_indices)
|
|
631
|
+
if len(data) == 0:
|
|
632
|
+
return Trajectory(xyz=np.zeros((0, topology.n_atoms, 3)), topology=topology)
|
|
633
|
+
|
|
634
|
+
in_units_of(data.coordinates, self.distance_unit, Trajectory._distance_unit, inplace=True)
|
|
635
|
+
in_units_of(data.cell_lengths, self.distance_unit, Trajectory._distance_unit, inplace=True)
|
|
636
|
+
|
|
637
|
+
return WESTTrajectory(
|
|
638
|
+
data.coordinates,
|
|
639
|
+
topology=topology,
|
|
640
|
+
time=data.time,
|
|
641
|
+
unitcell_lengths=data.cell_lengths,
|
|
642
|
+
unitcell_angles=data.cell_angles,
|
|
643
|
+
iter_labels=iter_labels,
|
|
644
|
+
seg_labels=seg_labels,
|
|
645
|
+
pcoords=None,
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
def read_restart(self, segment):
|
|
649
|
+
if self.has_restart(segment):
|
|
650
|
+
data = self.read_data('/restart/%d_%d' % (segment.n_iter, segment.seg_id), 'data')
|
|
651
|
+
segment.data['iterh5/restart'] = data
|
|
652
|
+
else:
|
|
653
|
+
raise ValueError('no restart data available for {}'.format(str(segment)))
|
|
654
|
+
|
|
655
|
+
def write_segment(self, segment, pop=False):
|
|
656
|
+
n_iter = segment.n_iter
|
|
657
|
+
|
|
658
|
+
self.root._v_attrs['n_iter'] = n_iter
|
|
659
|
+
|
|
660
|
+
if pop:
|
|
661
|
+
get_data = segment.data.pop
|
|
662
|
+
else:
|
|
663
|
+
get_data = segment.data.get
|
|
664
|
+
|
|
665
|
+
traj = get_data('iterh5/trajectory', None)
|
|
666
|
+
restart = get_data('iterh5/restart', None)
|
|
667
|
+
slog = get_data('iterh5/log', None)
|
|
668
|
+
|
|
669
|
+
if traj is not None:
|
|
670
|
+
# create trajectory object
|
|
671
|
+
traj = WESTTrajectory(traj, iter_labels=n_iter, seg_labels=segment.seg_id)
|
|
672
|
+
if traj.n_frames == 0:
|
|
673
|
+
# we may consider logging warnings instead throwing errors for later.
|
|
674
|
+
# right now this is good for debugging purposes
|
|
675
|
+
raise ValueError('no trajectory data present for %s' % repr(segment))
|
|
676
|
+
|
|
677
|
+
if n_iter == 0:
|
|
678
|
+
base_time = 0
|
|
679
|
+
else:
|
|
680
|
+
iter_duration = traj.time[-1] - traj.time[0]
|
|
681
|
+
base_time = iter_duration * (n_iter - 1)
|
|
682
|
+
|
|
683
|
+
traj.time -= traj.time[0]
|
|
684
|
+
traj.time += base_time
|
|
685
|
+
|
|
686
|
+
# pointers
|
|
687
|
+
if not self.has_pointer():
|
|
688
|
+
self._create_earray('/', name='pointer', atom=self.tables.Int64Atom(), shape=(0, 2))
|
|
689
|
+
|
|
690
|
+
iter_idx = traj.iter_labels
|
|
691
|
+
seg_idx = traj.seg_labels
|
|
692
|
+
|
|
693
|
+
pointers = np.stack((iter_idx, seg_idx)).T
|
|
694
|
+
|
|
695
|
+
self.write_data('/', 'pointer', pointers)
|
|
696
|
+
|
|
697
|
+
# trajectory
|
|
698
|
+
self.write(
|
|
699
|
+
coordinates=in_units_of(traj.xyz, Trajectory._distance_unit, self.distance_unit),
|
|
700
|
+
time=traj.time,
|
|
701
|
+
cell_lengths=in_units_of(traj.unitcell_lengths, Trajectory._distance_unit, self.distance_unit),
|
|
702
|
+
cell_angles=traj.unitcell_angles,
|
|
703
|
+
)
|
|
704
|
+
|
|
705
|
+
# topology
|
|
706
|
+
if self.mode == 'a':
|
|
707
|
+
if not self.has_topology():
|
|
708
|
+
self.topology = traj.topology
|
|
709
|
+
elif self.mode == 'w':
|
|
710
|
+
self.topology = traj.topology
|
|
711
|
+
|
|
712
|
+
# restart
|
|
713
|
+
if restart is not None:
|
|
714
|
+
if self.has_restart(segment):
|
|
715
|
+
self._remove_node('/restart', name='%d_%d' % (segment.n_iter, segment.seg_id), recursive=True)
|
|
716
|
+
|
|
717
|
+
self._create_array(
|
|
718
|
+
'/restart/%d_%d' % (segment.n_iter, segment.seg_id),
|
|
719
|
+
name='data',
|
|
720
|
+
atom=self.tables.StringAtom(itemsize=len(restart)),
|
|
721
|
+
obj=restart,
|
|
722
|
+
createparents=True,
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
if slog is not None:
|
|
726
|
+
if self._has_node('/log', str(segment.seg_id)):
|
|
727
|
+
self._remove_node('/log', name=str(segment.seg_id), recursive=True)
|
|
728
|
+
|
|
729
|
+
self._create_array(
|
|
730
|
+
'/log/%d_%d' % (segment.n_iter, segment.seg_id),
|
|
731
|
+
name='data',
|
|
732
|
+
atom=self.tables.StringAtom(itemsize=len(slog)),
|
|
733
|
+
obj=slog,
|
|
734
|
+
createparents=True,
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
@property
|
|
738
|
+
def _create_group(self):
|
|
739
|
+
if self.tables.__version__ >= '3.0.0':
|
|
740
|
+
return self._handle.create_group
|
|
741
|
+
return self._handle.createGroup
|
|
742
|
+
|
|
743
|
+
@property
|
|
744
|
+
def _create_array(self):
|
|
745
|
+
if self.tables.__version__ >= '3.0.0':
|
|
746
|
+
return self._handle.create_array
|
|
747
|
+
return self._handle.createArray
|
|
748
|
+
|
|
749
|
+
@property
|
|
750
|
+
def _remove_node(self):
|
|
751
|
+
if self.tables.__version__ >= '3.0.0':
|
|
752
|
+
return self._handle.remove_node
|
|
753
|
+
return self._handle.removeNode
|
|
754
|
+
|
|
755
|
+
|
|
756
|
+
### Generalized WE dataset access classes
|
|
757
|
+
|
|
758
|
+
|
|
759
|
+
class DSSpec:
|
|
760
|
+
'''Generalized WE dataset access'''
|
|
761
|
+
|
|
762
|
+
def get_iter_data(self, n_iter, seg_slice=index_exp[:]):
|
|
763
|
+
raise NotImplementedError
|
|
764
|
+
|
|
765
|
+
def get_segment_data(self, n_iter, seg_id):
|
|
766
|
+
return self.get_iter_data(n_iter)[seg_id]
|
|
767
|
+
|
|
768
|
+
def __getstate__(self):
|
|
769
|
+
d = dict(self.__dict__)
|
|
770
|
+
if '_h5file' in d:
|
|
771
|
+
d['_h5file'] = None
|
|
772
|
+
return d
|
|
773
|
+
|
|
774
|
+
def __setstate__(self, state):
|
|
775
|
+
self.__dict__.update(state)
|
|
776
|
+
|
|
777
|
+
|
|
778
|
+
class FileLinkedDSSpec(DSSpec):
|
|
779
|
+
'''Provide facilities for accessing WESTPA HDF5 files, including auto-opening and the ability
|
|
780
|
+
to pickle references to such files for transmission (through, e.g., the work manager), provided
|
|
781
|
+
that the HDF5 file can be accessed by the same path on both the sender and receiver.'''
|
|
782
|
+
|
|
783
|
+
def __init__(self, h5file_or_name):
|
|
784
|
+
self._h5file = None
|
|
785
|
+
self._h5filename = None
|
|
786
|
+
|
|
787
|
+
try:
|
|
788
|
+
self._h5filename = os.path.abspath(h5file_or_name.filename)
|
|
789
|
+
except AttributeError:
|
|
790
|
+
self._h5filename = h5file_or_name
|
|
791
|
+
self._h5file = None
|
|
792
|
+
else:
|
|
793
|
+
self._h5file = h5file_or_name
|
|
794
|
+
|
|
795
|
+
@property
|
|
796
|
+
def h5file(self):
|
|
797
|
+
'''Lazily open HDF5 file. This is required because allowing an open HDF5
|
|
798
|
+
file to cross a fork() boundary generally corrupts the internal state of
|
|
799
|
+
the HDF5 library.'''
|
|
800
|
+
if self._h5file is None:
|
|
801
|
+
self._h5file = WESTPAH5File(self._h5filename, 'r')
|
|
802
|
+
return self._h5file
|
|
803
|
+
|
|
804
|
+
|
|
805
|
+
class SingleDSSpec(FileLinkedDSSpec):
|
|
806
|
+
@classmethod
|
|
807
|
+
def from_string(cls, dsspec_string, default_h5file):
|
|
808
|
+
alias = None
|
|
809
|
+
|
|
810
|
+
h5file = default_h5file
|
|
811
|
+
fields = dsspec_string.split(',')
|
|
812
|
+
dsname = fields[0]
|
|
813
|
+
slice = None
|
|
814
|
+
|
|
815
|
+
for field in (field.strip() for field in fields[1:]):
|
|
816
|
+
k, v = field.split('=')
|
|
817
|
+
k = k.lower()
|
|
818
|
+
if k == 'alias':
|
|
819
|
+
alias = v
|
|
820
|
+
elif k == 'slice':
|
|
821
|
+
try:
|
|
822
|
+
slice = eval('np.index_exp' + v)
|
|
823
|
+
except SyntaxError:
|
|
824
|
+
raise SyntaxError('invalid index expression {!r}'.format(v))
|
|
825
|
+
elif k == 'file':
|
|
826
|
+
h5file = v
|
|
827
|
+
else:
|
|
828
|
+
raise ValueError('invalid dataset option {!r}'.format(k))
|
|
829
|
+
|
|
830
|
+
return cls(h5file, dsname, alias, slice)
|
|
831
|
+
|
|
832
|
+
def __init__(self, h5file_or_name, dsname, alias=None, slice=None):
|
|
833
|
+
FileLinkedDSSpec.__init__(self, h5file_or_name)
|
|
834
|
+
self.dsname = dsname
|
|
835
|
+
self.alias = alias or dsname
|
|
836
|
+
self.slice = np.index_exp[slice] if slice else None
|
|
837
|
+
|
|
838
|
+
|
|
839
|
+
class SingleIterDSSpec(SingleDSSpec):
|
|
840
|
+
def get_iter_data(self, n_iter, seg_slice=index_exp[:]):
|
|
841
|
+
if self.slice:
|
|
842
|
+
return self.h5file.get_iter_group(n_iter)[self.dsname][seg_slice + self.slice]
|
|
843
|
+
else:
|
|
844
|
+
return self.h5file.get_iter_group(n_iter)[self.dsname][seg_slice]
|
|
845
|
+
|
|
846
|
+
|
|
847
|
+
class SingleSegmentDSSpec(SingleDSSpec):
|
|
848
|
+
def get_iter_data(self, n_iter, seg_slice=index_exp[:]):
|
|
849
|
+
if self.slice:
|
|
850
|
+
return self.h5file.get_iter_group(n_iter)[self.dsname][seg_slice + index_exp[:] + self.slice]
|
|
851
|
+
else:
|
|
852
|
+
return self.h5file.get_iter_group(n_iter)[self.dsname][seg_slice]
|
|
853
|
+
|
|
854
|
+
def get_segment_data(self, n_iter, seg_id):
|
|
855
|
+
if self.slice:
|
|
856
|
+
return self.h5file.get_iter_group(n_iter)[np.index_exp[seg_id, :] + self.slice]
|
|
857
|
+
else:
|
|
858
|
+
return self.h5file.get_iter_group(n_iter)[seg_id]
|
|
859
|
+
|
|
860
|
+
|
|
861
|
+
class FnDSSpec(FileLinkedDSSpec):
|
|
862
|
+
def __init__(self, h5file_or_name, fn):
|
|
863
|
+
FileLinkedDSSpec.__init__(self, h5file_or_name)
|
|
864
|
+
self.fn = fn
|
|
865
|
+
|
|
866
|
+
def get_iter_data(self, n_iter, seg_slice=index_exp[:]):
|
|
867
|
+
return self.fn(n_iter, self.h5file.get_iter_group(n_iter))[seg_slice]
|
|
868
|
+
|
|
869
|
+
|
|
870
|
+
class MultiDSSpec(DSSpec):
|
|
871
|
+
def __init__(self, dsspecs):
|
|
872
|
+
self.dsspecs = dsspecs
|
|
873
|
+
|
|
874
|
+
def get_iter_data(self, n_iter, seg_slice=index_exp[:]):
|
|
875
|
+
datasets = [dsspec.get_iter_data(n_iter) for dsspec in self.dsspecs]
|
|
876
|
+
|
|
877
|
+
ncols = 0
|
|
878
|
+
nsegs = None
|
|
879
|
+
npts = None
|
|
880
|
+
for iset, dset in enumerate(datasets):
|
|
881
|
+
if nsegs is None:
|
|
882
|
+
nsegs = dset.shape[0]
|
|
883
|
+
elif dset.shape[0] != nsegs:
|
|
884
|
+
raise TypeError('dataset {} has incorrect first dimension (number of segments)'.format(self.dsspecs[iset]))
|
|
885
|
+
if npts is None:
|
|
886
|
+
npts = dset.shape[1]
|
|
887
|
+
elif dset.shape[1] != npts:
|
|
888
|
+
raise TypeError('dataset {} has incorrect second dimension (number of time points)'.format(self.dsspecs[iset]))
|
|
889
|
+
|
|
890
|
+
if dset.ndim < 2:
|
|
891
|
+
# scalar per segment or scalar per iteration
|
|
892
|
+
raise TypeError('dataset {} has too few dimensions'.format(self.dsspecs[iset]))
|
|
893
|
+
elif dset.ndim > 3:
|
|
894
|
+
# array per timepoint
|
|
895
|
+
raise TypeError('dataset {} has too many dimensions'.format(self.dsspecs[iset]))
|
|
896
|
+
elif dset.ndim == 2:
|
|
897
|
+
# scalar per timepoint
|
|
898
|
+
ncols += 1
|
|
899
|
+
else:
|
|
900
|
+
# vector per timepoint
|
|
901
|
+
ncols += dset.shape[-1]
|
|
902
|
+
|
|
903
|
+
output_dtype = np.result_type(*[ds.dtype for ds in datasets])
|
|
904
|
+
output_array = np.empty((nsegs, npts, ncols), dtype=output_dtype)
|
|
905
|
+
|
|
906
|
+
ocol = 0
|
|
907
|
+
for iset, dset in enumerate(datasets):
|
|
908
|
+
if dset.ndim == 2:
|
|
909
|
+
output_array[:, :, ocol] = dset[...]
|
|
910
|
+
ocol += 1
|
|
911
|
+
elif dset.ndim == 3:
|
|
912
|
+
output_array[:, :, ocol : (ocol + dset.shape[-1])] = dset[...]
|
|
913
|
+
ocol += dset.shape[-1]
|
|
914
|
+
|
|
915
|
+
return output_array[seg_slice]
|
|
916
|
+
|
|
917
|
+
|
|
918
|
+
class IterBlockedDataset:
|
|
919
|
+
@classmethod
|
|
920
|
+
def empty_like(cls, blocked_dataset):
|
|
921
|
+
source = blocked_dataset.data if blocked_dataset.data is not None else blocked_dataset.dataset
|
|
922
|
+
|
|
923
|
+
newbds = cls(
|
|
924
|
+
np.empty(source.shape, source.dtype),
|
|
925
|
+
attrs={'iter_start': blocked_dataset.iter_start, 'iter_stop': blocked_dataset.iter_stop},
|
|
926
|
+
)
|
|
927
|
+
return newbds
|
|
928
|
+
|
|
929
|
+
def __init__(self, dataset_or_array, attrs=None):
|
|
930
|
+
try:
|
|
931
|
+
dataset_or_array.attrs
|
|
932
|
+
except AttributeError:
|
|
933
|
+
self.dataset = None
|
|
934
|
+
self.data = dataset_or_array
|
|
935
|
+
if attrs is None:
|
|
936
|
+
raise ValueError('attribute dictionary containing iteration bounds must be provided')
|
|
937
|
+
self.iter_shape = self.data.shape[1:]
|
|
938
|
+
self.dtype = self.data.dtype
|
|
939
|
+
else:
|
|
940
|
+
self.dataset = dataset_or_array
|
|
941
|
+
attrs = self.dataset.attrs
|
|
942
|
+
self.data = None
|
|
943
|
+
self.iter_shape = self.dataset.shape[1:]
|
|
944
|
+
self.dtype = self.dataset.dtype
|
|
945
|
+
|
|
946
|
+
self.iter_start = attrs['iter_start']
|
|
947
|
+
self.iter_stop = attrs['iter_stop']
|
|
948
|
+
|
|
949
|
+
def cache_data(self, max_size=None):
|
|
950
|
+
'''Cache this dataset in RAM. If ``max_size`` is given, then only cache if the entire dataset
|
|
951
|
+
fits in ``max_size`` bytes. If ``max_size`` is the string 'available', then only cache if
|
|
952
|
+
the entire dataset fits in available RAM, as defined by the ``psutil`` module.'''
|
|
953
|
+
|
|
954
|
+
if max_size is not None:
|
|
955
|
+
dssize = self.dtype.itemsize * np.multiply.reduce(self.dataset.shape)
|
|
956
|
+
if max_size == 'available' and psutil is not None:
|
|
957
|
+
avail_bytes = psutil.virtual_memory().available
|
|
958
|
+
if dssize > avail_bytes:
|
|
959
|
+
return
|
|
960
|
+
elif isinstance(max_size, str):
|
|
961
|
+
return
|
|
962
|
+
else:
|
|
963
|
+
if dssize > max_size:
|
|
964
|
+
return
|
|
965
|
+
if self.dataset is not None:
|
|
966
|
+
if self.data is None:
|
|
967
|
+
self.data = self.dataset[...]
|
|
968
|
+
|
|
969
|
+
def drop_cache(self):
|
|
970
|
+
if self.dataset is not None:
|
|
971
|
+
del self.data
|
|
972
|
+
self.data = None
|
|
973
|
+
|
|
974
|
+
def iter_entry(self, n_iter):
|
|
975
|
+
if n_iter < self.iter_start:
|
|
976
|
+
raise IndexError('requested iteration {} less than first stored iteration {}'.format(n_iter, self.iter_start))
|
|
977
|
+
|
|
978
|
+
source = self.data if self.data is not None else self.dataset
|
|
979
|
+
return source[n_iter - self.iter_start]
|
|
980
|
+
|
|
981
|
+
def iter_slice(self, start=None, stop=None):
|
|
982
|
+
start = start or self.iter_start
|
|
983
|
+
stop = stop or self.iter_stop
|
|
984
|
+
step = 1 # strided retrieval not implemented yet
|
|
985
|
+
|
|
986
|
+
# if step % self.iter_step > 0:
|
|
987
|
+
# raise TypeError('dataset {!r} stored with stride {} cannot be accessed with stride {}'
|
|
988
|
+
# .format(self.dataset, self.iter_step, step))
|
|
989
|
+
if start < self.iter_start:
|
|
990
|
+
raise IndexError('requested start {} less than stored start {}'.format(start, self.iter_start))
|
|
991
|
+
elif stop > self.iter_stop:
|
|
992
|
+
stop = self.iter_stop
|
|
993
|
+
|
|
994
|
+
source = self.data if self.data is not None else self.dataset
|
|
995
|
+
return source[start - self.iter_start : stop - self.iter_start : step]
|