westpa 2022.13__cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. westpa/__init__.py +14 -0
  2. westpa/_version.py +21 -0
  3. westpa/analysis/__init__.py +5 -0
  4. westpa/analysis/core.py +749 -0
  5. westpa/analysis/statistics.py +27 -0
  6. westpa/analysis/trajectories.py +369 -0
  7. westpa/cli/__init__.py +0 -0
  8. westpa/cli/core/__init__.py +0 -0
  9. westpa/cli/core/w_fork.py +152 -0
  10. westpa/cli/core/w_init.py +230 -0
  11. westpa/cli/core/w_run.py +77 -0
  12. westpa/cli/core/w_states.py +212 -0
  13. westpa/cli/core/w_succ.py +99 -0
  14. westpa/cli/core/w_truncate.py +68 -0
  15. westpa/cli/tools/__init__.py +0 -0
  16. westpa/cli/tools/ploterr.py +506 -0
  17. westpa/cli/tools/plothist.py +706 -0
  18. westpa/cli/tools/w_assign.py +597 -0
  19. westpa/cli/tools/w_bins.py +166 -0
  20. westpa/cli/tools/w_crawl.py +119 -0
  21. westpa/cli/tools/w_direct.py +557 -0
  22. westpa/cli/tools/w_dumpsegs.py +94 -0
  23. westpa/cli/tools/w_eddist.py +506 -0
  24. westpa/cli/tools/w_fluxanl.py +376 -0
  25. westpa/cli/tools/w_ipa.py +832 -0
  26. westpa/cli/tools/w_kinavg.py +127 -0
  27. westpa/cli/tools/w_kinetics.py +96 -0
  28. westpa/cli/tools/w_multi_west.py +414 -0
  29. westpa/cli/tools/w_ntop.py +213 -0
  30. westpa/cli/tools/w_pdist.py +515 -0
  31. westpa/cli/tools/w_postanalysis_matrix.py +82 -0
  32. westpa/cli/tools/w_postanalysis_reweight.py +53 -0
  33. westpa/cli/tools/w_red.py +491 -0
  34. westpa/cli/tools/w_reweight.py +780 -0
  35. westpa/cli/tools/w_select.py +226 -0
  36. westpa/cli/tools/w_stateprobs.py +111 -0
  37. westpa/cli/tools/w_timings.py +113 -0
  38. westpa/cli/tools/w_trace.py +599 -0
  39. westpa/core/__init__.py +0 -0
  40. westpa/core/_rc.py +673 -0
  41. westpa/core/binning/__init__.py +55 -0
  42. westpa/core/binning/_assign.c +36018 -0
  43. westpa/core/binning/_assign.cpython-312-aarch64-linux-gnu.so +0 -0
  44. westpa/core/binning/_assign.pyx +370 -0
  45. westpa/core/binning/assign.py +454 -0
  46. westpa/core/binning/binless.py +96 -0
  47. westpa/core/binning/binless_driver.py +54 -0
  48. westpa/core/binning/binless_manager.py +189 -0
  49. westpa/core/binning/bins.py +47 -0
  50. westpa/core/binning/mab.py +506 -0
  51. westpa/core/binning/mab_driver.py +54 -0
  52. westpa/core/binning/mab_manager.py +197 -0
  53. westpa/core/data_manager.py +1761 -0
  54. westpa/core/extloader.py +74 -0
  55. westpa/core/h5io.py +1079 -0
  56. westpa/core/kinetics/__init__.py +24 -0
  57. westpa/core/kinetics/_kinetics.c +45174 -0
  58. westpa/core/kinetics/_kinetics.cpython-312-aarch64-linux-gnu.so +0 -0
  59. westpa/core/kinetics/_kinetics.pyx +815 -0
  60. westpa/core/kinetics/events.py +147 -0
  61. westpa/core/kinetics/matrates.py +156 -0
  62. westpa/core/kinetics/rate_averaging.py +266 -0
  63. westpa/core/progress.py +218 -0
  64. westpa/core/propagators/__init__.py +54 -0
  65. westpa/core/propagators/executable.py +592 -0
  66. westpa/core/propagators/loaders.py +196 -0
  67. westpa/core/reweight/__init__.py +14 -0
  68. westpa/core/reweight/_reweight.c +36899 -0
  69. westpa/core/reweight/_reweight.cpython-312-aarch64-linux-gnu.so +0 -0
  70. westpa/core/reweight/_reweight.pyx +439 -0
  71. westpa/core/reweight/matrix.py +126 -0
  72. westpa/core/segment.py +119 -0
  73. westpa/core/sim_manager.py +839 -0
  74. westpa/core/states.py +359 -0
  75. westpa/core/systems.py +93 -0
  76. westpa/core/textio.py +74 -0
  77. westpa/core/trajectory.py +603 -0
  78. westpa/core/we_driver.py +910 -0
  79. westpa/core/wm_ops.py +43 -0
  80. westpa/core/yamlcfg.py +298 -0
  81. westpa/fasthist/__init__.py +34 -0
  82. westpa/fasthist/_fasthist.c +38755 -0
  83. westpa/fasthist/_fasthist.cpython-312-aarch64-linux-gnu.so +0 -0
  84. westpa/fasthist/_fasthist.pyx +222 -0
  85. westpa/mclib/__init__.py +271 -0
  86. westpa/mclib/__main__.py +28 -0
  87. westpa/mclib/_mclib.c +34610 -0
  88. westpa/mclib/_mclib.cpython-312-aarch64-linux-gnu.so +0 -0
  89. westpa/mclib/_mclib.pyx +226 -0
  90. westpa/oldtools/__init__.py +4 -0
  91. westpa/oldtools/aframe/__init__.py +35 -0
  92. westpa/oldtools/aframe/atool.py +75 -0
  93. westpa/oldtools/aframe/base_mixin.py +26 -0
  94. westpa/oldtools/aframe/binning.py +178 -0
  95. westpa/oldtools/aframe/data_reader.py +560 -0
  96. westpa/oldtools/aframe/iter_range.py +200 -0
  97. westpa/oldtools/aframe/kinetics.py +117 -0
  98. westpa/oldtools/aframe/mcbs.py +153 -0
  99. westpa/oldtools/aframe/output.py +39 -0
  100. westpa/oldtools/aframe/plotting.py +88 -0
  101. westpa/oldtools/aframe/trajwalker.py +126 -0
  102. westpa/oldtools/aframe/transitions.py +469 -0
  103. westpa/oldtools/cmds/__init__.py +0 -0
  104. westpa/oldtools/cmds/w_ttimes.py +361 -0
  105. westpa/oldtools/files.py +34 -0
  106. westpa/oldtools/miscfn.py +23 -0
  107. westpa/oldtools/stats/__init__.py +4 -0
  108. westpa/oldtools/stats/accumulator.py +35 -0
  109. westpa/oldtools/stats/edfs.py +129 -0
  110. westpa/oldtools/stats/mcbs.py +96 -0
  111. westpa/tools/__init__.py +33 -0
  112. westpa/tools/binning.py +472 -0
  113. westpa/tools/core.py +340 -0
  114. westpa/tools/data_reader.py +159 -0
  115. westpa/tools/dtypes.py +31 -0
  116. westpa/tools/iter_range.py +198 -0
  117. westpa/tools/kinetics_tool.py +343 -0
  118. westpa/tools/plot.py +283 -0
  119. westpa/tools/progress.py +17 -0
  120. westpa/tools/selected_segs.py +154 -0
  121. westpa/tools/wipi.py +751 -0
  122. westpa/trajtree/__init__.py +4 -0
  123. westpa/trajtree/_trajtree.c +17829 -0
  124. westpa/trajtree/_trajtree.cpython-312-aarch64-linux-gnu.so +0 -0
  125. westpa/trajtree/_trajtree.pyx +130 -0
  126. westpa/trajtree/trajtree.py +117 -0
  127. westpa/westext/__init__.py +0 -0
  128. westpa/westext/adaptvoronoi/__init__.py +3 -0
  129. westpa/westext/adaptvoronoi/adaptVor_driver.py +214 -0
  130. westpa/westext/hamsm_restarting/__init__.py +3 -0
  131. westpa/westext/hamsm_restarting/example_overrides.py +35 -0
  132. westpa/westext/hamsm_restarting/restart_driver.py +1165 -0
  133. westpa/westext/stringmethod/__init__.py +11 -0
  134. westpa/westext/stringmethod/fourier_fitting.py +69 -0
  135. westpa/westext/stringmethod/string_driver.py +253 -0
  136. westpa/westext/stringmethod/string_method.py +306 -0
  137. westpa/westext/weed/BinCluster.py +180 -0
  138. westpa/westext/weed/ProbAdjustEquil.py +100 -0
  139. westpa/westext/weed/UncertMath.py +247 -0
  140. westpa/westext/weed/__init__.py +10 -0
  141. westpa/westext/weed/weed_driver.py +192 -0
  142. westpa/westext/wess/ProbAdjust.py +101 -0
  143. westpa/westext/wess/__init__.py +6 -0
  144. westpa/westext/wess/wess_driver.py +217 -0
  145. westpa/work_managers/__init__.py +57 -0
  146. westpa/work_managers/core.py +396 -0
  147. westpa/work_managers/environment.py +134 -0
  148. westpa/work_managers/mpi.py +318 -0
  149. westpa/work_managers/processes.py +201 -0
  150. westpa/work_managers/serial.py +28 -0
  151. westpa/work_managers/threads.py +79 -0
  152. westpa/work_managers/zeromq/__init__.py +20 -0
  153. westpa/work_managers/zeromq/core.py +635 -0
  154. westpa/work_managers/zeromq/node.py +131 -0
  155. westpa/work_managers/zeromq/work_manager.py +526 -0
  156. westpa/work_managers/zeromq/worker.py +320 -0
  157. westpa-2022.13.dist-info/METADATA +179 -0
  158. westpa-2022.13.dist-info/RECORD +162 -0
  159. westpa-2022.13.dist-info/WHEEL +7 -0
  160. westpa-2022.13.dist-info/entry_points.txt +30 -0
  161. westpa-2022.13.dist-info/licenses/LICENSE +21 -0
  162. westpa-2022.13.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1761 @@
1
+ """
2
+ HDF5 data manager for WEST.
3
+
4
+ Original HDF5 implementation: Joseph W. Kaus
5
+ Current implementation: Matthew C. Zwier
6
+
7
+ WEST exclusively uses the cross-platform, self-describing file format HDF5
8
+ for data storage. This ensures that data is stored efficiently and portably
9
+ in a manner that is relatively straightforward for other analysis tools
10
+ (perhaps written in C/C++/Fortran) to access.
11
+
12
+ The data is laid out in HDF5 as follows:
13
+ - summary -- overall summary data for the simulation
14
+ - /iterations/ -- data for individual iterations, one group per iteration under /iterations
15
+ - iter_00000001/ -- data for iteration 1
16
+ - seg_index -- overall information about segments in the iteration, including weight
17
+ - pcoord -- progress coordinate data organized as [seg_id][time][dimension]
18
+ - wtg_parents -- data used to reconstruct the split/merge history of trajectories
19
+ - recycling -- flux and event count for recycled particles, on a per-target-state basis
20
+ - auxdata/ -- auxiliary datasets (data stored on the 'data' field of Segment objects)
21
+
22
+ The file root object has an integer attribute 'west_file_format_version' which can be used to
23
+ determine how to access data even as the file format (i.e. organization of data within HDF5 file)
24
+ evolves.
25
+
26
+ Version history:
27
+ Version 10
28
+ - BinMapper pickle, hash, and bin_target_count are now saved for iteration 1
29
+ Version 9
30
+ - Basis states are now saved as iter_segid instead of just segid as a pointer label.
31
+ - Initial states are also saved in the iteration 0 file, with a negative sign.
32
+ Version 8
33
+ - Added external links to trajectory files in iterations/iter_* groups, if the HDF5
34
+ framework was used.
35
+ - Added an iter group for the iteration 0 to store conformations of basis states.
36
+ Version 7
37
+ - Removed bin_assignments, bin_populations, and bin_rates from iteration group.
38
+ - Added new_segments subgroup to iteration group
39
+ Version 6
40
+ - ???
41
+ Version 5
42
+ - moved iter_* groups into a top-level iterations/ group,
43
+ - added in-HDF5 storage for basis states, target states, and generated states
44
+ """
45
+
46
+ import logging
47
+ import os
48
+ import pickle
49
+ import posixpath
50
+ import sys
51
+ import threading
52
+ import time
53
+ import re
54
+ import builtins
55
+ from operator import attrgetter
56
+ from os.path import relpath, dirname, exists
57
+ from os import remove
58
+ from shutil import copyfile, move
59
+ from subprocess import run, CalledProcessError
60
+
61
+ import h5py
62
+ from h5py import h5s
63
+ import numpy as np
64
+
65
+ from . import h5io
66
+ from .segment import Segment
67
+ from .states import BasisState, TargetState, InitialState
68
+ from .we_driver import NewWeightEntry
69
+
70
+ import westpa
71
+
72
+
73
+ log = logging.getLogger(__name__)
74
+
75
+ file_format_version = 10
76
+
77
+
78
+ def makepath(template, template_args=None, expanduser=True, expandvars=True, abspath=False, realpath=False):
79
+ """Function for manipulating paths. Used in ExecutablePropagator as well."""
80
+ template_args = template_args or {}
81
+ path = template.format(**template_args)
82
+ if expandvars:
83
+ path = os.path.expandvars(path)
84
+ if expanduser:
85
+ path = os.path.expanduser(path)
86
+ if realpath:
87
+ path = os.path.realpath(path)
88
+ if abspath:
89
+ path = os.path.abspath(path)
90
+ path = os.path.normpath(path)
91
+ return path
92
+
93
+
94
+ class flushing_lock:
95
+ def __init__(self, lock, fileobj):
96
+ self.lock = lock
97
+ self.fileobj = fileobj
98
+
99
+ def __enter__(self):
100
+ self.lock.acquire()
101
+
102
+ def __exit__(self, exc_type, exc_val, exc_tb):
103
+ self.fileobj.flush()
104
+ self.lock.release()
105
+
106
+
107
+ class expiring_flushing_lock:
108
+ def __init__(self, lock, flush_method, nextsync):
109
+ self.lock = lock
110
+ self.flush_method = flush_method
111
+ self.nextsync = nextsync
112
+
113
+ def __enter__(self):
114
+ self.lock.acquire()
115
+
116
+ def __exit__(self, exc_type, exc_val, exc_tb):
117
+ if time.time() > self.nextsync:
118
+ self.flush_method()
119
+ self.lock.release()
120
+
121
+
122
+ # Data types for use in the HDF5 file
123
+ seg_id_dtype = np.int64 # Up to 9 quintillion segments per iteration; signed so that initial states can be stored negative
124
+ n_iter_dtype = np.uint32 # Up to 4 billion iterations
125
+ weight_dtype = np.float64 # about 15 digits of precision in weights
126
+ utime_dtype = np.float64 # ("u" for Unix time) Up to ~10^300 cpu-seconds
127
+ vstr_dtype = h5py.special_dtype(vlen=str)
128
+ h5ref_dtype = h5py.special_dtype(ref=h5py.Reference)
129
+ binhash_dtype = np.dtype('|S64')
130
+
131
+ # seg_status_dtype = h5py.special_dtype(enum=(np.uint8, Segment.statuses))
132
+ # seg_initpoint_dtype = h5py.special_dtype(enum=(np.uint8, Segment.initpoint_types))
133
+ # seg_endpoint_dtype = h5py.special_dtype(enum=(np.uint8, Segment.endpoint_types))
134
+ # istate_type_dtype = h5py.special_dtype(enum=(np.uint8, InitialState.istate_types))
135
+ # istate_status_dtype = h5py.special_dtype(enum=(np.uint8, InitialState.istate_statuses))
136
+
137
+ seg_status_dtype = np.uint8
138
+ seg_initpoint_dtype = np.uint8
139
+ seg_endpoint_dtype = np.uint8
140
+ istate_type_dtype = np.uint8
141
+ istate_status_dtype = np.uint8
142
+
143
+ summary_table_dtype = np.dtype(
144
+ [
145
+ ('n_particles', seg_id_dtype), # Number of live trajectories in this iteration
146
+ ('norm', weight_dtype), # Norm of probability, to watch for errors or drift
147
+ ('min_bin_prob', weight_dtype), # Per-bin minimum probability
148
+ ('max_bin_prob', weight_dtype), # Per-bin maximum probability
149
+ ('min_seg_prob', weight_dtype), # Per-segment minimum probability
150
+ ('max_seg_prob', weight_dtype), # Per-segment maximum probability
151
+ ('cputime', utime_dtype), # Total CPU time for this iteration
152
+ ('walltime', utime_dtype), # Total wallclock time for this iteration
153
+ ('binhash', binhash_dtype),
154
+ ]
155
+ )
156
+
157
+
158
+ # The HDF5 file tracks two distinct, but related, histories:
159
+ # (1) the evolution of the trajectory, which requires only an identifier
160
+ # of where a segment's initial state comes from (the "history graph");
161
+ # this is stored as the parent_id field of the seg index
162
+ # (2) the flow of probability due to splits, merges, and recycling events,
163
+ # which can be thought of as an adjacency list (the "weight graph")
164
+ # segment ID is implied by the row in the index table, and so is not stored
165
+ # initpoint_type remains implicitly stored as negative IDs (if parent_id < 0, then init_state_id = -(parent_id+1)
166
+ seg_index_dtype = np.dtype(
167
+ [
168
+ ('weight', weight_dtype), # Statistical weight of this segment
169
+ ('parent_id', seg_id_dtype), # ID of parent (for trajectory history)
170
+ ('wtg_n_parents', np.uint), # number of parents this segment has in the weight transfer graph
171
+ ('wtg_offset', np.uint), # offset into the weight transfer graph dataset
172
+ ('cputime', utime_dtype), # CPU time used in propagating this segment
173
+ ('walltime', utime_dtype), # Wallclock time used in propagating this segment
174
+ ('endpoint_type', seg_endpoint_dtype), # Endpoint type (will continue, merged, or recycled)
175
+ ('status', seg_status_dtype), # Status of propagation of this segment
176
+ ]
177
+ )
178
+
179
+ # Index to basis/initial states
180
+ ibstate_index_dtype = np.dtype([('iter_valid', np.uint), ('n_bstates', np.uint), ('group_ref', h5ref_dtype)])
181
+
182
+ # Basis state index type
183
+ bstate_dtype = np.dtype(
184
+ [
185
+ ('label', vstr_dtype), # An optional descriptive label
186
+ ('probability', weight_dtype), # Probability that this state will be selected
187
+ ('auxref', vstr_dtype), # An optional auxiliar data reference
188
+ ]
189
+ )
190
+
191
+ # Even when initial state generation is off and basis states are passed through directly, an initial state entry
192
+ # is created, as that allows precise tracing of the history of a given state in the most complex case of
193
+ # a new initial state for every new trajectory.
194
+ istate_dtype = np.dtype(
195
+ [
196
+ ('iter_created', np.uint), # Iteration during which this state was generated (0 for at w_init)
197
+ ('iter_used', np.uint), # When this state was used to start a new trajectory
198
+ ('basis_state_id', seg_id_dtype), # Which basis state this state was generated from
199
+ ('istate_type', istate_type_dtype), # What type this initial state is (generated or basis)
200
+ ('istate_status', istate_status_dtype), # Whether this initial state is ready to go
201
+ ('basis_auxref', vstr_dtype), # for start states to point back to their original location
202
+ ]
203
+ )
204
+
205
+ tstate_index_dtype = np.dtype(
206
+ [('iter_valid', np.uint), ('n_states', np.uint), ('group_ref', h5ref_dtype)] # Iteration when this state list is valid
207
+ ) # Reference to a group containing further data; this will be the
208
+ # null reference if there is no target state for that timeframe.
209
+ tstate_dtype = np.dtype([('label', vstr_dtype)]) # An optional descriptive label for this state
210
+
211
+ # Support for west.we_driver.NewWeightEntry
212
+ nw_source_dtype = np.uint8
213
+ nw_index_dtype = np.dtype(
214
+ [
215
+ ('source_type', nw_source_dtype),
216
+ ('weight', weight_dtype),
217
+ ('prev_seg_id', seg_id_dtype),
218
+ ('target_state_id', seg_id_dtype),
219
+ ('initial_state_id', seg_id_dtype),
220
+ ]
221
+ )
222
+
223
+ # Storage of bin identities
224
+ binning_index_dtype = np.dtype([('hash', binhash_dtype), ('pickle_len', np.uint32)])
225
+
226
+
227
+ class WESTDataManager:
228
+ """Data manager for assisiting the reading and writing of WEST data from/to HDF5 files."""
229
+
230
+ # defaults for various options
231
+ default_iter_prec = 8
232
+ default_we_h5filename = 'west.h5'
233
+ default_we_h5file_driver = None
234
+ default_flush_period = 60
235
+
236
+ # Compress any auxiliary dataset whose total size (across all segments) is more than 1MB
237
+ default_aux_compression_threshold = 1048576
238
+
239
+ # Bin data horizontal (second dimension) chunk size
240
+ binning_hchunksize = 4096
241
+
242
+ # Number of rows to retrieve during a table scan
243
+ table_scan_chunksize = 1024
244
+
245
+ def flushing_lock(self):
246
+ return flushing_lock(self.lock, self.we_h5file)
247
+
248
+ def expiring_flushing_lock(self):
249
+ next_flush = self.last_flush + self.flush_period
250
+ return expiring_flushing_lock(self.lock, self.flush_backing, next_flush)
251
+
252
+ def process_config(self):
253
+ config = self.rc.config
254
+
255
+ for entry, type_ in [('iter_prec', int)]:
256
+ config.require_type_if_present(['west', 'data', entry], type_)
257
+
258
+ self.we_h5filename = config.get_path(['west', 'data', 'west_data_file'], default=self.default_we_h5filename)
259
+ self.we_h5file_driver = config.get_choice(
260
+ ['west', 'data', 'west_data_file_driver'],
261
+ [None, 'sec2', 'family'],
262
+ default=self.default_we_h5file_driver,
263
+ value_transform=(lambda x: x.lower() if x else None),
264
+ )
265
+ self.iter_prec = config.get(['west', 'data', 'iter_prec'], self.default_iter_prec)
266
+ self.aux_compression_threshold = config.get(
267
+ ['west', 'data', 'aux_compression_threshold'], self.default_aux_compression_threshold
268
+ )
269
+ self.flush_period = config.get(['west', 'data', 'flush_period'], self.default_flush_period)
270
+
271
+ # Path to per-iter h5 file
272
+ self.iter_h5_path_template = config.get(['west', 'data', 'data_refs', 'iteration'], None)
273
+ try:
274
+ # Generating path to a template file for per-iter h5 file
275
+ self.iter_h5_template_file_path = re.sub(r'\{(.*?)\}', 'template', self.iter_h5_path_template)
276
+ except TypeError:
277
+ self.iter_h5_template_file_path = None
278
+
279
+ # If not provided, turn HDF5 Framework off.
280
+ self.store_h5 = self.iter_h5_path_template is not None
281
+
282
+ # Process dataset options
283
+ dsopts_list = config.get(['west', 'data', 'datasets']) or []
284
+ for dsopts in dsopts_list:
285
+ dsopts = normalize_dataset_options(dsopts, path_prefix='auxdata' if dsopts['name'] != 'pcoord' else '')
286
+ try:
287
+ self.dataset_options[dsopts['name']].update(dsopts)
288
+ except KeyError:
289
+ self.dataset_options[dsopts['name']] = dsopts
290
+
291
+ if 'pcoord' in self.dataset_options:
292
+ if self.dataset_options['pcoord']['h5path'] != 'pcoord':
293
+ raise ValueError('cannot override pcoord storage location')
294
+
295
+ def __init__(self, rc=None):
296
+ self.rc = rc or westpa.rc
297
+
298
+ self.we_h5filename = self.default_we_h5filename
299
+ self.we_h5file_driver = self.default_we_h5file_driver
300
+ self.we_h5file_version = None
301
+ self.h5_access_mode = 'r+'
302
+ self.iter_prec = self.default_iter_prec
303
+ self.aux_compression_threshold = self.default_aux_compression_threshold
304
+
305
+ self.we_h5file = None
306
+
307
+ self.lock = threading.RLock()
308
+ self.flush_period = None
309
+ self.last_flush = 0
310
+
311
+ self._system = None
312
+ self.iter_h5_path_template = None # Template for per-iter H5 file Path
313
+ self.iter_h5_template_file_path = None # Path to per-iter H5 template file
314
+ self.store_h5 = False # Indicates HDF5 Framework is activated or not
315
+ self.template_copy_flag = False # Flag indicating the template file was made this iteration
316
+
317
+ self.dataset_options = {}
318
+ self.process_config()
319
+
320
+ @property
321
+ def system(self):
322
+ if self._system is None:
323
+ self._system = self.rc.get_system_driver()
324
+ return self._system
325
+
326
+ @system.setter
327
+ def system(self, system):
328
+ self._system = system
329
+
330
+ @property
331
+ def closed(self):
332
+ return self.we_h5file is None
333
+
334
+ def iter_group_name(self, n_iter, absolute=True):
335
+ if absolute:
336
+ return '/iterations/iter_{:0{prec}d}'.format(int(n_iter), prec=self.iter_prec)
337
+ else:
338
+ return 'iter_{:0{prec}d}'.format(int(n_iter), prec=self.iter_prec)
339
+
340
+ def require_iter_group(self, n_iter):
341
+ '''Get the group associated with n_iter, creating it if necessary.'''
342
+ with self.lock:
343
+ iter_group = self.we_h5file.require_group('/iterations/iter_{:0{prec}d}'.format(int(n_iter), prec=self.iter_prec))
344
+ iter_group.attrs['n_iter'] = n_iter
345
+ return iter_group
346
+
347
+ def del_iter_group(self, n_iter):
348
+ with self.lock:
349
+ del self.we_h5file['/iterations/iter_{:0{prec}d}'.format(int(n_iter), prec=self.iter_prec)]
350
+
351
+ def get_iter_group(self, n_iter):
352
+ with self.lock:
353
+ try:
354
+ return self.we_h5file['/iterations/iter_{:0{prec}d}'.format(int(n_iter), prec=self.iter_prec)]
355
+ except KeyError:
356
+ return self.we_h5file['/iter_{:0{prec}d}'.format(int(n_iter), prec=self.iter_prec)]
357
+
358
+ def get_seg_index(self, n_iter):
359
+ with self.lock:
360
+ seg_index = self.get_iter_group(n_iter)['seg_index']
361
+ return seg_index
362
+
363
+ @property
364
+ def current_iteration(self):
365
+ with self.lock:
366
+ h5file_attrs = self.we_h5file['/'].attrs
367
+ h5file_attr_keys = list(h5file_attrs.keys())
368
+
369
+ if 'west_current_iteration' in h5file_attr_keys:
370
+ return int(self.we_h5file['/'].attrs['west_current_iteration'])
371
+ else:
372
+ return int(self.we_h5file['/'].attrs['wemd_current_iteration'])
373
+
374
+ @current_iteration.setter
375
+ def current_iteration(self, n_iter):
376
+ with self.lock:
377
+ self.we_h5file['/'].attrs['west_current_iteration'] = n_iter
378
+
379
+ def open_backing(self, mode=None):
380
+ '''Open the (already-created) HDF5 file named in self.west_h5filename.'''
381
+ mode = mode or self.h5_access_mode
382
+ if not self.we_h5file:
383
+ log.debug('attempting to open {} with mode {}'.format(self.we_h5filename, mode))
384
+ self.we_h5file = h5io.WESTPAH5File(self.we_h5filename, mode, driver=self.we_h5file_driver)
385
+
386
+ h5file_attrs = self.we_h5file['/'].attrs
387
+ h5file_attr_keys = list(h5file_attrs.keys())
388
+
389
+ if 'west_iter_prec' in h5file_attr_keys:
390
+ self.iter_prec = int(h5file_attrs['west_iter_prec'])
391
+ elif 'wemd_iter_prec' in h5file_attr_keys:
392
+ self.iter_prec = int(h5file_attrs['wemd_iter_prec'])
393
+ else:
394
+ log.info('iteration precision not stored in HDF5; using {:d}'.format(self.iter_prec))
395
+
396
+ if 'west_file_format_version' in h5file_attr_keys:
397
+ self.we_h5file_version = h5file_attrs['west_file_format_version']
398
+ elif 'wemd_file_format_version' in h5file_attr_keys:
399
+ self.we_h5file_version = h5file_attrs['wemd_file_format_version']
400
+ else:
401
+ log.info('WEST HDF5 file format version not stored, assuming 0')
402
+ self.we_h5file_version = 0
403
+
404
+ log.debug('opened WEST HDF5 file version {:d}'.format(self.we_h5file_version))
405
+
406
+ def prepare_backing(self): # istates):
407
+ '''Create new HDF5 file'''
408
+ self.we_h5file = h5py.File(self.we_h5filename, 'w', driver=self.we_h5file_driver)
409
+
410
+ with self.flushing_lock():
411
+ self.we_h5file['/'].attrs['west_file_format_version'] = file_format_version
412
+ self.we_h5file['/'].attrs['west_iter_prec'] = self.iter_prec
413
+ self.we_h5file['/'].attrs['west_version'] = westpa.__version__
414
+ self.current_iteration = 0
415
+ self.we_h5file['/'].create_dataset('summary', shape=(1,), dtype=summary_table_dtype, maxshape=(None,))
416
+ self.we_h5file.create_group('/iterations')
417
+
418
+ def close_backing(self):
419
+ if self.we_h5file is not None:
420
+ with self.lock:
421
+ self.we_h5file.close()
422
+ self.we_h5file = None
423
+
424
+ def flush_backing(self):
425
+ if self.we_h5file is not None:
426
+ with self.lock:
427
+ self.we_h5file.flush()
428
+ self.last_flush = time.time()
429
+
430
+ def save_target_states(self, tstates, n_iter=None):
431
+ """
432
+ Save the given target states in the HDF5 file; they will be used for the next iteration to
433
+ be propagated. A complete set is required, even if nominally appending to an existing set,
434
+ which simplifies the mapping of IDs to the table.
435
+ """
436
+
437
+ system = self.system
438
+
439
+ n_iter = n_iter or self.current_iteration
440
+
441
+ # Assemble all the important data before we start to modify the HDF5 file
442
+ tstates = list(tstates)
443
+ if tstates:
444
+ state_table = np.empty((len(tstates),), dtype=tstate_dtype)
445
+ state_pcoords = np.empty((len(tstates), system.pcoord_ndim), dtype=system.pcoord_dtype)
446
+ for i, state in enumerate(tstates):
447
+ state.state_id = i
448
+ state_table[i]['label'] = state.label
449
+ state_pcoords[i] = state.pcoord
450
+ else:
451
+ state_table = None
452
+ state_pcoords = None
453
+
454
+ # Commit changes to HDF5
455
+ with self.lock:
456
+ master_group = self.we_h5file.require_group('tstates')
457
+
458
+ try:
459
+ master_index = master_group['index']
460
+ except KeyError:
461
+ master_index = master_group.create_dataset('index', shape=(1,), maxshape=(None,), dtype=tstate_index_dtype)
462
+ n_sets = 1
463
+ else:
464
+ n_sets = len(master_index) + 1
465
+ master_index.resize((n_sets,))
466
+
467
+ set_id = n_sets - 1
468
+ master_index_row = master_index[set_id]
469
+ master_index_row['iter_valid'] = n_iter
470
+ master_index_row['n_states'] = len(tstates)
471
+
472
+ if tstates:
473
+ state_group = master_group.create_group(str(set_id))
474
+ master_index_row['group_ref'] = state_group.ref
475
+ state_group['index'] = state_table
476
+ state_group['pcoord'] = state_pcoords
477
+ else:
478
+ master_index_row['group_ref'] = None
479
+
480
+ master_index[set_id] = master_index_row
481
+
482
+ def _find_multi_iter_group(self, n_iter, master_group_name):
483
+ with self.lock:
484
+ master_group = self.we_h5file[master_group_name]
485
+ master_index = master_group['index'][...]
486
+ set_id = np.digitize([n_iter], master_index['iter_valid']) - 1
487
+ group_ref = master_index[set_id]['group_ref']
488
+
489
+ # Check if reference is Null
490
+ if not bool(group_ref):
491
+ return None
492
+
493
+ # This extra [0] is to work around a bug in h5py
494
+ try:
495
+ group = self.we_h5file[group_ref]
496
+ except (TypeError, AttributeError):
497
+ group = self.we_h5file[group_ref[0]]
498
+ else:
499
+ log.debug('h5py fixed; remove alternate code path')
500
+ log.debug('reference {!r} points to group {!r}'.format(group_ref, group))
501
+ return group
502
+
503
+ def find_tstate_group(self, n_iter):
504
+ return self._find_multi_iter_group(n_iter, 'tstates')
505
+
506
+ def find_ibstate_group(self, n_iter):
507
+ return self._find_multi_iter_group(n_iter, 'ibstates')
508
+
509
+ def get_target_states(self, n_iter):
510
+ '''Return a list of Target objects representing the target (sink) states that are in use for iteration n_iter.
511
+ Future iterations are assumed to continue from the most recent set of states.'''
512
+
513
+ with self.lock:
514
+ tstate_group = self.find_tstate_group(n_iter)
515
+
516
+ if tstate_group is not None:
517
+ tstate_index = tstate_group['index'][...]
518
+ tstate_pcoords = tstate_group['pcoord'][...]
519
+
520
+ tstates = [
521
+ TargetState(state_id=i, label=h5io.tostr(row['label']), pcoord=pcoord.copy())
522
+ for (i, (row, pcoord)) in enumerate(zip(tstate_index, tstate_pcoords))
523
+ ]
524
+ else:
525
+ tstates = []
526
+
527
+ return tstates
528
+
529
+ def create_ibstate_group(self, basis_states, n_iter=None):
530
+ '''Create the group used to store basis states and initial states (whose definitions are always
531
+ coupled). This group is hard-linked into all iteration groups that use these basis and
532
+ initial states.'''
533
+
534
+ with self.lock:
535
+ n_iter = n_iter or self.current_iteration
536
+ master_group = self.we_h5file.require_group('ibstates')
537
+
538
+ try:
539
+ master_index = master_group['index']
540
+ except KeyError:
541
+ master_index = master_group.create_dataset('index', dtype=ibstate_index_dtype, shape=(1,), maxshape=(None,))
542
+ n_sets = 1
543
+ else:
544
+ n_sets = len(master_index) + 1
545
+ master_index.resize((n_sets,))
546
+
547
+ set_id = n_sets - 1
548
+ master_index_row = master_index[set_id]
549
+ master_index_row['iter_valid'] = n_iter
550
+ master_index_row['n_bstates'] = len(basis_states)
551
+ state_group = master_group.create_group(str(set_id))
552
+ master_index_row['group_ref'] = state_group.ref
553
+
554
+ if basis_states:
555
+ system = self.system
556
+ state_table = np.empty((len(basis_states),), dtype=bstate_dtype)
557
+ state_pcoords = np.empty((len(basis_states), system.pcoord_ndim), dtype=system.pcoord_dtype)
558
+ for i, state in enumerate(basis_states):
559
+ state.state_id = i
560
+ state_table[i]['label'] = state.label
561
+ state_table[i]['probability'] = state.probability
562
+ state_table[i]['auxref'] = state.auxref or ''
563
+ state_pcoords[i] = state.pcoord
564
+
565
+ state_group['bstate_index'] = state_table
566
+ state_group['bstate_pcoord'] = state_pcoords
567
+
568
+ master_index[set_id] = master_index_row
569
+ return state_group
570
+
571
+ def create_ibstate_iter_h5file(self, basis_states):
572
+ """
573
+ Create the per-iteration HDF5 file for the basis states (i.e., iteration 0).
574
+ This special treatment is needed so that the analysis tools can access basis states
575
+ more easily.
576
+ """
577
+
578
+ if not self.store_h5:
579
+ return
580
+
581
+ segments = []
582
+ for i, state in enumerate(basis_states):
583
+ dummy_segment = Segment(
584
+ n_iter=0,
585
+ seg_id=state.state_id,
586
+ parent_id=-(state.state_id + 1),
587
+ weight=state.probability,
588
+ wtg_parent_ids=None,
589
+ pcoord=state.pcoord,
590
+ status=Segment.SEG_STATUS_UNSET,
591
+ data=state.data,
592
+ )
593
+ segments.append(dummy_segment)
594
+
595
+ # # link the iteration file in west.h5
596
+ self.prepare_iteration(0, segments)
597
+ self.update_iter_h5file(0, segments)
598
+
599
+ def update_iter_h5file(self, n_iter, segments):
600
+ """Write out the per-iteration HDF5 file with given segments and add an external link to it
601
+ in the main HDF5 file (west.h5) if the link is not present."""
602
+
603
+ if not self.store_h5:
604
+ return
605
+
606
+ west_h5_file = makepath(self.we_h5filename)
607
+ iter_ref_h5_file = makepath(self.iter_h5_path_template, {'n_iter': n_iter})
608
+ iter_ref_rel_path = relpath(iter_ref_h5_file, dirname(west_h5_file))
609
+ if self.iter_h5_template_file_path:
610
+ # Make path to per-iter H5 File
611
+ iter_h5_template_file_path_expanded = makepath(self.iter_h5_template_file_path)
612
+
613
+ # Copy the template per-iter H5 file with topology
614
+ if exists(iter_h5_template_file_path_expanded) and not exists(iter_ref_h5_file):
615
+ copyfile(iter_h5_template_file_path_expanded, iter_ref_h5_file)
616
+
617
+ with h5io.WESTIterationFile(iter_ref_h5_file, 'a') as outf:
618
+ for segment in segments:
619
+ outf.write_segment(segment, True)
620
+
621
+ if self.iter_h5_template_file_path and not exists(iter_h5_template_file_path_expanded):
622
+ # If template per-iter H5 file does not exist, copy and scrub out old data
623
+ copyfile(iter_ref_h5_file, iter_h5_template_file_path_expanded)
624
+ with h5io.WESTIterationFile(iter_h5_template_file_path_expanded, 'a') as outf:
625
+ outf.scrub_data()
626
+
627
+ # Launch a subprocess to repack the file to reclaim space, replacing template with smaller file
628
+ try:
629
+ run(
630
+ f'h5repack {iter_h5_template_file_path_expanded} {iter_h5_template_file_path_expanded}_repacked.h5', shell=True
631
+ ).check_returncode()
632
+ move(f'{iter_h5_template_file_path_expanded}_repacked.h5', iter_h5_template_file_path_expanded)
633
+ except CalledProcessError as e: # Unsuccessful in repacking file
634
+ log.warning(f'Unable to repack into {iter_h5_template_file_path_expanded}_repacked.h5: {e}')
635
+ if exists(f'{iter_h5_template_file_path_expanded}_repacked.h5'):
636
+ remove(f'{iter_h5_template_file_path_expanded}_repacked.h5')
637
+
638
+ iter_group = self.get_iter_group(n_iter)
639
+
640
+ if 'trajectories' not in iter_group:
641
+ iter_group['trajectories'] = h5py.ExternalLink(iter_ref_rel_path, '/')
642
+
643
+ def get_basis_states(self, n_iter=None):
644
+ """Return a list of BasisState objects representing the basis states that are in use for iteration n_iter."""
645
+
646
+ with self.lock:
647
+ n_iter = n_iter or self.current_iteration
648
+ ibstate_group = self.find_ibstate_group(n_iter)
649
+ try:
650
+ bstate_index = ibstate_group['bstate_index'][...]
651
+ except KeyError:
652
+ return []
653
+ bstate_pcoords = ibstate_group['bstate_pcoord'][...]
654
+ bstates = [
655
+ BasisState(
656
+ state_id=i,
657
+ label=h5io.tostr(row['label']),
658
+ probability=row['probability'],
659
+ auxref=h5io.tostr(row['auxref']) or None,
660
+ pcoord=pcoord.copy(),
661
+ )
662
+ for (i, (row, pcoord)) in enumerate(zip(bstate_index, bstate_pcoords))
663
+ ]
664
+
665
+ bstate_total_prob = sum(bstate.probability for bstate in bstates)
666
+
667
+ # This should run once in the second iteration, and only if start-states are specified,
668
+ # but is necessary to re-normalize (i.e. normalize without start-state probabilities included)
669
+ for i, bstate in enumerate(bstates):
670
+ bstate.probability /= bstate_total_prob
671
+ bstates[i] = bstate
672
+ return bstates
673
+
674
+ def create_initial_states(self, n_states, n_iter=None):
675
+ """Create storage for ``n_states`` initial states associated with iteration ``n_iter``, and
676
+ return bare InitialState objects with only state_id set."""
677
+
678
+ system = self.system
679
+ with self.lock:
680
+ n_iter = n_iter or self.current_iteration
681
+ ibstate_group = self.find_ibstate_group(n_iter)
682
+
683
+ try:
684
+ istate_index = ibstate_group['istate_index']
685
+ except KeyError:
686
+ istate_index = ibstate_group.create_dataset('istate_index', dtype=istate_dtype, shape=(n_states,), maxshape=(None,))
687
+ istate_pcoords = ibstate_group.create_dataset(
688
+ 'istate_pcoord',
689
+ dtype=system.pcoord_dtype,
690
+ shape=(n_states, system.pcoord_ndim),
691
+ maxshape=(None, system.pcoord_ndim),
692
+ )
693
+ len_index = len(istate_index)
694
+ first_id = 0
695
+ else:
696
+ first_id = len(istate_index)
697
+ len_index = len(istate_index) + n_states
698
+ istate_index.resize((len_index,))
699
+ istate_pcoords = ibstate_group['istate_pcoord']
700
+ istate_pcoords.resize((len_index, system.pcoord_ndim))
701
+
702
+ index_entries = istate_index[first_id:len_index]
703
+ new_istates = []
704
+ for irow, row in enumerate(index_entries):
705
+ row['iter_created'] = n_iter
706
+ row['istate_status'] = InitialState.ISTATE_STATUS_PENDING
707
+ new_istates.append(
708
+ InitialState(
709
+ state_id=first_id + irow,
710
+ basis_state_id=None,
711
+ iter_created=n_iter,
712
+ istate_status=InitialState.ISTATE_STATUS_PENDING,
713
+ )
714
+ )
715
+ istate_index[first_id:len_index] = index_entries
716
+ return new_istates
717
+
718
+ def update_initial_states(self, initial_states, n_iter=None):
719
+ """Save the given initial states in the HDF5 file"""
720
+
721
+ system = self.system
722
+ initial_states = sorted(initial_states, key=attrgetter('state_id'))
723
+ if not initial_states:
724
+ return
725
+
726
+ with self.lock:
727
+ n_iter = n_iter or self.current_iteration
728
+ ibstate_group = self.find_ibstate_group(n_iter)
729
+ state_ids = [state.state_id for state in initial_states]
730
+ index_entries = ibstate_group['istate_index'][state_ids]
731
+ pcoord_vals = np.empty((len(initial_states), system.pcoord_ndim), dtype=system.pcoord_dtype)
732
+ for i, initial_state in enumerate(initial_states):
733
+ index_entries[i]['iter_created'] = initial_state.iter_created
734
+ index_entries[i]['iter_used'] = initial_state.iter_used or InitialState.ISTATE_UNUSED
735
+ index_entries[i]['basis_state_id'] = (
736
+ initial_state.basis_state_id if initial_state.basis_state_id is not None else -1
737
+ )
738
+ index_entries[i]['istate_type'] = initial_state.istate_type or InitialState.ISTATE_TYPE_UNSET
739
+ index_entries[i]['istate_status'] = initial_state.istate_status or InitialState.ISTATE_STATUS_PENDING
740
+ pcoord_vals[i] = initial_state.pcoord
741
+
742
+ index_entries[i]['basis_auxref'] = initial_state.basis_auxref or ""
743
+
744
+ ibstate_group['istate_index'][state_ids] = index_entries
745
+ ibstate_group['istate_pcoord'][state_ids] = pcoord_vals
746
+
747
+ if self.store_h5:
748
+ segments = []
749
+ for i, state in enumerate(initial_states):
750
+ dummy_segment = Segment(
751
+ n_iter=-state.iter_created,
752
+ seg_id=state.state_id,
753
+ parent_id=state.basis_state_id,
754
+ wtg_parent_ids=None,
755
+ pcoord=state.pcoord,
756
+ status=Segment.SEG_STATUS_PREPARED,
757
+ data=state.data,
758
+ )
759
+ segments.append(dummy_segment)
760
+ self.update_iter_h5file(0, segments)
761
+
762
+ def get_initial_states(self, n_iter=None):
763
+ states = []
764
+ with self.lock:
765
+ n_iter = n_iter or self.current_iteration
766
+ ibstate_group = self.find_ibstate_group(n_iter)
767
+ try:
768
+ istate_index = ibstate_group['istate_index'][...]
769
+ except KeyError:
770
+ return []
771
+ istate_pcoords = ibstate_group['istate_pcoord'][...]
772
+
773
+ for state_id, (state, pcoord) in enumerate(zip(istate_index, istate_pcoords)):
774
+ states.append(
775
+ InitialState(
776
+ state_id=state_id,
777
+ basis_state_id=int(state['basis_state_id']),
778
+ iter_created=int(state['iter_created']),
779
+ iter_used=int(state['iter_used']),
780
+ istate_type=int(state['istate_type']),
781
+ basis_auxref=h5io.tostr(state['basis_auxref']),
782
+ pcoord=pcoord.copy(),
783
+ )
784
+ )
785
+ return states
786
+
787
+ def get_segment_initial_states(self, segments, n_iter=None):
788
+ '''Retrieve all initial states referenced by the given segments.'''
789
+
790
+ with self.lock:
791
+ n_iter = n_iter or self.current_iteration
792
+ ibstate_group = self.get_iter_group(n_iter)['ibstates']
793
+
794
+ istate_ids = {-int(segment.parent_id + 1) for segment in segments if segment.parent_id < 0}
795
+ sorted_istate_ids = sorted(istate_ids)
796
+ if not sorted_istate_ids:
797
+ return []
798
+
799
+ istate_rows = ibstate_group['istate_index'][sorted_istate_ids][...]
800
+ istate_pcoords = ibstate_group['istate_pcoord'][sorted_istate_ids][...]
801
+ istates = []
802
+
803
+ for state_id, state, pcoord in zip(sorted_istate_ids, istate_rows, istate_pcoords):
804
+ try:
805
+ b_auxref = h5io.tostr(state['basis_auxref'])
806
+ except ValueError:
807
+ b_auxref = ''
808
+ istate = InitialState(
809
+ state_id=state_id,
810
+ basis_state_id=int(state['basis_state_id']),
811
+ iter_created=int(state['iter_created']),
812
+ iter_used=int(state['iter_used']),
813
+ istate_type=int(state['istate_type']),
814
+ basis_auxref=b_auxref,
815
+ pcoord=pcoord.copy(),
816
+ )
817
+ istates.append(istate)
818
+ return istates
819
+
820
+ def get_unused_initial_states(self, n_states=None, n_iter=None):
821
+ """Retrieve any prepared but unused initial states applicable to the given iteration.
822
+ Up to ``n_states`` states are returned; if ``n_states`` is None, then all unused states
823
+ are returned."""
824
+
825
+ n_states = n_states or sys.maxsize
826
+ ISTATE_UNUSED = InitialState.ISTATE_UNUSED
827
+ ISTATE_STATUS_PREPARED = InitialState.ISTATE_STATUS_PREPARED
828
+ with self.lock:
829
+ n_iter = n_iter or self.current_iteration
830
+ ibstate_group = self.find_ibstate_group(n_iter)
831
+ istate_index = ibstate_group['istate_index']
832
+ istate_pcoords = ibstate_group['istate_pcoord']
833
+ n_index_entries = istate_index.len()
834
+ chunksize = self.table_scan_chunksize
835
+
836
+ states = []
837
+ istart = 0
838
+ while istart < n_index_entries and len(states) < n_states:
839
+ istop = min(istart + chunksize, n_index_entries)
840
+ istate_chunk = istate_index[istart:istop]
841
+ pcoord_chunk = istate_pcoords[istart:istop]
842
+ # state_ids = np.arange(istart,istop,dtype=np.uint)
843
+
844
+ for ci in range(len(istate_chunk)):
845
+ row = istate_chunk[ci]
846
+ pcoord = pcoord_chunk[ci]
847
+ state_id = istart + ci
848
+ if row['iter_used'] == ISTATE_UNUSED and row['istate_status'] == ISTATE_STATUS_PREPARED:
849
+ istate = InitialState(
850
+ state_id=state_id,
851
+ basis_state_id=int(row['basis_state_id']),
852
+ iter_created=int(row['iter_created']),
853
+ iter_used=0,
854
+ istate_type=int(row['istate_type']),
855
+ pcoord=pcoord.copy(),
856
+ istate_status=ISTATE_STATUS_PREPARED,
857
+ )
858
+ states.append(istate)
859
+ del row, pcoord, state_id
860
+ istart += chunksize
861
+ del istate_chunk, pcoord_chunk # , state_ids, unused, ids_of_unused
862
+ log.debug('found {:d} unused states'.format(len(states)))
863
+ return states[:n_states]
864
+
865
+ def prepare_iteration(self, n_iter, segments):
866
+ """Prepare for a new iteration by creating space to store the new iteration's data.
867
+ The number of segments, their IDs, and their lineage must be determined and included
868
+ in the set of segments passed in."""
869
+
870
+ log.debug('preparing HDF5 group for iteration %d (%d segments)' % (n_iter, len(segments)))
871
+
872
+ # Ensure we have a list for guaranteed ordering
873
+ init = n_iter == 0
874
+ segments = list(segments)
875
+ n_particles = len(segments)
876
+ system = self.system
877
+ pcoord_ndim = system.pcoord_ndim
878
+ pcoord_len = 2 if init else system.pcoord_len
879
+ pcoord_dtype = system.pcoord_dtype
880
+
881
+ with self.lock:
882
+ if not init:
883
+ # Create a table of summary information about each iteration
884
+ summary_table = self.we_h5file['summary']
885
+ if len(summary_table) < n_iter:
886
+ summary_table.resize((n_iter + 1,))
887
+
888
+ iter_group = self.require_iter_group(n_iter)
889
+
890
+ for linkname in ('seg_index', 'pcoord', 'wtgraph'):
891
+ try:
892
+ del iter_group[linkname]
893
+ except KeyError:
894
+ pass
895
+
896
+ # everything indexed by [particle] goes in an index table
897
+ seg_index_table_ds = iter_group.create_dataset('seg_index', shape=(n_particles,), dtype=seg_index_dtype)
898
+ # unfortunately, h5py doesn't like in-place modification of individual fields; it expects
899
+ # tuples. So, construct everything in a numpy array and then dump the whole thing into hdf5
900
+ # In fact, this appears to be an h5py best practice (collect as much in ram as possible and then dump)
901
+ seg_index_table = seg_index_table_ds[...]
902
+
903
+ if not init:
904
+ summary_row = np.zeros((1,), dtype=summary_table_dtype)
905
+ summary_row['n_particles'] = n_particles
906
+ summary_row['norm'] = np.add.reduce(list(map(attrgetter('weight'), segments)))
907
+ summary_table[n_iter - 1] = summary_row
908
+
909
+ # pcoord is indexed as [particle, time, dimension]
910
+ pcoord_opts = self.dataset_options.get('pcoord', {'name': 'pcoord', 'h5path': 'pcoord', 'compression': False})
911
+ shape = (n_particles, pcoord_len, pcoord_ndim)
912
+ pcoord_ds = create_dataset_from_dsopts(iter_group, pcoord_opts, shape, pcoord_dtype)
913
+ pcoord = np.empty((n_particles, pcoord_len, pcoord_ndim), pcoord_dtype)
914
+
915
+ total_parents = 0
916
+ for seg_id, segment in enumerate(segments):
917
+ if segment.seg_id is not None:
918
+ assert segment.seg_id == seg_id
919
+ else:
920
+ segment.seg_id = seg_id
921
+ # Parent must be set, though what it means depends on initpoint_type
922
+ assert segment.parent_id is not None
923
+ segment.seg_id = seg_id
924
+ seg_index_table[seg_id]['status'] = segment.status
925
+ seg_index_table[seg_id]['weight'] = segment.weight
926
+ seg_index_table[seg_id]['parent_id'] = segment.parent_id
927
+ seg_index_table[seg_id]['wtg_n_parents'] = len(segment.wtg_parent_ids)
928
+ seg_index_table[seg_id]['wtg_offset'] = total_parents
929
+ total_parents += len(segment.wtg_parent_ids)
930
+
931
+ # Assign progress coordinate if any exists
932
+ if segment.pcoord is not None:
933
+ if init:
934
+ if segment.pcoord.shape != pcoord.shape[2:]:
935
+ raise ValueError(
936
+ 'basis state pcoord shape [%r] does not match expected shape [%r]'
937
+ % (segment.pcoord.shape, pcoord.shape[2:])
938
+ )
939
+ # Initial pcoord
940
+ pcoord[seg_id, 1, :] = segment.pcoord[:]
941
+ else:
942
+ if len(segment.pcoord) == 1:
943
+ # Initial pcoord
944
+ pcoord[seg_id, 0, :] = segment.pcoord[0, :]
945
+ elif segment.pcoord.shape != pcoord.shape[1:]:
946
+ raise ValueError(
947
+ 'segment pcoord shape [%r] does not match expected shape [%r]'
948
+ % (segment.pcoord.shape, pcoord.shape[1:])
949
+ )
950
+ else:
951
+ pcoord[seg_id, ...] = segment.pcoord
952
+
953
+ if total_parents > 0:
954
+ wtgraph_ds = iter_group.create_dataset('wtgraph', (total_parents,), seg_id_dtype, compression='gzip', shuffle=True)
955
+ parents = np.empty((total_parents,), seg_id_dtype)
956
+
957
+ for seg_id, segment in enumerate(segments):
958
+ offset = seg_index_table[seg_id]['wtg_offset']
959
+ extent = seg_index_table[seg_id]['wtg_n_parents']
960
+ parent_list = list(segment.wtg_parent_ids)
961
+ parents[offset : offset + extent] = parent_list[:]
962
+
963
+ assert set(parents[offset : offset + extent]) == set(segment.wtg_parent_ids)
964
+
965
+ wtgraph_ds[:] = parents
966
+
967
+ # Create convenient hard links
968
+ self.update_iter_group_links(n_iter)
969
+
970
+ # Since we accumulated many of these changes in RAM (and not directly in HDF5), propagate
971
+ # the changes out to HDF5
972
+ seg_index_table_ds[:] = seg_index_table
973
+ pcoord_ds[...] = pcoord
974
+
975
+ def update_iter_group_links(self, n_iter):
976
+ """Update the per-iteration hard links pointing to the tables of target and initial/basis states for the
977
+ given iteration. These links are not used by this class, but are remarkably convenient for third-party
978
+ analysis tools and hdfview."""
979
+
980
+ with self.lock:
981
+ iter_group = self.require_iter_group(n_iter)
982
+
983
+ for linkname in ('ibstates', 'tstates'):
984
+ try:
985
+ del iter_group[linkname]
986
+ except KeyError:
987
+ pass
988
+
989
+ iter_group['ibstates'] = self.find_ibstate_group(n_iter)
990
+
991
+ tstate_group = self.find_tstate_group(n_iter)
992
+ if tstate_group is not None:
993
+ iter_group['tstates'] = tstate_group
994
+
995
+ def get_iter_summary(self, n_iter=None):
996
+ n_iter = n_iter or self.current_iteration
997
+ with self.lock:
998
+ return self.we_h5file['summary'][n_iter - 1]
999
+
1000
+ def update_iter_summary(self, summary, n_iter=None):
1001
+ n_iter = n_iter or self.current_iteration
1002
+ with self.lock:
1003
+ self.we_h5file['summary'][n_iter - 1] = summary
1004
+
1005
+ def del_iter_summary(self, min_iter): # delete the iterations starting at min_iter
1006
+ with self.lock:
1007
+ self.we_h5file['summary'].resize((min_iter - 1,))
1008
+
1009
+ def update_segments(self, n_iter, segments):
1010
+ """Update segment information in the HDF5 file; all prior information for each
1011
+ ``segment`` is overwritten, except for parent and weight transfer information."""
1012
+
1013
+ segments = sorted(segments, key=attrgetter('seg_id'))
1014
+
1015
+ with self.lock:
1016
+ iter_group = self.get_iter_group(n_iter)
1017
+
1018
+ pc_dsid = iter_group['pcoord'].id
1019
+ si_dsid = iter_group['seg_index'].id
1020
+
1021
+ seg_ids = [segment.seg_id for segment in segments]
1022
+ n_segments = len(segments)
1023
+ n_total_segments = si_dsid.shape[0]
1024
+ system = self.system
1025
+ pcoord_ndim = system.pcoord_ndim
1026
+ pcoord_len = system.pcoord_len
1027
+ pcoord_dtype = system.pcoord_dtype
1028
+
1029
+ seg_index_entries = np.empty((n_segments,), dtype=seg_index_dtype)
1030
+ pcoord_entries = np.empty((n_segments, pcoord_len, pcoord_ndim), dtype=pcoord_dtype)
1031
+
1032
+ pc_msel = h5s.create_simple(pcoord_entries.shape, (h5s.UNLIMITED,) * pcoord_entries.ndim)
1033
+ pc_msel.select_all()
1034
+ si_msel = h5s.create_simple(seg_index_entries.shape, (h5s.UNLIMITED,))
1035
+ si_msel.select_all()
1036
+ pc_fsel = pc_dsid.get_space()
1037
+ si_fsel = si_dsid.get_space()
1038
+
1039
+ for iseg in range(n_segments):
1040
+ seg_id = seg_ids[iseg]
1041
+ op = h5s.SELECT_OR if iseg != 0 else h5s.SELECT_SET
1042
+ si_fsel.select_hyperslab((seg_id,), (1,), op=op)
1043
+ pc_fsel.select_hyperslab((seg_id, 0, 0), (1, pcoord_len, pcoord_ndim), op=op)
1044
+
1045
+ # read summary data so that we have value, parent and weight transfer information
1046
+ si_dsid.read(si_msel, si_fsel, seg_index_entries)
1047
+
1048
+ for iseg, (segment, ientry) in enumerate(zip(segments, seg_index_entries)):
1049
+ ientry['status'] = segment.status
1050
+ ientry['endpoint_type'] = segment.endpoint_type or Segment.SEG_ENDPOINT_UNSET
1051
+ ientry['cputime'] = segment.cputime
1052
+ ientry['walltime'] = segment.walltime
1053
+ ientry['weight'] = segment.weight
1054
+
1055
+ pcoord_entries[iseg] = segment.pcoord
1056
+
1057
+ # write progress coordinates and index using low level HDF5 functions for efficiency
1058
+ si_dsid.write(si_msel, si_fsel, seg_index_entries)
1059
+ pc_dsid.write(pc_msel, pc_fsel, pcoord_entries)
1060
+
1061
+ # Now, to deal with auxiliary data
1062
+ # If any segment has any auxiliary data, then the aux dataset must spring into
1063
+ # existence. Each is named according to the name in segment.data, and has shape
1064
+ # (n_total_segs, ...) where the ... is the shape of the data in segment.data (and may be empty
1065
+ # in the case of scalar data) and dtype is taken from the data type of the data entry
1066
+ # compression is on by default for datasets that will be more than 1MiB
1067
+
1068
+ # a mapping of data set name to (per-segment shape, data type) tuples
1069
+ dsets = {}
1070
+
1071
+ # First we scan for presence, shape, and data type of auxiliary data sets
1072
+ for segment in segments:
1073
+ if segment.data:
1074
+ for dsname in segment.data:
1075
+ if dsname.startswith('iterh5/'):
1076
+ continue
1077
+ data = np.asarray(segment.data[dsname], order='C')
1078
+ segment.data[dsname] = data
1079
+ dsets[dsname] = (data.shape, data.dtype)
1080
+
1081
+ # Then we iterate over data sets and store data
1082
+ if dsets:
1083
+ for dsname, (shape, dtype) in dsets.items():
1084
+ # dset = self._require_aux_dataset(iter_group, dsname, n_total_segments, shape, dtype)
1085
+ try:
1086
+ dsopts = self.dataset_options[dsname]
1087
+ except KeyError:
1088
+ dsopts = normalize_dataset_options({'name': dsname}, path_prefix='auxdata')
1089
+
1090
+ shape = (n_total_segments,) + shape
1091
+ dset = require_dataset_from_dsopts(
1092
+ iter_group, dsopts, shape, dtype, autocompress_threshold=self.aux_compression_threshold, n_iter=n_iter
1093
+ )
1094
+ if dset is None:
1095
+ # storage is suppressed
1096
+ continue
1097
+ for segment in segments:
1098
+ try:
1099
+ auxdataset = segment.data[dsname]
1100
+ except KeyError:
1101
+ pass
1102
+ else:
1103
+ source_rank = len(auxdataset.shape)
1104
+ source_sel = h5s.create_simple(auxdataset.shape, (h5s.UNLIMITED,) * source_rank)
1105
+ source_sel.select_all()
1106
+ dest_sel = dset.id.get_space()
1107
+ dest_sel.select_hyperslab((segment.seg_id,) + (0,) * source_rank, (1,) + auxdataset.shape)
1108
+ dset.id.write(source_sel, dest_sel, auxdataset)
1109
+ if 'delram' in list(dsopts.keys()):
1110
+ del dsets[dsname]
1111
+
1112
+ self.update_iter_h5file(n_iter, segments)
1113
+
1114
+ def get_segments(self, n_iter=None, seg_ids=None, load_pcoords=True):
1115
+ """
1116
+ Return the given (or all) segments from a given iteration.
1117
+
1118
+ If the optional parameter ``load_auxdata`` is true, then all auxiliary datasets
1119
+ available are loaded and mapped onto the ``data`` dictionary of each segment. If
1120
+ ``load_auxdata`` is None, then use the default ``self.auto_load_auxdata``, which can
1121
+ be set by the option ``load_auxdata`` in the ``[data]`` section of ``west.cfg``. This
1122
+ essentially requires as much RAM as there is per-iteration auxiliary data, so this
1123
+ behavior is not on by default.
1124
+ """
1125
+
1126
+ n_iter = n_iter or self.current_iteration
1127
+ file_version = self.we_h5file_version
1128
+
1129
+ with self.lock:
1130
+ iter_group = self.get_iter_group(n_iter)
1131
+ seg_index_ds = iter_group['seg_index']
1132
+
1133
+ if file_version < 5:
1134
+ all_parent_ids = iter_group['parents'][...]
1135
+ else:
1136
+ all_parent_ids = iter_group['wtgraph'][...]
1137
+
1138
+ if seg_ids is not None:
1139
+ seg_ids = list(sorted(seg_ids))
1140
+ seg_index_entries = seg_index_ds[seg_ids]
1141
+ if load_pcoords:
1142
+ pcoord_entries = iter_group['pcoord'][seg_ids]
1143
+ else:
1144
+ seg_ids = list(range(len(seg_index_ds)))
1145
+ seg_index_entries = seg_index_ds[...]
1146
+ if load_pcoords:
1147
+ pcoord_entries = iter_group['pcoord'][...]
1148
+
1149
+ segments = []
1150
+
1151
+ for iseg, (seg_id, row) in enumerate(zip(seg_ids, seg_index_entries)):
1152
+ segment = Segment(
1153
+ seg_id=seg_id,
1154
+ n_iter=n_iter,
1155
+ status=int(row['status']),
1156
+ endpoint_type=int(row['endpoint_type']),
1157
+ walltime=float(row['walltime']),
1158
+ cputime=float(row['cputime']),
1159
+ weight=float(row['weight']),
1160
+ )
1161
+
1162
+ if load_pcoords:
1163
+ segment.pcoord = pcoord_entries[iseg]
1164
+
1165
+ if file_version < 5:
1166
+ wtg_n_parents = row['n_parents']
1167
+ wtg_offset = row['parents_offset']
1168
+ wtg_parent_ids = all_parent_ids[wtg_offset : wtg_offset + wtg_n_parents]
1169
+ segment.parent_id = int(wtg_parent_ids[0])
1170
+ else:
1171
+ wtg_n_parents = row['wtg_n_parents']
1172
+ wtg_offset = row['wtg_offset']
1173
+ wtg_parent_ids = all_parent_ids[wtg_offset : wtg_offset + wtg_n_parents]
1174
+ segment.parent_id = int(row['parent_id'])
1175
+ segment.wtg_parent_ids = set(map(int, wtg_parent_ids))
1176
+ assert len(segment.wtg_parent_ids) == wtg_n_parents
1177
+ segments.append(segment)
1178
+ del all_parent_ids
1179
+ if load_pcoords:
1180
+ del pcoord_entries
1181
+
1182
+ # If any other data sets are requested, load them as well
1183
+ for dsinfo in self.dataset_options.values():
1184
+ if dsinfo.get('load', False):
1185
+ dsname = dsinfo['name']
1186
+ try:
1187
+ ds = iter_group[dsinfo['h5path']]
1188
+ except KeyError:
1189
+ ds = None
1190
+
1191
+ if ds is not None:
1192
+ for segment in segments:
1193
+ seg_id = segment.seg_id
1194
+ segment.data[dsname] = ds[seg_id]
1195
+
1196
+ return segments
1197
+
1198
+ def prepare_segment_restarts(self, segments, basis_states=None, initial_states=None):
1199
+ """
1200
+ Prepare the necessary folder and files given the data stored in parent per-iteration HDF5 file
1201
+ for propagating the simulation. ``basis_states`` and ``initial_states`` should be provided if the
1202
+ segments are newly created
1203
+ """
1204
+
1205
+ if not self.store_h5:
1206
+ return
1207
+
1208
+ for segment in segments:
1209
+ if segment.parent_id < 0:
1210
+ if initial_states is None or basis_states is None:
1211
+ raise ValueError('initial and basis states required for preparing the segments')
1212
+ initial_state = initial_states[segment.initial_state_id]
1213
+ # Check if it's a start state
1214
+ if initial_state.istate_type == InitialState.ISTATE_TYPE_START:
1215
+ log.debug(
1216
+ f'Skip reading start state file from per-iteration HDF5 file for initial state {segment.initial_state_id}'
1217
+ )
1218
+ continue
1219
+ else:
1220
+ basis_state = basis_states[initial_state.basis_state_id]
1221
+
1222
+ parent = Segment(n_iter=0, seg_id=basis_state.state_id)
1223
+ else:
1224
+ parent = Segment(n_iter=segment.n_iter - 1, seg_id=segment.parent_id)
1225
+
1226
+ try:
1227
+ parent_iter_ref_h5_file = makepath(self.iter_h5_path_template, {'n_iter': parent.n_iter})
1228
+
1229
+ with h5io.WESTIterationFile(parent_iter_ref_h5_file, 'r') as outf:
1230
+ outf.read_restart(parent)
1231
+
1232
+ segment.data['iterh5/restart'] = parent.data['iterh5/restart']
1233
+ except Exception as e:
1234
+ print('could not prepare restart data for segment {}/{}: {}'.format(segment.n_iter, segment.seg_id, str(e)))
1235
+
1236
+ def get_all_parent_ids(self, n_iter):
1237
+ file_version = self.we_h5file_version
1238
+ with self.lock:
1239
+ iter_group = self.get_iter_group(n_iter)
1240
+ seg_index = iter_group['seg_index']
1241
+
1242
+ if file_version < 5:
1243
+ offsets = seg_index['parents_offset']
1244
+ all_parents = iter_group['parents'][...]
1245
+ return all_parents.take(offsets)
1246
+ else:
1247
+ return seg_index['parent_id']
1248
+
1249
+ def get_parent_ids(self, n_iter, seg_ids=None):
1250
+ '''Return a sequence of the parent IDs of the given seg_ids.'''
1251
+
1252
+ file_version = self.we_h5file_version
1253
+
1254
+ with self.lock:
1255
+ iter_group = self.get_iter_group(n_iter)
1256
+ seg_index = iter_group['seg_index']
1257
+
1258
+ if seg_ids is None:
1259
+ seg_ids = range(len(seg_index))
1260
+
1261
+ if file_version < 5:
1262
+ offsets = seg_index['parents_offset']
1263
+ all_parents = iter_group['parents'][...]
1264
+ return [all_parents[offsets[seg_id]] for seg_id in seg_ids]
1265
+ else:
1266
+ all_parents = seg_index['parent_id']
1267
+ return [all_parents[seg_id] for seg_id in seg_ids]
1268
+
1269
+ def get_weights(self, n_iter, seg_ids):
1270
+ '''Return the weights associated with the given seg_ids'''
1271
+
1272
+ unique_ids = sorted(set(seg_ids))
1273
+ if not unique_ids:
1274
+ return []
1275
+ with self.lock:
1276
+ iter_group = self.get_iter_group(n_iter)
1277
+ index_subset = iter_group['seg_index'][unique_ids]
1278
+ weight_map = dict(zip(unique_ids, index_subset['weight']))
1279
+ return [weight_map[seg_id] for seg_id in seg_ids]
1280
+
1281
+ def get_child_ids(self, n_iter, seg_id):
1282
+ """Return the seg_ids of segments who have the given segment as a parent."""
1283
+
1284
+ with self.lock:
1285
+ if n_iter == self.current_iteration:
1286
+ return []
1287
+
1288
+ iter_group = self.get_iter_group(n_iter + 1)
1289
+ seg_index = iter_group['seg_index']
1290
+ seg_ids = np.arange(len(seg_index), dtype=seg_id_dtype)
1291
+
1292
+ if self.we_h5file_version < 5:
1293
+ offsets = seg_index['parents_offset']
1294
+ all_parent_ids = iter_group['parents'][...]
1295
+ parent_ids = np.array([all_parent_ids[offset] for offset in offsets])
1296
+ else:
1297
+ parent_ids = seg_index['parent_id']
1298
+
1299
+ return seg_ids[parent_ids == seg_id]
1300
+
1301
+ def get_children(self, segment):
1302
+ """Return all segments which have the given segment as a parent"""
1303
+
1304
+ if segment.n_iter == self.current_iteration:
1305
+ return []
1306
+
1307
+ # Examine the segment index from the following iteration to see who has this segment
1308
+ # as a parent. We don't need to worry about the number of parents each segment
1309
+ # has, since each has at least one, and indexing on the offset into the parents array
1310
+ # gives the primary parent ID
1311
+
1312
+ with self.lock:
1313
+ iter_group = self.get_iter_group(segment.n_iter + 1)
1314
+ seg_index = iter_group['seg_index'][...]
1315
+
1316
+ # This is one of the slowest pieces of code I've ever written...
1317
+ # seg_index = iter_group['seg_index'][...]
1318
+ # seg_ids = [seg_id for (seg_id,row) in enumerate(seg_index)
1319
+ # if all_parent_ids[row['parents_offset']] == segment.seg_id]
1320
+ # return self.get_segments_by_id(segment.n_iter+1, seg_ids)
1321
+ if self.we_h5file_version < 5:
1322
+ parents = iter_group['parents'][seg_index['parent_offsets']]
1323
+ else:
1324
+ parents = seg_index['parent_id']
1325
+ all_seg_ids = np.arange(seg_index.len(), dtype=np.uintp)
1326
+ seg_ids = all_seg_ids[parents == segment.seg_id]
1327
+ # the above will return a scalar if only one is found, so convert
1328
+ # to a list if necessary
1329
+ try:
1330
+ len(seg_ids)
1331
+ except TypeError:
1332
+ seg_ids = [seg_ids]
1333
+
1334
+ return self.get_segments(segment.n_iter + 1, seg_ids)
1335
+
1336
+ # The following are dictated by the SimManager interface
1337
+ def prepare_run(self):
1338
+ self.open_backing()
1339
+
1340
+ def finalize_run(self):
1341
+ self.flush_backing()
1342
+ self.close_backing()
1343
+
1344
+ def save_new_weight_data(self, n_iter, new_weights):
1345
+ """
1346
+ Save a set of NewWeightEntry objects to HDF5. Note that this should
1347
+ be called for the iteration in which the weights appear in their
1348
+ new locations (e.g. for recycled walkers, the iteration following
1349
+ recycling).
1350
+ """
1351
+
1352
+ if not new_weights:
1353
+ return
1354
+
1355
+ system = self.system
1356
+
1357
+ index = np.empty(len(new_weights), dtype=nw_index_dtype)
1358
+ prev_init_pcoords = system.new_pcoord_array(len(new_weights))
1359
+ prev_final_pcoords = system.new_pcoord_array(len(new_weights))
1360
+ new_init_pcoords = system.new_pcoord_array(len(new_weights))
1361
+
1362
+ for ientry, nwentry in enumerate(new_weights):
1363
+ row = index[ientry]
1364
+ row['source_type'] = nwentry.source_type
1365
+ row['weight'] = nwentry.weight
1366
+ row['prev_seg_id'] = nwentry.prev_seg_id
1367
+ # the following use -1 as a sentinel for a missing value
1368
+ row['target_state_id'] = nwentry.target_state_id if nwentry.target_state_id is not None else -1
1369
+ row['initial_state_id'] = nwentry.initial_state_id if nwentry.initial_state_id is not None else -1
1370
+
1371
+ index[ientry] = row
1372
+
1373
+ if nwentry.prev_init_pcoord is not None:
1374
+ prev_init_pcoords[ientry] = nwentry.prev_init_pcoord
1375
+
1376
+ if nwentry.prev_final_pcoord is not None:
1377
+ prev_final_pcoords[ientry] = nwentry.prev_final_pcoord
1378
+
1379
+ if nwentry.new_init_pcoord is not None:
1380
+ new_init_pcoords[ientry] = nwentry.new_init_pcoord
1381
+
1382
+ with self.lock:
1383
+ iter_group = self.get_iter_group(n_iter)
1384
+ try:
1385
+ del iter_group['new_weights']
1386
+ except KeyError:
1387
+ pass
1388
+
1389
+ nwgroup = iter_group.create_group('new_weights')
1390
+ nwgroup['index'] = index
1391
+ nwgroup['prev_init_pcoord'] = prev_init_pcoords
1392
+ nwgroup['prev_final_pcoord'] = prev_final_pcoords
1393
+ nwgroup['new_init_pcoord'] = new_init_pcoords
1394
+
1395
+ def get_new_weight_data(self, n_iter):
1396
+ with self.lock:
1397
+ iter_group = self.get_iter_group(n_iter)
1398
+
1399
+ try:
1400
+ nwgroup = iter_group['new_weights']
1401
+ except KeyError:
1402
+ return []
1403
+
1404
+ try:
1405
+ index = nwgroup['index'][...]
1406
+ prev_init_pcoords = nwgroup['prev_init_pcoord'][...]
1407
+ prev_final_pcoords = nwgroup['prev_final_pcoord'][...]
1408
+ new_init_pcoords = nwgroup['new_init_pcoord'][...]
1409
+ except (KeyError, ValueError): # zero-length selections raise ValueError
1410
+ return []
1411
+
1412
+ entries = []
1413
+ for i in range(len(index)):
1414
+ irow = index[i]
1415
+
1416
+ prev_seg_id = irow['prev_seg_id']
1417
+ if prev_seg_id == -1:
1418
+ prev_seg_id = None
1419
+
1420
+ initial_state_id = irow['initial_state_id']
1421
+ if initial_state_id == -1:
1422
+ initial_state_id = None
1423
+
1424
+ target_state_id = irow['target_state_id']
1425
+ if target_state_id == -1:
1426
+ target_state_id = None
1427
+
1428
+ entry = NewWeightEntry(
1429
+ source_type=irow['source_type'],
1430
+ weight=irow['weight'],
1431
+ prev_seg_id=prev_seg_id,
1432
+ prev_init_pcoord=prev_init_pcoords[i].copy(),
1433
+ prev_final_pcoord=prev_final_pcoords[i].copy(),
1434
+ new_init_pcoord=new_init_pcoords[i].copy(),
1435
+ target_state_id=target_state_id,
1436
+ initial_state_id=initial_state_id,
1437
+ )
1438
+
1439
+ entries.append(entry)
1440
+ return entries
1441
+
1442
+ def find_bin_mapper(self, hashval):
1443
+ """
1444
+ Check to see if the given has value is in the binning table. Returns the index in the
1445
+ bin data tables if found, or raises KeyError if not.
1446
+ """
1447
+
1448
+ try:
1449
+ hashval = hashval.hexdigest()
1450
+ except AttributeError:
1451
+ pass
1452
+
1453
+ with self.lock:
1454
+ # these will raise KeyError if the group doesn't exist, which also means
1455
+ # that bin data is not available, so no special treatment here
1456
+ try:
1457
+ binning_group = self.we_h5file['/bin_topologies']
1458
+ index = binning_group['index']
1459
+ except KeyError:
1460
+ raise KeyError('hash {} not found'.format(hashval))
1461
+
1462
+ n_entries = len(index)
1463
+ if n_entries == 0:
1464
+ raise KeyError('hash {} not found'.format(hashval))
1465
+
1466
+ chunksize = self.table_scan_chunksize
1467
+ for istart in range(0, n_entries, chunksize):
1468
+ chunk = index[istart : min(istart + chunksize, n_entries)]
1469
+ for i in range(len(chunk)):
1470
+ if chunk[i]['hash'] == bytes(hashval, 'utf-8'):
1471
+ return istart + i
1472
+
1473
+ raise KeyError('hash {} not found'.format(hashval))
1474
+
1475
+ def get_bin_mapper(self, hashval):
1476
+ """Look up the given hash value in the binning table, unpickling and returning the corresponding
1477
+ bin mapper if available, or raising KeyError if not."""
1478
+
1479
+ # Convert to a hex digest if we need to
1480
+ try:
1481
+ hashval = hashval.hexdigest()
1482
+ except AttributeError:
1483
+ pass
1484
+
1485
+ with self.lock:
1486
+ # these will raise KeyError if the group doesn't exist, which also means
1487
+ # that bin data is not available, so no special treatment here
1488
+ try:
1489
+ binning_group = self.we_h5file['/bin_topologies']
1490
+ index = binning_group['index']
1491
+ pkl = binning_group['pickles']
1492
+ except KeyError:
1493
+ raise KeyError('hash {} not found. Could not retrieve binning group'.format(hashval))
1494
+
1495
+ n_entries = len(index)
1496
+ if n_entries == 0:
1497
+ raise KeyError('hash {} not found. No entries in index'.format(hashval))
1498
+
1499
+ chunksize = self.table_scan_chunksize
1500
+
1501
+ for istart in range(0, n_entries, chunksize):
1502
+ chunk = index[istart : min(istart + chunksize, n_entries)]
1503
+ for i in range(len(chunk)):
1504
+ if chunk[i]['hash'] == bytes(hashval, 'utf-8'):
1505
+ pkldat = bytes(pkl[istart + i, 0 : chunk[i]['pickle_len']].data)
1506
+ mapper = pickle.loads(pkldat)
1507
+ log.debug('loaded {!r} from {!r}'.format(mapper, binning_group))
1508
+ log.debug('hash value {!r}'.format(hashval))
1509
+ return mapper
1510
+
1511
+ raise KeyError('hash {} not found'.format(hashval))
1512
+
1513
+ def save_bin_mapper(self, hashval, pickle_data):
1514
+ """Store the given mapper in the table of saved mappers. If the mapper cannot be stored,
1515
+ PickleError will be raised. Returns the index in the bin data tables where the mapper is stored."""
1516
+
1517
+ try:
1518
+ hashval = hashval.hexdigest()
1519
+ except AttributeError:
1520
+ pass
1521
+ pickle_data = bytes(pickle_data)
1522
+
1523
+ # First, scan to see if the mapper already is in the HDF5 file
1524
+ try:
1525
+ return self.find_bin_mapper(hashval)
1526
+ except KeyError:
1527
+ pass
1528
+
1529
+ # At this point, we have a valid pickle and know it's not stored
1530
+ with self.lock:
1531
+ binning_group = self.we_h5file.require_group('/bin_topologies')
1532
+
1533
+ try:
1534
+ index = binning_group['index']
1535
+ pickle_ds = binning_group['pickles']
1536
+ except KeyError:
1537
+ index = binning_group.create_dataset('index', shape=(1,), maxshape=(None,), dtype=binning_index_dtype)
1538
+ pickle_ds = binning_group.create_dataset(
1539
+ 'pickles',
1540
+ dtype=np.uint8,
1541
+ shape=(1, len(pickle_data)),
1542
+ maxshape=(None, None),
1543
+ chunks=(1, 4096),
1544
+ compression='gzip',
1545
+ compression_opts=9,
1546
+ )
1547
+ n_entries = 1
1548
+ else:
1549
+ n_entries = len(index) + 1
1550
+ index.resize((n_entries,))
1551
+ new_hsize = max(pickle_ds.shape[1], len(pickle_data))
1552
+ pickle_ds.resize((n_entries, new_hsize))
1553
+
1554
+ index_row = index[n_entries - 1]
1555
+ index_row['hash'] = hashval
1556
+ index_row['pickle_len'] = len(pickle_data)
1557
+ index[n_entries - 1] = index_row
1558
+ pickle_ds[n_entries - 1, : len(pickle_data)] = memoryview(pickle_data)
1559
+ return n_entries - 1
1560
+
1561
+ def save_iter_binning(self, n_iter, hashval, pickled_mapper, target_counts):
1562
+ """Save information about the binning used to generate segments for iteration n_iter."""
1563
+
1564
+ with self.lock:
1565
+ iter_group = self.get_iter_group(n_iter)
1566
+
1567
+ try:
1568
+ del iter_group['bin_target_counts']
1569
+ except KeyError:
1570
+ pass
1571
+
1572
+ iter_group['bin_target_counts'] = target_counts
1573
+
1574
+ if hashval and pickled_mapper:
1575
+ self.save_bin_mapper(hashval, pickled_mapper)
1576
+ iter_group.attrs['binhash'] = hashval
1577
+ else:
1578
+ iter_group.attrs['binhash'] = ''
1579
+
1580
+
1581
+ def normalize_dataset_options(dsopts, path_prefix='', n_iter=0):
1582
+ dsopts = dict(dsopts)
1583
+
1584
+ ds_name = dsopts['name']
1585
+ if path_prefix:
1586
+ default_h5path = '{}/{}'.format(path_prefix, ds_name)
1587
+ else:
1588
+ default_h5path = ds_name
1589
+
1590
+ dsopts.setdefault('h5path', default_h5path)
1591
+ dtype = dsopts.get('dtype')
1592
+ if dtype:
1593
+ if isinstance(dtype, str):
1594
+ try:
1595
+ dsopts['dtype'] = np.dtype(getattr(np, dtype))
1596
+ except AttributeError:
1597
+ dsopts['dtype'] = np.dtype(getattr(builtins, dtype))
1598
+ else:
1599
+ dsopts['dtype'] = np.dtype(dtype)
1600
+
1601
+ dsopts['store'] = bool(dsopts['store']) if 'store' in dsopts else True
1602
+ dsopts['load'] = bool(dsopts['load']) if 'load' in dsopts else False
1603
+
1604
+ return dsopts
1605
+
1606
+
1607
+ def create_dataset_from_dsopts(group, dsopts, shape=None, dtype=None, data=None, autocompress_threshold=None, n_iter=None):
1608
+ # log.debug('create_dataset_from_dsopts(group={!r}, dsopts={!r}, shape={!r}, dtype={!r}, data={!r}, autocompress_threshold={!r})'
1609
+ # .format(group,dsopts,shape,dtype,data,autocompress_threshold))
1610
+ if not dsopts.get('store', True):
1611
+ return None
1612
+
1613
+ if 'file' in list(dsopts.keys()):
1614
+ import h5py
1615
+
1616
+ # dsopts['file'] = str(dsopts['file']).format(n_iter=n_iter)
1617
+ h5_auxfile = h5io.WESTPAH5File(dsopts['file'].format(n_iter=n_iter))
1618
+ h5group = group
1619
+ if ("iter_" + str(n_iter).zfill(8)) not in h5_auxfile:
1620
+ h5_auxfile.create_group("iter_" + str(n_iter).zfill(8))
1621
+ group = h5_auxfile[('/' + "iter_" + str(n_iter).zfill(8))]
1622
+
1623
+ h5path = dsopts['h5path']
1624
+ containing_group_name = posixpath.dirname(h5path)
1625
+ h5_dsname = posixpath.basename(h5path)
1626
+
1627
+ # ensure arguments are sane
1628
+ if not shape and data is None:
1629
+ raise ValueError('either shape or data must be provided')
1630
+ elif data is None and (shape and dtype is None):
1631
+ raise ValueError('both shape and dtype must be provided when data is not provided')
1632
+ elif shape and data is not None and not data.shape == shape:
1633
+ raise ValueError('explicit shape {!r} does not match data shape {!r}'.format(shape, data.shape))
1634
+
1635
+ if data is not None:
1636
+ shape = data.shape
1637
+ if dtype is None:
1638
+ dtype = data.dtype
1639
+ # end argument sanity checks
1640
+
1641
+ # figure out where to store this data
1642
+ if containing_group_name:
1643
+ containing_group = group.require_group(containing_group_name)
1644
+ else:
1645
+ containing_group = group
1646
+
1647
+ # has user requested an explicit data type?
1648
+ # the extra np.dtype is an idempotent operation on true dtype
1649
+ # objects, but ensures that things like np.float32, which are
1650
+ # actually NOT dtype objects, become dtype objects
1651
+ h5_dtype = np.dtype(dsopts.get('dtype', dtype))
1652
+
1653
+ compression = None
1654
+ scaleoffset = None
1655
+ shuffle = False
1656
+
1657
+ # compress if 1) explicitly requested, or 2) dataset size exceeds threshold and
1658
+ # compression not explicitly prohibited
1659
+ compression_directive = dsopts.get('compression')
1660
+ if compression_directive is None:
1661
+ # No directive
1662
+ nbytes = np.multiply.reduce(shape) * h5_dtype.itemsize
1663
+ if autocompress_threshold and nbytes > autocompress_threshold:
1664
+ compression = 9
1665
+ elif compression_directive == 0: # includes False
1666
+ # Compression prohibited
1667
+ compression = None
1668
+ else: # compression explicitly requested
1669
+ compression = compression_directive
1670
+
1671
+ # Is scale/offset requested?
1672
+ scaleoffset = dsopts.get('scaleoffset', None)
1673
+ if scaleoffset is not None:
1674
+ scaleoffset = int(scaleoffset)
1675
+
1676
+ # We always shuffle if we compress (losslessly)
1677
+ if compression:
1678
+ shuffle = True
1679
+ else:
1680
+ shuffle = False
1681
+
1682
+ need_chunks = any([compression, scaleoffset is not None, shuffle])
1683
+
1684
+ # We use user-provided chunks if available
1685
+ chunks_directive = dsopts.get('chunks')
1686
+ if chunks_directive is None:
1687
+ chunks = None
1688
+ elif chunks_directive is True:
1689
+ chunks = calc_chunksize(shape, h5_dtype)
1690
+ elif chunks_directive is False:
1691
+ chunks = None
1692
+ else:
1693
+ chunks = tuple(chunks_directive[i] if chunks_directive[i] <= shape[i] else shape[i] for i in range(len(shape)))
1694
+
1695
+ if not chunks and need_chunks:
1696
+ chunks = calc_chunksize(shape, h5_dtype)
1697
+
1698
+ opts = {'shape': shape, 'dtype': h5_dtype, 'compression': compression, 'shuffle': shuffle, 'chunks': chunks}
1699
+
1700
+ try:
1701
+ import h5py._hl.filters
1702
+
1703
+ h5py._hl.filters._COMP_FILTERS['scaleoffset']
1704
+ except (ImportError, KeyError, AttributeError):
1705
+ # filter not available, or an unexpected version of h5py
1706
+ # use lossless compression instead
1707
+ opts['compression'] = True
1708
+ else:
1709
+ opts['scaleoffset'] = scaleoffset
1710
+
1711
+ if log.isEnabledFor(logging.DEBUG):
1712
+ log.debug('requiring aux dataset {!r}, shape={!r}, opts={!r}'.format(h5_dsname, shape, opts))
1713
+
1714
+ dset = containing_group.require_dataset(h5_dsname, **opts)
1715
+
1716
+ if data is not None:
1717
+ dset[...] = data
1718
+
1719
+ if 'file' in list(dsopts.keys()):
1720
+ import h5py
1721
+
1722
+ if dsopts['h5path'] not in h5group:
1723
+ h5group[dsopts['h5path']] = h5py.ExternalLink(
1724
+ dsopts['file'].format(n_iter=n_iter), ("/" + "iter_" + str(n_iter).zfill(8) + "/" + dsopts['h5path'])
1725
+ )
1726
+
1727
+ return dset
1728
+
1729
+
1730
+ def require_dataset_from_dsopts(group, dsopts, shape=None, dtype=None, data=None, autocompress_threshold=None, n_iter=None):
1731
+ if not dsopts.get('store', True):
1732
+ return None
1733
+ try:
1734
+ return group[dsopts['h5path']]
1735
+ except KeyError:
1736
+ return create_dataset_from_dsopts(
1737
+ group, dsopts, shape=shape, dtype=dtype, data=data, autocompress_threshold=autocompress_threshold, n_iter=n_iter
1738
+ )
1739
+
1740
+
1741
+ def calc_chunksize(shape, dtype, max_chunksize=262144):
1742
+ """Calculate a chunk size for HDF5 data, anticipating that access will slice
1743
+ along lower dimensions sooner than higher dimensions."""
1744
+
1745
+ chunk_shape = list(shape)
1746
+ for idim in range(len(shape)):
1747
+ chunk_nbytes = np.multiply.reduce(chunk_shape) * dtype.itemsize
1748
+ while chunk_shape[idim] > 1 and chunk_nbytes > max_chunksize:
1749
+ chunk_shape[idim] >>= 1 # divide by 2
1750
+ chunk_nbytes = np.multiply.reduce(chunk_shape) * dtype.itemsize
1751
+
1752
+ if chunk_nbytes <= max_chunksize:
1753
+ break
1754
+
1755
+ chunk_shape = tuple(chunk_shape)
1756
+ log.debug(
1757
+ 'selected chunk shape {} for data set of type {} shaped {} (chunk size = {} bytes)'.format(
1758
+ chunk_shape, dtype, shape, chunk_nbytes
1759
+ )
1760
+ )
1761
+ return chunk_shape