westpa 2022.12__cp313-cp313-macosx_10_13_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of westpa might be problematic. Click here for more details.

Files changed (149) hide show
  1. westpa/__init__.py +14 -0
  2. westpa/_version.py +21 -0
  3. westpa/analysis/__init__.py +5 -0
  4. westpa/analysis/core.py +746 -0
  5. westpa/analysis/statistics.py +27 -0
  6. westpa/analysis/trajectories.py +360 -0
  7. westpa/cli/__init__.py +0 -0
  8. westpa/cli/core/__init__.py +0 -0
  9. westpa/cli/core/w_fork.py +152 -0
  10. westpa/cli/core/w_init.py +230 -0
  11. westpa/cli/core/w_run.py +77 -0
  12. westpa/cli/core/w_states.py +212 -0
  13. westpa/cli/core/w_succ.py +99 -0
  14. westpa/cli/core/w_truncate.py +68 -0
  15. westpa/cli/tools/__init__.py +0 -0
  16. westpa/cli/tools/ploterr.py +506 -0
  17. westpa/cli/tools/plothist.py +706 -0
  18. westpa/cli/tools/w_assign.py +596 -0
  19. westpa/cli/tools/w_bins.py +166 -0
  20. westpa/cli/tools/w_crawl.py +119 -0
  21. westpa/cli/tools/w_direct.py +547 -0
  22. westpa/cli/tools/w_dumpsegs.py +94 -0
  23. westpa/cli/tools/w_eddist.py +506 -0
  24. westpa/cli/tools/w_fluxanl.py +376 -0
  25. westpa/cli/tools/w_ipa.py +833 -0
  26. westpa/cli/tools/w_kinavg.py +127 -0
  27. westpa/cli/tools/w_kinetics.py +96 -0
  28. westpa/cli/tools/w_multi_west.py +414 -0
  29. westpa/cli/tools/w_ntop.py +213 -0
  30. westpa/cli/tools/w_pdist.py +515 -0
  31. westpa/cli/tools/w_postanalysis_matrix.py +82 -0
  32. westpa/cli/tools/w_postanalysis_reweight.py +53 -0
  33. westpa/cli/tools/w_red.py +491 -0
  34. westpa/cli/tools/w_reweight.py +780 -0
  35. westpa/cli/tools/w_select.py +226 -0
  36. westpa/cli/tools/w_stateprobs.py +111 -0
  37. westpa/cli/tools/w_trace.py +599 -0
  38. westpa/core/__init__.py +0 -0
  39. westpa/core/_rc.py +673 -0
  40. westpa/core/binning/__init__.py +55 -0
  41. westpa/core/binning/_assign.cpython-313-darwin.so +0 -0
  42. westpa/core/binning/assign.py +455 -0
  43. westpa/core/binning/binless.py +96 -0
  44. westpa/core/binning/binless_driver.py +54 -0
  45. westpa/core/binning/binless_manager.py +190 -0
  46. westpa/core/binning/bins.py +47 -0
  47. westpa/core/binning/mab.py +506 -0
  48. westpa/core/binning/mab_driver.py +54 -0
  49. westpa/core/binning/mab_manager.py +198 -0
  50. westpa/core/data_manager.py +1694 -0
  51. westpa/core/extloader.py +74 -0
  52. westpa/core/h5io.py +995 -0
  53. westpa/core/kinetics/__init__.py +24 -0
  54. westpa/core/kinetics/_kinetics.cpython-313-darwin.so +0 -0
  55. westpa/core/kinetics/events.py +147 -0
  56. westpa/core/kinetics/matrates.py +156 -0
  57. westpa/core/kinetics/rate_averaging.py +266 -0
  58. westpa/core/progress.py +218 -0
  59. westpa/core/propagators/__init__.py +54 -0
  60. westpa/core/propagators/executable.py +719 -0
  61. westpa/core/reweight/__init__.py +14 -0
  62. westpa/core/reweight/_reweight.cpython-313-darwin.so +0 -0
  63. westpa/core/reweight/matrix.py +126 -0
  64. westpa/core/segment.py +119 -0
  65. westpa/core/sim_manager.py +835 -0
  66. westpa/core/states.py +359 -0
  67. westpa/core/systems.py +93 -0
  68. westpa/core/textio.py +74 -0
  69. westpa/core/trajectory.py +330 -0
  70. westpa/core/we_driver.py +910 -0
  71. westpa/core/wm_ops.py +43 -0
  72. westpa/core/yamlcfg.py +391 -0
  73. westpa/fasthist/__init__.py +34 -0
  74. westpa/fasthist/_fasthist.cpython-313-darwin.so +0 -0
  75. westpa/mclib/__init__.py +271 -0
  76. westpa/mclib/__main__.py +28 -0
  77. westpa/mclib/_mclib.cpython-313-darwin.so +0 -0
  78. westpa/oldtools/__init__.py +4 -0
  79. westpa/oldtools/aframe/__init__.py +35 -0
  80. westpa/oldtools/aframe/atool.py +75 -0
  81. westpa/oldtools/aframe/base_mixin.py +26 -0
  82. westpa/oldtools/aframe/binning.py +178 -0
  83. westpa/oldtools/aframe/data_reader.py +560 -0
  84. westpa/oldtools/aframe/iter_range.py +200 -0
  85. westpa/oldtools/aframe/kinetics.py +117 -0
  86. westpa/oldtools/aframe/mcbs.py +153 -0
  87. westpa/oldtools/aframe/output.py +39 -0
  88. westpa/oldtools/aframe/plotting.py +90 -0
  89. westpa/oldtools/aframe/trajwalker.py +126 -0
  90. westpa/oldtools/aframe/transitions.py +469 -0
  91. westpa/oldtools/cmds/__init__.py +0 -0
  92. westpa/oldtools/cmds/w_ttimes.py +361 -0
  93. westpa/oldtools/files.py +34 -0
  94. westpa/oldtools/miscfn.py +23 -0
  95. westpa/oldtools/stats/__init__.py +4 -0
  96. westpa/oldtools/stats/accumulator.py +35 -0
  97. westpa/oldtools/stats/edfs.py +129 -0
  98. westpa/oldtools/stats/mcbs.py +96 -0
  99. westpa/tools/__init__.py +33 -0
  100. westpa/tools/binning.py +472 -0
  101. westpa/tools/core.py +340 -0
  102. westpa/tools/data_reader.py +159 -0
  103. westpa/tools/dtypes.py +31 -0
  104. westpa/tools/iter_range.py +198 -0
  105. westpa/tools/kinetics_tool.py +340 -0
  106. westpa/tools/plot.py +283 -0
  107. westpa/tools/progress.py +17 -0
  108. westpa/tools/selected_segs.py +154 -0
  109. westpa/tools/wipi.py +751 -0
  110. westpa/trajtree/__init__.py +4 -0
  111. westpa/trajtree/_trajtree.cpython-313-darwin.so +0 -0
  112. westpa/trajtree/trajtree.py +117 -0
  113. westpa/westext/__init__.py +0 -0
  114. westpa/westext/adaptvoronoi/__init__.py +3 -0
  115. westpa/westext/adaptvoronoi/adaptVor_driver.py +214 -0
  116. westpa/westext/hamsm_restarting/__init__.py +3 -0
  117. westpa/westext/hamsm_restarting/example_overrides.py +35 -0
  118. westpa/westext/hamsm_restarting/restart_driver.py +1165 -0
  119. westpa/westext/stringmethod/__init__.py +11 -0
  120. westpa/westext/stringmethod/fourier_fitting.py +69 -0
  121. westpa/westext/stringmethod/string_driver.py +253 -0
  122. westpa/westext/stringmethod/string_method.py +306 -0
  123. westpa/westext/weed/BinCluster.py +180 -0
  124. westpa/westext/weed/ProbAdjustEquil.py +100 -0
  125. westpa/westext/weed/UncertMath.py +247 -0
  126. westpa/westext/weed/__init__.py +10 -0
  127. westpa/westext/weed/weed_driver.py +192 -0
  128. westpa/westext/wess/ProbAdjust.py +101 -0
  129. westpa/westext/wess/__init__.py +6 -0
  130. westpa/westext/wess/wess_driver.py +217 -0
  131. westpa/work_managers/__init__.py +57 -0
  132. westpa/work_managers/core.py +396 -0
  133. westpa/work_managers/environment.py +134 -0
  134. westpa/work_managers/mpi.py +318 -0
  135. westpa/work_managers/processes.py +187 -0
  136. westpa/work_managers/serial.py +28 -0
  137. westpa/work_managers/threads.py +79 -0
  138. westpa/work_managers/zeromq/__init__.py +20 -0
  139. westpa/work_managers/zeromq/core.py +641 -0
  140. westpa/work_managers/zeromq/node.py +131 -0
  141. westpa/work_managers/zeromq/work_manager.py +526 -0
  142. westpa/work_managers/zeromq/worker.py +320 -0
  143. westpa-2022.12.dist-info/AUTHORS +22 -0
  144. westpa-2022.12.dist-info/LICENSE +21 -0
  145. westpa-2022.12.dist-info/METADATA +193 -0
  146. westpa-2022.12.dist-info/RECORD +149 -0
  147. westpa-2022.12.dist-info/WHEEL +6 -0
  148. westpa-2022.12.dist-info/entry_points.txt +29 -0
  149. westpa-2022.12.dist-info/top_level.txt +1 -0
westpa/core/h5io.py ADDED
@@ -0,0 +1,995 @@
1
+ '''Miscellaneous routines to help with HDF5 input and output of WEST-related data.'''
2
+
3
+ import collections
4
+ import errno
5
+ import getpass
6
+ import os
7
+ import posixpath
8
+ import socket
9
+ import sys
10
+ import time
11
+ import logging
12
+ import warnings
13
+
14
+ import h5py
15
+ import numpy as np
16
+ from numpy import index_exp
17
+ from tables import NaturalNameWarning
18
+
19
+ from mdtraj import Trajectory, join as join_traj
20
+ from mdtraj.utils import in_units_of, import_, ensure_type
21
+ from mdtraj.formats import HDF5TrajectoryFile
22
+ from mdtraj.formats.hdf5 import _check_mode, Frames
23
+
24
+ from .trajectory import WESTTrajectory
25
+
26
+ try:
27
+ import psutil
28
+ except ImportError:
29
+ psutil = None
30
+
31
+ log = logging.getLogger(__name__)
32
+ warnings.filterwarnings('ignore', category=NaturalNameWarning)
33
+
34
+ #
35
+ # Constants and globals
36
+ #
37
+ default_iter_prec = 8
38
+
39
+ #
40
+ # Helper functions
41
+ #
42
+
43
+
44
+ def resolve_filepath(path, constructor=h5py.File, cargs=None, ckwargs=None, **addtlkwargs):
45
+ '''Use a combined filesystem and HDF5 path to open an HDF5 file and return the
46
+ appropriate object. Returns (h5file, h5object). The file is opened using
47
+ ``constructor(filename, *cargs, **ckwargs)``.'''
48
+
49
+ cargs = cargs or ()
50
+ ckwargs = ckwargs or {}
51
+ ckwargs.update(addtlkwargs)
52
+ objpieces = collections.deque()
53
+ path = posixpath.normpath(path)
54
+
55
+ filepieces = path.split('/')
56
+ while filepieces:
57
+ testpath = '/'.join(filepieces)
58
+ if not testpath:
59
+ filepieces.pop()
60
+ continue
61
+ try:
62
+ h5file = constructor(testpath, *cargs, **ckwargs)
63
+ except IOError:
64
+ objpieces.appendleft(filepieces.pop())
65
+ continue
66
+ else:
67
+ return (h5file, h5file['/'.join([''] + list(objpieces)) if objpieces else '/'])
68
+ else:
69
+ # We don't provide a filename, because we're not sure where the filename stops
70
+ # and the HDF5 path begins.
71
+ raise IOError(errno.ENOENT, os.strerror(errno.ENOENT))
72
+
73
+
74
+ def calc_chunksize(shape, dtype, max_chunksize=262144):
75
+ '''Calculate a chunk size for HDF5 data, anticipating that access will slice
76
+ along lower dimensions sooner than higher dimensions.'''
77
+
78
+ chunk_shape = list(shape)
79
+ dtype = np.dtype(dtype)
80
+ for idim in range(len(shape)):
81
+ chunk_nbytes = np.multiply.reduce(chunk_shape) * dtype.itemsize
82
+ while chunk_shape[idim] > 1 and chunk_nbytes > max_chunksize:
83
+ chunk_shape[idim] >>= 1 # divide by 2
84
+ chunk_nbytes = np.multiply.reduce(chunk_shape) * dtype.itemsize
85
+
86
+ if chunk_nbytes <= max_chunksize:
87
+ break
88
+
89
+ chunk_shape = tuple(chunk_shape)
90
+ return chunk_shape
91
+
92
+
93
+ def tostr(b):
94
+ '''Convert a nonstandard string object ``b`` to str with the handling of the
95
+ case where ``b`` is bytes.'''
96
+
97
+ if b is None:
98
+ return None
99
+ elif isinstance(b, bytes):
100
+ return b.decode('utf-8')
101
+ else:
102
+ return str(b)
103
+
104
+
105
+ def is_within_directory(directory, target):
106
+ abs_directory = os.path.abspath(directory)
107
+ abs_target = os.path.abspath(target)
108
+
109
+ prefix = os.path.commonprefix([abs_directory, abs_target])
110
+
111
+ return prefix == abs_directory
112
+
113
+
114
+ def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
115
+ for member in tar.getmembers():
116
+ member_path = os.path.join(path, member.name)
117
+ if not is_within_directory(path, member_path):
118
+ raise Exception("Attempted Path Traversal in Tar File")
119
+
120
+ tar.extractall(path, members, numeric_owner=numeric_owner)
121
+
122
+
123
+ #
124
+ # Group and dataset manipulation functions
125
+ #
126
+
127
+
128
+ def create_hdf5_group(parent_group, groupname, replace=False, creating_program=None):
129
+ '''Create (or delete and recreate) and HDF5 group named ``groupname`` within
130
+ the enclosing Group (object) ``parent_group``. If ``replace`` is True, then
131
+ the group is replaced if present; if False, then an error is raised if the
132
+ group is present. After the group is created, HDF5 attributes are set using
133
+ `stamp_creator_data`.
134
+ '''
135
+
136
+ if replace:
137
+ try:
138
+ del parent_group[groupname]
139
+ except KeyError:
140
+ pass
141
+
142
+ newgroup = parent_group.create_group(groupname)
143
+ stamp_creator_data(newgroup)
144
+ return newgroup
145
+
146
+
147
+ #
148
+ # Group and dataset labeling functions
149
+ #
150
+
151
+
152
+ def stamp_creator_data(h5group, creating_program=None):
153
+ '''Mark the following on the HDF5 group ``h5group``:
154
+
155
+ :creation_program: The name of the program that created the group
156
+ :creation_user: The username of the user who created the group
157
+ :creation_hostname: The hostname of the machine on which the group was created
158
+ :creation_time: The date and time at which the group was created, in the
159
+ current locale.
160
+ :creation_unix_time: The Unix time (seconds from the epoch, UTC) at which the
161
+ group was created.
162
+
163
+ This is meant to facilitate tracking the flow of data, but should not be considered
164
+ a secure paper trail (after all, anyone with write access to the HDF5 file can modify
165
+ these attributes).
166
+ '''
167
+ now = time.time()
168
+ attrs = h5group.attrs
169
+
170
+ attrs['creation_program'] = creating_program or sys.argv[0] or 'unknown program'
171
+ attrs['creation_user'] = getpass.getuser()
172
+ attrs['creation_hostname'] = socket.gethostname()
173
+ attrs['creation_unix_time'] = now
174
+ attrs['creation_time'] = time.strftime('%c', time.localtime(now))
175
+
176
+
177
+ def get_creator_data(h5group):
178
+ '''Read back creator data as written by ``stamp_creator_data``, returning a dictionary with
179
+ keys as described for ``stamp_creator_data``. Missing fields are denoted with None.
180
+ The ``creation_time`` field is returned as a string.'''
181
+ attrs = h5group.attrs
182
+ d = dict()
183
+ for attr in ['creation_program', 'creation_user', 'creation_hostname', 'creation_unix_time', 'creation_time']:
184
+ d[attr] = attrs.get(attr)
185
+ return d
186
+
187
+
188
+ def load_west(filename):
189
+ """Load WESTPA trajectory files from disk.
190
+
191
+ Parameters
192
+ ----------
193
+ filename : str
194
+ String filename of HDF Trajectory file.
195
+ """
196
+
197
+ with h5py.File(filename, 'r') as f:
198
+ iter_group_template = 'iter_{1:0{0}d}'
199
+ iter_prec = f.attrs['west_iter_prec']
200
+ trajectories = []
201
+ n = 0
202
+
203
+ iter_group_name = iter_group_template.format(iter_prec, n)
204
+ for iter_group_name in f['iterations']:
205
+ iter_group = f['iterations/' + iter_group_name]
206
+
207
+ if 'trajectories' in iter_group:
208
+ traj_link = iter_group['trajectories']
209
+ traj_filename = traj_link.file.filename
210
+
211
+ with WESTIterationFile(traj_filename) as traj_file:
212
+ traj = traj_file.read_as_traj()
213
+ else:
214
+ # TODO: [HDF5] allow initializing trajectory without coordinates
215
+ raise ValueError("Missing trajectories for iteration %d" % n)
216
+
217
+ # pcoord is required
218
+ if 'pcoord' not in iter_group:
219
+ raise ValueError("Missing pcoords for iteration %d" % n)
220
+
221
+ raw_pcoord = iter_group['pcoord'][:]
222
+ if raw_pcoord.ndim != 3:
223
+ log.warning('pcoord is expected to be a 3-d ndarray instead of {}-d'.format(raw_pcoord.ndim))
224
+ continue
225
+ # ignore the first frame of each segment
226
+ if raw_pcoord.shape[1] == traj.n_frames + 1:
227
+ raw_pcoord = raw_pcoord[:, 1:, :]
228
+ elif raw_pcoord.shape[1] == traj.n_frames:
229
+ raw_pcoord = raw_pcoord[:, :, :]
230
+ else:
231
+ raise ValueError(
232
+ "Inconsistent number of pcoords (%d) and frames (%d) for iteration %d" % (raw_pcoord.shape[1], traj.n_frames, n)
233
+ )
234
+
235
+ pcoords = np.concatenate(raw_pcoord, axis=0)
236
+ n_frames = raw_pcoord.shape[1]
237
+
238
+ if 'seg_index' in iter_group:
239
+ raw_pid = iter_group['seg_index']['parent_id'][:]
240
+
241
+ if np.any(raw_pid < 0):
242
+ init_basis_ids = iter_group['ibstates']['istate_index']['basis_state_id'][:]
243
+ init_ids = -(raw_pid[raw_pid < 0] + 1)
244
+ raw_pid[raw_pid < 0] = [init_basis_ids[iid] for iid in init_ids]
245
+ parent_ids = raw_pid.repeat(n_frames, axis=0)
246
+ else:
247
+ parent_ids = None
248
+
249
+ traj.pcoords = pcoords
250
+ traj.parent_ids = parent_ids
251
+ trajectories.append(traj)
252
+
253
+ n += 1
254
+ iter_group_name = iter_group_template.format(iter_prec, n)
255
+
256
+ west_traj = join_traj(trajectories)
257
+
258
+ return west_traj
259
+
260
+
261
+ ###
262
+ # Iteration range metadata
263
+ ###
264
+ def stamp_iter_range(h5object, start_iter, stop_iter):
265
+ '''Mark that the HDF5 object ``h5object`` (dataset or group) contains data from iterations
266
+ start_iter <= n_iter < stop_iter.'''
267
+ h5object.attrs['iter_start'] = start_iter
268
+ h5object.attrs['iter_stop'] = stop_iter
269
+
270
+
271
+ def get_iter_range(h5object):
272
+ '''Read back iteration range data written by ``stamp_iter_range``'''
273
+ return int(h5object.attrs['iter_start']), int(h5object.attrs['iter_stop'])
274
+
275
+
276
+ def stamp_iter_step(h5group, iter_step):
277
+ '''Mark that the HDF5 object ``h5object`` (dataset or group) contains data with an
278
+ iteration step (stride) of iter_step).'''
279
+ h5group.attrs['iter_step'] = iter_step
280
+
281
+
282
+ def get_iter_step(h5group):
283
+ '''Read back iteration step (stride) written by ``stamp_iter_step``'''
284
+ return int(h5group.attrs['iter_step'])
285
+
286
+
287
+ def check_iter_range_least(h5object, iter_start, iter_stop):
288
+ '''Return True if the iteration range [iter_start, iter_stop) is
289
+ the same as or entirely contained within the iteration range stored
290
+ on ``h5object``.'''
291
+ obj_iter_start, obj_iter_stop = get_iter_range(h5object)
292
+ return obj_iter_start <= iter_start and obj_iter_stop >= iter_stop
293
+
294
+
295
+ def check_iter_range_equal(h5object, iter_start, iter_stop):
296
+ '''Return True if the iteration range [iter_start, iter_stop) is
297
+ the same as the iteration range stored on ``h5object``.'''
298
+ obj_iter_start, obj_iter_stop = get_iter_range(h5object)
299
+ return obj_iter_start == iter_start and obj_iter_stop == iter_stop
300
+
301
+
302
+ def get_iteration_entry(h5object, n_iter):
303
+ '''Create a slice for data corresponding to iteration ``n_iter`` in ``h5object``.'''
304
+ obj_iter_start, obj_iter_stop = get_iter_range(h5object)
305
+ if n_iter < obj_iter_start or n_iter >= obj_iter_stop:
306
+ raise IndexError('data for iteration {} not available in dataset {!r}'.format(n_iter, h5object))
307
+ return np.index_exp[n_iter - obj_iter_start]
308
+
309
+
310
+ def get_iteration_slice(h5object, iter_start, iter_stop=None, iter_stride=None):
311
+ '''Create a slice for data corresponding to iterations [iter_start,iter_stop),
312
+ with stride iter_step, in the given ``h5object``.'''
313
+ obj_iter_start, obj_iter_stop = get_iter_range(h5object)
314
+
315
+ if iter_stop is None:
316
+ iter_stop = iter_start + 1
317
+ if iter_stride is None:
318
+ iter_stride = 1
319
+
320
+ if iter_start < obj_iter_start:
321
+ raise IndexError('data for iteration {} not available in dataset {!r}'.format(iter_start, h5object))
322
+ elif iter_start > obj_iter_stop:
323
+ raise IndexError('data for iteration {} not available in dataset {!r}'.format(iter_stop, h5object))
324
+
325
+ start_index = iter_start - obj_iter_start
326
+ stop_index = iter_stop - obj_iter_start
327
+ return np.index_exp[start_index:stop_index:iter_stride]
328
+
329
+
330
+ ###
331
+ # Axis label metadata
332
+ ###
333
+ def label_axes(h5object, labels, units=None):
334
+ '''Stamp the given HDF5 object with axis labels. This stores the axis labels
335
+ in an array of strings in an attribute called ``axis_labels`` on the given
336
+ object. ``units`` if provided is a corresponding list of units.'''
337
+
338
+ if len(labels) != len(h5object.shape):
339
+ raise ValueError('number of axes and number of labels do not match')
340
+
341
+ if units is None:
342
+ units = []
343
+
344
+ if len(units) and len(units) != len(labels):
345
+ raise ValueError('number of units labels does not match number of axes')
346
+
347
+ h5object.attrs['axis_labels'] = np.array([np.bytes_(i) for i in labels])
348
+
349
+ if len(units):
350
+ h5object.attrs['axis_units'] = np.array([np.bytes_(i) for i in units])
351
+
352
+
353
+ NotGiven = object()
354
+
355
+
356
+ def _get_one_attr(h5object, namelist, default=NotGiven):
357
+ attrs = dict(h5object.attrs)
358
+ for name in namelist:
359
+ try:
360
+ return attrs[name]
361
+ except KeyError:
362
+ pass
363
+ else:
364
+ if default is NotGiven:
365
+ raise KeyError('no such key')
366
+ else:
367
+ return default
368
+
369
+
370
+ class WESTPAH5File(h5py.File):
371
+ '''Generalized input/output for WESTPA simulation (or analysis) data.'''
372
+
373
+ default_iter_prec = 8
374
+ _this_fileformat_version = 8
375
+
376
+ def __init__(self, *args, **kwargs):
377
+ # These values are used for creating files or reading files where this
378
+ # data is not stored. Otherwise, values stored as attributes on the root
379
+ # group are used instead.
380
+ arg_iter_prec = kwargs.pop('westpa_iter_prec', self.default_iter_prec)
381
+ arg_fileformat_version = kwargs.pop('westpa_fileformat_version', self._this_fileformat_version)
382
+ arg_creating_program = kwargs.pop('creating_program', None)
383
+
384
+ # Initialize h5py file
385
+ super().__init__(*args, **kwargs)
386
+
387
+ # Try to get iteration precision and I/O class version
388
+ h5file_iter_prec = _get_one_attr(self, ['westpa_iter_prec', 'west_iter_prec', 'wemd_iter_prec'], None)
389
+ h5file_fileformat_version = _get_one_attr(
390
+ self, ['westpa_fileformat_version', 'west_file_format_version', 'wemd_file_format_version'], None
391
+ )
392
+
393
+ self.iter_prec = h5file_iter_prec if h5file_iter_prec is not None else arg_iter_prec
394
+ self.fileformat_version = h5file_fileformat_version if h5file_fileformat_version is not None else arg_fileformat_version
395
+
396
+ # Ensure that file format attributes are stored, if the file is writable
397
+ if self.mode == 'r+':
398
+ self.attrs['westpa_iter_prec'] = self.iter_prec
399
+ self.attrs['westpa_fileformat_version'] = self.fileformat_version
400
+ if arg_creating_program:
401
+ stamp_creator_data(self, creating_program=arg_creating_program)
402
+
403
+ # Helper function to automatically replace a group, if it exists.
404
+ # Should really only be called when one is certain a dataset should be blown away.
405
+ def replace_dataset(self, *args, **kwargs):
406
+ try:
407
+ del self[args[0]]
408
+ except Exception:
409
+ pass
410
+ try:
411
+ del self[kwargs['name']]
412
+ except Exception:
413
+ pass
414
+
415
+ return self.create_dataset(*args, **kwargs)
416
+
417
+ # Iteration groups
418
+
419
+ def iter_object_name(self, n_iter, prefix='', suffix=''):
420
+ '''Return a properly-formatted per-iteration name for iteration
421
+ ``n_iter``. (This is used in create/require/get_iter_group, but may
422
+ also be useful for naming datasets on a per-iteration basis.)'''
423
+ return '{prefix}iter_{n_iter:0{prec}d}{suffix}'.format(n_iter=n_iter, prefix=prefix, suffix=suffix, prec=self.iter_prec)
424
+
425
+ def create_iter_group(self, n_iter, group=None):
426
+ '''Create a per-iteration data storage group for iteration number ``n_iter``
427
+ in the group ``group`` (which is '/iterations' by default).'''
428
+
429
+ if group is None:
430
+ group = self.require_group('/iterations')
431
+ return group.create_group(self.iter_object_name(n_iter))
432
+
433
+ def require_iter_group(self, n_iter, group=None):
434
+ '''Ensure that a per-iteration data storage group for iteration number ``n_iter``
435
+ is available in the group ``group`` (which is '/iterations' by default).'''
436
+ if group is None:
437
+ group = self.require_group('/iterations')
438
+ return group.require_group(self.iter_object_name(n_iter))
439
+
440
+ def get_iter_group(self, n_iter, group=None):
441
+ '''Get the per-iteration data group for iteration number ``n_iter`` from within
442
+ the group ``group`` ('/iterations' by default).'''
443
+ if group is None:
444
+ group = self['/iterations']
445
+ return group[self.iter_object_name(n_iter)]
446
+
447
+
448
+ class WESTIterationFile(HDF5TrajectoryFile):
449
+ def __init__(self, file, mode='r', force_overwrite=True, compression='zlib', link=None):
450
+ if isinstance(file, str):
451
+ super(WESTIterationFile, self).__init__(file, mode, force_overwrite, compression)
452
+ else:
453
+ try:
454
+ self._init_from_handle(file)
455
+ except AttributeError:
456
+ raise ValueError('unknown input type: %s' % str(type(file)))
457
+
458
+ def _init_from_handle(self, handle):
459
+ self._handle = handle
460
+ self._open = handle.isopen != 0
461
+ self.mode = mode = handle.mode # the mode in which the file was opened?
462
+
463
+ if mode not in ['r', 'w', 'a']:
464
+ raise ValueError("mode must be one of ['r', 'w', 'a']")
465
+
466
+ # import tables
467
+ self.tables = import_('tables')
468
+
469
+ if mode == 'w':
470
+ # what frame are we currently reading or writing at?
471
+ self._frame_index = 0
472
+ # do we need to write the header information?
473
+ self._needs_initialization = True
474
+
475
+ elif mode == 'a':
476
+ try:
477
+ self._frame_index = len(self._handle.root.coordinates)
478
+ self._needs_initialization = False
479
+ except self.tables.NoSuchNodeError:
480
+ self._frame_index = 0
481
+ self._needs_initialization = True
482
+ elif mode == 'r':
483
+ self._frame_index = 0
484
+ self._needs_initialization = False
485
+
486
+ def read(self, frame_indices=None, atom_indices=None):
487
+ _check_mode(self.mode, ('r',))
488
+
489
+ if frame_indices is None:
490
+ frame_slice = slice(None)
491
+ self._frame_index += frame_slice.stop - frame_slice.start
492
+ else:
493
+ frame_slice = ensure_type(frame_indices, dtype=int, ndim=1, name='frame_indices', warn_on_cast=False)
494
+ if not np.all(frame_slice < self._handle.root.coordinates.shape[0]):
495
+ raise ValueError(
496
+ 'As a zero-based index, the entries in '
497
+ 'frame_slice must all be less than the number of frames '
498
+ 'in the trajectory, %d' % self._handle.root.coordinates.shape[0]
499
+ )
500
+ if not np.all(frame_slice >= 0):
501
+ raise ValueError('The entries in frame_indices must be greater ' 'than or equal to zero')
502
+ self._frame_index += frame_slice[-1] - frame_slice[0]
503
+
504
+ if atom_indices is None:
505
+ # get all of the atoms
506
+ atom_slice = slice(None)
507
+ else:
508
+ atom_slice = ensure_type(atom_indices, dtype=int, ndim=1, name='atom_indices', warn_on_cast=False)
509
+ if not np.all(atom_slice < self._handle.root.coordinates.shape[1]):
510
+ raise ValueError(
511
+ 'As a zero-based index, the entries in '
512
+ 'atom_indices must all be less than the number of atoms '
513
+ 'in the trajectory, %d' % self._handle.root.coordinates.shape[1]
514
+ )
515
+ if not np.all(atom_slice >= 0):
516
+ raise ValueError('The entries in atom_indices must be greater ' 'than or equal to zero')
517
+
518
+ def get_item(node, key):
519
+ if not isinstance(key, tuple):
520
+ return node.__getitem__(key)
521
+
522
+ n_list_like = 0
523
+ new_keys = []
524
+ for item in key:
525
+ if not isinstance(item, slice):
526
+ try:
527
+ d = np.diff(item)
528
+ if len(d) == 0:
529
+ item = item[0]
530
+ elif np.all(d == d[0]):
531
+ item = slice(item[0], item[-1] + d[0], d[0])
532
+ else:
533
+ n_list_like += 1
534
+ except Exception:
535
+ n_list_like += 1
536
+ new_keys.append(item)
537
+ new_keys = tuple(new_keys)
538
+
539
+ if n_list_like <= 1:
540
+ return node.__getitem__(new_keys)
541
+
542
+ data = node
543
+ for i, item in enumerate(new_keys):
544
+ dkey = [slice(None)] * len(key)
545
+ dkey[i] = item
546
+ dkey = tuple(dkey)
547
+ data = data.__getitem__(dkey)
548
+
549
+ return data
550
+
551
+ def get_field(name, slice, out_units, can_be_none=True):
552
+ try:
553
+ node = self._get_node(where='/', name=name)
554
+ data = get_item(node, slice)
555
+ in_units = node.attrs.units
556
+ if not isinstance(in_units, str):
557
+ in_units = in_units.decode()
558
+ data = in_units_of(data, in_units, out_units)
559
+ return data
560
+ except self.tables.NoSuchNodeError:
561
+ if can_be_none:
562
+ return None
563
+ raise
564
+
565
+ frames = Frames(
566
+ coordinates=get_field('coordinates', (frame_slice, atom_slice, slice(None)), out_units='nanometers', can_be_none=False),
567
+ time=get_field('time', frame_slice, out_units='picoseconds'),
568
+ cell_lengths=get_field('cell_lengths', (frame_slice, slice(None)), out_units='nanometers'),
569
+ cell_angles=get_field('cell_angles', (frame_slice, slice(None)), out_units='degrees'),
570
+ velocities=get_field('velocities', (frame_slice, atom_slice, slice(None)), out_units='nanometers/picosecond'),
571
+ kineticEnergy=get_field('kineticEnergy', frame_slice, out_units='kilojoules_per_mole'),
572
+ potentialEnergy=get_field('potentialEnergy', frame_slice, out_units='kilojoules_per_mole'),
573
+ temperature=get_field('temperature', frame_slice, out_units='kelvin'),
574
+ alchemicalLambda=get_field('lambda', frame_slice, out_units='dimensionless'),
575
+ )
576
+
577
+ return frames
578
+
579
+ def _has_node(self, where, name):
580
+ try:
581
+ self._get_node(where, name=name)
582
+ except self.tables.NoSuchNodeError:
583
+ return False
584
+
585
+ return True
586
+
587
+ def has_topology(self):
588
+ return self._has_node('/', 'topology')
589
+
590
+ def has_pointer(self):
591
+ return self._has_node('/', 'pointer')
592
+
593
+ def has_restart(self, segment):
594
+ return self._has_node('/restart', '%d_%d' % (segment.n_iter, segment.seg_id))
595
+
596
+ def write_data(self, where, name, data):
597
+ node = self._get_node(where=where, name=name)
598
+ node.append(data)
599
+
600
+ def read_data(self, where, name):
601
+ node = self._get_node(where=where, name=name)
602
+ return node.read()
603
+
604
+ def read_as_traj(self, iteration=None, segment=None, atom_indices=None):
605
+ _check_mode(self.mode, ('r',))
606
+
607
+ pnode = self._get_node(where='/', name='pointer')
608
+
609
+ iter_labels = pnode[:, 0]
610
+ seg_labels = pnode[:, 1]
611
+
612
+ if iteration is None and segment is None:
613
+ frame_indices = slice(None)
614
+ elif isinstance(iteration, (np.integer, int)) and isinstance(segment, (np.integer, int)):
615
+ frame_torf = np.logical_and(iter_labels == iteration, seg_labels == segment)
616
+ frame_indices = np.arange(len(iter_labels))[frame_torf]
617
+ else:
618
+ raise ValueError("iteration and segment must be integers and provided at the same time")
619
+
620
+ if len(frame_indices) == 0:
621
+ raise ValueError(f"no frame was selected: iteration={iteration}, segment={segment}, atom_indices={atom_indices}")
622
+
623
+ iter_labels = iter_labels[frame_indices]
624
+ seg_labels = seg_labels[frame_indices]
625
+
626
+ topology = self.topology
627
+ if atom_indices is not None:
628
+ topology = topology.subset(atom_indices)
629
+
630
+ data = self.read(frame_indices=frame_indices, atom_indices=atom_indices)
631
+ if len(data) == 0:
632
+ return Trajectory(xyz=np.zeros((0, topology.n_atoms, 3)), topology=topology)
633
+
634
+ in_units_of(data.coordinates, self.distance_unit, Trajectory._distance_unit, inplace=True)
635
+ in_units_of(data.cell_lengths, self.distance_unit, Trajectory._distance_unit, inplace=True)
636
+
637
+ return WESTTrajectory(
638
+ data.coordinates,
639
+ topology=topology,
640
+ time=data.time,
641
+ unitcell_lengths=data.cell_lengths,
642
+ unitcell_angles=data.cell_angles,
643
+ iter_labels=iter_labels,
644
+ seg_labels=seg_labels,
645
+ pcoords=None,
646
+ )
647
+
648
+ def read_restart(self, segment):
649
+ if self.has_restart(segment):
650
+ data = self.read_data('/restart/%d_%d' % (segment.n_iter, segment.seg_id), 'data')
651
+ segment.data['iterh5/restart'] = data
652
+ else:
653
+ raise ValueError('no restart data available for {}'.format(str(segment)))
654
+
655
+ def write_segment(self, segment, pop=False):
656
+ n_iter = segment.n_iter
657
+
658
+ self.root._v_attrs['n_iter'] = n_iter
659
+
660
+ if pop:
661
+ get_data = segment.data.pop
662
+ else:
663
+ get_data = segment.data.get
664
+
665
+ traj = get_data('iterh5/trajectory', None)
666
+ restart = get_data('iterh5/restart', None)
667
+ slog = get_data('iterh5/log', None)
668
+
669
+ if traj is not None:
670
+ # create trajectory object
671
+ traj = WESTTrajectory(traj, iter_labels=n_iter, seg_labels=segment.seg_id)
672
+ if traj.n_frames == 0:
673
+ # we may consider logging warnings instead throwing errors for later.
674
+ # right now this is good for debugging purposes
675
+ raise ValueError('no trajectory data present for %s' % repr(segment))
676
+
677
+ if n_iter == 0:
678
+ base_time = 0
679
+ else:
680
+ iter_duration = traj.time[-1] - traj.time[0]
681
+ base_time = iter_duration * (n_iter - 1)
682
+
683
+ traj.time -= traj.time[0]
684
+ traj.time += base_time
685
+
686
+ # pointers
687
+ if not self.has_pointer():
688
+ self._create_earray('/', name='pointer', atom=self.tables.Int64Atom(), shape=(0, 2))
689
+
690
+ iter_idx = traj.iter_labels
691
+ seg_idx = traj.seg_labels
692
+
693
+ pointers = np.stack((iter_idx, seg_idx)).T
694
+
695
+ self.write_data('/', 'pointer', pointers)
696
+
697
+ # trajectory
698
+ self.write(
699
+ coordinates=in_units_of(traj.xyz, Trajectory._distance_unit, self.distance_unit),
700
+ time=traj.time,
701
+ cell_lengths=in_units_of(traj.unitcell_lengths, Trajectory._distance_unit, self.distance_unit),
702
+ cell_angles=traj.unitcell_angles,
703
+ )
704
+
705
+ # topology
706
+ if self.mode == 'a':
707
+ if not self.has_topology():
708
+ self.topology = traj.topology
709
+ elif self.mode == 'w':
710
+ self.topology = traj.topology
711
+
712
+ # restart
713
+ if restart is not None:
714
+ if self.has_restart(segment):
715
+ self._remove_node('/restart', name='%d_%d' % (segment.n_iter, segment.seg_id), recursive=True)
716
+
717
+ self._create_array(
718
+ '/restart/%d_%d' % (segment.n_iter, segment.seg_id),
719
+ name='data',
720
+ atom=self.tables.StringAtom(itemsize=len(restart)),
721
+ obj=restart,
722
+ createparents=True,
723
+ )
724
+
725
+ if slog is not None:
726
+ if self._has_node('/log', str(segment.seg_id)):
727
+ self._remove_node('/log', name=str(segment.seg_id), recursive=True)
728
+
729
+ self._create_array(
730
+ '/log/%d_%d' % (segment.n_iter, segment.seg_id),
731
+ name='data',
732
+ atom=self.tables.StringAtom(itemsize=len(slog)),
733
+ obj=slog,
734
+ createparents=True,
735
+ )
736
+
737
+ @property
738
+ def _create_group(self):
739
+ if self.tables.__version__ >= '3.0.0':
740
+ return self._handle.create_group
741
+ return self._handle.createGroup
742
+
743
+ @property
744
+ def _create_array(self):
745
+ if self.tables.__version__ >= '3.0.0':
746
+ return self._handle.create_array
747
+ return self._handle.createArray
748
+
749
+ @property
750
+ def _remove_node(self):
751
+ if self.tables.__version__ >= '3.0.0':
752
+ return self._handle.remove_node
753
+ return self._handle.removeNode
754
+
755
+
756
+ ### Generalized WE dataset access classes
757
+
758
+
759
+ class DSSpec:
760
+ '''Generalized WE dataset access'''
761
+
762
+ def get_iter_data(self, n_iter, seg_slice=index_exp[:]):
763
+ raise NotImplementedError
764
+
765
+ def get_segment_data(self, n_iter, seg_id):
766
+ return self.get_iter_data(n_iter)[seg_id]
767
+
768
+ def __getstate__(self):
769
+ d = dict(self.__dict__)
770
+ if '_h5file' in d:
771
+ d['_h5file'] = None
772
+ return d
773
+
774
+ def __setstate__(self, state):
775
+ self.__dict__.update(state)
776
+
777
+
778
+ class FileLinkedDSSpec(DSSpec):
779
+ '''Provide facilities for accessing WESTPA HDF5 files, including auto-opening and the ability
780
+ to pickle references to such files for transmission (through, e.g., the work manager), provided
781
+ that the HDF5 file can be accessed by the same path on both the sender and receiver.'''
782
+
783
+ def __init__(self, h5file_or_name):
784
+ self._h5file = None
785
+ self._h5filename = None
786
+
787
+ try:
788
+ self._h5filename = os.path.abspath(h5file_or_name.filename)
789
+ except AttributeError:
790
+ self._h5filename = h5file_or_name
791
+ self._h5file = None
792
+ else:
793
+ self._h5file = h5file_or_name
794
+
795
+ @property
796
+ def h5file(self):
797
+ '''Lazily open HDF5 file. This is required because allowing an open HDF5
798
+ file to cross a fork() boundary generally corrupts the internal state of
799
+ the HDF5 library.'''
800
+ if self._h5file is None:
801
+ self._h5file = WESTPAH5File(self._h5filename, 'r')
802
+ return self._h5file
803
+
804
+
805
+ class SingleDSSpec(FileLinkedDSSpec):
806
+ @classmethod
807
+ def from_string(cls, dsspec_string, default_h5file):
808
+ alias = None
809
+
810
+ h5file = default_h5file
811
+ fields = dsspec_string.split(',')
812
+ dsname = fields[0]
813
+ slice = None
814
+
815
+ for field in (field.strip() for field in fields[1:]):
816
+ k, v = field.split('=')
817
+ k = k.lower()
818
+ if k == 'alias':
819
+ alias = v
820
+ elif k == 'slice':
821
+ try:
822
+ slice = eval('np.index_exp' + v)
823
+ except SyntaxError:
824
+ raise SyntaxError('invalid index expression {!r}'.format(v))
825
+ elif k == 'file':
826
+ h5file = v
827
+ else:
828
+ raise ValueError('invalid dataset option {!r}'.format(k))
829
+
830
+ return cls(h5file, dsname, alias, slice)
831
+
832
+ def __init__(self, h5file_or_name, dsname, alias=None, slice=None):
833
+ FileLinkedDSSpec.__init__(self, h5file_or_name)
834
+ self.dsname = dsname
835
+ self.alias = alias or dsname
836
+ self.slice = np.index_exp[slice] if slice else None
837
+
838
+
839
+ class SingleIterDSSpec(SingleDSSpec):
840
+ def get_iter_data(self, n_iter, seg_slice=index_exp[:]):
841
+ if self.slice:
842
+ return self.h5file.get_iter_group(n_iter)[self.dsname][seg_slice + self.slice]
843
+ else:
844
+ return self.h5file.get_iter_group(n_iter)[self.dsname][seg_slice]
845
+
846
+
847
+ class SingleSegmentDSSpec(SingleDSSpec):
848
+ def get_iter_data(self, n_iter, seg_slice=index_exp[:]):
849
+ if self.slice:
850
+ return self.h5file.get_iter_group(n_iter)[self.dsname][seg_slice + index_exp[:] + self.slice]
851
+ else:
852
+ return self.h5file.get_iter_group(n_iter)[self.dsname][seg_slice]
853
+
854
+ def get_segment_data(self, n_iter, seg_id):
855
+ if self.slice:
856
+ return self.h5file.get_iter_group(n_iter)[np.index_exp[seg_id, :] + self.slice]
857
+ else:
858
+ return self.h5file.get_iter_group(n_iter)[seg_id]
859
+
860
+
861
+ class FnDSSpec(FileLinkedDSSpec):
862
+ def __init__(self, h5file_or_name, fn):
863
+ FileLinkedDSSpec.__init__(self, h5file_or_name)
864
+ self.fn = fn
865
+
866
+ def get_iter_data(self, n_iter, seg_slice=index_exp[:]):
867
+ return self.fn(n_iter, self.h5file.get_iter_group(n_iter))[seg_slice]
868
+
869
+
870
+ class MultiDSSpec(DSSpec):
871
+ def __init__(self, dsspecs):
872
+ self.dsspecs = dsspecs
873
+
874
+ def get_iter_data(self, n_iter, seg_slice=index_exp[:]):
875
+ datasets = [dsspec.get_iter_data(n_iter) for dsspec in self.dsspecs]
876
+
877
+ ncols = 0
878
+ nsegs = None
879
+ npts = None
880
+ for iset, dset in enumerate(datasets):
881
+ if nsegs is None:
882
+ nsegs = dset.shape[0]
883
+ elif dset.shape[0] != nsegs:
884
+ raise TypeError('dataset {} has incorrect first dimension (number of segments)'.format(self.dsspecs[iset]))
885
+ if npts is None:
886
+ npts = dset.shape[1]
887
+ elif dset.shape[1] != npts:
888
+ raise TypeError('dataset {} has incorrect second dimension (number of time points)'.format(self.dsspecs[iset]))
889
+
890
+ if dset.ndim < 2:
891
+ # scalar per segment or scalar per iteration
892
+ raise TypeError('dataset {} has too few dimensions'.format(self.dsspecs[iset]))
893
+ elif dset.ndim > 3:
894
+ # array per timepoint
895
+ raise TypeError('dataset {} has too many dimensions'.format(self.dsspecs[iset]))
896
+ elif dset.ndim == 2:
897
+ # scalar per timepoint
898
+ ncols += 1
899
+ else:
900
+ # vector per timepoint
901
+ ncols += dset.shape[-1]
902
+
903
+ output_dtype = np.result_type(*[ds.dtype for ds in datasets])
904
+ output_array = np.empty((nsegs, npts, ncols), dtype=output_dtype)
905
+
906
+ ocol = 0
907
+ for iset, dset in enumerate(datasets):
908
+ if dset.ndim == 2:
909
+ output_array[:, :, ocol] = dset[...]
910
+ ocol += 1
911
+ elif dset.ndim == 3:
912
+ output_array[:, :, ocol : (ocol + dset.shape[-1])] = dset[...]
913
+ ocol += dset.shape[-1]
914
+
915
+ return output_array[seg_slice]
916
+
917
+
918
+ class IterBlockedDataset:
919
+ @classmethod
920
+ def empty_like(cls, blocked_dataset):
921
+ source = blocked_dataset.data if blocked_dataset.data is not None else blocked_dataset.dataset
922
+
923
+ newbds = cls(
924
+ np.empty(source.shape, source.dtype),
925
+ attrs={'iter_start': blocked_dataset.iter_start, 'iter_stop': blocked_dataset.iter_stop},
926
+ )
927
+ return newbds
928
+
929
+ def __init__(self, dataset_or_array, attrs=None):
930
+ try:
931
+ dataset_or_array.attrs
932
+ except AttributeError:
933
+ self.dataset = None
934
+ self.data = dataset_or_array
935
+ if attrs is None:
936
+ raise ValueError('attribute dictionary containing iteration bounds must be provided')
937
+ self.iter_shape = self.data.shape[1:]
938
+ self.dtype = self.data.dtype
939
+ else:
940
+ self.dataset = dataset_or_array
941
+ attrs = self.dataset.attrs
942
+ self.data = None
943
+ self.iter_shape = self.dataset.shape[1:]
944
+ self.dtype = self.dataset.dtype
945
+
946
+ self.iter_start = attrs['iter_start']
947
+ self.iter_stop = attrs['iter_stop']
948
+
949
+ def cache_data(self, max_size=None):
950
+ '''Cache this dataset in RAM. If ``max_size`` is given, then only cache if the entire dataset
951
+ fits in ``max_size`` bytes. If ``max_size`` is the string 'available', then only cache if
952
+ the entire dataset fits in available RAM, as defined by the ``psutil`` module.'''
953
+
954
+ if max_size is not None:
955
+ dssize = self.dtype.itemsize * np.multiply.reduce(self.dataset.shape)
956
+ if max_size == 'available' and psutil is not None:
957
+ avail_bytes = psutil.virtual_memory().available
958
+ if dssize > avail_bytes:
959
+ return
960
+ elif isinstance(max_size, str):
961
+ return
962
+ else:
963
+ if dssize > max_size:
964
+ return
965
+ if self.dataset is not None:
966
+ if self.data is None:
967
+ self.data = self.dataset[...]
968
+
969
+ def drop_cache(self):
970
+ if self.dataset is not None:
971
+ del self.data
972
+ self.data = None
973
+
974
+ def iter_entry(self, n_iter):
975
+ if n_iter < self.iter_start:
976
+ raise IndexError('requested iteration {} less than first stored iteration {}'.format(n_iter, self.iter_start))
977
+
978
+ source = self.data if self.data is not None else self.dataset
979
+ return source[n_iter - self.iter_start]
980
+
981
+ def iter_slice(self, start=None, stop=None):
982
+ start = start or self.iter_start
983
+ stop = stop or self.iter_stop
984
+ step = 1 # strided retrieval not implemented yet
985
+
986
+ # if step % self.iter_step > 0:
987
+ # raise TypeError('dataset {!r} stored with stride {} cannot be accessed with stride {}'
988
+ # .format(self.dataset, self.iter_step, step))
989
+ if start < self.iter_start:
990
+ raise IndexError('requested start {} less than stored start {}'.format(start, self.iter_start))
991
+ elif stop > self.iter_stop:
992
+ stop = self.iter_stop
993
+
994
+ source = self.data if self.data is not None else self.dataset
995
+ return source[start - self.iter_start : stop - self.iter_start : step]