ChessAnalysisPipeline 0.0.17.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. CHAP/TaskManager.py +216 -0
  2. CHAP/__init__.py +27 -0
  3. CHAP/common/__init__.py +57 -0
  4. CHAP/common/models/__init__.py +8 -0
  5. CHAP/common/models/common.py +124 -0
  6. CHAP/common/models/integration.py +659 -0
  7. CHAP/common/models/map.py +1291 -0
  8. CHAP/common/processor.py +2869 -0
  9. CHAP/common/reader.py +658 -0
  10. CHAP/common/utils.py +110 -0
  11. CHAP/common/writer.py +730 -0
  12. CHAP/edd/__init__.py +23 -0
  13. CHAP/edd/models.py +876 -0
  14. CHAP/edd/processor.py +3069 -0
  15. CHAP/edd/reader.py +1023 -0
  16. CHAP/edd/select_material_params_gui.py +348 -0
  17. CHAP/edd/utils.py +1572 -0
  18. CHAP/edd/writer.py +26 -0
  19. CHAP/foxden/__init__.py +19 -0
  20. CHAP/foxden/models.py +71 -0
  21. CHAP/foxden/processor.py +124 -0
  22. CHAP/foxden/reader.py +224 -0
  23. CHAP/foxden/utils.py +80 -0
  24. CHAP/foxden/writer.py +168 -0
  25. CHAP/giwaxs/__init__.py +11 -0
  26. CHAP/giwaxs/models.py +491 -0
  27. CHAP/giwaxs/processor.py +776 -0
  28. CHAP/giwaxs/reader.py +8 -0
  29. CHAP/giwaxs/writer.py +8 -0
  30. CHAP/inference/__init__.py +7 -0
  31. CHAP/inference/processor.py +69 -0
  32. CHAP/inference/reader.py +8 -0
  33. CHAP/inference/writer.py +8 -0
  34. CHAP/models.py +227 -0
  35. CHAP/pipeline.py +479 -0
  36. CHAP/processor.py +125 -0
  37. CHAP/reader.py +124 -0
  38. CHAP/runner.py +277 -0
  39. CHAP/saxswaxs/__init__.py +7 -0
  40. CHAP/saxswaxs/processor.py +8 -0
  41. CHAP/saxswaxs/reader.py +8 -0
  42. CHAP/saxswaxs/writer.py +8 -0
  43. CHAP/server.py +125 -0
  44. CHAP/sin2psi/__init__.py +7 -0
  45. CHAP/sin2psi/processor.py +8 -0
  46. CHAP/sin2psi/reader.py +8 -0
  47. CHAP/sin2psi/writer.py +8 -0
  48. CHAP/tomo/__init__.py +15 -0
  49. CHAP/tomo/models.py +210 -0
  50. CHAP/tomo/processor.py +3862 -0
  51. CHAP/tomo/reader.py +9 -0
  52. CHAP/tomo/writer.py +59 -0
  53. CHAP/utils/__init__.py +6 -0
  54. CHAP/utils/converters.py +188 -0
  55. CHAP/utils/fit.py +2947 -0
  56. CHAP/utils/general.py +2655 -0
  57. CHAP/utils/material.py +274 -0
  58. CHAP/utils/models.py +595 -0
  59. CHAP/utils/parfile.py +224 -0
  60. CHAP/writer.py +122 -0
  61. MLaaS/__init__.py +0 -0
  62. MLaaS/ktrain.py +205 -0
  63. MLaaS/mnist_img.py +83 -0
  64. MLaaS/tfaas_client.py +371 -0
  65. chessanalysispipeline-0.0.17.dev3.dist-info/LICENSE +60 -0
  66. chessanalysispipeline-0.0.17.dev3.dist-info/METADATA +29 -0
  67. chessanalysispipeline-0.0.17.dev3.dist-info/RECORD +70 -0
  68. chessanalysispipeline-0.0.17.dev3.dist-info/WHEEL +5 -0
  69. chessanalysispipeline-0.0.17.dev3.dist-info/entry_points.txt +2 -0
  70. chessanalysispipeline-0.0.17.dev3.dist-info/top_level.txt +2 -0
@@ -0,0 +1,2869 @@
1
+ #!/usr/bin/env python
2
+ #-*- coding: utf-8 -*-
3
+ """
4
+ File : processor.py
5
+ Author : Valentin Kuznetsov <vkuznet AT gmail dot com>
6
+ Description: Module for Processors used in multiple experiment-specific workflows.
7
+ """
8
+
9
+ # System modules
10
+ from copy import deepcopy
11
+ import os
12
+ from typing import Optional
13
+
14
+ # Third party modules
15
+ import numpy as np
16
+ from pydantic import (
17
+ Field,
18
+ conint,
19
+ field_validator,
20
+ )
21
+
22
+ # Local modules
23
+ from CHAP import Processor
24
+ from CHAP.common.models.map import (
25
+ DetectorConfig,
26
+ MapConfig,
27
+ )
28
+
29
+
30
+ class AsyncProcessor(Processor):
31
+ """A Processor to process multiple sets of input data via asyncio
32
+ module.
33
+
34
+ :ivar mgr: The `Processor` used to process every set of input data.
35
+ :type mgr: Processor
36
+ """
37
+ def __init__(self, mgr):
38
+ super().__init__()
39
+ self.mgr = mgr
40
+
41
+ def process(self, data):
42
+ """Asynchronously process the input documents with the
43
+ `self.mgr` `Processor`.
44
+
45
+ :param data: Input data documents to process.
46
+ :type data: iterable
47
+ """
48
+ # System modules
49
+ import asyncio
50
+
51
+ async def task(mgr, doc):
52
+ """Process given data using provided `Processor`.
53
+
54
+ :param mgr: The object that will process given data.
55
+ :type mgr: Processor
56
+ :param doc: The data to process.
57
+ :type doc: object
58
+ :return: The processed data.
59
+ :rtype: object
60
+ """
61
+ return mgr.process(doc)
62
+
63
+ async def execute_tasks(mgr, docs):
64
+ """Process given set of documents using provided task
65
+ manager.
66
+
67
+ :param mgr: The object that will process all documents.
68
+ :type mgr: Processor
69
+ :param docs: The set of data documents to process.
70
+ :type doc: iterable
71
+ """
72
+ coroutines = [task(mgr, d) for d in docs]
73
+ await asyncio.gather(*coroutines)
74
+
75
+ asyncio.run(execute_tasks(self.mgr, data))
76
+
77
+
78
+ class BinarizeProcessor(Processor):
79
+ """A Processor to binarize a dataset."""
80
+ def process(self, data, config=None):
81
+ """Plot and return a binarized dataset from a dataset contained
82
+ in `data`. The dataset must either be `array-like` or a NeXus
83
+ NXobject object with a default plottable data path or a
84
+ specified path to a NeXus NXdata or NXfield object.
85
+
86
+ :param data: Input data.
87
+ :type data: list[PipelineData]
88
+ :param config: Initialization parameters for an instance of
89
+ CHAP.common.models.BinarizeProcessorConfig
90
+ :type config: dict, optional
91
+ :return: The binarized dataset for an `array-like` input or
92
+ a return type equal that of the input object with the
93
+ binarized dataset added.
94
+ :rtype: Union[numpy.ndarray, nexusformat.nexus.NXobject]
95
+ """
96
+ # Third party modules
97
+ from nexusformat.nexus import (
98
+ NXdata,
99
+ NXfield,
100
+ NXlink,
101
+ nxsetconfig,
102
+ )
103
+
104
+ # Local modules
105
+ from CHAP.utils.general import nxcopy
106
+
107
+ nxsetconfig(memory=100000)
108
+
109
+ # Load the validated binarize processor configuration
110
+ if config is None:
111
+ # Local modules
112
+ from CHAP.common.models.common import BinarizeProcessorConfig
113
+
114
+ config = BinarizeProcessorConfig()
115
+ else:
116
+ config = self.get_config(
117
+ data, config=config,
118
+ schema='common.models.BinarizeProcessorConfig')
119
+
120
+ # Load the default data
121
+ try:
122
+ nxobject = self.get_data(data)
123
+ if config.nxpath is None:
124
+ dataset = nxobject.get_default()
125
+ else:
126
+ dataset = nxobject[config.nxpath]
127
+ if isinstance(dataset, NXdata):
128
+ nxsignal = dataset.nxsignal
129
+ data = nxsignal.nxdata
130
+ else:
131
+ data = dataset.nxdata
132
+ assert isinstance(data, np.ndarray)
133
+ except Exception:
134
+ try:
135
+ dataset = self.unwrap_pipelinedata(data)[-1]
136
+ assert isinstance(dataset, np.ndarray)
137
+ data = dataset
138
+ except Exception as exc:
139
+ raise ValueError(
140
+ 'Unable the load a valid input data object') from exc
141
+
142
+ if config.method == 'yen':
143
+ min_ = data.min()
144
+ max_ = data.max()
145
+ data = 1 + (config.num_bin - 1) * (data - min_) / (max_ - min_)
146
+
147
+ # Get a histogram of the data
148
+ counts, edges = np.histogram(data, bins=config.num_bin)
149
+ centers = edges[:-1] + 0.5 * np.diff(edges)
150
+
151
+ # Calculate the data cutoff threshold
152
+ # pylint: disable=no-name-in-module
153
+ if config.method == 'CHAP':
154
+ weights = np.cumsum(counts)
155
+ means = np.cumsum(counts * centers)
156
+ weights = weights[0:-1] / weights[-1]
157
+ means = means[0:-1] / means[-1]
158
+ variances = (means-weights)**2 / (weights * (1. - weights))
159
+ threshold = centers[np.argmax(variances)]
160
+ elif config.method == 'otsu':
161
+ # Third party modules
162
+ from skimage.filters import threshold_otsu
163
+
164
+ threshold = threshold_otsu(hist=(counts, centers))
165
+ elif config.method == 'yen':
166
+ # Third party modules
167
+ from skimage.filters import threshold_yen
168
+
169
+ threshold = threshold_yen(hist=(counts, centers))
170
+ elif config.method == 'isodata':
171
+ # Third party modules
172
+ from skimage.filters import threshold_isodata
173
+
174
+ threshold = threshold_isodata(hist=(counts, centers))
175
+ else:
176
+ # Third party modules
177
+ from skimage.filters import threshold_minimum
178
+
179
+ threshold = threshold_minimum(hist=(counts, centers))
180
+ # pylint: enable=no-name-in-module
181
+
182
+ # Apply the data cutoff threshold
183
+ data = np.where(data < threshold, 0, 1).astype(np.ubyte)
184
+
185
+ # Return the output for array-like or NeXus NXfield inputs
186
+ if isinstance(dataset, np.ndarray):
187
+ return data
188
+ if isinstance(dataset, NXfield):
189
+ attrs = dataset.attrs
190
+ attrs.pop('target', None)
191
+ nxfield = NXfield(
192
+ value=data, name=f'{dataset.nxname}_binarized', attrs=attrs)
193
+ return nxfield
194
+
195
+ # Otherwise create a copy of the input NeXus, add the binarized
196
+ # data to the copied original dataset, and remove the original
197
+ # dataset if config.remove_original_data is set
198
+ name = f'{nxsignal.nxname}_binarized'
199
+ nxdefault = nxobject.get_default()
200
+ if isinstance(nxsignal, NXlink):
201
+ link = dataset.nxpath
202
+ path = os.path.split(nxsignal.nxtarget)[0]
203
+ else:
204
+ link = nxdefault.nxpath
205
+ path = os.path.split(nxsignal.nxpath)[0]
206
+ exclude_nxpaths = []
207
+ if config.remove_original_data:
208
+ if link is not None:
209
+ exclude_nxpaths.append(os.path.relpath(
210
+ f'{link}/{nxsignal.nxname}', nxobject.nxpath))
211
+ exclude_nxpaths.append(os.path.relpath(
212
+ f'{path}/{nxsignal.nxname}', nxobject.nxpath))
213
+ nxobject = nxcopy(nxobject, exclude_nxpaths=exclude_nxpaths)
214
+ attrs = nxsignal.attrs
215
+ attrs.pop('target', None)
216
+ nxobject[f'{path}/{name}'] = NXfield(
217
+ value=data, name=name, attrs=attrs)
218
+ nxobject[path].attrs['signal'] = name
219
+ if link is not None:
220
+ nxobject[f'{link}/{name}'] = NXlink(f'{path}/{name}')
221
+ nxobject[link].attrs['signal'] = name
222
+
223
+ return nxobject
224
+
225
+
226
+ class ConstructBaseline(Processor):
227
+ """A Processor to construct a baseline for a dataset."""
228
+ def process(
229
+ self, data, x=None, mask=None, tol=1.e-6, lam=1.e6, max_iter=20,
230
+ save_figures=False):
231
+ """Construct and return the baseline for a dataset.
232
+
233
+ :param data: Input data.
234
+ :type data: list[PipelineData]
235
+ :param x: Independent dimension (only used when running
236
+ interactively or when filename is set).
237
+ :param mask: A mask to apply to the spectrum before baseline
238
+ construction.
239
+ :type mask: array-like, optional
240
+ :param tol: The convergence tolerence, defaults to `1.e-6`.
241
+ :type tol: float, optional
242
+ :param lam: The &lambda (smoothness) parameter (the balance
243
+ between the residual of the data and the baseline and the
244
+ smoothness of the baseline). The suggested range is between
245
+ 100 and 10^8, defaults to `10^6`.
246
+ :type lam: float, optional
247
+ :param max_iter: The maximum number of iterations,
248
+ defaults to `20`.
249
+ :type max_iter: int, optional
250
+ :param save_figures: Save .pngs of plots for checking inputs &
251
+ outputs of this Processor, defaults to `False`.
252
+ :type save_figures: bool, optional
253
+ :return: The smoothed baseline and the configuration.
254
+ :rtype: numpy.array, dict
255
+ """
256
+ try:
257
+ data = np.asarray(self.unwrap_pipelinedata(data)[0])
258
+ except Exception as exc:
259
+ raise ValueError(
260
+ f'The structure of {data} contains no valid data') from exc
261
+
262
+ return self.construct_baseline(
263
+ data, x=x, mask=mask, tol=tol, lam=lam, max_iter=max_iter,
264
+ return_buf=save_figures)
265
+
266
+ @staticmethod
267
+ def construct_baseline(
268
+ y, x=None, mask=None, tol=1.e-6, lam=1.e6, max_iter=20, title=None,
269
+ xlabel=None, ylabel=None, interactive=False, return_buf=False):
270
+ """Construct and return the baseline for a dataset.
271
+
272
+ :param y: Input data.
273
+ :type y: numpy.array
274
+ :param x: Independent dimension (only used when interactive is
275
+ `True` of when filename is set).
276
+ :type x: array-like, optional
277
+ :param mask: A mask to apply to the spectrum before baseline
278
+ construction.
279
+ :type mask: array-like, optional
280
+ :param tol: The convergence tolerence, defaults to `1.e-6`.
281
+ :type tol: float, optional
282
+ :param lam: The &lambda (smoothness) parameter (the balance
283
+ between the residual of the data and the baseline and the
284
+ smoothness of the baseline). The suggested range is between
285
+ 100 and 10^8, defaults to `10^6`.
286
+ :type lam: float, optional
287
+ :param max_iter: The maximum number of iterations,
288
+ defaults to `20`.
289
+ :type max_iter: int, optional
290
+ :param title: Title for the displayed figure.
291
+ :type title: str, optional
292
+ :param xlabel: Label for the x-axis of the displayed figure.
293
+ :type xlabel: str, optional
294
+ :param ylabel: Label for the y-axis of the displayed figure.
295
+ :type ylabel: str, optional
296
+ :param interactive: Allows for user interactions,
297
+ defaults to `False`.
298
+ :type interactive: bool, optional
299
+ :param return_buf: Return an in-memory object as a byte stream
300
+ represention of the Matplotlib figure, defaults to `False`.
301
+ :type return_buf: bool, optional
302
+ :return: The smoothed baseline and the configuration and a
303
+ byte stream represention of the Matplotlib figure if
304
+ return_buf is `True` (`None` otherwise)
305
+ :rtype: numpy.array, dict, Union[io.BytesIO, None]
306
+ """
307
+ # Third party modules
308
+ from matplotlib.widgets import TextBox, Button
309
+ import matplotlib.pyplot as plt
310
+
311
+ # Local modules
312
+ from CHAP.utils.general import (
313
+ baseline_arPLS,
314
+ fig_to_iobuf,
315
+ )
316
+
317
+ def change_fig_subtitle(maxed_out=False, subtitle=None):
318
+ """Change the figure's subtitle."""
319
+ if fig_subtitles:
320
+ fig_subtitles[0].remove()
321
+ fig_subtitles.pop()
322
+ if subtitle is None:
323
+ subtitle = r'$\lambda$ = 'f'{lambdas[-1]:.2e}, '
324
+ if maxed_out:
325
+ subtitle += f'# iter = {num_iters[-1]} (maxed out) '
326
+ else:
327
+ subtitle += f'# iter = {num_iters[-1]} '
328
+ subtitle += f'error = {errors[-1]:.2e}'
329
+ fig_subtitles.append(
330
+ plt.figtext(*subtitle_pos, subtitle, **subtitle_props))
331
+
332
+ def select_lambda(expression):
333
+ """Callback function for the "Select lambda" TextBox."""
334
+ if not expression:
335
+ return
336
+ try:
337
+ lam = float(expression)
338
+ if lam < 0:
339
+ raise ValueError
340
+ except ValueError:
341
+ change_fig_subtitle(
342
+ subtitle='Invalid lambda, enter a positive number')
343
+ else:
344
+ lambdas.pop()
345
+ lambdas.append(10**lam)
346
+ baseline, _, _, num_iter, error = get_baseline(
347
+ y, mask=mask, tol=tol, lam=lambdas[-1], max_iter=max_iter)
348
+ num_iters.pop()
349
+ num_iters.append(num_iter)
350
+ errors.pop()
351
+ errors.append(error)
352
+ if num_iter < max_iter:
353
+ change_fig_subtitle()
354
+ else:
355
+ change_fig_subtitle(maxed_out=True)
356
+ baseline_handle.set_ydata(baseline)
357
+ lambda_box.set_val('')
358
+ plt.draw()
359
+
360
+ def continue_iter(event):
361
+ """Callback function for the "Continue" button."""
362
+ baseline, _, w, n_iter, error = get_baseline(
363
+ y, mask=mask, w=weights[-1], tol=tol, lam=lambdas[-1],
364
+ max_iter=max_iter)
365
+ num_iters[-1] += n_iter
366
+ errors.pop()
367
+ errors.append(error)
368
+ if n_iter < max_iter:
369
+ change_fig_subtitle()
370
+ else:
371
+ change_fig_subtitle(maxed_out=True)
372
+ baseline_handle.set_ydata(baseline)
373
+ plt.draw()
374
+ weights.pop()
375
+ weights.append(w)
376
+
377
+ def confirm(event):
378
+ """Callback function for the "Confirm" button."""
379
+ plt.close()
380
+
381
+ def get_baseline(
382
+ y, mask=None, w=None, tol=1.e-6, lam=1.6, max_iter=20):
383
+ """Get a baseline."""
384
+ return baseline_arPLS(
385
+ y, mask=mask, w=w, tol=tol, lam=lam, max_iter=max_iter,
386
+ full_output=True)
387
+
388
+ baseline, _, w, num_iter, error = get_baseline(
389
+ y, mask=mask, tol=tol, lam=lam, max_iter=max_iter)
390
+
391
+ if not interactive and not return_buf:
392
+ config = {
393
+ 'tol': tol, 'lambda': lam, 'max_iter': max_iter,
394
+ 'num_iter': num_iter, 'error': error, 'mask': mask}
395
+ return baseline, config, None
396
+
397
+ lambdas = [lam]
398
+ weights = [w]
399
+ num_iters = [num_iter]
400
+ errors = [error]
401
+ fig_subtitles = []
402
+
403
+ # Check inputs
404
+ if x is None:
405
+ x = np.arange(y.size)
406
+
407
+ # Setup the Matplotlib figure
408
+ title_pos = (0.5, 0.95)
409
+ title_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center',
410
+ 'verticalalignment': 'bottom'}
411
+ subtitle_pos = (0.5, 0.90)
412
+ subtitle_props = {'fontsize': 'x-large',
413
+ 'horizontalalignment': 'center',
414
+ 'verticalalignment': 'bottom'}
415
+ fig, ax = plt.subplots(figsize=(11, 8.5))
416
+ if mask is None:
417
+ ax.plot(x, y, label='input data')
418
+ else:
419
+ ax.plot(
420
+ x[mask.astype(bool)], y[mask.astype(bool)], label='input data')
421
+ baseline_handle = ax.plot(x, baseline, label='baseline')[0]
422
+ # ax.plot(x, y-baseline, label='baseline corrected data')
423
+ ax.legend()
424
+ ax.set_xlabel(xlabel, fontsize='x-large')
425
+ ax.set_ylabel(ylabel, fontsize='x-large')
426
+ ax.set_xlim(x[0], x[-1])
427
+ if title is None:
428
+ fig_title = plt.figtext(*title_pos, 'Baseline', **title_props)
429
+ else:
430
+ fig_title = plt.figtext(*title_pos, title, **title_props)
431
+ if num_iter < max_iter:
432
+ change_fig_subtitle()
433
+ else:
434
+ change_fig_subtitle(maxed_out=True)
435
+ fig.subplots_adjust(bottom=0.0, top=0.85)
436
+
437
+ lambda_box = None
438
+ if interactive:
439
+
440
+ fig.subplots_adjust(bottom=0.2)
441
+
442
+ # Setup TextBox
443
+ lambda_box = TextBox(
444
+ plt.axes([0.15, 0.05, 0.15, 0.075]), r'log($\lambda$)')
445
+ lambda_cid = lambda_box.on_submit(select_lambda)
446
+
447
+ # Setup "Continue" button
448
+ continue_btn = Button(
449
+ plt.axes([0.45, 0.05, 0.15, 0.075]), 'Continue smoothing')
450
+ continue_cid = continue_btn.on_clicked(continue_iter)
451
+
452
+ # Setup "Confirm" button
453
+ confirm_btn = Button(
454
+ plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm')
455
+ confirm_cid = confirm_btn.on_clicked(confirm)
456
+
457
+ # Show figure for user interaction
458
+ plt.show()
459
+
460
+ # Disconnect all widget callbacks when figure is closed
461
+ lambda_box.disconnect(lambda_cid)
462
+ continue_btn.disconnect(continue_cid)
463
+ confirm_btn.disconnect(confirm_cid)
464
+
465
+ # ... and remove the buttons before returning the figure
466
+ lambda_box.ax.remove()
467
+ continue_btn.ax.remove()
468
+ confirm_btn.ax.remove()
469
+
470
+ if return_buf:
471
+ fig_title.set_in_layout(True)
472
+ fig_subtitles[-1].set_in_layout(True)
473
+ fig.tight_layout(rect=(0, 0, 1, 0.90))
474
+ buf = fig_to_iobuf(fig)
475
+ else:
476
+ buf = None
477
+ plt.close()
478
+
479
+ config = {
480
+ 'tol': tol, 'lambda': lambdas[-1], 'max_iter': max_iter,
481
+ 'num_iter': num_iters[-1], 'error': errors[-1], 'mask': mask}
482
+ return baseline, config, buf
483
+
484
+
485
+ class ConvertStructuredProcessor(Processor):
486
+ """Processor for converting map data between structured /
487
+ unstructued formats.
488
+ """
489
+ def process(self, data):
490
+ # Local modules
491
+ from CHAP.utils.converters import convert_structured_unstructured
492
+
493
+ data = self.unwrap_pipelinedata(data)[0]
494
+ return convert_structured_unstructured(data)
495
+
496
+
497
+ class ImageProcessor(Processor):
498
+ """A Processor to perform various visualization operations on
499
+ images (slices) selected from a NeXus object."""
500
+ def __init__(self):
501
+ super().__init__()
502
+ self._figconfig = None
503
+
504
+ def process(self, data, config=None, save_figures=True):
505
+ """Plot and/or return image slices from a NeXus NXobject
506
+ object with a default plottable data path.
507
+
508
+ :param data: Input data.
509
+ :type data: list[PipelineData]
510
+ :param config: Initialization parameters for an instance of
511
+ CHAP.common.models.ImageProcessorConfig
512
+ :type config: dict, optional
513
+ :param save_figures: Return the plottable image(s) to be
514
+ written to file downstream in the pipeline,
515
+ defaults to `True`.
516
+ :type save_figures: bool, optional
517
+ :return: The plottable image(s) (for save_figures = `True`)
518
+ or the input default NeXus NXdata object
519
+ (for save_figures = `False`).
520
+ :rtype: Union[bytes, nexusformat.nexus.NXdata, numpy.ndarray]
521
+ """
522
+ if not save_figures and not self.interactive:
523
+ return None
524
+
525
+ # Third party modules
526
+ from nexusformat.nexus import nxsetconfig
527
+
528
+ nxsetconfig(memory=100000)
529
+
530
+ # Load the default data
531
+ try:
532
+ nxdata = self.get_data(data).get_default()
533
+ except Exception as exc:
534
+ raise ValueError(
535
+ 'Unable the load the default NXdata object from the input '
536
+ f'pipeline ({data})') from exc
537
+
538
+ # Load the validated image processor configuration
539
+ if config is None:
540
+ # Local modules
541
+ from CHAP.common.models.common import ImageProcessorConfig
542
+
543
+ config = ImageProcessorConfig()
544
+ else:
545
+ config = self.get_config(
546
+ data, config=config,
547
+ schema='common.models.ImageProcessorConfig')
548
+
549
+ # Get the axes info and image slice(s)
550
+ try:
551
+ data = nxdata.nxsignal
552
+ except Exception as exc:
553
+ raise ValueError('Unable the find the default signal in:\n'
554
+ f'({nxdata.tree})') from exc
555
+ axis = config.axis
556
+ axes = nxdata.attrs.get('axes', None)
557
+ if axes is not None:
558
+ axes = list(axes.nxdata)
559
+ if nxdata.nxsignal.ndim == 2:
560
+ exit('ImageProcessor not tested yet for a 2D dataset')
561
+ if axis is not None:
562
+ axis = None
563
+ self.logger.warning('Ignoring parameter axis')
564
+ if index is not None:
565
+ index = None
566
+ self.logger.warning('Ignoring parameter index')
567
+ if coord is not None:
568
+ coord = None
569
+ self.logger.warning('Ignoring parameter coord')
570
+ elif nxdata.nxsignal.ndim == 3:
571
+ if isinstance(axis, int):
572
+ if not 0 <= axis < nxdata.nxsignal.ndim:
573
+ raise ValueError(f'axis index out of range ({axis} not in '
574
+ f'[0, {nxdata.nxsignal.ndim-1}])')
575
+ elif isinstance(axis, str):
576
+ if axes is None or axis not in axes:
577
+ raise ValueError(
578
+ f'Unable to match axis = {axis} in {nxdata.tree}')
579
+ axis = axes.index(axis)
580
+ else:
581
+ raise ValueError(f'Invalid parameter axis ({axis})')
582
+ if axis:
583
+ data = np.moveaxis(data, axis, 0)
584
+ if axes is not None and hasattr(nxdata, axes[axis]):
585
+ if axis == 1:
586
+ axes = [axes[1], axes[0], axes[2]]
587
+ elif axis:
588
+ axes = [axes[2], axes[0], axes[1]]
589
+ axis_name = axes[0]
590
+ if 'units' in nxdata[axis_name].attrs:
591
+ axis_unit = f' ({nxdata[axis_name].units})'
592
+ else:
593
+ axis_unit = ''
594
+ row_label = axes[2]
595
+ row_coords = nxdata[row_label].nxdata
596
+ column_label = axes[1]
597
+ column_coords = nxdata[column_label].nxdata
598
+ if 'units' in nxdata[row_label].attrs:
599
+ row_label += f' ({nxdata[row_label].units})'
600
+ if 'units' in nxdata[column_label].attrs:
601
+ column_label += f' ({nxdata[column_label].units})'
602
+ else:
603
+ exit('No axes attribute not tested yet')
604
+ axes = [0, 1, 2]
605
+ axes.pop(axis)
606
+ axis_name = f'axis {axis}'
607
+ axis_unit = ''
608
+ # row_label = f'axis {axis[1]}'
609
+ # row_coords = None
610
+ # column_label = f'axis {axis[0]}'
611
+ # column_coords = None
612
+ axis_coords = nxdata[axis_name].nxdata
613
+ else:
614
+ raise ValueError('Invalid data dimension (must be 2D or 3D)')
615
+ if config.coord_range is None:
616
+ index_range = config.index_range
617
+ else:
618
+ # Local modules
619
+ from CHAP.utils.general import (
620
+ index_nearest_down,
621
+ index_nearest_up,
622
+ )
623
+
624
+ if config.index_range is not None:
625
+ self.logger.warning('Ignoring parameter index_range')
626
+ if isinstance(config.coord_range, (int, float)):
627
+ index_range = index_nearest_up(
628
+ axis_coords, config.coord_range)
629
+ elif len(config.coord_range) == 2:
630
+ index_range = [
631
+ index_nearest_up(axis_coords, config.coord_range[0]),
632
+ index_nearest_down(axis_coords, config.coord_range[1])]
633
+ else:
634
+ index_range = [
635
+ index_nearest_up(axis_coords, config.coord_range[0]),
636
+ index_nearest_down(axis_coords, config.coord_range[1]),
637
+ int(max(1, config.coord_range[2] /
638
+ ((axis_coords[-1]-axis_coords[0])/data.shape[0])))]
639
+ if index_range == -1:
640
+ index_range = nxdata.nxsignal.shape[axis] // 2
641
+ if isinstance(index_range, int):
642
+ data = data[index_range]
643
+ axis_coords = [axis_coords[index_range]]
644
+ elif index_range is not None:
645
+ slice_ = slice(*tuple(index_range))
646
+ data = data[slice_]
647
+ axis_coords = axis_coords[slice_]
648
+ if config.vrange is None:
649
+ vrange = (data.min(), data.max())
650
+ else:
651
+ vrange = config.vrange
652
+
653
+ # Create the figure configuration
654
+ self._figconfig = {
655
+ 'title': f'{nxdata.nxpath}/{nxdata.signal}',
656
+ 'axis_name': axis_name,
657
+ 'axis_unit': axis_unit,
658
+ 'axis_coords': axis_coords,
659
+ 'row_label': row_label,
660
+ 'column_label': column_label,
661
+ 'extent': (row_coords[0], row_coords[-1],
662
+ column_coords[-1], column_coords[0]),
663
+ 'vrange': vrange,
664
+ }
665
+ self.logger.debug(f'figure configuration:\n{self._figconfig}')
666
+
667
+ if len(axis_coords) == 1:
668
+ # Create a figure for a single image slice
669
+ if config.animation:
670
+ self.logger.warning(
671
+ 'Ignoring animation parameter for a single image')
672
+ fileformat = 'png'
673
+ if config.fileformat is None:
674
+ fileformat = 'png'
675
+ else:
676
+ fileformat = config.fileformat
677
+ fig, plt = self._create_figure(np.squeeze(data))
678
+ if self.interactive:
679
+ plt.show()
680
+ if save_figures:
681
+ # Local modules
682
+ from CHAP.utils.general import fig_to_iobuf
683
+
684
+ # Return a binary image of the figure
685
+ buf, fileformat = fig_to_iobuf(fig, fileformat=fileformat)
686
+ else:
687
+ buf = None
688
+ plt.close()
689
+ if save_figures:
690
+ return {'image_data': buf, 'fileformat': fileformat}
691
+ return nxdata
692
+
693
+ # Create an animation for a set of image slices
694
+ if self.interactive or config.animation:
695
+ ani = self._create_animation(data)
696
+ else:
697
+ ani = None
698
+
699
+ if save_figures:
700
+ if config.animation:
701
+ # Return the animation object
702
+ if (config.fileformat is not None
703
+ and config.fileformat != 'gif'):
704
+ self.logger.warning(
705
+ 'Ignoring inconsistent file extension')
706
+ fileformat = 'gif'
707
+ image_data = ani
708
+ else:
709
+ # Return the set of image slices as a tif stack
710
+ if (config.fileformat is not None
711
+ and config.fileformat != 'tif'):
712
+ self.logger.warning(
713
+ 'Ignoring inconsistent file extension')
714
+ fileformat = 'tif'
715
+ data = 255.0*((data - vrange[0])/
716
+ (vrange[1] - vrange[0]))
717
+ image_data = data.astype(np.uint8)
718
+ return {'image_data': image_data, 'fileformat': fileformat}
719
+ return nxdata
720
+
721
+ def _create_animation(self, data):
722
+ # Third party modules
723
+ from functools import partial
724
+ from matplotlib import animation
725
+
726
+ def animate(i, plt, title):
727
+ im.set_array(data[i])
728
+ title.set_text(self._set_title(i))
729
+ plt.draw()
730
+ return im,
731
+
732
+ fig, im, plt, title = self._create_figure(data[0], animated=True)
733
+ ani = animation.FuncAnimation(
734
+ fig, partial(animate, plt=plt, title=title), frames=data.shape[0],
735
+ interval=50, blit=True)
736
+ if self.interactive:
737
+ plt.show()
738
+ plt.close()
739
+
740
+ return ani
741
+
742
+ def _create_figure(self, image, animated=False):
743
+ # Third party modules
744
+ import matplotlib.pyplot as plt
745
+
746
+ fig, ax = plt.subplots()
747
+ im = plt.imshow(
748
+ image, extent=self._figconfig['extent'], origin='lower',
749
+ vmin=self._figconfig['vrange'][0],
750
+ vmax=self._figconfig['vrange'][1], cmap='gray', animated=animated)
751
+ fig.suptitle(self._figconfig['title'], fontsize='x-large')
752
+ title = ax.set_title(self._set_title(0), fontsize='x-large', pad=10)
753
+ ax.set_xlabel(self._figconfig['row_label'], fontsize='x-large')
754
+ ax.set_ylabel(self._figconfig['column_label'], fontsize='x-large')
755
+ plt.colorbar()
756
+ fig.tight_layout()
757
+ if animated:
758
+ return fig, im, plt, title
759
+ return fig, plt
760
+
761
+ def _set_title(self, i):
762
+ return self._figconfig['axis_name'] +\
763
+ f' = {self._figconfig["axis_coords"][i]:.3f}' +\
764
+ self._figconfig['axis_unit']
765
+
766
+
767
+ class MapProcessor(Processor):
768
+ """A Processor that takes a map configuration and returns a NeXus
769
+ NXentry object representing that map's metadata and any
770
+ scalar-valued raw data requested by the supplied map configuration.
771
+
772
+ :ivar config: Map configuration parameters to initialize an
773
+ instance of common.models.map.MapConfig. Any values in
774
+ `'config'` supplant their corresponding values obtained from
775
+ the pipeline data configuration.
776
+ :type config: Union[dict, common.models.map.MapConfig]
777
+ :ivar detector_config: Detector configurations of the detectors to
778
+ include raw data for in the returned NeXus NXentry object
779
+ (overruling detector info in the pipeline data, if present).
780
+ :type detector_config: Union[
781
+ dict, common.models.map.DetectorConfig]
782
+ :ivar num_proc: Number of processors used to read map,
783
+ defaults to `1`.
784
+ :type num_proc: int, optional
785
+ """
786
+ pipeline_fields: dict = Field(
787
+ default = {
788
+ 'config': 'common.models.map.MapConfig',
789
+ 'detector_config': 'common.models.map.DetectorConfig'},
790
+ init_var=True)
791
+ config: MapConfig
792
+ detector_config: DetectorConfig
793
+ num_proc: Optional[conint(gt=0)] = 1
794
+
795
+ @field_validator('num_proc')
796
+ @classmethod
797
+ def validate_num_proc(cls, num_proc, info):
798
+ """Validate the number of processors.
799
+
800
+ :ivar num_proc: Number of processors used to read map,
801
+ defaults to `1`.
802
+ :type num_proc: int, optional
803
+ :return: Number of processors
804
+ :rtype: str
805
+ """
806
+ if num_proc > 1:
807
+ logger = info['logger']
808
+ try:
809
+ # Third party modules
810
+ from mpi4py import MPI
811
+
812
+ if num_proc > os.cpu_count():
813
+ logger.warning(
814
+ f'The requested number of processors ({num_proc}) '
815
+ 'exceeds the maximum number of processors '
816
+ f'({os.cpu_count()}): reset it to {os.cpu_count()}')
817
+ num_proc = os.cpu_count()
818
+ except Exception:
819
+ logger.warning('Unable to load mpi4py, running serially')
820
+ num_proc = 1
821
+ logger.debug(f'Number of processors: {num_proc}')
822
+ return num_proc
823
+
824
+ def process(
825
+ self, data, placeholder_data=False, comm=None):
826
+
827
+ return self._process(data, placeholder_data, comm)
828
+
829
+ # @profile
830
+ def _process(self, data, placeholder_data=False, comm=None):
831
+ """Process that takes a map configuration and returns a NeXus
832
+ NXentry object representing the map.
833
+
834
+ :param data: Pipeline data list with an optional item for the
835
+ map configuration parameters with
836
+ `'common.models.map.MapConfig'` as its `'schema'` key.
837
+ :type data: list[PipelineData]
838
+ :param placeholder_data: For SMB EDD maps only. Value to use
839
+ for missing detector data frames, or `False` if missing
840
+ data should raise an error, defaults to `False`.
841
+ :type placeholder_data: object, optional
842
+ :param comm: MPI communicator.
843
+ :type comm: mpi4py.MPI.Comm, optional
844
+ :return: Map data and metadata.
845
+ :rtype: nexusformat.nexus.NXentry
846
+ """
847
+ # System modules
848
+ import logging
849
+
850
+ # Third party modules
851
+ import yaml
852
+
853
+ # Check for available metadata
854
+ metadata = {}
855
+ if data:
856
+ try:
857
+ for d in data:
858
+ if d.get('schema') == 'metadata':
859
+ metadata = d.get('data')
860
+ break
861
+ except Exception:
862
+ pass
863
+ if len(metadata) > 1:
864
+ raise ValueError(
865
+ f'Unable to find unique data for schema "metadata"')
866
+ if metadata:
867
+ metadata = self._get_metadata_config(metadata[0])
868
+
869
+ # Create the sub-pipeline configuration for each processor
870
+ # FIX: catered to EDD with one spec scan
871
+ assert len(self.config.spec_scans) == 1
872
+ spec_scans = self.config.spec_scans[0]
873
+ scan_numbers = spec_scans.scan_numbers
874
+ num_scan = len(scan_numbers)
875
+ if num_scan < self.num_proc:
876
+ self.logger.warning(
877
+ f'Requested number of processors ({self.num_proc}) exceeds '
878
+ f'the number of scans ({num_scan}): reset it to {num_scan}')
879
+ self.num_proc = num_scan
880
+ if self.num_proc == 1:
881
+ common_comm = comm
882
+ offsets = [0]
883
+ else:
884
+ # System modules
885
+ from tempfile import NamedTemporaryFile
886
+
887
+ # Local modules
888
+ from CHAP.models import RunConfig
889
+
890
+ raise NotImplementedError(
891
+ 'MapProcessor needs testing for num_proc>1')
892
+ scans_per_proc = num_scan//self.num_proc
893
+ num = scans_per_proc
894
+ if num_scan - scans_per_proc*self.num_proc > 0:
895
+ num += 1
896
+ spec_scans.scan_numbers = scan_numbers[:num]
897
+ n_scan = num
898
+ pipeline_config = []
899
+ offsets = [0]
900
+ for n_proc in range(1, self.num_proc):
901
+ num = scans_per_proc
902
+ if n_proc < num_scan - scans_per_proc*self.num_proc:
903
+ num += 1
904
+ config = self.config.model_dump()
905
+ config['spec_scans'][0]['scan_numbers'] = \
906
+ scan_numbers[n_scan:n_scan+num]
907
+ pipeline_config.append(
908
+ [{'common.MapProcessor': {
909
+ 'config': config,
910
+ 'detector_config': self.detector_config.model_dump(),
911
+ }}])
912
+ offsets.append(n_scan)
913
+ n_scan += num
914
+
915
+ # Spawn the workers to run the sub-pipeline
916
+ run_config = RunConfig(
917
+ log_level=logging.getLevelName(self.logger.level), spawn=1)
918
+ tmp_names = []
919
+ with NamedTemporaryFile(delete=False) as fp:
920
+ # pylint: disable=c-extension-no-member
921
+ fp_name = fp.name
922
+ tmp_names.append(fp_name)
923
+ with open(fp_name, 'w') as f:
924
+ yaml.dump({'config': {'spawn': 1}}, f, sort_keys=False)
925
+ for n_proc in range(1, self.num_proc):
926
+ f_name = f'{fp_name}_{n_proc}'
927
+ tmp_names.append(f_name)
928
+ with open(f_name, 'w') as f:
929
+ yaml.dump(
930
+ # FIX once comm is a field of RunConfig
931
+ # {'config': run_config.model_dump(exclude='comm'),
932
+ {'config': run_config.model_dump(),
933
+ 'pipeline': pipeline_config[n_proc-1]},
934
+ f, sort_keys=False)
935
+ # pylint: disable=used-before-assignment
936
+ sub_comm = MPI.COMM_SELF.Spawn(
937
+ 'CHAP', args=[fp_name], maxprocs=self.num_proc-1)
938
+ common_comm = sub_comm.Merge(False)
939
+ # Align with the barrier in RunConfig() on common_comm
940
+ # called from the spawned main() in common_comm
941
+ common_comm.barrier()
942
+ # Align with the barrier in run() on common_comm
943
+ # called from the spawned main()
944
+ common_comm.barrier()
945
+
946
+ if common_comm is None:
947
+ self.num_proc = 1
948
+ rank = 0
949
+ else:
950
+ self.num_proc = common_comm.Get_size()
951
+ rank = common_comm.Get_rank()
952
+ if self.num_proc == 1:
953
+ offset = 0
954
+ else:
955
+ num_scan = common_comm.bcast(num_scan, root=0)
956
+ offset = common_comm.scatter(offsets, root=0)
957
+
958
+ # Read the raw data
959
+ if self.config.experiment_type == 'EDD':
960
+ data, independent_dimensions, all_scalar_data = \
961
+ self._read_raw_data_edd(
962
+ common_comm, num_scan, offset, placeholder_data)
963
+ else:
964
+ data, independent_dimensions, all_scalar_data = \
965
+ self._read_raw_data(common_comm, num_scan, offset)
966
+ if not rank:
967
+ self.logger.debug(f'Data shape: {data.shape}')
968
+ if independent_dimensions is not None:
969
+ self.logger.debug('Independent dimensions shape: '
970
+ f'{independent_dimensions.shape}')
971
+ if all_scalar_data is not None:
972
+ self.logger.debug('Scalar data shape: '
973
+ f'{all_scalar_data.shape}')
974
+
975
+ if rank:
976
+ return None
977
+
978
+ if self.num_proc > 1:
979
+ # Reset the scan_numbers to the original full set
980
+ spec_scans.scan_numbers = scan_numbers
981
+ # Align with the barrier in main() on common_comm
982
+ # when disconnecting the spawned worker
983
+ common_comm.barrier()
984
+ # Disconnect spawned workers and cleanup temporary files
985
+ sub_comm.Disconnect()
986
+ for tmp_name in tmp_names:
987
+ os.remove(tmp_name)
988
+
989
+ # Construct and return the NeXus NXroot object
990
+ return self._get_nxroot(
991
+ data, independent_dimensions, all_scalar_data, placeholder_data)
992
+
993
+ def _get_metadata_config(self, metadata):
994
+ """Get experiment specific configurational data from the
995
+ FOXDEN metadata record
996
+
997
+ :param metadata: FOXDEN metadata record.
998
+ :type metadata: dict
999
+ :return: Experiment specific configurational data.
1000
+ :rtype: dict
1001
+ """
1002
+ config = {'did': metadata.get('did')}
1003
+ experiment_type = metadata.get('technique')
1004
+ if 'tomography' in experiment_type:
1005
+ config['title'] = metadata.get('sample_name')
1006
+ station = metadata.get('beamline')[0]
1007
+ if station == '3A':
1008
+ station = 'id3a'
1009
+ else:
1010
+ raise ValueError(f'Invalid beamline parameter ({station})')
1011
+ config['station'] = station
1012
+ config['experiment_type'] = 'TOMO'
1013
+ config['sample'] = {'name': config['title'],
1014
+ 'description': metadata.get('description')}
1015
+ if station == 'id3a':
1016
+ config['spec_file'] = os.path.join(
1017
+ metadata.get('data_location_raw'), 'spec.log')
1018
+ else:
1019
+ raise ValueError(
1020
+ f'Experiment type {experiment_type} not implemented yet')
1021
+ return config
1022
+
1023
+ def _get_nxroot(
1024
+ self, data, independent_dimensions, all_scalar_data,
1025
+ placeholder_data):
1026
+ """Use a `MapConfig` to construct a NeXus NXroot object.
1027
+
1028
+ :param data: The map's raw data.
1029
+ :type data: numpy.ndarray
1030
+ :param independent_dimensions: The map's independent
1031
+ coordinates.
1032
+ :type independent_dimensions: numpy.ndarray
1033
+ :param all_scalar_data: The map's scalar data.
1034
+ :type all_scalar_data: numpy.ndarray
1035
+ :param placeholder_data: For SMB EDD maps only. Value to use
1036
+ for missing detector data frames, or `False` if missing
1037
+ data should raise an error.
1038
+ :type placeholder_data: object
1039
+ :return: The map's data and metadata contained in a NeXus
1040
+ structure.
1041
+ :rtype: nexusformat.nexus.NXroot
1042
+ """
1043
+ # Third party modules
1044
+ # pylint: disable=no-name-in-module
1045
+ from nexusformat.nexus import (
1046
+ NXcollection,
1047
+ NXdata,
1048
+ NXentry,
1049
+ NXfield,
1050
+ NXlinkfield,
1051
+ NXsample,
1052
+ NXroot,
1053
+ )
1054
+ # pylint: enable=no-name-in-module
1055
+
1056
+ # Local modules:
1057
+ from CHAP.common.models.map import PointByPointScanData
1058
+
1059
+ def linkdims(nxgroup, nxdata_source):
1060
+ """Link the dimensions for an NXgroup."""
1061
+ source_axes = [k for k in nxdata_source.keys()]
1062
+ if isinstance(source_axes, str):
1063
+ source_axes = [source_axes]
1064
+ axes = []
1065
+ for dim in source_axes:
1066
+ axes.append(dim)
1067
+ if isinstance(nxdata_source[dim], NXlinkfield):
1068
+ nxgroup[dim] = nxdata_source[dim]
1069
+ else:
1070
+ nxgroup.makelink(nxdata_source[dim])
1071
+ if f'{dim}_indices' in nxdata_source.attrs:
1072
+ nxgroup.attrs[f'{dim}_indices'] = \
1073
+ nxdata_source.attrs[f'{dim}_indices']
1074
+ if len(axes) == 1:
1075
+ nxgroup.attrs['axes'] = axes
1076
+ else:
1077
+ nxgroup.attrs['unstructured_axes'] = axes
1078
+
1079
+ # Set up NeXus NXroot/NXentry and add CHESS-specific metadata
1080
+ nxroot = NXroot()
1081
+ nxentry = NXentry(name=self.config.title)
1082
+ nxroot[nxentry.nxname] = nxentry
1083
+ nxentry.set_default()
1084
+ nxentry.map_config = self.config.model_dump_json()
1085
+ nxentry.attrs['station'] = self.config.station
1086
+ for k, v in self.config.attrs.items():
1087
+ nxentry.attrs[k] = v
1088
+ nxentry.spec_scans = NXcollection()
1089
+ for scans in self.config.spec_scans:
1090
+ nxentry.spec_scans[scans.scanparsers[0].scan_name] = \
1091
+ NXfield(value=scans.scan_numbers,
1092
+ dtype='int8',
1093
+ attrs={'spec_file': str(scans.spec_file)})
1094
+
1095
+ # Add sample metadata
1096
+ nxentry[self.config.sample.name] = NXsample(
1097
+ **self.config.sample.model_dump())
1098
+
1099
+ # Set up independent dimensions NeXus NXdata group
1100
+ # (squeeze out constant dimensions)
1101
+ constant_dim = []
1102
+ for i, dim in enumerate(self.config.independent_dimensions):
1103
+ unique = np.unique(independent_dimensions[i])
1104
+ if unique.size == 1:
1105
+ constant_dim.append(i)
1106
+ nxentry.independent_dimensions = NXdata()
1107
+ for i, dim in enumerate(self.config.independent_dimensions):
1108
+ if i not in constant_dim:
1109
+ nxentry.independent_dimensions[dim.label] = NXfield(
1110
+ independent_dimensions[i], dim.label,
1111
+ attrs={'units': dim.units,
1112
+ 'long_name': f'{dim.label} ({dim.units})',
1113
+ 'data_type': dim.data_type,
1114
+ 'local_name': dim.name})
1115
+
1116
+ # Set up scalar data NeXus NXdata group
1117
+ # (add the constant independent dimensions)
1118
+ if all_scalar_data is not None:
1119
+ self.logger.debug(
1120
+ f'all_scalar_data.shape = {all_scalar_data.shape}\n\n')
1121
+ scalar_signals = []
1122
+ scalar_data = []
1123
+ for i, dim in enumerate(self.config.all_scalar_data):
1124
+ scalar_signals.append(dim.label)
1125
+ scalar_data.append(NXfield(
1126
+ value=all_scalar_data[i],
1127
+ units=dim.units,
1128
+ attrs={'long_name': f'{dim.label} ({dim.units})',
1129
+ 'data_type': dim.data_type,
1130
+ 'local_name': dim.name}))
1131
+ if (self.config.experiment_type == 'EDD'
1132
+ and not placeholder_data is False):
1133
+ scalar_signals.append('placeholder_data_used')
1134
+ scalar_data.append(NXfield(
1135
+ value=all_scalar_data[-1],
1136
+ attrs={'description':
1137
+ 'Indicates whether placeholder data may be present for'
1138
+ 'the corresponding frames of detector data.'}))
1139
+ for i, dim in enumerate(deepcopy(self.config.independent_dimensions)):
1140
+ if i in constant_dim:
1141
+ scalar_signals.append(dim.label)
1142
+ scalar_data.append(NXfield(
1143
+ independent_dimensions[i], dim.label,
1144
+ attrs={'units': dim.units,
1145
+ 'long_name': f'{dim.label} ({dim.units})',
1146
+ 'data_type': dim.data_type,
1147
+ 'local_name': dim.name}))
1148
+ self.config.all_scalar_data.append(
1149
+ PointByPointScanData(**dim.model_dump()))
1150
+ self.config.independent_dimensions.remove(dim)
1151
+ if scalar_signals:
1152
+ nxentry.scalar_data = NXdata()
1153
+ for k, v in zip(scalar_signals, scalar_data):
1154
+ nxentry.scalar_data[k] = v
1155
+ if 'SCAN_N' in scalar_signals:
1156
+ nxentry.scalar_data.attrs['signal'] = 'SCAN_N'
1157
+ else:
1158
+ nxentry.scalar_data.attrs['signal'] = scalar_signals[0]
1159
+ scalar_signals.remove(nxentry.scalar_data.attrs['signal'])
1160
+ nxentry.scalar_data.attrs['auxiliary_signals'] = scalar_signals
1161
+
1162
+ # Add detector data
1163
+ nxdata = NXdata()
1164
+ nxentry.data = nxdata
1165
+ nxentry.data.set_default()
1166
+ detector_ids = []
1167
+ for k, v in self.config.attrs.items():
1168
+ nxdata.attrs[k] = v
1169
+ min_ = np.min(data, axis=tuple(range(1, data.ndim)))
1170
+ max_ = np.max(data, axis=tuple(range(1, data.ndim)))
1171
+ for i, detector in enumerate(self.detector_config.detectors):
1172
+ nxdata[detector.get_id()] = NXfield(
1173
+ value=data[i],
1174
+ attrs={**detector.attrs, 'min': min_[i], 'max': max_[i]})
1175
+ detector_ids.append(detector.get_id())
1176
+ linkdims(nxdata, nxentry.independent_dimensions)
1177
+ if len(self.detector_config.detectors) == 1:
1178
+ nxdata.attrs['signal'] = self.detector_config.detectors[0].get_id()
1179
+ nxentry.detector_ids = detector_ids
1180
+
1181
+ return nxroot
1182
+
1183
+ def _read_raw_data_edd(
1184
+ self, comm, num_scan, offset, placeholder_data):
1185
+ """Read the raw EDD data for a given map configuration.
1186
+
1187
+ :param comm: MPI communicator.
1188
+ :type comm: mpi4py.MPI.Comm, optional
1189
+ :param num_scan: Number of scans in the map.
1190
+ :type num_scan: int
1191
+ :param offset: Offset scan number of current processor.
1192
+ :type offset: int
1193
+ :param placeholder_data: Value to use for missing detector
1194
+ data frames, or `False` if missing data should raise an
1195
+ error.
1196
+ :type placeholder_data: object
1197
+ :return: The map's raw data, independent dimensions and scalar
1198
+ data.
1199
+ :rtype: numpy.ndarray, numpy.ndarray, numpy.ndarray
1200
+ """
1201
+ # Third party modules
1202
+ try:
1203
+ from mpi4py import MPI
1204
+ from mpi4py.util import dtlib
1205
+ except Exception:
1206
+ pass
1207
+
1208
+ # Local modules
1209
+ from CHAP.utils.general import list_to_string
1210
+
1211
+ if comm is None:
1212
+ self.num_proc = 1
1213
+ rank = 0
1214
+ else:
1215
+ self.num_proc = comm.Get_size()
1216
+ rank = comm.Get_rank()
1217
+ if not rank:
1218
+ self.logger.debug(f'Number of processors: {self.num_proc}')
1219
+ self.logger.debug(f'Number of scans: {num_scan}')
1220
+
1221
+ # Create the shared data buffers
1222
+ # FIX: just one spec scan at this point
1223
+ assert len(self.config.spec_scans) == 1
1224
+ scan = self.config.spec_scans[0]
1225
+ scan_numbers = scan.scan_numbers
1226
+ scanparser = scan.get_scanparser(scan_numbers[0])
1227
+ detector_ids = [
1228
+ int(d.get_id()) for d in self.detector_config.detectors]
1229
+ ddata, placeholder_used = scanparser.get_detector_data(
1230
+ detector_ids, placeholder_data=placeholder_data)
1231
+ spec_scan_shape = scanparser.spec_scan_shape
1232
+ num_dim = np.prod(spec_scan_shape)
1233
+ num_id = len(self.config.independent_dimensions)
1234
+ num_sd = len(self.config.all_scalar_data)
1235
+ if placeholder_data is not False:
1236
+ num_sd += 1
1237
+ if self.num_proc == 1:
1238
+ assert num_scan == len(scan_numbers)
1239
+ data = np.empty((num_scan, *ddata.shape), dtype=ddata.dtype)
1240
+ independent_dimensions = np.empty(
1241
+ (num_id, num_scan*num_dim), dtype=np.float64)
1242
+ all_scalar_data = np.empty(
1243
+ (num_sd, num_scan*num_dim), dtype=np.float64)
1244
+ else:
1245
+ self.logger.debug(f'Scan offset on processor {rank}: {offset}')
1246
+ self.logger.debug(f'Scan numbers on processor {rank}: '
1247
+ f'{list_to_string(scan_numbers)}')
1248
+ datatype = dtlib.from_numpy_dtype(ddata.dtype)
1249
+ itemsize = datatype.Get_size()
1250
+ if not rank:
1251
+ nbytes = num_scan * np.prod(ddata.shape) * itemsize
1252
+ else:
1253
+ nbytes = 0
1254
+ win = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm)
1255
+ buf, itemsize = win.Shared_query(0)
1256
+ assert itemsize == datatype.Get_size()
1257
+ data = np.ndarray(
1258
+ buffer=buf, dtype=ddata.dtype, shape=(num_scan, *ddata.shape))
1259
+ datatype = dtlib.from_numpy_dtype(np.float64)
1260
+ itemsize = datatype.Get_size()
1261
+ if not rank:
1262
+ nbytes = num_id * num_scan * num_dim * itemsize
1263
+ win_id = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm)
1264
+ buf_id, _ = win_id.Shared_query(0)
1265
+ independent_dimensions = np.ndarray(
1266
+ buffer=buf_id, dtype=np.float64,
1267
+ shape=(num_id, num_scan*num_dim))
1268
+ if not rank:
1269
+ nbytes = num_sd * num_scan * num_dim * itemsize
1270
+ win_sd = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm)
1271
+ buf_sd, _ = win_sd.Shared_query(0)
1272
+ all_scalar_data = np.ndarray(
1273
+ buffer=buf_sd, dtype=np.float64,
1274
+ shape=(num_sd, num_scan*num_dim))
1275
+
1276
+ # Read the raw data
1277
+ init = True
1278
+ for scan in self.config.spec_scans:
1279
+ for scan_number in scan.scan_numbers:
1280
+ if init:
1281
+ init = False
1282
+ else:
1283
+ scanparser = scan.get_scanparser(scan_number)
1284
+ assert spec_scan_shape == scanparser.spec_scan_shape
1285
+ ddata, placeholder_used = scanparser.get_detector_data(
1286
+ detector_ids, placeholder_data=placeholder_data)
1287
+ data[offset] = ddata
1288
+ start_dim = offset * num_dim
1289
+ end_dim = start_dim + num_dim
1290
+ for i, dim in enumerate(self.config.independent_dimensions):
1291
+ independent_dimensions[i][start_dim:end_dim] = \
1292
+ dim.get_value(
1293
+ scan, scan_number, scan_step_index=-1,
1294
+ relative=False)
1295
+ for i, dim in enumerate(self.config.all_scalar_data):
1296
+ all_scalar_data[i][start_dim:end_dim] = dim.get_value(
1297
+ scan, scan_number, scan_step_index=-1,
1298
+ relative=False)
1299
+ if placeholder_data is not False:
1300
+ all_scalar_data[-1][start_dim:end_dim] = \
1301
+ placeholder_used
1302
+ offset += 1
1303
+
1304
+ return (
1305
+ np.swapaxes(
1306
+ data.reshape((np.prod(data.shape[:2]), *data.shape[2:])),
1307
+ 0, 1),
1308
+ independent_dimensions, all_scalar_data)
1309
+
1310
+ # @profile
1311
+ def _read_raw_data(self, comm, num_scan, offset):
1312
+ """Read the raw data for a given map configuration.
1313
+
1314
+ :param comm: MPI communicator.
1315
+ :type comm: mpi4py.MPI.Comm, optional
1316
+ :param num_scan: Number of scans in the map.
1317
+ :type num_scan: int
1318
+ :param offset: Offset scan number of current processor.
1319
+ :type offset: int
1320
+ :return: The map's raw data, independent dimensions and scalar
1321
+ data.
1322
+ :rtype: numpy.ndarray, numpy.ndarray, numpy.ndarray
1323
+ """
1324
+ # Third party modules
1325
+ try:
1326
+ from mpi4py import MPI
1327
+ from mpi4py.util import dtlib
1328
+ except Exception:
1329
+ pass
1330
+
1331
+ # Local modules
1332
+ from CHAP.utils.general import list_to_string
1333
+
1334
+ if comm is None:
1335
+ self.num_proc = 1
1336
+ rank = 0
1337
+ else:
1338
+ self.num_proc = comm.Get_size()
1339
+ rank = comm.Get_rank()
1340
+ if not rank:
1341
+ self.logger.debug(f'Number of processors: {self.num_proc}')
1342
+ self.logger.debug(f'Number of scans: {num_scan}')
1343
+
1344
+ # Create the shared data buffers
1345
+ assert len(self.config.spec_scans) == 1
1346
+ scans = self.config.spec_scans[0]
1347
+ scan_numbers = scans.scan_numbers
1348
+ scanparser = scans.get_scanparser(scan_numbers[0])
1349
+ #RV only correct for multiple detectors if the same image sizes
1350
+ if len(self.detector_config.detectors) != 1:
1351
+ raise ValueError('Multiple detectors not tested yet')
1352
+ if self.config.experiment_type == 'TOMO':
1353
+ dtype = np.float32
1354
+ ddata = scanparser.get_detector_data(
1355
+ self.detector_config.detectors[0].get_id(), dtype=dtype)
1356
+ else:
1357
+ dtype = None
1358
+ ddata = scanparser.get_detector_data(
1359
+ self.detector_config.detectors[0].get_id())
1360
+ num_det = len(self.detector_config.detectors)
1361
+ num_dim = ddata.shape[0]
1362
+ num_id = len(self.config.independent_dimensions)
1363
+ num_sd = len(self.config.all_scalar_data)
1364
+ if self.num_proc == 1:
1365
+ assert num_scan == len(scan_numbers)
1366
+ data = num_det*[num_scan*[None]]
1367
+ independent_dimensions = np.empty(
1368
+ (num_scan, num_id, num_dim), dtype=np.float64)
1369
+ if num_sd:
1370
+ all_scalar_data = np.empty(
1371
+ (num_scan, num_sd, num_dim), dtype=np.float64)
1372
+ else:
1373
+ self.logger.debug(f'Scan offset on processor {rank}: {offset}')
1374
+ self.logger.debug(f'Scan numbers on processor {rank}: '
1375
+ f'{list_to_string(scan_numbers)}')
1376
+ datatype = dtlib.from_numpy_dtype(dtype)
1377
+ itemsize = datatype.Get_size()
1378
+ if not rank:
1379
+ nbytes = num_scan * np.prod(ddata.shape) * itemsize
1380
+ else:
1381
+ nbytes = 0
1382
+ win = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm)
1383
+ buf, _ = win.Shared_query(0)
1384
+ #RV improve memory requirements ala single processor case?
1385
+ data = np.ndarray(
1386
+ buffer=buf, dtype=dtype,
1387
+ shape=(num_det, num_scan, *ddata.shape))
1388
+ datatype = dtlib.from_numpy_dtype(np.float64)
1389
+ itemsize = datatype.Get_size()
1390
+ if not rank:
1391
+ nbytes = num_scan * num_id * num_dim * itemsize
1392
+ else:
1393
+ nbytes = 0
1394
+ win_id = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm)
1395
+ buf_id, _ = win_id.Shared_query(0)
1396
+ independent_dimensions = np.ndarray(
1397
+ buffer=buf_id, dtype=np.float64,
1398
+ shape=(num_scan, num_id, num_dim))
1399
+ if num_sd:
1400
+ if not rank:
1401
+ nbytes = num_scan * num_sd * num_dim * itemsize
1402
+ win_sd = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm)
1403
+ buf_sd, _ = win_sd.Shared_query(0)
1404
+ all_scalar_data = np.ndarray(
1405
+ buffer=buf_sd, dtype=np.float64,
1406
+ shape=(num_scan, num_sd, num_dim))
1407
+ else:
1408
+ all_scalar_data = None
1409
+
1410
+ # Read the raw data
1411
+ init = True
1412
+ for scans in self.config.spec_scans:
1413
+ for scan_number in scans.scan_numbers:
1414
+ for i in range(len((self.detector_config.detectors))):
1415
+ if init:
1416
+ init = False
1417
+ data[i][offset] = ddata
1418
+ del ddata
1419
+ else:
1420
+ scanparser = scans.get_scanparser(scan_number)
1421
+ data[i][offset] = scanparser.get_detector_data(
1422
+ self.detector_config.detectors[i].get_id(),
1423
+ dtype=dtype)
1424
+ for i, dim in enumerate(self.config.independent_dimensions):
1425
+ if dim.data_type in ['scan_column',
1426
+ 'detector_log_timestamps']:
1427
+ independent_dimensions[offset,i] = dim.get_value(
1428
+ scans, scan_number, scan_step_index=-1,
1429
+ relative=False)[:num_dim]
1430
+ elif dim.data_type in ['smb_par', 'spec_motor',
1431
+ 'expression']:
1432
+ independent_dimensions[offset,i] = dim.get_value(
1433
+ scans, scan_number, scan_step_index=-1,
1434
+ relative=False,
1435
+ scalar_data=self.config.scalar_data)
1436
+ else:
1437
+ raise RuntimeError(
1438
+ f'{dim.data_type} in data_type not tested')
1439
+ for i, dim in enumerate(self.config.all_scalar_data):
1440
+ all_scalar_data[offset,i] = dim.get_value(
1441
+ scans, scan_number, scan_step_index=-1,
1442
+ relative=False)
1443
+ offset += 1
1444
+ if self.num_proc == 1:
1445
+ data = np.asarray(data)
1446
+ if num_sd:
1447
+ return (
1448
+ data.reshape(
1449
+ (data.shape[0], np.prod(data.shape[1:3]),
1450
+ *data.shape[3:])),
1451
+ np.stack(tuple([independent_dimensions[:,i].flatten()
1452
+ for i in range(num_id)])),
1453
+ np.stack(tuple([all_scalar_data[:,i].flatten()
1454
+ for i in range(num_sd)])))
1455
+ return (
1456
+ data.reshape(
1457
+ (data.shape[0], np.prod(data.shape[1:3]), *data.shape[3:])),
1458
+ np.stack(tuple([independent_dimensions[:,i].flatten()
1459
+ for i in range(num_id)])),
1460
+ None)
1461
+
1462
+
1463
+ class MPICollectProcessor(Processor):
1464
+ """A Processor that collects the distributed worker data from
1465
+ MPIMapProcessor on the root node.
1466
+ """
1467
+ def process(self, data, comm, root_as_worker=True):
1468
+ """Collect data on root node.
1469
+
1470
+ :param data: Input data.
1471
+ :type data: list[PipelineData]
1472
+ :param comm: MPI communicator.
1473
+ :type comm: mpi4py.MPI.Comm, optional
1474
+ :param root_as_worker: Use the root node as a worker,
1475
+ defaults to `True`.
1476
+ :type root_as_worker: bool, optional
1477
+ :return: Returns a list of the distributed worker data on the
1478
+ root node.
1479
+ """
1480
+ num_proc = comm.Get_size()
1481
+ rank = comm.Get_rank()
1482
+ if root_as_worker:
1483
+ data = self.unwrap_pipelinedata(data)[-1]
1484
+ if num_proc > 1:
1485
+ data = comm.gather(data, root=0)
1486
+ else:
1487
+ for n_worker in range(1, num_proc):
1488
+ if rank == n_worker:
1489
+ comm.send(self.unwrap_pipelinedata(data)[-1], dest=0)
1490
+ data = None
1491
+ elif not rank:
1492
+ if n_worker == 1:
1493
+ data = [comm.recv(source=n_worker)]
1494
+ else:
1495
+ data.append(comm.recv(source=n_worker))
1496
+ #FIX RV TODO Merge the list of data items in some generic fashion
1497
+ return data
1498
+
1499
+
1500
+ class MPIMapProcessor(Processor):
1501
+ """A Processor that applies a parallel generic sub-pipeline to
1502
+ a map configuration.
1503
+ """
1504
+ def process(self, data, config=None, sub_pipeline=None):
1505
+ """Run a parallel generic sub-pipeline.
1506
+
1507
+ :param data: Input data.
1508
+ :type data: list[PipelineData]
1509
+ :param config: Initialization parameters for an instance of
1510
+ common.models.map.MapConfig.
1511
+ :type config: dict, optional
1512
+ :param sub_pipeline: The sub-pipeline.
1513
+ :type sub_pipeline: Pipeline, optional
1514
+ :return: The `data` field of the first item in the returned
1515
+ list of sub-pipeline items.
1516
+ """
1517
+ # Third party modules
1518
+ from mpi4py import MPI
1519
+
1520
+ # Local modules
1521
+ from CHAP.models import RunConfig
1522
+ from CHAP.runner import run
1523
+ from CHAP.common.models.map import SpecScans
1524
+
1525
+ raise NotImplementedError('MPIMapProcessor needs updating and testing')
1526
+ # pylint: disable=c-extension-no-member
1527
+ comm = MPI.COMM_WORLD
1528
+ num_proc = comm.Get_size()
1529
+ rank = comm.Get_rank()
1530
+
1531
+ # Get the validated map configuration
1532
+ map_config = self.get_config(
1533
+ data=data, config=config, schema='common.models.map.MapConfig')
1534
+
1535
+ # Create the spec reader configuration for each processor
1536
+ # FIX: catered to EDD with one spec scan
1537
+ assert len(map_config.spec_scans) == 1
1538
+ spec_scans = map_config.spec_scans[0]
1539
+ scan_numbers = spec_scans.scan_numbers
1540
+ num_scan = len(scan_numbers)
1541
+ scans_per_proc = num_scan//num_proc
1542
+ n_scan = 0
1543
+ for n_proc in range(num_proc):
1544
+ num = scans_per_proc
1545
+ if n_proc == rank:
1546
+ if rank < num_scan - scans_per_proc*num_proc:
1547
+ num += 1
1548
+ scan_numbers = scan_numbers[n_scan:n_scan+num]
1549
+ n_scan += num
1550
+ spec_config = {
1551
+ 'station': map_config.station,
1552
+ 'experiment_type': map_config.experiment_type,
1553
+ 'spec_scans': [SpecScans(
1554
+ spec_file=spec_scans.spec_file, scan_numbers=scan_numbers)]}
1555
+
1556
+ # Get the run configuration to use for the sub-pipeline
1557
+ if sub_pipeline is None:
1558
+ sub_pipeline = {}
1559
+ run_config = {'inputdir': self.inputdir, 'outputdir': self.outputdir,
1560
+ 'interactive': self.interactive, 'log_level': self.log_level}
1561
+ run_config.update(sub_pipeline.get('config'))
1562
+ run_config = RunConfig(**run_config, comm=comm)
1563
+ pipeline_config = []
1564
+ for item in sub_pipeline['pipeline']:
1565
+ if isinstance(item, dict):
1566
+ for k, v in deepcopy(item).items():
1567
+ if k.endswith('Reader'):
1568
+ v['config'] = spec_config
1569
+ item[k] = v
1570
+ if num_proc > 1 and k.endswith('Writer'):
1571
+ r, e = os.path.splitext(v['filename'])
1572
+ v['filename'] = f'{r}_{rank}{e}'
1573
+ item[k] = v
1574
+ pipeline_config.append(item)
1575
+
1576
+ # Run the sub-pipeline on each processor
1577
+ return run(run_config, pipeline_config, logger=self.logger, comm=comm)
1578
+
1579
+
1580
+ class MPISpawnMapProcessor(Processor):
1581
+ """A Processor that applies a parallel generic sub-pipeline to
1582
+ a map configuration by spawning workers processes.
1583
+ """
1584
+ def process(
1585
+ self, data, num_proc=1, root_as_worker=True, collect_on_root=False,
1586
+ sub_pipeline=None):
1587
+ """Spawn workers running a parallel generic sub-pipeline.
1588
+
1589
+ :param data: Input data.
1590
+ :type data: list[PipelineData]
1591
+ :param num_proc: Number of spawned processors, defaults to `1`.
1592
+ :type num_proc: int, optional
1593
+ :param root_as_worker: Use the root node as a worker,
1594
+ defaults to `True`.
1595
+ :type root_as_worker: bool, optional
1596
+ :param collect_on_root: Collect the result of the spawned
1597
+ workers on the root node, defaults to `False`.
1598
+ :type collect_on_root: bool, optional
1599
+ :param sub_pipeline: The sub-pipeline.
1600
+ :type sub_pipeline: Pipeline, optional
1601
+ :return: The `data` field of the first item in the returned
1602
+ list of sub-pipeline items.
1603
+ """
1604
+ # Third party modules
1605
+ from mpi4py import MPI
1606
+ import yaml
1607
+
1608
+ # Local modules
1609
+ from CHAP.models import RunConfig
1610
+ from CHAP.runner import runner
1611
+ from CHAP.common.models.map import SpecScans
1612
+
1613
+ raise NotImplementedError('MPIMapProcessor needs updating and testing')
1614
+ # Get the map configuration from data
1615
+ map_config = self.get_config(
1616
+ data=data, schema='common.models.map.MapConfig')
1617
+
1618
+ # Get the run configuration to use for the sub-pipeline
1619
+ # Optionally include the root node as a worker node
1620
+ if sub_pipeline is None:
1621
+ sub_pipeline = {}
1622
+ run_config = {'inputdir': self.inputdir, 'outputdir': self.outputdir,
1623
+ 'interactive': self.interactive, 'log_level': self.log_level}
1624
+ run_config.update(sub_pipeline.get('config'))
1625
+ if root_as_worker:
1626
+ first_proc = 1
1627
+ spawn = 1
1628
+ else:
1629
+ first_proc = 0
1630
+ spawn = -1
1631
+ run_config = RunConfig(**run_config, logger=self.logger, spawn=spawn)
1632
+
1633
+ # Create the sub-pipeline configuration for each processor
1634
+ spec_scans = map_config.spec_scans[0]
1635
+ scan_numbers = spec_scans.scan_numbers
1636
+ num_scan = len(scan_numbers)
1637
+ scans_per_proc = num_scan//num_proc
1638
+ n_scan = 0
1639
+ pipeline_config = []
1640
+ for n_proc in range(num_proc):
1641
+ num = scans_per_proc
1642
+ if n_proc < num_scan - scans_per_proc*num_proc:
1643
+ num += 1
1644
+ spec_config = {
1645
+ 'station': map_config.station,
1646
+ 'experiment_type': map_config.experiment_type,
1647
+ 'spec_scans': [SpecScans(
1648
+ spec_file=spec_scans.spec_file,
1649
+ scan_numbers=scan_numbers[n_scan:n_scan+num]).__dict__]}
1650
+ sub_pipeline_config = []
1651
+ for item in deepcopy(sub_pipeline['pipeline']):
1652
+ if isinstance(item, dict):
1653
+ for k, v in deepcopy(item).items():
1654
+ if k.endswith('Reader'):
1655
+ v['config'] = spec_config
1656
+ item[k] = v
1657
+ if num_proc > 1 and k.endswith('Writer'):
1658
+ r, e = os.path.splitext(v['filename'])
1659
+ v['filename'] = f'{r}_{n_proc}{e}'
1660
+ item[k] = v
1661
+ sub_pipeline_config.append(item)
1662
+ if collect_on_root and (not root_as_worker or num_proc > 1):
1663
+ sub_pipeline_config += [
1664
+ {'common.MPICollectProcessor': {
1665
+ 'root_as_worker': root_as_worker}}]
1666
+ pipeline_config.append(sub_pipeline_config)
1667
+ n_scan += num
1668
+
1669
+ # Spawn the workers to run the sub-pipeline
1670
+ if num_proc > first_proc:
1671
+ # System modules
1672
+ from tempfile import NamedTemporaryFile
1673
+
1674
+ tmp_names = []
1675
+ with NamedTemporaryFile(delete=False) as fp:
1676
+ # pylint: disable=c-extension-no-member
1677
+ fp_name = fp.name
1678
+ tmp_names.append(fp_name)
1679
+ with open(fp_name, 'w') as f:
1680
+ yaml.dump(
1681
+ {'config': {'spawn': run_config.spawn}}, f,
1682
+ sort_keys=False)
1683
+ for n_proc in range(first_proc, num_proc):
1684
+ f_name = f'{fp_name}_{n_proc}'
1685
+ tmp_names.append(f_name)
1686
+ with open(f_name, 'w') as f:
1687
+ yaml.dump(
1688
+ #FIX once comm is a field of RunConfig
1689
+ #{'config': run_config.model_dump(exclude='comm'),
1690
+ {'config': run_config.model_dump(),
1691
+ 'pipeline': pipeline_config[n_proc]},
1692
+ f, sort_keys=False)
1693
+ # pylint: disable=used-before-assignment
1694
+ sub_comm = MPI.COMM_SELF.Spawn(
1695
+ 'CHAP', args=[fp_name], maxprocs=num_proc-first_proc)
1696
+ common_comm = sub_comm.Merge(False)
1697
+ if run_config.spawn > 0:
1698
+ # Align with the barrier in RunConfig() on common_comm
1699
+ # called from the spawned main()
1700
+ common_comm.barrier()
1701
+ else:
1702
+ common_comm = None
1703
+
1704
+ # Run the sub-pipeline on the root node
1705
+ if root_as_worker:
1706
+ data = runner(run_config, pipeline_config[0], comm=common_comm)
1707
+ elif collect_on_root:
1708
+ run_config.spawn = 0
1709
+ pipeline_config = [{'common.MPICollectProcessor': {
1710
+ 'root_as_worker': root_as_worker}}]
1711
+ data = runner(run_config, pipeline_config, common_comm)
1712
+ else:
1713
+ # Align with the barrier in run() on common_comm
1714
+ # called from the spawned main()
1715
+ common_comm.barrier()
1716
+ data = None
1717
+
1718
+ # Disconnect spawned workers and cleanup temporary files
1719
+ if num_proc > first_proc:
1720
+ # Align with the barrier in main() on common_comm
1721
+ # when disconnecting the spawned worker
1722
+ common_comm.barrier()
1723
+ # Disconnect spawned workers and cleanup temporary files
1724
+ sub_comm.Disconnect()
1725
+ for tmp_name in tmp_names:
1726
+ os.remove(tmp_name)
1727
+
1728
+ return data
1729
+
1730
+
1731
+ class NexusToNumpyProcessor(Processor):
1732
+ """A Processor to convert the default plottable data in a NeXus
1733
+ object into a `numpy.ndarray`.
1734
+ """
1735
+ def process(self, data):
1736
+ """Return the default plottable data signal in a NeXus object
1737
+ contained in `data` as an `numpy.ndarray`.
1738
+
1739
+ :param data: Input data.
1740
+ :type data: nexusformat.nexus.NXobject
1741
+ :raises ValueError: If `data` has no default plottable data
1742
+ signal.
1743
+ :return: The default plottable data signal.
1744
+ :rtype: numpy.ndarray
1745
+ """
1746
+ # Third party modules
1747
+ from nexusformat.nexus import NXdata
1748
+
1749
+ data = self.unwrap_pipelinedata(data)[-1]
1750
+
1751
+ if isinstance(data, NXdata):
1752
+ default_data = data
1753
+ else:
1754
+ default_data = data.plottable_data
1755
+ if default_data is None:
1756
+ default_data_path = data.attrs.get('default')
1757
+ default_data = data.get(default_data_path)
1758
+ if default_data is None:
1759
+ raise ValueError(
1760
+ f'The structure of {data} contains no default data')
1761
+
1762
+ try:
1763
+ default_signal = default_data.attrs['signal']
1764
+ except Exception as exc:
1765
+ raise ValueError(
1766
+ f'The signal of {default_data} is unknown') from exc
1767
+
1768
+ np_data = default_data[default_signal].nxdata
1769
+
1770
+ return np_data
1771
+
1772
+
1773
+ class NexusToXarrayProcessor(Processor):
1774
+ """A Processor to convert the default plottable data in a
1775
+ NeXus object into an `xarray.DataArray`.
1776
+ """
1777
+ def process(self, data):
1778
+ """Return the default plottable data signal in a NeXus object
1779
+ contained in `data` as an `xarray.DataArray`.
1780
+
1781
+ :param data: Input data.
1782
+ :type data: nexusformat.nexus.NXobject
1783
+ :raises ValueError: If metadata for `xarray` is absent from
1784
+ `data`
1785
+ :return: The default plottable data signal.
1786
+ :rtype: xarray.DataArray
1787
+ """
1788
+ # Third party modules
1789
+ from nexusformat.nexus import NXdata
1790
+ from xarray import DataArray
1791
+
1792
+ data = self.unwrap_pipelinedata(data)[-1]
1793
+
1794
+ if isinstance(data, NXdata):
1795
+ default_data = data
1796
+ else:
1797
+ default_data = data.plottable_data
1798
+ if default_data is None:
1799
+ default_data_path = data.attrs.get('default')
1800
+ default_data = data.get(default_data_path)
1801
+ if default_data is None:
1802
+ raise ValueError(
1803
+ f'The structure of {data} contains no default data')
1804
+
1805
+ try:
1806
+ default_signal = default_data.attrs['signal']
1807
+ except Exception as exc:
1808
+ raise ValueError(
1809
+ f'The signal of {default_data} is unknown') from exc
1810
+ signal_data = default_data[default_signal].nxdata
1811
+
1812
+ axes = default_data.attrs['axes']
1813
+ if isinstance(axes, str):
1814
+ axes = [axes]
1815
+ coords = {}
1816
+ for axis_name in axes:
1817
+ axis = default_data[axis_name]
1818
+ coords[axis_name] = (axis_name, axis.nxdata, axis.attrs)
1819
+
1820
+ dims = tuple(axes)
1821
+ name = default_signal
1822
+ attrs = default_data[default_signal].attrs
1823
+
1824
+ return DataArray(data=signal_data,
1825
+ coords=coords,
1826
+ dims=dims,
1827
+ name=name,
1828
+ attrs=attrs)
1829
+
1830
+
1831
+ class NormalizeNexusProcessor(Processor):
1832
+ """Processor for scaling one or more NXfields in the input nexus
1833
+ structure by the values of another NXfield in the same
1834
+ structure."""
1835
+ def process(self, data, normalize_nxfields, normalize_by_nxfield):
1836
+ """Return copy of the original input nexus structure with
1837
+ additional fields containing the normalized data of each field
1838
+ in `normalize_nxfields`.
1839
+
1840
+ :param data: Input nexus structure containing all fields to be
1841
+ normalized an the field by which to normalize them.
1842
+ :type data: nexusformat.nexus.NXgroup
1843
+ :param normalize_nxfields:
1844
+ :type normalize_nxfields: list[str]
1845
+ :param normalize_by_nxfield: Path in `data` to the `NXfield`
1846
+ containing normalization data
1847
+ :type normalize_by_nxfield: str
1848
+ :returns: Copy of input data with additional normalized fields
1849
+ :rtype: nexusformat.nexus.NXgroup
1850
+ """
1851
+ # Third party modules
1852
+ from nexusformat.nexus import (
1853
+ NXgroup,
1854
+ NXfield,
1855
+ )
1856
+
1857
+ # Local modules
1858
+ from CHAP.utils.general import nxcopy
1859
+
1860
+ # Check input data
1861
+ data = self.unwrap_pipelinedata(data)[0]
1862
+ data = nxcopy(data)
1863
+ if not isinstance(data, NXgroup):
1864
+ raise TypeError(f'Expected NXgroup, got (type{data})')
1865
+
1866
+ # Check normalize_by_nxfield
1867
+ if normalize_by_nxfield not in data:
1868
+ raise ValueError(
1869
+ f'{normalize_by_nxfield} not present in input data')
1870
+ if not isinstance(data[normalize_by_nxfield], NXfield):
1871
+ raise TypeError(
1872
+ f'{normalize_by_nxfield} is {type(data[normalize_by_nxfield])}'
1873
+ + ', expected NXfield')
1874
+ normalization_data = data[normalize_by_nxfield].nxdata
1875
+
1876
+ # Process normalize_nxfields
1877
+ for nxfield in normalize_nxfields:
1878
+ if nxfield not in data:
1879
+ self.logger.error(f'{nxfield} not present in input data')
1880
+ elif not isinstance(data[nxfield], NXfield):
1881
+ self.logger.error(
1882
+ f'{nxfield} is {type(data[nxfield])}, expected NXfield')
1883
+ else:
1884
+ field_shape = data[nxfield].nxdata.shape
1885
+ if not normalization_data.shape == \
1886
+ field_shape[:normalization_data.ndim]:
1887
+ self.logger.error(
1888
+ f'Incompatible dataset shapes: {normalize_by_nxfield} '
1889
+ + f'is {normalization_data.shape}, '
1890
+ + f'{nxfield} is {field_shape}'
1891
+ )
1892
+ else:
1893
+ self.logger.info(f'Normalizing {nxfield}')
1894
+ # make shapes compatible
1895
+ _normalization_data = normalization_data.reshape(
1896
+ normalization_data.shape + (1,)
1897
+ * (data[nxfield].nxdata.ndim
1898
+ - normalization_data.ndim))
1899
+ data[f'{nxfield}_normalized'] = NXfield(
1900
+ value=data[nxfield].nxdata / _normalization_data,
1901
+ attrs={**data[nxfield].attrs,
1902
+ 'normalized_by': normalize_by_nxfield}
1903
+ )
1904
+ return data
1905
+
1906
+
1907
+ class NormalizeMapProcessor(Processor):
1908
+ """Processor for calling `NormalizeNexusProcessor` for (usually
1909
+ all) detector data in an `NXroot` resulting from
1910
+ `MapProcessor`"""
1911
+ def process(self, data, normalize_by_nxfield, detector_ids=None):
1912
+ """Return copy of the original input map `NXroot` with
1913
+ additional fields containing normalized detector data.
1914
+
1915
+ :param data: Input nexus structure containing all fields to be
1916
+ normalized an the field by which to normalize them.
1917
+ :type data: nexusformat.nexus.NXroot
1918
+ :param normalize_by_nxfield: Path in `data` to the `NXfield`
1919
+ containing normalization data
1920
+ :type normalize_by_nxfield: str
1921
+ :returns: Copy of input data with additional normalized fields
1922
+ :rtype: nexusformat.nexus.NXroot
1923
+ """
1924
+ # Third party modules
1925
+ from nexusformat.nexus import (
1926
+ NXentry,
1927
+ NXlink,
1928
+ )
1929
+
1930
+ # Check input data
1931
+ data = self.unwrap_pipelinedata(data)[0]
1932
+ map_title = None
1933
+ for k, v in data.items():
1934
+ if isinstance(v, NXentry):
1935
+ map_title = k
1936
+ break
1937
+ if map_title is None:
1938
+ self.logger.error(f'Input data contains no NXentry')
1939
+ else:
1940
+ self.logger.info(f'Got map_title: {map_title}')
1941
+
1942
+ # Check detector_ids
1943
+ normalize_nxfields = []
1944
+ if detector_ids is None:
1945
+ detector_ids = [k for k in data[map_title].data.keys()
1946
+ if not isinstance(data[map_title].data[k], NXlink)]
1947
+ self.logger.info(f'Using detector_ids: {detector_ids}')
1948
+ normalize_nxfields = [f'{map_title}/data/{_id}'
1949
+ for _id in detector_ids]
1950
+
1951
+ # Normalize
1952
+ normalizer = NormalizeNexusProcessor()
1953
+ normalizer.logger = self.logger
1954
+ return normalizer.process(
1955
+ data, normalize_nxfields, normalize_by_nxfield)
1956
+
1957
+
1958
+ class PrintProcessor(Processor):
1959
+ """A Processor to simply print the input data to stdout and return
1960
+ the original input data, unchanged in any way.
1961
+ """
1962
+ def process(self, data):
1963
+ """Print and return the input data.
1964
+
1965
+ :param data: Input data.
1966
+ :type data: object
1967
+ :return: `data`
1968
+ :rtype: object
1969
+ """
1970
+ if callable(getattr(data, '_str_tree', None)):
1971
+ # If data is likely a NeXus NXobject, print its tree
1972
+ # representation (since NXobjects' str representations are
1973
+ # just their nxname)
1974
+ print(data._str_tree(attrs=True, recursive=True))
1975
+ else:
1976
+ # System modules
1977
+ from pprint import pprint
1978
+
1979
+ pprint(data)
1980
+
1981
+ return data
1982
+
1983
+
1984
+ class PyfaiAzimuthalIntegrationProcessor(Processor):
1985
+ """Processor to azimuthally integrate one or more frames of 2d
1986
+ detector data using the
1987
+ [pyFAI](https://pyfai.readthedocs.io/en/v2023.1/index.html)
1988
+ package.
1989
+ """
1990
+ def process(
1991
+ self, data, poni_file, npt, mask_file=None,
1992
+ integrate1d_kwargs=None):
1993
+ """Azimuthally integrate the detector data provided and return
1994
+ the result as a dictionary of numpy arrays containing the
1995
+ values of the radial coordinate of the result, the intensities
1996
+ along the radial direction, and the poisson errors for each
1997
+ intensity spectrum.
1998
+
1999
+ :param data: Detector data to integrate.
2000
+ :type data: Union[PipelineData, list[np.ndarray]]
2001
+ :param poni_file: Name of the [pyFAI PONI file]
2002
+ containing the detector properties pyFAI needs to perform
2003
+ azimuthal integration.
2004
+ :type poni_file: str
2005
+ :param npt: Number of points in the output pattern.
2006
+ :type npt: int
2007
+ :param mask_file: A file to use for masking the input data.
2008
+ :type mask_file: str, optional
2009
+ :param integrate1d_kwargs: Optional dictionary of keywords
2010
+ :type integrate1d_kwargs: Optional[dict]
2011
+ :returns: Azimuthal integration results as a dictionary of
2012
+ numpy arrays.
2013
+ """
2014
+ # Third party modules
2015
+ from pyFAI import load
2016
+
2017
+ if not os.path.isabs(poni_file):
2018
+ poni_file = os.path.join(self.inputdir, poni_file)
2019
+ ai = load(poni_file)
2020
+
2021
+ if mask_file is None:
2022
+ mask = None
2023
+ else:
2024
+ # Third party modules
2025
+ import fabio
2026
+
2027
+ if not os.path.isabs(mask_file):
2028
+ mask_file = os.path.join(self.inputdir, mask_file)
2029
+ mask = fabio.open(mask_file).data
2030
+
2031
+ try:
2032
+ det_data = self.unwrap_pipelinedata(data)[0]
2033
+ except Exception:
2034
+ det_data = data
2035
+
2036
+ if integrate1d_kwargs is None:
2037
+ integrate1d_kwargs = {}
2038
+ integrate1d_kwargs['mask'] = mask
2039
+
2040
+ return [ai.integrate1d(d, npt, **integrate1d_kwargs) for d in det_data]
2041
+
2042
+
2043
+ class RawDetectorDataMapProcessor(Processor):
2044
+ """A Processor to return a map of raw detector data in a
2045
+ NeXus NXroot object.
2046
+ """
2047
+ def process(self, data, detector_name, detector_shape):
2048
+ """Process configurations for a map and return the raw
2049
+ detector data data collected over the map.
2050
+
2051
+ :param data: Input map configuration.
2052
+ :type data: list[PipelineData]
2053
+ :param detector_name: The detector prefix.
2054
+ :type detector_name: str
2055
+ :param detector_shape: The shape of detector data for a single
2056
+ scan step.
2057
+ :type detector_shape: list
2058
+ :return: Map of raw detector data.
2059
+ :rtype: nexusformat.nexus.NXroot
2060
+ """
2061
+ map_config = self.get_config(data)
2062
+ nxroot = self.get_nxroot(map_config, detector_name, detector_shape)
2063
+
2064
+ return nxroot
2065
+
2066
+ def get_config(self, data):
2067
+ """Get instances of the map configuration object needed by this
2068
+ `Processor`.
2069
+
2070
+ :param data: Result of `Reader.read` where at least one item
2071
+ has the value `'common.models.map.MapConfig'` for the
2072
+ `'schema'` key.
2073
+ :type data: list[PipelineData]
2074
+ :raises Exception: If a valid map config object cannot be
2075
+ constructed from `data`.
2076
+ :return: A valid instance of the map configuration object with
2077
+ field values taken from `data`.
2078
+ :rtype: common.models.map.MapConfig
2079
+ """
2080
+ # Local modules
2081
+ from CHAP.common.models.map import MapConfig
2082
+
2083
+ map_config = False
2084
+ if isinstance(data, list):
2085
+ for item in data:
2086
+ if isinstance(item, dict):
2087
+ if item.get('schema') == 'common.models.map.MapConfig':
2088
+ map_config = item.get('data')
2089
+
2090
+ if not map_config:
2091
+ raise ValueError('No map configuration found in input data')
2092
+
2093
+ return MapConfig(**map_config)
2094
+
2095
+ def get_nxroot(self, map_config, detector_name, detector_shape):
2096
+ """Get a map of the detector data collected by the scans in
2097
+ `map_config`. The data will be returned along with some
2098
+ relevant metadata in the form of a NeXus structure.
2099
+
2100
+ :param map_config: The map configuration.
2101
+ :type map_config: common.models.map.MapConfig
2102
+ :param detector_name: The detector prefix.
2103
+ :type detector_name: str
2104
+ :param detector_shape: The shape of detector data for a single
2105
+ scan step.
2106
+ :type detector_shape: list
2107
+ :return: A map of the raw detector data.
2108
+ :rtype: nexusformat.nexus.NXroot
2109
+ """
2110
+ # Third party modules
2111
+ # pylint: disable=no-name-in-module
2112
+ from nexusformat.nexus import (
2113
+ NXdata,
2114
+ NXdetector,
2115
+ NXinstrument,
2116
+ NXroot,
2117
+ )
2118
+ # pylint: enable=no-name-in-module
2119
+
2120
+ raise RuntimeError('Not updated for the new MapProcessor')
2121
+ nxroot = NXroot()
2122
+
2123
+ nxroot[map_config.title] = MapProcessor.get_nxentry(map_config)
2124
+ nxentry = nxroot[map_config.title]
2125
+
2126
+ nxentry.instrument = NXinstrument()
2127
+ nxentry.instrument.detector = NXdetector()
2128
+
2129
+ nxentry.instrument.detector.data = NXdata()
2130
+ nxdata = nxentry.instrument.detector.data
2131
+ nxdata.raw = np.empty((*map_config.shape, *detector_shape))
2132
+ nxdata.raw.attrs['units'] = 'counts'
2133
+ for i, det_axis_size in enumerate(detector_shape):
2134
+ nxdata[f'detector_axis_{i}_index'] = np.arange(det_axis_size)
2135
+
2136
+ for map_index in np.ndindex(map_config.shape):
2137
+ scans, scan_number, scan_step_index = \
2138
+ map_config.get_scan_step_index(map_index)
2139
+ scanparser = scans.get_scanparser(scan_number)
2140
+ self.logger.debug(
2141
+ f'Adding data to nxroot for map point {map_index}')
2142
+ nxdata.raw[map_index] = scanparser.get_detector_data(
2143
+ detector_name,
2144
+ scan_step_index)
2145
+
2146
+ nxentry.data.makelink(
2147
+ nxdata.raw,
2148
+ name=detector_name)
2149
+ for i, det_axis_size in enumerate(detector_shape):
2150
+ nxentry.data.makelink(
2151
+ nxdata[f'detector_axis_{i}_index'],
2152
+ name=f'{detector_name}_axis_{i}_index'
2153
+ )
2154
+ if isinstance(nxentry.data.attrs['axes'], str):
2155
+ nxentry.data.attrs['axes'] = [
2156
+ nxentry.data.attrs['axes'],
2157
+ f'{detector_name}_axis_{i}_index']
2158
+ else:
2159
+ nxentry.data.attrs['axes'] += [
2160
+ f'{detector_name}_axis_{i}_index']
2161
+
2162
+ nxentry.data.attrs['signal'] = detector_name
2163
+
2164
+ return nxroot
2165
+
2166
+
2167
+ class SetupNXdataProcessor(Processor):
2168
+ """Processor to set up and return an "empty" NeXus representation
2169
+ of a structured dataset. This representation will be an instance
2170
+ of a NeXus NXdata object that has:
2171
+ A NeXus NXfield entry for every coordinate/signal specified.
2172
+ `nxaxes` that are the NeXus NXfield entries for the coordinates and contain the values provided for each coordinate.
2173
+ NeXus NXfield entries of appropriate shape, but containing all zeros, for every signal.
2174
+ Attributes that define the axes, plus any additional attributes specified by the user.
2175
+
2176
+ This `Processor` is most useful as a "setup" step for constucting
2177
+ a representation of / container for a complete dataset that will
2178
+ be filled out in pieces later by `UpdateNXdataProcessor`.
2179
+ """
2180
+ def process(
2181
+ self, data, nxname='data', coords=None, signals=None, attrs=None,
2182
+ data_points=None, extra_nxfields=None, duplicates='overwrite'):
2183
+ """Return a NeXus NXdata object that has the requisite axes
2184
+ and NeXus NXfield entries to represent a structured dataset
2185
+ with the properties provided. Properties may be provided either
2186
+ through the `data` argument (from an appropriate `PipelineItem`
2187
+ that immediately preceeds this one in a `Pipeline`), or through
2188
+ the `coords`, `signals`, `attrs`, and/or `data_points`
2189
+ arguments. If any of the latter are used, their values will
2190
+ completely override any values for these parameters found from
2191
+ `data`.
2192
+
2193
+ :param data: Data from the previous item in a `Pipeline`.
2194
+ :type data: list[PipelineData]
2195
+ :param nxname: Name for the returned NeXus NXdata object,
2196
+ defaults to `'data'`.
2197
+ :type nxname: str, optional
2198
+ :param coords: List of dictionaries defining the coordinates
2199
+ of the dataset. Each dictionary must have the keys
2200
+ `'name'` and `'values'`, whose values are the name of the
2201
+ coordinate axis (a string) and all the unique values of
2202
+ that coordinate for the structured dataset (a list of
2203
+ numbers), respectively. A third item in the dictionary is
2204
+ optional, but highly recommended: `'attrs'` may provide a
2205
+ dictionary of attributes to attach to the coordinate axis
2206
+ that assist in in interpreting the returned NeXus NXdata
2207
+ representation of the dataset. It is strongly recommended
2208
+ to provide the units of the values along an axis in the
2209
+ `attrs` dictionary.
2210
+ :type coords: list[dict[str, object]], optional
2211
+ :param signals: List of dictionaries defining the signals of
2212
+ the dataset. Each dictionary must have the keys `'name'`
2213
+ and `'shape'`, whose values are the name of the signal
2214
+ field (a string) and the shape of the signal's value at
2215
+ each point in the dataset (a list of zero or more
2216
+ integers), respectively. A third item in the dictionary is
2217
+ optional, but highly recommended: `'attrs'` may provide a
2218
+ dictionary of attributes to attach to the signal fieldthat
2219
+ assist in in interpreting the returned NeXus NXdata
2220
+ representation of the dataset. It is strongly recommended
2221
+ to provide the units of the signal's values `attrs`
2222
+ dictionary.
2223
+ :type signals: list[dict[str, object]], optional
2224
+ :param attrs: An arbitrary dictionary of attributes to assign
2225
+ to the returned NeXus NXdata object.
2226
+ :type attrs: dict[str, object], optional
2227
+ :param data_points: A list of data points to partially (or
2228
+ even entirely) fil out the "empty" signal NeXus NXfield's
2229
+ before returning the NeXus NXdata object.
2230
+ :type data_points: list[dict[str, object]], optional
2231
+ :param extra_nxfields: List "extra" NeXus NXfields to include
2232
+ that can be described neither as a signal of the dataset,
2233
+ not a dedicated coordinate. This paramteter is good for
2234
+ including "alternate" values for one of the coordinate
2235
+ dimensions -- the same coordinate axis expressed in
2236
+ different units, for instance. Each item in the list should
2237
+ be a dictionary of parameters for the
2238
+ `nexusformat.nexus.NXfield` constructor.
2239
+ :type extra_nxfields: list[dict[str, object]], optional
2240
+ :param duplicates: Behavior to use if any new data points occur
2241
+ at the same point in the dataset's coordinate space as an
2242
+ existing data point. Allowed values for `duplicates` are:
2243
+ `'overwrite'` and `'block'`. Defaults to `'overwrite'`.
2244
+ :type duplicates: Literal['overwrite', 'block']
2245
+ :returns: A NeXus NXdata object that represents the structured
2246
+ dataset as specified.
2247
+ :rtype: nexusformat.nexus.NXdata
2248
+ """
2249
+ self.nxname = nxname
2250
+
2251
+ if coords is None:
2252
+ coords = []
2253
+ if signals is None:
2254
+ signals = []
2255
+ if attrs is None:
2256
+ attrs = {}
2257
+ if extra_nxfields is None:
2258
+ extra_nxfields = []
2259
+ self.coords = coords
2260
+ self.signals = signals
2261
+ self.attrs = attrs
2262
+ self.data_points = data_points
2263
+ try:
2264
+ setup_params = self.unwrap_pipelinedata(data)[0]
2265
+ except Exception:
2266
+ setup_params = None
2267
+ if isinstance(setup_params, dict):
2268
+ for a in ('coords', 'signals', 'attrs', 'data_points'):
2269
+ setup_param = setup_params.get(a)
2270
+ if not getattr(self, a) and setup_param is not None:
2271
+ self.logger.info(f'Using input data from pipeline for {a}')
2272
+ setattr(self, a, setup_param)
2273
+ else:
2274
+ self.logger.info(
2275
+ f'Ignoring input data from pipeline for {a}')
2276
+ else:
2277
+ self.logger.warning('Ignoring all input data from pipeline')
2278
+ self.shape = tuple(len(c['values']) for c in self.coords)
2279
+ self.extra_nxfields = extra_nxfields
2280
+ self.duplicates = duplicates
2281
+ self.init_nxdata()
2282
+
2283
+ if self.data_points is not None:
2284
+ for d in self.data_points:
2285
+ self.add_data_point(d)
2286
+
2287
+ return self.nxdata
2288
+
2289
+ def add_data_point(self, data_point):
2290
+ """Add a data point to this dataset.
2291
+ 1. Validate `data_point`.
2292
+ 2. Append `data_point` to `self.data_points`.
2293
+ 3. Update signal `NXfield`s in `self.nxdata`.
2294
+
2295
+ :param data_point: Data point defining a point in the
2296
+ dataset's coordinate space and the new signal values at
2297
+ that point.
2298
+ :type data_point: dict[str, object]
2299
+ :returns: None
2300
+ """
2301
+ self.logger.info(
2302
+ f'Adding data point no. {data_point["dataset_point_index"]+1} of '
2303
+ f'{len(self.data_points)}')
2304
+ self.logger.debug(f'New data point: {data_point}')
2305
+ valid, msg = self.validate_data_point(data_point)
2306
+ if not valid:
2307
+ self.logger.error(f'Cannot add data point: {msg}')
2308
+ else:
2309
+ self.update_nxdata(data_point)
2310
+
2311
+ def validate_data_point(self, data_point):
2312
+ """Return `True` if `data_point` occurs at a valid point in
2313
+ this structured dataset's coordinate space, `False`
2314
+ otherwise. Also validate shapes of signal values and add NaN
2315
+ values for any missing signals.
2316
+
2317
+ :param data_point: Data point defining a point in the
2318
+ dataset's coordinate space and the new signal values at
2319
+ that point.
2320
+ :type data_point: dict[str, object]
2321
+ :returns: Validity of `data_point`, message
2322
+ :rtype: bool, str
2323
+ """
2324
+ valid = True
2325
+ msg = ''
2326
+ # Convert all values to numpy types
2327
+ data_point = {k: np.asarray(v) for k, v in data_point.items()}
2328
+ # Ensure data_point defines a specific point in the dataset's
2329
+ # coordinate space
2330
+ if not all(c['name'] in data_point for c in self.coords):
2331
+ valid = False
2332
+ msg = 'Missing coordinate values'
2333
+ # Ensure a value is present for all signals
2334
+ for s in self.signals:
2335
+ name = s['name']
2336
+ if name not in data_point:
2337
+ data_point[name] = np.full(s['shape'], 0)
2338
+ else:
2339
+ if not data_point[name].shape == tuple(s['shape']):
2340
+ valid = False
2341
+ msg = f'Shape mismatch for signal {s}'
2342
+ return valid, msg
2343
+
2344
+ def init_nxdata(self):
2345
+ """Initialize an empty NeXus NXdata representing this dataset
2346
+ to `self.nxdata`; values for axes' `NXfield`s are filled out,
2347
+ values for signals' `NXfield`s are empty an can be filled out
2348
+ later. Save the empty NeXus NXdata object to the NeXus file.
2349
+ Initialise `self.nxfile` and `self.nxdata_path` with the
2350
+ `NXFile` object and actual nxpath used to save and make updates
2351
+ to the Nexus NXdata object.
2352
+ """
2353
+ # Third party modules
2354
+ from nexusformat.nexus import (
2355
+ NXdata,
2356
+ NXfield,
2357
+ )
2358
+
2359
+ axes = tuple(NXfield(
2360
+ value=c['values'],
2361
+ name=c['name'],
2362
+ attrs=c.get('attrs'),
2363
+ dtype=c.get('dtype', 'float64')) for c in self.coords)
2364
+ entries = {s['name']: NXfield(
2365
+ value=np.full((*self.shape, *s['shape']), 0),
2366
+ name=s['name'],
2367
+ attrs=s.get('attrs'),
2368
+ dtype=s.get('dtype', 'float64')) for s in self.signals}
2369
+ extra_nxfields = [NXfield(**params) for params in self.extra_nxfields]
2370
+ extra_nxfields = {f.nxname: f for f in extra_nxfields}
2371
+ entries.update(extra_nxfields)
2372
+ self.nxdata = NXdata(
2373
+ name=self.nxname, axes=axes, entries=entries, attrs=self.attrs)
2374
+
2375
+ def update_nxdata(self, data_point):
2376
+ """Update `self.nxdata`'s NXfield values.
2377
+
2378
+ :param data_point: Data point defining a point in the
2379
+ dataset's coordinate space and the new signal values at
2380
+ that point.
2381
+ :type data_point: dict[str, object]
2382
+ :returns: None
2383
+ """
2384
+ index = self.get_index(data_point)
2385
+ for s in self.signals:
2386
+ name = s['name']
2387
+ if name in data_point:
2388
+ self.nxdata[name][index] = data_point[name]
2389
+
2390
+ def get_index(self, data_point):
2391
+ """Return a tuple representing the array index of `data_point`
2392
+ in the coordinate space of the dataset.
2393
+
2394
+ :param data_point: Data point defining a point in the
2395
+ dataset's coordinate space.
2396
+ :type data_point: dict[str, object]
2397
+ :returns: Multi-dimensional index of `data_point` in the
2398
+ dataset's coordinate space.
2399
+ :rtype: tuple
2400
+ """
2401
+ return tuple(c['values'].index(data_point[c['name']])
2402
+ for c in self.coords)
2403
+
2404
+
2405
+ class UnstructuredToStructuredProcessor(Processor):
2406
+ """Processor to reshape data in an NXdata from an "unstructured"
2407
+ to a "structured" representation.
2408
+ """
2409
+ def process(self, data, nxpath=None):
2410
+ # Third party modules
2411
+ from nexusformat.nexus import NXdata
2412
+
2413
+ try:
2414
+ nxobject = self.get_data(data)
2415
+ except Exception:
2416
+ nxobject = self.unwrap_pipelinedata(data)[0]
2417
+ if isinstance(nxobject, NXdata):
2418
+ return self.convert_nxdata(nxobject)
2419
+ if nxpath is not None:
2420
+ # Local modules
2421
+ # from CHAP.utils.general import nxcopy
2422
+ try:
2423
+ nxobject = nxobject[nxpath]
2424
+ except Exception as exc:
2425
+ raise ValueError(
2426
+ f'Invalid parameter nxpath ({nxpath})') from exc
2427
+ else:
2428
+ raise ValueError(f'Invalid input data ({data})')
2429
+ return self.convert_nxdata(nxobject)
2430
+
2431
+ def convert_nxdata(self, nxdata):
2432
+ # Third party modules
2433
+ from nexusformat.nexus import (
2434
+ NXdata,
2435
+ NXfield,
2436
+ )
2437
+
2438
+ # Local modules
2439
+ from CHAP.edd.processor import get_axes
2440
+
2441
+ # Extract axes from the NXdata attributes
2442
+ axes = get_axes(nxdata)
2443
+ for a in axes:
2444
+ if a not in nxdata:
2445
+ raise ValueError(f'Missing coordinates for {a}')
2446
+
2447
+ # Check the independent dimensions and axes
2448
+ unstructured_axes = []
2449
+ unstructured_dim = None
2450
+ for a in axes:
2451
+ if not isinstance(nxdata[a], NXfield):
2452
+ raise ValueError(
2453
+ f'Invalid axis field type ({type(nxdata[a])})')
2454
+ if len(nxdata[a].shape) == 1:
2455
+ if not unstructured_axes:
2456
+ unstructured_axes.append(a)
2457
+ unstructured_dim = nxdata[a].size
2458
+ else:
2459
+ if nxdata[a].size == unstructured_dim:
2460
+ unstructured_axes.append(a)
2461
+ elif 'unstructured_axes' in nxdata.attrs:
2462
+ raise ValueError(f'Inconsistent axes dimensions')
2463
+ elif 'unstructured_axes' in nxdata.attrs:
2464
+ raise ValueError(
2465
+ f'Invalid unstructered axis shape ({nxdata[a].shape})')
2466
+ if not axes and hasattr(nxdata, 'signal'):
2467
+ if len(nxdata[nxdata.signal].shape) < 2:
2468
+ raise ValueError(
2469
+ f'Invalid signal shape ({nxdata[nxdata.signal].shape})')
2470
+ unstructured_dim = nxdata[nxdata.signal].shape[0]
2471
+ for k, v in nxdata.items():
2472
+ if (isinstance(v, NXfield) and len(v.shape) == 1
2473
+ and v.shape[0] == unstructured_dim):
2474
+ unstructured_axes.append(k)
2475
+ if unstructured_dim is None:
2476
+ raise ValueError(f'Unable to determine the unstructered axes')
2477
+ axes = unstructured_axes
2478
+
2479
+ # Identify unique coordinate points for each axis
2480
+ unique_coords = {}
2481
+ coords = {}
2482
+ axes_attrs = {}
2483
+ for a in axes:
2484
+ coords[a] = nxdata[a].nxdata
2485
+ unique_coords[a] = np.sort(np.unique(nxdata[a].nxdata))
2486
+ axes_attrs[a] = deepcopy(nxdata[a].attrs)
2487
+ if 'target' in axes_attrs[a]:
2488
+ del axes_attrs[a]['target']
2489
+
2490
+ # Calculate the total number of unique coordinate points
2491
+ unique_npts = np.prod([len(v) for k, v in unique_coords.items()])
2492
+ if unique_npts != unstructured_dim:
2493
+ self.logger.warning('The unstructered grid does not fully map to '
2494
+ 'a structered one (there are missing points)')
2495
+
2496
+ # Identify the signals and the data point axes
2497
+ signals = []
2498
+ data_point_axes = []
2499
+ data_point_shape = []
2500
+ if hasattr(nxdata, 'signal'):
2501
+ if (len(nxdata[nxdata.signal].shape) < 2
2502
+ or nxdata[nxdata.signal].shape[0] != unstructured_dim):
2503
+ raise ValueError(
2504
+ f'Invalid signal shape ({nxdata[nxdata.signal].shape})')
2505
+ signals = [nxdata.signal]
2506
+ data_point_shape = [nxdata[nxdata.signal].shape[1:]]
2507
+ for k, v in nxdata.items():
2508
+ if (isinstance(v, NXfield) and k not in axes and k not in signals
2509
+ and v.shape[0] == unstructured_dim):
2510
+ signals.append(k)
2511
+ if not data_point_shape:
2512
+ data_point_shape.append(v.shape[1:])
2513
+ if len(data_point_shape) == 1:
2514
+ data_point_shape = data_point_shape[0]
2515
+ else:
2516
+ data_point_shape = []
2517
+ for _ in data_point_shape:
2518
+ for k, v in nxdata.items():
2519
+ if (isinstance(v, NXfield) and k not in axes
2520
+ and v.shape == data_point_shape):
2521
+ data_point_axes.append(k)
2522
+
2523
+ # Create the structured NXdata object
2524
+ structured_shape = tuple(len(unique_coords[a]) for a in axes)
2525
+ attrs = deepcopy(nxdata.attrs)
2526
+ if 'unstructured_axes' in attrs:
2527
+ attrs.pop('unstructured_axes')
2528
+ attrs['axes'] = axes
2529
+ nxdata_structured = NXdata(
2530
+ name=f'{nxdata.nxname}_structured',
2531
+ **{a: NXfield(
2532
+ value=unique_coords[a],
2533
+ attrs=axes_attrs[a])
2534
+ for a in axes},
2535
+ **{s: NXfield(
2536
+ # value=np.reshape( # FIX not always a sound way to reshape.
2537
+ # nxdata[s], (*structured_shape, *nxdata[s].shape[1:])),
2538
+ dtype=nxdata[s].dtype,
2539
+ shape=(*structured_shape, *nxdata[s].shape[1:]),
2540
+ attrs=nxdata[s].attrs)
2541
+ for s in signals},
2542
+ attrs=attrs)
2543
+ if len(data_point_axes) == 1:
2544
+ axes = nxdata_structured.attrs['axes']
2545
+ if isinstance(axes, str):
2546
+ axes = [axes]
2547
+ nxdata_structured.attrs['axes'] = axes + data_point_axes
2548
+ for a in data_point_axes:
2549
+ nxdata_structured[a] = NXfield(
2550
+ value=nxdata[a], attrs=nxdata[a].attrs)
2551
+
2552
+ # Populate the structured NXdata object with values
2553
+ for i, coord in enumerate(zip(*tuple(nxdata[a].nxdata for a in axes))):
2554
+ structured_index = tuple(
2555
+ np.asarray(
2556
+ coord[ii] == unique_coords[axes[ii]]).nonzero()[0][0]
2557
+ for ii in range(len(axes)))
2558
+ for s in signals:
2559
+ nxdata_structured[s][structured_index] = nxdata[s][i]
2560
+
2561
+ return nxdata_structured
2562
+
2563
+
2564
+ class UpdateNXvalueProcessor(Processor):
2565
+ """Processor to fill in part(s) of a NeXus object representing a
2566
+ structured dataset that's already been written to a NeXus file.
2567
+
2568
+ This Processor is most useful as an "update" step for a NeXus
2569
+ NXdata object created by `common.SetupNXdataProcessor`, and is
2570
+ most easy to use in a `Pipeline` immediately after another
2571
+ `PipelineItem` designed specifically to return a value that can
2572
+ be used as input to this `Processor`.
2573
+ """
2574
+ def process(self, data, nxfilename, data_points=None):
2575
+ """Write new data values to an existing NeXus object
2576
+ representing an unstructured dataset in a NeXus file.
2577
+ Return the list of data points used to update the dataset.
2578
+
2579
+ :param data: Data from the previous item in a `Pipeline`. May
2580
+ contain a list of data points that will extend the list of
2581
+ data points optionally provided with the `data_points`
2582
+ argument.
2583
+ :type data: list[PipelineData]
2584
+ :param nxfilename: Name of the NeXus file containing the
2585
+ NeXus object to update.
2586
+ :type nxfilename: str
2587
+ :param data_points: List of data points, each one a dictionary
2588
+ whose keys are the names of the nxpath, the index of the
2589
+ data point in the dataset, and the data value.
2590
+ :type data_points: Optional[list[dict[str, object]]]
2591
+ :returns: Complete list of data points used to update the
2592
+ dataset.
2593
+ :rtype: list[dict[str, object]]
2594
+ """
2595
+ # Third party modules
2596
+ from nexusformat.nexus import NXFile
2597
+
2598
+ if data_points is None:
2599
+ data_points = []
2600
+ self.logger.debug(f'Got {len(data_points)} data points from keyword')
2601
+ ddata_points = self.unwrap_pipelinedata(data)[0]
2602
+ if isinstance(ddata_points, list):
2603
+ self.logger.debug(f'Got {len(ddata_points)} from pipeline data')
2604
+ data_points.extend(ddata_points)
2605
+ self.logger.info(f'Updating a total of {len(data_points)} data points')
2606
+
2607
+ if not os.path.isabs(nxfilename):
2608
+ nxfilename = os.path.join(self.inputdir, nxfilename)
2609
+ nxfile = NXFile(nxfilename, 'rw')
2610
+
2611
+ indices = []
2612
+ for data_point in data_points:
2613
+ try:
2614
+ nxfile.writevalue(
2615
+ data_point['nxpath'], np.asarray(data_point['value']),
2616
+ data_point['index'])
2617
+ indices.append(data_point['index'])
2618
+ except Exception as exc:
2619
+ self.logger.error(
2620
+ f'Error updating {data_point["nxpath"]} for data point '
2621
+ f'{data_point["index"]}: {exc}')
2622
+ else:
2623
+ self.logger.debug(f'Updated data point {data_point}')
2624
+
2625
+ nxfile.close()
2626
+
2627
+ return data_points
2628
+
2629
+
2630
+ class UpdateNXdataProcessor(Processor):
2631
+ """Processor to fill in part(s) of a NeXus NXdata representing a
2632
+ structured dataset that's already been written to a NeXus file.
2633
+
2634
+ This Processor is most useful as an "update" step for a NeXus
2635
+ NXdata object created by `common.SetupNXdataProcessor`, and is
2636
+ most easy to use in a `Pipeline` immediately after another
2637
+ `PipelineItem` designed specifically to return a value that can
2638
+ be used as input to this `Processor`.
2639
+ """
2640
+ def process(
2641
+ self, data, nxfilename, nxdata_path, data_points=None,
2642
+ allow_approximate_coordinates=False):
2643
+ """Write new data points to the signal fields of an existing
2644
+ NeXus NXdata object representing a structued dataset in a NeXus
2645
+ file. Return the list of data points used to update the
2646
+ dataset.
2647
+
2648
+ :param data: Data from the previous item in a `Pipeline`. May
2649
+ contain a list of data points that will extend the list of
2650
+ data points optionally provided with the `data_points`
2651
+ argument.
2652
+ :type data: list[PipelineData]
2653
+ :param nxfilename: Name of the NeXus file containing the
2654
+ NeXus NXdata object to update.
2655
+ :type nxfilename: str
2656
+ :param nxdata_path: The path to the NeXus NXdata object to
2657
+ update in the file.
2658
+ :type nxdata_path: str
2659
+ :param data_points: List of data points, each one a dictionary
2660
+ whose keys are the names of the coordinates and axes, and
2661
+ whose values are the values of each coordinate / signal at
2662
+ a single point in the dataset. Deafults to None.
2663
+ :type data_points: Optional[list[dict[str, object]]]
2664
+ :param allow_approximate_coordinates: Parameter to allow the
2665
+ nearest existing match for the new data points'
2666
+ coordinates to be used if an exact match connot be found
2667
+ (sometimes this is due simply to differences in rounding
2668
+ convetions). Defaults to False.
2669
+ :type allow_approximate_coordinates: bool, optional
2670
+ :returns: Complete list of data points used to update the dataset.
2671
+ :rtype: list[dict[str, object]]
2672
+ """
2673
+ # Third party modules
2674
+ from nexusformat.nexus import NXFile
2675
+
2676
+ if data_points is None:
2677
+ data_points = []
2678
+ self.logger.debug(f'Got {len(data_points)} data points from keyword')
2679
+ _data_points = self.unwrap_pipelinedata(data)[0]
2680
+ if isinstance(_data_points, list):
2681
+ self.logger.debug(f'Got {len(_data_points)} from pipeline data')
2682
+ data_points.extend(_data_points)
2683
+ self.logger.info(f'Updating {len(data_points)} data points total')
2684
+
2685
+ if not os.path.isabs(nxfilename):
2686
+ nxfilename = os.path.join(self.inputdir, nxfilename)
2687
+ nxfile = NXFile(nxfilename, 'rw')
2688
+ nxdata = nxfile.readfile()[nxdata_path]
2689
+ axes_names = [a.nxname for a in nxdata.nxaxes]
2690
+
2691
+ data_points_used = []
2692
+ for i, d in enumerate(data_points):
2693
+ # Verify that the data point contains a value for all
2694
+ # coordinates in the dataset.
2695
+ if not all(a in d for a in axes_names):
2696
+ self.logger.error(
2697
+ f'Data point {i} is missing a value for at least one '
2698
+ f'axis. Skipping. Axes are: {", ".join(axes_names)}')
2699
+ continue
2700
+ self.logger.debug(
2701
+ f'Coordinates for data point {i}: ' +
2702
+ ', '.join([f'{a}={d[a]}' for a in axes_names]))
2703
+ # Get the index of the data point in the dataset based on
2704
+ # its values for each coordinate.
2705
+ try:
2706
+ index = tuple(np.where(a.nxdata == d[a.nxname])[0][0]
2707
+ for a in nxdata.nxaxes)
2708
+ except Exception:
2709
+ if allow_approximate_coordinates:
2710
+ try:
2711
+ index = tuple(
2712
+ np.argmin(np.abs(a.nxdata - d[a.nxname]))
2713
+ for a in nxdata.nxaxes)
2714
+ self.logger.warning(
2715
+ f'Nearest match for coordinates of data point {i}:'
2716
+ ', '.join(
2717
+ [f'{a.nxname}={a[_i]}'
2718
+ for _i, a in zip(index, nxdata.nxaxes)]))
2719
+ except Exception:
2720
+ self.logger.error(
2721
+ f'Cannot get the index of data point {i}. '
2722
+ 'Skipping.')
2723
+ continue
2724
+ else:
2725
+ self.logger.error(
2726
+ f'Cannot get the index of data point {i}. Skipping.')
2727
+ continue
2728
+ self.logger.debug(f'Index of data point {i}: {index}')
2729
+ # Update the signals contained in this data point at the
2730
+ # proper index in the dataset's singal `NXfield`s
2731
+ for k, v in d.items():
2732
+ if k in axes_names:
2733
+ continue
2734
+ try:
2735
+ nxfile.writevalue(
2736
+ os.path.join(nxdata_path, k), np.asarray(v), index)
2737
+ # self.logger.debug(
2738
+ # f'Wrote to {os.path.join(nxdata_path, k)} in '
2739
+ # f'{nxfilename} at index {index} value: {np.asarray(v)}'
2740
+ # f' (type: {type(v)})')
2741
+ except Exception as exc:
2742
+ self.logger.error(
2743
+ f'Error updating signal {k} for new data point '
2744
+ f'{i} (dataset index {index}): {exc}')
2745
+ data_points_used.append(d)
2746
+
2747
+ nxfile.close()
2748
+
2749
+ return data_points_used
2750
+
2751
+
2752
+ class NXdataToDataPointsProcessor(Processor):
2753
+ """Transform a NeXus NXdata object into a list of dictionaries.
2754
+ Each dictionary represents a single data point in the coordinate
2755
+ space of the dataset. The keys are the names of the signals and
2756
+ axes in the dataset, and the values are a single scalar value (in
2757
+ the case of axes) or the value of the signal at that point in the
2758
+ coordinate space of the dataset (in the case of signals -- this
2759
+ means that values for signals may be any shape, depending on the
2760
+ shape of the signal itself).
2761
+ """
2762
+ def process(self, data):
2763
+ """Return a list of dictionaries representing the coordinate
2764
+ and signal values at every point in the dataset provided.
2765
+
2766
+ :param data: Input pipeline data containing a NeXus NXdata
2767
+ object.
2768
+ :type data: list[PipelineData]
2769
+ :returns: List of all data points in the dataset.
2770
+ :rtype: list[dict[str,object]]
2771
+ """
2772
+ nxdata = self.unwrap_pipelinedata(data)[0]
2773
+
2774
+ data_points = []
2775
+ axes_names = [a.nxname for a in nxdata.nxaxes]
2776
+ self.logger.info(f'Dataset axes: {axes_names}')
2777
+ dataset_shape = tuple([a.size for a in nxdata.nxaxes])
2778
+ self.logger.info(f'Dataset shape: {dataset_shape}')
2779
+ signal_names = [k for k, v in nxdata.entries.items()
2780
+ if not k in axes_names \
2781
+ and v.shape[:len(dataset_shape)] == dataset_shape]
2782
+ self.logger.info(f'Dataset signals: {signal_names}')
2783
+ other_fields = [k for k, v in nxdata.entries.items()
2784
+ if not k in axes_names + signal_names]
2785
+ if len(other_fields) > 0:
2786
+ self.logger.warning(
2787
+ 'Ignoring the following fields that cannot be interpreted as '
2788
+ f'either dataset coordinates or signals: {other_fields}')
2789
+ for i in np.ndindex(dataset_shape):
2790
+ data_points.append({
2791
+ **{a: nxdata[a][_i] for a, _i in zip(axes_names, i)},
2792
+ **{s: nxdata[s].nxdata[i] for s in signal_names},
2793
+ })
2794
+ return data_points
2795
+
2796
+
2797
+ class XarrayToNexusProcessor(Processor):
2798
+ """A Processor to convert the data in an `xarray` structure to a
2799
+ NeXus NXdata object.
2800
+ """
2801
+ def process(self, data):
2802
+ """Return `data` represented as a NeXus NXdata object.
2803
+
2804
+ :param data: The input `xarray` structure.
2805
+ :type data: Union[xarray.DataArray, xarray.Dataset]
2806
+ :return: The data and metadata in `data`.
2807
+ :rtype: nexusformat.nexus.NXdata
2808
+ """
2809
+ # Third party modules
2810
+ from nexusformat.nexus import (
2811
+ NXdata,
2812
+ NXfield,
2813
+ )
2814
+
2815
+ data = self.unwrap_pipelinedata(data)[-1]
2816
+ signal = NXfield(value=data.data, name=data.name, attrs=data.attrs)
2817
+ axes = []
2818
+ for name, coord in data.coords.items():
2819
+ axes.append(
2820
+ NXfield(value=coord.data, name=name, attrs=coord.attrs))
2821
+ axes = tuple(axes)
2822
+
2823
+ return NXdata(signal=signal, axes=axes)
2824
+
2825
+
2826
+ class XarrayToNumpyProcessor(Processor):
2827
+ """A Processor to convert the data in an `xarray.DataArray`
2828
+ structure to an `numpy.ndarray`.
2829
+ """
2830
+ def process(self, data):
2831
+ """Return just the signal values contained in `data`.
2832
+
2833
+ :param data: The input `xarray.DataArray`.
2834
+ :type data: xarray.DataArray
2835
+ :return: The data in `data`.
2836
+ :rtype: numpy.ndarray
2837
+ """
2838
+ return self.unwrap_pipelinedata(data)[-1].data
2839
+
2840
+
2841
+ #class SumProcessor(Processor):
2842
+ # """A Processor to sum the data in a NeXus NXobject, given a set of
2843
+ # nxpaths.
2844
+ # """
2845
+ # def process(self, data):
2846
+ # """Return the summed data array
2847
+ #
2848
+ # :param data:
2849
+ # :type data:
2850
+ # :return: The summed data.
2851
+ # :rtype: numpy.ndarray
2852
+ # """
2853
+ # nxentry, nxpaths = self.unwrap_pipelinedata(data)[-1]
2854
+ # if len(nxpaths) == 1:
2855
+ # return nxentry[nxpaths[0]]
2856
+ # sum_data = deepcopy(nxentry[nxpaths[0]])
2857
+ # for nxpath in nxpaths[1:]:
2858
+ # nxdata = nxentry[nxpath]
2859
+ # for entry in nxdata.entries:
2860
+ # sum_data[entry] += nxdata[entry]
2861
+ #
2862
+ # return sum_data
2863
+
2864
+
2865
+ if __name__ == '__main__':
2866
+ # Local modules
2867
+ from CHAP.processor import main
2868
+
2869
+ main()