westpa 2022.10__cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of westpa might be problematic. Click here for more details.
- westpa/__init__.py +14 -0
- westpa/_version.py +21 -0
- westpa/analysis/__init__.py +5 -0
- westpa/analysis/core.py +746 -0
- westpa/analysis/statistics.py +27 -0
- westpa/analysis/trajectories.py +360 -0
- westpa/cli/__init__.py +0 -0
- westpa/cli/core/__init__.py +0 -0
- westpa/cli/core/w_fork.py +152 -0
- westpa/cli/core/w_init.py +230 -0
- westpa/cli/core/w_run.py +77 -0
- westpa/cli/core/w_states.py +212 -0
- westpa/cli/core/w_succ.py +99 -0
- westpa/cli/core/w_truncate.py +59 -0
- westpa/cli/tools/__init__.py +0 -0
- westpa/cli/tools/ploterr.py +506 -0
- westpa/cli/tools/plothist.py +706 -0
- westpa/cli/tools/w_assign.py +596 -0
- westpa/cli/tools/w_bins.py +166 -0
- westpa/cli/tools/w_crawl.py +119 -0
- westpa/cli/tools/w_direct.py +547 -0
- westpa/cli/tools/w_dumpsegs.py +94 -0
- westpa/cli/tools/w_eddist.py +506 -0
- westpa/cli/tools/w_fluxanl.py +378 -0
- westpa/cli/tools/w_ipa.py +833 -0
- westpa/cli/tools/w_kinavg.py +127 -0
- westpa/cli/tools/w_kinetics.py +96 -0
- westpa/cli/tools/w_multi_west.py +414 -0
- westpa/cli/tools/w_ntop.py +213 -0
- westpa/cli/tools/w_pdist.py +515 -0
- westpa/cli/tools/w_postanalysis_matrix.py +82 -0
- westpa/cli/tools/w_postanalysis_reweight.py +53 -0
- westpa/cli/tools/w_red.py +486 -0
- westpa/cli/tools/w_reweight.py +780 -0
- westpa/cli/tools/w_select.py +226 -0
- westpa/cli/tools/w_stateprobs.py +111 -0
- westpa/cli/tools/w_trace.py +599 -0
- westpa/core/__init__.py +0 -0
- westpa/core/_rc.py +673 -0
- westpa/core/binning/__init__.py +55 -0
- westpa/core/binning/_assign.cpython-312-x86_64-linux-gnu.so +0 -0
- westpa/core/binning/assign.py +449 -0
- westpa/core/binning/binless.py +96 -0
- westpa/core/binning/binless_driver.py +54 -0
- westpa/core/binning/binless_manager.py +190 -0
- westpa/core/binning/bins.py +47 -0
- westpa/core/binning/mab.py +427 -0
- westpa/core/binning/mab_driver.py +54 -0
- westpa/core/binning/mab_manager.py +198 -0
- westpa/core/data_manager.py +1694 -0
- westpa/core/extloader.py +74 -0
- westpa/core/h5io.py +995 -0
- westpa/core/kinetics/__init__.py +24 -0
- westpa/core/kinetics/_kinetics.cpython-312-x86_64-linux-gnu.so +0 -0
- westpa/core/kinetics/events.py +147 -0
- westpa/core/kinetics/matrates.py +156 -0
- westpa/core/kinetics/rate_averaging.py +266 -0
- westpa/core/progress.py +218 -0
- westpa/core/propagators/__init__.py +54 -0
- westpa/core/propagators/executable.py +715 -0
- westpa/core/reweight/__init__.py +14 -0
- westpa/core/reweight/_reweight.cpython-312-x86_64-linux-gnu.so +0 -0
- westpa/core/reweight/matrix.py +126 -0
- westpa/core/segment.py +119 -0
- westpa/core/sim_manager.py +830 -0
- westpa/core/states.py +359 -0
- westpa/core/systems.py +93 -0
- westpa/core/textio.py +74 -0
- westpa/core/trajectory.py +330 -0
- westpa/core/we_driver.py +908 -0
- westpa/core/wm_ops.py +43 -0
- westpa/core/yamlcfg.py +391 -0
- westpa/fasthist/__init__.py +34 -0
- westpa/fasthist/__main__.py +110 -0
- westpa/fasthist/_fasthist.cpython-312-x86_64-linux-gnu.so +0 -0
- westpa/mclib/__init__.py +264 -0
- westpa/mclib/__main__.py +28 -0
- westpa/mclib/_mclib.cpython-312-x86_64-linux-gnu.so +0 -0
- westpa/oldtools/__init__.py +4 -0
- westpa/oldtools/aframe/__init__.py +35 -0
- westpa/oldtools/aframe/atool.py +75 -0
- westpa/oldtools/aframe/base_mixin.py +26 -0
- westpa/oldtools/aframe/binning.py +178 -0
- westpa/oldtools/aframe/data_reader.py +560 -0
- westpa/oldtools/aframe/iter_range.py +200 -0
- westpa/oldtools/aframe/kinetics.py +117 -0
- westpa/oldtools/aframe/mcbs.py +146 -0
- westpa/oldtools/aframe/output.py +39 -0
- westpa/oldtools/aframe/plotting.py +90 -0
- westpa/oldtools/aframe/trajwalker.py +126 -0
- westpa/oldtools/aframe/transitions.py +469 -0
- westpa/oldtools/cmds/__init__.py +0 -0
- westpa/oldtools/cmds/w_ttimes.py +358 -0
- westpa/oldtools/files.py +34 -0
- westpa/oldtools/miscfn.py +23 -0
- westpa/oldtools/stats/__init__.py +4 -0
- westpa/oldtools/stats/accumulator.py +35 -0
- westpa/oldtools/stats/edfs.py +129 -0
- westpa/oldtools/stats/mcbs.py +89 -0
- westpa/tools/__init__.py +33 -0
- westpa/tools/binning.py +472 -0
- westpa/tools/core.py +340 -0
- westpa/tools/data_reader.py +159 -0
- westpa/tools/dtypes.py +31 -0
- westpa/tools/iter_range.py +198 -0
- westpa/tools/kinetics_tool.py +340 -0
- westpa/tools/plot.py +283 -0
- westpa/tools/progress.py +17 -0
- westpa/tools/selected_segs.py +154 -0
- westpa/tools/wipi.py +751 -0
- westpa/trajtree/__init__.py +4 -0
- westpa/trajtree/_trajtree.cpython-312-x86_64-linux-gnu.so +0 -0
- westpa/trajtree/trajtree.py +117 -0
- westpa/westext/__init__.py +0 -0
- westpa/westext/adaptvoronoi/__init__.py +3 -0
- westpa/westext/adaptvoronoi/adaptVor_driver.py +214 -0
- westpa/westext/hamsm_restarting/__init__.py +3 -0
- westpa/westext/hamsm_restarting/example_overrides.py +35 -0
- westpa/westext/hamsm_restarting/restart_driver.py +1165 -0
- westpa/westext/stringmethod/__init__.py +11 -0
- westpa/westext/stringmethod/fourier_fitting.py +69 -0
- westpa/westext/stringmethod/string_driver.py +253 -0
- westpa/westext/stringmethod/string_method.py +306 -0
- westpa/westext/weed/BinCluster.py +180 -0
- westpa/westext/weed/ProbAdjustEquil.py +100 -0
- westpa/westext/weed/UncertMath.py +247 -0
- westpa/westext/weed/__init__.py +10 -0
- westpa/westext/weed/weed_driver.py +182 -0
- westpa/westext/wess/ProbAdjust.py +101 -0
- westpa/westext/wess/__init__.py +6 -0
- westpa/westext/wess/wess_driver.py +207 -0
- westpa/work_managers/__init__.py +57 -0
- westpa/work_managers/core.py +396 -0
- westpa/work_managers/environment.py +134 -0
- westpa/work_managers/mpi.py +318 -0
- westpa/work_managers/processes.py +187 -0
- westpa/work_managers/serial.py +28 -0
- westpa/work_managers/threads.py +79 -0
- westpa/work_managers/zeromq/__init__.py +20 -0
- westpa/work_managers/zeromq/core.py +641 -0
- westpa/work_managers/zeromq/node.py +131 -0
- westpa/work_managers/zeromq/work_manager.py +526 -0
- westpa/work_managers/zeromq/worker.py +320 -0
- westpa-2022.10.dist-info/AUTHORS +22 -0
- westpa-2022.10.dist-info/LICENSE +21 -0
- westpa-2022.10.dist-info/METADATA +183 -0
- westpa-2022.10.dist-info/RECORD +150 -0
- westpa-2022.10.dist-info/WHEEL +6 -0
- westpa-2022.10.dist-info/entry_points.txt +29 -0
- westpa-2022.10.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1694 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HDF5 data manager for WEST.
|
|
3
|
+
|
|
4
|
+
Original HDF5 implementation: Joseph W. Kaus
|
|
5
|
+
Current implementation: Matthew C. Zwier
|
|
6
|
+
|
|
7
|
+
WEST exclusively uses the cross-platform, self-describing file format HDF5
|
|
8
|
+
for data storage. This ensures that data is stored efficiently and portably
|
|
9
|
+
in a manner that is relatively straightforward for other analysis tools
|
|
10
|
+
(perhaps written in C/C++/Fortran) to access.
|
|
11
|
+
|
|
12
|
+
The data is laid out in HDF5 as follows:
|
|
13
|
+
- summary -- overall summary data for the simulation
|
|
14
|
+
- /iterations/ -- data for individual iterations, one group per iteration under /iterations
|
|
15
|
+
- iter_00000001/ -- data for iteration 1
|
|
16
|
+
- seg_index -- overall information about segments in the iteration, including weight
|
|
17
|
+
- pcoord -- progress coordinate data organized as [seg_id][time][dimension]
|
|
18
|
+
- wtg_parents -- data used to reconstruct the split/merge history of trajectories
|
|
19
|
+
- recycling -- flux and event count for recycled particles, on a per-target-state basis
|
|
20
|
+
- auxdata/ -- auxiliary datasets (data stored on the 'data' field of Segment objects)
|
|
21
|
+
|
|
22
|
+
The file root object has an integer attribute 'west_file_format_version' which can be used to
|
|
23
|
+
determine how to access data even as the file format (i.e. organization of data within HDF5 file)
|
|
24
|
+
evolves.
|
|
25
|
+
|
|
26
|
+
Version history:
|
|
27
|
+
Version 9
|
|
28
|
+
- Basis states are now saved as iter_segid instead of just segid as a pointer label.
|
|
29
|
+
- Initial states are also saved in the iteration 0 file, with a negative sign.
|
|
30
|
+
Version 8
|
|
31
|
+
- Added external links to trajectory files in iterations/iter_* groups, if the HDF5
|
|
32
|
+
framework was used.
|
|
33
|
+
- Added an iter group for the iteration 0 to store conformations of basis states.
|
|
34
|
+
Version 7
|
|
35
|
+
- Removed bin_assignments, bin_populations, and bin_rates from iteration group.
|
|
36
|
+
- Added new_segments subgroup to iteration group
|
|
37
|
+
Version 6
|
|
38
|
+
- ???
|
|
39
|
+
Version 5
|
|
40
|
+
- moved iter_* groups into a top-level iterations/ group,
|
|
41
|
+
- added in-HDF5 storage for basis states, target states, and generated states
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
import logging
|
|
45
|
+
import pickle
|
|
46
|
+
import posixpath
|
|
47
|
+
import sys
|
|
48
|
+
import threading
|
|
49
|
+
import time
|
|
50
|
+
import builtins
|
|
51
|
+
from operator import attrgetter
|
|
52
|
+
from os.path import relpath, dirname
|
|
53
|
+
|
|
54
|
+
import h5py
|
|
55
|
+
from h5py import h5s
|
|
56
|
+
import numpy as np
|
|
57
|
+
|
|
58
|
+
from . import h5io
|
|
59
|
+
from .segment import Segment
|
|
60
|
+
from .states import BasisState, TargetState, InitialState
|
|
61
|
+
from .we_driver import NewWeightEntry
|
|
62
|
+
from .propagators.executable import ExecutablePropagator
|
|
63
|
+
|
|
64
|
+
import westpa
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
log = logging.getLogger(__name__)
|
|
68
|
+
|
|
69
|
+
file_format_version = 9
|
|
70
|
+
|
|
71
|
+
makepath = ExecutablePropagator.makepath
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class flushing_lock:
|
|
75
|
+
def __init__(self, lock, fileobj):
|
|
76
|
+
self.lock = lock
|
|
77
|
+
self.fileobj = fileobj
|
|
78
|
+
|
|
79
|
+
def __enter__(self):
|
|
80
|
+
self.lock.acquire()
|
|
81
|
+
|
|
82
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
83
|
+
self.fileobj.flush()
|
|
84
|
+
self.lock.release()
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class expiring_flushing_lock:
|
|
88
|
+
def __init__(self, lock, flush_method, nextsync):
|
|
89
|
+
self.lock = lock
|
|
90
|
+
self.flush_method = flush_method
|
|
91
|
+
self.nextsync = nextsync
|
|
92
|
+
|
|
93
|
+
def __enter__(self):
|
|
94
|
+
self.lock.acquire()
|
|
95
|
+
|
|
96
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
97
|
+
if time.time() > self.nextsync:
|
|
98
|
+
self.flush_method()
|
|
99
|
+
self.lock.release()
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
# Data types for use in the HDF5 file
|
|
103
|
+
seg_id_dtype = np.int64 # Up to 9 quintillion segments per iteration; signed so that initial states can be stored negative
|
|
104
|
+
n_iter_dtype = np.uint32 # Up to 4 billion iterations
|
|
105
|
+
weight_dtype = np.float64 # about 15 digits of precision in weights
|
|
106
|
+
utime_dtype = np.float64 # ("u" for Unix time) Up to ~10^300 cpu-seconds
|
|
107
|
+
vstr_dtype = h5py.special_dtype(vlen=str)
|
|
108
|
+
h5ref_dtype = h5py.special_dtype(ref=h5py.Reference)
|
|
109
|
+
binhash_dtype = np.dtype('|S64')
|
|
110
|
+
|
|
111
|
+
# seg_status_dtype = h5py.special_dtype(enum=(np.uint8, Segment.statuses))
|
|
112
|
+
# seg_initpoint_dtype = h5py.special_dtype(enum=(np.uint8, Segment.initpoint_types))
|
|
113
|
+
# seg_endpoint_dtype = h5py.special_dtype(enum=(np.uint8, Segment.endpoint_types))
|
|
114
|
+
# istate_type_dtype = h5py.special_dtype(enum=(np.uint8, InitialState.istate_types))
|
|
115
|
+
# istate_status_dtype = h5py.special_dtype(enum=(np.uint8, InitialState.istate_statuses))
|
|
116
|
+
|
|
117
|
+
seg_status_dtype = np.uint8
|
|
118
|
+
seg_initpoint_dtype = np.uint8
|
|
119
|
+
seg_endpoint_dtype = np.uint8
|
|
120
|
+
istate_type_dtype = np.uint8
|
|
121
|
+
istate_status_dtype = np.uint8
|
|
122
|
+
|
|
123
|
+
summary_table_dtype = np.dtype(
|
|
124
|
+
[
|
|
125
|
+
('n_particles', seg_id_dtype), # Number of live trajectories in this iteration
|
|
126
|
+
('norm', weight_dtype), # Norm of probability, to watch for errors or drift
|
|
127
|
+
('min_bin_prob', weight_dtype), # Per-bin minimum probability
|
|
128
|
+
('max_bin_prob', weight_dtype), # Per-bin maximum probability
|
|
129
|
+
('min_seg_prob', weight_dtype), # Per-segment minimum probability
|
|
130
|
+
('max_seg_prob', weight_dtype), # Per-segment maximum probability
|
|
131
|
+
('cputime', utime_dtype), # Total CPU time for this iteration
|
|
132
|
+
('walltime', utime_dtype), # Total wallclock time for this iteration
|
|
133
|
+
('binhash', binhash_dtype),
|
|
134
|
+
]
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
# The HDF5 file tracks two distinct, but related, histories:
|
|
139
|
+
# (1) the evolution of the trajectory, which requires only an identifier
|
|
140
|
+
# of where a segment's initial state comes from (the "history graph");
|
|
141
|
+
# this is stored as the parent_id field of the seg index
|
|
142
|
+
# (2) the flow of probability due to splits, merges, and recycling events,
|
|
143
|
+
# which can be thought of as an adjacency list (the "weight graph")
|
|
144
|
+
# segment ID is implied by the row in the index table, and so is not stored
|
|
145
|
+
# initpoint_type remains implicitly stored as negative IDs (if parent_id < 0, then init_state_id = -(parent_id+1)
|
|
146
|
+
seg_index_dtype = np.dtype(
|
|
147
|
+
[
|
|
148
|
+
('weight', weight_dtype), # Statistical weight of this segment
|
|
149
|
+
('parent_id', seg_id_dtype), # ID of parent (for trajectory history)
|
|
150
|
+
('wtg_n_parents', np.uint), # number of parents this segment has in the weight transfer graph
|
|
151
|
+
('wtg_offset', np.uint), # offset into the weight transfer graph dataset
|
|
152
|
+
('cputime', utime_dtype), # CPU time used in propagating this segment
|
|
153
|
+
('walltime', utime_dtype), # Wallclock time used in propagating this segment
|
|
154
|
+
('endpoint_type', seg_endpoint_dtype), # Endpoint type (will continue, merged, or recycled)
|
|
155
|
+
('status', seg_status_dtype), # Status of propagation of this segment
|
|
156
|
+
]
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# Index to basis/initial states
|
|
160
|
+
ibstate_index_dtype = np.dtype([('iter_valid', np.uint), ('n_bstates', np.uint), ('group_ref', h5ref_dtype)])
|
|
161
|
+
|
|
162
|
+
# Basis state index type
|
|
163
|
+
bstate_dtype = np.dtype(
|
|
164
|
+
[
|
|
165
|
+
('label', vstr_dtype), # An optional descriptive label
|
|
166
|
+
('probability', weight_dtype), # Probability that this state will be selected
|
|
167
|
+
('auxref', vstr_dtype), # An optional auxiliar data reference
|
|
168
|
+
]
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
# Even when initial state generation is off and basis states are passed through directly, an initial state entry
|
|
172
|
+
# is created, as that allows precise tracing of the history of a given state in the most complex case of
|
|
173
|
+
# a new initial state for every new trajectory.
|
|
174
|
+
istate_dtype = np.dtype(
|
|
175
|
+
[
|
|
176
|
+
('iter_created', np.uint), # Iteration during which this state was generated (0 for at w_init)
|
|
177
|
+
('iter_used', np.uint), # When this state was used to start a new trajectory
|
|
178
|
+
('basis_state_id', seg_id_dtype), # Which basis state this state was generated from
|
|
179
|
+
('istate_type', istate_type_dtype), # What type this initial state is (generated or basis)
|
|
180
|
+
('istate_status', istate_status_dtype), # Whether this initial state is ready to go
|
|
181
|
+
('basis_auxref', vstr_dtype),
|
|
182
|
+
]
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
tstate_index_dtype = np.dtype(
|
|
186
|
+
[('iter_valid', np.uint), ('n_states', np.uint), ('group_ref', h5ref_dtype)] # Iteration when this state list is valid
|
|
187
|
+
) # Reference to a group containing further data; this will be the
|
|
188
|
+
# null reference if there is no target state for that timeframe.
|
|
189
|
+
tstate_dtype = np.dtype([('label', vstr_dtype)]) # An optional descriptive label for this state
|
|
190
|
+
|
|
191
|
+
# Support for west.we_driver.NewWeightEntry
|
|
192
|
+
nw_source_dtype = np.uint8
|
|
193
|
+
nw_index_dtype = np.dtype(
|
|
194
|
+
[
|
|
195
|
+
('source_type', nw_source_dtype),
|
|
196
|
+
('weight', weight_dtype),
|
|
197
|
+
('prev_seg_id', seg_id_dtype),
|
|
198
|
+
('target_state_id', seg_id_dtype),
|
|
199
|
+
('initial_state_id', seg_id_dtype),
|
|
200
|
+
]
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# Storage of bin identities
|
|
204
|
+
binning_index_dtype = np.dtype([('hash', binhash_dtype), ('pickle_len', np.uint32)])
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class WESTDataManager:
|
|
208
|
+
"""Data manager for assisiting the reading and writing of WEST data from/to HDF5 files."""
|
|
209
|
+
|
|
210
|
+
# defaults for various options
|
|
211
|
+
default_iter_prec = 8
|
|
212
|
+
default_we_h5filename = 'west.h5'
|
|
213
|
+
default_we_h5file_driver = None
|
|
214
|
+
default_flush_period = 60
|
|
215
|
+
|
|
216
|
+
# Compress any auxiliary dataset whose total size (across all segments) is more than 1MB
|
|
217
|
+
default_aux_compression_threshold = 1048576
|
|
218
|
+
|
|
219
|
+
# Bin data horizontal (second dimension) chunk size
|
|
220
|
+
binning_hchunksize = 4096
|
|
221
|
+
|
|
222
|
+
# Number of rows to retrieve during a table scan
|
|
223
|
+
table_scan_chunksize = 1024
|
|
224
|
+
|
|
225
|
+
def flushing_lock(self):
|
|
226
|
+
return flushing_lock(self.lock, self.we_h5file)
|
|
227
|
+
|
|
228
|
+
def expiring_flushing_lock(self):
|
|
229
|
+
next_flush = self.last_flush + self.flush_period
|
|
230
|
+
return expiring_flushing_lock(self.lock, self.flush_backing, next_flush)
|
|
231
|
+
|
|
232
|
+
def process_config(self):
|
|
233
|
+
config = self.rc.config
|
|
234
|
+
|
|
235
|
+
for entry, type_ in [('iter_prec', int)]:
|
|
236
|
+
config.require_type_if_present(['west', 'data', entry], type_)
|
|
237
|
+
|
|
238
|
+
self.we_h5filename = config.get_path(['west', 'data', 'west_data_file'], default=self.default_we_h5filename)
|
|
239
|
+
self.we_h5file_driver = config.get_choice(
|
|
240
|
+
['west', 'data', 'west_data_file_driver'],
|
|
241
|
+
[None, 'sec2', 'family'],
|
|
242
|
+
default=self.default_we_h5file_driver,
|
|
243
|
+
value_transform=(lambda x: x.lower() if x else None),
|
|
244
|
+
)
|
|
245
|
+
self.iter_prec = config.get(['west', 'data', 'iter_prec'], self.default_iter_prec)
|
|
246
|
+
self.aux_compression_threshold = config.get(
|
|
247
|
+
['west', 'data', 'aux_compression_threshold'], self.default_aux_compression_threshold
|
|
248
|
+
)
|
|
249
|
+
self.flush_period = config.get(['west', 'data', 'flush_period'], self.default_flush_period)
|
|
250
|
+
self.iter_ref_h5_template = config.get(['west', 'data', 'data_refs', 'iteration'], None)
|
|
251
|
+
self.store_h5 = self.iter_ref_h5_template is not None
|
|
252
|
+
|
|
253
|
+
# Process dataset options
|
|
254
|
+
dsopts_list = config.get(['west', 'data', 'datasets']) or []
|
|
255
|
+
for dsopts in dsopts_list:
|
|
256
|
+
dsopts = normalize_dataset_options(dsopts, path_prefix='auxdata' if dsopts['name'] != 'pcoord' else '')
|
|
257
|
+
try:
|
|
258
|
+
self.dataset_options[dsopts['name']].update(dsopts)
|
|
259
|
+
except KeyError:
|
|
260
|
+
self.dataset_options[dsopts['name']] = dsopts
|
|
261
|
+
|
|
262
|
+
if 'pcoord' in self.dataset_options:
|
|
263
|
+
if self.dataset_options['pcoord']['h5path'] != 'pcoord':
|
|
264
|
+
raise ValueError('cannot override pcoord storage location')
|
|
265
|
+
|
|
266
|
+
def __init__(self, rc=None):
|
|
267
|
+
self.rc = rc or westpa.rc
|
|
268
|
+
|
|
269
|
+
self.we_h5filename = self.default_we_h5filename
|
|
270
|
+
self.we_h5file_driver = self.default_we_h5file_driver
|
|
271
|
+
self.we_h5file_version = None
|
|
272
|
+
self.h5_access_mode = 'r+'
|
|
273
|
+
self.iter_prec = self.default_iter_prec
|
|
274
|
+
self.aux_compression_threshold = self.default_aux_compression_threshold
|
|
275
|
+
|
|
276
|
+
self.we_h5file = None
|
|
277
|
+
|
|
278
|
+
self.lock = threading.RLock()
|
|
279
|
+
self.flush_period = None
|
|
280
|
+
self.last_flush = 0
|
|
281
|
+
|
|
282
|
+
self._system = None
|
|
283
|
+
self.iter_ref_h5_template = None
|
|
284
|
+
self.store_h5 = False
|
|
285
|
+
|
|
286
|
+
self.dataset_options = {}
|
|
287
|
+
self.process_config()
|
|
288
|
+
|
|
289
|
+
@property
|
|
290
|
+
def system(self):
|
|
291
|
+
if self._system is None:
|
|
292
|
+
self._system = self.rc.get_system_driver()
|
|
293
|
+
return self._system
|
|
294
|
+
|
|
295
|
+
@system.setter
|
|
296
|
+
def system(self, system):
|
|
297
|
+
self._system = system
|
|
298
|
+
|
|
299
|
+
@property
|
|
300
|
+
def closed(self):
|
|
301
|
+
return self.we_h5file is None
|
|
302
|
+
|
|
303
|
+
def iter_group_name(self, n_iter, absolute=True):
|
|
304
|
+
if absolute:
|
|
305
|
+
return '/iterations/iter_{:0{prec}d}'.format(int(n_iter), prec=self.iter_prec)
|
|
306
|
+
else:
|
|
307
|
+
return 'iter_{:0{prec}d}'.format(int(n_iter), prec=self.iter_prec)
|
|
308
|
+
|
|
309
|
+
def require_iter_group(self, n_iter):
|
|
310
|
+
'''Get the group associated with n_iter, creating it if necessary.'''
|
|
311
|
+
with self.lock:
|
|
312
|
+
iter_group = self.we_h5file.require_group('/iterations/iter_{:0{prec}d}'.format(int(n_iter), prec=self.iter_prec))
|
|
313
|
+
iter_group.attrs['n_iter'] = n_iter
|
|
314
|
+
return iter_group
|
|
315
|
+
|
|
316
|
+
def del_iter_group(self, n_iter):
|
|
317
|
+
with self.lock:
|
|
318
|
+
del self.we_h5file['/iterations/iter_{:0{prec}d}'.format(int(n_iter), prec=self.iter_prec)]
|
|
319
|
+
|
|
320
|
+
def get_iter_group(self, n_iter):
|
|
321
|
+
with self.lock:
|
|
322
|
+
try:
|
|
323
|
+
return self.we_h5file['/iterations/iter_{:0{prec}d}'.format(int(n_iter), prec=self.iter_prec)]
|
|
324
|
+
except KeyError:
|
|
325
|
+
return self.we_h5file['/iter_{:0{prec}d}'.format(int(n_iter), prec=self.iter_prec)]
|
|
326
|
+
|
|
327
|
+
def get_seg_index(self, n_iter):
|
|
328
|
+
with self.lock:
|
|
329
|
+
seg_index = self.get_iter_group(n_iter)['seg_index']
|
|
330
|
+
return seg_index
|
|
331
|
+
|
|
332
|
+
@property
|
|
333
|
+
def current_iteration(self):
|
|
334
|
+
with self.lock:
|
|
335
|
+
h5file_attrs = self.we_h5file['/'].attrs
|
|
336
|
+
h5file_attr_keys = list(h5file_attrs.keys())
|
|
337
|
+
|
|
338
|
+
if 'west_current_iteration' in h5file_attr_keys:
|
|
339
|
+
return int(self.we_h5file['/'].attrs['west_current_iteration'])
|
|
340
|
+
else:
|
|
341
|
+
return int(self.we_h5file['/'].attrs['wemd_current_iteration'])
|
|
342
|
+
|
|
343
|
+
@current_iteration.setter
|
|
344
|
+
def current_iteration(self, n_iter):
|
|
345
|
+
with self.lock:
|
|
346
|
+
self.we_h5file['/'].attrs['west_current_iteration'] = n_iter
|
|
347
|
+
|
|
348
|
+
def open_backing(self, mode=None):
|
|
349
|
+
'''Open the (already-created) HDF5 file named in self.west_h5filename.'''
|
|
350
|
+
mode = mode or self.h5_access_mode
|
|
351
|
+
if not self.we_h5file:
|
|
352
|
+
log.debug('attempting to open {} with mode {}'.format(self.we_h5filename, mode))
|
|
353
|
+
self.we_h5file = h5io.WESTPAH5File(self.we_h5filename, mode, driver=self.we_h5file_driver)
|
|
354
|
+
|
|
355
|
+
h5file_attrs = self.we_h5file['/'].attrs
|
|
356
|
+
h5file_attr_keys = list(h5file_attrs.keys())
|
|
357
|
+
|
|
358
|
+
if 'west_iter_prec' in h5file_attr_keys:
|
|
359
|
+
self.iter_prec = int(h5file_attrs['west_iter_prec'])
|
|
360
|
+
elif 'wemd_iter_prec' in h5file_attr_keys:
|
|
361
|
+
self.iter_prec = int(h5file_attrs['wemd_iter_prec'])
|
|
362
|
+
else:
|
|
363
|
+
log.info('iteration precision not stored in HDF5; using {:d}'.format(self.iter_prec))
|
|
364
|
+
|
|
365
|
+
if 'west_file_format_version' in h5file_attr_keys:
|
|
366
|
+
self.we_h5file_version = h5file_attrs['west_file_format_version']
|
|
367
|
+
elif 'wemd_file_format_version' in h5file_attr_keys:
|
|
368
|
+
self.we_h5file_version = h5file_attrs['wemd_file_format_version']
|
|
369
|
+
else:
|
|
370
|
+
log.info('WEST HDF5 file format version not stored, assuming 0')
|
|
371
|
+
self.we_h5file_version = 0
|
|
372
|
+
|
|
373
|
+
log.debug('opened WEST HDF5 file version {:d}'.format(self.we_h5file_version))
|
|
374
|
+
|
|
375
|
+
def prepare_backing(self): # istates):
|
|
376
|
+
'''Create new HDF5 file'''
|
|
377
|
+
self.we_h5file = h5py.File(self.we_h5filename, 'w', driver=self.we_h5file_driver)
|
|
378
|
+
|
|
379
|
+
with self.flushing_lock():
|
|
380
|
+
self.we_h5file['/'].attrs['west_file_format_version'] = file_format_version
|
|
381
|
+
self.we_h5file['/'].attrs['west_iter_prec'] = self.iter_prec
|
|
382
|
+
self.we_h5file['/'].attrs['west_version'] = westpa.__version__
|
|
383
|
+
self.current_iteration = 0
|
|
384
|
+
self.we_h5file['/'].create_dataset('summary', shape=(1,), dtype=summary_table_dtype, maxshape=(None,))
|
|
385
|
+
self.we_h5file.create_group('/iterations')
|
|
386
|
+
|
|
387
|
+
def close_backing(self):
|
|
388
|
+
if self.we_h5file is not None:
|
|
389
|
+
with self.lock:
|
|
390
|
+
self.we_h5file.close()
|
|
391
|
+
self.we_h5file = None
|
|
392
|
+
|
|
393
|
+
def flush_backing(self):
|
|
394
|
+
if self.we_h5file is not None:
|
|
395
|
+
with self.lock:
|
|
396
|
+
self.we_h5file.flush()
|
|
397
|
+
self.last_flush = time.time()
|
|
398
|
+
|
|
399
|
+
def save_target_states(self, tstates, n_iter=None):
|
|
400
|
+
'''Save the given target states in the HDF5 file; they will be used for the next iteration to
|
|
401
|
+
be propagated. A complete set is required, even if nominally appending to an existing set,
|
|
402
|
+
which simplifies the mapping of IDs to the table.'''
|
|
403
|
+
|
|
404
|
+
system = self.system
|
|
405
|
+
|
|
406
|
+
n_iter = n_iter or self.current_iteration
|
|
407
|
+
|
|
408
|
+
# Assemble all the important data before we start to modify the HDF5 file
|
|
409
|
+
tstates = list(tstates)
|
|
410
|
+
if tstates:
|
|
411
|
+
state_table = np.empty((len(tstates),), dtype=tstate_dtype)
|
|
412
|
+
state_pcoords = np.empty((len(tstates), system.pcoord_ndim), dtype=system.pcoord_dtype)
|
|
413
|
+
for i, state in enumerate(tstates):
|
|
414
|
+
state.state_id = i
|
|
415
|
+
state_table[i]['label'] = state.label
|
|
416
|
+
state_pcoords[i] = state.pcoord
|
|
417
|
+
else:
|
|
418
|
+
state_table = None
|
|
419
|
+
state_pcoords = None
|
|
420
|
+
|
|
421
|
+
# Commit changes to HDF5
|
|
422
|
+
with self.lock:
|
|
423
|
+
master_group = self.we_h5file.require_group('tstates')
|
|
424
|
+
|
|
425
|
+
try:
|
|
426
|
+
master_index = master_group['index']
|
|
427
|
+
except KeyError:
|
|
428
|
+
master_index = master_group.create_dataset('index', shape=(1,), maxshape=(None,), dtype=tstate_index_dtype)
|
|
429
|
+
n_sets = 1
|
|
430
|
+
else:
|
|
431
|
+
n_sets = len(master_index) + 1
|
|
432
|
+
master_index.resize((n_sets,))
|
|
433
|
+
|
|
434
|
+
set_id = n_sets - 1
|
|
435
|
+
master_index_row = master_index[set_id]
|
|
436
|
+
master_index_row['iter_valid'] = n_iter
|
|
437
|
+
master_index_row['n_states'] = len(tstates)
|
|
438
|
+
|
|
439
|
+
if tstates:
|
|
440
|
+
state_group = master_group.create_group(str(set_id))
|
|
441
|
+
master_index_row['group_ref'] = state_group.ref
|
|
442
|
+
state_group['index'] = state_table
|
|
443
|
+
state_group['pcoord'] = state_pcoords
|
|
444
|
+
else:
|
|
445
|
+
master_index_row['group_ref'] = None
|
|
446
|
+
|
|
447
|
+
master_index[set_id] = master_index_row
|
|
448
|
+
|
|
449
|
+
def _find_multi_iter_group(self, n_iter, master_group_name):
|
|
450
|
+
with self.lock:
|
|
451
|
+
master_group = self.we_h5file[master_group_name]
|
|
452
|
+
master_index = master_group['index'][...]
|
|
453
|
+
set_id = np.digitize([n_iter], master_index['iter_valid']) - 1
|
|
454
|
+
group_ref = master_index[set_id]['group_ref']
|
|
455
|
+
|
|
456
|
+
# Check if reference is Null
|
|
457
|
+
if not bool(group_ref):
|
|
458
|
+
return None
|
|
459
|
+
|
|
460
|
+
# This extra [0] is to work around a bug in h5py
|
|
461
|
+
try:
|
|
462
|
+
group = self.we_h5file[group_ref]
|
|
463
|
+
except (TypeError, AttributeError):
|
|
464
|
+
group = self.we_h5file[group_ref[0]]
|
|
465
|
+
else:
|
|
466
|
+
log.debug('h5py fixed; remove alternate code path')
|
|
467
|
+
log.debug('reference {!r} points to group {!r}'.format(group_ref, group))
|
|
468
|
+
return group
|
|
469
|
+
|
|
470
|
+
def find_tstate_group(self, n_iter):
|
|
471
|
+
return self._find_multi_iter_group(n_iter, 'tstates')
|
|
472
|
+
|
|
473
|
+
def find_ibstate_group(self, n_iter):
|
|
474
|
+
return self._find_multi_iter_group(n_iter, 'ibstates')
|
|
475
|
+
|
|
476
|
+
def get_target_states(self, n_iter):
|
|
477
|
+
'''Return a list of Target objects representing the target (sink) states that are in use for iteration n_iter.
|
|
478
|
+
Future iterations are assumed to continue from the most recent set of states.'''
|
|
479
|
+
|
|
480
|
+
with self.lock:
|
|
481
|
+
tstate_group = self.find_tstate_group(n_iter)
|
|
482
|
+
|
|
483
|
+
if tstate_group is not None:
|
|
484
|
+
tstate_index = tstate_group['index'][...]
|
|
485
|
+
tstate_pcoords = tstate_group['pcoord'][...]
|
|
486
|
+
|
|
487
|
+
tstates = [
|
|
488
|
+
TargetState(state_id=i, label=h5io.tostr(row['label']), pcoord=pcoord.copy())
|
|
489
|
+
for (i, (row, pcoord)) in enumerate(zip(tstate_index, tstate_pcoords))
|
|
490
|
+
]
|
|
491
|
+
else:
|
|
492
|
+
tstates = []
|
|
493
|
+
|
|
494
|
+
return tstates
|
|
495
|
+
|
|
496
|
+
def create_ibstate_group(self, basis_states, n_iter=None):
|
|
497
|
+
'''Create the group used to store basis states and initial states (whose definitions are always
|
|
498
|
+
coupled). This group is hard-linked into all iteration groups that use these basis and
|
|
499
|
+
initial states.'''
|
|
500
|
+
|
|
501
|
+
with self.lock:
|
|
502
|
+
n_iter = n_iter or self.current_iteration
|
|
503
|
+
master_group = self.we_h5file.require_group('ibstates')
|
|
504
|
+
|
|
505
|
+
try:
|
|
506
|
+
master_index = master_group['index']
|
|
507
|
+
except KeyError:
|
|
508
|
+
master_index = master_group.create_dataset('index', dtype=ibstate_index_dtype, shape=(1,), maxshape=(None,))
|
|
509
|
+
n_sets = 1
|
|
510
|
+
else:
|
|
511
|
+
n_sets = len(master_index) + 1
|
|
512
|
+
master_index.resize((n_sets,))
|
|
513
|
+
|
|
514
|
+
set_id = n_sets - 1
|
|
515
|
+
master_index_row = master_index[set_id]
|
|
516
|
+
master_index_row['iter_valid'] = n_iter
|
|
517
|
+
master_index_row['n_bstates'] = len(basis_states)
|
|
518
|
+
state_group = master_group.create_group(str(set_id))
|
|
519
|
+
master_index_row['group_ref'] = state_group.ref
|
|
520
|
+
|
|
521
|
+
if basis_states:
|
|
522
|
+
system = self.system
|
|
523
|
+
state_table = np.empty((len(basis_states),), dtype=bstate_dtype)
|
|
524
|
+
state_pcoords = np.empty((len(basis_states), system.pcoord_ndim), dtype=system.pcoord_dtype)
|
|
525
|
+
for i, state in enumerate(basis_states):
|
|
526
|
+
state.state_id = i
|
|
527
|
+
state_table[i]['label'] = state.label
|
|
528
|
+
state_table[i]['probability'] = state.probability
|
|
529
|
+
state_table[i]['auxref'] = state.auxref or ''
|
|
530
|
+
state_pcoords[i] = state.pcoord
|
|
531
|
+
|
|
532
|
+
state_group['bstate_index'] = state_table
|
|
533
|
+
state_group['bstate_pcoord'] = state_pcoords
|
|
534
|
+
|
|
535
|
+
master_index[set_id] = master_index_row
|
|
536
|
+
return state_group
|
|
537
|
+
|
|
538
|
+
def create_ibstate_iter_h5file(self, basis_states):
|
|
539
|
+
'''Create the per-iteration HDF5 file for the basis states (i.e., iteration 0).
|
|
540
|
+
This special treatment is needed so that the analysis tools can access basis states
|
|
541
|
+
more easily.'''
|
|
542
|
+
|
|
543
|
+
if not self.store_h5:
|
|
544
|
+
return
|
|
545
|
+
|
|
546
|
+
segments = []
|
|
547
|
+
for i, state in enumerate(basis_states):
|
|
548
|
+
dummy_segment = Segment(
|
|
549
|
+
n_iter=0,
|
|
550
|
+
seg_id=state.state_id,
|
|
551
|
+
parent_id=-(state.state_id + 1),
|
|
552
|
+
weight=state.probability,
|
|
553
|
+
wtg_parent_ids=None,
|
|
554
|
+
pcoord=state.pcoord,
|
|
555
|
+
status=Segment.SEG_STATUS_UNSET,
|
|
556
|
+
data=state.data,
|
|
557
|
+
)
|
|
558
|
+
segments.append(dummy_segment)
|
|
559
|
+
|
|
560
|
+
# # link the iteration file in west.h5
|
|
561
|
+
self.prepare_iteration(0, segments)
|
|
562
|
+
self.update_iter_h5file(0, segments)
|
|
563
|
+
|
|
564
|
+
def update_iter_h5file(self, n_iter, segments):
|
|
565
|
+
'''Write out the per-iteration HDF5 file with given segments and add an external link to it
|
|
566
|
+
in the main HDF5 file (west.h5) if the link is not present.'''
|
|
567
|
+
|
|
568
|
+
if not self.store_h5:
|
|
569
|
+
return
|
|
570
|
+
|
|
571
|
+
west_h5_file = makepath(self.we_h5filename)
|
|
572
|
+
iter_ref_h5_file = makepath(self.iter_ref_h5_template, {'n_iter': n_iter})
|
|
573
|
+
iter_ref_rel_path = relpath(iter_ref_h5_file, dirname(west_h5_file))
|
|
574
|
+
|
|
575
|
+
with h5io.WESTIterationFile(iter_ref_h5_file, 'a') as outf:
|
|
576
|
+
for segment in segments:
|
|
577
|
+
outf.write_segment(segment, True)
|
|
578
|
+
|
|
579
|
+
iter_group = self.get_iter_group(n_iter)
|
|
580
|
+
|
|
581
|
+
if 'trajectories' not in iter_group:
|
|
582
|
+
iter_group['trajectories'] = h5py.ExternalLink(iter_ref_rel_path, '/')
|
|
583
|
+
|
|
584
|
+
def get_basis_states(self, n_iter=None):
|
|
585
|
+
'''Return a list of BasisState objects representing the basis states that are in use for iteration n_iter.'''
|
|
586
|
+
|
|
587
|
+
with self.lock:
|
|
588
|
+
n_iter = n_iter or self.current_iteration
|
|
589
|
+
ibstate_group = self.find_ibstate_group(n_iter)
|
|
590
|
+
try:
|
|
591
|
+
bstate_index = ibstate_group['bstate_index'][...]
|
|
592
|
+
except KeyError:
|
|
593
|
+
return []
|
|
594
|
+
bstate_pcoords = ibstate_group['bstate_pcoord'][...]
|
|
595
|
+
bstates = [
|
|
596
|
+
BasisState(
|
|
597
|
+
state_id=i,
|
|
598
|
+
label=h5io.tostr(row['label']),
|
|
599
|
+
probability=row['probability'],
|
|
600
|
+
auxref=h5io.tostr(row['auxref']) or None,
|
|
601
|
+
pcoord=pcoord.copy(),
|
|
602
|
+
)
|
|
603
|
+
for (i, (row, pcoord)) in enumerate(zip(bstate_index, bstate_pcoords))
|
|
604
|
+
]
|
|
605
|
+
|
|
606
|
+
bstate_total_prob = sum(bstate.probability for bstate in bstates)
|
|
607
|
+
|
|
608
|
+
# This should run once in the second iteration, and only if start-states are specified,
|
|
609
|
+
# but is necessary to re-normalize (i.e. normalize without start-state probabilities included)
|
|
610
|
+
for i, bstate in enumerate(bstates):
|
|
611
|
+
bstate.probability /= bstate_total_prob
|
|
612
|
+
bstates[i] = bstate
|
|
613
|
+
return bstates
|
|
614
|
+
|
|
615
|
+
def create_initial_states(self, n_states, n_iter=None):
|
|
616
|
+
'''Create storage for ``n_states`` initial states associated with iteration ``n_iter``, and
|
|
617
|
+
return bare InitialState objects with only state_id set.'''
|
|
618
|
+
|
|
619
|
+
system = self.system
|
|
620
|
+
with self.lock:
|
|
621
|
+
n_iter = n_iter or self.current_iteration
|
|
622
|
+
ibstate_group = self.find_ibstate_group(n_iter)
|
|
623
|
+
|
|
624
|
+
try:
|
|
625
|
+
istate_index = ibstate_group['istate_index']
|
|
626
|
+
except KeyError:
|
|
627
|
+
istate_index = ibstate_group.create_dataset('istate_index', dtype=istate_dtype, shape=(n_states,), maxshape=(None,))
|
|
628
|
+
istate_pcoords = ibstate_group.create_dataset(
|
|
629
|
+
'istate_pcoord',
|
|
630
|
+
dtype=system.pcoord_dtype,
|
|
631
|
+
shape=(n_states, system.pcoord_ndim),
|
|
632
|
+
maxshape=(None, system.pcoord_ndim),
|
|
633
|
+
)
|
|
634
|
+
len_index = len(istate_index)
|
|
635
|
+
first_id = 0
|
|
636
|
+
else:
|
|
637
|
+
first_id = len(istate_index)
|
|
638
|
+
len_index = len(istate_index) + n_states
|
|
639
|
+
istate_index.resize((len_index,))
|
|
640
|
+
istate_pcoords = ibstate_group['istate_pcoord']
|
|
641
|
+
istate_pcoords.resize((len_index, system.pcoord_ndim))
|
|
642
|
+
|
|
643
|
+
index_entries = istate_index[first_id:len_index]
|
|
644
|
+
new_istates = []
|
|
645
|
+
for irow, row in enumerate(index_entries):
|
|
646
|
+
row['iter_created'] = n_iter
|
|
647
|
+
row['istate_status'] = InitialState.ISTATE_STATUS_PENDING
|
|
648
|
+
new_istates.append(
|
|
649
|
+
InitialState(
|
|
650
|
+
state_id=first_id + irow,
|
|
651
|
+
basis_state_id=None,
|
|
652
|
+
iter_created=n_iter,
|
|
653
|
+
istate_status=InitialState.ISTATE_STATUS_PENDING,
|
|
654
|
+
)
|
|
655
|
+
)
|
|
656
|
+
istate_index[first_id:len_index] = index_entries
|
|
657
|
+
return new_istates
|
|
658
|
+
|
|
659
|
+
def update_initial_states(self, initial_states, n_iter=None):
|
|
660
|
+
'''Save the given initial states in the HDF5 file'''
|
|
661
|
+
|
|
662
|
+
system = self.system
|
|
663
|
+
initial_states = sorted(initial_states, key=attrgetter('state_id'))
|
|
664
|
+
if not initial_states:
|
|
665
|
+
return
|
|
666
|
+
|
|
667
|
+
with self.lock:
|
|
668
|
+
n_iter = n_iter or self.current_iteration
|
|
669
|
+
ibstate_group = self.find_ibstate_group(n_iter)
|
|
670
|
+
state_ids = [state.state_id for state in initial_states]
|
|
671
|
+
index_entries = ibstate_group['istate_index'][state_ids]
|
|
672
|
+
pcoord_vals = np.empty((len(initial_states), system.pcoord_ndim), dtype=system.pcoord_dtype)
|
|
673
|
+
for i, initial_state in enumerate(initial_states):
|
|
674
|
+
index_entries[i]['iter_created'] = initial_state.iter_created
|
|
675
|
+
index_entries[i]['iter_used'] = initial_state.iter_used or InitialState.ISTATE_UNUSED
|
|
676
|
+
index_entries[i]['basis_state_id'] = (
|
|
677
|
+
initial_state.basis_state_id if initial_state.basis_state_id is not None else -1
|
|
678
|
+
)
|
|
679
|
+
index_entries[i]['istate_type'] = initial_state.istate_type or InitialState.ISTATE_TYPE_UNSET
|
|
680
|
+
index_entries[i]['istate_status'] = initial_state.istate_status or InitialState.ISTATE_STATUS_PENDING
|
|
681
|
+
pcoord_vals[i] = initial_state.pcoord
|
|
682
|
+
|
|
683
|
+
index_entries[i]['basis_auxref'] = initial_state.basis_auxref or ""
|
|
684
|
+
|
|
685
|
+
ibstate_group['istate_index'][state_ids] = index_entries
|
|
686
|
+
ibstate_group['istate_pcoord'][state_ids] = pcoord_vals
|
|
687
|
+
|
|
688
|
+
if self.store_h5:
|
|
689
|
+
segments = []
|
|
690
|
+
for i, state in enumerate(initial_states):
|
|
691
|
+
dummy_segment = Segment(
|
|
692
|
+
n_iter=-state.iter_created,
|
|
693
|
+
seg_id=state.state_id,
|
|
694
|
+
parent_id=state.basis_state_id,
|
|
695
|
+
wtg_parent_ids=None,
|
|
696
|
+
pcoord=state.pcoord,
|
|
697
|
+
status=Segment.SEG_STATUS_PREPARED,
|
|
698
|
+
data=state.data,
|
|
699
|
+
)
|
|
700
|
+
segments.append(dummy_segment)
|
|
701
|
+
self.update_iter_h5file(0, segments)
|
|
702
|
+
|
|
703
|
+
def get_initial_states(self, n_iter=None):
|
|
704
|
+
states = []
|
|
705
|
+
with self.lock:
|
|
706
|
+
n_iter = n_iter or self.current_iteration
|
|
707
|
+
ibstate_group = self.find_ibstate_group(n_iter)
|
|
708
|
+
try:
|
|
709
|
+
istate_index = ibstate_group['istate_index'][...]
|
|
710
|
+
except KeyError:
|
|
711
|
+
return []
|
|
712
|
+
istate_pcoords = ibstate_group['istate_pcoord'][...]
|
|
713
|
+
|
|
714
|
+
for state_id, (state, pcoord) in enumerate(zip(istate_index, istate_pcoords)):
|
|
715
|
+
states.append(
|
|
716
|
+
InitialState(
|
|
717
|
+
state_id=state_id,
|
|
718
|
+
basis_state_id=int(state['basis_state_id']),
|
|
719
|
+
iter_created=int(state['iter_created']),
|
|
720
|
+
iter_used=int(state['iter_used']),
|
|
721
|
+
istate_type=int(state['istate_type']),
|
|
722
|
+
basis_auxref=h5io.tostr(state['basis_auxref']),
|
|
723
|
+
pcoord=pcoord.copy(),
|
|
724
|
+
)
|
|
725
|
+
)
|
|
726
|
+
return states
|
|
727
|
+
|
|
728
|
+
def get_segment_initial_states(self, segments, n_iter=None):
|
|
729
|
+
'''Retrieve all initial states referenced by the given segments.'''
|
|
730
|
+
|
|
731
|
+
with self.lock:
|
|
732
|
+
n_iter = n_iter or self.current_iteration
|
|
733
|
+
ibstate_group = self.get_iter_group(n_iter)['ibstates']
|
|
734
|
+
|
|
735
|
+
istate_ids = {-int(segment.parent_id + 1) for segment in segments if segment.parent_id < 0}
|
|
736
|
+
sorted_istate_ids = sorted(istate_ids)
|
|
737
|
+
if not sorted_istate_ids:
|
|
738
|
+
return []
|
|
739
|
+
|
|
740
|
+
istate_rows = ibstate_group['istate_index'][sorted_istate_ids][...]
|
|
741
|
+
istate_pcoords = ibstate_group['istate_pcoord'][sorted_istate_ids][...]
|
|
742
|
+
istates = []
|
|
743
|
+
|
|
744
|
+
for state_id, state, pcoord in zip(sorted_istate_ids, istate_rows, istate_pcoords):
|
|
745
|
+
try:
|
|
746
|
+
b_auxref = h5io.tostr(state['basis_auxref'])
|
|
747
|
+
except ValueError:
|
|
748
|
+
b_auxref = ''
|
|
749
|
+
istate = InitialState(
|
|
750
|
+
state_id=state_id,
|
|
751
|
+
basis_state_id=int(state['basis_state_id']),
|
|
752
|
+
iter_created=int(state['iter_created']),
|
|
753
|
+
iter_used=int(state['iter_used']),
|
|
754
|
+
istate_type=int(state['istate_type']),
|
|
755
|
+
basis_auxref=b_auxref,
|
|
756
|
+
pcoord=pcoord.copy(),
|
|
757
|
+
)
|
|
758
|
+
istates.append(istate)
|
|
759
|
+
return istates
|
|
760
|
+
|
|
761
|
+
def get_unused_initial_states(self, n_states=None, n_iter=None):
|
|
762
|
+
'''Retrieve any prepared but unused initial states applicable to the given iteration.
|
|
763
|
+
Up to ``n_states`` states are returned; if ``n_states`` is None, then all unused states
|
|
764
|
+
are returned.'''
|
|
765
|
+
|
|
766
|
+
n_states = n_states or sys.maxsize
|
|
767
|
+
ISTATE_UNUSED = InitialState.ISTATE_UNUSED
|
|
768
|
+
ISTATE_STATUS_PREPARED = InitialState.ISTATE_STATUS_PREPARED
|
|
769
|
+
with self.lock:
|
|
770
|
+
n_iter = n_iter or self.current_iteration
|
|
771
|
+
ibstate_group = self.find_ibstate_group(n_iter)
|
|
772
|
+
istate_index = ibstate_group['istate_index']
|
|
773
|
+
istate_pcoords = ibstate_group['istate_pcoord']
|
|
774
|
+
n_index_entries = istate_index.len()
|
|
775
|
+
chunksize = self.table_scan_chunksize
|
|
776
|
+
|
|
777
|
+
states = []
|
|
778
|
+
istart = 0
|
|
779
|
+
while istart < n_index_entries and len(states) < n_states:
|
|
780
|
+
istop = min(istart + chunksize, n_index_entries)
|
|
781
|
+
istate_chunk = istate_index[istart:istop]
|
|
782
|
+
pcoord_chunk = istate_pcoords[istart:istop]
|
|
783
|
+
# state_ids = np.arange(istart,istop,dtype=np.uint)
|
|
784
|
+
|
|
785
|
+
for ci in range(len(istate_chunk)):
|
|
786
|
+
row = istate_chunk[ci]
|
|
787
|
+
pcoord = pcoord_chunk[ci]
|
|
788
|
+
state_id = istart + ci
|
|
789
|
+
if row['iter_used'] == ISTATE_UNUSED and row['istate_status'] == ISTATE_STATUS_PREPARED:
|
|
790
|
+
istate = InitialState(
|
|
791
|
+
state_id=state_id,
|
|
792
|
+
basis_state_id=int(row['basis_state_id']),
|
|
793
|
+
iter_created=int(row['iter_created']),
|
|
794
|
+
iter_used=0,
|
|
795
|
+
istate_type=int(row['istate_type']),
|
|
796
|
+
pcoord=pcoord.copy(),
|
|
797
|
+
istate_status=ISTATE_STATUS_PREPARED,
|
|
798
|
+
)
|
|
799
|
+
states.append(istate)
|
|
800
|
+
del row, pcoord, state_id
|
|
801
|
+
istart += chunksize
|
|
802
|
+
del istate_chunk, pcoord_chunk # , state_ids, unused, ids_of_unused
|
|
803
|
+
log.debug('found {:d} unused states'.format(len(states)))
|
|
804
|
+
return states[:n_states]
|
|
805
|
+
|
|
806
|
+
def prepare_iteration(self, n_iter, segments):
|
|
807
|
+
"""Prepare for a new iteration by creating space to store the new iteration's data.
|
|
808
|
+
The number of segments, their IDs, and their lineage must be determined and included
|
|
809
|
+
in the set of segments passed in."""
|
|
810
|
+
|
|
811
|
+
log.debug('preparing HDF5 group for iteration %d (%d segments)' % (n_iter, len(segments)))
|
|
812
|
+
|
|
813
|
+
# Ensure we have a list for guaranteed ordering
|
|
814
|
+
init = n_iter == 0
|
|
815
|
+
segments = list(segments)
|
|
816
|
+
n_particles = len(segments)
|
|
817
|
+
system = self.system
|
|
818
|
+
pcoord_ndim = system.pcoord_ndim
|
|
819
|
+
pcoord_len = 2 if init else system.pcoord_len
|
|
820
|
+
pcoord_dtype = system.pcoord_dtype
|
|
821
|
+
|
|
822
|
+
with self.lock:
|
|
823
|
+
if not init:
|
|
824
|
+
# Create a table of summary information about each iteration
|
|
825
|
+
summary_table = self.we_h5file['summary']
|
|
826
|
+
if len(summary_table) < n_iter:
|
|
827
|
+
summary_table.resize((n_iter + 1,))
|
|
828
|
+
|
|
829
|
+
iter_group = self.require_iter_group(n_iter)
|
|
830
|
+
|
|
831
|
+
for linkname in ('seg_index', 'pcoord', 'wtgraph'):
|
|
832
|
+
try:
|
|
833
|
+
del iter_group[linkname]
|
|
834
|
+
except KeyError:
|
|
835
|
+
pass
|
|
836
|
+
|
|
837
|
+
# everything indexed by [particle] goes in an index table
|
|
838
|
+
seg_index_table_ds = iter_group.create_dataset('seg_index', shape=(n_particles,), dtype=seg_index_dtype)
|
|
839
|
+
# unfortunately, h5py doesn't like in-place modification of individual fields; it expects
|
|
840
|
+
# tuples. So, construct everything in a numpy array and then dump the whole thing into hdf5
|
|
841
|
+
# In fact, this appears to be an h5py best practice (collect as much in ram as possible and then dump)
|
|
842
|
+
seg_index_table = seg_index_table_ds[...]
|
|
843
|
+
|
|
844
|
+
if not init:
|
|
845
|
+
summary_row = np.zeros((1,), dtype=summary_table_dtype)
|
|
846
|
+
summary_row['n_particles'] = n_particles
|
|
847
|
+
summary_row['norm'] = np.add.reduce(list(map(attrgetter('weight'), segments)))
|
|
848
|
+
summary_table[n_iter - 1] = summary_row
|
|
849
|
+
|
|
850
|
+
# pcoord is indexed as [particle, time, dimension]
|
|
851
|
+
pcoord_opts = self.dataset_options.get('pcoord', {'name': 'pcoord', 'h5path': 'pcoord', 'compression': False})
|
|
852
|
+
shape = (n_particles, pcoord_len, pcoord_ndim)
|
|
853
|
+
pcoord_ds = create_dataset_from_dsopts(iter_group, pcoord_opts, shape, pcoord_dtype)
|
|
854
|
+
pcoord = np.empty((n_particles, pcoord_len, pcoord_ndim), pcoord_dtype)
|
|
855
|
+
|
|
856
|
+
total_parents = 0
|
|
857
|
+
for seg_id, segment in enumerate(segments):
|
|
858
|
+
if segment.seg_id is not None:
|
|
859
|
+
assert segment.seg_id == seg_id
|
|
860
|
+
else:
|
|
861
|
+
segment.seg_id = seg_id
|
|
862
|
+
# Parent must be set, though what it means depends on initpoint_type
|
|
863
|
+
assert segment.parent_id is not None
|
|
864
|
+
segment.seg_id = seg_id
|
|
865
|
+
seg_index_table[seg_id]['status'] = segment.status
|
|
866
|
+
seg_index_table[seg_id]['weight'] = segment.weight
|
|
867
|
+
seg_index_table[seg_id]['parent_id'] = segment.parent_id
|
|
868
|
+
seg_index_table[seg_id]['wtg_n_parents'] = len(segment.wtg_parent_ids)
|
|
869
|
+
seg_index_table[seg_id]['wtg_offset'] = total_parents
|
|
870
|
+
total_parents += len(segment.wtg_parent_ids)
|
|
871
|
+
|
|
872
|
+
# Assign progress coordinate if any exists
|
|
873
|
+
if segment.pcoord is not None:
|
|
874
|
+
if init:
|
|
875
|
+
if segment.pcoord.shape != pcoord.shape[2:]:
|
|
876
|
+
raise ValueError(
|
|
877
|
+
'basis state pcoord shape [%r] does not match expected shape [%r]'
|
|
878
|
+
% (segment.pcoord.shape, pcoord.shape[2:])
|
|
879
|
+
)
|
|
880
|
+
# Initial pcoord
|
|
881
|
+
pcoord[seg_id, 1, :] = segment.pcoord[:]
|
|
882
|
+
else:
|
|
883
|
+
if len(segment.pcoord) == 1:
|
|
884
|
+
# Initial pcoord
|
|
885
|
+
pcoord[seg_id, 0, :] = segment.pcoord[0, :]
|
|
886
|
+
elif segment.pcoord.shape != pcoord.shape[1:]:
|
|
887
|
+
raise ValueError(
|
|
888
|
+
'segment pcoord shape [%r] does not match expected shape [%r]'
|
|
889
|
+
% (segment.pcoord.shape, pcoord.shape[1:])
|
|
890
|
+
)
|
|
891
|
+
else:
|
|
892
|
+
pcoord[seg_id, ...] = segment.pcoord
|
|
893
|
+
|
|
894
|
+
if total_parents > 0:
|
|
895
|
+
wtgraph_ds = iter_group.create_dataset('wtgraph', (total_parents,), seg_id_dtype, compression='gzip', shuffle=True)
|
|
896
|
+
parents = np.empty((total_parents,), seg_id_dtype)
|
|
897
|
+
|
|
898
|
+
for seg_id, segment in enumerate(segments):
|
|
899
|
+
offset = seg_index_table[seg_id]['wtg_offset']
|
|
900
|
+
extent = seg_index_table[seg_id]['wtg_n_parents']
|
|
901
|
+
parent_list = list(segment.wtg_parent_ids)
|
|
902
|
+
parents[offset : offset + extent] = parent_list[:]
|
|
903
|
+
|
|
904
|
+
assert set(parents[offset : offset + extent]) == set(segment.wtg_parent_ids)
|
|
905
|
+
|
|
906
|
+
wtgraph_ds[:] = parents
|
|
907
|
+
|
|
908
|
+
# Create convenient hard links
|
|
909
|
+
self.update_iter_group_links(n_iter)
|
|
910
|
+
|
|
911
|
+
# Since we accumulated many of these changes in RAM (and not directly in HDF5), propagate
|
|
912
|
+
# the changes out to HDF5
|
|
913
|
+
seg_index_table_ds[:] = seg_index_table
|
|
914
|
+
pcoord_ds[...] = pcoord
|
|
915
|
+
|
|
916
|
+
def update_iter_group_links(self, n_iter):
|
|
917
|
+
'''Update the per-iteration hard links pointing to the tables of target and initial/basis states for the
|
|
918
|
+
given iteration. These links are not used by this class, but are remarkably convenient for third-party
|
|
919
|
+
analysis tools and hdfview.'''
|
|
920
|
+
|
|
921
|
+
with self.lock:
|
|
922
|
+
iter_group = self.require_iter_group(n_iter)
|
|
923
|
+
|
|
924
|
+
for linkname in ('ibstates', 'tstates'):
|
|
925
|
+
try:
|
|
926
|
+
del iter_group[linkname]
|
|
927
|
+
except KeyError:
|
|
928
|
+
pass
|
|
929
|
+
|
|
930
|
+
iter_group['ibstates'] = self.find_ibstate_group(n_iter)
|
|
931
|
+
|
|
932
|
+
tstate_group = self.find_tstate_group(n_iter)
|
|
933
|
+
if tstate_group is not None:
|
|
934
|
+
iter_group['tstates'] = tstate_group
|
|
935
|
+
|
|
936
|
+
def get_iter_summary(self, n_iter=None):
|
|
937
|
+
n_iter = n_iter or self.current_iteration
|
|
938
|
+
with self.lock:
|
|
939
|
+
return self.we_h5file['summary'][n_iter - 1]
|
|
940
|
+
|
|
941
|
+
def update_iter_summary(self, summary, n_iter=None):
|
|
942
|
+
n_iter = n_iter or self.current_iteration
|
|
943
|
+
with self.lock:
|
|
944
|
+
self.we_h5file['summary'][n_iter - 1] = summary
|
|
945
|
+
|
|
946
|
+
def del_iter_summary(self, min_iter): # delete the iterations starting at min_iter
|
|
947
|
+
with self.lock:
|
|
948
|
+
self.we_h5file['summary'].resize((min_iter - 1,))
|
|
949
|
+
|
|
950
|
+
def update_segments(self, n_iter, segments):
|
|
951
|
+
'''Update segment information in the HDF5 file; all prior information for each
|
|
952
|
+
``segment`` is overwritten, except for parent and weight transfer information.'''
|
|
953
|
+
|
|
954
|
+
segments = sorted(segments, key=attrgetter('seg_id'))
|
|
955
|
+
|
|
956
|
+
with self.lock:
|
|
957
|
+
iter_group = self.get_iter_group(n_iter)
|
|
958
|
+
|
|
959
|
+
pc_dsid = iter_group['pcoord'].id
|
|
960
|
+
si_dsid = iter_group['seg_index'].id
|
|
961
|
+
|
|
962
|
+
seg_ids = [segment.seg_id for segment in segments]
|
|
963
|
+
n_segments = len(segments)
|
|
964
|
+
n_total_segments = si_dsid.shape[0]
|
|
965
|
+
system = self.system
|
|
966
|
+
pcoord_ndim = system.pcoord_ndim
|
|
967
|
+
pcoord_len = system.pcoord_len
|
|
968
|
+
pcoord_dtype = system.pcoord_dtype
|
|
969
|
+
|
|
970
|
+
seg_index_entries = np.empty((n_segments,), dtype=seg_index_dtype)
|
|
971
|
+
pcoord_entries = np.empty((n_segments, pcoord_len, pcoord_ndim), dtype=pcoord_dtype)
|
|
972
|
+
|
|
973
|
+
pc_msel = h5s.create_simple(pcoord_entries.shape, (h5s.UNLIMITED,) * pcoord_entries.ndim)
|
|
974
|
+
pc_msel.select_all()
|
|
975
|
+
si_msel = h5s.create_simple(seg_index_entries.shape, (h5s.UNLIMITED,))
|
|
976
|
+
si_msel.select_all()
|
|
977
|
+
pc_fsel = pc_dsid.get_space()
|
|
978
|
+
si_fsel = si_dsid.get_space()
|
|
979
|
+
|
|
980
|
+
for iseg in range(n_segments):
|
|
981
|
+
seg_id = seg_ids[iseg]
|
|
982
|
+
op = h5s.SELECT_OR if iseg != 0 else h5s.SELECT_SET
|
|
983
|
+
si_fsel.select_hyperslab((seg_id,), (1,), op=op)
|
|
984
|
+
pc_fsel.select_hyperslab((seg_id, 0, 0), (1, pcoord_len, pcoord_ndim), op=op)
|
|
985
|
+
|
|
986
|
+
# read summary data so that we have valud parent and weight transfer information
|
|
987
|
+
si_dsid.read(si_msel, si_fsel, seg_index_entries)
|
|
988
|
+
|
|
989
|
+
for iseg, (segment, ientry) in enumerate(zip(segments, seg_index_entries)):
|
|
990
|
+
ientry['status'] = segment.status
|
|
991
|
+
ientry['endpoint_type'] = segment.endpoint_type or Segment.SEG_ENDPOINT_UNSET
|
|
992
|
+
ientry['cputime'] = segment.cputime
|
|
993
|
+
ientry['walltime'] = segment.walltime
|
|
994
|
+
ientry['weight'] = segment.weight
|
|
995
|
+
|
|
996
|
+
pcoord_entries[iseg] = segment.pcoord
|
|
997
|
+
|
|
998
|
+
# write progress coordinates and index using low level HDF5 functions for efficiency
|
|
999
|
+
si_dsid.write(si_msel, si_fsel, seg_index_entries)
|
|
1000
|
+
pc_dsid.write(pc_msel, pc_fsel, pcoord_entries)
|
|
1001
|
+
|
|
1002
|
+
# Now, to deal with auxiliary data
|
|
1003
|
+
# If any segment has any auxiliary data, then the aux dataset must spring into
|
|
1004
|
+
# existence. Each is named according to the name in segment.data, and has shape
|
|
1005
|
+
# (n_total_segs, ...) where the ... is the shape of the data in segment.data (and may be empty
|
|
1006
|
+
# in the case of scalar data) and dtype is taken from the data type of the data entry
|
|
1007
|
+
# compression is on by default for datasets that will be more than 1MiB
|
|
1008
|
+
|
|
1009
|
+
# a mapping of data set name to (per-segment shape, data type) tuples
|
|
1010
|
+
dsets = {}
|
|
1011
|
+
|
|
1012
|
+
# First we scan for presence, shape, and data type of auxiliary data sets
|
|
1013
|
+
for segment in segments:
|
|
1014
|
+
if segment.data:
|
|
1015
|
+
for dsname in segment.data:
|
|
1016
|
+
if dsname.startswith('iterh5/'):
|
|
1017
|
+
continue
|
|
1018
|
+
data = np.asarray(segment.data[dsname], order='C')
|
|
1019
|
+
segment.data[dsname] = data
|
|
1020
|
+
dsets[dsname] = (data.shape, data.dtype)
|
|
1021
|
+
|
|
1022
|
+
# Then we iterate over data sets and store data
|
|
1023
|
+
if dsets:
|
|
1024
|
+
for dsname, (shape, dtype) in dsets.items():
|
|
1025
|
+
# dset = self._require_aux_dataset(iter_group, dsname, n_total_segments, shape, dtype)
|
|
1026
|
+
try:
|
|
1027
|
+
dsopts = self.dataset_options[dsname]
|
|
1028
|
+
except KeyError:
|
|
1029
|
+
dsopts = normalize_dataset_options({'name': dsname}, path_prefix='auxdata')
|
|
1030
|
+
|
|
1031
|
+
shape = (n_total_segments,) + shape
|
|
1032
|
+
dset = require_dataset_from_dsopts(
|
|
1033
|
+
iter_group, dsopts, shape, dtype, autocompress_threshold=self.aux_compression_threshold, n_iter=n_iter
|
|
1034
|
+
)
|
|
1035
|
+
if dset is None:
|
|
1036
|
+
# storage is suppressed
|
|
1037
|
+
continue
|
|
1038
|
+
for segment in segments:
|
|
1039
|
+
try:
|
|
1040
|
+
auxdataset = segment.data[dsname]
|
|
1041
|
+
except KeyError:
|
|
1042
|
+
pass
|
|
1043
|
+
else:
|
|
1044
|
+
source_rank = len(auxdataset.shape)
|
|
1045
|
+
source_sel = h5s.create_simple(auxdataset.shape, (h5s.UNLIMITED,) * source_rank)
|
|
1046
|
+
source_sel.select_all()
|
|
1047
|
+
dest_sel = dset.id.get_space()
|
|
1048
|
+
dest_sel.select_hyperslab((segment.seg_id,) + (0,) * source_rank, (1,) + auxdataset.shape)
|
|
1049
|
+
dset.id.write(source_sel, dest_sel, auxdataset)
|
|
1050
|
+
if 'delram' in list(dsopts.keys()):
|
|
1051
|
+
del dsets[dsname]
|
|
1052
|
+
|
|
1053
|
+
self.update_iter_h5file(n_iter, segments)
|
|
1054
|
+
|
|
1055
|
+
def get_segments(self, n_iter=None, seg_ids=None, load_pcoords=True):
|
|
1056
|
+
'''Return the given (or all) segments from a given iteration.
|
|
1057
|
+
|
|
1058
|
+
If the optional parameter ``load_auxdata`` is true, then all auxiliary datasets
|
|
1059
|
+
available are loaded and mapped onto the ``data`` dictionary of each segment. If
|
|
1060
|
+
``load_auxdata`` is None, then use the default ``self.auto_load_auxdata``, which can
|
|
1061
|
+
be set by the option ``load_auxdata`` in the ``[data]`` section of ``west.cfg``. This
|
|
1062
|
+
essentially requires as much RAM as there is per-iteration auxiliary data, so this
|
|
1063
|
+
behavior is not on by default.'''
|
|
1064
|
+
|
|
1065
|
+
n_iter = n_iter or self.current_iteration
|
|
1066
|
+
file_version = self.we_h5file_version
|
|
1067
|
+
|
|
1068
|
+
with self.lock:
|
|
1069
|
+
iter_group = self.get_iter_group(n_iter)
|
|
1070
|
+
seg_index_ds = iter_group['seg_index']
|
|
1071
|
+
|
|
1072
|
+
if file_version < 5:
|
|
1073
|
+
all_parent_ids = iter_group['parents'][...]
|
|
1074
|
+
else:
|
|
1075
|
+
all_parent_ids = iter_group['wtgraph'][...]
|
|
1076
|
+
|
|
1077
|
+
if seg_ids is not None:
|
|
1078
|
+
seg_ids = list(sorted(seg_ids))
|
|
1079
|
+
seg_index_entries = seg_index_ds[seg_ids]
|
|
1080
|
+
if load_pcoords:
|
|
1081
|
+
pcoord_entries = iter_group['pcoord'][seg_ids]
|
|
1082
|
+
else:
|
|
1083
|
+
seg_ids = list(range(len(seg_index_ds)))
|
|
1084
|
+
seg_index_entries = seg_index_ds[...]
|
|
1085
|
+
if load_pcoords:
|
|
1086
|
+
pcoord_entries = iter_group['pcoord'][...]
|
|
1087
|
+
|
|
1088
|
+
segments = []
|
|
1089
|
+
|
|
1090
|
+
for iseg, (seg_id, row) in enumerate(zip(seg_ids, seg_index_entries)):
|
|
1091
|
+
segment = Segment(
|
|
1092
|
+
seg_id=seg_id,
|
|
1093
|
+
n_iter=n_iter,
|
|
1094
|
+
status=int(row['status']),
|
|
1095
|
+
endpoint_type=int(row['endpoint_type']),
|
|
1096
|
+
walltime=float(row['walltime']),
|
|
1097
|
+
cputime=float(row['cputime']),
|
|
1098
|
+
weight=float(row['weight']),
|
|
1099
|
+
)
|
|
1100
|
+
|
|
1101
|
+
if load_pcoords:
|
|
1102
|
+
segment.pcoord = pcoord_entries[iseg]
|
|
1103
|
+
|
|
1104
|
+
if file_version < 5:
|
|
1105
|
+
wtg_n_parents = row['n_parents']
|
|
1106
|
+
wtg_offset = row['parents_offset']
|
|
1107
|
+
wtg_parent_ids = all_parent_ids[wtg_offset : wtg_offset + wtg_n_parents]
|
|
1108
|
+
segment.parent_id = int(wtg_parent_ids[0])
|
|
1109
|
+
else:
|
|
1110
|
+
wtg_n_parents = row['wtg_n_parents']
|
|
1111
|
+
wtg_offset = row['wtg_offset']
|
|
1112
|
+
wtg_parent_ids = all_parent_ids[wtg_offset : wtg_offset + wtg_n_parents]
|
|
1113
|
+
segment.parent_id = int(row['parent_id'])
|
|
1114
|
+
segment.wtg_parent_ids = set(map(int, wtg_parent_ids))
|
|
1115
|
+
assert len(segment.wtg_parent_ids) == wtg_n_parents
|
|
1116
|
+
segments.append(segment)
|
|
1117
|
+
del all_parent_ids
|
|
1118
|
+
if load_pcoords:
|
|
1119
|
+
del pcoord_entries
|
|
1120
|
+
|
|
1121
|
+
# If any other data sets are requested, load them as well
|
|
1122
|
+
for dsinfo in self.dataset_options.values():
|
|
1123
|
+
if dsinfo.get('load', False):
|
|
1124
|
+
dsname = dsinfo['name']
|
|
1125
|
+
try:
|
|
1126
|
+
ds = iter_group[dsinfo['h5path']]
|
|
1127
|
+
except KeyError:
|
|
1128
|
+
ds = None
|
|
1129
|
+
|
|
1130
|
+
if ds is not None:
|
|
1131
|
+
for segment in segments:
|
|
1132
|
+
seg_id = segment.seg_id
|
|
1133
|
+
segment.data[dsname] = ds[seg_id]
|
|
1134
|
+
|
|
1135
|
+
return segments
|
|
1136
|
+
|
|
1137
|
+
def prepare_segment_restarts(self, segments, basis_states=None, initial_states=None):
|
|
1138
|
+
'''Prepare the necessary folder and files given the data stored in parent per-iteration HDF5 file
|
|
1139
|
+
for propagating the simulation. ``basis_states`` and ``initial_states`` should be provided if the
|
|
1140
|
+
segments are newly created'''
|
|
1141
|
+
|
|
1142
|
+
if not self.store_h5:
|
|
1143
|
+
return
|
|
1144
|
+
|
|
1145
|
+
for segment in segments:
|
|
1146
|
+
if segment.parent_id < 0:
|
|
1147
|
+
if initial_states is None or basis_states is None:
|
|
1148
|
+
raise ValueError('initial and basis states required for preparing the segments')
|
|
1149
|
+
initial_state = initial_states[segment.initial_state_id]
|
|
1150
|
+
# Check if it's a start state
|
|
1151
|
+
if initial_state.istate_type == InitialState.ISTATE_TYPE_START:
|
|
1152
|
+
log.debug(
|
|
1153
|
+
f'Skip reading start state file from per-iteration HDF5 file for initial state {segment.initial_state_id}'
|
|
1154
|
+
)
|
|
1155
|
+
continue
|
|
1156
|
+
else:
|
|
1157
|
+
basis_state = basis_states[initial_state.basis_state_id]
|
|
1158
|
+
|
|
1159
|
+
parent = Segment(n_iter=0, seg_id=basis_state.state_id)
|
|
1160
|
+
else:
|
|
1161
|
+
parent = Segment(n_iter=segment.n_iter - 1, seg_id=segment.parent_id)
|
|
1162
|
+
|
|
1163
|
+
try:
|
|
1164
|
+
parent_iter_ref_h5_file = makepath(self.iter_ref_h5_template, {'n_iter': parent.n_iter})
|
|
1165
|
+
|
|
1166
|
+
with h5io.WESTIterationFile(parent_iter_ref_h5_file, 'r') as outf:
|
|
1167
|
+
outf.read_restart(parent)
|
|
1168
|
+
|
|
1169
|
+
segment.data['iterh5/restart'] = parent.data['iterh5/restart']
|
|
1170
|
+
except Exception as e:
|
|
1171
|
+
print('could not prepare restart data for segment {}/{}: {}'.format(segment.n_iter, segment.seg_id, str(e)))
|
|
1172
|
+
|
|
1173
|
+
def get_all_parent_ids(self, n_iter):
|
|
1174
|
+
file_version = self.we_h5file_version
|
|
1175
|
+
with self.lock:
|
|
1176
|
+
iter_group = self.get_iter_group(n_iter)
|
|
1177
|
+
seg_index = iter_group['seg_index']
|
|
1178
|
+
|
|
1179
|
+
if file_version < 5:
|
|
1180
|
+
offsets = seg_index['parents_offset']
|
|
1181
|
+
all_parents = iter_group['parents'][...]
|
|
1182
|
+
return all_parents.take(offsets)
|
|
1183
|
+
else:
|
|
1184
|
+
return seg_index['parent_id']
|
|
1185
|
+
|
|
1186
|
+
def get_parent_ids(self, n_iter, seg_ids=None):
|
|
1187
|
+
'''Return a sequence of the parent IDs of the given seg_ids.'''
|
|
1188
|
+
|
|
1189
|
+
file_version = self.we_h5file_version
|
|
1190
|
+
|
|
1191
|
+
with self.lock:
|
|
1192
|
+
iter_group = self.get_iter_group(n_iter)
|
|
1193
|
+
seg_index = iter_group['seg_index']
|
|
1194
|
+
|
|
1195
|
+
if seg_ids is None:
|
|
1196
|
+
seg_ids = range(len(seg_index))
|
|
1197
|
+
|
|
1198
|
+
if file_version < 5:
|
|
1199
|
+
offsets = seg_index['parents_offset']
|
|
1200
|
+
all_parents = iter_group['parents'][...]
|
|
1201
|
+
return [all_parents[offsets[seg_id]] for seg_id in seg_ids]
|
|
1202
|
+
else:
|
|
1203
|
+
all_parents = seg_index['parent_id']
|
|
1204
|
+
return [all_parents[seg_id] for seg_id in seg_ids]
|
|
1205
|
+
|
|
1206
|
+
def get_weights(self, n_iter, seg_ids):
|
|
1207
|
+
'''Return the weights associated with the given seg_ids'''
|
|
1208
|
+
|
|
1209
|
+
unique_ids = sorted(set(seg_ids))
|
|
1210
|
+
if not unique_ids:
|
|
1211
|
+
return []
|
|
1212
|
+
with self.lock:
|
|
1213
|
+
iter_group = self.get_iter_group(n_iter)
|
|
1214
|
+
index_subset = iter_group['seg_index'][unique_ids]
|
|
1215
|
+
weight_map = dict(zip(unique_ids, index_subset['weight']))
|
|
1216
|
+
return [weight_map[seg_id] for seg_id in seg_ids]
|
|
1217
|
+
|
|
1218
|
+
def get_child_ids(self, n_iter, seg_id):
|
|
1219
|
+
'''Return the seg_ids of segments who have the given segment as a parent.'''
|
|
1220
|
+
|
|
1221
|
+
with self.lock:
|
|
1222
|
+
if n_iter == self.current_iteration:
|
|
1223
|
+
return []
|
|
1224
|
+
|
|
1225
|
+
iter_group = self.get_iter_group(n_iter + 1)
|
|
1226
|
+
seg_index = iter_group['seg_index']
|
|
1227
|
+
seg_ids = np.arange(len(seg_index), dtype=seg_id_dtype)
|
|
1228
|
+
|
|
1229
|
+
if self.we_h5file_version < 5:
|
|
1230
|
+
offsets = seg_index['parents_offset']
|
|
1231
|
+
all_parent_ids = iter_group['parents'][...]
|
|
1232
|
+
parent_ids = np.array([all_parent_ids[offset] for offset in offsets])
|
|
1233
|
+
else:
|
|
1234
|
+
parent_ids = seg_index['parent_id']
|
|
1235
|
+
|
|
1236
|
+
return seg_ids[parent_ids == seg_id]
|
|
1237
|
+
|
|
1238
|
+
def get_children(self, segment):
|
|
1239
|
+
'''Return all segments which have the given segment as a parent'''
|
|
1240
|
+
|
|
1241
|
+
if segment.n_iter == self.current_iteration:
|
|
1242
|
+
return []
|
|
1243
|
+
|
|
1244
|
+
# Examine the segment index from the following iteration to see who has this segment
|
|
1245
|
+
# as a parent. We don't need to worry about the number of parents each segment
|
|
1246
|
+
# has, since each has at least one, and indexing on the offset into the parents array
|
|
1247
|
+
# gives the primary parent ID
|
|
1248
|
+
|
|
1249
|
+
with self.lock:
|
|
1250
|
+
iter_group = self.get_iter_group(segment.n_iter + 1)
|
|
1251
|
+
seg_index = iter_group['seg_index'][...]
|
|
1252
|
+
|
|
1253
|
+
# This is one of the slowest pieces of code I've ever written...
|
|
1254
|
+
# seg_index = iter_group['seg_index'][...]
|
|
1255
|
+
# seg_ids = [seg_id for (seg_id,row) in enumerate(seg_index)
|
|
1256
|
+
# if all_parent_ids[row['parents_offset']] == segment.seg_id]
|
|
1257
|
+
# return self.get_segments_by_id(segment.n_iter+1, seg_ids)
|
|
1258
|
+
if self.we_h5file_version < 5:
|
|
1259
|
+
parents = iter_group['parents'][seg_index['parent_offsets']]
|
|
1260
|
+
else:
|
|
1261
|
+
parents = seg_index['parent_id']
|
|
1262
|
+
all_seg_ids = np.arange(seg_index.len(), dtype=np.uintp)
|
|
1263
|
+
seg_ids = all_seg_ids[parents == segment.seg_id]
|
|
1264
|
+
# the above will return a scalar if only one is found, so convert
|
|
1265
|
+
# to a list if necessary
|
|
1266
|
+
try:
|
|
1267
|
+
len(seg_ids)
|
|
1268
|
+
except TypeError:
|
|
1269
|
+
seg_ids = [seg_ids]
|
|
1270
|
+
|
|
1271
|
+
return self.get_segments(segment.n_iter + 1, seg_ids)
|
|
1272
|
+
|
|
1273
|
+
# The following are dictated by the SimManager interface
|
|
1274
|
+
def prepare_run(self):
|
|
1275
|
+
self.open_backing()
|
|
1276
|
+
|
|
1277
|
+
def finalize_run(self):
|
|
1278
|
+
self.flush_backing()
|
|
1279
|
+
self.close_backing()
|
|
1280
|
+
|
|
1281
|
+
def save_new_weight_data(self, n_iter, new_weights):
|
|
1282
|
+
'''Save a set of NewWeightEntry objects to HDF5. Note that this should
|
|
1283
|
+
be called for the iteration in which the weights appear in their
|
|
1284
|
+
new locations (e.g. for recycled walkers, the iteration following
|
|
1285
|
+
recycling).'''
|
|
1286
|
+
|
|
1287
|
+
if not new_weights:
|
|
1288
|
+
return
|
|
1289
|
+
|
|
1290
|
+
system = self.system
|
|
1291
|
+
|
|
1292
|
+
index = np.empty(len(new_weights), dtype=nw_index_dtype)
|
|
1293
|
+
prev_init_pcoords = system.new_pcoord_array(len(new_weights))
|
|
1294
|
+
prev_final_pcoords = system.new_pcoord_array(len(new_weights))
|
|
1295
|
+
new_init_pcoords = system.new_pcoord_array(len(new_weights))
|
|
1296
|
+
|
|
1297
|
+
for ientry, nwentry in enumerate(new_weights):
|
|
1298
|
+
row = index[ientry]
|
|
1299
|
+
row['source_type'] = nwentry.source_type
|
|
1300
|
+
row['weight'] = nwentry.weight
|
|
1301
|
+
row['prev_seg_id'] = nwentry.prev_seg_id
|
|
1302
|
+
# the following use -1 as a sentinel for a missing value
|
|
1303
|
+
row['target_state_id'] = nwentry.target_state_id if nwentry.target_state_id is not None else -1
|
|
1304
|
+
row['initial_state_id'] = nwentry.initial_state_id if nwentry.initial_state_id is not None else -1
|
|
1305
|
+
|
|
1306
|
+
index[ientry] = row
|
|
1307
|
+
|
|
1308
|
+
if nwentry.prev_init_pcoord is not None:
|
|
1309
|
+
prev_init_pcoords[ientry] = nwentry.prev_init_pcoord
|
|
1310
|
+
|
|
1311
|
+
if nwentry.prev_final_pcoord is not None:
|
|
1312
|
+
prev_final_pcoords[ientry] = nwentry.prev_final_pcoord
|
|
1313
|
+
|
|
1314
|
+
if nwentry.new_init_pcoord is not None:
|
|
1315
|
+
new_init_pcoords[ientry] = nwentry.new_init_pcoord
|
|
1316
|
+
|
|
1317
|
+
with self.lock:
|
|
1318
|
+
iter_group = self.get_iter_group(n_iter)
|
|
1319
|
+
try:
|
|
1320
|
+
del iter_group['new_weights']
|
|
1321
|
+
except KeyError:
|
|
1322
|
+
pass
|
|
1323
|
+
|
|
1324
|
+
nwgroup = iter_group.create_group('new_weights')
|
|
1325
|
+
nwgroup['index'] = index
|
|
1326
|
+
nwgroup['prev_init_pcoord'] = prev_init_pcoords
|
|
1327
|
+
nwgroup['prev_final_pcoord'] = prev_final_pcoords
|
|
1328
|
+
nwgroup['new_init_pcoord'] = new_init_pcoords
|
|
1329
|
+
|
|
1330
|
+
def get_new_weight_data(self, n_iter):
|
|
1331
|
+
with self.lock:
|
|
1332
|
+
iter_group = self.get_iter_group(n_iter)
|
|
1333
|
+
|
|
1334
|
+
try:
|
|
1335
|
+
nwgroup = iter_group['new_weights']
|
|
1336
|
+
except KeyError:
|
|
1337
|
+
return []
|
|
1338
|
+
|
|
1339
|
+
try:
|
|
1340
|
+
index = nwgroup['index'][...]
|
|
1341
|
+
prev_init_pcoords = nwgroup['prev_init_pcoord'][...]
|
|
1342
|
+
prev_final_pcoords = nwgroup['prev_final_pcoord'][...]
|
|
1343
|
+
new_init_pcoords = nwgroup['new_init_pcoord'][...]
|
|
1344
|
+
except (KeyError, ValueError): # zero-length selections raise ValueError
|
|
1345
|
+
return []
|
|
1346
|
+
|
|
1347
|
+
entries = []
|
|
1348
|
+
for i in range(len(index)):
|
|
1349
|
+
irow = index[i]
|
|
1350
|
+
|
|
1351
|
+
prev_seg_id = irow['prev_seg_id']
|
|
1352
|
+
if prev_seg_id == -1:
|
|
1353
|
+
prev_seg_id = None
|
|
1354
|
+
|
|
1355
|
+
initial_state_id = irow['initial_state_id']
|
|
1356
|
+
if initial_state_id == -1:
|
|
1357
|
+
initial_state_id = None
|
|
1358
|
+
|
|
1359
|
+
target_state_id = irow['target_state_id']
|
|
1360
|
+
if target_state_id == -1:
|
|
1361
|
+
target_state_id = None
|
|
1362
|
+
|
|
1363
|
+
entry = NewWeightEntry(
|
|
1364
|
+
source_type=irow['source_type'],
|
|
1365
|
+
weight=irow['weight'],
|
|
1366
|
+
prev_seg_id=prev_seg_id,
|
|
1367
|
+
prev_init_pcoord=prev_init_pcoords[i].copy(),
|
|
1368
|
+
prev_final_pcoord=prev_final_pcoords[i].copy(),
|
|
1369
|
+
new_init_pcoord=new_init_pcoords[i].copy(),
|
|
1370
|
+
target_state_id=target_state_id,
|
|
1371
|
+
initial_state_id=initial_state_id,
|
|
1372
|
+
)
|
|
1373
|
+
|
|
1374
|
+
entries.append(entry)
|
|
1375
|
+
return entries
|
|
1376
|
+
|
|
1377
|
+
def find_bin_mapper(self, hashval):
|
|
1378
|
+
'''Check to see if the given has value is in the binning table. Returns the index in the
|
|
1379
|
+
bin data tables if found, or raises KeyError if not.'''
|
|
1380
|
+
|
|
1381
|
+
try:
|
|
1382
|
+
hashval = hashval.hexdigest()
|
|
1383
|
+
except AttributeError:
|
|
1384
|
+
pass
|
|
1385
|
+
|
|
1386
|
+
with self.lock:
|
|
1387
|
+
# these will raise KeyError if the group doesn't exist, which also means
|
|
1388
|
+
# that bin data is not available, so no special treatment here
|
|
1389
|
+
try:
|
|
1390
|
+
binning_group = self.we_h5file['/bin_topologies']
|
|
1391
|
+
index = binning_group['index']
|
|
1392
|
+
except KeyError:
|
|
1393
|
+
raise KeyError('hash {} not found'.format(hashval))
|
|
1394
|
+
|
|
1395
|
+
n_entries = len(index)
|
|
1396
|
+
if n_entries == 0:
|
|
1397
|
+
raise KeyError('hash {} not found'.format(hashval))
|
|
1398
|
+
|
|
1399
|
+
chunksize = self.table_scan_chunksize
|
|
1400
|
+
for istart in range(0, n_entries, chunksize):
|
|
1401
|
+
chunk = index[istart : min(istart + chunksize, n_entries)]
|
|
1402
|
+
for i in range(len(chunk)):
|
|
1403
|
+
if chunk[i]['hash'] == bytes(hashval, 'utf-8'):
|
|
1404
|
+
return istart + i
|
|
1405
|
+
|
|
1406
|
+
raise KeyError('hash {} not found'.format(hashval))
|
|
1407
|
+
|
|
1408
|
+
def get_bin_mapper(self, hashval):
|
|
1409
|
+
'''Look up the given hash value in the binning table, unpickling and returning the corresponding
|
|
1410
|
+
bin mapper if available, or raising KeyError if not.'''
|
|
1411
|
+
|
|
1412
|
+
# Convert to a hex digest if we need to
|
|
1413
|
+
try:
|
|
1414
|
+
hashval = hashval.hexdigest()
|
|
1415
|
+
except AttributeError:
|
|
1416
|
+
pass
|
|
1417
|
+
|
|
1418
|
+
with self.lock:
|
|
1419
|
+
# these will raise KeyError if the group doesn't exist, which also means
|
|
1420
|
+
# that bin data is not available, so no special treatment here
|
|
1421
|
+
try:
|
|
1422
|
+
binning_group = self.we_h5file['/bin_topologies']
|
|
1423
|
+
index = binning_group['index']
|
|
1424
|
+
pkl = binning_group['pickles']
|
|
1425
|
+
except KeyError:
|
|
1426
|
+
raise KeyError('hash {} not found. Could not retrieve binning group'.format(hashval))
|
|
1427
|
+
|
|
1428
|
+
n_entries = len(index)
|
|
1429
|
+
if n_entries == 0:
|
|
1430
|
+
raise KeyError('hash {} not found. No entries in index'.format(hashval))
|
|
1431
|
+
|
|
1432
|
+
chunksize = self.table_scan_chunksize
|
|
1433
|
+
|
|
1434
|
+
for istart in range(0, n_entries, chunksize):
|
|
1435
|
+
chunk = index[istart : min(istart + chunksize, n_entries)]
|
|
1436
|
+
for i in range(len(chunk)):
|
|
1437
|
+
if chunk[i]['hash'] == bytes(hashval, 'utf-8'):
|
|
1438
|
+
pkldat = bytes(pkl[istart + i, 0 : chunk[i]['pickle_len']].data)
|
|
1439
|
+
mapper = pickle.loads(pkldat)
|
|
1440
|
+
log.debug('loaded {!r} from {!r}'.format(mapper, binning_group))
|
|
1441
|
+
log.debug('hash value {!r}'.format(hashval))
|
|
1442
|
+
return mapper
|
|
1443
|
+
|
|
1444
|
+
raise KeyError('hash {} not found'.format(hashval))
|
|
1445
|
+
|
|
1446
|
+
def save_bin_mapper(self, hashval, pickle_data):
|
|
1447
|
+
'''Store the given mapper in the table of saved mappers. If the mapper cannot be stored,
|
|
1448
|
+
PickleError will be raised. Returns the index in the bin data tables where the mapper is stored.'''
|
|
1449
|
+
|
|
1450
|
+
try:
|
|
1451
|
+
hashval = hashval.hexdigest()
|
|
1452
|
+
except AttributeError:
|
|
1453
|
+
pass
|
|
1454
|
+
pickle_data = bytes(pickle_data)
|
|
1455
|
+
|
|
1456
|
+
# First, scan to see if the mapper already is in the HDF5 file
|
|
1457
|
+
try:
|
|
1458
|
+
return self.find_bin_mapper(hashval)
|
|
1459
|
+
except KeyError:
|
|
1460
|
+
pass
|
|
1461
|
+
|
|
1462
|
+
# At this point, we have a valid pickle and know it's not stored
|
|
1463
|
+
with self.lock:
|
|
1464
|
+
binning_group = self.we_h5file.require_group('/bin_topologies')
|
|
1465
|
+
|
|
1466
|
+
try:
|
|
1467
|
+
index = binning_group['index']
|
|
1468
|
+
pickle_ds = binning_group['pickles']
|
|
1469
|
+
except KeyError:
|
|
1470
|
+
index = binning_group.create_dataset('index', shape=(1,), maxshape=(None,), dtype=binning_index_dtype)
|
|
1471
|
+
pickle_ds = binning_group.create_dataset(
|
|
1472
|
+
'pickles',
|
|
1473
|
+
dtype=np.uint8,
|
|
1474
|
+
shape=(1, len(pickle_data)),
|
|
1475
|
+
maxshape=(None, None),
|
|
1476
|
+
chunks=(1, 4096),
|
|
1477
|
+
compression='gzip',
|
|
1478
|
+
compression_opts=9,
|
|
1479
|
+
)
|
|
1480
|
+
n_entries = 1
|
|
1481
|
+
else:
|
|
1482
|
+
n_entries = len(index) + 1
|
|
1483
|
+
index.resize((n_entries,))
|
|
1484
|
+
new_hsize = max(pickle_ds.shape[1], len(pickle_data))
|
|
1485
|
+
pickle_ds.resize((n_entries, new_hsize))
|
|
1486
|
+
|
|
1487
|
+
index_row = index[n_entries - 1]
|
|
1488
|
+
index_row['hash'] = hashval
|
|
1489
|
+
index_row['pickle_len'] = len(pickle_data)
|
|
1490
|
+
index[n_entries - 1] = index_row
|
|
1491
|
+
pickle_ds[n_entries - 1, : len(pickle_data)] = memoryview(pickle_data)
|
|
1492
|
+
return n_entries - 1
|
|
1493
|
+
|
|
1494
|
+
def save_iter_binning(self, n_iter, hashval, pickled_mapper, target_counts):
|
|
1495
|
+
'''Save information about the binning used to generate segments for iteration n_iter.'''
|
|
1496
|
+
|
|
1497
|
+
with self.lock:
|
|
1498
|
+
iter_group = self.get_iter_group(n_iter)
|
|
1499
|
+
|
|
1500
|
+
try:
|
|
1501
|
+
del iter_group['bin_target_counts']
|
|
1502
|
+
except KeyError:
|
|
1503
|
+
pass
|
|
1504
|
+
|
|
1505
|
+
iter_group['bin_target_counts'] = target_counts
|
|
1506
|
+
|
|
1507
|
+
if hashval and pickled_mapper:
|
|
1508
|
+
self.save_bin_mapper(hashval, pickled_mapper)
|
|
1509
|
+
iter_group.attrs['binhash'] = hashval
|
|
1510
|
+
else:
|
|
1511
|
+
iter_group.attrs['binhash'] = ''
|
|
1512
|
+
|
|
1513
|
+
|
|
1514
|
+
def normalize_dataset_options(dsopts, path_prefix='', n_iter=0):
|
|
1515
|
+
dsopts = dict(dsopts)
|
|
1516
|
+
|
|
1517
|
+
ds_name = dsopts['name']
|
|
1518
|
+
if path_prefix:
|
|
1519
|
+
default_h5path = '{}/{}'.format(path_prefix, ds_name)
|
|
1520
|
+
else:
|
|
1521
|
+
default_h5path = ds_name
|
|
1522
|
+
|
|
1523
|
+
dsopts.setdefault('h5path', default_h5path)
|
|
1524
|
+
dtype = dsopts.get('dtype')
|
|
1525
|
+
if dtype:
|
|
1526
|
+
if isinstance(dtype, str):
|
|
1527
|
+
try:
|
|
1528
|
+
dsopts['dtype'] = np.dtype(getattr(np, dtype))
|
|
1529
|
+
except AttributeError:
|
|
1530
|
+
dsopts['dtype'] = np.dtype(getattr(builtins, dtype))
|
|
1531
|
+
else:
|
|
1532
|
+
dsopts['dtype'] = np.dtype(dtype)
|
|
1533
|
+
|
|
1534
|
+
dsopts['store'] = bool(dsopts['store']) if 'store' in dsopts else True
|
|
1535
|
+
dsopts['load'] = bool(dsopts['load']) if 'load' in dsopts else False
|
|
1536
|
+
|
|
1537
|
+
return dsopts
|
|
1538
|
+
|
|
1539
|
+
|
|
1540
|
+
def create_dataset_from_dsopts(group, dsopts, shape=None, dtype=None, data=None, autocompress_threshold=None, n_iter=None):
|
|
1541
|
+
# log.debug('create_dataset_from_dsopts(group={!r}, dsopts={!r}, shape={!r}, dtype={!r}, data={!r}, autocompress_threshold={!r})'
|
|
1542
|
+
# .format(group,dsopts,shape,dtype,data,autocompress_threshold))
|
|
1543
|
+
if not dsopts.get('store', True):
|
|
1544
|
+
return None
|
|
1545
|
+
|
|
1546
|
+
if 'file' in list(dsopts.keys()):
|
|
1547
|
+
import h5py
|
|
1548
|
+
|
|
1549
|
+
# dsopts['file'] = str(dsopts['file']).format(n_iter=n_iter)
|
|
1550
|
+
h5_auxfile = h5io.WESTPAH5File(dsopts['file'].format(n_iter=n_iter))
|
|
1551
|
+
h5group = group
|
|
1552
|
+
if not ("iter_" + str(n_iter).zfill(8)) in h5_auxfile:
|
|
1553
|
+
h5_auxfile.create_group("iter_" + str(n_iter).zfill(8))
|
|
1554
|
+
group = h5_auxfile[('/' + "iter_" + str(n_iter).zfill(8))]
|
|
1555
|
+
|
|
1556
|
+
h5path = dsopts['h5path']
|
|
1557
|
+
containing_group_name = posixpath.dirname(h5path)
|
|
1558
|
+
h5_dsname = posixpath.basename(h5path)
|
|
1559
|
+
|
|
1560
|
+
# ensure arguments are sane
|
|
1561
|
+
if not shape and data is None:
|
|
1562
|
+
raise ValueError('either shape or data must be provided')
|
|
1563
|
+
elif data is None and (shape and dtype is None):
|
|
1564
|
+
raise ValueError('both shape and dtype must be provided when data is not provided')
|
|
1565
|
+
elif shape and data is not None and not data.shape == shape:
|
|
1566
|
+
raise ValueError('explicit shape {!r} does not match data shape {!r}'.format(shape, data.shape))
|
|
1567
|
+
|
|
1568
|
+
if data is not None:
|
|
1569
|
+
shape = data.shape
|
|
1570
|
+
if dtype is None:
|
|
1571
|
+
dtype = data.dtype
|
|
1572
|
+
# end argument sanity checks
|
|
1573
|
+
|
|
1574
|
+
# figure out where to store this data
|
|
1575
|
+
if containing_group_name:
|
|
1576
|
+
containing_group = group.require_group(containing_group_name)
|
|
1577
|
+
else:
|
|
1578
|
+
containing_group = group
|
|
1579
|
+
|
|
1580
|
+
# has user requested an explicit data type?
|
|
1581
|
+
# the extra np.dtype is an idempotent operation on true dtype
|
|
1582
|
+
# objects, but ensures that things like np.float32, which are
|
|
1583
|
+
# actually NOT dtype objects, become dtype objects
|
|
1584
|
+
h5_dtype = np.dtype(dsopts.get('dtype', dtype))
|
|
1585
|
+
|
|
1586
|
+
compression = None
|
|
1587
|
+
scaleoffset = None
|
|
1588
|
+
shuffle = False
|
|
1589
|
+
|
|
1590
|
+
# compress if 1) explicitly requested, or 2) dataset size exceeds threshold and
|
|
1591
|
+
# compression not explicitly prohibited
|
|
1592
|
+
compression_directive = dsopts.get('compression')
|
|
1593
|
+
if compression_directive is None:
|
|
1594
|
+
# No directive
|
|
1595
|
+
nbytes = np.multiply.reduce(shape) * h5_dtype.itemsize
|
|
1596
|
+
if autocompress_threshold and nbytes > autocompress_threshold:
|
|
1597
|
+
compression = 9
|
|
1598
|
+
elif compression_directive == 0: # includes False
|
|
1599
|
+
# Compression prohibited
|
|
1600
|
+
compression = None
|
|
1601
|
+
else: # compression explicitly requested
|
|
1602
|
+
compression = compression_directive
|
|
1603
|
+
|
|
1604
|
+
# Is scale/offset requested?
|
|
1605
|
+
scaleoffset = dsopts.get('scaleoffset', None)
|
|
1606
|
+
if scaleoffset is not None:
|
|
1607
|
+
scaleoffset = int(scaleoffset)
|
|
1608
|
+
|
|
1609
|
+
# We always shuffle if we compress (losslessly)
|
|
1610
|
+
if compression:
|
|
1611
|
+
shuffle = True
|
|
1612
|
+
else:
|
|
1613
|
+
shuffle = False
|
|
1614
|
+
|
|
1615
|
+
need_chunks = any([compression, scaleoffset is not None, shuffle])
|
|
1616
|
+
|
|
1617
|
+
# We use user-provided chunks if available
|
|
1618
|
+
chunks_directive = dsopts.get('chunks')
|
|
1619
|
+
if chunks_directive is None:
|
|
1620
|
+
chunks = None
|
|
1621
|
+
elif chunks_directive is True:
|
|
1622
|
+
chunks = calc_chunksize(shape, h5_dtype)
|
|
1623
|
+
elif chunks_directive is False:
|
|
1624
|
+
chunks = None
|
|
1625
|
+
else:
|
|
1626
|
+
chunks = tuple(chunks_directive[i] if chunks_directive[i] <= shape[i] else shape[i] for i in range(len(shape)))
|
|
1627
|
+
|
|
1628
|
+
if not chunks and need_chunks:
|
|
1629
|
+
chunks = calc_chunksize(shape, h5_dtype)
|
|
1630
|
+
|
|
1631
|
+
opts = {'shape': shape, 'dtype': h5_dtype, 'compression': compression, 'shuffle': shuffle, 'chunks': chunks}
|
|
1632
|
+
|
|
1633
|
+
try:
|
|
1634
|
+
import h5py._hl.filters
|
|
1635
|
+
|
|
1636
|
+
h5py._hl.filters._COMP_FILTERS['scaleoffset']
|
|
1637
|
+
except (ImportError, KeyError, AttributeError):
|
|
1638
|
+
# filter not available, or an unexpected version of h5py
|
|
1639
|
+
# use lossless compression instead
|
|
1640
|
+
opts['compression'] = True
|
|
1641
|
+
else:
|
|
1642
|
+
opts['scaleoffset'] = scaleoffset
|
|
1643
|
+
|
|
1644
|
+
if log.isEnabledFor(logging.DEBUG):
|
|
1645
|
+
log.debug('requiring aux dataset {!r}, shape={!r}, opts={!r}'.format(h5_dsname, shape, opts))
|
|
1646
|
+
|
|
1647
|
+
dset = containing_group.require_dataset(h5_dsname, **opts)
|
|
1648
|
+
|
|
1649
|
+
if data is not None:
|
|
1650
|
+
dset[...] = data
|
|
1651
|
+
|
|
1652
|
+
if 'file' in list(dsopts.keys()):
|
|
1653
|
+
import h5py
|
|
1654
|
+
|
|
1655
|
+
if not dsopts['h5path'] in h5group:
|
|
1656
|
+
h5group[dsopts['h5path']] = h5py.ExternalLink(
|
|
1657
|
+
dsopts['file'].format(n_iter=n_iter), ("/" + "iter_" + str(n_iter).zfill(8) + "/" + dsopts['h5path'])
|
|
1658
|
+
)
|
|
1659
|
+
|
|
1660
|
+
return dset
|
|
1661
|
+
|
|
1662
|
+
|
|
1663
|
+
def require_dataset_from_dsopts(group, dsopts, shape=None, dtype=None, data=None, autocompress_threshold=None, n_iter=None):
|
|
1664
|
+
if not dsopts.get('store', True):
|
|
1665
|
+
return None
|
|
1666
|
+
try:
|
|
1667
|
+
return group[dsopts['h5path']]
|
|
1668
|
+
except KeyError:
|
|
1669
|
+
return create_dataset_from_dsopts(
|
|
1670
|
+
group, dsopts, shape=shape, dtype=dtype, data=data, autocompress_threshold=autocompress_threshold, n_iter=n_iter
|
|
1671
|
+
)
|
|
1672
|
+
|
|
1673
|
+
|
|
1674
|
+
def calc_chunksize(shape, dtype, max_chunksize=262144):
|
|
1675
|
+
'''Calculate a chunk size for HDF5 data, anticipating that access will slice
|
|
1676
|
+
along lower dimensions sooner than higher dimensions.'''
|
|
1677
|
+
|
|
1678
|
+
chunk_shape = list(shape)
|
|
1679
|
+
for idim in range(len(shape)):
|
|
1680
|
+
chunk_nbytes = np.multiply.reduce(chunk_shape) * dtype.itemsize
|
|
1681
|
+
while chunk_shape[idim] > 1 and chunk_nbytes > max_chunksize:
|
|
1682
|
+
chunk_shape[idim] >>= 1 # divide by 2
|
|
1683
|
+
chunk_nbytes = np.multiply.reduce(chunk_shape) * dtype.itemsize
|
|
1684
|
+
|
|
1685
|
+
if chunk_nbytes <= max_chunksize:
|
|
1686
|
+
break
|
|
1687
|
+
|
|
1688
|
+
chunk_shape = tuple(chunk_shape)
|
|
1689
|
+
log.debug(
|
|
1690
|
+
'selected chunk shape {} for data set of type {} shaped {} (chunk size = {} bytes)'.format(
|
|
1691
|
+
chunk_shape, dtype, shape, chunk_nbytes
|
|
1692
|
+
)
|
|
1693
|
+
)
|
|
1694
|
+
return chunk_shape
|