westpa 2022.10__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of westpa might be problematic. Click here for more details.

Files changed (150) hide show
  1. westpa/__init__.py +14 -0
  2. westpa/_version.py +21 -0
  3. westpa/analysis/__init__.py +5 -0
  4. westpa/analysis/core.py +746 -0
  5. westpa/analysis/statistics.py +27 -0
  6. westpa/analysis/trajectories.py +360 -0
  7. westpa/cli/__init__.py +0 -0
  8. westpa/cli/core/__init__.py +0 -0
  9. westpa/cli/core/w_fork.py +152 -0
  10. westpa/cli/core/w_init.py +230 -0
  11. westpa/cli/core/w_run.py +77 -0
  12. westpa/cli/core/w_states.py +212 -0
  13. westpa/cli/core/w_succ.py +99 -0
  14. westpa/cli/core/w_truncate.py +59 -0
  15. westpa/cli/tools/__init__.py +0 -0
  16. westpa/cli/tools/ploterr.py +506 -0
  17. westpa/cli/tools/plothist.py +706 -0
  18. westpa/cli/tools/w_assign.py +596 -0
  19. westpa/cli/tools/w_bins.py +166 -0
  20. westpa/cli/tools/w_crawl.py +119 -0
  21. westpa/cli/tools/w_direct.py +547 -0
  22. westpa/cli/tools/w_dumpsegs.py +94 -0
  23. westpa/cli/tools/w_eddist.py +506 -0
  24. westpa/cli/tools/w_fluxanl.py +378 -0
  25. westpa/cli/tools/w_ipa.py +833 -0
  26. westpa/cli/tools/w_kinavg.py +127 -0
  27. westpa/cli/tools/w_kinetics.py +96 -0
  28. westpa/cli/tools/w_multi_west.py +414 -0
  29. westpa/cli/tools/w_ntop.py +213 -0
  30. westpa/cli/tools/w_pdist.py +515 -0
  31. westpa/cli/tools/w_postanalysis_matrix.py +82 -0
  32. westpa/cli/tools/w_postanalysis_reweight.py +53 -0
  33. westpa/cli/tools/w_red.py +486 -0
  34. westpa/cli/tools/w_reweight.py +780 -0
  35. westpa/cli/tools/w_select.py +226 -0
  36. westpa/cli/tools/w_stateprobs.py +111 -0
  37. westpa/cli/tools/w_trace.py +599 -0
  38. westpa/core/__init__.py +0 -0
  39. westpa/core/_rc.py +673 -0
  40. westpa/core/binning/__init__.py +55 -0
  41. westpa/core/binning/_assign.cpython-312-darwin.so +0 -0
  42. westpa/core/binning/assign.py +449 -0
  43. westpa/core/binning/binless.py +96 -0
  44. westpa/core/binning/binless_driver.py +54 -0
  45. westpa/core/binning/binless_manager.py +190 -0
  46. westpa/core/binning/bins.py +47 -0
  47. westpa/core/binning/mab.py +427 -0
  48. westpa/core/binning/mab_driver.py +54 -0
  49. westpa/core/binning/mab_manager.py +198 -0
  50. westpa/core/data_manager.py +1694 -0
  51. westpa/core/extloader.py +74 -0
  52. westpa/core/h5io.py +995 -0
  53. westpa/core/kinetics/__init__.py +24 -0
  54. westpa/core/kinetics/_kinetics.cpython-312-darwin.so +0 -0
  55. westpa/core/kinetics/events.py +147 -0
  56. westpa/core/kinetics/matrates.py +156 -0
  57. westpa/core/kinetics/rate_averaging.py +266 -0
  58. westpa/core/progress.py +218 -0
  59. westpa/core/propagators/__init__.py +54 -0
  60. westpa/core/propagators/executable.py +715 -0
  61. westpa/core/reweight/__init__.py +14 -0
  62. westpa/core/reweight/_reweight.cpython-312-darwin.so +0 -0
  63. westpa/core/reweight/matrix.py +126 -0
  64. westpa/core/segment.py +119 -0
  65. westpa/core/sim_manager.py +830 -0
  66. westpa/core/states.py +359 -0
  67. westpa/core/systems.py +93 -0
  68. westpa/core/textio.py +74 -0
  69. westpa/core/trajectory.py +330 -0
  70. westpa/core/we_driver.py +908 -0
  71. westpa/core/wm_ops.py +43 -0
  72. westpa/core/yamlcfg.py +391 -0
  73. westpa/fasthist/__init__.py +34 -0
  74. westpa/fasthist/__main__.py +110 -0
  75. westpa/fasthist/_fasthist.cpython-312-darwin.so +0 -0
  76. westpa/mclib/__init__.py +264 -0
  77. westpa/mclib/__main__.py +28 -0
  78. westpa/mclib/_mclib.cpython-312-darwin.so +0 -0
  79. westpa/oldtools/__init__.py +4 -0
  80. westpa/oldtools/aframe/__init__.py +35 -0
  81. westpa/oldtools/aframe/atool.py +75 -0
  82. westpa/oldtools/aframe/base_mixin.py +26 -0
  83. westpa/oldtools/aframe/binning.py +178 -0
  84. westpa/oldtools/aframe/data_reader.py +560 -0
  85. westpa/oldtools/aframe/iter_range.py +200 -0
  86. westpa/oldtools/aframe/kinetics.py +117 -0
  87. westpa/oldtools/aframe/mcbs.py +146 -0
  88. westpa/oldtools/aframe/output.py +39 -0
  89. westpa/oldtools/aframe/plotting.py +90 -0
  90. westpa/oldtools/aframe/trajwalker.py +126 -0
  91. westpa/oldtools/aframe/transitions.py +469 -0
  92. westpa/oldtools/cmds/__init__.py +0 -0
  93. westpa/oldtools/cmds/w_ttimes.py +358 -0
  94. westpa/oldtools/files.py +34 -0
  95. westpa/oldtools/miscfn.py +23 -0
  96. westpa/oldtools/stats/__init__.py +4 -0
  97. westpa/oldtools/stats/accumulator.py +35 -0
  98. westpa/oldtools/stats/edfs.py +129 -0
  99. westpa/oldtools/stats/mcbs.py +89 -0
  100. westpa/tools/__init__.py +33 -0
  101. westpa/tools/binning.py +472 -0
  102. westpa/tools/core.py +340 -0
  103. westpa/tools/data_reader.py +159 -0
  104. westpa/tools/dtypes.py +31 -0
  105. westpa/tools/iter_range.py +198 -0
  106. westpa/tools/kinetics_tool.py +340 -0
  107. westpa/tools/plot.py +283 -0
  108. westpa/tools/progress.py +17 -0
  109. westpa/tools/selected_segs.py +154 -0
  110. westpa/tools/wipi.py +751 -0
  111. westpa/trajtree/__init__.py +4 -0
  112. westpa/trajtree/_trajtree.cpython-312-darwin.so +0 -0
  113. westpa/trajtree/trajtree.py +117 -0
  114. westpa/westext/__init__.py +0 -0
  115. westpa/westext/adaptvoronoi/__init__.py +3 -0
  116. westpa/westext/adaptvoronoi/adaptVor_driver.py +214 -0
  117. westpa/westext/hamsm_restarting/__init__.py +3 -0
  118. westpa/westext/hamsm_restarting/example_overrides.py +35 -0
  119. westpa/westext/hamsm_restarting/restart_driver.py +1165 -0
  120. westpa/westext/stringmethod/__init__.py +11 -0
  121. westpa/westext/stringmethod/fourier_fitting.py +69 -0
  122. westpa/westext/stringmethod/string_driver.py +253 -0
  123. westpa/westext/stringmethod/string_method.py +306 -0
  124. westpa/westext/weed/BinCluster.py +180 -0
  125. westpa/westext/weed/ProbAdjustEquil.py +100 -0
  126. westpa/westext/weed/UncertMath.py +247 -0
  127. westpa/westext/weed/__init__.py +10 -0
  128. westpa/westext/weed/weed_driver.py +182 -0
  129. westpa/westext/wess/ProbAdjust.py +101 -0
  130. westpa/westext/wess/__init__.py +6 -0
  131. westpa/westext/wess/wess_driver.py +207 -0
  132. westpa/work_managers/__init__.py +57 -0
  133. westpa/work_managers/core.py +396 -0
  134. westpa/work_managers/environment.py +134 -0
  135. westpa/work_managers/mpi.py +318 -0
  136. westpa/work_managers/processes.py +187 -0
  137. westpa/work_managers/serial.py +28 -0
  138. westpa/work_managers/threads.py +79 -0
  139. westpa/work_managers/zeromq/__init__.py +20 -0
  140. westpa/work_managers/zeromq/core.py +641 -0
  141. westpa/work_managers/zeromq/node.py +131 -0
  142. westpa/work_managers/zeromq/work_manager.py +526 -0
  143. westpa/work_managers/zeromq/worker.py +320 -0
  144. westpa-2022.10.dist-info/AUTHORS +22 -0
  145. westpa-2022.10.dist-info/LICENSE +21 -0
  146. westpa-2022.10.dist-info/METADATA +183 -0
  147. westpa-2022.10.dist-info/RECORD +150 -0
  148. westpa-2022.10.dist-info/WHEEL +5 -0
  149. westpa-2022.10.dist-info/entry_points.txt +29 -0
  150. westpa-2022.10.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1694 @@
1
+ """
2
+ HDF5 data manager for WEST.
3
+
4
+ Original HDF5 implementation: Joseph W. Kaus
5
+ Current implementation: Matthew C. Zwier
6
+
7
+ WEST exclusively uses the cross-platform, self-describing file format HDF5
8
+ for data storage. This ensures that data is stored efficiently and portably
9
+ in a manner that is relatively straightforward for other analysis tools
10
+ (perhaps written in C/C++/Fortran) to access.
11
+
12
+ The data is laid out in HDF5 as follows:
13
+ - summary -- overall summary data for the simulation
14
+ - /iterations/ -- data for individual iterations, one group per iteration under /iterations
15
+ - iter_00000001/ -- data for iteration 1
16
+ - seg_index -- overall information about segments in the iteration, including weight
17
+ - pcoord -- progress coordinate data organized as [seg_id][time][dimension]
18
+ - wtg_parents -- data used to reconstruct the split/merge history of trajectories
19
+ - recycling -- flux and event count for recycled particles, on a per-target-state basis
20
+ - auxdata/ -- auxiliary datasets (data stored on the 'data' field of Segment objects)
21
+
22
+ The file root object has an integer attribute 'west_file_format_version' which can be used to
23
+ determine how to access data even as the file format (i.e. organization of data within HDF5 file)
24
+ evolves.
25
+
26
+ Version history:
27
+ Version 9
28
+ - Basis states are now saved as iter_segid instead of just segid as a pointer label.
29
+ - Initial states are also saved in the iteration 0 file, with a negative sign.
30
+ Version 8
31
+ - Added external links to trajectory files in iterations/iter_* groups, if the HDF5
32
+ framework was used.
33
+ - Added an iter group for the iteration 0 to store conformations of basis states.
34
+ Version 7
35
+ - Removed bin_assignments, bin_populations, and bin_rates from iteration group.
36
+ - Added new_segments subgroup to iteration group
37
+ Version 6
38
+ - ???
39
+ Version 5
40
+ - moved iter_* groups into a top-level iterations/ group,
41
+ - added in-HDF5 storage for basis states, target states, and generated states
42
+ """
43
+
44
+ import logging
45
+ import pickle
46
+ import posixpath
47
+ import sys
48
+ import threading
49
+ import time
50
+ import builtins
51
+ from operator import attrgetter
52
+ from os.path import relpath, dirname
53
+
54
+ import h5py
55
+ from h5py import h5s
56
+ import numpy as np
57
+
58
+ from . import h5io
59
+ from .segment import Segment
60
+ from .states import BasisState, TargetState, InitialState
61
+ from .we_driver import NewWeightEntry
62
+ from .propagators.executable import ExecutablePropagator
63
+
64
+ import westpa
65
+
66
+
67
+ log = logging.getLogger(__name__)
68
+
69
+ file_format_version = 9
70
+
71
+ makepath = ExecutablePropagator.makepath
72
+
73
+
74
+ class flushing_lock:
75
+ def __init__(self, lock, fileobj):
76
+ self.lock = lock
77
+ self.fileobj = fileobj
78
+
79
+ def __enter__(self):
80
+ self.lock.acquire()
81
+
82
+ def __exit__(self, exc_type, exc_val, exc_tb):
83
+ self.fileobj.flush()
84
+ self.lock.release()
85
+
86
+
87
+ class expiring_flushing_lock:
88
+ def __init__(self, lock, flush_method, nextsync):
89
+ self.lock = lock
90
+ self.flush_method = flush_method
91
+ self.nextsync = nextsync
92
+
93
+ def __enter__(self):
94
+ self.lock.acquire()
95
+
96
+ def __exit__(self, exc_type, exc_val, exc_tb):
97
+ if time.time() > self.nextsync:
98
+ self.flush_method()
99
+ self.lock.release()
100
+
101
+
102
+ # Data types for use in the HDF5 file
103
+ seg_id_dtype = np.int64 # Up to 9 quintillion segments per iteration; signed so that initial states can be stored negative
104
+ n_iter_dtype = np.uint32 # Up to 4 billion iterations
105
+ weight_dtype = np.float64 # about 15 digits of precision in weights
106
+ utime_dtype = np.float64 # ("u" for Unix time) Up to ~10^300 cpu-seconds
107
+ vstr_dtype = h5py.special_dtype(vlen=str)
108
+ h5ref_dtype = h5py.special_dtype(ref=h5py.Reference)
109
+ binhash_dtype = np.dtype('|S64')
110
+
111
+ # seg_status_dtype = h5py.special_dtype(enum=(np.uint8, Segment.statuses))
112
+ # seg_initpoint_dtype = h5py.special_dtype(enum=(np.uint8, Segment.initpoint_types))
113
+ # seg_endpoint_dtype = h5py.special_dtype(enum=(np.uint8, Segment.endpoint_types))
114
+ # istate_type_dtype = h5py.special_dtype(enum=(np.uint8, InitialState.istate_types))
115
+ # istate_status_dtype = h5py.special_dtype(enum=(np.uint8, InitialState.istate_statuses))
116
+
117
+ seg_status_dtype = np.uint8
118
+ seg_initpoint_dtype = np.uint8
119
+ seg_endpoint_dtype = np.uint8
120
+ istate_type_dtype = np.uint8
121
+ istate_status_dtype = np.uint8
122
+
123
+ summary_table_dtype = np.dtype(
124
+ [
125
+ ('n_particles', seg_id_dtype), # Number of live trajectories in this iteration
126
+ ('norm', weight_dtype), # Norm of probability, to watch for errors or drift
127
+ ('min_bin_prob', weight_dtype), # Per-bin minimum probability
128
+ ('max_bin_prob', weight_dtype), # Per-bin maximum probability
129
+ ('min_seg_prob', weight_dtype), # Per-segment minimum probability
130
+ ('max_seg_prob', weight_dtype), # Per-segment maximum probability
131
+ ('cputime', utime_dtype), # Total CPU time for this iteration
132
+ ('walltime', utime_dtype), # Total wallclock time for this iteration
133
+ ('binhash', binhash_dtype),
134
+ ]
135
+ )
136
+
137
+
138
+ # The HDF5 file tracks two distinct, but related, histories:
139
+ # (1) the evolution of the trajectory, which requires only an identifier
140
+ # of where a segment's initial state comes from (the "history graph");
141
+ # this is stored as the parent_id field of the seg index
142
+ # (2) the flow of probability due to splits, merges, and recycling events,
143
+ # which can be thought of as an adjacency list (the "weight graph")
144
+ # segment ID is implied by the row in the index table, and so is not stored
145
+ # initpoint_type remains implicitly stored as negative IDs (if parent_id < 0, then init_state_id = -(parent_id+1)
146
+ seg_index_dtype = np.dtype(
147
+ [
148
+ ('weight', weight_dtype), # Statistical weight of this segment
149
+ ('parent_id', seg_id_dtype), # ID of parent (for trajectory history)
150
+ ('wtg_n_parents', np.uint), # number of parents this segment has in the weight transfer graph
151
+ ('wtg_offset', np.uint), # offset into the weight transfer graph dataset
152
+ ('cputime', utime_dtype), # CPU time used in propagating this segment
153
+ ('walltime', utime_dtype), # Wallclock time used in propagating this segment
154
+ ('endpoint_type', seg_endpoint_dtype), # Endpoint type (will continue, merged, or recycled)
155
+ ('status', seg_status_dtype), # Status of propagation of this segment
156
+ ]
157
+ )
158
+
159
+ # Index to basis/initial states
160
+ ibstate_index_dtype = np.dtype([('iter_valid', np.uint), ('n_bstates', np.uint), ('group_ref', h5ref_dtype)])
161
+
162
+ # Basis state index type
163
+ bstate_dtype = np.dtype(
164
+ [
165
+ ('label', vstr_dtype), # An optional descriptive label
166
+ ('probability', weight_dtype), # Probability that this state will be selected
167
+ ('auxref', vstr_dtype), # An optional auxiliar data reference
168
+ ]
169
+ )
170
+
171
+ # Even when initial state generation is off and basis states are passed through directly, an initial state entry
172
+ # is created, as that allows precise tracing of the history of a given state in the most complex case of
173
+ # a new initial state for every new trajectory.
174
+ istate_dtype = np.dtype(
175
+ [
176
+ ('iter_created', np.uint), # Iteration during which this state was generated (0 for at w_init)
177
+ ('iter_used', np.uint), # When this state was used to start a new trajectory
178
+ ('basis_state_id', seg_id_dtype), # Which basis state this state was generated from
179
+ ('istate_type', istate_type_dtype), # What type this initial state is (generated or basis)
180
+ ('istate_status', istate_status_dtype), # Whether this initial state is ready to go
181
+ ('basis_auxref', vstr_dtype),
182
+ ]
183
+ )
184
+
185
+ tstate_index_dtype = np.dtype(
186
+ [('iter_valid', np.uint), ('n_states', np.uint), ('group_ref', h5ref_dtype)] # Iteration when this state list is valid
187
+ ) # Reference to a group containing further data; this will be the
188
+ # null reference if there is no target state for that timeframe.
189
+ tstate_dtype = np.dtype([('label', vstr_dtype)]) # An optional descriptive label for this state
190
+
191
+ # Support for west.we_driver.NewWeightEntry
192
+ nw_source_dtype = np.uint8
193
+ nw_index_dtype = np.dtype(
194
+ [
195
+ ('source_type', nw_source_dtype),
196
+ ('weight', weight_dtype),
197
+ ('prev_seg_id', seg_id_dtype),
198
+ ('target_state_id', seg_id_dtype),
199
+ ('initial_state_id', seg_id_dtype),
200
+ ]
201
+ )
202
+
203
+ # Storage of bin identities
204
+ binning_index_dtype = np.dtype([('hash', binhash_dtype), ('pickle_len', np.uint32)])
205
+
206
+
207
+ class WESTDataManager:
208
+ """Data manager for assisiting the reading and writing of WEST data from/to HDF5 files."""
209
+
210
+ # defaults for various options
211
+ default_iter_prec = 8
212
+ default_we_h5filename = 'west.h5'
213
+ default_we_h5file_driver = None
214
+ default_flush_period = 60
215
+
216
+ # Compress any auxiliary dataset whose total size (across all segments) is more than 1MB
217
+ default_aux_compression_threshold = 1048576
218
+
219
+ # Bin data horizontal (second dimension) chunk size
220
+ binning_hchunksize = 4096
221
+
222
+ # Number of rows to retrieve during a table scan
223
+ table_scan_chunksize = 1024
224
+
225
+ def flushing_lock(self):
226
+ return flushing_lock(self.lock, self.we_h5file)
227
+
228
+ def expiring_flushing_lock(self):
229
+ next_flush = self.last_flush + self.flush_period
230
+ return expiring_flushing_lock(self.lock, self.flush_backing, next_flush)
231
+
232
+ def process_config(self):
233
+ config = self.rc.config
234
+
235
+ for entry, type_ in [('iter_prec', int)]:
236
+ config.require_type_if_present(['west', 'data', entry], type_)
237
+
238
+ self.we_h5filename = config.get_path(['west', 'data', 'west_data_file'], default=self.default_we_h5filename)
239
+ self.we_h5file_driver = config.get_choice(
240
+ ['west', 'data', 'west_data_file_driver'],
241
+ [None, 'sec2', 'family'],
242
+ default=self.default_we_h5file_driver,
243
+ value_transform=(lambda x: x.lower() if x else None),
244
+ )
245
+ self.iter_prec = config.get(['west', 'data', 'iter_prec'], self.default_iter_prec)
246
+ self.aux_compression_threshold = config.get(
247
+ ['west', 'data', 'aux_compression_threshold'], self.default_aux_compression_threshold
248
+ )
249
+ self.flush_period = config.get(['west', 'data', 'flush_period'], self.default_flush_period)
250
+ self.iter_ref_h5_template = config.get(['west', 'data', 'data_refs', 'iteration'], None)
251
+ self.store_h5 = self.iter_ref_h5_template is not None
252
+
253
+ # Process dataset options
254
+ dsopts_list = config.get(['west', 'data', 'datasets']) or []
255
+ for dsopts in dsopts_list:
256
+ dsopts = normalize_dataset_options(dsopts, path_prefix='auxdata' if dsopts['name'] != 'pcoord' else '')
257
+ try:
258
+ self.dataset_options[dsopts['name']].update(dsopts)
259
+ except KeyError:
260
+ self.dataset_options[dsopts['name']] = dsopts
261
+
262
+ if 'pcoord' in self.dataset_options:
263
+ if self.dataset_options['pcoord']['h5path'] != 'pcoord':
264
+ raise ValueError('cannot override pcoord storage location')
265
+
266
+ def __init__(self, rc=None):
267
+ self.rc = rc or westpa.rc
268
+
269
+ self.we_h5filename = self.default_we_h5filename
270
+ self.we_h5file_driver = self.default_we_h5file_driver
271
+ self.we_h5file_version = None
272
+ self.h5_access_mode = 'r+'
273
+ self.iter_prec = self.default_iter_prec
274
+ self.aux_compression_threshold = self.default_aux_compression_threshold
275
+
276
+ self.we_h5file = None
277
+
278
+ self.lock = threading.RLock()
279
+ self.flush_period = None
280
+ self.last_flush = 0
281
+
282
+ self._system = None
283
+ self.iter_ref_h5_template = None
284
+ self.store_h5 = False
285
+
286
+ self.dataset_options = {}
287
+ self.process_config()
288
+
289
+ @property
290
+ def system(self):
291
+ if self._system is None:
292
+ self._system = self.rc.get_system_driver()
293
+ return self._system
294
+
295
+ @system.setter
296
+ def system(self, system):
297
+ self._system = system
298
+
299
+ @property
300
+ def closed(self):
301
+ return self.we_h5file is None
302
+
303
+ def iter_group_name(self, n_iter, absolute=True):
304
+ if absolute:
305
+ return '/iterations/iter_{:0{prec}d}'.format(int(n_iter), prec=self.iter_prec)
306
+ else:
307
+ return 'iter_{:0{prec}d}'.format(int(n_iter), prec=self.iter_prec)
308
+
309
+ def require_iter_group(self, n_iter):
310
+ '''Get the group associated with n_iter, creating it if necessary.'''
311
+ with self.lock:
312
+ iter_group = self.we_h5file.require_group('/iterations/iter_{:0{prec}d}'.format(int(n_iter), prec=self.iter_prec))
313
+ iter_group.attrs['n_iter'] = n_iter
314
+ return iter_group
315
+
316
+ def del_iter_group(self, n_iter):
317
+ with self.lock:
318
+ del self.we_h5file['/iterations/iter_{:0{prec}d}'.format(int(n_iter), prec=self.iter_prec)]
319
+
320
+ def get_iter_group(self, n_iter):
321
+ with self.lock:
322
+ try:
323
+ return self.we_h5file['/iterations/iter_{:0{prec}d}'.format(int(n_iter), prec=self.iter_prec)]
324
+ except KeyError:
325
+ return self.we_h5file['/iter_{:0{prec}d}'.format(int(n_iter), prec=self.iter_prec)]
326
+
327
+ def get_seg_index(self, n_iter):
328
+ with self.lock:
329
+ seg_index = self.get_iter_group(n_iter)['seg_index']
330
+ return seg_index
331
+
332
+ @property
333
+ def current_iteration(self):
334
+ with self.lock:
335
+ h5file_attrs = self.we_h5file['/'].attrs
336
+ h5file_attr_keys = list(h5file_attrs.keys())
337
+
338
+ if 'west_current_iteration' in h5file_attr_keys:
339
+ return int(self.we_h5file['/'].attrs['west_current_iteration'])
340
+ else:
341
+ return int(self.we_h5file['/'].attrs['wemd_current_iteration'])
342
+
343
+ @current_iteration.setter
344
+ def current_iteration(self, n_iter):
345
+ with self.lock:
346
+ self.we_h5file['/'].attrs['west_current_iteration'] = n_iter
347
+
348
+ def open_backing(self, mode=None):
349
+ '''Open the (already-created) HDF5 file named in self.west_h5filename.'''
350
+ mode = mode or self.h5_access_mode
351
+ if not self.we_h5file:
352
+ log.debug('attempting to open {} with mode {}'.format(self.we_h5filename, mode))
353
+ self.we_h5file = h5io.WESTPAH5File(self.we_h5filename, mode, driver=self.we_h5file_driver)
354
+
355
+ h5file_attrs = self.we_h5file['/'].attrs
356
+ h5file_attr_keys = list(h5file_attrs.keys())
357
+
358
+ if 'west_iter_prec' in h5file_attr_keys:
359
+ self.iter_prec = int(h5file_attrs['west_iter_prec'])
360
+ elif 'wemd_iter_prec' in h5file_attr_keys:
361
+ self.iter_prec = int(h5file_attrs['wemd_iter_prec'])
362
+ else:
363
+ log.info('iteration precision not stored in HDF5; using {:d}'.format(self.iter_prec))
364
+
365
+ if 'west_file_format_version' in h5file_attr_keys:
366
+ self.we_h5file_version = h5file_attrs['west_file_format_version']
367
+ elif 'wemd_file_format_version' in h5file_attr_keys:
368
+ self.we_h5file_version = h5file_attrs['wemd_file_format_version']
369
+ else:
370
+ log.info('WEST HDF5 file format version not stored, assuming 0')
371
+ self.we_h5file_version = 0
372
+
373
+ log.debug('opened WEST HDF5 file version {:d}'.format(self.we_h5file_version))
374
+
375
+ def prepare_backing(self): # istates):
376
+ '''Create new HDF5 file'''
377
+ self.we_h5file = h5py.File(self.we_h5filename, 'w', driver=self.we_h5file_driver)
378
+
379
+ with self.flushing_lock():
380
+ self.we_h5file['/'].attrs['west_file_format_version'] = file_format_version
381
+ self.we_h5file['/'].attrs['west_iter_prec'] = self.iter_prec
382
+ self.we_h5file['/'].attrs['west_version'] = westpa.__version__
383
+ self.current_iteration = 0
384
+ self.we_h5file['/'].create_dataset('summary', shape=(1,), dtype=summary_table_dtype, maxshape=(None,))
385
+ self.we_h5file.create_group('/iterations')
386
+
387
+ def close_backing(self):
388
+ if self.we_h5file is not None:
389
+ with self.lock:
390
+ self.we_h5file.close()
391
+ self.we_h5file = None
392
+
393
+ def flush_backing(self):
394
+ if self.we_h5file is not None:
395
+ with self.lock:
396
+ self.we_h5file.flush()
397
+ self.last_flush = time.time()
398
+
399
+ def save_target_states(self, tstates, n_iter=None):
400
+ '''Save the given target states in the HDF5 file; they will be used for the next iteration to
401
+ be propagated. A complete set is required, even if nominally appending to an existing set,
402
+ which simplifies the mapping of IDs to the table.'''
403
+
404
+ system = self.system
405
+
406
+ n_iter = n_iter or self.current_iteration
407
+
408
+ # Assemble all the important data before we start to modify the HDF5 file
409
+ tstates = list(tstates)
410
+ if tstates:
411
+ state_table = np.empty((len(tstates),), dtype=tstate_dtype)
412
+ state_pcoords = np.empty((len(tstates), system.pcoord_ndim), dtype=system.pcoord_dtype)
413
+ for i, state in enumerate(tstates):
414
+ state.state_id = i
415
+ state_table[i]['label'] = state.label
416
+ state_pcoords[i] = state.pcoord
417
+ else:
418
+ state_table = None
419
+ state_pcoords = None
420
+
421
+ # Commit changes to HDF5
422
+ with self.lock:
423
+ master_group = self.we_h5file.require_group('tstates')
424
+
425
+ try:
426
+ master_index = master_group['index']
427
+ except KeyError:
428
+ master_index = master_group.create_dataset('index', shape=(1,), maxshape=(None,), dtype=tstate_index_dtype)
429
+ n_sets = 1
430
+ else:
431
+ n_sets = len(master_index) + 1
432
+ master_index.resize((n_sets,))
433
+
434
+ set_id = n_sets - 1
435
+ master_index_row = master_index[set_id]
436
+ master_index_row['iter_valid'] = n_iter
437
+ master_index_row['n_states'] = len(tstates)
438
+
439
+ if tstates:
440
+ state_group = master_group.create_group(str(set_id))
441
+ master_index_row['group_ref'] = state_group.ref
442
+ state_group['index'] = state_table
443
+ state_group['pcoord'] = state_pcoords
444
+ else:
445
+ master_index_row['group_ref'] = None
446
+
447
+ master_index[set_id] = master_index_row
448
+
449
+ def _find_multi_iter_group(self, n_iter, master_group_name):
450
+ with self.lock:
451
+ master_group = self.we_h5file[master_group_name]
452
+ master_index = master_group['index'][...]
453
+ set_id = np.digitize([n_iter], master_index['iter_valid']) - 1
454
+ group_ref = master_index[set_id]['group_ref']
455
+
456
+ # Check if reference is Null
457
+ if not bool(group_ref):
458
+ return None
459
+
460
+ # This extra [0] is to work around a bug in h5py
461
+ try:
462
+ group = self.we_h5file[group_ref]
463
+ except (TypeError, AttributeError):
464
+ group = self.we_h5file[group_ref[0]]
465
+ else:
466
+ log.debug('h5py fixed; remove alternate code path')
467
+ log.debug('reference {!r} points to group {!r}'.format(group_ref, group))
468
+ return group
469
+
470
+ def find_tstate_group(self, n_iter):
471
+ return self._find_multi_iter_group(n_iter, 'tstates')
472
+
473
+ def find_ibstate_group(self, n_iter):
474
+ return self._find_multi_iter_group(n_iter, 'ibstates')
475
+
476
+ def get_target_states(self, n_iter):
477
+ '''Return a list of Target objects representing the target (sink) states that are in use for iteration n_iter.
478
+ Future iterations are assumed to continue from the most recent set of states.'''
479
+
480
+ with self.lock:
481
+ tstate_group = self.find_tstate_group(n_iter)
482
+
483
+ if tstate_group is not None:
484
+ tstate_index = tstate_group['index'][...]
485
+ tstate_pcoords = tstate_group['pcoord'][...]
486
+
487
+ tstates = [
488
+ TargetState(state_id=i, label=h5io.tostr(row['label']), pcoord=pcoord.copy())
489
+ for (i, (row, pcoord)) in enumerate(zip(tstate_index, tstate_pcoords))
490
+ ]
491
+ else:
492
+ tstates = []
493
+
494
+ return tstates
495
+
496
+ def create_ibstate_group(self, basis_states, n_iter=None):
497
+ '''Create the group used to store basis states and initial states (whose definitions are always
498
+ coupled). This group is hard-linked into all iteration groups that use these basis and
499
+ initial states.'''
500
+
501
+ with self.lock:
502
+ n_iter = n_iter or self.current_iteration
503
+ master_group = self.we_h5file.require_group('ibstates')
504
+
505
+ try:
506
+ master_index = master_group['index']
507
+ except KeyError:
508
+ master_index = master_group.create_dataset('index', dtype=ibstate_index_dtype, shape=(1,), maxshape=(None,))
509
+ n_sets = 1
510
+ else:
511
+ n_sets = len(master_index) + 1
512
+ master_index.resize((n_sets,))
513
+
514
+ set_id = n_sets - 1
515
+ master_index_row = master_index[set_id]
516
+ master_index_row['iter_valid'] = n_iter
517
+ master_index_row['n_bstates'] = len(basis_states)
518
+ state_group = master_group.create_group(str(set_id))
519
+ master_index_row['group_ref'] = state_group.ref
520
+
521
+ if basis_states:
522
+ system = self.system
523
+ state_table = np.empty((len(basis_states),), dtype=bstate_dtype)
524
+ state_pcoords = np.empty((len(basis_states), system.pcoord_ndim), dtype=system.pcoord_dtype)
525
+ for i, state in enumerate(basis_states):
526
+ state.state_id = i
527
+ state_table[i]['label'] = state.label
528
+ state_table[i]['probability'] = state.probability
529
+ state_table[i]['auxref'] = state.auxref or ''
530
+ state_pcoords[i] = state.pcoord
531
+
532
+ state_group['bstate_index'] = state_table
533
+ state_group['bstate_pcoord'] = state_pcoords
534
+
535
+ master_index[set_id] = master_index_row
536
+ return state_group
537
+
538
+ def create_ibstate_iter_h5file(self, basis_states):
539
+ '''Create the per-iteration HDF5 file for the basis states (i.e., iteration 0).
540
+ This special treatment is needed so that the analysis tools can access basis states
541
+ more easily.'''
542
+
543
+ if not self.store_h5:
544
+ return
545
+
546
+ segments = []
547
+ for i, state in enumerate(basis_states):
548
+ dummy_segment = Segment(
549
+ n_iter=0,
550
+ seg_id=state.state_id,
551
+ parent_id=-(state.state_id + 1),
552
+ weight=state.probability,
553
+ wtg_parent_ids=None,
554
+ pcoord=state.pcoord,
555
+ status=Segment.SEG_STATUS_UNSET,
556
+ data=state.data,
557
+ )
558
+ segments.append(dummy_segment)
559
+
560
+ # # link the iteration file in west.h5
561
+ self.prepare_iteration(0, segments)
562
+ self.update_iter_h5file(0, segments)
563
+
564
+ def update_iter_h5file(self, n_iter, segments):
565
+ '''Write out the per-iteration HDF5 file with given segments and add an external link to it
566
+ in the main HDF5 file (west.h5) if the link is not present.'''
567
+
568
+ if not self.store_h5:
569
+ return
570
+
571
+ west_h5_file = makepath(self.we_h5filename)
572
+ iter_ref_h5_file = makepath(self.iter_ref_h5_template, {'n_iter': n_iter})
573
+ iter_ref_rel_path = relpath(iter_ref_h5_file, dirname(west_h5_file))
574
+
575
+ with h5io.WESTIterationFile(iter_ref_h5_file, 'a') as outf:
576
+ for segment in segments:
577
+ outf.write_segment(segment, True)
578
+
579
+ iter_group = self.get_iter_group(n_iter)
580
+
581
+ if 'trajectories' not in iter_group:
582
+ iter_group['trajectories'] = h5py.ExternalLink(iter_ref_rel_path, '/')
583
+
584
+ def get_basis_states(self, n_iter=None):
585
+ '''Return a list of BasisState objects representing the basis states that are in use for iteration n_iter.'''
586
+
587
+ with self.lock:
588
+ n_iter = n_iter or self.current_iteration
589
+ ibstate_group = self.find_ibstate_group(n_iter)
590
+ try:
591
+ bstate_index = ibstate_group['bstate_index'][...]
592
+ except KeyError:
593
+ return []
594
+ bstate_pcoords = ibstate_group['bstate_pcoord'][...]
595
+ bstates = [
596
+ BasisState(
597
+ state_id=i,
598
+ label=h5io.tostr(row['label']),
599
+ probability=row['probability'],
600
+ auxref=h5io.tostr(row['auxref']) or None,
601
+ pcoord=pcoord.copy(),
602
+ )
603
+ for (i, (row, pcoord)) in enumerate(zip(bstate_index, bstate_pcoords))
604
+ ]
605
+
606
+ bstate_total_prob = sum(bstate.probability for bstate in bstates)
607
+
608
+ # This should run once in the second iteration, and only if start-states are specified,
609
+ # but is necessary to re-normalize (i.e. normalize without start-state probabilities included)
610
+ for i, bstate in enumerate(bstates):
611
+ bstate.probability /= bstate_total_prob
612
+ bstates[i] = bstate
613
+ return bstates
614
+
615
+ def create_initial_states(self, n_states, n_iter=None):
616
+ '''Create storage for ``n_states`` initial states associated with iteration ``n_iter``, and
617
+ return bare InitialState objects with only state_id set.'''
618
+
619
+ system = self.system
620
+ with self.lock:
621
+ n_iter = n_iter or self.current_iteration
622
+ ibstate_group = self.find_ibstate_group(n_iter)
623
+
624
+ try:
625
+ istate_index = ibstate_group['istate_index']
626
+ except KeyError:
627
+ istate_index = ibstate_group.create_dataset('istate_index', dtype=istate_dtype, shape=(n_states,), maxshape=(None,))
628
+ istate_pcoords = ibstate_group.create_dataset(
629
+ 'istate_pcoord',
630
+ dtype=system.pcoord_dtype,
631
+ shape=(n_states, system.pcoord_ndim),
632
+ maxshape=(None, system.pcoord_ndim),
633
+ )
634
+ len_index = len(istate_index)
635
+ first_id = 0
636
+ else:
637
+ first_id = len(istate_index)
638
+ len_index = len(istate_index) + n_states
639
+ istate_index.resize((len_index,))
640
+ istate_pcoords = ibstate_group['istate_pcoord']
641
+ istate_pcoords.resize((len_index, system.pcoord_ndim))
642
+
643
+ index_entries = istate_index[first_id:len_index]
644
+ new_istates = []
645
+ for irow, row in enumerate(index_entries):
646
+ row['iter_created'] = n_iter
647
+ row['istate_status'] = InitialState.ISTATE_STATUS_PENDING
648
+ new_istates.append(
649
+ InitialState(
650
+ state_id=first_id + irow,
651
+ basis_state_id=None,
652
+ iter_created=n_iter,
653
+ istate_status=InitialState.ISTATE_STATUS_PENDING,
654
+ )
655
+ )
656
+ istate_index[first_id:len_index] = index_entries
657
+ return new_istates
658
+
659
+ def update_initial_states(self, initial_states, n_iter=None):
660
+ '''Save the given initial states in the HDF5 file'''
661
+
662
+ system = self.system
663
+ initial_states = sorted(initial_states, key=attrgetter('state_id'))
664
+ if not initial_states:
665
+ return
666
+
667
+ with self.lock:
668
+ n_iter = n_iter or self.current_iteration
669
+ ibstate_group = self.find_ibstate_group(n_iter)
670
+ state_ids = [state.state_id for state in initial_states]
671
+ index_entries = ibstate_group['istate_index'][state_ids]
672
+ pcoord_vals = np.empty((len(initial_states), system.pcoord_ndim), dtype=system.pcoord_dtype)
673
+ for i, initial_state in enumerate(initial_states):
674
+ index_entries[i]['iter_created'] = initial_state.iter_created
675
+ index_entries[i]['iter_used'] = initial_state.iter_used or InitialState.ISTATE_UNUSED
676
+ index_entries[i]['basis_state_id'] = (
677
+ initial_state.basis_state_id if initial_state.basis_state_id is not None else -1
678
+ )
679
+ index_entries[i]['istate_type'] = initial_state.istate_type or InitialState.ISTATE_TYPE_UNSET
680
+ index_entries[i]['istate_status'] = initial_state.istate_status or InitialState.ISTATE_STATUS_PENDING
681
+ pcoord_vals[i] = initial_state.pcoord
682
+
683
+ index_entries[i]['basis_auxref'] = initial_state.basis_auxref or ""
684
+
685
+ ibstate_group['istate_index'][state_ids] = index_entries
686
+ ibstate_group['istate_pcoord'][state_ids] = pcoord_vals
687
+
688
+ if self.store_h5:
689
+ segments = []
690
+ for i, state in enumerate(initial_states):
691
+ dummy_segment = Segment(
692
+ n_iter=-state.iter_created,
693
+ seg_id=state.state_id,
694
+ parent_id=state.basis_state_id,
695
+ wtg_parent_ids=None,
696
+ pcoord=state.pcoord,
697
+ status=Segment.SEG_STATUS_PREPARED,
698
+ data=state.data,
699
+ )
700
+ segments.append(dummy_segment)
701
+ self.update_iter_h5file(0, segments)
702
+
703
+ def get_initial_states(self, n_iter=None):
704
+ states = []
705
+ with self.lock:
706
+ n_iter = n_iter or self.current_iteration
707
+ ibstate_group = self.find_ibstate_group(n_iter)
708
+ try:
709
+ istate_index = ibstate_group['istate_index'][...]
710
+ except KeyError:
711
+ return []
712
+ istate_pcoords = ibstate_group['istate_pcoord'][...]
713
+
714
+ for state_id, (state, pcoord) in enumerate(zip(istate_index, istate_pcoords)):
715
+ states.append(
716
+ InitialState(
717
+ state_id=state_id,
718
+ basis_state_id=int(state['basis_state_id']),
719
+ iter_created=int(state['iter_created']),
720
+ iter_used=int(state['iter_used']),
721
+ istate_type=int(state['istate_type']),
722
+ basis_auxref=h5io.tostr(state['basis_auxref']),
723
+ pcoord=pcoord.copy(),
724
+ )
725
+ )
726
+ return states
727
+
728
+ def get_segment_initial_states(self, segments, n_iter=None):
729
+ '''Retrieve all initial states referenced by the given segments.'''
730
+
731
+ with self.lock:
732
+ n_iter = n_iter or self.current_iteration
733
+ ibstate_group = self.get_iter_group(n_iter)['ibstates']
734
+
735
+ istate_ids = {-int(segment.parent_id + 1) for segment in segments if segment.parent_id < 0}
736
+ sorted_istate_ids = sorted(istate_ids)
737
+ if not sorted_istate_ids:
738
+ return []
739
+
740
+ istate_rows = ibstate_group['istate_index'][sorted_istate_ids][...]
741
+ istate_pcoords = ibstate_group['istate_pcoord'][sorted_istate_ids][...]
742
+ istates = []
743
+
744
+ for state_id, state, pcoord in zip(sorted_istate_ids, istate_rows, istate_pcoords):
745
+ try:
746
+ b_auxref = h5io.tostr(state['basis_auxref'])
747
+ except ValueError:
748
+ b_auxref = ''
749
+ istate = InitialState(
750
+ state_id=state_id,
751
+ basis_state_id=int(state['basis_state_id']),
752
+ iter_created=int(state['iter_created']),
753
+ iter_used=int(state['iter_used']),
754
+ istate_type=int(state['istate_type']),
755
+ basis_auxref=b_auxref,
756
+ pcoord=pcoord.copy(),
757
+ )
758
+ istates.append(istate)
759
+ return istates
760
+
761
+ def get_unused_initial_states(self, n_states=None, n_iter=None):
762
+ '''Retrieve any prepared but unused initial states applicable to the given iteration.
763
+ Up to ``n_states`` states are returned; if ``n_states`` is None, then all unused states
764
+ are returned.'''
765
+
766
+ n_states = n_states or sys.maxsize
767
+ ISTATE_UNUSED = InitialState.ISTATE_UNUSED
768
+ ISTATE_STATUS_PREPARED = InitialState.ISTATE_STATUS_PREPARED
769
+ with self.lock:
770
+ n_iter = n_iter or self.current_iteration
771
+ ibstate_group = self.find_ibstate_group(n_iter)
772
+ istate_index = ibstate_group['istate_index']
773
+ istate_pcoords = ibstate_group['istate_pcoord']
774
+ n_index_entries = istate_index.len()
775
+ chunksize = self.table_scan_chunksize
776
+
777
+ states = []
778
+ istart = 0
779
+ while istart < n_index_entries and len(states) < n_states:
780
+ istop = min(istart + chunksize, n_index_entries)
781
+ istate_chunk = istate_index[istart:istop]
782
+ pcoord_chunk = istate_pcoords[istart:istop]
783
+ # state_ids = np.arange(istart,istop,dtype=np.uint)
784
+
785
+ for ci in range(len(istate_chunk)):
786
+ row = istate_chunk[ci]
787
+ pcoord = pcoord_chunk[ci]
788
+ state_id = istart + ci
789
+ if row['iter_used'] == ISTATE_UNUSED and row['istate_status'] == ISTATE_STATUS_PREPARED:
790
+ istate = InitialState(
791
+ state_id=state_id,
792
+ basis_state_id=int(row['basis_state_id']),
793
+ iter_created=int(row['iter_created']),
794
+ iter_used=0,
795
+ istate_type=int(row['istate_type']),
796
+ pcoord=pcoord.copy(),
797
+ istate_status=ISTATE_STATUS_PREPARED,
798
+ )
799
+ states.append(istate)
800
+ del row, pcoord, state_id
801
+ istart += chunksize
802
+ del istate_chunk, pcoord_chunk # , state_ids, unused, ids_of_unused
803
+ log.debug('found {:d} unused states'.format(len(states)))
804
+ return states[:n_states]
805
+
806
+ def prepare_iteration(self, n_iter, segments):
807
+ """Prepare for a new iteration by creating space to store the new iteration's data.
808
+ The number of segments, their IDs, and their lineage must be determined and included
809
+ in the set of segments passed in."""
810
+
811
+ log.debug('preparing HDF5 group for iteration %d (%d segments)' % (n_iter, len(segments)))
812
+
813
+ # Ensure we have a list for guaranteed ordering
814
+ init = n_iter == 0
815
+ segments = list(segments)
816
+ n_particles = len(segments)
817
+ system = self.system
818
+ pcoord_ndim = system.pcoord_ndim
819
+ pcoord_len = 2 if init else system.pcoord_len
820
+ pcoord_dtype = system.pcoord_dtype
821
+
822
+ with self.lock:
823
+ if not init:
824
+ # Create a table of summary information about each iteration
825
+ summary_table = self.we_h5file['summary']
826
+ if len(summary_table) < n_iter:
827
+ summary_table.resize((n_iter + 1,))
828
+
829
+ iter_group = self.require_iter_group(n_iter)
830
+
831
+ for linkname in ('seg_index', 'pcoord', 'wtgraph'):
832
+ try:
833
+ del iter_group[linkname]
834
+ except KeyError:
835
+ pass
836
+
837
+ # everything indexed by [particle] goes in an index table
838
+ seg_index_table_ds = iter_group.create_dataset('seg_index', shape=(n_particles,), dtype=seg_index_dtype)
839
+ # unfortunately, h5py doesn't like in-place modification of individual fields; it expects
840
+ # tuples. So, construct everything in a numpy array and then dump the whole thing into hdf5
841
+ # In fact, this appears to be an h5py best practice (collect as much in ram as possible and then dump)
842
+ seg_index_table = seg_index_table_ds[...]
843
+
844
+ if not init:
845
+ summary_row = np.zeros((1,), dtype=summary_table_dtype)
846
+ summary_row['n_particles'] = n_particles
847
+ summary_row['norm'] = np.add.reduce(list(map(attrgetter('weight'), segments)))
848
+ summary_table[n_iter - 1] = summary_row
849
+
850
+ # pcoord is indexed as [particle, time, dimension]
851
+ pcoord_opts = self.dataset_options.get('pcoord', {'name': 'pcoord', 'h5path': 'pcoord', 'compression': False})
852
+ shape = (n_particles, pcoord_len, pcoord_ndim)
853
+ pcoord_ds = create_dataset_from_dsopts(iter_group, pcoord_opts, shape, pcoord_dtype)
854
+ pcoord = np.empty((n_particles, pcoord_len, pcoord_ndim), pcoord_dtype)
855
+
856
+ total_parents = 0
857
+ for seg_id, segment in enumerate(segments):
858
+ if segment.seg_id is not None:
859
+ assert segment.seg_id == seg_id
860
+ else:
861
+ segment.seg_id = seg_id
862
+ # Parent must be set, though what it means depends on initpoint_type
863
+ assert segment.parent_id is not None
864
+ segment.seg_id = seg_id
865
+ seg_index_table[seg_id]['status'] = segment.status
866
+ seg_index_table[seg_id]['weight'] = segment.weight
867
+ seg_index_table[seg_id]['parent_id'] = segment.parent_id
868
+ seg_index_table[seg_id]['wtg_n_parents'] = len(segment.wtg_parent_ids)
869
+ seg_index_table[seg_id]['wtg_offset'] = total_parents
870
+ total_parents += len(segment.wtg_parent_ids)
871
+
872
+ # Assign progress coordinate if any exists
873
+ if segment.pcoord is not None:
874
+ if init:
875
+ if segment.pcoord.shape != pcoord.shape[2:]:
876
+ raise ValueError(
877
+ 'basis state pcoord shape [%r] does not match expected shape [%r]'
878
+ % (segment.pcoord.shape, pcoord.shape[2:])
879
+ )
880
+ # Initial pcoord
881
+ pcoord[seg_id, 1, :] = segment.pcoord[:]
882
+ else:
883
+ if len(segment.pcoord) == 1:
884
+ # Initial pcoord
885
+ pcoord[seg_id, 0, :] = segment.pcoord[0, :]
886
+ elif segment.pcoord.shape != pcoord.shape[1:]:
887
+ raise ValueError(
888
+ 'segment pcoord shape [%r] does not match expected shape [%r]'
889
+ % (segment.pcoord.shape, pcoord.shape[1:])
890
+ )
891
+ else:
892
+ pcoord[seg_id, ...] = segment.pcoord
893
+
894
+ if total_parents > 0:
895
+ wtgraph_ds = iter_group.create_dataset('wtgraph', (total_parents,), seg_id_dtype, compression='gzip', shuffle=True)
896
+ parents = np.empty((total_parents,), seg_id_dtype)
897
+
898
+ for seg_id, segment in enumerate(segments):
899
+ offset = seg_index_table[seg_id]['wtg_offset']
900
+ extent = seg_index_table[seg_id]['wtg_n_parents']
901
+ parent_list = list(segment.wtg_parent_ids)
902
+ parents[offset : offset + extent] = parent_list[:]
903
+
904
+ assert set(parents[offset : offset + extent]) == set(segment.wtg_parent_ids)
905
+
906
+ wtgraph_ds[:] = parents
907
+
908
+ # Create convenient hard links
909
+ self.update_iter_group_links(n_iter)
910
+
911
+ # Since we accumulated many of these changes in RAM (and not directly in HDF5), propagate
912
+ # the changes out to HDF5
913
+ seg_index_table_ds[:] = seg_index_table
914
+ pcoord_ds[...] = pcoord
915
+
916
+ def update_iter_group_links(self, n_iter):
917
+ '''Update the per-iteration hard links pointing to the tables of target and initial/basis states for the
918
+ given iteration. These links are not used by this class, but are remarkably convenient for third-party
919
+ analysis tools and hdfview.'''
920
+
921
+ with self.lock:
922
+ iter_group = self.require_iter_group(n_iter)
923
+
924
+ for linkname in ('ibstates', 'tstates'):
925
+ try:
926
+ del iter_group[linkname]
927
+ except KeyError:
928
+ pass
929
+
930
+ iter_group['ibstates'] = self.find_ibstate_group(n_iter)
931
+
932
+ tstate_group = self.find_tstate_group(n_iter)
933
+ if tstate_group is not None:
934
+ iter_group['tstates'] = tstate_group
935
+
936
+ def get_iter_summary(self, n_iter=None):
937
+ n_iter = n_iter or self.current_iteration
938
+ with self.lock:
939
+ return self.we_h5file['summary'][n_iter - 1]
940
+
941
+ def update_iter_summary(self, summary, n_iter=None):
942
+ n_iter = n_iter or self.current_iteration
943
+ with self.lock:
944
+ self.we_h5file['summary'][n_iter - 1] = summary
945
+
946
+ def del_iter_summary(self, min_iter): # delete the iterations starting at min_iter
947
+ with self.lock:
948
+ self.we_h5file['summary'].resize((min_iter - 1,))
949
+
950
+ def update_segments(self, n_iter, segments):
951
+ '''Update segment information in the HDF5 file; all prior information for each
952
+ ``segment`` is overwritten, except for parent and weight transfer information.'''
953
+
954
+ segments = sorted(segments, key=attrgetter('seg_id'))
955
+
956
+ with self.lock:
957
+ iter_group = self.get_iter_group(n_iter)
958
+
959
+ pc_dsid = iter_group['pcoord'].id
960
+ si_dsid = iter_group['seg_index'].id
961
+
962
+ seg_ids = [segment.seg_id for segment in segments]
963
+ n_segments = len(segments)
964
+ n_total_segments = si_dsid.shape[0]
965
+ system = self.system
966
+ pcoord_ndim = system.pcoord_ndim
967
+ pcoord_len = system.pcoord_len
968
+ pcoord_dtype = system.pcoord_dtype
969
+
970
+ seg_index_entries = np.empty((n_segments,), dtype=seg_index_dtype)
971
+ pcoord_entries = np.empty((n_segments, pcoord_len, pcoord_ndim), dtype=pcoord_dtype)
972
+
973
+ pc_msel = h5s.create_simple(pcoord_entries.shape, (h5s.UNLIMITED,) * pcoord_entries.ndim)
974
+ pc_msel.select_all()
975
+ si_msel = h5s.create_simple(seg_index_entries.shape, (h5s.UNLIMITED,))
976
+ si_msel.select_all()
977
+ pc_fsel = pc_dsid.get_space()
978
+ si_fsel = si_dsid.get_space()
979
+
980
+ for iseg in range(n_segments):
981
+ seg_id = seg_ids[iseg]
982
+ op = h5s.SELECT_OR if iseg != 0 else h5s.SELECT_SET
983
+ si_fsel.select_hyperslab((seg_id,), (1,), op=op)
984
+ pc_fsel.select_hyperslab((seg_id, 0, 0), (1, pcoord_len, pcoord_ndim), op=op)
985
+
986
+ # read summary data so that we have valud parent and weight transfer information
987
+ si_dsid.read(si_msel, si_fsel, seg_index_entries)
988
+
989
+ for iseg, (segment, ientry) in enumerate(zip(segments, seg_index_entries)):
990
+ ientry['status'] = segment.status
991
+ ientry['endpoint_type'] = segment.endpoint_type or Segment.SEG_ENDPOINT_UNSET
992
+ ientry['cputime'] = segment.cputime
993
+ ientry['walltime'] = segment.walltime
994
+ ientry['weight'] = segment.weight
995
+
996
+ pcoord_entries[iseg] = segment.pcoord
997
+
998
+ # write progress coordinates and index using low level HDF5 functions for efficiency
999
+ si_dsid.write(si_msel, si_fsel, seg_index_entries)
1000
+ pc_dsid.write(pc_msel, pc_fsel, pcoord_entries)
1001
+
1002
+ # Now, to deal with auxiliary data
1003
+ # If any segment has any auxiliary data, then the aux dataset must spring into
1004
+ # existence. Each is named according to the name in segment.data, and has shape
1005
+ # (n_total_segs, ...) where the ... is the shape of the data in segment.data (and may be empty
1006
+ # in the case of scalar data) and dtype is taken from the data type of the data entry
1007
+ # compression is on by default for datasets that will be more than 1MiB
1008
+
1009
+ # a mapping of data set name to (per-segment shape, data type) tuples
1010
+ dsets = {}
1011
+
1012
+ # First we scan for presence, shape, and data type of auxiliary data sets
1013
+ for segment in segments:
1014
+ if segment.data:
1015
+ for dsname in segment.data:
1016
+ if dsname.startswith('iterh5/'):
1017
+ continue
1018
+ data = np.asarray(segment.data[dsname], order='C')
1019
+ segment.data[dsname] = data
1020
+ dsets[dsname] = (data.shape, data.dtype)
1021
+
1022
+ # Then we iterate over data sets and store data
1023
+ if dsets:
1024
+ for dsname, (shape, dtype) in dsets.items():
1025
+ # dset = self._require_aux_dataset(iter_group, dsname, n_total_segments, shape, dtype)
1026
+ try:
1027
+ dsopts = self.dataset_options[dsname]
1028
+ except KeyError:
1029
+ dsopts = normalize_dataset_options({'name': dsname}, path_prefix='auxdata')
1030
+
1031
+ shape = (n_total_segments,) + shape
1032
+ dset = require_dataset_from_dsopts(
1033
+ iter_group, dsopts, shape, dtype, autocompress_threshold=self.aux_compression_threshold, n_iter=n_iter
1034
+ )
1035
+ if dset is None:
1036
+ # storage is suppressed
1037
+ continue
1038
+ for segment in segments:
1039
+ try:
1040
+ auxdataset = segment.data[dsname]
1041
+ except KeyError:
1042
+ pass
1043
+ else:
1044
+ source_rank = len(auxdataset.shape)
1045
+ source_sel = h5s.create_simple(auxdataset.shape, (h5s.UNLIMITED,) * source_rank)
1046
+ source_sel.select_all()
1047
+ dest_sel = dset.id.get_space()
1048
+ dest_sel.select_hyperslab((segment.seg_id,) + (0,) * source_rank, (1,) + auxdataset.shape)
1049
+ dset.id.write(source_sel, dest_sel, auxdataset)
1050
+ if 'delram' in list(dsopts.keys()):
1051
+ del dsets[dsname]
1052
+
1053
+ self.update_iter_h5file(n_iter, segments)
1054
+
1055
+ def get_segments(self, n_iter=None, seg_ids=None, load_pcoords=True):
1056
+ '''Return the given (or all) segments from a given iteration.
1057
+
1058
+ If the optional parameter ``load_auxdata`` is true, then all auxiliary datasets
1059
+ available are loaded and mapped onto the ``data`` dictionary of each segment. If
1060
+ ``load_auxdata`` is None, then use the default ``self.auto_load_auxdata``, which can
1061
+ be set by the option ``load_auxdata`` in the ``[data]`` section of ``west.cfg``. This
1062
+ essentially requires as much RAM as there is per-iteration auxiliary data, so this
1063
+ behavior is not on by default.'''
1064
+
1065
+ n_iter = n_iter or self.current_iteration
1066
+ file_version = self.we_h5file_version
1067
+
1068
+ with self.lock:
1069
+ iter_group = self.get_iter_group(n_iter)
1070
+ seg_index_ds = iter_group['seg_index']
1071
+
1072
+ if file_version < 5:
1073
+ all_parent_ids = iter_group['parents'][...]
1074
+ else:
1075
+ all_parent_ids = iter_group['wtgraph'][...]
1076
+
1077
+ if seg_ids is not None:
1078
+ seg_ids = list(sorted(seg_ids))
1079
+ seg_index_entries = seg_index_ds[seg_ids]
1080
+ if load_pcoords:
1081
+ pcoord_entries = iter_group['pcoord'][seg_ids]
1082
+ else:
1083
+ seg_ids = list(range(len(seg_index_ds)))
1084
+ seg_index_entries = seg_index_ds[...]
1085
+ if load_pcoords:
1086
+ pcoord_entries = iter_group['pcoord'][...]
1087
+
1088
+ segments = []
1089
+
1090
+ for iseg, (seg_id, row) in enumerate(zip(seg_ids, seg_index_entries)):
1091
+ segment = Segment(
1092
+ seg_id=seg_id,
1093
+ n_iter=n_iter,
1094
+ status=int(row['status']),
1095
+ endpoint_type=int(row['endpoint_type']),
1096
+ walltime=float(row['walltime']),
1097
+ cputime=float(row['cputime']),
1098
+ weight=float(row['weight']),
1099
+ )
1100
+
1101
+ if load_pcoords:
1102
+ segment.pcoord = pcoord_entries[iseg]
1103
+
1104
+ if file_version < 5:
1105
+ wtg_n_parents = row['n_parents']
1106
+ wtg_offset = row['parents_offset']
1107
+ wtg_parent_ids = all_parent_ids[wtg_offset : wtg_offset + wtg_n_parents]
1108
+ segment.parent_id = int(wtg_parent_ids[0])
1109
+ else:
1110
+ wtg_n_parents = row['wtg_n_parents']
1111
+ wtg_offset = row['wtg_offset']
1112
+ wtg_parent_ids = all_parent_ids[wtg_offset : wtg_offset + wtg_n_parents]
1113
+ segment.parent_id = int(row['parent_id'])
1114
+ segment.wtg_parent_ids = set(map(int, wtg_parent_ids))
1115
+ assert len(segment.wtg_parent_ids) == wtg_n_parents
1116
+ segments.append(segment)
1117
+ del all_parent_ids
1118
+ if load_pcoords:
1119
+ del pcoord_entries
1120
+
1121
+ # If any other data sets are requested, load them as well
1122
+ for dsinfo in self.dataset_options.values():
1123
+ if dsinfo.get('load', False):
1124
+ dsname = dsinfo['name']
1125
+ try:
1126
+ ds = iter_group[dsinfo['h5path']]
1127
+ except KeyError:
1128
+ ds = None
1129
+
1130
+ if ds is not None:
1131
+ for segment in segments:
1132
+ seg_id = segment.seg_id
1133
+ segment.data[dsname] = ds[seg_id]
1134
+
1135
+ return segments
1136
+
1137
+ def prepare_segment_restarts(self, segments, basis_states=None, initial_states=None):
1138
+ '''Prepare the necessary folder and files given the data stored in parent per-iteration HDF5 file
1139
+ for propagating the simulation. ``basis_states`` and ``initial_states`` should be provided if the
1140
+ segments are newly created'''
1141
+
1142
+ if not self.store_h5:
1143
+ return
1144
+
1145
+ for segment in segments:
1146
+ if segment.parent_id < 0:
1147
+ if initial_states is None or basis_states is None:
1148
+ raise ValueError('initial and basis states required for preparing the segments')
1149
+ initial_state = initial_states[segment.initial_state_id]
1150
+ # Check if it's a start state
1151
+ if initial_state.istate_type == InitialState.ISTATE_TYPE_START:
1152
+ log.debug(
1153
+ f'Skip reading start state file from per-iteration HDF5 file for initial state {segment.initial_state_id}'
1154
+ )
1155
+ continue
1156
+ else:
1157
+ basis_state = basis_states[initial_state.basis_state_id]
1158
+
1159
+ parent = Segment(n_iter=0, seg_id=basis_state.state_id)
1160
+ else:
1161
+ parent = Segment(n_iter=segment.n_iter - 1, seg_id=segment.parent_id)
1162
+
1163
+ try:
1164
+ parent_iter_ref_h5_file = makepath(self.iter_ref_h5_template, {'n_iter': parent.n_iter})
1165
+
1166
+ with h5io.WESTIterationFile(parent_iter_ref_h5_file, 'r') as outf:
1167
+ outf.read_restart(parent)
1168
+
1169
+ segment.data['iterh5/restart'] = parent.data['iterh5/restart']
1170
+ except Exception as e:
1171
+ print('could not prepare restart data for segment {}/{}: {}'.format(segment.n_iter, segment.seg_id, str(e)))
1172
+
1173
+ def get_all_parent_ids(self, n_iter):
1174
+ file_version = self.we_h5file_version
1175
+ with self.lock:
1176
+ iter_group = self.get_iter_group(n_iter)
1177
+ seg_index = iter_group['seg_index']
1178
+
1179
+ if file_version < 5:
1180
+ offsets = seg_index['parents_offset']
1181
+ all_parents = iter_group['parents'][...]
1182
+ return all_parents.take(offsets)
1183
+ else:
1184
+ return seg_index['parent_id']
1185
+
1186
+ def get_parent_ids(self, n_iter, seg_ids=None):
1187
+ '''Return a sequence of the parent IDs of the given seg_ids.'''
1188
+
1189
+ file_version = self.we_h5file_version
1190
+
1191
+ with self.lock:
1192
+ iter_group = self.get_iter_group(n_iter)
1193
+ seg_index = iter_group['seg_index']
1194
+
1195
+ if seg_ids is None:
1196
+ seg_ids = range(len(seg_index))
1197
+
1198
+ if file_version < 5:
1199
+ offsets = seg_index['parents_offset']
1200
+ all_parents = iter_group['parents'][...]
1201
+ return [all_parents[offsets[seg_id]] for seg_id in seg_ids]
1202
+ else:
1203
+ all_parents = seg_index['parent_id']
1204
+ return [all_parents[seg_id] for seg_id in seg_ids]
1205
+
1206
+ def get_weights(self, n_iter, seg_ids):
1207
+ '''Return the weights associated with the given seg_ids'''
1208
+
1209
+ unique_ids = sorted(set(seg_ids))
1210
+ if not unique_ids:
1211
+ return []
1212
+ with self.lock:
1213
+ iter_group = self.get_iter_group(n_iter)
1214
+ index_subset = iter_group['seg_index'][unique_ids]
1215
+ weight_map = dict(zip(unique_ids, index_subset['weight']))
1216
+ return [weight_map[seg_id] for seg_id in seg_ids]
1217
+
1218
+ def get_child_ids(self, n_iter, seg_id):
1219
+ '''Return the seg_ids of segments who have the given segment as a parent.'''
1220
+
1221
+ with self.lock:
1222
+ if n_iter == self.current_iteration:
1223
+ return []
1224
+
1225
+ iter_group = self.get_iter_group(n_iter + 1)
1226
+ seg_index = iter_group['seg_index']
1227
+ seg_ids = np.arange(len(seg_index), dtype=seg_id_dtype)
1228
+
1229
+ if self.we_h5file_version < 5:
1230
+ offsets = seg_index['parents_offset']
1231
+ all_parent_ids = iter_group['parents'][...]
1232
+ parent_ids = np.array([all_parent_ids[offset] for offset in offsets])
1233
+ else:
1234
+ parent_ids = seg_index['parent_id']
1235
+
1236
+ return seg_ids[parent_ids == seg_id]
1237
+
1238
+ def get_children(self, segment):
1239
+ '''Return all segments which have the given segment as a parent'''
1240
+
1241
+ if segment.n_iter == self.current_iteration:
1242
+ return []
1243
+
1244
+ # Examine the segment index from the following iteration to see who has this segment
1245
+ # as a parent. We don't need to worry about the number of parents each segment
1246
+ # has, since each has at least one, and indexing on the offset into the parents array
1247
+ # gives the primary parent ID
1248
+
1249
+ with self.lock:
1250
+ iter_group = self.get_iter_group(segment.n_iter + 1)
1251
+ seg_index = iter_group['seg_index'][...]
1252
+
1253
+ # This is one of the slowest pieces of code I've ever written...
1254
+ # seg_index = iter_group['seg_index'][...]
1255
+ # seg_ids = [seg_id for (seg_id,row) in enumerate(seg_index)
1256
+ # if all_parent_ids[row['parents_offset']] == segment.seg_id]
1257
+ # return self.get_segments_by_id(segment.n_iter+1, seg_ids)
1258
+ if self.we_h5file_version < 5:
1259
+ parents = iter_group['parents'][seg_index['parent_offsets']]
1260
+ else:
1261
+ parents = seg_index['parent_id']
1262
+ all_seg_ids = np.arange(seg_index.len(), dtype=np.uintp)
1263
+ seg_ids = all_seg_ids[parents == segment.seg_id]
1264
+ # the above will return a scalar if only one is found, so convert
1265
+ # to a list if necessary
1266
+ try:
1267
+ len(seg_ids)
1268
+ except TypeError:
1269
+ seg_ids = [seg_ids]
1270
+
1271
+ return self.get_segments(segment.n_iter + 1, seg_ids)
1272
+
1273
+ # The following are dictated by the SimManager interface
1274
+ def prepare_run(self):
1275
+ self.open_backing()
1276
+
1277
+ def finalize_run(self):
1278
+ self.flush_backing()
1279
+ self.close_backing()
1280
+
1281
+ def save_new_weight_data(self, n_iter, new_weights):
1282
+ '''Save a set of NewWeightEntry objects to HDF5. Note that this should
1283
+ be called for the iteration in which the weights appear in their
1284
+ new locations (e.g. for recycled walkers, the iteration following
1285
+ recycling).'''
1286
+
1287
+ if not new_weights:
1288
+ return
1289
+
1290
+ system = self.system
1291
+
1292
+ index = np.empty(len(new_weights), dtype=nw_index_dtype)
1293
+ prev_init_pcoords = system.new_pcoord_array(len(new_weights))
1294
+ prev_final_pcoords = system.new_pcoord_array(len(new_weights))
1295
+ new_init_pcoords = system.new_pcoord_array(len(new_weights))
1296
+
1297
+ for ientry, nwentry in enumerate(new_weights):
1298
+ row = index[ientry]
1299
+ row['source_type'] = nwentry.source_type
1300
+ row['weight'] = nwentry.weight
1301
+ row['prev_seg_id'] = nwentry.prev_seg_id
1302
+ # the following use -1 as a sentinel for a missing value
1303
+ row['target_state_id'] = nwentry.target_state_id if nwentry.target_state_id is not None else -1
1304
+ row['initial_state_id'] = nwentry.initial_state_id if nwentry.initial_state_id is not None else -1
1305
+
1306
+ index[ientry] = row
1307
+
1308
+ if nwentry.prev_init_pcoord is not None:
1309
+ prev_init_pcoords[ientry] = nwentry.prev_init_pcoord
1310
+
1311
+ if nwentry.prev_final_pcoord is not None:
1312
+ prev_final_pcoords[ientry] = nwentry.prev_final_pcoord
1313
+
1314
+ if nwentry.new_init_pcoord is not None:
1315
+ new_init_pcoords[ientry] = nwentry.new_init_pcoord
1316
+
1317
+ with self.lock:
1318
+ iter_group = self.get_iter_group(n_iter)
1319
+ try:
1320
+ del iter_group['new_weights']
1321
+ except KeyError:
1322
+ pass
1323
+
1324
+ nwgroup = iter_group.create_group('new_weights')
1325
+ nwgroup['index'] = index
1326
+ nwgroup['prev_init_pcoord'] = prev_init_pcoords
1327
+ nwgroup['prev_final_pcoord'] = prev_final_pcoords
1328
+ nwgroup['new_init_pcoord'] = new_init_pcoords
1329
+
1330
+ def get_new_weight_data(self, n_iter):
1331
+ with self.lock:
1332
+ iter_group = self.get_iter_group(n_iter)
1333
+
1334
+ try:
1335
+ nwgroup = iter_group['new_weights']
1336
+ except KeyError:
1337
+ return []
1338
+
1339
+ try:
1340
+ index = nwgroup['index'][...]
1341
+ prev_init_pcoords = nwgroup['prev_init_pcoord'][...]
1342
+ prev_final_pcoords = nwgroup['prev_final_pcoord'][...]
1343
+ new_init_pcoords = nwgroup['new_init_pcoord'][...]
1344
+ except (KeyError, ValueError): # zero-length selections raise ValueError
1345
+ return []
1346
+
1347
+ entries = []
1348
+ for i in range(len(index)):
1349
+ irow = index[i]
1350
+
1351
+ prev_seg_id = irow['prev_seg_id']
1352
+ if prev_seg_id == -1:
1353
+ prev_seg_id = None
1354
+
1355
+ initial_state_id = irow['initial_state_id']
1356
+ if initial_state_id == -1:
1357
+ initial_state_id = None
1358
+
1359
+ target_state_id = irow['target_state_id']
1360
+ if target_state_id == -1:
1361
+ target_state_id = None
1362
+
1363
+ entry = NewWeightEntry(
1364
+ source_type=irow['source_type'],
1365
+ weight=irow['weight'],
1366
+ prev_seg_id=prev_seg_id,
1367
+ prev_init_pcoord=prev_init_pcoords[i].copy(),
1368
+ prev_final_pcoord=prev_final_pcoords[i].copy(),
1369
+ new_init_pcoord=new_init_pcoords[i].copy(),
1370
+ target_state_id=target_state_id,
1371
+ initial_state_id=initial_state_id,
1372
+ )
1373
+
1374
+ entries.append(entry)
1375
+ return entries
1376
+
1377
+ def find_bin_mapper(self, hashval):
1378
+ '''Check to see if the given has value is in the binning table. Returns the index in the
1379
+ bin data tables if found, or raises KeyError if not.'''
1380
+
1381
+ try:
1382
+ hashval = hashval.hexdigest()
1383
+ except AttributeError:
1384
+ pass
1385
+
1386
+ with self.lock:
1387
+ # these will raise KeyError if the group doesn't exist, which also means
1388
+ # that bin data is not available, so no special treatment here
1389
+ try:
1390
+ binning_group = self.we_h5file['/bin_topologies']
1391
+ index = binning_group['index']
1392
+ except KeyError:
1393
+ raise KeyError('hash {} not found'.format(hashval))
1394
+
1395
+ n_entries = len(index)
1396
+ if n_entries == 0:
1397
+ raise KeyError('hash {} not found'.format(hashval))
1398
+
1399
+ chunksize = self.table_scan_chunksize
1400
+ for istart in range(0, n_entries, chunksize):
1401
+ chunk = index[istart : min(istart + chunksize, n_entries)]
1402
+ for i in range(len(chunk)):
1403
+ if chunk[i]['hash'] == bytes(hashval, 'utf-8'):
1404
+ return istart + i
1405
+
1406
+ raise KeyError('hash {} not found'.format(hashval))
1407
+
1408
+ def get_bin_mapper(self, hashval):
1409
+ '''Look up the given hash value in the binning table, unpickling and returning the corresponding
1410
+ bin mapper if available, or raising KeyError if not.'''
1411
+
1412
+ # Convert to a hex digest if we need to
1413
+ try:
1414
+ hashval = hashval.hexdigest()
1415
+ except AttributeError:
1416
+ pass
1417
+
1418
+ with self.lock:
1419
+ # these will raise KeyError if the group doesn't exist, which also means
1420
+ # that bin data is not available, so no special treatment here
1421
+ try:
1422
+ binning_group = self.we_h5file['/bin_topologies']
1423
+ index = binning_group['index']
1424
+ pkl = binning_group['pickles']
1425
+ except KeyError:
1426
+ raise KeyError('hash {} not found. Could not retrieve binning group'.format(hashval))
1427
+
1428
+ n_entries = len(index)
1429
+ if n_entries == 0:
1430
+ raise KeyError('hash {} not found. No entries in index'.format(hashval))
1431
+
1432
+ chunksize = self.table_scan_chunksize
1433
+
1434
+ for istart in range(0, n_entries, chunksize):
1435
+ chunk = index[istart : min(istart + chunksize, n_entries)]
1436
+ for i in range(len(chunk)):
1437
+ if chunk[i]['hash'] == bytes(hashval, 'utf-8'):
1438
+ pkldat = bytes(pkl[istart + i, 0 : chunk[i]['pickle_len']].data)
1439
+ mapper = pickle.loads(pkldat)
1440
+ log.debug('loaded {!r} from {!r}'.format(mapper, binning_group))
1441
+ log.debug('hash value {!r}'.format(hashval))
1442
+ return mapper
1443
+
1444
+ raise KeyError('hash {} not found'.format(hashval))
1445
+
1446
+ def save_bin_mapper(self, hashval, pickle_data):
1447
+ '''Store the given mapper in the table of saved mappers. If the mapper cannot be stored,
1448
+ PickleError will be raised. Returns the index in the bin data tables where the mapper is stored.'''
1449
+
1450
+ try:
1451
+ hashval = hashval.hexdigest()
1452
+ except AttributeError:
1453
+ pass
1454
+ pickle_data = bytes(pickle_data)
1455
+
1456
+ # First, scan to see if the mapper already is in the HDF5 file
1457
+ try:
1458
+ return self.find_bin_mapper(hashval)
1459
+ except KeyError:
1460
+ pass
1461
+
1462
+ # At this point, we have a valid pickle and know it's not stored
1463
+ with self.lock:
1464
+ binning_group = self.we_h5file.require_group('/bin_topologies')
1465
+
1466
+ try:
1467
+ index = binning_group['index']
1468
+ pickle_ds = binning_group['pickles']
1469
+ except KeyError:
1470
+ index = binning_group.create_dataset('index', shape=(1,), maxshape=(None,), dtype=binning_index_dtype)
1471
+ pickle_ds = binning_group.create_dataset(
1472
+ 'pickles',
1473
+ dtype=np.uint8,
1474
+ shape=(1, len(pickle_data)),
1475
+ maxshape=(None, None),
1476
+ chunks=(1, 4096),
1477
+ compression='gzip',
1478
+ compression_opts=9,
1479
+ )
1480
+ n_entries = 1
1481
+ else:
1482
+ n_entries = len(index) + 1
1483
+ index.resize((n_entries,))
1484
+ new_hsize = max(pickle_ds.shape[1], len(pickle_data))
1485
+ pickle_ds.resize((n_entries, new_hsize))
1486
+
1487
+ index_row = index[n_entries - 1]
1488
+ index_row['hash'] = hashval
1489
+ index_row['pickle_len'] = len(pickle_data)
1490
+ index[n_entries - 1] = index_row
1491
+ pickle_ds[n_entries - 1, : len(pickle_data)] = memoryview(pickle_data)
1492
+ return n_entries - 1
1493
+
1494
+ def save_iter_binning(self, n_iter, hashval, pickled_mapper, target_counts):
1495
+ '''Save information about the binning used to generate segments for iteration n_iter.'''
1496
+
1497
+ with self.lock:
1498
+ iter_group = self.get_iter_group(n_iter)
1499
+
1500
+ try:
1501
+ del iter_group['bin_target_counts']
1502
+ except KeyError:
1503
+ pass
1504
+
1505
+ iter_group['bin_target_counts'] = target_counts
1506
+
1507
+ if hashval and pickled_mapper:
1508
+ self.save_bin_mapper(hashval, pickled_mapper)
1509
+ iter_group.attrs['binhash'] = hashval
1510
+ else:
1511
+ iter_group.attrs['binhash'] = ''
1512
+
1513
+
1514
+ def normalize_dataset_options(dsopts, path_prefix='', n_iter=0):
1515
+ dsopts = dict(dsopts)
1516
+
1517
+ ds_name = dsopts['name']
1518
+ if path_prefix:
1519
+ default_h5path = '{}/{}'.format(path_prefix, ds_name)
1520
+ else:
1521
+ default_h5path = ds_name
1522
+
1523
+ dsopts.setdefault('h5path', default_h5path)
1524
+ dtype = dsopts.get('dtype')
1525
+ if dtype:
1526
+ if isinstance(dtype, str):
1527
+ try:
1528
+ dsopts['dtype'] = np.dtype(getattr(np, dtype))
1529
+ except AttributeError:
1530
+ dsopts['dtype'] = np.dtype(getattr(builtins, dtype))
1531
+ else:
1532
+ dsopts['dtype'] = np.dtype(dtype)
1533
+
1534
+ dsopts['store'] = bool(dsopts['store']) if 'store' in dsopts else True
1535
+ dsopts['load'] = bool(dsopts['load']) if 'load' in dsopts else False
1536
+
1537
+ return dsopts
1538
+
1539
+
1540
+ def create_dataset_from_dsopts(group, dsopts, shape=None, dtype=None, data=None, autocompress_threshold=None, n_iter=None):
1541
+ # log.debug('create_dataset_from_dsopts(group={!r}, dsopts={!r}, shape={!r}, dtype={!r}, data={!r}, autocompress_threshold={!r})'
1542
+ # .format(group,dsopts,shape,dtype,data,autocompress_threshold))
1543
+ if not dsopts.get('store', True):
1544
+ return None
1545
+
1546
+ if 'file' in list(dsopts.keys()):
1547
+ import h5py
1548
+
1549
+ # dsopts['file'] = str(dsopts['file']).format(n_iter=n_iter)
1550
+ h5_auxfile = h5io.WESTPAH5File(dsopts['file'].format(n_iter=n_iter))
1551
+ h5group = group
1552
+ if not ("iter_" + str(n_iter).zfill(8)) in h5_auxfile:
1553
+ h5_auxfile.create_group("iter_" + str(n_iter).zfill(8))
1554
+ group = h5_auxfile[('/' + "iter_" + str(n_iter).zfill(8))]
1555
+
1556
+ h5path = dsopts['h5path']
1557
+ containing_group_name = posixpath.dirname(h5path)
1558
+ h5_dsname = posixpath.basename(h5path)
1559
+
1560
+ # ensure arguments are sane
1561
+ if not shape and data is None:
1562
+ raise ValueError('either shape or data must be provided')
1563
+ elif data is None and (shape and dtype is None):
1564
+ raise ValueError('both shape and dtype must be provided when data is not provided')
1565
+ elif shape and data is not None and not data.shape == shape:
1566
+ raise ValueError('explicit shape {!r} does not match data shape {!r}'.format(shape, data.shape))
1567
+
1568
+ if data is not None:
1569
+ shape = data.shape
1570
+ if dtype is None:
1571
+ dtype = data.dtype
1572
+ # end argument sanity checks
1573
+
1574
+ # figure out where to store this data
1575
+ if containing_group_name:
1576
+ containing_group = group.require_group(containing_group_name)
1577
+ else:
1578
+ containing_group = group
1579
+
1580
+ # has user requested an explicit data type?
1581
+ # the extra np.dtype is an idempotent operation on true dtype
1582
+ # objects, but ensures that things like np.float32, which are
1583
+ # actually NOT dtype objects, become dtype objects
1584
+ h5_dtype = np.dtype(dsopts.get('dtype', dtype))
1585
+
1586
+ compression = None
1587
+ scaleoffset = None
1588
+ shuffle = False
1589
+
1590
+ # compress if 1) explicitly requested, or 2) dataset size exceeds threshold and
1591
+ # compression not explicitly prohibited
1592
+ compression_directive = dsopts.get('compression')
1593
+ if compression_directive is None:
1594
+ # No directive
1595
+ nbytes = np.multiply.reduce(shape) * h5_dtype.itemsize
1596
+ if autocompress_threshold and nbytes > autocompress_threshold:
1597
+ compression = 9
1598
+ elif compression_directive == 0: # includes False
1599
+ # Compression prohibited
1600
+ compression = None
1601
+ else: # compression explicitly requested
1602
+ compression = compression_directive
1603
+
1604
+ # Is scale/offset requested?
1605
+ scaleoffset = dsopts.get('scaleoffset', None)
1606
+ if scaleoffset is not None:
1607
+ scaleoffset = int(scaleoffset)
1608
+
1609
+ # We always shuffle if we compress (losslessly)
1610
+ if compression:
1611
+ shuffle = True
1612
+ else:
1613
+ shuffle = False
1614
+
1615
+ need_chunks = any([compression, scaleoffset is not None, shuffle])
1616
+
1617
+ # We use user-provided chunks if available
1618
+ chunks_directive = dsopts.get('chunks')
1619
+ if chunks_directive is None:
1620
+ chunks = None
1621
+ elif chunks_directive is True:
1622
+ chunks = calc_chunksize(shape, h5_dtype)
1623
+ elif chunks_directive is False:
1624
+ chunks = None
1625
+ else:
1626
+ chunks = tuple(chunks_directive[i] if chunks_directive[i] <= shape[i] else shape[i] for i in range(len(shape)))
1627
+
1628
+ if not chunks and need_chunks:
1629
+ chunks = calc_chunksize(shape, h5_dtype)
1630
+
1631
+ opts = {'shape': shape, 'dtype': h5_dtype, 'compression': compression, 'shuffle': shuffle, 'chunks': chunks}
1632
+
1633
+ try:
1634
+ import h5py._hl.filters
1635
+
1636
+ h5py._hl.filters._COMP_FILTERS['scaleoffset']
1637
+ except (ImportError, KeyError, AttributeError):
1638
+ # filter not available, or an unexpected version of h5py
1639
+ # use lossless compression instead
1640
+ opts['compression'] = True
1641
+ else:
1642
+ opts['scaleoffset'] = scaleoffset
1643
+
1644
+ if log.isEnabledFor(logging.DEBUG):
1645
+ log.debug('requiring aux dataset {!r}, shape={!r}, opts={!r}'.format(h5_dsname, shape, opts))
1646
+
1647
+ dset = containing_group.require_dataset(h5_dsname, **opts)
1648
+
1649
+ if data is not None:
1650
+ dset[...] = data
1651
+
1652
+ if 'file' in list(dsopts.keys()):
1653
+ import h5py
1654
+
1655
+ if not dsopts['h5path'] in h5group:
1656
+ h5group[dsopts['h5path']] = h5py.ExternalLink(
1657
+ dsopts['file'].format(n_iter=n_iter), ("/" + "iter_" + str(n_iter).zfill(8) + "/" + dsopts['h5path'])
1658
+ )
1659
+
1660
+ return dset
1661
+
1662
+
1663
+ def require_dataset_from_dsopts(group, dsopts, shape=None, dtype=None, data=None, autocompress_threshold=None, n_iter=None):
1664
+ if not dsopts.get('store', True):
1665
+ return None
1666
+ try:
1667
+ return group[dsopts['h5path']]
1668
+ except KeyError:
1669
+ return create_dataset_from_dsopts(
1670
+ group, dsopts, shape=shape, dtype=dtype, data=data, autocompress_threshold=autocompress_threshold, n_iter=n_iter
1671
+ )
1672
+
1673
+
1674
+ def calc_chunksize(shape, dtype, max_chunksize=262144):
1675
+ '''Calculate a chunk size for HDF5 data, anticipating that access will slice
1676
+ along lower dimensions sooner than higher dimensions.'''
1677
+
1678
+ chunk_shape = list(shape)
1679
+ for idim in range(len(shape)):
1680
+ chunk_nbytes = np.multiply.reduce(chunk_shape) * dtype.itemsize
1681
+ while chunk_shape[idim] > 1 and chunk_nbytes > max_chunksize:
1682
+ chunk_shape[idim] >>= 1 # divide by 2
1683
+ chunk_nbytes = np.multiply.reduce(chunk_shape) * dtype.itemsize
1684
+
1685
+ if chunk_nbytes <= max_chunksize:
1686
+ break
1687
+
1688
+ chunk_shape = tuple(chunk_shape)
1689
+ log.debug(
1690
+ 'selected chunk shape {} for data set of type {} shaped {} (chunk size = {} bytes)'.format(
1691
+ chunk_shape, dtype, shape, chunk_nbytes
1692
+ )
1693
+ )
1694
+ return chunk_shape