westpa 2022.13__cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. westpa/__init__.py +14 -0
  2. westpa/_version.py +21 -0
  3. westpa/analysis/__init__.py +5 -0
  4. westpa/analysis/core.py +749 -0
  5. westpa/analysis/statistics.py +27 -0
  6. westpa/analysis/trajectories.py +369 -0
  7. westpa/cli/__init__.py +0 -0
  8. westpa/cli/core/__init__.py +0 -0
  9. westpa/cli/core/w_fork.py +152 -0
  10. westpa/cli/core/w_init.py +230 -0
  11. westpa/cli/core/w_run.py +77 -0
  12. westpa/cli/core/w_states.py +212 -0
  13. westpa/cli/core/w_succ.py +99 -0
  14. westpa/cli/core/w_truncate.py +68 -0
  15. westpa/cli/tools/__init__.py +0 -0
  16. westpa/cli/tools/ploterr.py +506 -0
  17. westpa/cli/tools/plothist.py +706 -0
  18. westpa/cli/tools/w_assign.py +597 -0
  19. westpa/cli/tools/w_bins.py +166 -0
  20. westpa/cli/tools/w_crawl.py +119 -0
  21. westpa/cli/tools/w_direct.py +557 -0
  22. westpa/cli/tools/w_dumpsegs.py +94 -0
  23. westpa/cli/tools/w_eddist.py +506 -0
  24. westpa/cli/tools/w_fluxanl.py +376 -0
  25. westpa/cli/tools/w_ipa.py +832 -0
  26. westpa/cli/tools/w_kinavg.py +127 -0
  27. westpa/cli/tools/w_kinetics.py +96 -0
  28. westpa/cli/tools/w_multi_west.py +414 -0
  29. westpa/cli/tools/w_ntop.py +213 -0
  30. westpa/cli/tools/w_pdist.py +515 -0
  31. westpa/cli/tools/w_postanalysis_matrix.py +82 -0
  32. westpa/cli/tools/w_postanalysis_reweight.py +53 -0
  33. westpa/cli/tools/w_red.py +491 -0
  34. westpa/cli/tools/w_reweight.py +780 -0
  35. westpa/cli/tools/w_select.py +226 -0
  36. westpa/cli/tools/w_stateprobs.py +111 -0
  37. westpa/cli/tools/w_timings.py +113 -0
  38. westpa/cli/tools/w_trace.py +599 -0
  39. westpa/core/__init__.py +0 -0
  40. westpa/core/_rc.py +673 -0
  41. westpa/core/binning/__init__.py +55 -0
  42. westpa/core/binning/_assign.c +36018 -0
  43. westpa/core/binning/_assign.cpython-312-aarch64-linux-gnu.so +0 -0
  44. westpa/core/binning/_assign.pyx +370 -0
  45. westpa/core/binning/assign.py +454 -0
  46. westpa/core/binning/binless.py +96 -0
  47. westpa/core/binning/binless_driver.py +54 -0
  48. westpa/core/binning/binless_manager.py +189 -0
  49. westpa/core/binning/bins.py +47 -0
  50. westpa/core/binning/mab.py +506 -0
  51. westpa/core/binning/mab_driver.py +54 -0
  52. westpa/core/binning/mab_manager.py +197 -0
  53. westpa/core/data_manager.py +1761 -0
  54. westpa/core/extloader.py +74 -0
  55. westpa/core/h5io.py +1079 -0
  56. westpa/core/kinetics/__init__.py +24 -0
  57. westpa/core/kinetics/_kinetics.c +45174 -0
  58. westpa/core/kinetics/_kinetics.cpython-312-aarch64-linux-gnu.so +0 -0
  59. westpa/core/kinetics/_kinetics.pyx +815 -0
  60. westpa/core/kinetics/events.py +147 -0
  61. westpa/core/kinetics/matrates.py +156 -0
  62. westpa/core/kinetics/rate_averaging.py +266 -0
  63. westpa/core/progress.py +218 -0
  64. westpa/core/propagators/__init__.py +54 -0
  65. westpa/core/propagators/executable.py +592 -0
  66. westpa/core/propagators/loaders.py +196 -0
  67. westpa/core/reweight/__init__.py +14 -0
  68. westpa/core/reweight/_reweight.c +36899 -0
  69. westpa/core/reweight/_reweight.cpython-312-aarch64-linux-gnu.so +0 -0
  70. westpa/core/reweight/_reweight.pyx +439 -0
  71. westpa/core/reweight/matrix.py +126 -0
  72. westpa/core/segment.py +119 -0
  73. westpa/core/sim_manager.py +839 -0
  74. westpa/core/states.py +359 -0
  75. westpa/core/systems.py +93 -0
  76. westpa/core/textio.py +74 -0
  77. westpa/core/trajectory.py +603 -0
  78. westpa/core/we_driver.py +910 -0
  79. westpa/core/wm_ops.py +43 -0
  80. westpa/core/yamlcfg.py +298 -0
  81. westpa/fasthist/__init__.py +34 -0
  82. westpa/fasthist/_fasthist.c +38755 -0
  83. westpa/fasthist/_fasthist.cpython-312-aarch64-linux-gnu.so +0 -0
  84. westpa/fasthist/_fasthist.pyx +222 -0
  85. westpa/mclib/__init__.py +271 -0
  86. westpa/mclib/__main__.py +28 -0
  87. westpa/mclib/_mclib.c +34610 -0
  88. westpa/mclib/_mclib.cpython-312-aarch64-linux-gnu.so +0 -0
  89. westpa/mclib/_mclib.pyx +226 -0
  90. westpa/oldtools/__init__.py +4 -0
  91. westpa/oldtools/aframe/__init__.py +35 -0
  92. westpa/oldtools/aframe/atool.py +75 -0
  93. westpa/oldtools/aframe/base_mixin.py +26 -0
  94. westpa/oldtools/aframe/binning.py +178 -0
  95. westpa/oldtools/aframe/data_reader.py +560 -0
  96. westpa/oldtools/aframe/iter_range.py +200 -0
  97. westpa/oldtools/aframe/kinetics.py +117 -0
  98. westpa/oldtools/aframe/mcbs.py +153 -0
  99. westpa/oldtools/aframe/output.py +39 -0
  100. westpa/oldtools/aframe/plotting.py +88 -0
  101. westpa/oldtools/aframe/trajwalker.py +126 -0
  102. westpa/oldtools/aframe/transitions.py +469 -0
  103. westpa/oldtools/cmds/__init__.py +0 -0
  104. westpa/oldtools/cmds/w_ttimes.py +361 -0
  105. westpa/oldtools/files.py +34 -0
  106. westpa/oldtools/miscfn.py +23 -0
  107. westpa/oldtools/stats/__init__.py +4 -0
  108. westpa/oldtools/stats/accumulator.py +35 -0
  109. westpa/oldtools/stats/edfs.py +129 -0
  110. westpa/oldtools/stats/mcbs.py +96 -0
  111. westpa/tools/__init__.py +33 -0
  112. westpa/tools/binning.py +472 -0
  113. westpa/tools/core.py +340 -0
  114. westpa/tools/data_reader.py +159 -0
  115. westpa/tools/dtypes.py +31 -0
  116. westpa/tools/iter_range.py +198 -0
  117. westpa/tools/kinetics_tool.py +343 -0
  118. westpa/tools/plot.py +283 -0
  119. westpa/tools/progress.py +17 -0
  120. westpa/tools/selected_segs.py +154 -0
  121. westpa/tools/wipi.py +751 -0
  122. westpa/trajtree/__init__.py +4 -0
  123. westpa/trajtree/_trajtree.c +17829 -0
  124. westpa/trajtree/_trajtree.cpython-312-aarch64-linux-gnu.so +0 -0
  125. westpa/trajtree/_trajtree.pyx +130 -0
  126. westpa/trajtree/trajtree.py +117 -0
  127. westpa/westext/__init__.py +0 -0
  128. westpa/westext/adaptvoronoi/__init__.py +3 -0
  129. westpa/westext/adaptvoronoi/adaptVor_driver.py +214 -0
  130. westpa/westext/hamsm_restarting/__init__.py +3 -0
  131. westpa/westext/hamsm_restarting/example_overrides.py +35 -0
  132. westpa/westext/hamsm_restarting/restart_driver.py +1165 -0
  133. westpa/westext/stringmethod/__init__.py +11 -0
  134. westpa/westext/stringmethod/fourier_fitting.py +69 -0
  135. westpa/westext/stringmethod/string_driver.py +253 -0
  136. westpa/westext/stringmethod/string_method.py +306 -0
  137. westpa/westext/weed/BinCluster.py +180 -0
  138. westpa/westext/weed/ProbAdjustEquil.py +100 -0
  139. westpa/westext/weed/UncertMath.py +247 -0
  140. westpa/westext/weed/__init__.py +10 -0
  141. westpa/westext/weed/weed_driver.py +192 -0
  142. westpa/westext/wess/ProbAdjust.py +101 -0
  143. westpa/westext/wess/__init__.py +6 -0
  144. westpa/westext/wess/wess_driver.py +217 -0
  145. westpa/work_managers/__init__.py +57 -0
  146. westpa/work_managers/core.py +396 -0
  147. westpa/work_managers/environment.py +134 -0
  148. westpa/work_managers/mpi.py +318 -0
  149. westpa/work_managers/processes.py +201 -0
  150. westpa/work_managers/serial.py +28 -0
  151. westpa/work_managers/threads.py +79 -0
  152. westpa/work_managers/zeromq/__init__.py +20 -0
  153. westpa/work_managers/zeromq/core.py +635 -0
  154. westpa/work_managers/zeromq/node.py +131 -0
  155. westpa/work_managers/zeromq/work_manager.py +526 -0
  156. westpa/work_managers/zeromq/worker.py +320 -0
  157. westpa-2022.13.dist-info/METADATA +179 -0
  158. westpa-2022.13.dist-info/RECORD +162 -0
  159. westpa-2022.13.dist-info/WHEEL +7 -0
  160. westpa-2022.13.dist-info/entry_points.txt +30 -0
  161. westpa-2022.13.dist-info/licenses/LICENSE +21 -0
  162. westpa-2022.13.dist-info/top_level.txt +1 -0
westpa/core/h5io.py ADDED
@@ -0,0 +1,1079 @@
1
+ '''Miscellaneous routines to help with HDF5 input and output of WEST-related data.'''
2
+
3
+ import collections
4
+ import errno
5
+ import getpass
6
+ import os
7
+ import posixpath
8
+ import socket
9
+ import sys
10
+ import time
11
+ import logging
12
+ import warnings
13
+ from pathlib import Path
14
+
15
+ import h5py
16
+ import numpy as np
17
+ from numpy import index_exp
18
+ from tables import NaturalNameWarning
19
+
20
+ from mdtraj import Trajectory, join as join_traj
21
+ from mdtraj.utils import in_units_of, import_, ensure_type
22
+ from mdtraj.formats import HDF5TrajectoryFile
23
+ from mdtraj.formats.hdf5 import _check_mode, Frames
24
+
25
+ from .trajectory import WESTTrajectory
26
+
27
+ try:
28
+ import psutil
29
+ except ImportError:
30
+ psutil = None
31
+
32
+ log = logging.getLogger(__name__)
33
+ warnings.filterwarnings('ignore', category=NaturalNameWarning)
34
+
35
+ #
36
+ # Constants and globals
37
+ #
38
+ default_iter_prec = 8
39
+
40
+ #
41
+ # Helper functions
42
+ #
43
+
44
+
45
+ def resolve_filepath(path, constructor=h5py.File, cargs=None, ckwargs=None, **addtlkwargs):
46
+ '''Use a combined filesystem and HDF5 path to open an HDF5 file and return the
47
+ appropriate object. Returns (h5file, h5object). The file is opened using
48
+ ``constructor(filename, *cargs, **ckwargs)``.'''
49
+
50
+ cargs = cargs or ()
51
+ ckwargs = ckwargs or {}
52
+ ckwargs.update(addtlkwargs)
53
+ objpieces = collections.deque()
54
+ path = posixpath.normpath(path)
55
+
56
+ filepieces = path.split('/')
57
+ while filepieces:
58
+ testpath = '/'.join(filepieces)
59
+ if not testpath:
60
+ filepieces.pop()
61
+ continue
62
+ try:
63
+ h5file = constructor(testpath, *cargs, **ckwargs)
64
+ except IOError:
65
+ objpieces.appendleft(filepieces.pop())
66
+ continue
67
+ else:
68
+ return (h5file, h5file['/'.join([''] + list(objpieces)) if objpieces else '/'])
69
+ else:
70
+ # We don't provide a filename, because we're not sure where the filename stops
71
+ # and the HDF5 path begins.
72
+ raise IOError(errno.ENOENT, os.strerror(errno.ENOENT))
73
+
74
+
75
+ def calc_chunksize(shape, dtype, max_chunksize=262144):
76
+ '''Calculate a chunk size for HDF5 data, anticipating that access will slice
77
+ along lower dimensions sooner than higher dimensions.'''
78
+
79
+ chunk_shape = list(shape)
80
+ dtype = np.dtype(dtype)
81
+ for idim in range(len(shape)):
82
+ chunk_nbytes = np.multiply.reduce(chunk_shape) * dtype.itemsize
83
+ while chunk_shape[idim] > 1 and chunk_nbytes > max_chunksize:
84
+ chunk_shape[idim] >>= 1 # divide by 2
85
+ chunk_nbytes = np.multiply.reduce(chunk_shape) * dtype.itemsize
86
+
87
+ if chunk_nbytes <= max_chunksize:
88
+ break
89
+
90
+ chunk_shape = tuple(chunk_shape)
91
+ return chunk_shape
92
+
93
+
94
+ def tostr(b):
95
+ '''Convert a nonstandard string object ``b`` to str with the handling of the
96
+ case where ``b`` is bytes.'''
97
+
98
+ if b is None:
99
+ return None
100
+ elif isinstance(b, bytes):
101
+ return b.decode('utf-8')
102
+ else:
103
+ return str(b)
104
+
105
+
106
+ def is_within_directory(directory, target):
107
+ abs_directory = os.path.abspath(directory)
108
+ abs_target = os.path.abspath(target)
109
+
110
+ prefix = os.path.commonprefix([abs_directory, abs_target])
111
+
112
+ return prefix == abs_directory
113
+
114
+
115
+ def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
116
+ for member in tar.getmembers():
117
+ member_path = os.path.join(path, member.name)
118
+ if not is_within_directory(path, member_path):
119
+ raise Exception("Attempted Path Traversal in Tar File")
120
+
121
+ tar.extractall(path, members, numeric_owner=numeric_owner, filter='fully_trusted')
122
+
123
+
124
+ #
125
+ # Group and dataset manipulation functions
126
+ #
127
+
128
+
129
+ def create_hdf5_group(parent_group, groupname, replace=False, creating_program=None):
130
+ '''Create (or delete and recreate) and HDF5 group named ``groupname`` within
131
+ the enclosing Group (object) ``parent_group``. If ``replace`` is True, then
132
+ the group is replaced if present; if False, then an error is raised if the
133
+ group is present. After the group is created, HDF5 attributes are set using
134
+ `stamp_creator_data`.
135
+ '''
136
+
137
+ if replace:
138
+ try:
139
+ del parent_group[groupname]
140
+ except KeyError:
141
+ pass
142
+
143
+ newgroup = parent_group.create_group(groupname)
144
+ stamp_creator_data(newgroup)
145
+ return newgroup
146
+
147
+
148
+ #
149
+ # Group and dataset labeling functions
150
+ #
151
+
152
+
153
+ def stamp_creator_data(h5group, creating_program=None):
154
+ '''Mark the following on the HDF5 group ``h5group``:
155
+
156
+ :creation_program: The name of the program that created the group
157
+ :creation_user: The username of the user who created the group
158
+ :creation_hostname: The hostname of the machine on which the group was created
159
+ :creation_time: The date and time at which the group was created, in the
160
+ current locale.
161
+ :creation_unix_time: The Unix time (seconds from the epoch, UTC) at which the
162
+ group was created.
163
+
164
+ This is meant to facilitate tracking the flow of data, but should not be considered
165
+ a secure paper trail (after all, anyone with write access to the HDF5 file can modify
166
+ these attributes).
167
+ '''
168
+ now = time.time()
169
+ attrs = h5group.attrs
170
+
171
+ attrs['creation_program'] = creating_program or sys.argv[0] or 'unknown program'
172
+ attrs['creation_user'] = getpass.getuser()
173
+ attrs['creation_hostname'] = socket.gethostname()
174
+ attrs['creation_unix_time'] = now
175
+ attrs['creation_time'] = time.strftime('%c', time.localtime(now))
176
+
177
+
178
+ def get_creator_data(h5group):
179
+ '''Read back creator data as written by ``stamp_creator_data``, returning a dictionary with
180
+ keys as described for ``stamp_creator_data``. Missing fields are denoted with None.
181
+ The ``creation_time`` field is returned as a string.'''
182
+ attrs = h5group.attrs
183
+ d = dict()
184
+ for attr in ['creation_program', 'creation_user', 'creation_hostname', 'creation_unix_time', 'creation_time']:
185
+ d[attr] = attrs.get(attr)
186
+ return d
187
+
188
+
189
+ def load_west(filename):
190
+ """Load WESTPA trajectory files from disk.
191
+
192
+ Parameters
193
+ ----------
194
+ filename : str
195
+ String filename of HDF Trajectory file.
196
+ """
197
+
198
+ with h5py.File(filename, 'r') as f:
199
+ iter_group_template = 'iter_{1:0{0}d}'
200
+ iter_prec = f.attrs['west_iter_prec']
201
+ trajectories = []
202
+ n = 0
203
+
204
+ iter_group_name = iter_group_template.format(iter_prec, n)
205
+ for iter_group_name in f['iterations']:
206
+ iter_group = f['iterations/' + iter_group_name]
207
+
208
+ if 'trajectories' in iter_group:
209
+ traj_link = iter_group['trajectories']
210
+ traj_filename = traj_link.file.filename
211
+
212
+ with WESTIterationFile(traj_filename) as traj_file:
213
+ traj = traj_file.read_as_traj()
214
+ else:
215
+ # TODO: [HDF5] allow initializing trajectory without coordinates
216
+ raise ValueError("Missing trajectories for iteration %d" % n)
217
+
218
+ # pcoord is required
219
+ if 'pcoord' not in iter_group:
220
+ raise ValueError("Missing pcoords for iteration %d" % n)
221
+
222
+ raw_pcoord = iter_group['pcoord'][:]
223
+ if raw_pcoord.ndim != 3:
224
+ log.warning('pcoord is expected to be a 3-d ndarray instead of {}-d'.format(raw_pcoord.ndim))
225
+ continue
226
+ # ignore the first frame of each segment
227
+ if raw_pcoord.shape[1] == traj.n_frames + 1:
228
+ raw_pcoord = raw_pcoord[:, 1:, :]
229
+ elif raw_pcoord.shape[1] == traj.n_frames:
230
+ raw_pcoord = raw_pcoord[:, :, :]
231
+ else:
232
+ raise ValueError(
233
+ "Inconsistent number of pcoords (%d) and frames (%d) for iteration %d" % (raw_pcoord.shape[1], traj.n_frames, n)
234
+ )
235
+
236
+ pcoords = np.concatenate(raw_pcoord, axis=0)
237
+ n_frames = raw_pcoord.shape[1]
238
+
239
+ if 'seg_index' in iter_group:
240
+ raw_pid = iter_group['seg_index']['parent_id'][:]
241
+
242
+ if np.any(raw_pid < 0):
243
+ init_basis_ids = iter_group['ibstates']['istate_index']['basis_state_id'][:]
244
+ init_ids = -(raw_pid[raw_pid < 0] + 1)
245
+ raw_pid[raw_pid < 0] = [init_basis_ids[iid] for iid in init_ids]
246
+ parent_ids = raw_pid.repeat(n_frames, axis=0)
247
+ else:
248
+ parent_ids = None
249
+
250
+ traj.pcoords = pcoords
251
+ traj.parent_ids = parent_ids
252
+ trajectories.append(traj)
253
+
254
+ n += 1
255
+ iter_group_name = iter_group_template.format(iter_prec, n)
256
+
257
+ west_traj = join_traj(trajectories)
258
+
259
+ return west_traj
260
+
261
+
262
+ ###
263
+ # Iteration range metadata
264
+ ###
265
+ def stamp_iter_range(h5object, start_iter, stop_iter):
266
+ '''Mark that the HDF5 object ``h5object`` (dataset or group) contains data from iterations
267
+ start_iter <= n_iter < stop_iter.'''
268
+ h5object.attrs['iter_start'] = start_iter
269
+ h5object.attrs['iter_stop'] = stop_iter
270
+
271
+
272
+ def get_iter_range(h5object):
273
+ '''Read back iteration range data written by ``stamp_iter_range``'''
274
+ return int(h5object.attrs['iter_start']), int(h5object.attrs['iter_stop'])
275
+
276
+
277
+ def stamp_iter_step(h5group, iter_step):
278
+ '''Mark that the HDF5 object ``h5object`` (dataset or group) contains data with an
279
+ iteration step (stride) of iter_step).'''
280
+ h5group.attrs['iter_step'] = iter_step
281
+
282
+
283
+ def get_iter_step(h5group):
284
+ '''Read back iteration step (stride) written by ``stamp_iter_step``'''
285
+ return int(h5group.attrs['iter_step'])
286
+
287
+
288
+ def check_iter_range_least(h5object, iter_start, iter_stop):
289
+ '''Return True if the iteration range [iter_start, iter_stop) is
290
+ the same as or entirely contained within the iteration range stored
291
+ on ``h5object``.'''
292
+ obj_iter_start, obj_iter_stop = get_iter_range(h5object)
293
+ return obj_iter_start <= iter_start and obj_iter_stop >= iter_stop
294
+
295
+
296
+ def check_iter_range_equal(h5object, iter_start, iter_stop):
297
+ '''Return True if the iteration range [iter_start, iter_stop) is
298
+ the same as the iteration range stored on ``h5object``.'''
299
+ obj_iter_start, obj_iter_stop = get_iter_range(h5object)
300
+ return obj_iter_start == iter_start and obj_iter_stop == iter_stop
301
+
302
+
303
+ def get_iteration_entry(h5object, n_iter):
304
+ '''Create a slice for data corresponding to iteration ``n_iter`` in ``h5object``.'''
305
+ obj_iter_start, obj_iter_stop = get_iter_range(h5object)
306
+ if n_iter < obj_iter_start or n_iter >= obj_iter_stop:
307
+ raise IndexError('data for iteration {} not available in dataset {!r}'.format(n_iter, h5object))
308
+ return np.index_exp[n_iter - obj_iter_start]
309
+
310
+
311
+ def get_iteration_slice(h5object, iter_start, iter_stop=None, iter_stride=None):
312
+ '''Create a slice for data corresponding to iterations [iter_start,iter_stop),
313
+ with stride iter_step, in the given ``h5object``.'''
314
+ obj_iter_start, obj_iter_stop = get_iter_range(h5object)
315
+
316
+ if iter_stop is None:
317
+ iter_stop = iter_start + 1
318
+ if iter_stride is None:
319
+ iter_stride = 1
320
+
321
+ if iter_start < obj_iter_start:
322
+ raise IndexError('data for iteration {} not available in dataset {!r}'.format(iter_start, h5object))
323
+ elif iter_start > obj_iter_stop:
324
+ raise IndexError('data for iteration {} not available in dataset {!r}'.format(iter_stop, h5object))
325
+
326
+ start_index = iter_start - obj_iter_start
327
+ stop_index = iter_stop - obj_iter_start
328
+ return np.index_exp[start_index:stop_index:iter_stride]
329
+
330
+
331
+ ###
332
+ # Axis label metadata
333
+ ###
334
+ def label_axes(h5object, labels, units=None):
335
+ '''Stamp the given HDF5 object with axis labels. This stores the axis labels
336
+ in an array of strings in an attribute called ``axis_labels`` on the given
337
+ object. ``units`` if provided is a corresponding list of units.'''
338
+
339
+ if len(labels) != len(h5object.shape):
340
+ raise ValueError('number of axes and number of labels do not match')
341
+
342
+ if units is None:
343
+ units = []
344
+
345
+ if len(units) and len(units) != len(labels):
346
+ raise ValueError('number of units labels does not match number of axes')
347
+
348
+ h5object.attrs['axis_labels'] = np.array([np.bytes_(i) for i in labels])
349
+
350
+ if len(units):
351
+ h5object.attrs['axis_units'] = np.array([np.bytes_(i) for i in units])
352
+
353
+
354
+ NotGiven = object()
355
+
356
+
357
+ def _get_one_attr(h5object, namelist, default=NotGiven):
358
+ attrs = dict(h5object.attrs)
359
+ for name in namelist:
360
+ try:
361
+ return attrs[name]
362
+ except KeyError:
363
+ pass
364
+ else:
365
+ if default is NotGiven:
366
+ raise KeyError('no such key')
367
+ else:
368
+ return default
369
+
370
+
371
+ class WESTPAH5File(h5py.File):
372
+ '''Generalized input/output for WESTPA simulation (or analysis) data.'''
373
+
374
+ default_iter_prec = 8
375
+ _this_fileformat_version = 8
376
+
377
+ def __init__(self, *args, **kwargs):
378
+ # These values are used for creating files or reading files where this
379
+ # data is not stored. Otherwise, values stored as attributes on the root
380
+ # group are used instead.
381
+ arg_iter_prec = kwargs.pop('westpa_iter_prec', self.default_iter_prec)
382
+ arg_fileformat_version = kwargs.pop('westpa_fileformat_version', self._this_fileformat_version)
383
+ arg_creating_program = kwargs.pop('creating_program', None)
384
+
385
+ # Initialize h5py file
386
+ super().__init__(*args, **kwargs)
387
+
388
+ # Try to get iteration precision and I/O class version
389
+ h5file_iter_prec = _get_one_attr(self, ['westpa_iter_prec', 'west_iter_prec', 'wemd_iter_prec'], None)
390
+ h5file_fileformat_version = _get_one_attr(
391
+ self, ['westpa_fileformat_version', 'west_file_format_version', 'wemd_file_format_version'], None
392
+ )
393
+
394
+ self.iter_prec = h5file_iter_prec if h5file_iter_prec is not None else arg_iter_prec
395
+ self.fileformat_version = h5file_fileformat_version if h5file_fileformat_version is not None else arg_fileformat_version
396
+
397
+ # Ensure that file format attributes are stored, if the file is writable
398
+ if self.mode == 'r+':
399
+ self.attrs['westpa_iter_prec'] = self.iter_prec
400
+ self.attrs['westpa_fileformat_version'] = self.fileformat_version
401
+ if arg_creating_program:
402
+ stamp_creator_data(self, creating_program=arg_creating_program)
403
+
404
+ # Helper function to automatically replace a group, if it exists.
405
+ # Should really only be called when one is certain a dataset should be blown away.
406
+ def replace_dataset(self, *args, **kwargs):
407
+ try:
408
+ del self[args[0]]
409
+ except Exception:
410
+ pass
411
+ try:
412
+ del self[kwargs['name']]
413
+ except Exception:
414
+ pass
415
+
416
+ return self.create_dataset(*args, **kwargs)
417
+
418
+ # Iteration groups
419
+
420
+ def iter_object_name(self, n_iter, prefix='', suffix=''):
421
+ '''Return a properly-formatted per-iteration name for iteration
422
+ ``n_iter``. (This is used in create/require/get_iter_group, but may
423
+ also be useful for naming datasets on a per-iteration basis.)'''
424
+ return '{prefix}iter_{n_iter:0{prec}d}{suffix}'.format(n_iter=n_iter, prefix=prefix, suffix=suffix, prec=self.iter_prec)
425
+
426
+ def create_iter_group(self, n_iter, group=None):
427
+ '''Create a per-iteration data storage group for iteration number ``n_iter``
428
+ in the group ``group`` (which is '/iterations' by default).'''
429
+
430
+ if group is None:
431
+ group = self.require_group('/iterations')
432
+ return group.create_group(self.iter_object_name(n_iter))
433
+
434
+ def require_iter_group(self, n_iter, group=None):
435
+ '''Ensure that a per-iteration data storage group for iteration number ``n_iter``
436
+ is available in the group ``group`` (which is '/iterations' by default).'''
437
+ if group is None:
438
+ group = self.require_group('/iterations')
439
+ return group.require_group(self.iter_object_name(n_iter))
440
+
441
+ def get_iter_group(self, n_iter, group=None):
442
+ '''Get the per-iteration data group for iteration number ``n_iter`` from within
443
+ the group ``group`` ('/iterations' by default).'''
444
+ if group is None:
445
+ group = self['/iterations']
446
+ return group[self.iter_object_name(n_iter)]
447
+
448
+
449
+ class WESTIterationFile(HDF5TrajectoryFile):
450
+ def __init__(self, file, mode='r', force_overwrite=True, compression='zlib', link=None):
451
+ if isinstance(file, (str, Path)):
452
+ super(WESTIterationFile, self).__init__(file, mode, force_overwrite, compression)
453
+ else:
454
+ try:
455
+ self._init_from_handle(file) # If a WESTIterationFile object, just make sure it's open correctly
456
+ except AttributeError:
457
+ raise ValueError('unknown input type: %s' % str(type(file)))
458
+
459
+ def _init_from_handle(self, handle: HDF5TrajectoryFile):
460
+ self._handle = handle
461
+ self._open = handle.isopen != 0
462
+ self.mode = mode = handle.mode # the mode in which the file was opened?
463
+
464
+ if mode not in ['r', 'w', 'a']:
465
+ raise ValueError("mode must be one of ['r', 'w', 'a']")
466
+
467
+ # import tables
468
+ self.tables = import_('tables')
469
+
470
+ if mode == 'w':
471
+ # what frame are we currently reading or writing at?
472
+ self._frame_index = 0
473
+ # do we need to write the header information?
474
+ self._needs_initialization = True
475
+
476
+ elif mode == 'a':
477
+ try:
478
+ self._frame_index = len(self._handle.root.coordinates)
479
+ self._needs_initialization = False
480
+ except self.tables.NoSuchNodeError:
481
+ self._frame_index = 0
482
+ self._needs_initialization = True
483
+ elif mode == 'r':
484
+ self._frame_index = 0
485
+ self._needs_initialization = False
486
+
487
+ def __contains__(self, path):
488
+ try:
489
+ self._get_node('/', path)
490
+ except self.tables.NoSuchNodeError:
491
+ return False
492
+
493
+ return True
494
+
495
+ def read(self, frame_indices=None, atom_indices=None):
496
+ _check_mode(self.mode, ('r',))
497
+
498
+ if frame_indices is None:
499
+ frame_slice = slice(None)
500
+ self._frame_index += frame_slice.stop - frame_slice.start
501
+ else:
502
+ frame_slice = ensure_type(frame_indices, dtype=int, ndim=1, name='frame_indices', warn_on_cast=False)
503
+ if not np.all(frame_slice < self._handle.root.coordinates.shape[0]):
504
+ raise ValueError(
505
+ 'As a zero-based index, the entries in '
506
+ 'frame_slice must all be less than the number of frames '
507
+ 'in the trajectory, %d' % self._handle.root.coordinates.shape[0]
508
+ )
509
+ if not np.all(frame_slice >= 0):
510
+ raise ValueError('The entries in frame_indices must be greater ' 'than or equal to zero')
511
+ self._frame_index += frame_slice[-1] - frame_slice[0]
512
+
513
+ if atom_indices is None:
514
+ # get all of the atoms
515
+ atom_slice = slice(None)
516
+ else:
517
+ atom_slice = ensure_type(atom_indices, dtype=int, ndim=1, name='atom_indices', warn_on_cast=False)
518
+ if not np.all(atom_slice < self._handle.root.coordinates.shape[1]):
519
+ raise ValueError(
520
+ 'As a zero-based index, the entries in '
521
+ 'atom_indices must all be less than the number of atoms '
522
+ 'in the trajectory, %d' % self._handle.root.coordinates.shape[1]
523
+ )
524
+ if not np.all(atom_slice >= 0):
525
+ raise ValueError('The entries in atom_indices must be greater ' 'than or equal to zero')
526
+
527
+ def get_item(node, key):
528
+ if not isinstance(key, tuple):
529
+ return node.__getitem__(key)
530
+
531
+ n_list_like = 0
532
+ new_keys = []
533
+ for item in key:
534
+ if not isinstance(item, slice):
535
+ try:
536
+ d = np.diff(item)
537
+ if len(d) == 0:
538
+ item = item[0]
539
+ elif np.all(d == d[0]):
540
+ item = slice(item[0], item[-1] + d[0], d[0])
541
+ else:
542
+ n_list_like += 1
543
+ except Exception:
544
+ n_list_like += 1
545
+ new_keys.append(item)
546
+ new_keys = tuple(new_keys)
547
+
548
+ if n_list_like <= 1:
549
+ return node.__getitem__(new_keys)
550
+
551
+ data = node
552
+ for i, item in enumerate(new_keys):
553
+ dkey = [slice(None)] * len(key)
554
+ dkey[i] = item
555
+ dkey = tuple(dkey)
556
+ data = data.__getitem__(dkey)
557
+
558
+ return data
559
+
560
+ def get_field(name, slice, out_units, can_be_none=True):
561
+ try:
562
+ node = self._get_node(where='/', name=name)
563
+ data = get_item(node, slice)
564
+ in_units = node.attrs.units
565
+ if not isinstance(in_units, str):
566
+ in_units = in_units.decode()
567
+ data = in_units_of(data, in_units, out_units)
568
+ return data
569
+ except self.tables.NoSuchNodeError:
570
+ if can_be_none:
571
+ return None
572
+ raise
573
+
574
+ frames = Frames(
575
+ coordinates=get_field('coordinates', (frame_slice, atom_slice, slice(None)), out_units='nanometers', can_be_none=False),
576
+ time=get_field('time', frame_slice, out_units='picoseconds'),
577
+ cell_lengths=get_field('cell_lengths', (frame_slice, slice(None)), out_units='nanometers'),
578
+ cell_angles=get_field('cell_angles', (frame_slice, slice(None)), out_units='degrees'),
579
+ velocities=get_field('velocities', (frame_slice, atom_slice, slice(None)), out_units='nanometers/picosecond'),
580
+ kineticEnergy=get_field('kineticEnergy', frame_slice, out_units='kilojoules_per_mole'),
581
+ potentialEnergy=get_field('potentialEnergy', frame_slice, out_units='kilojoules_per_mole'),
582
+ temperature=get_field('temperature', frame_slice, out_units='kelvin'),
583
+ alchemicalLambda=get_field('lambda', frame_slice, out_units='dimensionless'),
584
+ )
585
+
586
+ return frames
587
+
588
+ def _has_node(self, where, name):
589
+ try:
590
+ self._get_node(where, name=name)
591
+ except self.tables.NoSuchNodeError:
592
+ return False
593
+
594
+ return True
595
+
596
+ def has_topology(self):
597
+ return self._has_node('/', 'topology')
598
+
599
+ def has_pointer(self):
600
+ return self._has_node('/', 'pointer')
601
+
602
+ def has_restart(self, segment):
603
+ return self._has_node('/restart', '%d_%d' % (segment.n_iter, segment.seg_id))
604
+
605
+ def has_seglog(self, segment):
606
+ return self._has_node('/log', '%d_%d' % (segment.n_iter, segment.seg_id))
607
+
608
+ def write_data(self, where, name, data):
609
+ node = self._get_node(where=where, name=name)
610
+ node.append(data)
611
+
612
+ def read_data(self, where, name):
613
+ node = self._get_node(where=where, name=name)
614
+ return node.read()
615
+
616
+ def read_as_traj(self, iteration=None, segment=None, atom_indices=None):
617
+ _check_mode(self.mode, ('r',))
618
+
619
+ pnode = self._get_node(where='/', name='pointer')
620
+
621
+ iter_labels = pnode[:, 0]
622
+ seg_labels = pnode[:, 1]
623
+
624
+ if iteration is None and segment is None:
625
+ frame_indices = slice(None)
626
+ elif isinstance(iteration, (np.integer, int)) and isinstance(segment, (np.integer, int)):
627
+ frame_torf = np.logical_and(iter_labels == iteration, seg_labels == segment)
628
+ frame_indices = np.arange(len(iter_labels))[frame_torf]
629
+ else:
630
+ raise ValueError("iteration and segment must be integers and provided at the same time")
631
+
632
+ if len(frame_indices) == 0:
633
+ raise ValueError(f"no frame was selected: iteration={iteration}, segment={segment}, atom_indices={atom_indices}")
634
+
635
+ iter_labels = iter_labels[frame_indices]
636
+ seg_labels = seg_labels[frame_indices]
637
+
638
+ topology = self.topology
639
+ if atom_indices is not None:
640
+ topology = topology.subset(atom_indices)
641
+
642
+ data = self.read(frame_indices=frame_indices, atom_indices=atom_indices)
643
+ if len(data) == 0:
644
+ return Trajectory(xyz=np.zeros((0, topology.n_atoms, 3)), topology=topology)
645
+
646
+ in_units_of(data.coordinates, self.distance_unit, Trajectory._distance_unit, inplace=True)
647
+ in_units_of(data.cell_lengths, self.distance_unit, Trajectory._distance_unit, inplace=True)
648
+
649
+ return WESTTrajectory(
650
+ data.coordinates,
651
+ topology=topology,
652
+ time=data.time,
653
+ unitcell_lengths=data.cell_lengths,
654
+ unitcell_angles=data.cell_angles,
655
+ iter_labels=iter_labels,
656
+ seg_labels=seg_labels,
657
+ pcoords=None,
658
+ )
659
+
660
+ def read_restart(self, segment):
661
+ if self.has_restart(segment):
662
+ data = self.read_data('/restart/%d_%d' % (segment.n_iter, segment.seg_id), 'data')
663
+ segment.data['iterh5/restart'] = data
664
+ else:
665
+ raise ValueError('no restart data available for {}'.format(str(segment)))
666
+
667
+ def read_seglog(self, segment):
668
+ if self.has_seglog(segment):
669
+ data = self.read_data('/log/%d_%d' % (segment.niter, segment.seg_id), 'data')
670
+ segment.data['iterh5/log'] = data
671
+ else:
672
+ raise ValueError('no log data available for {}'.format(str(segment)))
673
+
674
+ def write_segment(self, segment, pop=False):
675
+ n_iter = segment.n_iter
676
+
677
+ self.root._v_attrs['n_iter'] = n_iter
678
+
679
+ if pop:
680
+ get_data = segment.data.pop
681
+ else:
682
+ get_data = segment.data.get
683
+
684
+ traj = get_data('iterh5/trajectory', None)
685
+ restart = get_data('iterh5/restart', None)
686
+ slog = get_data('iterh5/log', None)
687
+
688
+ if traj is not None:
689
+ # create trajectory object or if already is, skip.
690
+ if not isinstance(traj, WESTTrajectory):
691
+ traj = WESTTrajectory(traj, iter_labels=n_iter, seg_labels=segment.seg_id)
692
+ if traj.n_frames == 0:
693
+ # we may consider logging warnings instead throwing errors for later.
694
+ # right now this is good for debugging purposes
695
+ raise ValueError('no trajectory data present for %s' % repr(segment))
696
+
697
+ if n_iter == 0:
698
+ base_time = 0
699
+ else:
700
+ iter_duration = traj.time[-1] - traj.time[0]
701
+ base_time = iter_duration * (n_iter - 1)
702
+
703
+ traj.time -= traj.time[0]
704
+ traj.time += base_time
705
+
706
+ # pointers
707
+ if not self.has_pointer():
708
+ self._create_earray('/', name='pointer', atom=self.tables.Int64Atom(), shape=(0, 2))
709
+
710
+ try:
711
+ existing_labels = np.where(self.root['pointer'][:, 1] == traj.seg_labels[0])[0]
712
+ except ValueError:
713
+ existing_labels = []
714
+
715
+ iter_idx = traj.iter_labels
716
+ seg_idx = traj.seg_labels
717
+
718
+ pointers = np.stack((iter_idx, seg_idx)).T
719
+ # This is to deal with the case where the simulation ended mid-iteration and some segment data are saved but itself not marked as complete yet.
720
+ if [segment.seg_id] in self.root['pointer'][:, 1]:
721
+ needed_extra = len(pointers) - len(existing_labels)
722
+ # If previous run did not save all frames
723
+ if needed_extra > 0:
724
+ # Write extra rows for extra frames
725
+ self.write_data('/', 'pointer', pointers[-needed_extra:])
726
+
727
+ # write trajectory for the extra rows
728
+ output_dict = {
729
+ 'coordinates': in_units_of(traj.xyz[-needed_extra:], Trajectory._distance_unit, self.distance_unit),
730
+ 'time': traj.time[-needed_extra:],
731
+ }
732
+ if traj.unitcell_lengths:
733
+ output_dict['cell_lengths'] = in_units_of(
734
+ traj.unitcell_lengths[-needed_extra:], Trajectory._distance_unit, self.distance_unit
735
+ )
736
+ output_dict['cell_angles'] = (traj.unitcell_angles[-needed_extra:],)
737
+
738
+ self.write(**output_dict)
739
+ needed_extra = len(existing_labels)
740
+ elif needed_extra < 0:
741
+ # Extra frames found, turning pointer for those rows to sentinel
742
+ log.warning(
743
+ f'Extra frames for segment {n_iter}_{segment.seg_id} found in WESTIterationFile. Overwriting extra frame pointers with sentinal [-n_iter, seg_id].'
744
+ )
745
+ for row_idx in range(needed_extra, 0):
746
+ self.root['pointer'][existing_labels[row_idx]] = [-n_iter, segment.seg_id]
747
+ needed_extra = len(pointers)
748
+ else:
749
+ # Number of Frames match. None will return all frames and traj.
750
+ needed_extra = None
751
+
752
+ # Replace existing rows with corresponding data, up until specified
753
+ self.replace_frames(existing_labels[:needed_extra], traj[:needed_extra])
754
+
755
+ else:
756
+ # Write pointers
757
+ self.write_data('/', 'pointer', pointers)
758
+
759
+ # trajectory
760
+ self.write(
761
+ coordinates=in_units_of(traj.xyz, Trajectory._distance_unit, self.distance_unit),
762
+ time=traj.time,
763
+ cell_lengths=in_units_of(traj.unitcell_lengths, Trajectory._distance_unit, self.distance_unit),
764
+ cell_angles=traj.unitcell_angles,
765
+ )
766
+
767
+ # topology
768
+ try:
769
+ if self.mode == 'a':
770
+ if not self.has_topology():
771
+ self.topology = traj.topology
772
+ elif self.mode == 'w':
773
+ self.topology = traj.topology
774
+ except (ModuleNotFoundError, ImportError):
775
+ pass
776
+
777
+ # restart
778
+ if restart is not None:
779
+ if self.has_restart(segment):
780
+ self._remove_node('/restart', name='%d_%d' % (segment.n_iter, segment.seg_id), recursive=True)
781
+
782
+ self._create_array(
783
+ '/restart/%d_%d' % (segment.n_iter, segment.seg_id),
784
+ name='data',
785
+ atom=self.tables.StringAtom(itemsize=len(restart)),
786
+ obj=restart,
787
+ createparents=True,
788
+ )
789
+
790
+ if slog is not None:
791
+ if self._has_node('/log', name='%d_%d' % (segment.n_iter, segment.seg_id)):
792
+ self._remove_node('/log', name='%d_%d' % (segment.n_iter, segment.seg_id), recursive=True)
793
+
794
+ self._create_array(
795
+ '/log/%d_%d' % (segment.n_iter, segment.seg_id),
796
+ name='data',
797
+ atom=self.tables.StringAtom(itemsize=len(slog)),
798
+ obj=slog,
799
+ createparents=True,
800
+ )
801
+
802
+ def scrub_data(self):
803
+ '''Method to remove existing coordinates, pointers etc. while preserving topology'''
804
+ for node in ['log', 'restart', 'time', 'coordinates', 'pointer', 'cell_angles', 'cell_lengths']:
805
+ try:
806
+ self._remove_node('/', node, recursive=True)
807
+ except self.tables.exceptions.NoSuchNodeError:
808
+ pass
809
+ self._frame_index = 0
810
+ self.root._v_attrs.n_iter = 0
811
+ self.flush()
812
+
813
+ def replace_frames(self, rows, traj):
814
+ datasets = {'coordinates': 'xyz', 'time': 'time', 'cell_angles': 'unitcell_angles', 'cell_lengths': 'unitcell_lengths'}
815
+
816
+ for ptkey, mdkey in datasets.items():
817
+ if self._has_node('/', ptkey) and getattr(traj, mdkey) is not None:
818
+ for frame_idx, row in enumerate(rows):
819
+ self.root[ptkey][row] = getattr(traj, mdkey)[frame_idx]
820
+
821
+ @property
822
+ def _create_group(self):
823
+ if self.tables.__version__ >= '3.0.0':
824
+ return self._handle.create_group
825
+ return self._handle.createGroup
826
+
827
+ @property
828
+ def _create_array(self):
829
+ if self.tables.__version__ >= '3.0.0':
830
+ return self._handle.create_array
831
+ return self._handle.createArray
832
+
833
+ @property
834
+ def _remove_node(self):
835
+ if self.tables.__version__ >= '3.0.0':
836
+ return self._handle.remove_node
837
+ return self._handle.removeNode
838
+
839
+
840
+ ### Generalized WE dataset access classes
841
+
842
+
843
+ class DSSpec:
844
+ '''Generalized WE dataset access'''
845
+
846
+ def get_iter_data(self, n_iter, seg_slice=index_exp[:]):
847
+ raise NotImplementedError
848
+
849
+ def get_segment_data(self, n_iter, seg_id):
850
+ return self.get_iter_data(n_iter)[seg_id]
851
+
852
+ def __getstate__(self):
853
+ d = dict(self.__dict__)
854
+ if '_h5file' in d:
855
+ d['_h5file'] = None
856
+ return d
857
+
858
+ def __setstate__(self, state):
859
+ self.__dict__.update(state)
860
+
861
+
862
+ class FileLinkedDSSpec(DSSpec):
863
+ '''Provide facilities for accessing WESTPA HDF5 files, including auto-opening and the ability
864
+ to pickle references to such files for transmission (through, e.g., the work manager), provided
865
+ that the HDF5 file can be accessed by the same path on both the sender and receiver.'''
866
+
867
+ def __init__(self, h5file_or_name):
868
+ self._h5file = None
869
+ self._h5filename = None
870
+
871
+ try:
872
+ self._h5filename = os.path.abspath(h5file_or_name.filename)
873
+ except AttributeError:
874
+ self._h5filename = h5file_or_name
875
+ self._h5file = None
876
+ else:
877
+ self._h5file = h5file_or_name
878
+
879
+ @property
880
+ def h5file(self):
881
+ '''Lazily open HDF5 file. This is required because allowing an open HDF5
882
+ file to cross a fork() boundary generally corrupts the internal state of
883
+ the HDF5 library.'''
884
+ if self._h5file is None:
885
+ self._h5file = WESTPAH5File(self._h5filename, 'r')
886
+ return self._h5file
887
+
888
+
889
+ class SingleDSSpec(FileLinkedDSSpec):
890
+ @classmethod
891
+ def from_string(cls, dsspec_string, default_h5file):
892
+ alias = None
893
+
894
+ h5file = default_h5file
895
+ fields = dsspec_string.split(',')
896
+ dsname = fields[0]
897
+ slice = None
898
+
899
+ for field in (field.strip() for field in fields[1:]):
900
+ k, v = field.split('=')
901
+ k = k.lower()
902
+ if k == 'alias':
903
+ alias = v
904
+ elif k == 'slice':
905
+ try:
906
+ slice = eval('np.index_exp' + v)
907
+ except SyntaxError:
908
+ raise SyntaxError('invalid index expression {!r}'.format(v))
909
+ elif k == 'file':
910
+ h5file = v
911
+ else:
912
+ raise ValueError('invalid dataset option {!r}'.format(k))
913
+
914
+ return cls(h5file, dsname, alias, slice)
915
+
916
+ def __init__(self, h5file_or_name, dsname, alias=None, slice=None):
917
+ FileLinkedDSSpec.__init__(self, h5file_or_name)
918
+ self.dsname = dsname
919
+ self.alias = alias or dsname
920
+ self.slice = np.index_exp[slice] if slice else None
921
+
922
+
923
+ class SingleIterDSSpec(SingleDSSpec):
924
+ def get_iter_data(self, n_iter, seg_slice=index_exp[:]):
925
+ if self.slice:
926
+ return self.h5file.get_iter_group(n_iter)[self.dsname][seg_slice + self.slice]
927
+ else:
928
+ return self.h5file.get_iter_group(n_iter)[self.dsname][seg_slice]
929
+
930
+
931
+ class SingleSegmentDSSpec(SingleDSSpec):
932
+ def get_iter_data(self, n_iter, seg_slice=index_exp[:]):
933
+ if self.slice:
934
+ return self.h5file.get_iter_group(n_iter)[self.dsname][seg_slice + index_exp[:] + self.slice]
935
+ else:
936
+ return self.h5file.get_iter_group(n_iter)[self.dsname][seg_slice]
937
+
938
+ def get_segment_data(self, n_iter, seg_id):
939
+ if self.slice:
940
+ return self.h5file.get_iter_group(n_iter)[np.index_exp[seg_id, :] + self.slice]
941
+ else:
942
+ return self.h5file.get_iter_group(n_iter)[seg_id]
943
+
944
+
945
+ class FnDSSpec(FileLinkedDSSpec):
946
+ def __init__(self, h5file_or_name, fn):
947
+ FileLinkedDSSpec.__init__(self, h5file_or_name)
948
+ self.fn = fn
949
+
950
+ def get_iter_data(self, n_iter, seg_slice=index_exp[:]):
951
+ return self.fn(n_iter, self.h5file.get_iter_group(n_iter))[seg_slice]
952
+
953
+
954
+ class MultiDSSpec(DSSpec):
955
+ def __init__(self, dsspecs):
956
+ self.dsspecs = dsspecs
957
+
958
+ def get_iter_data(self, n_iter, seg_slice=index_exp[:]):
959
+ datasets = [dsspec.get_iter_data(n_iter) for dsspec in self.dsspecs]
960
+
961
+ ncols = 0
962
+ nsegs = None
963
+ npts = None
964
+ for iset, dset in enumerate(datasets):
965
+ if nsegs is None:
966
+ nsegs = dset.shape[0]
967
+ elif dset.shape[0] != nsegs:
968
+ raise TypeError('dataset {} has incorrect first dimension (number of segments)'.format(self.dsspecs[iset]))
969
+ if npts is None:
970
+ npts = dset.shape[1]
971
+ elif dset.shape[1] != npts:
972
+ raise TypeError('dataset {} has incorrect second dimension (number of time points)'.format(self.dsspecs[iset]))
973
+
974
+ if dset.ndim < 2:
975
+ # scalar per segment or scalar per iteration
976
+ raise TypeError('dataset {} has too few dimensions'.format(self.dsspecs[iset]))
977
+ elif dset.ndim > 3:
978
+ # array per timepoint
979
+ raise TypeError('dataset {} has too many dimensions'.format(self.dsspecs[iset]))
980
+ elif dset.ndim == 2:
981
+ # scalar per timepoint
982
+ ncols += 1
983
+ else:
984
+ # vector per timepoint
985
+ ncols += dset.shape[-1]
986
+
987
+ output_dtype = np.result_type(*[ds.dtype for ds in datasets])
988
+ output_array = np.empty((nsegs, npts, ncols), dtype=output_dtype)
989
+
990
+ ocol = 0
991
+ for iset, dset in enumerate(datasets):
992
+ if dset.ndim == 2:
993
+ output_array[:, :, ocol] = dset[...]
994
+ ocol += 1
995
+ elif dset.ndim == 3:
996
+ output_array[:, :, ocol : (ocol + dset.shape[-1])] = dset[...]
997
+ ocol += dset.shape[-1]
998
+
999
+ return output_array[seg_slice]
1000
+
1001
+
1002
+ class IterBlockedDataset:
1003
+ @classmethod
1004
+ def empty_like(cls, blocked_dataset):
1005
+ source = blocked_dataset.data if blocked_dataset.data is not None else blocked_dataset.dataset
1006
+
1007
+ newbds = cls(
1008
+ np.empty(source.shape, source.dtype),
1009
+ attrs={'iter_start': blocked_dataset.iter_start, 'iter_stop': blocked_dataset.iter_stop},
1010
+ )
1011
+ return newbds
1012
+
1013
+ def __init__(self, dataset_or_array, attrs=None):
1014
+ try:
1015
+ dataset_or_array.attrs
1016
+ except AttributeError:
1017
+ self.dataset = None
1018
+ self.data = dataset_or_array
1019
+ if attrs is None:
1020
+ raise ValueError('attribute dictionary containing iteration bounds must be provided')
1021
+ self.iter_shape = self.data.shape[1:]
1022
+ self.dtype = self.data.dtype
1023
+ else:
1024
+ self.dataset = dataset_or_array
1025
+ attrs = self.dataset.attrs
1026
+ self.data = None
1027
+ self.iter_shape = self.dataset.shape[1:]
1028
+ self.dtype = self.dataset.dtype
1029
+
1030
+ self.iter_start = attrs['iter_start']
1031
+ self.iter_stop = attrs['iter_stop']
1032
+
1033
+ def cache_data(self, max_size=None):
1034
+ '''Cache this dataset in RAM. If ``max_size`` is given, then only cache if the entire dataset
1035
+ fits in ``max_size`` bytes. If ``max_size`` is the string 'available', then only cache if
1036
+ the entire dataset fits in available RAM, as defined by the ``psutil`` module.'''
1037
+
1038
+ if max_size is not None:
1039
+ dssize = self.dtype.itemsize * np.multiply.reduce(self.dataset.shape)
1040
+ if max_size == 'available' and psutil is not None:
1041
+ avail_bytes = psutil.virtual_memory().available
1042
+ if dssize > avail_bytes:
1043
+ return
1044
+ elif isinstance(max_size, str):
1045
+ return
1046
+ else:
1047
+ if dssize > max_size:
1048
+ return
1049
+ if self.dataset is not None:
1050
+ if self.data is None:
1051
+ self.data = self.dataset[...]
1052
+
1053
+ def drop_cache(self):
1054
+ if self.dataset is not None:
1055
+ del self.data
1056
+ self.data = None
1057
+
1058
+ def iter_entry(self, n_iter):
1059
+ if n_iter < self.iter_start:
1060
+ raise IndexError('requested iteration {} less than first stored iteration {}'.format(n_iter, self.iter_start))
1061
+
1062
+ source = self.data if self.data is not None else self.dataset
1063
+ return source[n_iter - self.iter_start]
1064
+
1065
+ def iter_slice(self, start=None, stop=None):
1066
+ start = start or self.iter_start
1067
+ stop = stop or self.iter_stop
1068
+ step = 1 # strided retrieval not implemented yet
1069
+
1070
+ # if step % self.iter_step > 0:
1071
+ # raise TypeError('dataset {!r} stored with stride {} cannot be accessed with stride {}'
1072
+ # .format(self.dataset, self.iter_step, step))
1073
+ if start < self.iter_start:
1074
+ raise IndexError('requested start {} less than stored start {}'.format(start, self.iter_start))
1075
+ elif stop > self.iter_stop:
1076
+ stop = self.iter_stop
1077
+
1078
+ source = self.data if self.data is not None else self.dataset
1079
+ return source[start - self.iter_start : stop - self.iter_start : step]