ChessAnalysisPipeline 0.0.17.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- CHAP/TaskManager.py +216 -0
- CHAP/__init__.py +27 -0
- CHAP/common/__init__.py +57 -0
- CHAP/common/models/__init__.py +8 -0
- CHAP/common/models/common.py +124 -0
- CHAP/common/models/integration.py +659 -0
- CHAP/common/models/map.py +1291 -0
- CHAP/common/processor.py +2869 -0
- CHAP/common/reader.py +658 -0
- CHAP/common/utils.py +110 -0
- CHAP/common/writer.py +730 -0
- CHAP/edd/__init__.py +23 -0
- CHAP/edd/models.py +876 -0
- CHAP/edd/processor.py +3069 -0
- CHAP/edd/reader.py +1023 -0
- CHAP/edd/select_material_params_gui.py +348 -0
- CHAP/edd/utils.py +1572 -0
- CHAP/edd/writer.py +26 -0
- CHAP/foxden/__init__.py +19 -0
- CHAP/foxden/models.py +71 -0
- CHAP/foxden/processor.py +124 -0
- CHAP/foxden/reader.py +224 -0
- CHAP/foxden/utils.py +80 -0
- CHAP/foxden/writer.py +168 -0
- CHAP/giwaxs/__init__.py +11 -0
- CHAP/giwaxs/models.py +491 -0
- CHAP/giwaxs/processor.py +776 -0
- CHAP/giwaxs/reader.py +8 -0
- CHAP/giwaxs/writer.py +8 -0
- CHAP/inference/__init__.py +7 -0
- CHAP/inference/processor.py +69 -0
- CHAP/inference/reader.py +8 -0
- CHAP/inference/writer.py +8 -0
- CHAP/models.py +227 -0
- CHAP/pipeline.py +479 -0
- CHAP/processor.py +125 -0
- CHAP/reader.py +124 -0
- CHAP/runner.py +277 -0
- CHAP/saxswaxs/__init__.py +7 -0
- CHAP/saxswaxs/processor.py +8 -0
- CHAP/saxswaxs/reader.py +8 -0
- CHAP/saxswaxs/writer.py +8 -0
- CHAP/server.py +125 -0
- CHAP/sin2psi/__init__.py +7 -0
- CHAP/sin2psi/processor.py +8 -0
- CHAP/sin2psi/reader.py +8 -0
- CHAP/sin2psi/writer.py +8 -0
- CHAP/tomo/__init__.py +15 -0
- CHAP/tomo/models.py +210 -0
- CHAP/tomo/processor.py +3862 -0
- CHAP/tomo/reader.py +9 -0
- CHAP/tomo/writer.py +59 -0
- CHAP/utils/__init__.py +6 -0
- CHAP/utils/converters.py +188 -0
- CHAP/utils/fit.py +2947 -0
- CHAP/utils/general.py +2655 -0
- CHAP/utils/material.py +274 -0
- CHAP/utils/models.py +595 -0
- CHAP/utils/parfile.py +224 -0
- CHAP/writer.py +122 -0
- MLaaS/__init__.py +0 -0
- MLaaS/ktrain.py +205 -0
- MLaaS/mnist_img.py +83 -0
- MLaaS/tfaas_client.py +371 -0
- chessanalysispipeline-0.0.17.dev3.dist-info/LICENSE +60 -0
- chessanalysispipeline-0.0.17.dev3.dist-info/METADATA +29 -0
- chessanalysispipeline-0.0.17.dev3.dist-info/RECORD +70 -0
- chessanalysispipeline-0.0.17.dev3.dist-info/WHEEL +5 -0
- chessanalysispipeline-0.0.17.dev3.dist-info/entry_points.txt +2 -0
- chessanalysispipeline-0.0.17.dev3.dist-info/top_level.txt +2 -0
CHAP/common/processor.py
ADDED
|
@@ -0,0 +1,2869 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
#-*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
File : processor.py
|
|
5
|
+
Author : Valentin Kuznetsov <vkuznet AT gmail dot com>
|
|
6
|
+
Description: Module for Processors used in multiple experiment-specific workflows.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
# System modules
|
|
10
|
+
from copy import deepcopy
|
|
11
|
+
import os
|
|
12
|
+
from typing import Optional
|
|
13
|
+
|
|
14
|
+
# Third party modules
|
|
15
|
+
import numpy as np
|
|
16
|
+
from pydantic import (
|
|
17
|
+
Field,
|
|
18
|
+
conint,
|
|
19
|
+
field_validator,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
# Local modules
|
|
23
|
+
from CHAP import Processor
|
|
24
|
+
from CHAP.common.models.map import (
|
|
25
|
+
DetectorConfig,
|
|
26
|
+
MapConfig,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class AsyncProcessor(Processor):
|
|
31
|
+
"""A Processor to process multiple sets of input data via asyncio
|
|
32
|
+
module.
|
|
33
|
+
|
|
34
|
+
:ivar mgr: The `Processor` used to process every set of input data.
|
|
35
|
+
:type mgr: Processor
|
|
36
|
+
"""
|
|
37
|
+
def __init__(self, mgr):
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.mgr = mgr
|
|
40
|
+
|
|
41
|
+
def process(self, data):
|
|
42
|
+
"""Asynchronously process the input documents with the
|
|
43
|
+
`self.mgr` `Processor`.
|
|
44
|
+
|
|
45
|
+
:param data: Input data documents to process.
|
|
46
|
+
:type data: iterable
|
|
47
|
+
"""
|
|
48
|
+
# System modules
|
|
49
|
+
import asyncio
|
|
50
|
+
|
|
51
|
+
async def task(mgr, doc):
|
|
52
|
+
"""Process given data using provided `Processor`.
|
|
53
|
+
|
|
54
|
+
:param mgr: The object that will process given data.
|
|
55
|
+
:type mgr: Processor
|
|
56
|
+
:param doc: The data to process.
|
|
57
|
+
:type doc: object
|
|
58
|
+
:return: The processed data.
|
|
59
|
+
:rtype: object
|
|
60
|
+
"""
|
|
61
|
+
return mgr.process(doc)
|
|
62
|
+
|
|
63
|
+
async def execute_tasks(mgr, docs):
|
|
64
|
+
"""Process given set of documents using provided task
|
|
65
|
+
manager.
|
|
66
|
+
|
|
67
|
+
:param mgr: The object that will process all documents.
|
|
68
|
+
:type mgr: Processor
|
|
69
|
+
:param docs: The set of data documents to process.
|
|
70
|
+
:type doc: iterable
|
|
71
|
+
"""
|
|
72
|
+
coroutines = [task(mgr, d) for d in docs]
|
|
73
|
+
await asyncio.gather(*coroutines)
|
|
74
|
+
|
|
75
|
+
asyncio.run(execute_tasks(self.mgr, data))
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class BinarizeProcessor(Processor):
|
|
79
|
+
"""A Processor to binarize a dataset."""
|
|
80
|
+
def process(self, data, config=None):
|
|
81
|
+
"""Plot and return a binarized dataset from a dataset contained
|
|
82
|
+
in `data`. The dataset must either be `array-like` or a NeXus
|
|
83
|
+
NXobject object with a default plottable data path or a
|
|
84
|
+
specified path to a NeXus NXdata or NXfield object.
|
|
85
|
+
|
|
86
|
+
:param data: Input data.
|
|
87
|
+
:type data: list[PipelineData]
|
|
88
|
+
:param config: Initialization parameters for an instance of
|
|
89
|
+
CHAP.common.models.BinarizeProcessorConfig
|
|
90
|
+
:type config: dict, optional
|
|
91
|
+
:return: The binarized dataset for an `array-like` input or
|
|
92
|
+
a return type equal that of the input object with the
|
|
93
|
+
binarized dataset added.
|
|
94
|
+
:rtype: Union[numpy.ndarray, nexusformat.nexus.NXobject]
|
|
95
|
+
"""
|
|
96
|
+
# Third party modules
|
|
97
|
+
from nexusformat.nexus import (
|
|
98
|
+
NXdata,
|
|
99
|
+
NXfield,
|
|
100
|
+
NXlink,
|
|
101
|
+
nxsetconfig,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Local modules
|
|
105
|
+
from CHAP.utils.general import nxcopy
|
|
106
|
+
|
|
107
|
+
nxsetconfig(memory=100000)
|
|
108
|
+
|
|
109
|
+
# Load the validated binarize processor configuration
|
|
110
|
+
if config is None:
|
|
111
|
+
# Local modules
|
|
112
|
+
from CHAP.common.models.common import BinarizeProcessorConfig
|
|
113
|
+
|
|
114
|
+
config = BinarizeProcessorConfig()
|
|
115
|
+
else:
|
|
116
|
+
config = self.get_config(
|
|
117
|
+
data, config=config,
|
|
118
|
+
schema='common.models.BinarizeProcessorConfig')
|
|
119
|
+
|
|
120
|
+
# Load the default data
|
|
121
|
+
try:
|
|
122
|
+
nxobject = self.get_data(data)
|
|
123
|
+
if config.nxpath is None:
|
|
124
|
+
dataset = nxobject.get_default()
|
|
125
|
+
else:
|
|
126
|
+
dataset = nxobject[config.nxpath]
|
|
127
|
+
if isinstance(dataset, NXdata):
|
|
128
|
+
nxsignal = dataset.nxsignal
|
|
129
|
+
data = nxsignal.nxdata
|
|
130
|
+
else:
|
|
131
|
+
data = dataset.nxdata
|
|
132
|
+
assert isinstance(data, np.ndarray)
|
|
133
|
+
except Exception:
|
|
134
|
+
try:
|
|
135
|
+
dataset = self.unwrap_pipelinedata(data)[-1]
|
|
136
|
+
assert isinstance(dataset, np.ndarray)
|
|
137
|
+
data = dataset
|
|
138
|
+
except Exception as exc:
|
|
139
|
+
raise ValueError(
|
|
140
|
+
'Unable the load a valid input data object') from exc
|
|
141
|
+
|
|
142
|
+
if config.method == 'yen':
|
|
143
|
+
min_ = data.min()
|
|
144
|
+
max_ = data.max()
|
|
145
|
+
data = 1 + (config.num_bin - 1) * (data - min_) / (max_ - min_)
|
|
146
|
+
|
|
147
|
+
# Get a histogram of the data
|
|
148
|
+
counts, edges = np.histogram(data, bins=config.num_bin)
|
|
149
|
+
centers = edges[:-1] + 0.5 * np.diff(edges)
|
|
150
|
+
|
|
151
|
+
# Calculate the data cutoff threshold
|
|
152
|
+
# pylint: disable=no-name-in-module
|
|
153
|
+
if config.method == 'CHAP':
|
|
154
|
+
weights = np.cumsum(counts)
|
|
155
|
+
means = np.cumsum(counts * centers)
|
|
156
|
+
weights = weights[0:-1] / weights[-1]
|
|
157
|
+
means = means[0:-1] / means[-1]
|
|
158
|
+
variances = (means-weights)**2 / (weights * (1. - weights))
|
|
159
|
+
threshold = centers[np.argmax(variances)]
|
|
160
|
+
elif config.method == 'otsu':
|
|
161
|
+
# Third party modules
|
|
162
|
+
from skimage.filters import threshold_otsu
|
|
163
|
+
|
|
164
|
+
threshold = threshold_otsu(hist=(counts, centers))
|
|
165
|
+
elif config.method == 'yen':
|
|
166
|
+
# Third party modules
|
|
167
|
+
from skimage.filters import threshold_yen
|
|
168
|
+
|
|
169
|
+
threshold = threshold_yen(hist=(counts, centers))
|
|
170
|
+
elif config.method == 'isodata':
|
|
171
|
+
# Third party modules
|
|
172
|
+
from skimage.filters import threshold_isodata
|
|
173
|
+
|
|
174
|
+
threshold = threshold_isodata(hist=(counts, centers))
|
|
175
|
+
else:
|
|
176
|
+
# Third party modules
|
|
177
|
+
from skimage.filters import threshold_minimum
|
|
178
|
+
|
|
179
|
+
threshold = threshold_minimum(hist=(counts, centers))
|
|
180
|
+
# pylint: enable=no-name-in-module
|
|
181
|
+
|
|
182
|
+
# Apply the data cutoff threshold
|
|
183
|
+
data = np.where(data < threshold, 0, 1).astype(np.ubyte)
|
|
184
|
+
|
|
185
|
+
# Return the output for array-like or NeXus NXfield inputs
|
|
186
|
+
if isinstance(dataset, np.ndarray):
|
|
187
|
+
return data
|
|
188
|
+
if isinstance(dataset, NXfield):
|
|
189
|
+
attrs = dataset.attrs
|
|
190
|
+
attrs.pop('target', None)
|
|
191
|
+
nxfield = NXfield(
|
|
192
|
+
value=data, name=f'{dataset.nxname}_binarized', attrs=attrs)
|
|
193
|
+
return nxfield
|
|
194
|
+
|
|
195
|
+
# Otherwise create a copy of the input NeXus, add the binarized
|
|
196
|
+
# data to the copied original dataset, and remove the original
|
|
197
|
+
# dataset if config.remove_original_data is set
|
|
198
|
+
name = f'{nxsignal.nxname}_binarized'
|
|
199
|
+
nxdefault = nxobject.get_default()
|
|
200
|
+
if isinstance(nxsignal, NXlink):
|
|
201
|
+
link = dataset.nxpath
|
|
202
|
+
path = os.path.split(nxsignal.nxtarget)[0]
|
|
203
|
+
else:
|
|
204
|
+
link = nxdefault.nxpath
|
|
205
|
+
path = os.path.split(nxsignal.nxpath)[0]
|
|
206
|
+
exclude_nxpaths = []
|
|
207
|
+
if config.remove_original_data:
|
|
208
|
+
if link is not None:
|
|
209
|
+
exclude_nxpaths.append(os.path.relpath(
|
|
210
|
+
f'{link}/{nxsignal.nxname}', nxobject.nxpath))
|
|
211
|
+
exclude_nxpaths.append(os.path.relpath(
|
|
212
|
+
f'{path}/{nxsignal.nxname}', nxobject.nxpath))
|
|
213
|
+
nxobject = nxcopy(nxobject, exclude_nxpaths=exclude_nxpaths)
|
|
214
|
+
attrs = nxsignal.attrs
|
|
215
|
+
attrs.pop('target', None)
|
|
216
|
+
nxobject[f'{path}/{name}'] = NXfield(
|
|
217
|
+
value=data, name=name, attrs=attrs)
|
|
218
|
+
nxobject[path].attrs['signal'] = name
|
|
219
|
+
if link is not None:
|
|
220
|
+
nxobject[f'{link}/{name}'] = NXlink(f'{path}/{name}')
|
|
221
|
+
nxobject[link].attrs['signal'] = name
|
|
222
|
+
|
|
223
|
+
return nxobject
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class ConstructBaseline(Processor):
|
|
227
|
+
"""A Processor to construct a baseline for a dataset."""
|
|
228
|
+
def process(
|
|
229
|
+
self, data, x=None, mask=None, tol=1.e-6, lam=1.e6, max_iter=20,
|
|
230
|
+
save_figures=False):
|
|
231
|
+
"""Construct and return the baseline for a dataset.
|
|
232
|
+
|
|
233
|
+
:param data: Input data.
|
|
234
|
+
:type data: list[PipelineData]
|
|
235
|
+
:param x: Independent dimension (only used when running
|
|
236
|
+
interactively or when filename is set).
|
|
237
|
+
:param mask: A mask to apply to the spectrum before baseline
|
|
238
|
+
construction.
|
|
239
|
+
:type mask: array-like, optional
|
|
240
|
+
:param tol: The convergence tolerence, defaults to `1.e-6`.
|
|
241
|
+
:type tol: float, optional
|
|
242
|
+
:param lam: The &lambda (smoothness) parameter (the balance
|
|
243
|
+
between the residual of the data and the baseline and the
|
|
244
|
+
smoothness of the baseline). The suggested range is between
|
|
245
|
+
100 and 10^8, defaults to `10^6`.
|
|
246
|
+
:type lam: float, optional
|
|
247
|
+
:param max_iter: The maximum number of iterations,
|
|
248
|
+
defaults to `20`.
|
|
249
|
+
:type max_iter: int, optional
|
|
250
|
+
:param save_figures: Save .pngs of plots for checking inputs &
|
|
251
|
+
outputs of this Processor, defaults to `False`.
|
|
252
|
+
:type save_figures: bool, optional
|
|
253
|
+
:return: The smoothed baseline and the configuration.
|
|
254
|
+
:rtype: numpy.array, dict
|
|
255
|
+
"""
|
|
256
|
+
try:
|
|
257
|
+
data = np.asarray(self.unwrap_pipelinedata(data)[0])
|
|
258
|
+
except Exception as exc:
|
|
259
|
+
raise ValueError(
|
|
260
|
+
f'The structure of {data} contains no valid data') from exc
|
|
261
|
+
|
|
262
|
+
return self.construct_baseline(
|
|
263
|
+
data, x=x, mask=mask, tol=tol, lam=lam, max_iter=max_iter,
|
|
264
|
+
return_buf=save_figures)
|
|
265
|
+
|
|
266
|
+
@staticmethod
|
|
267
|
+
def construct_baseline(
|
|
268
|
+
y, x=None, mask=None, tol=1.e-6, lam=1.e6, max_iter=20, title=None,
|
|
269
|
+
xlabel=None, ylabel=None, interactive=False, return_buf=False):
|
|
270
|
+
"""Construct and return the baseline for a dataset.
|
|
271
|
+
|
|
272
|
+
:param y: Input data.
|
|
273
|
+
:type y: numpy.array
|
|
274
|
+
:param x: Independent dimension (only used when interactive is
|
|
275
|
+
`True` of when filename is set).
|
|
276
|
+
:type x: array-like, optional
|
|
277
|
+
:param mask: A mask to apply to the spectrum before baseline
|
|
278
|
+
construction.
|
|
279
|
+
:type mask: array-like, optional
|
|
280
|
+
:param tol: The convergence tolerence, defaults to `1.e-6`.
|
|
281
|
+
:type tol: float, optional
|
|
282
|
+
:param lam: The &lambda (smoothness) parameter (the balance
|
|
283
|
+
between the residual of the data and the baseline and the
|
|
284
|
+
smoothness of the baseline). The suggested range is between
|
|
285
|
+
100 and 10^8, defaults to `10^6`.
|
|
286
|
+
:type lam: float, optional
|
|
287
|
+
:param max_iter: The maximum number of iterations,
|
|
288
|
+
defaults to `20`.
|
|
289
|
+
:type max_iter: int, optional
|
|
290
|
+
:param title: Title for the displayed figure.
|
|
291
|
+
:type title: str, optional
|
|
292
|
+
:param xlabel: Label for the x-axis of the displayed figure.
|
|
293
|
+
:type xlabel: str, optional
|
|
294
|
+
:param ylabel: Label for the y-axis of the displayed figure.
|
|
295
|
+
:type ylabel: str, optional
|
|
296
|
+
:param interactive: Allows for user interactions,
|
|
297
|
+
defaults to `False`.
|
|
298
|
+
:type interactive: bool, optional
|
|
299
|
+
:param return_buf: Return an in-memory object as a byte stream
|
|
300
|
+
represention of the Matplotlib figure, defaults to `False`.
|
|
301
|
+
:type return_buf: bool, optional
|
|
302
|
+
:return: The smoothed baseline and the configuration and a
|
|
303
|
+
byte stream represention of the Matplotlib figure if
|
|
304
|
+
return_buf is `True` (`None` otherwise)
|
|
305
|
+
:rtype: numpy.array, dict, Union[io.BytesIO, None]
|
|
306
|
+
"""
|
|
307
|
+
# Third party modules
|
|
308
|
+
from matplotlib.widgets import TextBox, Button
|
|
309
|
+
import matplotlib.pyplot as plt
|
|
310
|
+
|
|
311
|
+
# Local modules
|
|
312
|
+
from CHAP.utils.general import (
|
|
313
|
+
baseline_arPLS,
|
|
314
|
+
fig_to_iobuf,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
def change_fig_subtitle(maxed_out=False, subtitle=None):
|
|
318
|
+
"""Change the figure's subtitle."""
|
|
319
|
+
if fig_subtitles:
|
|
320
|
+
fig_subtitles[0].remove()
|
|
321
|
+
fig_subtitles.pop()
|
|
322
|
+
if subtitle is None:
|
|
323
|
+
subtitle = r'$\lambda$ = 'f'{lambdas[-1]:.2e}, '
|
|
324
|
+
if maxed_out:
|
|
325
|
+
subtitle += f'# iter = {num_iters[-1]} (maxed out) '
|
|
326
|
+
else:
|
|
327
|
+
subtitle += f'# iter = {num_iters[-1]} '
|
|
328
|
+
subtitle += f'error = {errors[-1]:.2e}'
|
|
329
|
+
fig_subtitles.append(
|
|
330
|
+
plt.figtext(*subtitle_pos, subtitle, **subtitle_props))
|
|
331
|
+
|
|
332
|
+
def select_lambda(expression):
|
|
333
|
+
"""Callback function for the "Select lambda" TextBox."""
|
|
334
|
+
if not expression:
|
|
335
|
+
return
|
|
336
|
+
try:
|
|
337
|
+
lam = float(expression)
|
|
338
|
+
if lam < 0:
|
|
339
|
+
raise ValueError
|
|
340
|
+
except ValueError:
|
|
341
|
+
change_fig_subtitle(
|
|
342
|
+
subtitle='Invalid lambda, enter a positive number')
|
|
343
|
+
else:
|
|
344
|
+
lambdas.pop()
|
|
345
|
+
lambdas.append(10**lam)
|
|
346
|
+
baseline, _, _, num_iter, error = get_baseline(
|
|
347
|
+
y, mask=mask, tol=tol, lam=lambdas[-1], max_iter=max_iter)
|
|
348
|
+
num_iters.pop()
|
|
349
|
+
num_iters.append(num_iter)
|
|
350
|
+
errors.pop()
|
|
351
|
+
errors.append(error)
|
|
352
|
+
if num_iter < max_iter:
|
|
353
|
+
change_fig_subtitle()
|
|
354
|
+
else:
|
|
355
|
+
change_fig_subtitle(maxed_out=True)
|
|
356
|
+
baseline_handle.set_ydata(baseline)
|
|
357
|
+
lambda_box.set_val('')
|
|
358
|
+
plt.draw()
|
|
359
|
+
|
|
360
|
+
def continue_iter(event):
|
|
361
|
+
"""Callback function for the "Continue" button."""
|
|
362
|
+
baseline, _, w, n_iter, error = get_baseline(
|
|
363
|
+
y, mask=mask, w=weights[-1], tol=tol, lam=lambdas[-1],
|
|
364
|
+
max_iter=max_iter)
|
|
365
|
+
num_iters[-1] += n_iter
|
|
366
|
+
errors.pop()
|
|
367
|
+
errors.append(error)
|
|
368
|
+
if n_iter < max_iter:
|
|
369
|
+
change_fig_subtitle()
|
|
370
|
+
else:
|
|
371
|
+
change_fig_subtitle(maxed_out=True)
|
|
372
|
+
baseline_handle.set_ydata(baseline)
|
|
373
|
+
plt.draw()
|
|
374
|
+
weights.pop()
|
|
375
|
+
weights.append(w)
|
|
376
|
+
|
|
377
|
+
def confirm(event):
|
|
378
|
+
"""Callback function for the "Confirm" button."""
|
|
379
|
+
plt.close()
|
|
380
|
+
|
|
381
|
+
def get_baseline(
|
|
382
|
+
y, mask=None, w=None, tol=1.e-6, lam=1.6, max_iter=20):
|
|
383
|
+
"""Get a baseline."""
|
|
384
|
+
return baseline_arPLS(
|
|
385
|
+
y, mask=mask, w=w, tol=tol, lam=lam, max_iter=max_iter,
|
|
386
|
+
full_output=True)
|
|
387
|
+
|
|
388
|
+
baseline, _, w, num_iter, error = get_baseline(
|
|
389
|
+
y, mask=mask, tol=tol, lam=lam, max_iter=max_iter)
|
|
390
|
+
|
|
391
|
+
if not interactive and not return_buf:
|
|
392
|
+
config = {
|
|
393
|
+
'tol': tol, 'lambda': lam, 'max_iter': max_iter,
|
|
394
|
+
'num_iter': num_iter, 'error': error, 'mask': mask}
|
|
395
|
+
return baseline, config, None
|
|
396
|
+
|
|
397
|
+
lambdas = [lam]
|
|
398
|
+
weights = [w]
|
|
399
|
+
num_iters = [num_iter]
|
|
400
|
+
errors = [error]
|
|
401
|
+
fig_subtitles = []
|
|
402
|
+
|
|
403
|
+
# Check inputs
|
|
404
|
+
if x is None:
|
|
405
|
+
x = np.arange(y.size)
|
|
406
|
+
|
|
407
|
+
# Setup the Matplotlib figure
|
|
408
|
+
title_pos = (0.5, 0.95)
|
|
409
|
+
title_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center',
|
|
410
|
+
'verticalalignment': 'bottom'}
|
|
411
|
+
subtitle_pos = (0.5, 0.90)
|
|
412
|
+
subtitle_props = {'fontsize': 'x-large',
|
|
413
|
+
'horizontalalignment': 'center',
|
|
414
|
+
'verticalalignment': 'bottom'}
|
|
415
|
+
fig, ax = plt.subplots(figsize=(11, 8.5))
|
|
416
|
+
if mask is None:
|
|
417
|
+
ax.plot(x, y, label='input data')
|
|
418
|
+
else:
|
|
419
|
+
ax.plot(
|
|
420
|
+
x[mask.astype(bool)], y[mask.astype(bool)], label='input data')
|
|
421
|
+
baseline_handle = ax.plot(x, baseline, label='baseline')[0]
|
|
422
|
+
# ax.plot(x, y-baseline, label='baseline corrected data')
|
|
423
|
+
ax.legend()
|
|
424
|
+
ax.set_xlabel(xlabel, fontsize='x-large')
|
|
425
|
+
ax.set_ylabel(ylabel, fontsize='x-large')
|
|
426
|
+
ax.set_xlim(x[0], x[-1])
|
|
427
|
+
if title is None:
|
|
428
|
+
fig_title = plt.figtext(*title_pos, 'Baseline', **title_props)
|
|
429
|
+
else:
|
|
430
|
+
fig_title = plt.figtext(*title_pos, title, **title_props)
|
|
431
|
+
if num_iter < max_iter:
|
|
432
|
+
change_fig_subtitle()
|
|
433
|
+
else:
|
|
434
|
+
change_fig_subtitle(maxed_out=True)
|
|
435
|
+
fig.subplots_adjust(bottom=0.0, top=0.85)
|
|
436
|
+
|
|
437
|
+
lambda_box = None
|
|
438
|
+
if interactive:
|
|
439
|
+
|
|
440
|
+
fig.subplots_adjust(bottom=0.2)
|
|
441
|
+
|
|
442
|
+
# Setup TextBox
|
|
443
|
+
lambda_box = TextBox(
|
|
444
|
+
plt.axes([0.15, 0.05, 0.15, 0.075]), r'log($\lambda$)')
|
|
445
|
+
lambda_cid = lambda_box.on_submit(select_lambda)
|
|
446
|
+
|
|
447
|
+
# Setup "Continue" button
|
|
448
|
+
continue_btn = Button(
|
|
449
|
+
plt.axes([0.45, 0.05, 0.15, 0.075]), 'Continue smoothing')
|
|
450
|
+
continue_cid = continue_btn.on_clicked(continue_iter)
|
|
451
|
+
|
|
452
|
+
# Setup "Confirm" button
|
|
453
|
+
confirm_btn = Button(
|
|
454
|
+
plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm')
|
|
455
|
+
confirm_cid = confirm_btn.on_clicked(confirm)
|
|
456
|
+
|
|
457
|
+
# Show figure for user interaction
|
|
458
|
+
plt.show()
|
|
459
|
+
|
|
460
|
+
# Disconnect all widget callbacks when figure is closed
|
|
461
|
+
lambda_box.disconnect(lambda_cid)
|
|
462
|
+
continue_btn.disconnect(continue_cid)
|
|
463
|
+
confirm_btn.disconnect(confirm_cid)
|
|
464
|
+
|
|
465
|
+
# ... and remove the buttons before returning the figure
|
|
466
|
+
lambda_box.ax.remove()
|
|
467
|
+
continue_btn.ax.remove()
|
|
468
|
+
confirm_btn.ax.remove()
|
|
469
|
+
|
|
470
|
+
if return_buf:
|
|
471
|
+
fig_title.set_in_layout(True)
|
|
472
|
+
fig_subtitles[-1].set_in_layout(True)
|
|
473
|
+
fig.tight_layout(rect=(0, 0, 1, 0.90))
|
|
474
|
+
buf = fig_to_iobuf(fig)
|
|
475
|
+
else:
|
|
476
|
+
buf = None
|
|
477
|
+
plt.close()
|
|
478
|
+
|
|
479
|
+
config = {
|
|
480
|
+
'tol': tol, 'lambda': lambdas[-1], 'max_iter': max_iter,
|
|
481
|
+
'num_iter': num_iters[-1], 'error': errors[-1], 'mask': mask}
|
|
482
|
+
return baseline, config, buf
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
class ConvertStructuredProcessor(Processor):
|
|
486
|
+
"""Processor for converting map data between structured /
|
|
487
|
+
unstructued formats.
|
|
488
|
+
"""
|
|
489
|
+
def process(self, data):
|
|
490
|
+
# Local modules
|
|
491
|
+
from CHAP.utils.converters import convert_structured_unstructured
|
|
492
|
+
|
|
493
|
+
data = self.unwrap_pipelinedata(data)[0]
|
|
494
|
+
return convert_structured_unstructured(data)
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
class ImageProcessor(Processor):
|
|
498
|
+
"""A Processor to perform various visualization operations on
|
|
499
|
+
images (slices) selected from a NeXus object."""
|
|
500
|
+
def __init__(self):
|
|
501
|
+
super().__init__()
|
|
502
|
+
self._figconfig = None
|
|
503
|
+
|
|
504
|
+
def process(self, data, config=None, save_figures=True):
|
|
505
|
+
"""Plot and/or return image slices from a NeXus NXobject
|
|
506
|
+
object with a default plottable data path.
|
|
507
|
+
|
|
508
|
+
:param data: Input data.
|
|
509
|
+
:type data: list[PipelineData]
|
|
510
|
+
:param config: Initialization parameters for an instance of
|
|
511
|
+
CHAP.common.models.ImageProcessorConfig
|
|
512
|
+
:type config: dict, optional
|
|
513
|
+
:param save_figures: Return the plottable image(s) to be
|
|
514
|
+
written to file downstream in the pipeline,
|
|
515
|
+
defaults to `True`.
|
|
516
|
+
:type save_figures: bool, optional
|
|
517
|
+
:return: The plottable image(s) (for save_figures = `True`)
|
|
518
|
+
or the input default NeXus NXdata object
|
|
519
|
+
(for save_figures = `False`).
|
|
520
|
+
:rtype: Union[bytes, nexusformat.nexus.NXdata, numpy.ndarray]
|
|
521
|
+
"""
|
|
522
|
+
if not save_figures and not self.interactive:
|
|
523
|
+
return None
|
|
524
|
+
|
|
525
|
+
# Third party modules
|
|
526
|
+
from nexusformat.nexus import nxsetconfig
|
|
527
|
+
|
|
528
|
+
nxsetconfig(memory=100000)
|
|
529
|
+
|
|
530
|
+
# Load the default data
|
|
531
|
+
try:
|
|
532
|
+
nxdata = self.get_data(data).get_default()
|
|
533
|
+
except Exception as exc:
|
|
534
|
+
raise ValueError(
|
|
535
|
+
'Unable the load the default NXdata object from the input '
|
|
536
|
+
f'pipeline ({data})') from exc
|
|
537
|
+
|
|
538
|
+
# Load the validated image processor configuration
|
|
539
|
+
if config is None:
|
|
540
|
+
# Local modules
|
|
541
|
+
from CHAP.common.models.common import ImageProcessorConfig
|
|
542
|
+
|
|
543
|
+
config = ImageProcessorConfig()
|
|
544
|
+
else:
|
|
545
|
+
config = self.get_config(
|
|
546
|
+
data, config=config,
|
|
547
|
+
schema='common.models.ImageProcessorConfig')
|
|
548
|
+
|
|
549
|
+
# Get the axes info and image slice(s)
|
|
550
|
+
try:
|
|
551
|
+
data = nxdata.nxsignal
|
|
552
|
+
except Exception as exc:
|
|
553
|
+
raise ValueError('Unable the find the default signal in:\n'
|
|
554
|
+
f'({nxdata.tree})') from exc
|
|
555
|
+
axis = config.axis
|
|
556
|
+
axes = nxdata.attrs.get('axes', None)
|
|
557
|
+
if axes is not None:
|
|
558
|
+
axes = list(axes.nxdata)
|
|
559
|
+
if nxdata.nxsignal.ndim == 2:
|
|
560
|
+
exit('ImageProcessor not tested yet for a 2D dataset')
|
|
561
|
+
if axis is not None:
|
|
562
|
+
axis = None
|
|
563
|
+
self.logger.warning('Ignoring parameter axis')
|
|
564
|
+
if index is not None:
|
|
565
|
+
index = None
|
|
566
|
+
self.logger.warning('Ignoring parameter index')
|
|
567
|
+
if coord is not None:
|
|
568
|
+
coord = None
|
|
569
|
+
self.logger.warning('Ignoring parameter coord')
|
|
570
|
+
elif nxdata.nxsignal.ndim == 3:
|
|
571
|
+
if isinstance(axis, int):
|
|
572
|
+
if not 0 <= axis < nxdata.nxsignal.ndim:
|
|
573
|
+
raise ValueError(f'axis index out of range ({axis} not in '
|
|
574
|
+
f'[0, {nxdata.nxsignal.ndim-1}])')
|
|
575
|
+
elif isinstance(axis, str):
|
|
576
|
+
if axes is None or axis not in axes:
|
|
577
|
+
raise ValueError(
|
|
578
|
+
f'Unable to match axis = {axis} in {nxdata.tree}')
|
|
579
|
+
axis = axes.index(axis)
|
|
580
|
+
else:
|
|
581
|
+
raise ValueError(f'Invalid parameter axis ({axis})')
|
|
582
|
+
if axis:
|
|
583
|
+
data = np.moveaxis(data, axis, 0)
|
|
584
|
+
if axes is not None and hasattr(nxdata, axes[axis]):
|
|
585
|
+
if axis == 1:
|
|
586
|
+
axes = [axes[1], axes[0], axes[2]]
|
|
587
|
+
elif axis:
|
|
588
|
+
axes = [axes[2], axes[0], axes[1]]
|
|
589
|
+
axis_name = axes[0]
|
|
590
|
+
if 'units' in nxdata[axis_name].attrs:
|
|
591
|
+
axis_unit = f' ({nxdata[axis_name].units})'
|
|
592
|
+
else:
|
|
593
|
+
axis_unit = ''
|
|
594
|
+
row_label = axes[2]
|
|
595
|
+
row_coords = nxdata[row_label].nxdata
|
|
596
|
+
column_label = axes[1]
|
|
597
|
+
column_coords = nxdata[column_label].nxdata
|
|
598
|
+
if 'units' in nxdata[row_label].attrs:
|
|
599
|
+
row_label += f' ({nxdata[row_label].units})'
|
|
600
|
+
if 'units' in nxdata[column_label].attrs:
|
|
601
|
+
column_label += f' ({nxdata[column_label].units})'
|
|
602
|
+
else:
|
|
603
|
+
exit('No axes attribute not tested yet')
|
|
604
|
+
axes = [0, 1, 2]
|
|
605
|
+
axes.pop(axis)
|
|
606
|
+
axis_name = f'axis {axis}'
|
|
607
|
+
axis_unit = ''
|
|
608
|
+
# row_label = f'axis {axis[1]}'
|
|
609
|
+
# row_coords = None
|
|
610
|
+
# column_label = f'axis {axis[0]}'
|
|
611
|
+
# column_coords = None
|
|
612
|
+
axis_coords = nxdata[axis_name].nxdata
|
|
613
|
+
else:
|
|
614
|
+
raise ValueError('Invalid data dimension (must be 2D or 3D)')
|
|
615
|
+
if config.coord_range is None:
|
|
616
|
+
index_range = config.index_range
|
|
617
|
+
else:
|
|
618
|
+
# Local modules
|
|
619
|
+
from CHAP.utils.general import (
|
|
620
|
+
index_nearest_down,
|
|
621
|
+
index_nearest_up,
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
if config.index_range is not None:
|
|
625
|
+
self.logger.warning('Ignoring parameter index_range')
|
|
626
|
+
if isinstance(config.coord_range, (int, float)):
|
|
627
|
+
index_range = index_nearest_up(
|
|
628
|
+
axis_coords, config.coord_range)
|
|
629
|
+
elif len(config.coord_range) == 2:
|
|
630
|
+
index_range = [
|
|
631
|
+
index_nearest_up(axis_coords, config.coord_range[0]),
|
|
632
|
+
index_nearest_down(axis_coords, config.coord_range[1])]
|
|
633
|
+
else:
|
|
634
|
+
index_range = [
|
|
635
|
+
index_nearest_up(axis_coords, config.coord_range[0]),
|
|
636
|
+
index_nearest_down(axis_coords, config.coord_range[1]),
|
|
637
|
+
int(max(1, config.coord_range[2] /
|
|
638
|
+
((axis_coords[-1]-axis_coords[0])/data.shape[0])))]
|
|
639
|
+
if index_range == -1:
|
|
640
|
+
index_range = nxdata.nxsignal.shape[axis] // 2
|
|
641
|
+
if isinstance(index_range, int):
|
|
642
|
+
data = data[index_range]
|
|
643
|
+
axis_coords = [axis_coords[index_range]]
|
|
644
|
+
elif index_range is not None:
|
|
645
|
+
slice_ = slice(*tuple(index_range))
|
|
646
|
+
data = data[slice_]
|
|
647
|
+
axis_coords = axis_coords[slice_]
|
|
648
|
+
if config.vrange is None:
|
|
649
|
+
vrange = (data.min(), data.max())
|
|
650
|
+
else:
|
|
651
|
+
vrange = config.vrange
|
|
652
|
+
|
|
653
|
+
# Create the figure configuration
|
|
654
|
+
self._figconfig = {
|
|
655
|
+
'title': f'{nxdata.nxpath}/{nxdata.signal}',
|
|
656
|
+
'axis_name': axis_name,
|
|
657
|
+
'axis_unit': axis_unit,
|
|
658
|
+
'axis_coords': axis_coords,
|
|
659
|
+
'row_label': row_label,
|
|
660
|
+
'column_label': column_label,
|
|
661
|
+
'extent': (row_coords[0], row_coords[-1],
|
|
662
|
+
column_coords[-1], column_coords[0]),
|
|
663
|
+
'vrange': vrange,
|
|
664
|
+
}
|
|
665
|
+
self.logger.debug(f'figure configuration:\n{self._figconfig}')
|
|
666
|
+
|
|
667
|
+
if len(axis_coords) == 1:
|
|
668
|
+
# Create a figure for a single image slice
|
|
669
|
+
if config.animation:
|
|
670
|
+
self.logger.warning(
|
|
671
|
+
'Ignoring animation parameter for a single image')
|
|
672
|
+
fileformat = 'png'
|
|
673
|
+
if config.fileformat is None:
|
|
674
|
+
fileformat = 'png'
|
|
675
|
+
else:
|
|
676
|
+
fileformat = config.fileformat
|
|
677
|
+
fig, plt = self._create_figure(np.squeeze(data))
|
|
678
|
+
if self.interactive:
|
|
679
|
+
plt.show()
|
|
680
|
+
if save_figures:
|
|
681
|
+
# Local modules
|
|
682
|
+
from CHAP.utils.general import fig_to_iobuf
|
|
683
|
+
|
|
684
|
+
# Return a binary image of the figure
|
|
685
|
+
buf, fileformat = fig_to_iobuf(fig, fileformat=fileformat)
|
|
686
|
+
else:
|
|
687
|
+
buf = None
|
|
688
|
+
plt.close()
|
|
689
|
+
if save_figures:
|
|
690
|
+
return {'image_data': buf, 'fileformat': fileformat}
|
|
691
|
+
return nxdata
|
|
692
|
+
|
|
693
|
+
# Create an animation for a set of image slices
|
|
694
|
+
if self.interactive or config.animation:
|
|
695
|
+
ani = self._create_animation(data)
|
|
696
|
+
else:
|
|
697
|
+
ani = None
|
|
698
|
+
|
|
699
|
+
if save_figures:
|
|
700
|
+
if config.animation:
|
|
701
|
+
# Return the animation object
|
|
702
|
+
if (config.fileformat is not None
|
|
703
|
+
and config.fileformat != 'gif'):
|
|
704
|
+
self.logger.warning(
|
|
705
|
+
'Ignoring inconsistent file extension')
|
|
706
|
+
fileformat = 'gif'
|
|
707
|
+
image_data = ani
|
|
708
|
+
else:
|
|
709
|
+
# Return the set of image slices as a tif stack
|
|
710
|
+
if (config.fileformat is not None
|
|
711
|
+
and config.fileformat != 'tif'):
|
|
712
|
+
self.logger.warning(
|
|
713
|
+
'Ignoring inconsistent file extension')
|
|
714
|
+
fileformat = 'tif'
|
|
715
|
+
data = 255.0*((data - vrange[0])/
|
|
716
|
+
(vrange[1] - vrange[0]))
|
|
717
|
+
image_data = data.astype(np.uint8)
|
|
718
|
+
return {'image_data': image_data, 'fileformat': fileformat}
|
|
719
|
+
return nxdata
|
|
720
|
+
|
|
721
|
+
def _create_animation(self, data):
|
|
722
|
+
# Third party modules
|
|
723
|
+
from functools import partial
|
|
724
|
+
from matplotlib import animation
|
|
725
|
+
|
|
726
|
+
def animate(i, plt, title):
|
|
727
|
+
im.set_array(data[i])
|
|
728
|
+
title.set_text(self._set_title(i))
|
|
729
|
+
plt.draw()
|
|
730
|
+
return im,
|
|
731
|
+
|
|
732
|
+
fig, im, plt, title = self._create_figure(data[0], animated=True)
|
|
733
|
+
ani = animation.FuncAnimation(
|
|
734
|
+
fig, partial(animate, plt=plt, title=title), frames=data.shape[0],
|
|
735
|
+
interval=50, blit=True)
|
|
736
|
+
if self.interactive:
|
|
737
|
+
plt.show()
|
|
738
|
+
plt.close()
|
|
739
|
+
|
|
740
|
+
return ani
|
|
741
|
+
|
|
742
|
+
def _create_figure(self, image, animated=False):
|
|
743
|
+
# Third party modules
|
|
744
|
+
import matplotlib.pyplot as plt
|
|
745
|
+
|
|
746
|
+
fig, ax = plt.subplots()
|
|
747
|
+
im = plt.imshow(
|
|
748
|
+
image, extent=self._figconfig['extent'], origin='lower',
|
|
749
|
+
vmin=self._figconfig['vrange'][0],
|
|
750
|
+
vmax=self._figconfig['vrange'][1], cmap='gray', animated=animated)
|
|
751
|
+
fig.suptitle(self._figconfig['title'], fontsize='x-large')
|
|
752
|
+
title = ax.set_title(self._set_title(0), fontsize='x-large', pad=10)
|
|
753
|
+
ax.set_xlabel(self._figconfig['row_label'], fontsize='x-large')
|
|
754
|
+
ax.set_ylabel(self._figconfig['column_label'], fontsize='x-large')
|
|
755
|
+
plt.colorbar()
|
|
756
|
+
fig.tight_layout()
|
|
757
|
+
if animated:
|
|
758
|
+
return fig, im, plt, title
|
|
759
|
+
return fig, plt
|
|
760
|
+
|
|
761
|
+
def _set_title(self, i):
|
|
762
|
+
return self._figconfig['axis_name'] +\
|
|
763
|
+
f' = {self._figconfig["axis_coords"][i]:.3f}' +\
|
|
764
|
+
self._figconfig['axis_unit']
|
|
765
|
+
|
|
766
|
+
|
|
767
|
+
class MapProcessor(Processor):
|
|
768
|
+
"""A Processor that takes a map configuration and returns a NeXus
|
|
769
|
+
NXentry object representing that map's metadata and any
|
|
770
|
+
scalar-valued raw data requested by the supplied map configuration.
|
|
771
|
+
|
|
772
|
+
:ivar config: Map configuration parameters to initialize an
|
|
773
|
+
instance of common.models.map.MapConfig. Any values in
|
|
774
|
+
`'config'` supplant their corresponding values obtained from
|
|
775
|
+
the pipeline data configuration.
|
|
776
|
+
:type config: Union[dict, common.models.map.MapConfig]
|
|
777
|
+
:ivar detector_config: Detector configurations of the detectors to
|
|
778
|
+
include raw data for in the returned NeXus NXentry object
|
|
779
|
+
(overruling detector info in the pipeline data, if present).
|
|
780
|
+
:type detector_config: Union[
|
|
781
|
+
dict, common.models.map.DetectorConfig]
|
|
782
|
+
:ivar num_proc: Number of processors used to read map,
|
|
783
|
+
defaults to `1`.
|
|
784
|
+
:type num_proc: int, optional
|
|
785
|
+
"""
|
|
786
|
+
pipeline_fields: dict = Field(
|
|
787
|
+
default = {
|
|
788
|
+
'config': 'common.models.map.MapConfig',
|
|
789
|
+
'detector_config': 'common.models.map.DetectorConfig'},
|
|
790
|
+
init_var=True)
|
|
791
|
+
config: MapConfig
|
|
792
|
+
detector_config: DetectorConfig
|
|
793
|
+
num_proc: Optional[conint(gt=0)] = 1
|
|
794
|
+
|
|
795
|
+
@field_validator('num_proc')
|
|
796
|
+
@classmethod
|
|
797
|
+
def validate_num_proc(cls, num_proc, info):
|
|
798
|
+
"""Validate the number of processors.
|
|
799
|
+
|
|
800
|
+
:ivar num_proc: Number of processors used to read map,
|
|
801
|
+
defaults to `1`.
|
|
802
|
+
:type num_proc: int, optional
|
|
803
|
+
:return: Number of processors
|
|
804
|
+
:rtype: str
|
|
805
|
+
"""
|
|
806
|
+
if num_proc > 1:
|
|
807
|
+
logger = info['logger']
|
|
808
|
+
try:
|
|
809
|
+
# Third party modules
|
|
810
|
+
from mpi4py import MPI
|
|
811
|
+
|
|
812
|
+
if num_proc > os.cpu_count():
|
|
813
|
+
logger.warning(
|
|
814
|
+
f'The requested number of processors ({num_proc}) '
|
|
815
|
+
'exceeds the maximum number of processors '
|
|
816
|
+
f'({os.cpu_count()}): reset it to {os.cpu_count()}')
|
|
817
|
+
num_proc = os.cpu_count()
|
|
818
|
+
except Exception:
|
|
819
|
+
logger.warning('Unable to load mpi4py, running serially')
|
|
820
|
+
num_proc = 1
|
|
821
|
+
logger.debug(f'Number of processors: {num_proc}')
|
|
822
|
+
return num_proc
|
|
823
|
+
|
|
824
|
+
def process(
|
|
825
|
+
self, data, placeholder_data=False, comm=None):
|
|
826
|
+
|
|
827
|
+
return self._process(data, placeholder_data, comm)
|
|
828
|
+
|
|
829
|
+
# @profile
|
|
830
|
+
def _process(self, data, placeholder_data=False, comm=None):
|
|
831
|
+
"""Process that takes a map configuration and returns a NeXus
|
|
832
|
+
NXentry object representing the map.
|
|
833
|
+
|
|
834
|
+
:param data: Pipeline data list with an optional item for the
|
|
835
|
+
map configuration parameters with
|
|
836
|
+
`'common.models.map.MapConfig'` as its `'schema'` key.
|
|
837
|
+
:type data: list[PipelineData]
|
|
838
|
+
:param placeholder_data: For SMB EDD maps only. Value to use
|
|
839
|
+
for missing detector data frames, or `False` if missing
|
|
840
|
+
data should raise an error, defaults to `False`.
|
|
841
|
+
:type placeholder_data: object, optional
|
|
842
|
+
:param comm: MPI communicator.
|
|
843
|
+
:type comm: mpi4py.MPI.Comm, optional
|
|
844
|
+
:return: Map data and metadata.
|
|
845
|
+
:rtype: nexusformat.nexus.NXentry
|
|
846
|
+
"""
|
|
847
|
+
# System modules
|
|
848
|
+
import logging
|
|
849
|
+
|
|
850
|
+
# Third party modules
|
|
851
|
+
import yaml
|
|
852
|
+
|
|
853
|
+
# Check for available metadata
|
|
854
|
+
metadata = {}
|
|
855
|
+
if data:
|
|
856
|
+
try:
|
|
857
|
+
for d in data:
|
|
858
|
+
if d.get('schema') == 'metadata':
|
|
859
|
+
metadata = d.get('data')
|
|
860
|
+
break
|
|
861
|
+
except Exception:
|
|
862
|
+
pass
|
|
863
|
+
if len(metadata) > 1:
|
|
864
|
+
raise ValueError(
|
|
865
|
+
f'Unable to find unique data for schema "metadata"')
|
|
866
|
+
if metadata:
|
|
867
|
+
metadata = self._get_metadata_config(metadata[0])
|
|
868
|
+
|
|
869
|
+
# Create the sub-pipeline configuration for each processor
|
|
870
|
+
# FIX: catered to EDD with one spec scan
|
|
871
|
+
assert len(self.config.spec_scans) == 1
|
|
872
|
+
spec_scans = self.config.spec_scans[0]
|
|
873
|
+
scan_numbers = spec_scans.scan_numbers
|
|
874
|
+
num_scan = len(scan_numbers)
|
|
875
|
+
if num_scan < self.num_proc:
|
|
876
|
+
self.logger.warning(
|
|
877
|
+
f'Requested number of processors ({self.num_proc}) exceeds '
|
|
878
|
+
f'the number of scans ({num_scan}): reset it to {num_scan}')
|
|
879
|
+
self.num_proc = num_scan
|
|
880
|
+
if self.num_proc == 1:
|
|
881
|
+
common_comm = comm
|
|
882
|
+
offsets = [0]
|
|
883
|
+
else:
|
|
884
|
+
# System modules
|
|
885
|
+
from tempfile import NamedTemporaryFile
|
|
886
|
+
|
|
887
|
+
# Local modules
|
|
888
|
+
from CHAP.models import RunConfig
|
|
889
|
+
|
|
890
|
+
raise NotImplementedError(
|
|
891
|
+
'MapProcessor needs testing for num_proc>1')
|
|
892
|
+
scans_per_proc = num_scan//self.num_proc
|
|
893
|
+
num = scans_per_proc
|
|
894
|
+
if num_scan - scans_per_proc*self.num_proc > 0:
|
|
895
|
+
num += 1
|
|
896
|
+
spec_scans.scan_numbers = scan_numbers[:num]
|
|
897
|
+
n_scan = num
|
|
898
|
+
pipeline_config = []
|
|
899
|
+
offsets = [0]
|
|
900
|
+
for n_proc in range(1, self.num_proc):
|
|
901
|
+
num = scans_per_proc
|
|
902
|
+
if n_proc < num_scan - scans_per_proc*self.num_proc:
|
|
903
|
+
num += 1
|
|
904
|
+
config = self.config.model_dump()
|
|
905
|
+
config['spec_scans'][0]['scan_numbers'] = \
|
|
906
|
+
scan_numbers[n_scan:n_scan+num]
|
|
907
|
+
pipeline_config.append(
|
|
908
|
+
[{'common.MapProcessor': {
|
|
909
|
+
'config': config,
|
|
910
|
+
'detector_config': self.detector_config.model_dump(),
|
|
911
|
+
}}])
|
|
912
|
+
offsets.append(n_scan)
|
|
913
|
+
n_scan += num
|
|
914
|
+
|
|
915
|
+
# Spawn the workers to run the sub-pipeline
|
|
916
|
+
run_config = RunConfig(
|
|
917
|
+
log_level=logging.getLevelName(self.logger.level), spawn=1)
|
|
918
|
+
tmp_names = []
|
|
919
|
+
with NamedTemporaryFile(delete=False) as fp:
|
|
920
|
+
# pylint: disable=c-extension-no-member
|
|
921
|
+
fp_name = fp.name
|
|
922
|
+
tmp_names.append(fp_name)
|
|
923
|
+
with open(fp_name, 'w') as f:
|
|
924
|
+
yaml.dump({'config': {'spawn': 1}}, f, sort_keys=False)
|
|
925
|
+
for n_proc in range(1, self.num_proc):
|
|
926
|
+
f_name = f'{fp_name}_{n_proc}'
|
|
927
|
+
tmp_names.append(f_name)
|
|
928
|
+
with open(f_name, 'w') as f:
|
|
929
|
+
yaml.dump(
|
|
930
|
+
# FIX once comm is a field of RunConfig
|
|
931
|
+
# {'config': run_config.model_dump(exclude='comm'),
|
|
932
|
+
{'config': run_config.model_dump(),
|
|
933
|
+
'pipeline': pipeline_config[n_proc-1]},
|
|
934
|
+
f, sort_keys=False)
|
|
935
|
+
# pylint: disable=used-before-assignment
|
|
936
|
+
sub_comm = MPI.COMM_SELF.Spawn(
|
|
937
|
+
'CHAP', args=[fp_name], maxprocs=self.num_proc-1)
|
|
938
|
+
common_comm = sub_comm.Merge(False)
|
|
939
|
+
# Align with the barrier in RunConfig() on common_comm
|
|
940
|
+
# called from the spawned main() in common_comm
|
|
941
|
+
common_comm.barrier()
|
|
942
|
+
# Align with the barrier in run() on common_comm
|
|
943
|
+
# called from the spawned main()
|
|
944
|
+
common_comm.barrier()
|
|
945
|
+
|
|
946
|
+
if common_comm is None:
|
|
947
|
+
self.num_proc = 1
|
|
948
|
+
rank = 0
|
|
949
|
+
else:
|
|
950
|
+
self.num_proc = common_comm.Get_size()
|
|
951
|
+
rank = common_comm.Get_rank()
|
|
952
|
+
if self.num_proc == 1:
|
|
953
|
+
offset = 0
|
|
954
|
+
else:
|
|
955
|
+
num_scan = common_comm.bcast(num_scan, root=0)
|
|
956
|
+
offset = common_comm.scatter(offsets, root=0)
|
|
957
|
+
|
|
958
|
+
# Read the raw data
|
|
959
|
+
if self.config.experiment_type == 'EDD':
|
|
960
|
+
data, independent_dimensions, all_scalar_data = \
|
|
961
|
+
self._read_raw_data_edd(
|
|
962
|
+
common_comm, num_scan, offset, placeholder_data)
|
|
963
|
+
else:
|
|
964
|
+
data, independent_dimensions, all_scalar_data = \
|
|
965
|
+
self._read_raw_data(common_comm, num_scan, offset)
|
|
966
|
+
if not rank:
|
|
967
|
+
self.logger.debug(f'Data shape: {data.shape}')
|
|
968
|
+
if independent_dimensions is not None:
|
|
969
|
+
self.logger.debug('Independent dimensions shape: '
|
|
970
|
+
f'{independent_dimensions.shape}')
|
|
971
|
+
if all_scalar_data is not None:
|
|
972
|
+
self.logger.debug('Scalar data shape: '
|
|
973
|
+
f'{all_scalar_data.shape}')
|
|
974
|
+
|
|
975
|
+
if rank:
|
|
976
|
+
return None
|
|
977
|
+
|
|
978
|
+
if self.num_proc > 1:
|
|
979
|
+
# Reset the scan_numbers to the original full set
|
|
980
|
+
spec_scans.scan_numbers = scan_numbers
|
|
981
|
+
# Align with the barrier in main() on common_comm
|
|
982
|
+
# when disconnecting the spawned worker
|
|
983
|
+
common_comm.barrier()
|
|
984
|
+
# Disconnect spawned workers and cleanup temporary files
|
|
985
|
+
sub_comm.Disconnect()
|
|
986
|
+
for tmp_name in tmp_names:
|
|
987
|
+
os.remove(tmp_name)
|
|
988
|
+
|
|
989
|
+
# Construct and return the NeXus NXroot object
|
|
990
|
+
return self._get_nxroot(
|
|
991
|
+
data, independent_dimensions, all_scalar_data, placeholder_data)
|
|
992
|
+
|
|
993
|
+
def _get_metadata_config(self, metadata):
|
|
994
|
+
"""Get experiment specific configurational data from the
|
|
995
|
+
FOXDEN metadata record
|
|
996
|
+
|
|
997
|
+
:param metadata: FOXDEN metadata record.
|
|
998
|
+
:type metadata: dict
|
|
999
|
+
:return: Experiment specific configurational data.
|
|
1000
|
+
:rtype: dict
|
|
1001
|
+
"""
|
|
1002
|
+
config = {'did': metadata.get('did')}
|
|
1003
|
+
experiment_type = metadata.get('technique')
|
|
1004
|
+
if 'tomography' in experiment_type:
|
|
1005
|
+
config['title'] = metadata.get('sample_name')
|
|
1006
|
+
station = metadata.get('beamline')[0]
|
|
1007
|
+
if station == '3A':
|
|
1008
|
+
station = 'id3a'
|
|
1009
|
+
else:
|
|
1010
|
+
raise ValueError(f'Invalid beamline parameter ({station})')
|
|
1011
|
+
config['station'] = station
|
|
1012
|
+
config['experiment_type'] = 'TOMO'
|
|
1013
|
+
config['sample'] = {'name': config['title'],
|
|
1014
|
+
'description': metadata.get('description')}
|
|
1015
|
+
if station == 'id3a':
|
|
1016
|
+
config['spec_file'] = os.path.join(
|
|
1017
|
+
metadata.get('data_location_raw'), 'spec.log')
|
|
1018
|
+
else:
|
|
1019
|
+
raise ValueError(
|
|
1020
|
+
f'Experiment type {experiment_type} not implemented yet')
|
|
1021
|
+
return config
|
|
1022
|
+
|
|
1023
|
+
def _get_nxroot(
|
|
1024
|
+
self, data, independent_dimensions, all_scalar_data,
|
|
1025
|
+
placeholder_data):
|
|
1026
|
+
"""Use a `MapConfig` to construct a NeXus NXroot object.
|
|
1027
|
+
|
|
1028
|
+
:param data: The map's raw data.
|
|
1029
|
+
:type data: numpy.ndarray
|
|
1030
|
+
:param independent_dimensions: The map's independent
|
|
1031
|
+
coordinates.
|
|
1032
|
+
:type independent_dimensions: numpy.ndarray
|
|
1033
|
+
:param all_scalar_data: The map's scalar data.
|
|
1034
|
+
:type all_scalar_data: numpy.ndarray
|
|
1035
|
+
:param placeholder_data: For SMB EDD maps only. Value to use
|
|
1036
|
+
for missing detector data frames, or `False` if missing
|
|
1037
|
+
data should raise an error.
|
|
1038
|
+
:type placeholder_data: object
|
|
1039
|
+
:return: The map's data and metadata contained in a NeXus
|
|
1040
|
+
structure.
|
|
1041
|
+
:rtype: nexusformat.nexus.NXroot
|
|
1042
|
+
"""
|
|
1043
|
+
# Third party modules
|
|
1044
|
+
# pylint: disable=no-name-in-module
|
|
1045
|
+
from nexusformat.nexus import (
|
|
1046
|
+
NXcollection,
|
|
1047
|
+
NXdata,
|
|
1048
|
+
NXentry,
|
|
1049
|
+
NXfield,
|
|
1050
|
+
NXlinkfield,
|
|
1051
|
+
NXsample,
|
|
1052
|
+
NXroot,
|
|
1053
|
+
)
|
|
1054
|
+
# pylint: enable=no-name-in-module
|
|
1055
|
+
|
|
1056
|
+
# Local modules:
|
|
1057
|
+
from CHAP.common.models.map import PointByPointScanData
|
|
1058
|
+
|
|
1059
|
+
def linkdims(nxgroup, nxdata_source):
|
|
1060
|
+
"""Link the dimensions for an NXgroup."""
|
|
1061
|
+
source_axes = [k for k in nxdata_source.keys()]
|
|
1062
|
+
if isinstance(source_axes, str):
|
|
1063
|
+
source_axes = [source_axes]
|
|
1064
|
+
axes = []
|
|
1065
|
+
for dim in source_axes:
|
|
1066
|
+
axes.append(dim)
|
|
1067
|
+
if isinstance(nxdata_source[dim], NXlinkfield):
|
|
1068
|
+
nxgroup[dim] = nxdata_source[dim]
|
|
1069
|
+
else:
|
|
1070
|
+
nxgroup.makelink(nxdata_source[dim])
|
|
1071
|
+
if f'{dim}_indices' in nxdata_source.attrs:
|
|
1072
|
+
nxgroup.attrs[f'{dim}_indices'] = \
|
|
1073
|
+
nxdata_source.attrs[f'{dim}_indices']
|
|
1074
|
+
if len(axes) == 1:
|
|
1075
|
+
nxgroup.attrs['axes'] = axes
|
|
1076
|
+
else:
|
|
1077
|
+
nxgroup.attrs['unstructured_axes'] = axes
|
|
1078
|
+
|
|
1079
|
+
# Set up NeXus NXroot/NXentry and add CHESS-specific metadata
|
|
1080
|
+
nxroot = NXroot()
|
|
1081
|
+
nxentry = NXentry(name=self.config.title)
|
|
1082
|
+
nxroot[nxentry.nxname] = nxentry
|
|
1083
|
+
nxentry.set_default()
|
|
1084
|
+
nxentry.map_config = self.config.model_dump_json()
|
|
1085
|
+
nxentry.attrs['station'] = self.config.station
|
|
1086
|
+
for k, v in self.config.attrs.items():
|
|
1087
|
+
nxentry.attrs[k] = v
|
|
1088
|
+
nxentry.spec_scans = NXcollection()
|
|
1089
|
+
for scans in self.config.spec_scans:
|
|
1090
|
+
nxentry.spec_scans[scans.scanparsers[0].scan_name] = \
|
|
1091
|
+
NXfield(value=scans.scan_numbers,
|
|
1092
|
+
dtype='int8',
|
|
1093
|
+
attrs={'spec_file': str(scans.spec_file)})
|
|
1094
|
+
|
|
1095
|
+
# Add sample metadata
|
|
1096
|
+
nxentry[self.config.sample.name] = NXsample(
|
|
1097
|
+
**self.config.sample.model_dump())
|
|
1098
|
+
|
|
1099
|
+
# Set up independent dimensions NeXus NXdata group
|
|
1100
|
+
# (squeeze out constant dimensions)
|
|
1101
|
+
constant_dim = []
|
|
1102
|
+
for i, dim in enumerate(self.config.independent_dimensions):
|
|
1103
|
+
unique = np.unique(independent_dimensions[i])
|
|
1104
|
+
if unique.size == 1:
|
|
1105
|
+
constant_dim.append(i)
|
|
1106
|
+
nxentry.independent_dimensions = NXdata()
|
|
1107
|
+
for i, dim in enumerate(self.config.independent_dimensions):
|
|
1108
|
+
if i not in constant_dim:
|
|
1109
|
+
nxentry.independent_dimensions[dim.label] = NXfield(
|
|
1110
|
+
independent_dimensions[i], dim.label,
|
|
1111
|
+
attrs={'units': dim.units,
|
|
1112
|
+
'long_name': f'{dim.label} ({dim.units})',
|
|
1113
|
+
'data_type': dim.data_type,
|
|
1114
|
+
'local_name': dim.name})
|
|
1115
|
+
|
|
1116
|
+
# Set up scalar data NeXus NXdata group
|
|
1117
|
+
# (add the constant independent dimensions)
|
|
1118
|
+
if all_scalar_data is not None:
|
|
1119
|
+
self.logger.debug(
|
|
1120
|
+
f'all_scalar_data.shape = {all_scalar_data.shape}\n\n')
|
|
1121
|
+
scalar_signals = []
|
|
1122
|
+
scalar_data = []
|
|
1123
|
+
for i, dim in enumerate(self.config.all_scalar_data):
|
|
1124
|
+
scalar_signals.append(dim.label)
|
|
1125
|
+
scalar_data.append(NXfield(
|
|
1126
|
+
value=all_scalar_data[i],
|
|
1127
|
+
units=dim.units,
|
|
1128
|
+
attrs={'long_name': f'{dim.label} ({dim.units})',
|
|
1129
|
+
'data_type': dim.data_type,
|
|
1130
|
+
'local_name': dim.name}))
|
|
1131
|
+
if (self.config.experiment_type == 'EDD'
|
|
1132
|
+
and not placeholder_data is False):
|
|
1133
|
+
scalar_signals.append('placeholder_data_used')
|
|
1134
|
+
scalar_data.append(NXfield(
|
|
1135
|
+
value=all_scalar_data[-1],
|
|
1136
|
+
attrs={'description':
|
|
1137
|
+
'Indicates whether placeholder data may be present for'
|
|
1138
|
+
'the corresponding frames of detector data.'}))
|
|
1139
|
+
for i, dim in enumerate(deepcopy(self.config.independent_dimensions)):
|
|
1140
|
+
if i in constant_dim:
|
|
1141
|
+
scalar_signals.append(dim.label)
|
|
1142
|
+
scalar_data.append(NXfield(
|
|
1143
|
+
independent_dimensions[i], dim.label,
|
|
1144
|
+
attrs={'units': dim.units,
|
|
1145
|
+
'long_name': f'{dim.label} ({dim.units})',
|
|
1146
|
+
'data_type': dim.data_type,
|
|
1147
|
+
'local_name': dim.name}))
|
|
1148
|
+
self.config.all_scalar_data.append(
|
|
1149
|
+
PointByPointScanData(**dim.model_dump()))
|
|
1150
|
+
self.config.independent_dimensions.remove(dim)
|
|
1151
|
+
if scalar_signals:
|
|
1152
|
+
nxentry.scalar_data = NXdata()
|
|
1153
|
+
for k, v in zip(scalar_signals, scalar_data):
|
|
1154
|
+
nxentry.scalar_data[k] = v
|
|
1155
|
+
if 'SCAN_N' in scalar_signals:
|
|
1156
|
+
nxentry.scalar_data.attrs['signal'] = 'SCAN_N'
|
|
1157
|
+
else:
|
|
1158
|
+
nxentry.scalar_data.attrs['signal'] = scalar_signals[0]
|
|
1159
|
+
scalar_signals.remove(nxentry.scalar_data.attrs['signal'])
|
|
1160
|
+
nxentry.scalar_data.attrs['auxiliary_signals'] = scalar_signals
|
|
1161
|
+
|
|
1162
|
+
# Add detector data
|
|
1163
|
+
nxdata = NXdata()
|
|
1164
|
+
nxentry.data = nxdata
|
|
1165
|
+
nxentry.data.set_default()
|
|
1166
|
+
detector_ids = []
|
|
1167
|
+
for k, v in self.config.attrs.items():
|
|
1168
|
+
nxdata.attrs[k] = v
|
|
1169
|
+
min_ = np.min(data, axis=tuple(range(1, data.ndim)))
|
|
1170
|
+
max_ = np.max(data, axis=tuple(range(1, data.ndim)))
|
|
1171
|
+
for i, detector in enumerate(self.detector_config.detectors):
|
|
1172
|
+
nxdata[detector.get_id()] = NXfield(
|
|
1173
|
+
value=data[i],
|
|
1174
|
+
attrs={**detector.attrs, 'min': min_[i], 'max': max_[i]})
|
|
1175
|
+
detector_ids.append(detector.get_id())
|
|
1176
|
+
linkdims(nxdata, nxentry.independent_dimensions)
|
|
1177
|
+
if len(self.detector_config.detectors) == 1:
|
|
1178
|
+
nxdata.attrs['signal'] = self.detector_config.detectors[0].get_id()
|
|
1179
|
+
nxentry.detector_ids = detector_ids
|
|
1180
|
+
|
|
1181
|
+
return nxroot
|
|
1182
|
+
|
|
1183
|
+
def _read_raw_data_edd(
|
|
1184
|
+
self, comm, num_scan, offset, placeholder_data):
|
|
1185
|
+
"""Read the raw EDD data for a given map configuration.
|
|
1186
|
+
|
|
1187
|
+
:param comm: MPI communicator.
|
|
1188
|
+
:type comm: mpi4py.MPI.Comm, optional
|
|
1189
|
+
:param num_scan: Number of scans in the map.
|
|
1190
|
+
:type num_scan: int
|
|
1191
|
+
:param offset: Offset scan number of current processor.
|
|
1192
|
+
:type offset: int
|
|
1193
|
+
:param placeholder_data: Value to use for missing detector
|
|
1194
|
+
data frames, or `False` if missing data should raise an
|
|
1195
|
+
error.
|
|
1196
|
+
:type placeholder_data: object
|
|
1197
|
+
:return: The map's raw data, independent dimensions and scalar
|
|
1198
|
+
data.
|
|
1199
|
+
:rtype: numpy.ndarray, numpy.ndarray, numpy.ndarray
|
|
1200
|
+
"""
|
|
1201
|
+
# Third party modules
|
|
1202
|
+
try:
|
|
1203
|
+
from mpi4py import MPI
|
|
1204
|
+
from mpi4py.util import dtlib
|
|
1205
|
+
except Exception:
|
|
1206
|
+
pass
|
|
1207
|
+
|
|
1208
|
+
# Local modules
|
|
1209
|
+
from CHAP.utils.general import list_to_string
|
|
1210
|
+
|
|
1211
|
+
if comm is None:
|
|
1212
|
+
self.num_proc = 1
|
|
1213
|
+
rank = 0
|
|
1214
|
+
else:
|
|
1215
|
+
self.num_proc = comm.Get_size()
|
|
1216
|
+
rank = comm.Get_rank()
|
|
1217
|
+
if not rank:
|
|
1218
|
+
self.logger.debug(f'Number of processors: {self.num_proc}')
|
|
1219
|
+
self.logger.debug(f'Number of scans: {num_scan}')
|
|
1220
|
+
|
|
1221
|
+
# Create the shared data buffers
|
|
1222
|
+
# FIX: just one spec scan at this point
|
|
1223
|
+
assert len(self.config.spec_scans) == 1
|
|
1224
|
+
scan = self.config.spec_scans[0]
|
|
1225
|
+
scan_numbers = scan.scan_numbers
|
|
1226
|
+
scanparser = scan.get_scanparser(scan_numbers[0])
|
|
1227
|
+
detector_ids = [
|
|
1228
|
+
int(d.get_id()) for d in self.detector_config.detectors]
|
|
1229
|
+
ddata, placeholder_used = scanparser.get_detector_data(
|
|
1230
|
+
detector_ids, placeholder_data=placeholder_data)
|
|
1231
|
+
spec_scan_shape = scanparser.spec_scan_shape
|
|
1232
|
+
num_dim = np.prod(spec_scan_shape)
|
|
1233
|
+
num_id = len(self.config.independent_dimensions)
|
|
1234
|
+
num_sd = len(self.config.all_scalar_data)
|
|
1235
|
+
if placeholder_data is not False:
|
|
1236
|
+
num_sd += 1
|
|
1237
|
+
if self.num_proc == 1:
|
|
1238
|
+
assert num_scan == len(scan_numbers)
|
|
1239
|
+
data = np.empty((num_scan, *ddata.shape), dtype=ddata.dtype)
|
|
1240
|
+
independent_dimensions = np.empty(
|
|
1241
|
+
(num_id, num_scan*num_dim), dtype=np.float64)
|
|
1242
|
+
all_scalar_data = np.empty(
|
|
1243
|
+
(num_sd, num_scan*num_dim), dtype=np.float64)
|
|
1244
|
+
else:
|
|
1245
|
+
self.logger.debug(f'Scan offset on processor {rank}: {offset}')
|
|
1246
|
+
self.logger.debug(f'Scan numbers on processor {rank}: '
|
|
1247
|
+
f'{list_to_string(scan_numbers)}')
|
|
1248
|
+
datatype = dtlib.from_numpy_dtype(ddata.dtype)
|
|
1249
|
+
itemsize = datatype.Get_size()
|
|
1250
|
+
if not rank:
|
|
1251
|
+
nbytes = num_scan * np.prod(ddata.shape) * itemsize
|
|
1252
|
+
else:
|
|
1253
|
+
nbytes = 0
|
|
1254
|
+
win = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm)
|
|
1255
|
+
buf, itemsize = win.Shared_query(0)
|
|
1256
|
+
assert itemsize == datatype.Get_size()
|
|
1257
|
+
data = np.ndarray(
|
|
1258
|
+
buffer=buf, dtype=ddata.dtype, shape=(num_scan, *ddata.shape))
|
|
1259
|
+
datatype = dtlib.from_numpy_dtype(np.float64)
|
|
1260
|
+
itemsize = datatype.Get_size()
|
|
1261
|
+
if not rank:
|
|
1262
|
+
nbytes = num_id * num_scan * num_dim * itemsize
|
|
1263
|
+
win_id = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm)
|
|
1264
|
+
buf_id, _ = win_id.Shared_query(0)
|
|
1265
|
+
independent_dimensions = np.ndarray(
|
|
1266
|
+
buffer=buf_id, dtype=np.float64,
|
|
1267
|
+
shape=(num_id, num_scan*num_dim))
|
|
1268
|
+
if not rank:
|
|
1269
|
+
nbytes = num_sd * num_scan * num_dim * itemsize
|
|
1270
|
+
win_sd = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm)
|
|
1271
|
+
buf_sd, _ = win_sd.Shared_query(0)
|
|
1272
|
+
all_scalar_data = np.ndarray(
|
|
1273
|
+
buffer=buf_sd, dtype=np.float64,
|
|
1274
|
+
shape=(num_sd, num_scan*num_dim))
|
|
1275
|
+
|
|
1276
|
+
# Read the raw data
|
|
1277
|
+
init = True
|
|
1278
|
+
for scan in self.config.spec_scans:
|
|
1279
|
+
for scan_number in scan.scan_numbers:
|
|
1280
|
+
if init:
|
|
1281
|
+
init = False
|
|
1282
|
+
else:
|
|
1283
|
+
scanparser = scan.get_scanparser(scan_number)
|
|
1284
|
+
assert spec_scan_shape == scanparser.spec_scan_shape
|
|
1285
|
+
ddata, placeholder_used = scanparser.get_detector_data(
|
|
1286
|
+
detector_ids, placeholder_data=placeholder_data)
|
|
1287
|
+
data[offset] = ddata
|
|
1288
|
+
start_dim = offset * num_dim
|
|
1289
|
+
end_dim = start_dim + num_dim
|
|
1290
|
+
for i, dim in enumerate(self.config.independent_dimensions):
|
|
1291
|
+
independent_dimensions[i][start_dim:end_dim] = \
|
|
1292
|
+
dim.get_value(
|
|
1293
|
+
scan, scan_number, scan_step_index=-1,
|
|
1294
|
+
relative=False)
|
|
1295
|
+
for i, dim in enumerate(self.config.all_scalar_data):
|
|
1296
|
+
all_scalar_data[i][start_dim:end_dim] = dim.get_value(
|
|
1297
|
+
scan, scan_number, scan_step_index=-1,
|
|
1298
|
+
relative=False)
|
|
1299
|
+
if placeholder_data is not False:
|
|
1300
|
+
all_scalar_data[-1][start_dim:end_dim] = \
|
|
1301
|
+
placeholder_used
|
|
1302
|
+
offset += 1
|
|
1303
|
+
|
|
1304
|
+
return (
|
|
1305
|
+
np.swapaxes(
|
|
1306
|
+
data.reshape((np.prod(data.shape[:2]), *data.shape[2:])),
|
|
1307
|
+
0, 1),
|
|
1308
|
+
independent_dimensions, all_scalar_data)
|
|
1309
|
+
|
|
1310
|
+
# @profile
|
|
1311
|
+
def _read_raw_data(self, comm, num_scan, offset):
|
|
1312
|
+
"""Read the raw data for a given map configuration.
|
|
1313
|
+
|
|
1314
|
+
:param comm: MPI communicator.
|
|
1315
|
+
:type comm: mpi4py.MPI.Comm, optional
|
|
1316
|
+
:param num_scan: Number of scans in the map.
|
|
1317
|
+
:type num_scan: int
|
|
1318
|
+
:param offset: Offset scan number of current processor.
|
|
1319
|
+
:type offset: int
|
|
1320
|
+
:return: The map's raw data, independent dimensions and scalar
|
|
1321
|
+
data.
|
|
1322
|
+
:rtype: numpy.ndarray, numpy.ndarray, numpy.ndarray
|
|
1323
|
+
"""
|
|
1324
|
+
# Third party modules
|
|
1325
|
+
try:
|
|
1326
|
+
from mpi4py import MPI
|
|
1327
|
+
from mpi4py.util import dtlib
|
|
1328
|
+
except Exception:
|
|
1329
|
+
pass
|
|
1330
|
+
|
|
1331
|
+
# Local modules
|
|
1332
|
+
from CHAP.utils.general import list_to_string
|
|
1333
|
+
|
|
1334
|
+
if comm is None:
|
|
1335
|
+
self.num_proc = 1
|
|
1336
|
+
rank = 0
|
|
1337
|
+
else:
|
|
1338
|
+
self.num_proc = comm.Get_size()
|
|
1339
|
+
rank = comm.Get_rank()
|
|
1340
|
+
if not rank:
|
|
1341
|
+
self.logger.debug(f'Number of processors: {self.num_proc}')
|
|
1342
|
+
self.logger.debug(f'Number of scans: {num_scan}')
|
|
1343
|
+
|
|
1344
|
+
# Create the shared data buffers
|
|
1345
|
+
assert len(self.config.spec_scans) == 1
|
|
1346
|
+
scans = self.config.spec_scans[0]
|
|
1347
|
+
scan_numbers = scans.scan_numbers
|
|
1348
|
+
scanparser = scans.get_scanparser(scan_numbers[0])
|
|
1349
|
+
#RV only correct for multiple detectors if the same image sizes
|
|
1350
|
+
if len(self.detector_config.detectors) != 1:
|
|
1351
|
+
raise ValueError('Multiple detectors not tested yet')
|
|
1352
|
+
if self.config.experiment_type == 'TOMO':
|
|
1353
|
+
dtype = np.float32
|
|
1354
|
+
ddata = scanparser.get_detector_data(
|
|
1355
|
+
self.detector_config.detectors[0].get_id(), dtype=dtype)
|
|
1356
|
+
else:
|
|
1357
|
+
dtype = None
|
|
1358
|
+
ddata = scanparser.get_detector_data(
|
|
1359
|
+
self.detector_config.detectors[0].get_id())
|
|
1360
|
+
num_det = len(self.detector_config.detectors)
|
|
1361
|
+
num_dim = ddata.shape[0]
|
|
1362
|
+
num_id = len(self.config.independent_dimensions)
|
|
1363
|
+
num_sd = len(self.config.all_scalar_data)
|
|
1364
|
+
if self.num_proc == 1:
|
|
1365
|
+
assert num_scan == len(scan_numbers)
|
|
1366
|
+
data = num_det*[num_scan*[None]]
|
|
1367
|
+
independent_dimensions = np.empty(
|
|
1368
|
+
(num_scan, num_id, num_dim), dtype=np.float64)
|
|
1369
|
+
if num_sd:
|
|
1370
|
+
all_scalar_data = np.empty(
|
|
1371
|
+
(num_scan, num_sd, num_dim), dtype=np.float64)
|
|
1372
|
+
else:
|
|
1373
|
+
self.logger.debug(f'Scan offset on processor {rank}: {offset}')
|
|
1374
|
+
self.logger.debug(f'Scan numbers on processor {rank}: '
|
|
1375
|
+
f'{list_to_string(scan_numbers)}')
|
|
1376
|
+
datatype = dtlib.from_numpy_dtype(dtype)
|
|
1377
|
+
itemsize = datatype.Get_size()
|
|
1378
|
+
if not rank:
|
|
1379
|
+
nbytes = num_scan * np.prod(ddata.shape) * itemsize
|
|
1380
|
+
else:
|
|
1381
|
+
nbytes = 0
|
|
1382
|
+
win = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm)
|
|
1383
|
+
buf, _ = win.Shared_query(0)
|
|
1384
|
+
#RV improve memory requirements ala single processor case?
|
|
1385
|
+
data = np.ndarray(
|
|
1386
|
+
buffer=buf, dtype=dtype,
|
|
1387
|
+
shape=(num_det, num_scan, *ddata.shape))
|
|
1388
|
+
datatype = dtlib.from_numpy_dtype(np.float64)
|
|
1389
|
+
itemsize = datatype.Get_size()
|
|
1390
|
+
if not rank:
|
|
1391
|
+
nbytes = num_scan * num_id * num_dim * itemsize
|
|
1392
|
+
else:
|
|
1393
|
+
nbytes = 0
|
|
1394
|
+
win_id = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm)
|
|
1395
|
+
buf_id, _ = win_id.Shared_query(0)
|
|
1396
|
+
independent_dimensions = np.ndarray(
|
|
1397
|
+
buffer=buf_id, dtype=np.float64,
|
|
1398
|
+
shape=(num_scan, num_id, num_dim))
|
|
1399
|
+
if num_sd:
|
|
1400
|
+
if not rank:
|
|
1401
|
+
nbytes = num_scan * num_sd * num_dim * itemsize
|
|
1402
|
+
win_sd = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm)
|
|
1403
|
+
buf_sd, _ = win_sd.Shared_query(0)
|
|
1404
|
+
all_scalar_data = np.ndarray(
|
|
1405
|
+
buffer=buf_sd, dtype=np.float64,
|
|
1406
|
+
shape=(num_scan, num_sd, num_dim))
|
|
1407
|
+
else:
|
|
1408
|
+
all_scalar_data = None
|
|
1409
|
+
|
|
1410
|
+
# Read the raw data
|
|
1411
|
+
init = True
|
|
1412
|
+
for scans in self.config.spec_scans:
|
|
1413
|
+
for scan_number in scans.scan_numbers:
|
|
1414
|
+
for i in range(len((self.detector_config.detectors))):
|
|
1415
|
+
if init:
|
|
1416
|
+
init = False
|
|
1417
|
+
data[i][offset] = ddata
|
|
1418
|
+
del ddata
|
|
1419
|
+
else:
|
|
1420
|
+
scanparser = scans.get_scanparser(scan_number)
|
|
1421
|
+
data[i][offset] = scanparser.get_detector_data(
|
|
1422
|
+
self.detector_config.detectors[i].get_id(),
|
|
1423
|
+
dtype=dtype)
|
|
1424
|
+
for i, dim in enumerate(self.config.independent_dimensions):
|
|
1425
|
+
if dim.data_type in ['scan_column',
|
|
1426
|
+
'detector_log_timestamps']:
|
|
1427
|
+
independent_dimensions[offset,i] = dim.get_value(
|
|
1428
|
+
scans, scan_number, scan_step_index=-1,
|
|
1429
|
+
relative=False)[:num_dim]
|
|
1430
|
+
elif dim.data_type in ['smb_par', 'spec_motor',
|
|
1431
|
+
'expression']:
|
|
1432
|
+
independent_dimensions[offset,i] = dim.get_value(
|
|
1433
|
+
scans, scan_number, scan_step_index=-1,
|
|
1434
|
+
relative=False,
|
|
1435
|
+
scalar_data=self.config.scalar_data)
|
|
1436
|
+
else:
|
|
1437
|
+
raise RuntimeError(
|
|
1438
|
+
f'{dim.data_type} in data_type not tested')
|
|
1439
|
+
for i, dim in enumerate(self.config.all_scalar_data):
|
|
1440
|
+
all_scalar_data[offset,i] = dim.get_value(
|
|
1441
|
+
scans, scan_number, scan_step_index=-1,
|
|
1442
|
+
relative=False)
|
|
1443
|
+
offset += 1
|
|
1444
|
+
if self.num_proc == 1:
|
|
1445
|
+
data = np.asarray(data)
|
|
1446
|
+
if num_sd:
|
|
1447
|
+
return (
|
|
1448
|
+
data.reshape(
|
|
1449
|
+
(data.shape[0], np.prod(data.shape[1:3]),
|
|
1450
|
+
*data.shape[3:])),
|
|
1451
|
+
np.stack(tuple([independent_dimensions[:,i].flatten()
|
|
1452
|
+
for i in range(num_id)])),
|
|
1453
|
+
np.stack(tuple([all_scalar_data[:,i].flatten()
|
|
1454
|
+
for i in range(num_sd)])))
|
|
1455
|
+
return (
|
|
1456
|
+
data.reshape(
|
|
1457
|
+
(data.shape[0], np.prod(data.shape[1:3]), *data.shape[3:])),
|
|
1458
|
+
np.stack(tuple([independent_dimensions[:,i].flatten()
|
|
1459
|
+
for i in range(num_id)])),
|
|
1460
|
+
None)
|
|
1461
|
+
|
|
1462
|
+
|
|
1463
|
+
class MPICollectProcessor(Processor):
|
|
1464
|
+
"""A Processor that collects the distributed worker data from
|
|
1465
|
+
MPIMapProcessor on the root node.
|
|
1466
|
+
"""
|
|
1467
|
+
def process(self, data, comm, root_as_worker=True):
|
|
1468
|
+
"""Collect data on root node.
|
|
1469
|
+
|
|
1470
|
+
:param data: Input data.
|
|
1471
|
+
:type data: list[PipelineData]
|
|
1472
|
+
:param comm: MPI communicator.
|
|
1473
|
+
:type comm: mpi4py.MPI.Comm, optional
|
|
1474
|
+
:param root_as_worker: Use the root node as a worker,
|
|
1475
|
+
defaults to `True`.
|
|
1476
|
+
:type root_as_worker: bool, optional
|
|
1477
|
+
:return: Returns a list of the distributed worker data on the
|
|
1478
|
+
root node.
|
|
1479
|
+
"""
|
|
1480
|
+
num_proc = comm.Get_size()
|
|
1481
|
+
rank = comm.Get_rank()
|
|
1482
|
+
if root_as_worker:
|
|
1483
|
+
data = self.unwrap_pipelinedata(data)[-1]
|
|
1484
|
+
if num_proc > 1:
|
|
1485
|
+
data = comm.gather(data, root=0)
|
|
1486
|
+
else:
|
|
1487
|
+
for n_worker in range(1, num_proc):
|
|
1488
|
+
if rank == n_worker:
|
|
1489
|
+
comm.send(self.unwrap_pipelinedata(data)[-1], dest=0)
|
|
1490
|
+
data = None
|
|
1491
|
+
elif not rank:
|
|
1492
|
+
if n_worker == 1:
|
|
1493
|
+
data = [comm.recv(source=n_worker)]
|
|
1494
|
+
else:
|
|
1495
|
+
data.append(comm.recv(source=n_worker))
|
|
1496
|
+
#FIX RV TODO Merge the list of data items in some generic fashion
|
|
1497
|
+
return data
|
|
1498
|
+
|
|
1499
|
+
|
|
1500
|
+
class MPIMapProcessor(Processor):
|
|
1501
|
+
"""A Processor that applies a parallel generic sub-pipeline to
|
|
1502
|
+
a map configuration.
|
|
1503
|
+
"""
|
|
1504
|
+
def process(self, data, config=None, sub_pipeline=None):
|
|
1505
|
+
"""Run a parallel generic sub-pipeline.
|
|
1506
|
+
|
|
1507
|
+
:param data: Input data.
|
|
1508
|
+
:type data: list[PipelineData]
|
|
1509
|
+
:param config: Initialization parameters for an instance of
|
|
1510
|
+
common.models.map.MapConfig.
|
|
1511
|
+
:type config: dict, optional
|
|
1512
|
+
:param sub_pipeline: The sub-pipeline.
|
|
1513
|
+
:type sub_pipeline: Pipeline, optional
|
|
1514
|
+
:return: The `data` field of the first item in the returned
|
|
1515
|
+
list of sub-pipeline items.
|
|
1516
|
+
"""
|
|
1517
|
+
# Third party modules
|
|
1518
|
+
from mpi4py import MPI
|
|
1519
|
+
|
|
1520
|
+
# Local modules
|
|
1521
|
+
from CHAP.models import RunConfig
|
|
1522
|
+
from CHAP.runner import run
|
|
1523
|
+
from CHAP.common.models.map import SpecScans
|
|
1524
|
+
|
|
1525
|
+
raise NotImplementedError('MPIMapProcessor needs updating and testing')
|
|
1526
|
+
# pylint: disable=c-extension-no-member
|
|
1527
|
+
comm = MPI.COMM_WORLD
|
|
1528
|
+
num_proc = comm.Get_size()
|
|
1529
|
+
rank = comm.Get_rank()
|
|
1530
|
+
|
|
1531
|
+
# Get the validated map configuration
|
|
1532
|
+
map_config = self.get_config(
|
|
1533
|
+
data=data, config=config, schema='common.models.map.MapConfig')
|
|
1534
|
+
|
|
1535
|
+
# Create the spec reader configuration for each processor
|
|
1536
|
+
# FIX: catered to EDD with one spec scan
|
|
1537
|
+
assert len(map_config.spec_scans) == 1
|
|
1538
|
+
spec_scans = map_config.spec_scans[0]
|
|
1539
|
+
scan_numbers = spec_scans.scan_numbers
|
|
1540
|
+
num_scan = len(scan_numbers)
|
|
1541
|
+
scans_per_proc = num_scan//num_proc
|
|
1542
|
+
n_scan = 0
|
|
1543
|
+
for n_proc in range(num_proc):
|
|
1544
|
+
num = scans_per_proc
|
|
1545
|
+
if n_proc == rank:
|
|
1546
|
+
if rank < num_scan - scans_per_proc*num_proc:
|
|
1547
|
+
num += 1
|
|
1548
|
+
scan_numbers = scan_numbers[n_scan:n_scan+num]
|
|
1549
|
+
n_scan += num
|
|
1550
|
+
spec_config = {
|
|
1551
|
+
'station': map_config.station,
|
|
1552
|
+
'experiment_type': map_config.experiment_type,
|
|
1553
|
+
'spec_scans': [SpecScans(
|
|
1554
|
+
spec_file=spec_scans.spec_file, scan_numbers=scan_numbers)]}
|
|
1555
|
+
|
|
1556
|
+
# Get the run configuration to use for the sub-pipeline
|
|
1557
|
+
if sub_pipeline is None:
|
|
1558
|
+
sub_pipeline = {}
|
|
1559
|
+
run_config = {'inputdir': self.inputdir, 'outputdir': self.outputdir,
|
|
1560
|
+
'interactive': self.interactive, 'log_level': self.log_level}
|
|
1561
|
+
run_config.update(sub_pipeline.get('config'))
|
|
1562
|
+
run_config = RunConfig(**run_config, comm=comm)
|
|
1563
|
+
pipeline_config = []
|
|
1564
|
+
for item in sub_pipeline['pipeline']:
|
|
1565
|
+
if isinstance(item, dict):
|
|
1566
|
+
for k, v in deepcopy(item).items():
|
|
1567
|
+
if k.endswith('Reader'):
|
|
1568
|
+
v['config'] = spec_config
|
|
1569
|
+
item[k] = v
|
|
1570
|
+
if num_proc > 1 and k.endswith('Writer'):
|
|
1571
|
+
r, e = os.path.splitext(v['filename'])
|
|
1572
|
+
v['filename'] = f'{r}_{rank}{e}'
|
|
1573
|
+
item[k] = v
|
|
1574
|
+
pipeline_config.append(item)
|
|
1575
|
+
|
|
1576
|
+
# Run the sub-pipeline on each processor
|
|
1577
|
+
return run(run_config, pipeline_config, logger=self.logger, comm=comm)
|
|
1578
|
+
|
|
1579
|
+
|
|
1580
|
+
class MPISpawnMapProcessor(Processor):
|
|
1581
|
+
"""A Processor that applies a parallel generic sub-pipeline to
|
|
1582
|
+
a map configuration by spawning workers processes.
|
|
1583
|
+
"""
|
|
1584
|
+
def process(
|
|
1585
|
+
self, data, num_proc=1, root_as_worker=True, collect_on_root=False,
|
|
1586
|
+
sub_pipeline=None):
|
|
1587
|
+
"""Spawn workers running a parallel generic sub-pipeline.
|
|
1588
|
+
|
|
1589
|
+
:param data: Input data.
|
|
1590
|
+
:type data: list[PipelineData]
|
|
1591
|
+
:param num_proc: Number of spawned processors, defaults to `1`.
|
|
1592
|
+
:type num_proc: int, optional
|
|
1593
|
+
:param root_as_worker: Use the root node as a worker,
|
|
1594
|
+
defaults to `True`.
|
|
1595
|
+
:type root_as_worker: bool, optional
|
|
1596
|
+
:param collect_on_root: Collect the result of the spawned
|
|
1597
|
+
workers on the root node, defaults to `False`.
|
|
1598
|
+
:type collect_on_root: bool, optional
|
|
1599
|
+
:param sub_pipeline: The sub-pipeline.
|
|
1600
|
+
:type sub_pipeline: Pipeline, optional
|
|
1601
|
+
:return: The `data` field of the first item in the returned
|
|
1602
|
+
list of sub-pipeline items.
|
|
1603
|
+
"""
|
|
1604
|
+
# Third party modules
|
|
1605
|
+
from mpi4py import MPI
|
|
1606
|
+
import yaml
|
|
1607
|
+
|
|
1608
|
+
# Local modules
|
|
1609
|
+
from CHAP.models import RunConfig
|
|
1610
|
+
from CHAP.runner import runner
|
|
1611
|
+
from CHAP.common.models.map import SpecScans
|
|
1612
|
+
|
|
1613
|
+
raise NotImplementedError('MPIMapProcessor needs updating and testing')
|
|
1614
|
+
# Get the map configuration from data
|
|
1615
|
+
map_config = self.get_config(
|
|
1616
|
+
data=data, schema='common.models.map.MapConfig')
|
|
1617
|
+
|
|
1618
|
+
# Get the run configuration to use for the sub-pipeline
|
|
1619
|
+
# Optionally include the root node as a worker node
|
|
1620
|
+
if sub_pipeline is None:
|
|
1621
|
+
sub_pipeline = {}
|
|
1622
|
+
run_config = {'inputdir': self.inputdir, 'outputdir': self.outputdir,
|
|
1623
|
+
'interactive': self.interactive, 'log_level': self.log_level}
|
|
1624
|
+
run_config.update(sub_pipeline.get('config'))
|
|
1625
|
+
if root_as_worker:
|
|
1626
|
+
first_proc = 1
|
|
1627
|
+
spawn = 1
|
|
1628
|
+
else:
|
|
1629
|
+
first_proc = 0
|
|
1630
|
+
spawn = -1
|
|
1631
|
+
run_config = RunConfig(**run_config, logger=self.logger, spawn=spawn)
|
|
1632
|
+
|
|
1633
|
+
# Create the sub-pipeline configuration for each processor
|
|
1634
|
+
spec_scans = map_config.spec_scans[0]
|
|
1635
|
+
scan_numbers = spec_scans.scan_numbers
|
|
1636
|
+
num_scan = len(scan_numbers)
|
|
1637
|
+
scans_per_proc = num_scan//num_proc
|
|
1638
|
+
n_scan = 0
|
|
1639
|
+
pipeline_config = []
|
|
1640
|
+
for n_proc in range(num_proc):
|
|
1641
|
+
num = scans_per_proc
|
|
1642
|
+
if n_proc < num_scan - scans_per_proc*num_proc:
|
|
1643
|
+
num += 1
|
|
1644
|
+
spec_config = {
|
|
1645
|
+
'station': map_config.station,
|
|
1646
|
+
'experiment_type': map_config.experiment_type,
|
|
1647
|
+
'spec_scans': [SpecScans(
|
|
1648
|
+
spec_file=spec_scans.spec_file,
|
|
1649
|
+
scan_numbers=scan_numbers[n_scan:n_scan+num]).__dict__]}
|
|
1650
|
+
sub_pipeline_config = []
|
|
1651
|
+
for item in deepcopy(sub_pipeline['pipeline']):
|
|
1652
|
+
if isinstance(item, dict):
|
|
1653
|
+
for k, v in deepcopy(item).items():
|
|
1654
|
+
if k.endswith('Reader'):
|
|
1655
|
+
v['config'] = spec_config
|
|
1656
|
+
item[k] = v
|
|
1657
|
+
if num_proc > 1 and k.endswith('Writer'):
|
|
1658
|
+
r, e = os.path.splitext(v['filename'])
|
|
1659
|
+
v['filename'] = f'{r}_{n_proc}{e}'
|
|
1660
|
+
item[k] = v
|
|
1661
|
+
sub_pipeline_config.append(item)
|
|
1662
|
+
if collect_on_root and (not root_as_worker or num_proc > 1):
|
|
1663
|
+
sub_pipeline_config += [
|
|
1664
|
+
{'common.MPICollectProcessor': {
|
|
1665
|
+
'root_as_worker': root_as_worker}}]
|
|
1666
|
+
pipeline_config.append(sub_pipeline_config)
|
|
1667
|
+
n_scan += num
|
|
1668
|
+
|
|
1669
|
+
# Spawn the workers to run the sub-pipeline
|
|
1670
|
+
if num_proc > first_proc:
|
|
1671
|
+
# System modules
|
|
1672
|
+
from tempfile import NamedTemporaryFile
|
|
1673
|
+
|
|
1674
|
+
tmp_names = []
|
|
1675
|
+
with NamedTemporaryFile(delete=False) as fp:
|
|
1676
|
+
# pylint: disable=c-extension-no-member
|
|
1677
|
+
fp_name = fp.name
|
|
1678
|
+
tmp_names.append(fp_name)
|
|
1679
|
+
with open(fp_name, 'w') as f:
|
|
1680
|
+
yaml.dump(
|
|
1681
|
+
{'config': {'spawn': run_config.spawn}}, f,
|
|
1682
|
+
sort_keys=False)
|
|
1683
|
+
for n_proc in range(first_proc, num_proc):
|
|
1684
|
+
f_name = f'{fp_name}_{n_proc}'
|
|
1685
|
+
tmp_names.append(f_name)
|
|
1686
|
+
with open(f_name, 'w') as f:
|
|
1687
|
+
yaml.dump(
|
|
1688
|
+
#FIX once comm is a field of RunConfig
|
|
1689
|
+
#{'config': run_config.model_dump(exclude='comm'),
|
|
1690
|
+
{'config': run_config.model_dump(),
|
|
1691
|
+
'pipeline': pipeline_config[n_proc]},
|
|
1692
|
+
f, sort_keys=False)
|
|
1693
|
+
# pylint: disable=used-before-assignment
|
|
1694
|
+
sub_comm = MPI.COMM_SELF.Spawn(
|
|
1695
|
+
'CHAP', args=[fp_name], maxprocs=num_proc-first_proc)
|
|
1696
|
+
common_comm = sub_comm.Merge(False)
|
|
1697
|
+
if run_config.spawn > 0:
|
|
1698
|
+
# Align with the barrier in RunConfig() on common_comm
|
|
1699
|
+
# called from the spawned main()
|
|
1700
|
+
common_comm.barrier()
|
|
1701
|
+
else:
|
|
1702
|
+
common_comm = None
|
|
1703
|
+
|
|
1704
|
+
# Run the sub-pipeline on the root node
|
|
1705
|
+
if root_as_worker:
|
|
1706
|
+
data = runner(run_config, pipeline_config[0], comm=common_comm)
|
|
1707
|
+
elif collect_on_root:
|
|
1708
|
+
run_config.spawn = 0
|
|
1709
|
+
pipeline_config = [{'common.MPICollectProcessor': {
|
|
1710
|
+
'root_as_worker': root_as_worker}}]
|
|
1711
|
+
data = runner(run_config, pipeline_config, common_comm)
|
|
1712
|
+
else:
|
|
1713
|
+
# Align with the barrier in run() on common_comm
|
|
1714
|
+
# called from the spawned main()
|
|
1715
|
+
common_comm.barrier()
|
|
1716
|
+
data = None
|
|
1717
|
+
|
|
1718
|
+
# Disconnect spawned workers and cleanup temporary files
|
|
1719
|
+
if num_proc > first_proc:
|
|
1720
|
+
# Align with the barrier in main() on common_comm
|
|
1721
|
+
# when disconnecting the spawned worker
|
|
1722
|
+
common_comm.barrier()
|
|
1723
|
+
# Disconnect spawned workers and cleanup temporary files
|
|
1724
|
+
sub_comm.Disconnect()
|
|
1725
|
+
for tmp_name in tmp_names:
|
|
1726
|
+
os.remove(tmp_name)
|
|
1727
|
+
|
|
1728
|
+
return data
|
|
1729
|
+
|
|
1730
|
+
|
|
1731
|
+
class NexusToNumpyProcessor(Processor):
|
|
1732
|
+
"""A Processor to convert the default plottable data in a NeXus
|
|
1733
|
+
object into a `numpy.ndarray`.
|
|
1734
|
+
"""
|
|
1735
|
+
def process(self, data):
|
|
1736
|
+
"""Return the default plottable data signal in a NeXus object
|
|
1737
|
+
contained in `data` as an `numpy.ndarray`.
|
|
1738
|
+
|
|
1739
|
+
:param data: Input data.
|
|
1740
|
+
:type data: nexusformat.nexus.NXobject
|
|
1741
|
+
:raises ValueError: If `data` has no default plottable data
|
|
1742
|
+
signal.
|
|
1743
|
+
:return: The default plottable data signal.
|
|
1744
|
+
:rtype: numpy.ndarray
|
|
1745
|
+
"""
|
|
1746
|
+
# Third party modules
|
|
1747
|
+
from nexusformat.nexus import NXdata
|
|
1748
|
+
|
|
1749
|
+
data = self.unwrap_pipelinedata(data)[-1]
|
|
1750
|
+
|
|
1751
|
+
if isinstance(data, NXdata):
|
|
1752
|
+
default_data = data
|
|
1753
|
+
else:
|
|
1754
|
+
default_data = data.plottable_data
|
|
1755
|
+
if default_data is None:
|
|
1756
|
+
default_data_path = data.attrs.get('default')
|
|
1757
|
+
default_data = data.get(default_data_path)
|
|
1758
|
+
if default_data is None:
|
|
1759
|
+
raise ValueError(
|
|
1760
|
+
f'The structure of {data} contains no default data')
|
|
1761
|
+
|
|
1762
|
+
try:
|
|
1763
|
+
default_signal = default_data.attrs['signal']
|
|
1764
|
+
except Exception as exc:
|
|
1765
|
+
raise ValueError(
|
|
1766
|
+
f'The signal of {default_data} is unknown') from exc
|
|
1767
|
+
|
|
1768
|
+
np_data = default_data[default_signal].nxdata
|
|
1769
|
+
|
|
1770
|
+
return np_data
|
|
1771
|
+
|
|
1772
|
+
|
|
1773
|
+
class NexusToXarrayProcessor(Processor):
|
|
1774
|
+
"""A Processor to convert the default plottable data in a
|
|
1775
|
+
NeXus object into an `xarray.DataArray`.
|
|
1776
|
+
"""
|
|
1777
|
+
def process(self, data):
|
|
1778
|
+
"""Return the default plottable data signal in a NeXus object
|
|
1779
|
+
contained in `data` as an `xarray.DataArray`.
|
|
1780
|
+
|
|
1781
|
+
:param data: Input data.
|
|
1782
|
+
:type data: nexusformat.nexus.NXobject
|
|
1783
|
+
:raises ValueError: If metadata for `xarray` is absent from
|
|
1784
|
+
`data`
|
|
1785
|
+
:return: The default plottable data signal.
|
|
1786
|
+
:rtype: xarray.DataArray
|
|
1787
|
+
"""
|
|
1788
|
+
# Third party modules
|
|
1789
|
+
from nexusformat.nexus import NXdata
|
|
1790
|
+
from xarray import DataArray
|
|
1791
|
+
|
|
1792
|
+
data = self.unwrap_pipelinedata(data)[-1]
|
|
1793
|
+
|
|
1794
|
+
if isinstance(data, NXdata):
|
|
1795
|
+
default_data = data
|
|
1796
|
+
else:
|
|
1797
|
+
default_data = data.plottable_data
|
|
1798
|
+
if default_data is None:
|
|
1799
|
+
default_data_path = data.attrs.get('default')
|
|
1800
|
+
default_data = data.get(default_data_path)
|
|
1801
|
+
if default_data is None:
|
|
1802
|
+
raise ValueError(
|
|
1803
|
+
f'The structure of {data} contains no default data')
|
|
1804
|
+
|
|
1805
|
+
try:
|
|
1806
|
+
default_signal = default_data.attrs['signal']
|
|
1807
|
+
except Exception as exc:
|
|
1808
|
+
raise ValueError(
|
|
1809
|
+
f'The signal of {default_data} is unknown') from exc
|
|
1810
|
+
signal_data = default_data[default_signal].nxdata
|
|
1811
|
+
|
|
1812
|
+
axes = default_data.attrs['axes']
|
|
1813
|
+
if isinstance(axes, str):
|
|
1814
|
+
axes = [axes]
|
|
1815
|
+
coords = {}
|
|
1816
|
+
for axis_name in axes:
|
|
1817
|
+
axis = default_data[axis_name]
|
|
1818
|
+
coords[axis_name] = (axis_name, axis.nxdata, axis.attrs)
|
|
1819
|
+
|
|
1820
|
+
dims = tuple(axes)
|
|
1821
|
+
name = default_signal
|
|
1822
|
+
attrs = default_data[default_signal].attrs
|
|
1823
|
+
|
|
1824
|
+
return DataArray(data=signal_data,
|
|
1825
|
+
coords=coords,
|
|
1826
|
+
dims=dims,
|
|
1827
|
+
name=name,
|
|
1828
|
+
attrs=attrs)
|
|
1829
|
+
|
|
1830
|
+
|
|
1831
|
+
class NormalizeNexusProcessor(Processor):
|
|
1832
|
+
"""Processor for scaling one or more NXfields in the input nexus
|
|
1833
|
+
structure by the values of another NXfield in the same
|
|
1834
|
+
structure."""
|
|
1835
|
+
def process(self, data, normalize_nxfields, normalize_by_nxfield):
|
|
1836
|
+
"""Return copy of the original input nexus structure with
|
|
1837
|
+
additional fields containing the normalized data of each field
|
|
1838
|
+
in `normalize_nxfields`.
|
|
1839
|
+
|
|
1840
|
+
:param data: Input nexus structure containing all fields to be
|
|
1841
|
+
normalized an the field by which to normalize them.
|
|
1842
|
+
:type data: nexusformat.nexus.NXgroup
|
|
1843
|
+
:param normalize_nxfields:
|
|
1844
|
+
:type normalize_nxfields: list[str]
|
|
1845
|
+
:param normalize_by_nxfield: Path in `data` to the `NXfield`
|
|
1846
|
+
containing normalization data
|
|
1847
|
+
:type normalize_by_nxfield: str
|
|
1848
|
+
:returns: Copy of input data with additional normalized fields
|
|
1849
|
+
:rtype: nexusformat.nexus.NXgroup
|
|
1850
|
+
"""
|
|
1851
|
+
# Third party modules
|
|
1852
|
+
from nexusformat.nexus import (
|
|
1853
|
+
NXgroup,
|
|
1854
|
+
NXfield,
|
|
1855
|
+
)
|
|
1856
|
+
|
|
1857
|
+
# Local modules
|
|
1858
|
+
from CHAP.utils.general import nxcopy
|
|
1859
|
+
|
|
1860
|
+
# Check input data
|
|
1861
|
+
data = self.unwrap_pipelinedata(data)[0]
|
|
1862
|
+
data = nxcopy(data)
|
|
1863
|
+
if not isinstance(data, NXgroup):
|
|
1864
|
+
raise TypeError(f'Expected NXgroup, got (type{data})')
|
|
1865
|
+
|
|
1866
|
+
# Check normalize_by_nxfield
|
|
1867
|
+
if normalize_by_nxfield not in data:
|
|
1868
|
+
raise ValueError(
|
|
1869
|
+
f'{normalize_by_nxfield} not present in input data')
|
|
1870
|
+
if not isinstance(data[normalize_by_nxfield], NXfield):
|
|
1871
|
+
raise TypeError(
|
|
1872
|
+
f'{normalize_by_nxfield} is {type(data[normalize_by_nxfield])}'
|
|
1873
|
+
+ ', expected NXfield')
|
|
1874
|
+
normalization_data = data[normalize_by_nxfield].nxdata
|
|
1875
|
+
|
|
1876
|
+
# Process normalize_nxfields
|
|
1877
|
+
for nxfield in normalize_nxfields:
|
|
1878
|
+
if nxfield not in data:
|
|
1879
|
+
self.logger.error(f'{nxfield} not present in input data')
|
|
1880
|
+
elif not isinstance(data[nxfield], NXfield):
|
|
1881
|
+
self.logger.error(
|
|
1882
|
+
f'{nxfield} is {type(data[nxfield])}, expected NXfield')
|
|
1883
|
+
else:
|
|
1884
|
+
field_shape = data[nxfield].nxdata.shape
|
|
1885
|
+
if not normalization_data.shape == \
|
|
1886
|
+
field_shape[:normalization_data.ndim]:
|
|
1887
|
+
self.logger.error(
|
|
1888
|
+
f'Incompatible dataset shapes: {normalize_by_nxfield} '
|
|
1889
|
+
+ f'is {normalization_data.shape}, '
|
|
1890
|
+
+ f'{nxfield} is {field_shape}'
|
|
1891
|
+
)
|
|
1892
|
+
else:
|
|
1893
|
+
self.logger.info(f'Normalizing {nxfield}')
|
|
1894
|
+
# make shapes compatible
|
|
1895
|
+
_normalization_data = normalization_data.reshape(
|
|
1896
|
+
normalization_data.shape + (1,)
|
|
1897
|
+
* (data[nxfield].nxdata.ndim
|
|
1898
|
+
- normalization_data.ndim))
|
|
1899
|
+
data[f'{nxfield}_normalized'] = NXfield(
|
|
1900
|
+
value=data[nxfield].nxdata / _normalization_data,
|
|
1901
|
+
attrs={**data[nxfield].attrs,
|
|
1902
|
+
'normalized_by': normalize_by_nxfield}
|
|
1903
|
+
)
|
|
1904
|
+
return data
|
|
1905
|
+
|
|
1906
|
+
|
|
1907
|
+
class NormalizeMapProcessor(Processor):
|
|
1908
|
+
"""Processor for calling `NormalizeNexusProcessor` for (usually
|
|
1909
|
+
all) detector data in an `NXroot` resulting from
|
|
1910
|
+
`MapProcessor`"""
|
|
1911
|
+
def process(self, data, normalize_by_nxfield, detector_ids=None):
|
|
1912
|
+
"""Return copy of the original input map `NXroot` with
|
|
1913
|
+
additional fields containing normalized detector data.
|
|
1914
|
+
|
|
1915
|
+
:param data: Input nexus structure containing all fields to be
|
|
1916
|
+
normalized an the field by which to normalize them.
|
|
1917
|
+
:type data: nexusformat.nexus.NXroot
|
|
1918
|
+
:param normalize_by_nxfield: Path in `data` to the `NXfield`
|
|
1919
|
+
containing normalization data
|
|
1920
|
+
:type normalize_by_nxfield: str
|
|
1921
|
+
:returns: Copy of input data with additional normalized fields
|
|
1922
|
+
:rtype: nexusformat.nexus.NXroot
|
|
1923
|
+
"""
|
|
1924
|
+
# Third party modules
|
|
1925
|
+
from nexusformat.nexus import (
|
|
1926
|
+
NXentry,
|
|
1927
|
+
NXlink,
|
|
1928
|
+
)
|
|
1929
|
+
|
|
1930
|
+
# Check input data
|
|
1931
|
+
data = self.unwrap_pipelinedata(data)[0]
|
|
1932
|
+
map_title = None
|
|
1933
|
+
for k, v in data.items():
|
|
1934
|
+
if isinstance(v, NXentry):
|
|
1935
|
+
map_title = k
|
|
1936
|
+
break
|
|
1937
|
+
if map_title is None:
|
|
1938
|
+
self.logger.error(f'Input data contains no NXentry')
|
|
1939
|
+
else:
|
|
1940
|
+
self.logger.info(f'Got map_title: {map_title}')
|
|
1941
|
+
|
|
1942
|
+
# Check detector_ids
|
|
1943
|
+
normalize_nxfields = []
|
|
1944
|
+
if detector_ids is None:
|
|
1945
|
+
detector_ids = [k for k in data[map_title].data.keys()
|
|
1946
|
+
if not isinstance(data[map_title].data[k], NXlink)]
|
|
1947
|
+
self.logger.info(f'Using detector_ids: {detector_ids}')
|
|
1948
|
+
normalize_nxfields = [f'{map_title}/data/{_id}'
|
|
1949
|
+
for _id in detector_ids]
|
|
1950
|
+
|
|
1951
|
+
# Normalize
|
|
1952
|
+
normalizer = NormalizeNexusProcessor()
|
|
1953
|
+
normalizer.logger = self.logger
|
|
1954
|
+
return normalizer.process(
|
|
1955
|
+
data, normalize_nxfields, normalize_by_nxfield)
|
|
1956
|
+
|
|
1957
|
+
|
|
1958
|
+
class PrintProcessor(Processor):
|
|
1959
|
+
"""A Processor to simply print the input data to stdout and return
|
|
1960
|
+
the original input data, unchanged in any way.
|
|
1961
|
+
"""
|
|
1962
|
+
def process(self, data):
|
|
1963
|
+
"""Print and return the input data.
|
|
1964
|
+
|
|
1965
|
+
:param data: Input data.
|
|
1966
|
+
:type data: object
|
|
1967
|
+
:return: `data`
|
|
1968
|
+
:rtype: object
|
|
1969
|
+
"""
|
|
1970
|
+
if callable(getattr(data, '_str_tree', None)):
|
|
1971
|
+
# If data is likely a NeXus NXobject, print its tree
|
|
1972
|
+
# representation (since NXobjects' str representations are
|
|
1973
|
+
# just their nxname)
|
|
1974
|
+
print(data._str_tree(attrs=True, recursive=True))
|
|
1975
|
+
else:
|
|
1976
|
+
# System modules
|
|
1977
|
+
from pprint import pprint
|
|
1978
|
+
|
|
1979
|
+
pprint(data)
|
|
1980
|
+
|
|
1981
|
+
return data
|
|
1982
|
+
|
|
1983
|
+
|
|
1984
|
+
class PyfaiAzimuthalIntegrationProcessor(Processor):
|
|
1985
|
+
"""Processor to azimuthally integrate one or more frames of 2d
|
|
1986
|
+
detector data using the
|
|
1987
|
+
[pyFAI](https://pyfai.readthedocs.io/en/v2023.1/index.html)
|
|
1988
|
+
package.
|
|
1989
|
+
"""
|
|
1990
|
+
def process(
|
|
1991
|
+
self, data, poni_file, npt, mask_file=None,
|
|
1992
|
+
integrate1d_kwargs=None):
|
|
1993
|
+
"""Azimuthally integrate the detector data provided and return
|
|
1994
|
+
the result as a dictionary of numpy arrays containing the
|
|
1995
|
+
values of the radial coordinate of the result, the intensities
|
|
1996
|
+
along the radial direction, and the poisson errors for each
|
|
1997
|
+
intensity spectrum.
|
|
1998
|
+
|
|
1999
|
+
:param data: Detector data to integrate.
|
|
2000
|
+
:type data: Union[PipelineData, list[np.ndarray]]
|
|
2001
|
+
:param poni_file: Name of the [pyFAI PONI file]
|
|
2002
|
+
containing the detector properties pyFAI needs to perform
|
|
2003
|
+
azimuthal integration.
|
|
2004
|
+
:type poni_file: str
|
|
2005
|
+
:param npt: Number of points in the output pattern.
|
|
2006
|
+
:type npt: int
|
|
2007
|
+
:param mask_file: A file to use for masking the input data.
|
|
2008
|
+
:type mask_file: str, optional
|
|
2009
|
+
:param integrate1d_kwargs: Optional dictionary of keywords
|
|
2010
|
+
:type integrate1d_kwargs: Optional[dict]
|
|
2011
|
+
:returns: Azimuthal integration results as a dictionary of
|
|
2012
|
+
numpy arrays.
|
|
2013
|
+
"""
|
|
2014
|
+
# Third party modules
|
|
2015
|
+
from pyFAI import load
|
|
2016
|
+
|
|
2017
|
+
if not os.path.isabs(poni_file):
|
|
2018
|
+
poni_file = os.path.join(self.inputdir, poni_file)
|
|
2019
|
+
ai = load(poni_file)
|
|
2020
|
+
|
|
2021
|
+
if mask_file is None:
|
|
2022
|
+
mask = None
|
|
2023
|
+
else:
|
|
2024
|
+
# Third party modules
|
|
2025
|
+
import fabio
|
|
2026
|
+
|
|
2027
|
+
if not os.path.isabs(mask_file):
|
|
2028
|
+
mask_file = os.path.join(self.inputdir, mask_file)
|
|
2029
|
+
mask = fabio.open(mask_file).data
|
|
2030
|
+
|
|
2031
|
+
try:
|
|
2032
|
+
det_data = self.unwrap_pipelinedata(data)[0]
|
|
2033
|
+
except Exception:
|
|
2034
|
+
det_data = data
|
|
2035
|
+
|
|
2036
|
+
if integrate1d_kwargs is None:
|
|
2037
|
+
integrate1d_kwargs = {}
|
|
2038
|
+
integrate1d_kwargs['mask'] = mask
|
|
2039
|
+
|
|
2040
|
+
return [ai.integrate1d(d, npt, **integrate1d_kwargs) for d in det_data]
|
|
2041
|
+
|
|
2042
|
+
|
|
2043
|
+
class RawDetectorDataMapProcessor(Processor):
|
|
2044
|
+
"""A Processor to return a map of raw detector data in a
|
|
2045
|
+
NeXus NXroot object.
|
|
2046
|
+
"""
|
|
2047
|
+
def process(self, data, detector_name, detector_shape):
|
|
2048
|
+
"""Process configurations for a map and return the raw
|
|
2049
|
+
detector data data collected over the map.
|
|
2050
|
+
|
|
2051
|
+
:param data: Input map configuration.
|
|
2052
|
+
:type data: list[PipelineData]
|
|
2053
|
+
:param detector_name: The detector prefix.
|
|
2054
|
+
:type detector_name: str
|
|
2055
|
+
:param detector_shape: The shape of detector data for a single
|
|
2056
|
+
scan step.
|
|
2057
|
+
:type detector_shape: list
|
|
2058
|
+
:return: Map of raw detector data.
|
|
2059
|
+
:rtype: nexusformat.nexus.NXroot
|
|
2060
|
+
"""
|
|
2061
|
+
map_config = self.get_config(data)
|
|
2062
|
+
nxroot = self.get_nxroot(map_config, detector_name, detector_shape)
|
|
2063
|
+
|
|
2064
|
+
return nxroot
|
|
2065
|
+
|
|
2066
|
+
def get_config(self, data):
|
|
2067
|
+
"""Get instances of the map configuration object needed by this
|
|
2068
|
+
`Processor`.
|
|
2069
|
+
|
|
2070
|
+
:param data: Result of `Reader.read` where at least one item
|
|
2071
|
+
has the value `'common.models.map.MapConfig'` for the
|
|
2072
|
+
`'schema'` key.
|
|
2073
|
+
:type data: list[PipelineData]
|
|
2074
|
+
:raises Exception: If a valid map config object cannot be
|
|
2075
|
+
constructed from `data`.
|
|
2076
|
+
:return: A valid instance of the map configuration object with
|
|
2077
|
+
field values taken from `data`.
|
|
2078
|
+
:rtype: common.models.map.MapConfig
|
|
2079
|
+
"""
|
|
2080
|
+
# Local modules
|
|
2081
|
+
from CHAP.common.models.map import MapConfig
|
|
2082
|
+
|
|
2083
|
+
map_config = False
|
|
2084
|
+
if isinstance(data, list):
|
|
2085
|
+
for item in data:
|
|
2086
|
+
if isinstance(item, dict):
|
|
2087
|
+
if item.get('schema') == 'common.models.map.MapConfig':
|
|
2088
|
+
map_config = item.get('data')
|
|
2089
|
+
|
|
2090
|
+
if not map_config:
|
|
2091
|
+
raise ValueError('No map configuration found in input data')
|
|
2092
|
+
|
|
2093
|
+
return MapConfig(**map_config)
|
|
2094
|
+
|
|
2095
|
+
def get_nxroot(self, map_config, detector_name, detector_shape):
|
|
2096
|
+
"""Get a map of the detector data collected by the scans in
|
|
2097
|
+
`map_config`. The data will be returned along with some
|
|
2098
|
+
relevant metadata in the form of a NeXus structure.
|
|
2099
|
+
|
|
2100
|
+
:param map_config: The map configuration.
|
|
2101
|
+
:type map_config: common.models.map.MapConfig
|
|
2102
|
+
:param detector_name: The detector prefix.
|
|
2103
|
+
:type detector_name: str
|
|
2104
|
+
:param detector_shape: The shape of detector data for a single
|
|
2105
|
+
scan step.
|
|
2106
|
+
:type detector_shape: list
|
|
2107
|
+
:return: A map of the raw detector data.
|
|
2108
|
+
:rtype: nexusformat.nexus.NXroot
|
|
2109
|
+
"""
|
|
2110
|
+
# Third party modules
|
|
2111
|
+
# pylint: disable=no-name-in-module
|
|
2112
|
+
from nexusformat.nexus import (
|
|
2113
|
+
NXdata,
|
|
2114
|
+
NXdetector,
|
|
2115
|
+
NXinstrument,
|
|
2116
|
+
NXroot,
|
|
2117
|
+
)
|
|
2118
|
+
# pylint: enable=no-name-in-module
|
|
2119
|
+
|
|
2120
|
+
raise RuntimeError('Not updated for the new MapProcessor')
|
|
2121
|
+
nxroot = NXroot()
|
|
2122
|
+
|
|
2123
|
+
nxroot[map_config.title] = MapProcessor.get_nxentry(map_config)
|
|
2124
|
+
nxentry = nxroot[map_config.title]
|
|
2125
|
+
|
|
2126
|
+
nxentry.instrument = NXinstrument()
|
|
2127
|
+
nxentry.instrument.detector = NXdetector()
|
|
2128
|
+
|
|
2129
|
+
nxentry.instrument.detector.data = NXdata()
|
|
2130
|
+
nxdata = nxentry.instrument.detector.data
|
|
2131
|
+
nxdata.raw = np.empty((*map_config.shape, *detector_shape))
|
|
2132
|
+
nxdata.raw.attrs['units'] = 'counts'
|
|
2133
|
+
for i, det_axis_size in enumerate(detector_shape):
|
|
2134
|
+
nxdata[f'detector_axis_{i}_index'] = np.arange(det_axis_size)
|
|
2135
|
+
|
|
2136
|
+
for map_index in np.ndindex(map_config.shape):
|
|
2137
|
+
scans, scan_number, scan_step_index = \
|
|
2138
|
+
map_config.get_scan_step_index(map_index)
|
|
2139
|
+
scanparser = scans.get_scanparser(scan_number)
|
|
2140
|
+
self.logger.debug(
|
|
2141
|
+
f'Adding data to nxroot for map point {map_index}')
|
|
2142
|
+
nxdata.raw[map_index] = scanparser.get_detector_data(
|
|
2143
|
+
detector_name,
|
|
2144
|
+
scan_step_index)
|
|
2145
|
+
|
|
2146
|
+
nxentry.data.makelink(
|
|
2147
|
+
nxdata.raw,
|
|
2148
|
+
name=detector_name)
|
|
2149
|
+
for i, det_axis_size in enumerate(detector_shape):
|
|
2150
|
+
nxentry.data.makelink(
|
|
2151
|
+
nxdata[f'detector_axis_{i}_index'],
|
|
2152
|
+
name=f'{detector_name}_axis_{i}_index'
|
|
2153
|
+
)
|
|
2154
|
+
if isinstance(nxentry.data.attrs['axes'], str):
|
|
2155
|
+
nxentry.data.attrs['axes'] = [
|
|
2156
|
+
nxentry.data.attrs['axes'],
|
|
2157
|
+
f'{detector_name}_axis_{i}_index']
|
|
2158
|
+
else:
|
|
2159
|
+
nxentry.data.attrs['axes'] += [
|
|
2160
|
+
f'{detector_name}_axis_{i}_index']
|
|
2161
|
+
|
|
2162
|
+
nxentry.data.attrs['signal'] = detector_name
|
|
2163
|
+
|
|
2164
|
+
return nxroot
|
|
2165
|
+
|
|
2166
|
+
|
|
2167
|
+
class SetupNXdataProcessor(Processor):
|
|
2168
|
+
"""Processor to set up and return an "empty" NeXus representation
|
|
2169
|
+
of a structured dataset. This representation will be an instance
|
|
2170
|
+
of a NeXus NXdata object that has:
|
|
2171
|
+
A NeXus NXfield entry for every coordinate/signal specified.
|
|
2172
|
+
`nxaxes` that are the NeXus NXfield entries for the coordinates and contain the values provided for each coordinate.
|
|
2173
|
+
NeXus NXfield entries of appropriate shape, but containing all zeros, for every signal.
|
|
2174
|
+
Attributes that define the axes, plus any additional attributes specified by the user.
|
|
2175
|
+
|
|
2176
|
+
This `Processor` is most useful as a "setup" step for constucting
|
|
2177
|
+
a representation of / container for a complete dataset that will
|
|
2178
|
+
be filled out in pieces later by `UpdateNXdataProcessor`.
|
|
2179
|
+
"""
|
|
2180
|
+
def process(
|
|
2181
|
+
self, data, nxname='data', coords=None, signals=None, attrs=None,
|
|
2182
|
+
data_points=None, extra_nxfields=None, duplicates='overwrite'):
|
|
2183
|
+
"""Return a NeXus NXdata object that has the requisite axes
|
|
2184
|
+
and NeXus NXfield entries to represent a structured dataset
|
|
2185
|
+
with the properties provided. Properties may be provided either
|
|
2186
|
+
through the `data` argument (from an appropriate `PipelineItem`
|
|
2187
|
+
that immediately preceeds this one in a `Pipeline`), or through
|
|
2188
|
+
the `coords`, `signals`, `attrs`, and/or `data_points`
|
|
2189
|
+
arguments. If any of the latter are used, their values will
|
|
2190
|
+
completely override any values for these parameters found from
|
|
2191
|
+
`data`.
|
|
2192
|
+
|
|
2193
|
+
:param data: Data from the previous item in a `Pipeline`.
|
|
2194
|
+
:type data: list[PipelineData]
|
|
2195
|
+
:param nxname: Name for the returned NeXus NXdata object,
|
|
2196
|
+
defaults to `'data'`.
|
|
2197
|
+
:type nxname: str, optional
|
|
2198
|
+
:param coords: List of dictionaries defining the coordinates
|
|
2199
|
+
of the dataset. Each dictionary must have the keys
|
|
2200
|
+
`'name'` and `'values'`, whose values are the name of the
|
|
2201
|
+
coordinate axis (a string) and all the unique values of
|
|
2202
|
+
that coordinate for the structured dataset (a list of
|
|
2203
|
+
numbers), respectively. A third item in the dictionary is
|
|
2204
|
+
optional, but highly recommended: `'attrs'` may provide a
|
|
2205
|
+
dictionary of attributes to attach to the coordinate axis
|
|
2206
|
+
that assist in in interpreting the returned NeXus NXdata
|
|
2207
|
+
representation of the dataset. It is strongly recommended
|
|
2208
|
+
to provide the units of the values along an axis in the
|
|
2209
|
+
`attrs` dictionary.
|
|
2210
|
+
:type coords: list[dict[str, object]], optional
|
|
2211
|
+
:param signals: List of dictionaries defining the signals of
|
|
2212
|
+
the dataset. Each dictionary must have the keys `'name'`
|
|
2213
|
+
and `'shape'`, whose values are the name of the signal
|
|
2214
|
+
field (a string) and the shape of the signal's value at
|
|
2215
|
+
each point in the dataset (a list of zero or more
|
|
2216
|
+
integers), respectively. A third item in the dictionary is
|
|
2217
|
+
optional, but highly recommended: `'attrs'` may provide a
|
|
2218
|
+
dictionary of attributes to attach to the signal fieldthat
|
|
2219
|
+
assist in in interpreting the returned NeXus NXdata
|
|
2220
|
+
representation of the dataset. It is strongly recommended
|
|
2221
|
+
to provide the units of the signal's values `attrs`
|
|
2222
|
+
dictionary.
|
|
2223
|
+
:type signals: list[dict[str, object]], optional
|
|
2224
|
+
:param attrs: An arbitrary dictionary of attributes to assign
|
|
2225
|
+
to the returned NeXus NXdata object.
|
|
2226
|
+
:type attrs: dict[str, object], optional
|
|
2227
|
+
:param data_points: A list of data points to partially (or
|
|
2228
|
+
even entirely) fil out the "empty" signal NeXus NXfield's
|
|
2229
|
+
before returning the NeXus NXdata object.
|
|
2230
|
+
:type data_points: list[dict[str, object]], optional
|
|
2231
|
+
:param extra_nxfields: List "extra" NeXus NXfields to include
|
|
2232
|
+
that can be described neither as a signal of the dataset,
|
|
2233
|
+
not a dedicated coordinate. This paramteter is good for
|
|
2234
|
+
including "alternate" values for one of the coordinate
|
|
2235
|
+
dimensions -- the same coordinate axis expressed in
|
|
2236
|
+
different units, for instance. Each item in the list should
|
|
2237
|
+
be a dictionary of parameters for the
|
|
2238
|
+
`nexusformat.nexus.NXfield` constructor.
|
|
2239
|
+
:type extra_nxfields: list[dict[str, object]], optional
|
|
2240
|
+
:param duplicates: Behavior to use if any new data points occur
|
|
2241
|
+
at the same point in the dataset's coordinate space as an
|
|
2242
|
+
existing data point. Allowed values for `duplicates` are:
|
|
2243
|
+
`'overwrite'` and `'block'`. Defaults to `'overwrite'`.
|
|
2244
|
+
:type duplicates: Literal['overwrite', 'block']
|
|
2245
|
+
:returns: A NeXus NXdata object that represents the structured
|
|
2246
|
+
dataset as specified.
|
|
2247
|
+
:rtype: nexusformat.nexus.NXdata
|
|
2248
|
+
"""
|
|
2249
|
+
self.nxname = nxname
|
|
2250
|
+
|
|
2251
|
+
if coords is None:
|
|
2252
|
+
coords = []
|
|
2253
|
+
if signals is None:
|
|
2254
|
+
signals = []
|
|
2255
|
+
if attrs is None:
|
|
2256
|
+
attrs = {}
|
|
2257
|
+
if extra_nxfields is None:
|
|
2258
|
+
extra_nxfields = []
|
|
2259
|
+
self.coords = coords
|
|
2260
|
+
self.signals = signals
|
|
2261
|
+
self.attrs = attrs
|
|
2262
|
+
self.data_points = data_points
|
|
2263
|
+
try:
|
|
2264
|
+
setup_params = self.unwrap_pipelinedata(data)[0]
|
|
2265
|
+
except Exception:
|
|
2266
|
+
setup_params = None
|
|
2267
|
+
if isinstance(setup_params, dict):
|
|
2268
|
+
for a in ('coords', 'signals', 'attrs', 'data_points'):
|
|
2269
|
+
setup_param = setup_params.get(a)
|
|
2270
|
+
if not getattr(self, a) and setup_param is not None:
|
|
2271
|
+
self.logger.info(f'Using input data from pipeline for {a}')
|
|
2272
|
+
setattr(self, a, setup_param)
|
|
2273
|
+
else:
|
|
2274
|
+
self.logger.info(
|
|
2275
|
+
f'Ignoring input data from pipeline for {a}')
|
|
2276
|
+
else:
|
|
2277
|
+
self.logger.warning('Ignoring all input data from pipeline')
|
|
2278
|
+
self.shape = tuple(len(c['values']) for c in self.coords)
|
|
2279
|
+
self.extra_nxfields = extra_nxfields
|
|
2280
|
+
self.duplicates = duplicates
|
|
2281
|
+
self.init_nxdata()
|
|
2282
|
+
|
|
2283
|
+
if self.data_points is not None:
|
|
2284
|
+
for d in self.data_points:
|
|
2285
|
+
self.add_data_point(d)
|
|
2286
|
+
|
|
2287
|
+
return self.nxdata
|
|
2288
|
+
|
|
2289
|
+
def add_data_point(self, data_point):
|
|
2290
|
+
"""Add a data point to this dataset.
|
|
2291
|
+
1. Validate `data_point`.
|
|
2292
|
+
2. Append `data_point` to `self.data_points`.
|
|
2293
|
+
3. Update signal `NXfield`s in `self.nxdata`.
|
|
2294
|
+
|
|
2295
|
+
:param data_point: Data point defining a point in the
|
|
2296
|
+
dataset's coordinate space and the new signal values at
|
|
2297
|
+
that point.
|
|
2298
|
+
:type data_point: dict[str, object]
|
|
2299
|
+
:returns: None
|
|
2300
|
+
"""
|
|
2301
|
+
self.logger.info(
|
|
2302
|
+
f'Adding data point no. {data_point["dataset_point_index"]+1} of '
|
|
2303
|
+
f'{len(self.data_points)}')
|
|
2304
|
+
self.logger.debug(f'New data point: {data_point}')
|
|
2305
|
+
valid, msg = self.validate_data_point(data_point)
|
|
2306
|
+
if not valid:
|
|
2307
|
+
self.logger.error(f'Cannot add data point: {msg}')
|
|
2308
|
+
else:
|
|
2309
|
+
self.update_nxdata(data_point)
|
|
2310
|
+
|
|
2311
|
+
def validate_data_point(self, data_point):
|
|
2312
|
+
"""Return `True` if `data_point` occurs at a valid point in
|
|
2313
|
+
this structured dataset's coordinate space, `False`
|
|
2314
|
+
otherwise. Also validate shapes of signal values and add NaN
|
|
2315
|
+
values for any missing signals.
|
|
2316
|
+
|
|
2317
|
+
:param data_point: Data point defining a point in the
|
|
2318
|
+
dataset's coordinate space and the new signal values at
|
|
2319
|
+
that point.
|
|
2320
|
+
:type data_point: dict[str, object]
|
|
2321
|
+
:returns: Validity of `data_point`, message
|
|
2322
|
+
:rtype: bool, str
|
|
2323
|
+
"""
|
|
2324
|
+
valid = True
|
|
2325
|
+
msg = ''
|
|
2326
|
+
# Convert all values to numpy types
|
|
2327
|
+
data_point = {k: np.asarray(v) for k, v in data_point.items()}
|
|
2328
|
+
# Ensure data_point defines a specific point in the dataset's
|
|
2329
|
+
# coordinate space
|
|
2330
|
+
if not all(c['name'] in data_point for c in self.coords):
|
|
2331
|
+
valid = False
|
|
2332
|
+
msg = 'Missing coordinate values'
|
|
2333
|
+
# Ensure a value is present for all signals
|
|
2334
|
+
for s in self.signals:
|
|
2335
|
+
name = s['name']
|
|
2336
|
+
if name not in data_point:
|
|
2337
|
+
data_point[name] = np.full(s['shape'], 0)
|
|
2338
|
+
else:
|
|
2339
|
+
if not data_point[name].shape == tuple(s['shape']):
|
|
2340
|
+
valid = False
|
|
2341
|
+
msg = f'Shape mismatch for signal {s}'
|
|
2342
|
+
return valid, msg
|
|
2343
|
+
|
|
2344
|
+
def init_nxdata(self):
|
|
2345
|
+
"""Initialize an empty NeXus NXdata representing this dataset
|
|
2346
|
+
to `self.nxdata`; values for axes' `NXfield`s are filled out,
|
|
2347
|
+
values for signals' `NXfield`s are empty an can be filled out
|
|
2348
|
+
later. Save the empty NeXus NXdata object to the NeXus file.
|
|
2349
|
+
Initialise `self.nxfile` and `self.nxdata_path` with the
|
|
2350
|
+
`NXFile` object and actual nxpath used to save and make updates
|
|
2351
|
+
to the Nexus NXdata object.
|
|
2352
|
+
"""
|
|
2353
|
+
# Third party modules
|
|
2354
|
+
from nexusformat.nexus import (
|
|
2355
|
+
NXdata,
|
|
2356
|
+
NXfield,
|
|
2357
|
+
)
|
|
2358
|
+
|
|
2359
|
+
axes = tuple(NXfield(
|
|
2360
|
+
value=c['values'],
|
|
2361
|
+
name=c['name'],
|
|
2362
|
+
attrs=c.get('attrs'),
|
|
2363
|
+
dtype=c.get('dtype', 'float64')) for c in self.coords)
|
|
2364
|
+
entries = {s['name']: NXfield(
|
|
2365
|
+
value=np.full((*self.shape, *s['shape']), 0),
|
|
2366
|
+
name=s['name'],
|
|
2367
|
+
attrs=s.get('attrs'),
|
|
2368
|
+
dtype=s.get('dtype', 'float64')) for s in self.signals}
|
|
2369
|
+
extra_nxfields = [NXfield(**params) for params in self.extra_nxfields]
|
|
2370
|
+
extra_nxfields = {f.nxname: f for f in extra_nxfields}
|
|
2371
|
+
entries.update(extra_nxfields)
|
|
2372
|
+
self.nxdata = NXdata(
|
|
2373
|
+
name=self.nxname, axes=axes, entries=entries, attrs=self.attrs)
|
|
2374
|
+
|
|
2375
|
+
def update_nxdata(self, data_point):
|
|
2376
|
+
"""Update `self.nxdata`'s NXfield values.
|
|
2377
|
+
|
|
2378
|
+
:param data_point: Data point defining a point in the
|
|
2379
|
+
dataset's coordinate space and the new signal values at
|
|
2380
|
+
that point.
|
|
2381
|
+
:type data_point: dict[str, object]
|
|
2382
|
+
:returns: None
|
|
2383
|
+
"""
|
|
2384
|
+
index = self.get_index(data_point)
|
|
2385
|
+
for s in self.signals:
|
|
2386
|
+
name = s['name']
|
|
2387
|
+
if name in data_point:
|
|
2388
|
+
self.nxdata[name][index] = data_point[name]
|
|
2389
|
+
|
|
2390
|
+
def get_index(self, data_point):
|
|
2391
|
+
"""Return a tuple representing the array index of `data_point`
|
|
2392
|
+
in the coordinate space of the dataset.
|
|
2393
|
+
|
|
2394
|
+
:param data_point: Data point defining a point in the
|
|
2395
|
+
dataset's coordinate space.
|
|
2396
|
+
:type data_point: dict[str, object]
|
|
2397
|
+
:returns: Multi-dimensional index of `data_point` in the
|
|
2398
|
+
dataset's coordinate space.
|
|
2399
|
+
:rtype: tuple
|
|
2400
|
+
"""
|
|
2401
|
+
return tuple(c['values'].index(data_point[c['name']])
|
|
2402
|
+
for c in self.coords)
|
|
2403
|
+
|
|
2404
|
+
|
|
2405
|
+
class UnstructuredToStructuredProcessor(Processor):
|
|
2406
|
+
"""Processor to reshape data in an NXdata from an "unstructured"
|
|
2407
|
+
to a "structured" representation.
|
|
2408
|
+
"""
|
|
2409
|
+
def process(self, data, nxpath=None):
|
|
2410
|
+
# Third party modules
|
|
2411
|
+
from nexusformat.nexus import NXdata
|
|
2412
|
+
|
|
2413
|
+
try:
|
|
2414
|
+
nxobject = self.get_data(data)
|
|
2415
|
+
except Exception:
|
|
2416
|
+
nxobject = self.unwrap_pipelinedata(data)[0]
|
|
2417
|
+
if isinstance(nxobject, NXdata):
|
|
2418
|
+
return self.convert_nxdata(nxobject)
|
|
2419
|
+
if nxpath is not None:
|
|
2420
|
+
# Local modules
|
|
2421
|
+
# from CHAP.utils.general import nxcopy
|
|
2422
|
+
try:
|
|
2423
|
+
nxobject = nxobject[nxpath]
|
|
2424
|
+
except Exception as exc:
|
|
2425
|
+
raise ValueError(
|
|
2426
|
+
f'Invalid parameter nxpath ({nxpath})') from exc
|
|
2427
|
+
else:
|
|
2428
|
+
raise ValueError(f'Invalid input data ({data})')
|
|
2429
|
+
return self.convert_nxdata(nxobject)
|
|
2430
|
+
|
|
2431
|
+
def convert_nxdata(self, nxdata):
|
|
2432
|
+
# Third party modules
|
|
2433
|
+
from nexusformat.nexus import (
|
|
2434
|
+
NXdata,
|
|
2435
|
+
NXfield,
|
|
2436
|
+
)
|
|
2437
|
+
|
|
2438
|
+
# Local modules
|
|
2439
|
+
from CHAP.edd.processor import get_axes
|
|
2440
|
+
|
|
2441
|
+
# Extract axes from the NXdata attributes
|
|
2442
|
+
axes = get_axes(nxdata)
|
|
2443
|
+
for a in axes:
|
|
2444
|
+
if a not in nxdata:
|
|
2445
|
+
raise ValueError(f'Missing coordinates for {a}')
|
|
2446
|
+
|
|
2447
|
+
# Check the independent dimensions and axes
|
|
2448
|
+
unstructured_axes = []
|
|
2449
|
+
unstructured_dim = None
|
|
2450
|
+
for a in axes:
|
|
2451
|
+
if not isinstance(nxdata[a], NXfield):
|
|
2452
|
+
raise ValueError(
|
|
2453
|
+
f'Invalid axis field type ({type(nxdata[a])})')
|
|
2454
|
+
if len(nxdata[a].shape) == 1:
|
|
2455
|
+
if not unstructured_axes:
|
|
2456
|
+
unstructured_axes.append(a)
|
|
2457
|
+
unstructured_dim = nxdata[a].size
|
|
2458
|
+
else:
|
|
2459
|
+
if nxdata[a].size == unstructured_dim:
|
|
2460
|
+
unstructured_axes.append(a)
|
|
2461
|
+
elif 'unstructured_axes' in nxdata.attrs:
|
|
2462
|
+
raise ValueError(f'Inconsistent axes dimensions')
|
|
2463
|
+
elif 'unstructured_axes' in nxdata.attrs:
|
|
2464
|
+
raise ValueError(
|
|
2465
|
+
f'Invalid unstructered axis shape ({nxdata[a].shape})')
|
|
2466
|
+
if not axes and hasattr(nxdata, 'signal'):
|
|
2467
|
+
if len(nxdata[nxdata.signal].shape) < 2:
|
|
2468
|
+
raise ValueError(
|
|
2469
|
+
f'Invalid signal shape ({nxdata[nxdata.signal].shape})')
|
|
2470
|
+
unstructured_dim = nxdata[nxdata.signal].shape[0]
|
|
2471
|
+
for k, v in nxdata.items():
|
|
2472
|
+
if (isinstance(v, NXfield) and len(v.shape) == 1
|
|
2473
|
+
and v.shape[0] == unstructured_dim):
|
|
2474
|
+
unstructured_axes.append(k)
|
|
2475
|
+
if unstructured_dim is None:
|
|
2476
|
+
raise ValueError(f'Unable to determine the unstructered axes')
|
|
2477
|
+
axes = unstructured_axes
|
|
2478
|
+
|
|
2479
|
+
# Identify unique coordinate points for each axis
|
|
2480
|
+
unique_coords = {}
|
|
2481
|
+
coords = {}
|
|
2482
|
+
axes_attrs = {}
|
|
2483
|
+
for a in axes:
|
|
2484
|
+
coords[a] = nxdata[a].nxdata
|
|
2485
|
+
unique_coords[a] = np.sort(np.unique(nxdata[a].nxdata))
|
|
2486
|
+
axes_attrs[a] = deepcopy(nxdata[a].attrs)
|
|
2487
|
+
if 'target' in axes_attrs[a]:
|
|
2488
|
+
del axes_attrs[a]['target']
|
|
2489
|
+
|
|
2490
|
+
# Calculate the total number of unique coordinate points
|
|
2491
|
+
unique_npts = np.prod([len(v) for k, v in unique_coords.items()])
|
|
2492
|
+
if unique_npts != unstructured_dim:
|
|
2493
|
+
self.logger.warning('The unstructered grid does not fully map to '
|
|
2494
|
+
'a structered one (there are missing points)')
|
|
2495
|
+
|
|
2496
|
+
# Identify the signals and the data point axes
|
|
2497
|
+
signals = []
|
|
2498
|
+
data_point_axes = []
|
|
2499
|
+
data_point_shape = []
|
|
2500
|
+
if hasattr(nxdata, 'signal'):
|
|
2501
|
+
if (len(nxdata[nxdata.signal].shape) < 2
|
|
2502
|
+
or nxdata[nxdata.signal].shape[0] != unstructured_dim):
|
|
2503
|
+
raise ValueError(
|
|
2504
|
+
f'Invalid signal shape ({nxdata[nxdata.signal].shape})')
|
|
2505
|
+
signals = [nxdata.signal]
|
|
2506
|
+
data_point_shape = [nxdata[nxdata.signal].shape[1:]]
|
|
2507
|
+
for k, v in nxdata.items():
|
|
2508
|
+
if (isinstance(v, NXfield) and k not in axes and k not in signals
|
|
2509
|
+
and v.shape[0] == unstructured_dim):
|
|
2510
|
+
signals.append(k)
|
|
2511
|
+
if not data_point_shape:
|
|
2512
|
+
data_point_shape.append(v.shape[1:])
|
|
2513
|
+
if len(data_point_shape) == 1:
|
|
2514
|
+
data_point_shape = data_point_shape[0]
|
|
2515
|
+
else:
|
|
2516
|
+
data_point_shape = []
|
|
2517
|
+
for _ in data_point_shape:
|
|
2518
|
+
for k, v in nxdata.items():
|
|
2519
|
+
if (isinstance(v, NXfield) and k not in axes
|
|
2520
|
+
and v.shape == data_point_shape):
|
|
2521
|
+
data_point_axes.append(k)
|
|
2522
|
+
|
|
2523
|
+
# Create the structured NXdata object
|
|
2524
|
+
structured_shape = tuple(len(unique_coords[a]) for a in axes)
|
|
2525
|
+
attrs = deepcopy(nxdata.attrs)
|
|
2526
|
+
if 'unstructured_axes' in attrs:
|
|
2527
|
+
attrs.pop('unstructured_axes')
|
|
2528
|
+
attrs['axes'] = axes
|
|
2529
|
+
nxdata_structured = NXdata(
|
|
2530
|
+
name=f'{nxdata.nxname}_structured',
|
|
2531
|
+
**{a: NXfield(
|
|
2532
|
+
value=unique_coords[a],
|
|
2533
|
+
attrs=axes_attrs[a])
|
|
2534
|
+
for a in axes},
|
|
2535
|
+
**{s: NXfield(
|
|
2536
|
+
# value=np.reshape( # FIX not always a sound way to reshape.
|
|
2537
|
+
# nxdata[s], (*structured_shape, *nxdata[s].shape[1:])),
|
|
2538
|
+
dtype=nxdata[s].dtype,
|
|
2539
|
+
shape=(*structured_shape, *nxdata[s].shape[1:]),
|
|
2540
|
+
attrs=nxdata[s].attrs)
|
|
2541
|
+
for s in signals},
|
|
2542
|
+
attrs=attrs)
|
|
2543
|
+
if len(data_point_axes) == 1:
|
|
2544
|
+
axes = nxdata_structured.attrs['axes']
|
|
2545
|
+
if isinstance(axes, str):
|
|
2546
|
+
axes = [axes]
|
|
2547
|
+
nxdata_structured.attrs['axes'] = axes + data_point_axes
|
|
2548
|
+
for a in data_point_axes:
|
|
2549
|
+
nxdata_structured[a] = NXfield(
|
|
2550
|
+
value=nxdata[a], attrs=nxdata[a].attrs)
|
|
2551
|
+
|
|
2552
|
+
# Populate the structured NXdata object with values
|
|
2553
|
+
for i, coord in enumerate(zip(*tuple(nxdata[a].nxdata for a in axes))):
|
|
2554
|
+
structured_index = tuple(
|
|
2555
|
+
np.asarray(
|
|
2556
|
+
coord[ii] == unique_coords[axes[ii]]).nonzero()[0][0]
|
|
2557
|
+
for ii in range(len(axes)))
|
|
2558
|
+
for s in signals:
|
|
2559
|
+
nxdata_structured[s][structured_index] = nxdata[s][i]
|
|
2560
|
+
|
|
2561
|
+
return nxdata_structured
|
|
2562
|
+
|
|
2563
|
+
|
|
2564
|
+
class UpdateNXvalueProcessor(Processor):
|
|
2565
|
+
"""Processor to fill in part(s) of a NeXus object representing a
|
|
2566
|
+
structured dataset that's already been written to a NeXus file.
|
|
2567
|
+
|
|
2568
|
+
This Processor is most useful as an "update" step for a NeXus
|
|
2569
|
+
NXdata object created by `common.SetupNXdataProcessor`, and is
|
|
2570
|
+
most easy to use in a `Pipeline` immediately after another
|
|
2571
|
+
`PipelineItem` designed specifically to return a value that can
|
|
2572
|
+
be used as input to this `Processor`.
|
|
2573
|
+
"""
|
|
2574
|
+
def process(self, data, nxfilename, data_points=None):
|
|
2575
|
+
"""Write new data values to an existing NeXus object
|
|
2576
|
+
representing an unstructured dataset in a NeXus file.
|
|
2577
|
+
Return the list of data points used to update the dataset.
|
|
2578
|
+
|
|
2579
|
+
:param data: Data from the previous item in a `Pipeline`. May
|
|
2580
|
+
contain a list of data points that will extend the list of
|
|
2581
|
+
data points optionally provided with the `data_points`
|
|
2582
|
+
argument.
|
|
2583
|
+
:type data: list[PipelineData]
|
|
2584
|
+
:param nxfilename: Name of the NeXus file containing the
|
|
2585
|
+
NeXus object to update.
|
|
2586
|
+
:type nxfilename: str
|
|
2587
|
+
:param data_points: List of data points, each one a dictionary
|
|
2588
|
+
whose keys are the names of the nxpath, the index of the
|
|
2589
|
+
data point in the dataset, and the data value.
|
|
2590
|
+
:type data_points: Optional[list[dict[str, object]]]
|
|
2591
|
+
:returns: Complete list of data points used to update the
|
|
2592
|
+
dataset.
|
|
2593
|
+
:rtype: list[dict[str, object]]
|
|
2594
|
+
"""
|
|
2595
|
+
# Third party modules
|
|
2596
|
+
from nexusformat.nexus import NXFile
|
|
2597
|
+
|
|
2598
|
+
if data_points is None:
|
|
2599
|
+
data_points = []
|
|
2600
|
+
self.logger.debug(f'Got {len(data_points)} data points from keyword')
|
|
2601
|
+
ddata_points = self.unwrap_pipelinedata(data)[0]
|
|
2602
|
+
if isinstance(ddata_points, list):
|
|
2603
|
+
self.logger.debug(f'Got {len(ddata_points)} from pipeline data')
|
|
2604
|
+
data_points.extend(ddata_points)
|
|
2605
|
+
self.logger.info(f'Updating a total of {len(data_points)} data points')
|
|
2606
|
+
|
|
2607
|
+
if not os.path.isabs(nxfilename):
|
|
2608
|
+
nxfilename = os.path.join(self.inputdir, nxfilename)
|
|
2609
|
+
nxfile = NXFile(nxfilename, 'rw')
|
|
2610
|
+
|
|
2611
|
+
indices = []
|
|
2612
|
+
for data_point in data_points:
|
|
2613
|
+
try:
|
|
2614
|
+
nxfile.writevalue(
|
|
2615
|
+
data_point['nxpath'], np.asarray(data_point['value']),
|
|
2616
|
+
data_point['index'])
|
|
2617
|
+
indices.append(data_point['index'])
|
|
2618
|
+
except Exception as exc:
|
|
2619
|
+
self.logger.error(
|
|
2620
|
+
f'Error updating {data_point["nxpath"]} for data point '
|
|
2621
|
+
f'{data_point["index"]}: {exc}')
|
|
2622
|
+
else:
|
|
2623
|
+
self.logger.debug(f'Updated data point {data_point}')
|
|
2624
|
+
|
|
2625
|
+
nxfile.close()
|
|
2626
|
+
|
|
2627
|
+
return data_points
|
|
2628
|
+
|
|
2629
|
+
|
|
2630
|
+
class UpdateNXdataProcessor(Processor):
|
|
2631
|
+
"""Processor to fill in part(s) of a NeXus NXdata representing a
|
|
2632
|
+
structured dataset that's already been written to a NeXus file.
|
|
2633
|
+
|
|
2634
|
+
This Processor is most useful as an "update" step for a NeXus
|
|
2635
|
+
NXdata object created by `common.SetupNXdataProcessor`, and is
|
|
2636
|
+
most easy to use in a `Pipeline` immediately after another
|
|
2637
|
+
`PipelineItem` designed specifically to return a value that can
|
|
2638
|
+
be used as input to this `Processor`.
|
|
2639
|
+
"""
|
|
2640
|
+
def process(
|
|
2641
|
+
self, data, nxfilename, nxdata_path, data_points=None,
|
|
2642
|
+
allow_approximate_coordinates=False):
|
|
2643
|
+
"""Write new data points to the signal fields of an existing
|
|
2644
|
+
NeXus NXdata object representing a structued dataset in a NeXus
|
|
2645
|
+
file. Return the list of data points used to update the
|
|
2646
|
+
dataset.
|
|
2647
|
+
|
|
2648
|
+
:param data: Data from the previous item in a `Pipeline`. May
|
|
2649
|
+
contain a list of data points that will extend the list of
|
|
2650
|
+
data points optionally provided with the `data_points`
|
|
2651
|
+
argument.
|
|
2652
|
+
:type data: list[PipelineData]
|
|
2653
|
+
:param nxfilename: Name of the NeXus file containing the
|
|
2654
|
+
NeXus NXdata object to update.
|
|
2655
|
+
:type nxfilename: str
|
|
2656
|
+
:param nxdata_path: The path to the NeXus NXdata object to
|
|
2657
|
+
update in the file.
|
|
2658
|
+
:type nxdata_path: str
|
|
2659
|
+
:param data_points: List of data points, each one a dictionary
|
|
2660
|
+
whose keys are the names of the coordinates and axes, and
|
|
2661
|
+
whose values are the values of each coordinate / signal at
|
|
2662
|
+
a single point in the dataset. Deafults to None.
|
|
2663
|
+
:type data_points: Optional[list[dict[str, object]]]
|
|
2664
|
+
:param allow_approximate_coordinates: Parameter to allow the
|
|
2665
|
+
nearest existing match for the new data points'
|
|
2666
|
+
coordinates to be used if an exact match connot be found
|
|
2667
|
+
(sometimes this is due simply to differences in rounding
|
|
2668
|
+
convetions). Defaults to False.
|
|
2669
|
+
:type allow_approximate_coordinates: bool, optional
|
|
2670
|
+
:returns: Complete list of data points used to update the dataset.
|
|
2671
|
+
:rtype: list[dict[str, object]]
|
|
2672
|
+
"""
|
|
2673
|
+
# Third party modules
|
|
2674
|
+
from nexusformat.nexus import NXFile
|
|
2675
|
+
|
|
2676
|
+
if data_points is None:
|
|
2677
|
+
data_points = []
|
|
2678
|
+
self.logger.debug(f'Got {len(data_points)} data points from keyword')
|
|
2679
|
+
_data_points = self.unwrap_pipelinedata(data)[0]
|
|
2680
|
+
if isinstance(_data_points, list):
|
|
2681
|
+
self.logger.debug(f'Got {len(_data_points)} from pipeline data')
|
|
2682
|
+
data_points.extend(_data_points)
|
|
2683
|
+
self.logger.info(f'Updating {len(data_points)} data points total')
|
|
2684
|
+
|
|
2685
|
+
if not os.path.isabs(nxfilename):
|
|
2686
|
+
nxfilename = os.path.join(self.inputdir, nxfilename)
|
|
2687
|
+
nxfile = NXFile(nxfilename, 'rw')
|
|
2688
|
+
nxdata = nxfile.readfile()[nxdata_path]
|
|
2689
|
+
axes_names = [a.nxname for a in nxdata.nxaxes]
|
|
2690
|
+
|
|
2691
|
+
data_points_used = []
|
|
2692
|
+
for i, d in enumerate(data_points):
|
|
2693
|
+
# Verify that the data point contains a value for all
|
|
2694
|
+
# coordinates in the dataset.
|
|
2695
|
+
if not all(a in d for a in axes_names):
|
|
2696
|
+
self.logger.error(
|
|
2697
|
+
f'Data point {i} is missing a value for at least one '
|
|
2698
|
+
f'axis. Skipping. Axes are: {", ".join(axes_names)}')
|
|
2699
|
+
continue
|
|
2700
|
+
self.logger.debug(
|
|
2701
|
+
f'Coordinates for data point {i}: ' +
|
|
2702
|
+
', '.join([f'{a}={d[a]}' for a in axes_names]))
|
|
2703
|
+
# Get the index of the data point in the dataset based on
|
|
2704
|
+
# its values for each coordinate.
|
|
2705
|
+
try:
|
|
2706
|
+
index = tuple(np.where(a.nxdata == d[a.nxname])[0][0]
|
|
2707
|
+
for a in nxdata.nxaxes)
|
|
2708
|
+
except Exception:
|
|
2709
|
+
if allow_approximate_coordinates:
|
|
2710
|
+
try:
|
|
2711
|
+
index = tuple(
|
|
2712
|
+
np.argmin(np.abs(a.nxdata - d[a.nxname]))
|
|
2713
|
+
for a in nxdata.nxaxes)
|
|
2714
|
+
self.logger.warning(
|
|
2715
|
+
f'Nearest match for coordinates of data point {i}:'
|
|
2716
|
+
', '.join(
|
|
2717
|
+
[f'{a.nxname}={a[_i]}'
|
|
2718
|
+
for _i, a in zip(index, nxdata.nxaxes)]))
|
|
2719
|
+
except Exception:
|
|
2720
|
+
self.logger.error(
|
|
2721
|
+
f'Cannot get the index of data point {i}. '
|
|
2722
|
+
'Skipping.')
|
|
2723
|
+
continue
|
|
2724
|
+
else:
|
|
2725
|
+
self.logger.error(
|
|
2726
|
+
f'Cannot get the index of data point {i}. Skipping.')
|
|
2727
|
+
continue
|
|
2728
|
+
self.logger.debug(f'Index of data point {i}: {index}')
|
|
2729
|
+
# Update the signals contained in this data point at the
|
|
2730
|
+
# proper index in the dataset's singal `NXfield`s
|
|
2731
|
+
for k, v in d.items():
|
|
2732
|
+
if k in axes_names:
|
|
2733
|
+
continue
|
|
2734
|
+
try:
|
|
2735
|
+
nxfile.writevalue(
|
|
2736
|
+
os.path.join(nxdata_path, k), np.asarray(v), index)
|
|
2737
|
+
# self.logger.debug(
|
|
2738
|
+
# f'Wrote to {os.path.join(nxdata_path, k)} in '
|
|
2739
|
+
# f'{nxfilename} at index {index} value: {np.asarray(v)}'
|
|
2740
|
+
# f' (type: {type(v)})')
|
|
2741
|
+
except Exception as exc:
|
|
2742
|
+
self.logger.error(
|
|
2743
|
+
f'Error updating signal {k} for new data point '
|
|
2744
|
+
f'{i} (dataset index {index}): {exc}')
|
|
2745
|
+
data_points_used.append(d)
|
|
2746
|
+
|
|
2747
|
+
nxfile.close()
|
|
2748
|
+
|
|
2749
|
+
return data_points_used
|
|
2750
|
+
|
|
2751
|
+
|
|
2752
|
+
class NXdataToDataPointsProcessor(Processor):
|
|
2753
|
+
"""Transform a NeXus NXdata object into a list of dictionaries.
|
|
2754
|
+
Each dictionary represents a single data point in the coordinate
|
|
2755
|
+
space of the dataset. The keys are the names of the signals and
|
|
2756
|
+
axes in the dataset, and the values are a single scalar value (in
|
|
2757
|
+
the case of axes) or the value of the signal at that point in the
|
|
2758
|
+
coordinate space of the dataset (in the case of signals -- this
|
|
2759
|
+
means that values for signals may be any shape, depending on the
|
|
2760
|
+
shape of the signal itself).
|
|
2761
|
+
"""
|
|
2762
|
+
def process(self, data):
|
|
2763
|
+
"""Return a list of dictionaries representing the coordinate
|
|
2764
|
+
and signal values at every point in the dataset provided.
|
|
2765
|
+
|
|
2766
|
+
:param data: Input pipeline data containing a NeXus NXdata
|
|
2767
|
+
object.
|
|
2768
|
+
:type data: list[PipelineData]
|
|
2769
|
+
:returns: List of all data points in the dataset.
|
|
2770
|
+
:rtype: list[dict[str,object]]
|
|
2771
|
+
"""
|
|
2772
|
+
nxdata = self.unwrap_pipelinedata(data)[0]
|
|
2773
|
+
|
|
2774
|
+
data_points = []
|
|
2775
|
+
axes_names = [a.nxname for a in nxdata.nxaxes]
|
|
2776
|
+
self.logger.info(f'Dataset axes: {axes_names}')
|
|
2777
|
+
dataset_shape = tuple([a.size for a in nxdata.nxaxes])
|
|
2778
|
+
self.logger.info(f'Dataset shape: {dataset_shape}')
|
|
2779
|
+
signal_names = [k for k, v in nxdata.entries.items()
|
|
2780
|
+
if not k in axes_names \
|
|
2781
|
+
and v.shape[:len(dataset_shape)] == dataset_shape]
|
|
2782
|
+
self.logger.info(f'Dataset signals: {signal_names}')
|
|
2783
|
+
other_fields = [k for k, v in nxdata.entries.items()
|
|
2784
|
+
if not k in axes_names + signal_names]
|
|
2785
|
+
if len(other_fields) > 0:
|
|
2786
|
+
self.logger.warning(
|
|
2787
|
+
'Ignoring the following fields that cannot be interpreted as '
|
|
2788
|
+
f'either dataset coordinates or signals: {other_fields}')
|
|
2789
|
+
for i in np.ndindex(dataset_shape):
|
|
2790
|
+
data_points.append({
|
|
2791
|
+
**{a: nxdata[a][_i] for a, _i in zip(axes_names, i)},
|
|
2792
|
+
**{s: nxdata[s].nxdata[i] for s in signal_names},
|
|
2793
|
+
})
|
|
2794
|
+
return data_points
|
|
2795
|
+
|
|
2796
|
+
|
|
2797
|
+
class XarrayToNexusProcessor(Processor):
|
|
2798
|
+
"""A Processor to convert the data in an `xarray` structure to a
|
|
2799
|
+
NeXus NXdata object.
|
|
2800
|
+
"""
|
|
2801
|
+
def process(self, data):
|
|
2802
|
+
"""Return `data` represented as a NeXus NXdata object.
|
|
2803
|
+
|
|
2804
|
+
:param data: The input `xarray` structure.
|
|
2805
|
+
:type data: Union[xarray.DataArray, xarray.Dataset]
|
|
2806
|
+
:return: The data and metadata in `data`.
|
|
2807
|
+
:rtype: nexusformat.nexus.NXdata
|
|
2808
|
+
"""
|
|
2809
|
+
# Third party modules
|
|
2810
|
+
from nexusformat.nexus import (
|
|
2811
|
+
NXdata,
|
|
2812
|
+
NXfield,
|
|
2813
|
+
)
|
|
2814
|
+
|
|
2815
|
+
data = self.unwrap_pipelinedata(data)[-1]
|
|
2816
|
+
signal = NXfield(value=data.data, name=data.name, attrs=data.attrs)
|
|
2817
|
+
axes = []
|
|
2818
|
+
for name, coord in data.coords.items():
|
|
2819
|
+
axes.append(
|
|
2820
|
+
NXfield(value=coord.data, name=name, attrs=coord.attrs))
|
|
2821
|
+
axes = tuple(axes)
|
|
2822
|
+
|
|
2823
|
+
return NXdata(signal=signal, axes=axes)
|
|
2824
|
+
|
|
2825
|
+
|
|
2826
|
+
class XarrayToNumpyProcessor(Processor):
|
|
2827
|
+
"""A Processor to convert the data in an `xarray.DataArray`
|
|
2828
|
+
structure to an `numpy.ndarray`.
|
|
2829
|
+
"""
|
|
2830
|
+
def process(self, data):
|
|
2831
|
+
"""Return just the signal values contained in `data`.
|
|
2832
|
+
|
|
2833
|
+
:param data: The input `xarray.DataArray`.
|
|
2834
|
+
:type data: xarray.DataArray
|
|
2835
|
+
:return: The data in `data`.
|
|
2836
|
+
:rtype: numpy.ndarray
|
|
2837
|
+
"""
|
|
2838
|
+
return self.unwrap_pipelinedata(data)[-1].data
|
|
2839
|
+
|
|
2840
|
+
|
|
2841
|
+
#class SumProcessor(Processor):
|
|
2842
|
+
# """A Processor to sum the data in a NeXus NXobject, given a set of
|
|
2843
|
+
# nxpaths.
|
|
2844
|
+
# """
|
|
2845
|
+
# def process(self, data):
|
|
2846
|
+
# """Return the summed data array
|
|
2847
|
+
#
|
|
2848
|
+
# :param data:
|
|
2849
|
+
# :type data:
|
|
2850
|
+
# :return: The summed data.
|
|
2851
|
+
# :rtype: numpy.ndarray
|
|
2852
|
+
# """
|
|
2853
|
+
# nxentry, nxpaths = self.unwrap_pipelinedata(data)[-1]
|
|
2854
|
+
# if len(nxpaths) == 1:
|
|
2855
|
+
# return nxentry[nxpaths[0]]
|
|
2856
|
+
# sum_data = deepcopy(nxentry[nxpaths[0]])
|
|
2857
|
+
# for nxpath in nxpaths[1:]:
|
|
2858
|
+
# nxdata = nxentry[nxpath]
|
|
2859
|
+
# for entry in nxdata.entries:
|
|
2860
|
+
# sum_data[entry] += nxdata[entry]
|
|
2861
|
+
#
|
|
2862
|
+
# return sum_data
|
|
2863
|
+
|
|
2864
|
+
|
|
2865
|
+
if __name__ == '__main__':
|
|
2866
|
+
# Local modules
|
|
2867
|
+
from CHAP.processor import main
|
|
2868
|
+
|
|
2869
|
+
main()
|