westpa 2022.9__cp38-cp38-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of westpa might be problematic. Click here for more details.

Files changed (150) hide show
  1. westpa/__init__.py +14 -0
  2. westpa/_version.py +21 -0
  3. westpa/analysis/__init__.py +5 -0
  4. westpa/analysis/core.py +746 -0
  5. westpa/analysis/statistics.py +27 -0
  6. westpa/analysis/trajectories.py +360 -0
  7. westpa/cli/__init__.py +0 -0
  8. westpa/cli/core/__init__.py +0 -0
  9. westpa/cli/core/w_fork.py +152 -0
  10. westpa/cli/core/w_init.py +230 -0
  11. westpa/cli/core/w_run.py +77 -0
  12. westpa/cli/core/w_states.py +212 -0
  13. westpa/cli/core/w_succ.py +99 -0
  14. westpa/cli/core/w_truncate.py +59 -0
  15. westpa/cli/tools/__init__.py +0 -0
  16. westpa/cli/tools/ploterr.py +506 -0
  17. westpa/cli/tools/plothist.py +706 -0
  18. westpa/cli/tools/w_assign.py +596 -0
  19. westpa/cli/tools/w_bins.py +166 -0
  20. westpa/cli/tools/w_crawl.py +119 -0
  21. westpa/cli/tools/w_direct.py +547 -0
  22. westpa/cli/tools/w_dumpsegs.py +94 -0
  23. westpa/cli/tools/w_eddist.py +506 -0
  24. westpa/cli/tools/w_fluxanl.py +378 -0
  25. westpa/cli/tools/w_ipa.py +833 -0
  26. westpa/cli/tools/w_kinavg.py +127 -0
  27. westpa/cli/tools/w_kinetics.py +96 -0
  28. westpa/cli/tools/w_multi_west.py +414 -0
  29. westpa/cli/tools/w_ntop.py +213 -0
  30. westpa/cli/tools/w_pdist.py +515 -0
  31. westpa/cli/tools/w_postanalysis_matrix.py +82 -0
  32. westpa/cli/tools/w_postanalysis_reweight.py +53 -0
  33. westpa/cli/tools/w_red.py +486 -0
  34. westpa/cli/tools/w_reweight.py +780 -0
  35. westpa/cli/tools/w_select.py +226 -0
  36. westpa/cli/tools/w_stateprobs.py +111 -0
  37. westpa/cli/tools/w_trace.py +599 -0
  38. westpa/core/__init__.py +0 -0
  39. westpa/core/_rc.py +673 -0
  40. westpa/core/binning/__init__.py +55 -0
  41. westpa/core/binning/_assign.cpython-38-darwin.so +0 -0
  42. westpa/core/binning/assign.py +449 -0
  43. westpa/core/binning/binless.py +96 -0
  44. westpa/core/binning/binless_driver.py +54 -0
  45. westpa/core/binning/binless_manager.py +190 -0
  46. westpa/core/binning/bins.py +47 -0
  47. westpa/core/binning/mab.py +427 -0
  48. westpa/core/binning/mab_driver.py +54 -0
  49. westpa/core/binning/mab_manager.py +198 -0
  50. westpa/core/data_manager.py +1694 -0
  51. westpa/core/extloader.py +74 -0
  52. westpa/core/h5io.py +996 -0
  53. westpa/core/kinetics/__init__.py +24 -0
  54. westpa/core/kinetics/_kinetics.cpython-38-darwin.so +0 -0
  55. westpa/core/kinetics/events.py +147 -0
  56. westpa/core/kinetics/matrates.py +156 -0
  57. westpa/core/kinetics/rate_averaging.py +266 -0
  58. westpa/core/progress.py +218 -0
  59. westpa/core/propagators/__init__.py +54 -0
  60. westpa/core/propagators/executable.py +715 -0
  61. westpa/core/reweight/__init__.py +14 -0
  62. westpa/core/reweight/_reweight.cpython-38-darwin.so +0 -0
  63. westpa/core/reweight/matrix.py +126 -0
  64. westpa/core/segment.py +119 -0
  65. westpa/core/sim_manager.py +830 -0
  66. westpa/core/states.py +359 -0
  67. westpa/core/systems.py +93 -0
  68. westpa/core/textio.py +74 -0
  69. westpa/core/trajectory.py +330 -0
  70. westpa/core/we_driver.py +908 -0
  71. westpa/core/wm_ops.py +43 -0
  72. westpa/core/yamlcfg.py +391 -0
  73. westpa/fasthist/__init__.py +34 -0
  74. westpa/fasthist/__main__.py +110 -0
  75. westpa/fasthist/_fasthist.cpython-38-darwin.so +0 -0
  76. westpa/mclib/__init__.py +264 -0
  77. westpa/mclib/__main__.py +28 -0
  78. westpa/mclib/_mclib.cpython-38-darwin.so +0 -0
  79. westpa/oldtools/__init__.py +4 -0
  80. westpa/oldtools/aframe/__init__.py +35 -0
  81. westpa/oldtools/aframe/atool.py +75 -0
  82. westpa/oldtools/aframe/base_mixin.py +26 -0
  83. westpa/oldtools/aframe/binning.py +178 -0
  84. westpa/oldtools/aframe/data_reader.py +560 -0
  85. westpa/oldtools/aframe/iter_range.py +200 -0
  86. westpa/oldtools/aframe/kinetics.py +117 -0
  87. westpa/oldtools/aframe/mcbs.py +146 -0
  88. westpa/oldtools/aframe/output.py +39 -0
  89. westpa/oldtools/aframe/plotting.py +90 -0
  90. westpa/oldtools/aframe/trajwalker.py +126 -0
  91. westpa/oldtools/aframe/transitions.py +469 -0
  92. westpa/oldtools/cmds/__init__.py +0 -0
  93. westpa/oldtools/cmds/w_ttimes.py +358 -0
  94. westpa/oldtools/files.py +34 -0
  95. westpa/oldtools/miscfn.py +23 -0
  96. westpa/oldtools/stats/__init__.py +4 -0
  97. westpa/oldtools/stats/accumulator.py +35 -0
  98. westpa/oldtools/stats/edfs.py +129 -0
  99. westpa/oldtools/stats/mcbs.py +89 -0
  100. westpa/tools/__init__.py +33 -0
  101. westpa/tools/binning.py +472 -0
  102. westpa/tools/core.py +340 -0
  103. westpa/tools/data_reader.py +159 -0
  104. westpa/tools/dtypes.py +31 -0
  105. westpa/tools/iter_range.py +198 -0
  106. westpa/tools/kinetics_tool.py +340 -0
  107. westpa/tools/plot.py +283 -0
  108. westpa/tools/progress.py +17 -0
  109. westpa/tools/selected_segs.py +154 -0
  110. westpa/tools/wipi.py +751 -0
  111. westpa/trajtree/__init__.py +4 -0
  112. westpa/trajtree/_trajtree.cpython-38-darwin.so +0 -0
  113. westpa/trajtree/trajtree.py +117 -0
  114. westpa/westext/__init__.py +0 -0
  115. westpa/westext/adaptvoronoi/__init__.py +3 -0
  116. westpa/westext/adaptvoronoi/adaptVor_driver.py +214 -0
  117. westpa/westext/hamsm_restarting/__init__.py +3 -0
  118. westpa/westext/hamsm_restarting/example_overrides.py +35 -0
  119. westpa/westext/hamsm_restarting/restart_driver.py +1165 -0
  120. westpa/westext/stringmethod/__init__.py +11 -0
  121. westpa/westext/stringmethod/fourier_fitting.py +69 -0
  122. westpa/westext/stringmethod/string_driver.py +253 -0
  123. westpa/westext/stringmethod/string_method.py +306 -0
  124. westpa/westext/weed/BinCluster.py +180 -0
  125. westpa/westext/weed/ProbAdjustEquil.py +100 -0
  126. westpa/westext/weed/UncertMath.py +247 -0
  127. westpa/westext/weed/__init__.py +10 -0
  128. westpa/westext/weed/weed_driver.py +182 -0
  129. westpa/westext/wess/ProbAdjust.py +101 -0
  130. westpa/westext/wess/__init__.py +6 -0
  131. westpa/westext/wess/wess_driver.py +207 -0
  132. westpa/work_managers/__init__.py +57 -0
  133. westpa/work_managers/core.py +396 -0
  134. westpa/work_managers/environment.py +134 -0
  135. westpa/work_managers/mpi.py +318 -0
  136. westpa/work_managers/processes.py +187 -0
  137. westpa/work_managers/serial.py +28 -0
  138. westpa/work_managers/threads.py +79 -0
  139. westpa/work_managers/zeromq/__init__.py +20 -0
  140. westpa/work_managers/zeromq/core.py +641 -0
  141. westpa/work_managers/zeromq/node.py +131 -0
  142. westpa/work_managers/zeromq/work_manager.py +526 -0
  143. westpa/work_managers/zeromq/worker.py +320 -0
  144. westpa-2022.9.dist-info/AUTHORS +22 -0
  145. westpa-2022.9.dist-info/LICENSE +21 -0
  146. westpa-2022.9.dist-info/METADATA +183 -0
  147. westpa-2022.9.dist-info/RECORD +150 -0
  148. westpa-2022.9.dist-info/WHEEL +5 -0
  149. westpa-2022.9.dist-info/entry_points.txt +29 -0
  150. westpa-2022.9.dist-info/top_level.txt +1 -0
westpa/core/h5io.py ADDED
@@ -0,0 +1,996 @@
1
+ '''Miscellaneous routines to help with HDF5 input and output of WEST-related data.'''
2
+
3
+ import collections
4
+ import errno
5
+ import getpass
6
+ import os
7
+ import posixpath
8
+ import socket
9
+ import sys
10
+ import time
11
+ import logging
12
+ import warnings
13
+
14
+ import h5py
15
+ import numpy as np
16
+ from numpy import index_exp
17
+ from tables import NaturalNameWarning
18
+
19
+ from mdtraj import Trajectory, join as join_traj
20
+ from mdtraj.utils import in_units_of, import_, ensure_type
21
+ from mdtraj.utils.six import string_types
22
+ from mdtraj.formats import HDF5TrajectoryFile
23
+ from mdtraj.formats.hdf5 import _check_mode, Frames
24
+
25
+ from .trajectory import WESTTrajectory
26
+
27
+ try:
28
+ import psutil
29
+ except ImportError:
30
+ psutil = None
31
+
32
+ log = logging.getLogger(__name__)
33
+ warnings.filterwarnings('ignore', category=NaturalNameWarning)
34
+
35
+ #
36
+ # Constants and globals
37
+ #
38
+ default_iter_prec = 8
39
+
40
+ #
41
+ # Helper functions
42
+ #
43
+
44
+
45
+ def resolve_filepath(path, constructor=h5py.File, cargs=None, ckwargs=None, **addtlkwargs):
46
+ '''Use a combined filesystem and HDF5 path to open an HDF5 file and return the
47
+ appropriate object. Returns (h5file, h5object). The file is opened using
48
+ ``constructor(filename, *cargs, **ckwargs)``.'''
49
+
50
+ cargs = cargs or ()
51
+ ckwargs = ckwargs or {}
52
+ ckwargs.update(addtlkwargs)
53
+ objpieces = collections.deque()
54
+ path = posixpath.normpath(path)
55
+
56
+ filepieces = path.split('/')
57
+ while filepieces:
58
+ testpath = '/'.join(filepieces)
59
+ if not testpath:
60
+ filepieces.pop()
61
+ continue
62
+ try:
63
+ h5file = constructor(testpath, *cargs, **ckwargs)
64
+ except IOError:
65
+ objpieces.appendleft(filepieces.pop())
66
+ continue
67
+ else:
68
+ return (h5file, h5file['/'.join([''] + list(objpieces)) if objpieces else '/'])
69
+ else:
70
+ # We don't provide a filename, because we're not sure where the filename stops
71
+ # and the HDF5 path begins.
72
+ raise IOError(errno.ENOENT, os.strerror(errno.ENOENT))
73
+
74
+
75
+ def calc_chunksize(shape, dtype, max_chunksize=262144):
76
+ '''Calculate a chunk size for HDF5 data, anticipating that access will slice
77
+ along lower dimensions sooner than higher dimensions.'''
78
+
79
+ chunk_shape = list(shape)
80
+ dtype = np.dtype(dtype)
81
+ for idim in range(len(shape)):
82
+ chunk_nbytes = np.multiply.reduce(chunk_shape) * dtype.itemsize
83
+ while chunk_shape[idim] > 1 and chunk_nbytes > max_chunksize:
84
+ chunk_shape[idim] >>= 1 # divide by 2
85
+ chunk_nbytes = np.multiply.reduce(chunk_shape) * dtype.itemsize
86
+
87
+ if chunk_nbytes <= max_chunksize:
88
+ break
89
+
90
+ chunk_shape = tuple(chunk_shape)
91
+ return chunk_shape
92
+
93
+
94
+ def tostr(b):
95
+ '''Convert a nonstandard string object ``b`` to str with the handling of the
96
+ case where ``b`` is bytes.'''
97
+
98
+ if b is None:
99
+ return None
100
+ elif isinstance(b, bytes):
101
+ return b.decode('utf-8')
102
+ else:
103
+ return str(b)
104
+
105
+
106
+ def is_within_directory(directory, target):
107
+ abs_directory = os.path.abspath(directory)
108
+ abs_target = os.path.abspath(target)
109
+
110
+ prefix = os.path.commonprefix([abs_directory, abs_target])
111
+
112
+ return prefix == abs_directory
113
+
114
+
115
+ def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
116
+ for member in tar.getmembers():
117
+ member_path = os.path.join(path, member.name)
118
+ if not is_within_directory(path, member_path):
119
+ raise Exception("Attempted Path Traversal in Tar File")
120
+
121
+ tar.extractall(path, members, numeric_owner=numeric_owner)
122
+
123
+
124
+ #
125
+ # Group and dataset manipulation functions
126
+ #
127
+
128
+
129
+ def create_hdf5_group(parent_group, groupname, replace=False, creating_program=None):
130
+ '''Create (or delete and recreate) and HDF5 group named ``groupname`` within
131
+ the enclosing Group (object) ``parent_group``. If ``replace`` is True, then
132
+ the group is replaced if present; if False, then an error is raised if the
133
+ group is present. After the group is created, HDF5 attributes are set using
134
+ `stamp_creator_data`.
135
+ '''
136
+
137
+ if replace:
138
+ try:
139
+ del parent_group[groupname]
140
+ except KeyError:
141
+ pass
142
+
143
+ newgroup = parent_group.create_group(groupname)
144
+ stamp_creator_data(newgroup)
145
+ return newgroup
146
+
147
+
148
+ #
149
+ # Group and dataset labeling functions
150
+ #
151
+
152
+
153
+ def stamp_creator_data(h5group, creating_program=None):
154
+ '''Mark the following on the HDF5 group ``h5group``:
155
+
156
+ :creation_program: The name of the program that created the group
157
+ :creation_user: The username of the user who created the group
158
+ :creation_hostname: The hostname of the machine on which the group was created
159
+ :creation_time: The date and time at which the group was created, in the
160
+ current locale.
161
+ :creation_unix_time: The Unix time (seconds from the epoch, UTC) at which the
162
+ group was created.
163
+
164
+ This is meant to facilitate tracking the flow of data, but should not be considered
165
+ a secure paper trail (after all, anyone with write access to the HDF5 file can modify
166
+ these attributes).
167
+ '''
168
+ now = time.time()
169
+ attrs = h5group.attrs
170
+
171
+ attrs['creation_program'] = creating_program or sys.argv[0] or 'unknown program'
172
+ attrs['creation_user'] = getpass.getuser()
173
+ attrs['creation_hostname'] = socket.gethostname()
174
+ attrs['creation_unix_time'] = now
175
+ attrs['creation_time'] = time.strftime('%c', time.localtime(now))
176
+
177
+
178
+ def get_creator_data(h5group):
179
+ '''Read back creator data as written by ``stamp_creator_data``, returning a dictionary with
180
+ keys as described for ``stamp_creator_data``. Missing fields are denoted with None.
181
+ The ``creation_time`` field is returned as a string.'''
182
+ attrs = h5group.attrs
183
+ d = dict()
184
+ for attr in ['creation_program', 'creation_user', 'creation_hostname', 'creation_unix_time', 'creation_time']:
185
+ d[attr] = attrs.get(attr)
186
+ return d
187
+
188
+
189
+ def load_west(filename):
190
+ """Load WESTPA trajectory files from disk.
191
+
192
+ Parameters
193
+ ----------
194
+ filename : str
195
+ String filename of HDF Trajectory file.
196
+ """
197
+
198
+ with h5py.File(filename, 'r') as f:
199
+ iter_group_template = 'iter_{1:0{0}d}'
200
+ iter_prec = f.attrs['west_iter_prec']
201
+ trajectories = []
202
+ n = 0
203
+
204
+ iter_group_name = iter_group_template.format(iter_prec, n)
205
+ for iter_group_name in f['iterations']:
206
+ iter_group = f['iterations/' + iter_group_name]
207
+
208
+ if 'trajectories' in iter_group:
209
+ traj_link = iter_group['trajectories']
210
+ traj_filename = traj_link.file.filename
211
+
212
+ with WESTIterationFile(traj_filename) as traj_file:
213
+ traj = traj_file.read_as_traj()
214
+ else:
215
+ # TODO: [HDF5] allow initializing trajectory without coordinates
216
+ raise ValueError("Missing trajectories for iteration %d" % n)
217
+
218
+ # pcoord is required
219
+ if 'pcoord' not in iter_group:
220
+ raise ValueError("Missing pcoords for iteration %d" % n)
221
+
222
+ raw_pcoord = iter_group['pcoord'][:]
223
+ if raw_pcoord.ndim != 3:
224
+ log.warning('pcoord is expected to be a 3-d ndarray instead of {}-d'.format(raw_pcoord.ndim))
225
+ continue
226
+ # ignore the first frame of each segment
227
+ if raw_pcoord.shape[1] == traj.n_frames + 1:
228
+ raw_pcoord = raw_pcoord[:, 1:, :]
229
+ elif raw_pcoord.shape[1] == traj.n_frames:
230
+ raw_pcoord = raw_pcoord[:, :, :]
231
+ else:
232
+ raise ValueError(
233
+ "Inconsistent number of pcoords (%d) and frames (%d) for iteration %d" % (raw_pcoord.shape[1], traj.n_frames, n)
234
+ )
235
+
236
+ pcoords = np.concatenate(raw_pcoord, axis=0)
237
+ n_frames = raw_pcoord.shape[1]
238
+
239
+ if 'seg_index' in iter_group:
240
+ raw_pid = iter_group['seg_index']['parent_id'][:]
241
+
242
+ if np.any(raw_pid < 0):
243
+ init_basis_ids = iter_group['ibstates']['istate_index']['basis_state_id'][:]
244
+ init_ids = -(raw_pid[raw_pid < 0] + 1)
245
+ raw_pid[raw_pid < 0] = [init_basis_ids[iid] for iid in init_ids]
246
+ parent_ids = raw_pid.repeat(n_frames, axis=0)
247
+ else:
248
+ parent_ids = None
249
+
250
+ traj.pcoords = pcoords
251
+ traj.parent_ids = parent_ids
252
+ trajectories.append(traj)
253
+
254
+ n += 1
255
+ iter_group_name = iter_group_template.format(iter_prec, n)
256
+
257
+ west_traj = join_traj(trajectories)
258
+
259
+ return west_traj
260
+
261
+
262
+ ###
263
+ # Iteration range metadata
264
+ ###
265
+ def stamp_iter_range(h5object, start_iter, stop_iter):
266
+ '''Mark that the HDF5 object ``h5object`` (dataset or group) contains data from iterations
267
+ start_iter <= n_iter < stop_iter.'''
268
+ h5object.attrs['iter_start'] = start_iter
269
+ h5object.attrs['iter_stop'] = stop_iter
270
+
271
+
272
+ def get_iter_range(h5object):
273
+ '''Read back iteration range data written by ``stamp_iter_range``'''
274
+ return int(h5object.attrs['iter_start']), int(h5object.attrs['iter_stop'])
275
+
276
+
277
+ def stamp_iter_step(h5group, iter_step):
278
+ '''Mark that the HDF5 object ``h5object`` (dataset or group) contains data with an
279
+ iteration step (stride) of iter_step).'''
280
+ h5group.attrs['iter_step'] = iter_step
281
+
282
+
283
+ def get_iter_step(h5group):
284
+ '''Read back iteration step (stride) written by ``stamp_iter_step``'''
285
+ return int(h5group.attrs['iter_step'])
286
+
287
+
288
+ def check_iter_range_least(h5object, iter_start, iter_stop):
289
+ '''Return True if the iteration range [iter_start, iter_stop) is
290
+ the same as or entirely contained within the iteration range stored
291
+ on ``h5object``.'''
292
+ obj_iter_start, obj_iter_stop = get_iter_range(h5object)
293
+ return obj_iter_start <= iter_start and obj_iter_stop >= iter_stop
294
+
295
+
296
+ def check_iter_range_equal(h5object, iter_start, iter_stop):
297
+ '''Return True if the iteration range [iter_start, iter_stop) is
298
+ the same as the iteration range stored on ``h5object``.'''
299
+ obj_iter_start, obj_iter_stop = get_iter_range(h5object)
300
+ return obj_iter_start == iter_start and obj_iter_stop == iter_stop
301
+
302
+
303
+ def get_iteration_entry(h5object, n_iter):
304
+ '''Create a slice for data corresponding to iteration ``n_iter`` in ``h5object``.'''
305
+ obj_iter_start, obj_iter_stop = get_iter_range(h5object)
306
+ if n_iter < obj_iter_start or n_iter >= obj_iter_stop:
307
+ raise IndexError('data for iteration {} not available in dataset {!r}'.format(n_iter, h5object))
308
+ return np.index_exp[n_iter - obj_iter_start]
309
+
310
+
311
+ def get_iteration_slice(h5object, iter_start, iter_stop=None, iter_stride=None):
312
+ '''Create a slice for data corresponding to iterations [iter_start,iter_stop),
313
+ with stride iter_step, in the given ``h5object``.'''
314
+ obj_iter_start, obj_iter_stop = get_iter_range(h5object)
315
+
316
+ if iter_stop is None:
317
+ iter_stop = iter_start + 1
318
+ if iter_stride is None:
319
+ iter_stride = 1
320
+
321
+ if iter_start < obj_iter_start:
322
+ raise IndexError('data for iteration {} not available in dataset {!r}'.format(iter_start, h5object))
323
+ elif iter_start > obj_iter_stop:
324
+ raise IndexError('data for iteration {} not available in dataset {!r}'.format(iter_stop, h5object))
325
+
326
+ start_index = iter_start - obj_iter_start
327
+ stop_index = iter_stop - obj_iter_start
328
+ return np.index_exp[start_index:stop_index:iter_stride]
329
+
330
+
331
+ ###
332
+ # Axis label metadata
333
+ ###
334
+ def label_axes(h5object, labels, units=None):
335
+ '''Stamp the given HDF5 object with axis labels. This stores the axis labels
336
+ in an array of strings in an attribute called ``axis_labels`` on the given
337
+ object. ``units`` if provided is a corresponding list of units.'''
338
+
339
+ if len(labels) != len(h5object.shape):
340
+ raise ValueError('number of axes and number of labels do not match')
341
+
342
+ if units is None:
343
+ units = []
344
+
345
+ if len(units) and len(units) != len(labels):
346
+ raise ValueError('number of units labels does not match number of axes')
347
+
348
+ h5object.attrs['axis_labels'] = np.array([np.string_(i) for i in labels])
349
+
350
+ if len(units):
351
+ h5object.attrs['axis_units'] = np.array([np.string_(i) for i in units])
352
+
353
+
354
+ NotGiven = object()
355
+
356
+
357
+ def _get_one_attr(h5object, namelist, default=NotGiven):
358
+ attrs = dict(h5object.attrs)
359
+ for name in namelist:
360
+ try:
361
+ return attrs[name]
362
+ except KeyError:
363
+ pass
364
+ else:
365
+ if default is NotGiven:
366
+ raise KeyError('no such key')
367
+ else:
368
+ return default
369
+
370
+
371
+ class WESTPAH5File(h5py.File):
372
+ '''Generalized input/output for WESTPA simulation (or analysis) data.'''
373
+
374
+ default_iter_prec = 8
375
+ _this_fileformat_version = 8
376
+
377
+ def __init__(self, *args, **kwargs):
378
+ # These values are used for creating files or reading files where this
379
+ # data is not stored. Otherwise, values stored as attributes on the root
380
+ # group are used instead.
381
+ arg_iter_prec = kwargs.pop('westpa_iter_prec', self.default_iter_prec)
382
+ arg_fileformat_version = kwargs.pop('westpa_fileformat_version', self._this_fileformat_version)
383
+ arg_creating_program = kwargs.pop('creating_program', None)
384
+
385
+ # Initialize h5py file
386
+ super().__init__(*args, **kwargs)
387
+
388
+ # Try to get iteration precision and I/O class version
389
+ h5file_iter_prec = _get_one_attr(self, ['westpa_iter_prec', 'west_iter_prec', 'wemd_iter_prec'], None)
390
+ h5file_fileformat_version = _get_one_attr(
391
+ self, ['westpa_fileformat_version', 'west_file_format_version', 'wemd_file_format_version'], None
392
+ )
393
+
394
+ self.iter_prec = h5file_iter_prec if h5file_iter_prec is not None else arg_iter_prec
395
+ self.fileformat_version = h5file_fileformat_version if h5file_fileformat_version is not None else arg_fileformat_version
396
+
397
+ # Ensure that file format attributes are stored, if the file is writable
398
+ if self.mode == 'r+':
399
+ self.attrs['westpa_iter_prec'] = self.iter_prec
400
+ self.attrs['westpa_fileformat_version'] = self.fileformat_version
401
+ if arg_creating_program:
402
+ stamp_creator_data(self, creating_program=arg_creating_program)
403
+
404
+ # Helper function to automatically replace a group, if it exists.
405
+ # Should really only be called when one is certain a dataset should be blown away.
406
+ def replace_dataset(self, *args, **kwargs):
407
+ try:
408
+ del self[args[0]]
409
+ except Exception:
410
+ pass
411
+ try:
412
+ del self[kwargs['name']]
413
+ except Exception:
414
+ pass
415
+
416
+ return self.create_dataset(*args, **kwargs)
417
+
418
+ # Iteration groups
419
+
420
+ def iter_object_name(self, n_iter, prefix='', suffix=''):
421
+ '''Return a properly-formatted per-iteration name for iteration
422
+ ``n_iter``. (This is used in create/require/get_iter_group, but may
423
+ also be useful for naming datasets on a per-iteration basis.)'''
424
+ return '{prefix}iter_{n_iter:0{prec}d}{suffix}'.format(n_iter=n_iter, prefix=prefix, suffix=suffix, prec=self.iter_prec)
425
+
426
+ def create_iter_group(self, n_iter, group=None):
427
+ '''Create a per-iteration data storage group for iteration number ``n_iter``
428
+ in the group ``group`` (which is '/iterations' by default).'''
429
+
430
+ if group is None:
431
+ group = self.require_group('/iterations')
432
+ return group.create_group(self.iter_object_name(n_iter))
433
+
434
+ def require_iter_group(self, n_iter, group=None):
435
+ '''Ensure that a per-iteration data storage group for iteration number ``n_iter``
436
+ is available in the group ``group`` (which is '/iterations' by default).'''
437
+ if group is None:
438
+ group = self.require_group('/iterations')
439
+ return group.require_group(self.iter_object_name(n_iter))
440
+
441
+ def get_iter_group(self, n_iter, group=None):
442
+ '''Get the per-iteration data group for iteration number ``n_iter`` from within
443
+ the group ``group`` ('/iterations' by default).'''
444
+ if group is None:
445
+ group = self['/iterations']
446
+ return group[self.iter_object_name(n_iter)]
447
+
448
+
449
+ class WESTIterationFile(HDF5TrajectoryFile):
450
+ def __init__(self, file, mode='r', force_overwrite=True, compression='zlib', link=None):
451
+ if isinstance(file, str):
452
+ super(WESTIterationFile, self).__init__(file, mode, force_overwrite, compression)
453
+ else:
454
+ try:
455
+ self._init_from_handle(file)
456
+ except AttributeError:
457
+ raise ValueError('unknown input type: %s' % str(type(file)))
458
+
459
+ def _init_from_handle(self, handle):
460
+ self._handle = handle
461
+ self._open = handle.isopen != 0
462
+ self.mode = mode = handle.mode # the mode in which the file was opened?
463
+
464
+ if mode not in ['r', 'w', 'a']:
465
+ raise ValueError("mode must be one of ['r', 'w', 'a']")
466
+
467
+ # import tables
468
+ self.tables = import_('tables')
469
+
470
+ if mode == 'w':
471
+ # what frame are we currently reading or writing at?
472
+ self._frame_index = 0
473
+ # do we need to write the header information?
474
+ self._needs_initialization = True
475
+
476
+ elif mode == 'a':
477
+ try:
478
+ self._frame_index = len(self._handle.root.coordinates)
479
+ self._needs_initialization = False
480
+ except self.tables.NoSuchNodeError:
481
+ self._frame_index = 0
482
+ self._needs_initialization = True
483
+ elif mode == 'r':
484
+ self._frame_index = 0
485
+ self._needs_initialization = False
486
+
487
+ def read(self, frame_indices=None, atom_indices=None):
488
+ _check_mode(self.mode, ('r',))
489
+
490
+ if frame_indices is None:
491
+ frame_slice = slice(None)
492
+ self._frame_index += frame_slice.stop - frame_slice.start
493
+ else:
494
+ frame_slice = ensure_type(frame_indices, dtype=int, ndim=1, name='frame_indices', warn_on_cast=False)
495
+ if not np.all(frame_slice < self._handle.root.coordinates.shape[0]):
496
+ raise ValueError(
497
+ 'As a zero-based index, the entries in '
498
+ 'frame_slice must all be less than the number of frames '
499
+ 'in the trajectory, %d' % self._handle.root.coordinates.shape[0]
500
+ )
501
+ if not np.all(frame_slice >= 0):
502
+ raise ValueError('The entries in frame_indices must be greater ' 'than or equal to zero')
503
+ self._frame_index += frame_slice[-1] - frame_slice[0]
504
+
505
+ if atom_indices is None:
506
+ # get all of the atoms
507
+ atom_slice = slice(None)
508
+ else:
509
+ atom_slice = ensure_type(atom_indices, dtype=int, ndim=1, name='atom_indices', warn_on_cast=False)
510
+ if not np.all(atom_slice < self._handle.root.coordinates.shape[1]):
511
+ raise ValueError(
512
+ 'As a zero-based index, the entries in '
513
+ 'atom_indices must all be less than the number of atoms '
514
+ 'in the trajectory, %d' % self._handle.root.coordinates.shape[1]
515
+ )
516
+ if not np.all(atom_slice >= 0):
517
+ raise ValueError('The entries in atom_indices must be greater ' 'than or equal to zero')
518
+
519
+ def get_item(node, key):
520
+ if not isinstance(key, tuple):
521
+ return node.__getitem__(key)
522
+
523
+ n_list_like = 0
524
+ new_keys = []
525
+ for item in key:
526
+ if not isinstance(item, slice):
527
+ try:
528
+ d = np.diff(item)
529
+ if len(d) == 0:
530
+ item = item[0]
531
+ elif np.all(d == d[0]):
532
+ item = slice(item[0], item[-1] + d[0], d[0])
533
+ else:
534
+ n_list_like += 1
535
+ except Exception:
536
+ n_list_like += 1
537
+ new_keys.append(item)
538
+ new_keys = tuple(new_keys)
539
+
540
+ if n_list_like <= 1:
541
+ return node.__getitem__(new_keys)
542
+
543
+ data = node
544
+ for i, item in enumerate(new_keys):
545
+ dkey = [slice(None)] * len(key)
546
+ dkey[i] = item
547
+ dkey = tuple(dkey)
548
+ data = data.__getitem__(dkey)
549
+
550
+ return data
551
+
552
+ def get_field(name, slice, out_units, can_be_none=True):
553
+ try:
554
+ node = self._get_node(where='/', name=name)
555
+ data = get_item(node, slice)
556
+ in_units = node.attrs.units
557
+ if not isinstance(in_units, string_types):
558
+ in_units = in_units.decode()
559
+ data = in_units_of(data, in_units, out_units)
560
+ return data
561
+ except self.tables.NoSuchNodeError:
562
+ if can_be_none:
563
+ return None
564
+ raise
565
+
566
+ frames = Frames(
567
+ coordinates=get_field('coordinates', (frame_slice, atom_slice, slice(None)), out_units='nanometers', can_be_none=False),
568
+ time=get_field('time', frame_slice, out_units='picoseconds'),
569
+ cell_lengths=get_field('cell_lengths', (frame_slice, slice(None)), out_units='nanometers'),
570
+ cell_angles=get_field('cell_angles', (frame_slice, slice(None)), out_units='degrees'),
571
+ velocities=get_field('velocities', (frame_slice, atom_slice, slice(None)), out_units='nanometers/picosecond'),
572
+ kineticEnergy=get_field('kineticEnergy', frame_slice, out_units='kilojoules_per_mole'),
573
+ potentialEnergy=get_field('potentialEnergy', frame_slice, out_units='kilojoules_per_mole'),
574
+ temperature=get_field('temperature', frame_slice, out_units='kelvin'),
575
+ alchemicalLambda=get_field('lambda', frame_slice, out_units='dimensionless'),
576
+ )
577
+
578
+ return frames
579
+
580
+ def _has_node(self, where, name):
581
+ try:
582
+ self._get_node(where, name=name)
583
+ except self.tables.NoSuchNodeError:
584
+ return False
585
+
586
+ return True
587
+
588
+ def has_topology(self):
589
+ return self._has_node('/', 'topology')
590
+
591
+ def has_pointer(self):
592
+ return self._has_node('/', 'pointer')
593
+
594
+ def has_restart(self, segment):
595
+ return self._has_node('/restart', '%d_%d' % (segment.n_iter, segment.seg_id))
596
+
597
+ def write_data(self, where, name, data):
598
+ node = self._get_node(where=where, name=name)
599
+ node.append(data)
600
+
601
+ def read_data(self, where, name):
602
+ node = self._get_node(where=where, name=name)
603
+ return node.read()
604
+
605
+ def read_as_traj(self, iteration=None, segment=None, atom_indices=None):
606
+ _check_mode(self.mode, ('r',))
607
+
608
+ pnode = self._get_node(where='/', name='pointer')
609
+
610
+ iter_labels = pnode[:, 0]
611
+ seg_labels = pnode[:, 1]
612
+
613
+ if iteration is None and segment is None:
614
+ frame_indices = slice(None)
615
+ elif isinstance(iteration, (np.integer, int)) and isinstance(segment, (np.integer, int)):
616
+ frame_torf = np.logical_and(iter_labels == iteration, seg_labels == segment)
617
+ frame_indices = np.arange(len(iter_labels))[frame_torf]
618
+ else:
619
+ raise ValueError("iteration and segment must be integers and provided at the same time")
620
+
621
+ if len(frame_indices) == 0:
622
+ raise ValueError(f"no frame was selected: iteration={iteration}, segment={segment}, atom_indices={atom_indices}")
623
+
624
+ iter_labels = iter_labels[frame_indices]
625
+ seg_labels = seg_labels[frame_indices]
626
+
627
+ topology = self.topology
628
+ if atom_indices is not None:
629
+ topology = topology.subset(atom_indices)
630
+
631
+ data = self.read(frame_indices=frame_indices, atom_indices=atom_indices)
632
+ if len(data) == 0:
633
+ return Trajectory(xyz=np.zeros((0, topology.n_atoms, 3)), topology=topology)
634
+
635
+ in_units_of(data.coordinates, self.distance_unit, Trajectory._distance_unit, inplace=True)
636
+ in_units_of(data.cell_lengths, self.distance_unit, Trajectory._distance_unit, inplace=True)
637
+
638
+ return WESTTrajectory(
639
+ data.coordinates,
640
+ topology=topology,
641
+ time=data.time,
642
+ unitcell_lengths=data.cell_lengths,
643
+ unitcell_angles=data.cell_angles,
644
+ iter_labels=iter_labels,
645
+ seg_labels=seg_labels,
646
+ pcoords=None,
647
+ )
648
+
649
+ def read_restart(self, segment):
650
+ if self.has_restart(segment):
651
+ data = self.read_data('/restart/%d_%d' % (segment.n_iter, segment.seg_id), 'data')
652
+ segment.data['iterh5/restart'] = data
653
+ else:
654
+ raise ValueError('no restart data available for {}'.format(str(segment)))
655
+
656
+ def write_segment(self, segment, pop=False):
657
+ n_iter = segment.n_iter
658
+
659
+ self.root._v_attrs['n_iter'] = n_iter
660
+
661
+ if pop:
662
+ get_data = segment.data.pop
663
+ else:
664
+ get_data = segment.data.get
665
+
666
+ traj = get_data('iterh5/trajectory', None)
667
+ restart = get_data('iterh5/restart', None)
668
+ slog = get_data('iterh5/log', None)
669
+
670
+ if traj is not None:
671
+ # create trajectory object
672
+ traj = WESTTrajectory(traj, iter_labels=n_iter, seg_labels=segment.seg_id)
673
+ if traj.n_frames == 0:
674
+ # we may consider logging warnings instead throwing errors for later.
675
+ # right now this is good for debugging purposes
676
+ raise ValueError('no trajectory data present for %s' % repr(segment))
677
+
678
+ if n_iter == 0:
679
+ base_time = 0
680
+ else:
681
+ iter_duration = traj.time[-1] - traj.time[0]
682
+ base_time = iter_duration * (n_iter - 1)
683
+
684
+ traj.time -= traj.time[0]
685
+ traj.time += base_time
686
+
687
+ # pointers
688
+ if not self.has_pointer():
689
+ self._create_earray('/', name='pointer', atom=self.tables.Int64Atom(), shape=(0, 2))
690
+
691
+ iter_idx = traj.iter_labels
692
+ seg_idx = traj.seg_labels
693
+
694
+ pointers = np.stack((iter_idx, seg_idx)).T
695
+
696
+ self.write_data('/', 'pointer', pointers)
697
+
698
+ # trajectory
699
+ self.write(
700
+ coordinates=in_units_of(traj.xyz, Trajectory._distance_unit, self.distance_unit),
701
+ time=traj.time,
702
+ cell_lengths=in_units_of(traj.unitcell_lengths, Trajectory._distance_unit, self.distance_unit),
703
+ cell_angles=traj.unitcell_angles,
704
+ )
705
+
706
+ # topology
707
+ if self.mode == 'a':
708
+ if not self.has_topology():
709
+ self.topology = traj.topology
710
+ elif self.mode == 'w':
711
+ self.topology = traj.topology
712
+
713
+ # restart
714
+ if restart is not None:
715
+ if self.has_restart(segment):
716
+ self._remove_node('/restart', name='%d_%d' % (segment.n_iter, segment.seg_id), recursive=True)
717
+
718
+ self._create_array(
719
+ '/restart/%d_%d' % (segment.n_iter, segment.seg_id),
720
+ name='data',
721
+ atom=self.tables.StringAtom(itemsize=len(restart)),
722
+ obj=restart,
723
+ createparents=True,
724
+ )
725
+
726
+ if slog is not None:
727
+ if self._has_node('/log', str(segment.seg_id)):
728
+ self._remove_node('/log', name=str(segment.seg_id), recursive=True)
729
+
730
+ self._create_array(
731
+ '/log/%d_%d' % (segment.n_iter, segment.seg_id),
732
+ name='data',
733
+ atom=self.tables.StringAtom(itemsize=len(slog)),
734
+ obj=slog,
735
+ createparents=True,
736
+ )
737
+
738
+ @property
739
+ def _create_group(self):
740
+ if self.tables.__version__ >= '3.0.0':
741
+ return self._handle.create_group
742
+ return self._handle.createGroup
743
+
744
+ @property
745
+ def _create_array(self):
746
+ if self.tables.__version__ >= '3.0.0':
747
+ return self._handle.create_array
748
+ return self._handle.createArray
749
+
750
+ @property
751
+ def _remove_node(self):
752
+ if self.tables.__version__ >= '3.0.0':
753
+ return self._handle.remove_node
754
+ return self._handle.removeNode
755
+
756
+
757
+ ### Generalized WE dataset access classes
758
+
759
+
760
+ class DSSpec:
761
+ '''Generalized WE dataset access'''
762
+
763
+ def get_iter_data(self, n_iter, seg_slice=index_exp[:]):
764
+ raise NotImplementedError
765
+
766
+ def get_segment_data(self, n_iter, seg_id):
767
+ return self.get_iter_data(n_iter)[seg_id]
768
+
769
+ def __getstate__(self):
770
+ d = dict(self.__dict__)
771
+ if '_h5file' in d:
772
+ d['_h5file'] = None
773
+ return d
774
+
775
+ def __setstate__(self, state):
776
+ self.__dict__.update(state)
777
+
778
+
779
+ class FileLinkedDSSpec(DSSpec):
780
+ '''Provide facilities for accessing WESTPA HDF5 files, including auto-opening and the ability
781
+ to pickle references to such files for transmission (through, e.g., the work manager), provided
782
+ that the HDF5 file can be accessed by the same path on both the sender and receiver.'''
783
+
784
+ def __init__(self, h5file_or_name):
785
+ self._h5file = None
786
+ self._h5filename = None
787
+
788
+ try:
789
+ self._h5filename = os.path.abspath(h5file_or_name.filename)
790
+ except AttributeError:
791
+ self._h5filename = h5file_or_name
792
+ self._h5file = None
793
+ else:
794
+ self._h5file = h5file_or_name
795
+
796
+ @property
797
+ def h5file(self):
798
+ '''Lazily open HDF5 file. This is required because allowing an open HDF5
799
+ file to cross a fork() boundary generally corrupts the internal state of
800
+ the HDF5 library.'''
801
+ if self._h5file is None:
802
+ self._h5file = WESTPAH5File(self._h5filename, 'r')
803
+ return self._h5file
804
+
805
+
806
+ class SingleDSSpec(FileLinkedDSSpec):
807
+ @classmethod
808
+ def from_string(cls, dsspec_string, default_h5file):
809
+ alias = None
810
+
811
+ h5file = default_h5file
812
+ fields = dsspec_string.split(',')
813
+ dsname = fields[0]
814
+ slice = None
815
+
816
+ for field in (field.strip() for field in fields[1:]):
817
+ k, v = field.split('=')
818
+ k = k.lower()
819
+ if k == 'alias':
820
+ alias = v
821
+ elif k == 'slice':
822
+ try:
823
+ slice = eval('np.index_exp' + v)
824
+ except SyntaxError:
825
+ raise SyntaxError('invalid index expression {!r}'.format(v))
826
+ elif k == 'file':
827
+ h5file = v
828
+ else:
829
+ raise ValueError('invalid dataset option {!r}'.format(k))
830
+
831
+ return cls(h5file, dsname, alias, slice)
832
+
833
+ def __init__(self, h5file_or_name, dsname, alias=None, slice=None):
834
+ FileLinkedDSSpec.__init__(self, h5file_or_name)
835
+ self.dsname = dsname
836
+ self.alias = alias or dsname
837
+ self.slice = np.index_exp[slice] if slice else None
838
+
839
+
840
+ class SingleIterDSSpec(SingleDSSpec):
841
+ def get_iter_data(self, n_iter, seg_slice=index_exp[:]):
842
+ if self.slice:
843
+ return self.h5file.get_iter_group(n_iter)[self.dsname][seg_slice + self.slice]
844
+ else:
845
+ return self.h5file.get_iter_group(n_iter)[self.dsname][seg_slice]
846
+
847
+
848
+ class SingleSegmentDSSpec(SingleDSSpec):
849
+ def get_iter_data(self, n_iter, seg_slice=index_exp[:]):
850
+ if self.slice:
851
+ return self.h5file.get_iter_group(n_iter)[self.dsname][seg_slice + index_exp[:] + self.slice]
852
+ else:
853
+ return self.h5file.get_iter_group(n_iter)[self.dsname][seg_slice]
854
+
855
+ def get_segment_data(self, n_iter, seg_id):
856
+ if self.slice:
857
+ return self.h5file.get_iter_group(n_iter)[np.index_exp[seg_id, :] + self.slice]
858
+ else:
859
+ return self.h5file.get_iter_group(n_iter)[seg_id]
860
+
861
+
862
+ class FnDSSpec(FileLinkedDSSpec):
863
+ def __init__(self, h5file_or_name, fn):
864
+ FileLinkedDSSpec.__init__(self, h5file_or_name)
865
+ self.fn = fn
866
+
867
+ def get_iter_data(self, n_iter, seg_slice=index_exp[:]):
868
+ return self.fn(n_iter, self.h5file.get_iter_group(n_iter))[seg_slice]
869
+
870
+
871
+ class MultiDSSpec(DSSpec):
872
+ def __init__(self, dsspecs):
873
+ self.dsspecs = dsspecs
874
+
875
+ def get_iter_data(self, n_iter, seg_slice=index_exp[:]):
876
+ datasets = [dsspec.get_iter_data(n_iter) for dsspec in self.dsspecs]
877
+
878
+ ncols = 0
879
+ nsegs = None
880
+ npts = None
881
+ for iset, dset in enumerate(datasets):
882
+ if nsegs is None:
883
+ nsegs = dset.shape[0]
884
+ elif dset.shape[0] != nsegs:
885
+ raise TypeError('dataset {} has incorrect first dimension (number of segments)'.format(self.dsspecs[iset]))
886
+ if npts is None:
887
+ npts = dset.shape[1]
888
+ elif dset.shape[1] != npts:
889
+ raise TypeError('dataset {} has incorrect second dimension (number of time points)'.format(self.dsspecs[iset]))
890
+
891
+ if dset.ndim < 2:
892
+ # scalar per segment or scalar per iteration
893
+ raise TypeError('dataset {} has too few dimensions'.format(self.dsspecs[iset]))
894
+ elif dset.ndim > 3:
895
+ # array per timepoint
896
+ raise TypeError('dataset {} has too many dimensions'.format(self.dsspecs[iset]))
897
+ elif dset.ndim == 2:
898
+ # scalar per timepoint
899
+ ncols += 1
900
+ else:
901
+ # vector per timepoint
902
+ ncols += dset.shape[-1]
903
+
904
+ output_dtype = np.result_type(*[ds.dtype for ds in datasets])
905
+ output_array = np.empty((nsegs, npts, ncols), dtype=output_dtype)
906
+
907
+ ocol = 0
908
+ for iset, dset in enumerate(datasets):
909
+ if dset.ndim == 2:
910
+ output_array[:, :, ocol] = dset[...]
911
+ ocol += 1
912
+ elif dset.ndim == 3:
913
+ output_array[:, :, ocol : (ocol + dset.shape[-1])] = dset[...]
914
+ ocol += dset.shape[-1]
915
+
916
+ return output_array[seg_slice]
917
+
918
+
919
+ class IterBlockedDataset:
920
+ @classmethod
921
+ def empty_like(cls, blocked_dataset):
922
+ source = blocked_dataset.data if blocked_dataset.data is not None else blocked_dataset.dataset
923
+
924
+ newbds = cls(
925
+ np.empty(source.shape, source.dtype),
926
+ attrs={'iter_start': blocked_dataset.iter_start, 'iter_stop': blocked_dataset.iter_stop},
927
+ )
928
+ return newbds
929
+
930
+ def __init__(self, dataset_or_array, attrs=None):
931
+ try:
932
+ dataset_or_array.attrs
933
+ except AttributeError:
934
+ self.dataset = None
935
+ self.data = dataset_or_array
936
+ if attrs is None:
937
+ raise ValueError('attribute dictionary containing iteration bounds must be provided')
938
+ self.iter_shape = self.data.shape[1:]
939
+ self.dtype = self.data.dtype
940
+ else:
941
+ self.dataset = dataset_or_array
942
+ attrs = self.dataset.attrs
943
+ self.data = None
944
+ self.iter_shape = self.dataset.shape[1:]
945
+ self.dtype = self.dataset.dtype
946
+
947
+ self.iter_start = attrs['iter_start']
948
+ self.iter_stop = attrs['iter_stop']
949
+
950
+ def cache_data(self, max_size=None):
951
+ '''Cache this dataset in RAM. If ``max_size`` is given, then only cache if the entire dataset
952
+ fits in ``max_size`` bytes. If ``max_size`` is the string 'available', then only cache if
953
+ the entire dataset fits in available RAM, as defined by the ``psutil`` module.'''
954
+
955
+ if max_size is not None:
956
+ dssize = self.dtype.itemsize * np.multiply.reduce(self.dataset.shape)
957
+ if max_size == 'available' and psutil is not None:
958
+ avail_bytes = psutil.virtual_memory().available
959
+ if dssize > avail_bytes:
960
+ return
961
+ elif isinstance(max_size, str):
962
+ return
963
+ else:
964
+ if dssize > max_size:
965
+ return
966
+ if self.dataset is not None:
967
+ if self.data is None:
968
+ self.data = self.dataset[...]
969
+
970
+ def drop_cache(self):
971
+ if self.dataset is not None:
972
+ del self.data
973
+ self.data = None
974
+
975
+ def iter_entry(self, n_iter):
976
+ if n_iter < self.iter_start:
977
+ raise IndexError('requested iteration {} less than first stored iteration {}'.format(n_iter, self.iter_start))
978
+
979
+ source = self.data if self.data is not None else self.dataset
980
+ return source[n_iter - self.iter_start]
981
+
982
+ def iter_slice(self, start=None, stop=None):
983
+ start = start or self.iter_start
984
+ stop = stop or self.iter_stop
985
+ step = 1 # strided retrieval not implemented yet
986
+
987
+ # if step % self.iter_step > 0:
988
+ # raise TypeError('dataset {!r} stored with stride {} cannot be accessed with stride {}'
989
+ # .format(self.dataset, self.iter_step, step))
990
+ if start < self.iter_start:
991
+ raise IndexError('requested start {} less than stored start {}'.format(start, self.iter_start))
992
+ elif stop > self.iter_stop:
993
+ stop = self.iter_stop
994
+
995
+ source = self.data if self.data is not None else self.dataset
996
+ return source[start - self.iter_start : stop - self.iter_start : step]