ChessAnalysisPipeline 0.0.13__py3-none-any.whl → 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ChessAnalysisPipeline might be problematic. Click here for more details.

CHAP/common/processor.py CHANGED
@@ -19,64 +19,146 @@ class AnimationProcessor(Processor):
19
19
  """A Processor to show and return an animation.
20
20
  """
21
21
  def process(
22
- self, data, num_frames, axis=0, interval=1000, blit=True,
23
- repeat=True, repeat_delay=1000):
22
+ self, data, num_frames, vmin=None, vmax=None, axis=None,
23
+ interval=1000, blit=True, repeat=True, repeat_delay=1000,
24
+ interactive=False):
24
25
  """Show and return an animation of image slices from a dataset
25
26
  contained in `data`.
26
27
 
27
28
  :param data: Input data.
28
- :type data: CHAP.pipeline.PipelineData
29
+ :type data: list[PipelineData]
29
30
  :param num_frames: Number of frames for the animation.
30
31
  :type num_frames: int
31
- :param axis: Axis direction of the image slices,
32
+ :param vmin: Minimum array value in image slice, default to
33
+ `None`, which uses the actual minimum value in the slice.
34
+ :type vmin: float
35
+ :param vmax: Maximum array value in image slice, default to
36
+ `None`, which uses the actual maximum value in the slice.
37
+ :type vmax: float
38
+ :param axis: Axis direction or name of the image slices,
32
39
  defaults to `0`
33
- :type axis: int, optional
34
- :param interval: Delay between frames in milliseconds,
35
- defaults to `1000`
40
+ :type axis: Union[int, str], optional
41
+ :param interval: Delay between frames in milliseconds (only
42
+ used when interactive=True), defaults to `1000`
36
43
  :type interval: int, optional
37
44
  :param blit: Whether blitting is used to optimize drawing,
38
45
  default to `True`
39
46
  :type blit: bool, optional
40
47
  :param repeat: Whether the animation repeats when the sequence
41
- of frames is completed, defaults to `True`
48
+ of frames is completed (only used when interactive=True),
49
+ defaults to `True`
42
50
  :type repeat: bool, optional
43
51
  :param repeat_delay: Delay in milliseconds between consecutive
44
- animation runs if repeat is `True`, defaults to `1000`
52
+ animation runs if repeat is `True` (only used when
53
+ interactive=True), defaults to `1000`
45
54
  :type repeat_delay: int, optional
55
+ :param interactive: Allows for user interactions, defaults to
56
+ `False`.
57
+ :type interactive: bool, optional
46
58
  :return: The matplotlib animation.
47
59
  :rtype: matplotlib.animation.ArtistAnimation
48
60
  """
61
+ # System modules
62
+ from os.path import (
63
+ isabs,
64
+ join,
65
+ )
66
+
49
67
  # Third party modules
50
68
  import matplotlib.animation as animation
51
69
  import matplotlib.pyplot as plt
52
70
 
71
+ # Get the default Nexus NXdata object
72
+ data = self.unwrap_pipelinedata(data)[0]
73
+ try:
74
+ nxdata = data.get_default()
75
+ except:
76
+ if nxdata.nxclass != 'NXdata':
77
+ raise ValueError('Invalid default pathway to an NXdata object '
78
+ f'in ({data})')
79
+
53
80
  # Get the frames
54
- data = self.unwrap_pipelinedata(data)[-1]
55
- delta = int(data.shape[axis]/(num_frames+1))
56
- indices = np.linspace(delta, data.shape[axis]-delta, num_frames)
57
- if data.ndim == 3:
81
+ axes = nxdata.attrs.get('axes', None)
82
+ title = f'{nxdata.nxpath}/{nxdata.signal}'
83
+ if nxdata.nxsignal.ndim == 2:
84
+ exit('AnimationProcessor not tested yet for a 2D dataset')
85
+ elif nxdata.nxsignal.ndim == 3:
86
+ if isinstance(axis, int):
87
+ if not 0 <= axis < nxdata.nxsignal.ndim:
88
+ raise ValueError(f'axis index out of range ({axis} not in '
89
+ f'[0, {nxdata.nxsignal.ndim-1}])')
90
+ axis_name = 'axis {axis}'
91
+ elif isinstance(axis, str):
92
+ if axes is None or axis not in list(axes.nxdata):
93
+ raise ValueError(
94
+ f'Unable to match axis = {axis} in {nxdata.tree}')
95
+ axes = list(axes.nxdata)
96
+ axis_name = axis
97
+ axis = axes.index(axis)
98
+ else:
99
+ raise ValueError(f'Invalid parameter axis ({axis})')
100
+ delta = int(nxdata.nxsignal.shape[axis]/(num_frames+1))
101
+ indices = np.linspace(
102
+ delta, nxdata.nxsignal.shape[axis]-delta, num_frames)
58
103
  if not axis:
59
- frames = [data[int(index)] for index in indices]
104
+ frames = [nxdata[nxdata.signal][int(index),:,:]
105
+ for index in indices]
60
106
  elif axis == 1:
61
- frames = [data[:,int(index),:] for index in indices]
107
+ frames = [nxdata[nxdata.signal][:,int(index),:]
108
+ for index in indices]
62
109
  elif axis == 2:
63
- frames = [data[:,:,int(index)] for index in indices]
110
+ frames = [nxdata[nxdata.signal][:,:,int(index)]
111
+ for index in indices]
112
+ if axes is None:
113
+ axes = [i for i in range(3) if i != axis]
114
+ row_coords = range(a.shape[1])
115
+ row_label = f'axis {axes[1]} index'
116
+ column_coords = range(a.shape[0])
117
+ column_label = f'axis {axes[0]} index'
118
+ else:
119
+ axes.pop(axis)
120
+ row_coords = nxdata[axes[1]].nxdata
121
+ row_label = axes[1]
122
+ if 'units' in nxdata[axes[1]].attrs:
123
+ row_label += f' ({nxdata[axes[1]].units})'
124
+ column_coords = nxdata[axes[0]].nxdata
125
+ column_label = axes[0]
126
+ if 'units' in nxdata[axes[0]].attrs:
127
+ column_label += f' ({nxdata[axes[0]].units})'
64
128
  else:
65
129
  raise ValueError('Invalid data dimension (must be 2D or 3D)')
66
130
 
67
- fig = plt.figure()
68
- # vmin = np.min(frames)/8
69
- # vmax = np.max(frames)/8
131
+
132
+ # Create the movie
133
+ if vmin is None or vmax is None:
134
+ a_max = frames[0].max()
135
+ for n in range(1, num_frames):
136
+ a_max = min(a_max, frames[n].max())
137
+ a_max = float(a_max)
138
+ if vmin is None:
139
+ vmin = -a_max
140
+ if vmax is None:
141
+ vmax = a_max
142
+ extent = (
143
+ row_coords[0], row_coords[-1], column_coords[-1], column_coords[0])
144
+ fig, ax = plt.subplots(figsize=(11, 8.5))
145
+ ax.set_title(title, fontsize='xx-large', pad=20)
146
+ ax.set_xlabel(row_label, fontsize='x-large')
147
+ ax.set_ylabel(column_label, fontsize='x-large')
148
+ fig.tight_layout()
70
149
  ims = [[plt.imshow(
71
- #frames[n], vmin=vmin,vmax=vmax, cmap='gray',
72
- frames[n], cmap='gray',
150
+ frames[n], extent=extent, origin='lower',
151
+ vmin=vmin, vmax=vmax, cmap='gray',
73
152
  animated=True)]
74
153
  for n in range(num_frames)]
75
- ani = animation.ArtistAnimation(
76
- fig, ims, interval=interval, blit=blit, repeat=repeat,
77
- repeat_delay=repeat_delay)
78
-
79
- plt.show()
154
+ plt.colorbar()
155
+ if interactive:
156
+ ani = animation.ArtistAnimation(
157
+ fig, ims, interval=interval, blit=blit, repeat=repeat,
158
+ repeat_delay=repeat_delay)
159
+ plt.show()
160
+ else:
161
+ ani = animation.ArtistAnimation(fig, ims, blit=blit)
80
162
 
81
163
  return ani
82
164
 
@@ -129,43 +211,809 @@ class AsyncProcessor(Processor):
129
211
  asyncio.run(execute_tasks(self.mgr, data))
130
212
 
131
213
 
214
+ class BinarizeProcessor(Processor):
215
+ """A Processor to binarize a dataset.
216
+ """
217
+ def process(
218
+ self, data, nxpath='', interactive=False, method='CHAP',
219
+ num_bin=256, axis=None, remove_original_data=False):
220
+ """Show and return a binarized dataset from a dataset
221
+ contained in `data`. The dataset must either be of type
222
+ `numpy.ndarray` or a NeXus NXobject object with a default path
223
+ to a NeXus NXfield object.
224
+
225
+ :param data: Input data.
226
+ :type data: list[PipelineData]
227
+ :param nxpath: The relative path to a specific NeXus NXentry or
228
+ NeXus NXdata object in the NeXus file tree to read the
229
+ input data from (ignored for Numpy or NeXus NXfield input
230
+ datasets), defaults to `''`
231
+ :type nxpath: str, optional
232
+ :param interactive: Allows for user interactions (ignored
233
+ for any method other than `'manual'`), defaults to `False`.
234
+ :type interactive: bool, optional
235
+ :param method: Binarization method, defaults to `'CHAP'`
236
+ (CHAP's internal implementation of Otzu's method).
237
+ :type method: Literal['CHAP', 'manual', 'otsu', 'yen', 'isodata',
238
+ 'minimum']
239
+ :param num_bin: The number of bins used to calculate the
240
+ histogram in the binarization algorithms (ignored for
241
+ method = `'manual'`), defaults to `256`.
242
+ :type num_bin: int, optional
243
+ :param axis: Axis direction of the image slices (ignored
244
+ for any method other than `'manual'`), defaults to `None`
245
+ :type axis: int, optional
246
+ :param remove_original_data: Removes the original data field
247
+ (ignored for Numpy input datasets), defaults to `False`.
248
+ :type force_remove_original_data: bool, optional
249
+ :raises ValueError: Upon invalid input parameters.
250
+ :return: The binarized dataset with a return type equal to
251
+ that of the input dataset.
252
+ :rtype: typing.Union[numpy.ndarray, nexusformat.nexus.NXobject]
253
+ """
254
+ # System modules
255
+ from os.path import join as os_join
256
+ from os.path import relpath
257
+
258
+ # Local modules
259
+ from CHAP.utils.general import (
260
+ is_int,
261
+ nxcopy,
262
+ )
263
+ from nexusformat.nexus import (
264
+ NXdata,
265
+ NXfield,
266
+ NXlink,
267
+ NXprocess,
268
+ nxsetconfig,
269
+ )
270
+
271
+ if method not in [
272
+ 'CHAP', 'manual', 'otsu', 'yen', 'isodata', 'minimum']:
273
+ raise ValueError(f'Invalid parameter method ({method})')
274
+ if not is_int(num_bin, gt=0):
275
+ raise ValueError(f'Invalid parameter num_bin ({num_bin})')
276
+ if not isinstance(remove_original_data, bool):
277
+ raise ValueError('Invalid parameter remove_original_data '
278
+ f'({remove_original_data})')
279
+
280
+ nxsetconfig(memory=100000)
281
+
282
+ # Get the dataset and make a copy if it is a NeXus NXgroup
283
+ dataset = self.unwrap_pipelinedata(data)[-1]
284
+ if isinstance(dataset, np.ndarray):
285
+ if method == 'manual':
286
+ if axis is not None and not is_int(axis, gt=0, lt=3):
287
+ raise ValueError(f'Invalid parameter axis ({axis})')
288
+ axes = ['i', 'j', 'k']
289
+ data = dataset
290
+ elif isinstance(dataset, NXfield):
291
+ if method == 'manual':
292
+ if axis is not None and not is_int(axis, gt=0, lt=3):
293
+ raise ValueError(f'Invalid parameter axis ({axis})')
294
+ axes = ['i', 'j', 'k']
295
+ if isinstance(dataset, NXfield):
296
+ if nxpath not in ('', '/'):
297
+ self.logger.warning('Ignoring parameter nxpath')
298
+ data = dataset.nxdata
299
+ else:
300
+ try:
301
+ data = dataset[nxpath].nxdata
302
+ except:
303
+ raise ValueError(f'Invalid parameter nxpath ({nxpath})')
304
+ else:
305
+ # Get the default Nexus NXdata object
306
+ try:
307
+ nxdefault = dataset.get_default()
308
+ except:
309
+ nxdefault = None
310
+ if nxdefault is not None and nxdefault.nxclass != 'NXdata':
311
+ raise ValueError('Invalid default pathway NXobject type '
312
+ f'({nxdefault.nxclass})')
313
+ # Get the requested NeXus NXdata object to binarize
314
+ if nxpath is None:
315
+ nxclass = dataset.nxclass
316
+ else:
317
+ try:
318
+ nxclass = dataset[nxpath].nxclass
319
+ except:
320
+ raise ValueError(f'Invalid parameter nxpath ({nxpath})')
321
+ if nxclass == 'NXdata':
322
+ nxdata = dataset[nxpath]
323
+ else:
324
+ if nxdefault is None:
325
+ raise ValueError(f'No default pathway to a NXdata object')
326
+ nxdata = nxdefault
327
+ nxsignal = nxdata.nxsignal
328
+ if method == 'manual':
329
+ if hasattr(nxdata.attrs, 'axes'):
330
+ axes = nxdata.attrs['axes']
331
+ if isinstance(axis, str):
332
+ if axis not in axes:
333
+ raise ValueError(f'Invalid parameter axis ({axis})')
334
+ axis = axes.index(axis)
335
+ elif axis is not None and not is_int(axis, gt=0, lt=3):
336
+ raise ValueError(f'Invalid parameter axis ({axis})')
337
+ else:
338
+ axes = ['i', 'j', 'k']
339
+ if nxsignal.ndim != 3:
340
+ raise ValueError('Invalid data dimension (must be 3D)')
341
+ data = nxsignal.nxdata
342
+ # Create a copy of the input NeXus object, removing the
343
+ # default NeXus NXdata object as well as the original
344
+ # dateset if the remove_original_data parameter is set
345
+ exclude_nxpaths = []
346
+ if nxdefault is not None:
347
+ exclude_nxpaths.append(
348
+ os_join(relpath(nxdefault.nxpath, dataset.nxpath)))
349
+ if remove_original_data:
350
+ if (nxdefault is None
351
+ or nxdefault.nxpath != nxdata.nxpath):
352
+ relpath_nxdata = relpath(nxdata.nxpath, dataset.nxpath)
353
+ keys = list(nxdata.keys())
354
+ keys.remove(nxsignal.nxname)
355
+ for axis in nxdata.axes:
356
+ keys.remove(axis)
357
+ if len(keys):
358
+ raise RuntimeError('Not tested yet')
359
+ exclude_nxpaths.append(os_join(
360
+ relpath(nxsignal.nxpath, dataset.nxpath)))
361
+ elif relpath_nxdata == '.':
362
+ exclude_nxpaths.append(nxsignal.nxname)
363
+ if dataset.nxclass != 'NXdata':
364
+ exclude_nxpaths += nxdata.axes
365
+ else:
366
+ exclude_nxpaths.append(relpath_nxdata)
367
+ if not (dataset.nxclass == 'NXdata'
368
+ or nxdata.nxsignal.nxtarget is None):
369
+ nxsignal = dataset[nxsignal.nxtarget]
370
+ nxgroup = nxsignal.nxgroup
371
+ keys = list(nxgroup.keys())
372
+ keys.remove(nxsignal.nxname)
373
+ for axis in nxgroup.axes:
374
+ keys.remove(axis)
375
+ if len(keys):
376
+ raise RuntimeError('Not tested yet')
377
+ exclude_nxpaths.append(os_join(
378
+ relpath(nxsignal.nxpath, dataset.nxpath)))
379
+ else:
380
+ exclude_nxpaths.append(os_join(
381
+ relpath(nxgroup.nxpath, dataset.nxpath)))
382
+ nxobject = nxcopy(dataset, exclude_nxpaths=exclude_nxpaths)
383
+
384
+ # Get a histogram of the data
385
+ if method not in ['manual', 'yen']:
386
+ counts, edges = np.histogram(data, bins=num_bin)
387
+ centers = edges[:-1] + 0.5 * np.diff(edges)
388
+
389
+ # Calculate the data cutoff threshold
390
+ if method == 'CHAP':
391
+ weights = np.cumsum(counts)
392
+ means = np.cumsum(counts * centers)
393
+ weights = weights[0:-1]/weights[-1]
394
+ means = means[0:-1]/means[-1]
395
+ variances = (means-weights)**2/(weights*(1.-weights))
396
+ threshold = centers[np.argmax(variances)]
397
+ elif method == 'otsu':
398
+ # Third party modules
399
+ from skimage.filters import threshold_otsu
400
+
401
+ threshold = threshold_otsu(hist=(counts, centers))
402
+ elif method == 'yen':
403
+ # Third party modules
404
+ from skimage.filters import threshold_yen
405
+
406
+ _min = data.min()
407
+ _max = data.max()
408
+ data = 1+(num_bin-1)*(data-_min)/(_max-_min)
409
+ counts, edges = np.histogram(data, bins=num_bin)
410
+ centers = edges[:-1] + 0.5 * np.diff(edges)
411
+
412
+ threshold = threshold_yen(hist=(counts, centers))
413
+ elif method == 'isodata':
414
+ # Third party modules
415
+ from skimage.filters import threshold_isodata
416
+
417
+ threshold = threshold_isodata(hist=(counts, centers))
418
+ elif method == 'minimum':
419
+ # Third party modules
420
+ from skimage.filters import threshold_minimum
421
+
422
+ threshold = threshold_minimum(hist=(counts, centers))
423
+ else:
424
+ # Third party modules
425
+ import matplotlib.pyplot as plt
426
+ from matplotlib.widgets import RadioButtons, Button
427
+
428
+ # Local modules
429
+ from CHAP.utils.general import (
430
+ select_roi_1d,
431
+ select_roi_2d,
432
+ )
433
+
434
+ def select_direction(direction):
435
+ """Callback function for the "Select direction" input."""
436
+ selected_direction.append(radio_btn.value_selected)
437
+ plt.close()
438
+
439
+ def accept(event):
440
+ """Callback function for the "Accept" button."""
441
+ selected_direction.append(radio_btn.value_selected)
442
+ plt.close()
443
+
444
+ # Select the direction for data averaging
445
+ if axis is not None:
446
+ mean_data = data.mean(axis=axis)
447
+ subaxes = [i for i in range(3) if i != axis]
448
+ else:
449
+ selected_direction = []
450
+
451
+ # Setup figure
452
+ title_pos = (0.5, 0.95)
453
+ title_props = {'fontsize': 'xx-large',
454
+ 'horizontalalignment': 'center',
455
+ 'verticalalignment': 'bottom'}
456
+ fig, axs = plt.subplots(ncols=3, figsize=(17, 8.5))
457
+ mean_data = []
458
+ for i, ax in enumerate(axs):
459
+ mean_data.append(data.mean(axis=i))
460
+ subaxes = [a for a in axes if a != axes[i]]
461
+ ax.imshow(mean_data[i], aspect='auto', cmap='gray')
462
+ ax.set_title(
463
+ f'Data averaged in {axes[i]}-direction',
464
+ fontsize='x-large')
465
+ ax.set_xlabel(subaxes[1], fontsize='x-large')
466
+ ax.set_ylabel(subaxes[0], fontsize='x-large')
467
+ fig_title = plt.figtext(
468
+ *title_pos,
469
+ 'Select a direction or press "Accept" for the default one '
470
+ f'({axes[0]}) to obtain the binary threshold value',
471
+ **title_props)
472
+ fig.subplots_adjust(bottom=0.25, top=0.85)
473
+
474
+ # Setup RadioButtons
475
+ select_text = plt.figtext(
476
+ 0.225, 0.175, 'Averaging direction', fontsize='x-large',
477
+ horizontalalignment='center', verticalalignment='center')
478
+ radio_btn = RadioButtons(
479
+ plt.axes([0.175, 0.05, 0.1, 0.1]), labels=axes, active=0)
480
+ radio_cid = radio_btn.on_clicked(select_direction)
481
+
482
+ # Setup "Accept" button
483
+ accept_btn = Button(
484
+ plt.axes([0.7, 0.05, 0.15, 0.075]), 'Accept')
485
+ accept_cid = accept_btn.on_clicked(accept)
486
+
487
+ plt.show()
488
+
489
+ axis = axes.index(selected_direction[0])
490
+ mean_data = mean_data[axis]
491
+ subaxes = [a for a in axes if a != axes[axis]]
492
+
493
+ plt.close()
494
+
495
+ # Select the ROI's orthogonal to the selected averaging direction
496
+ bounds = []
497
+ for i, bound in enumerate(['"0"', '"1"']):
498
+ roi = select_roi_2d(
499
+ mean_data,
500
+ title=f'Select the ROI to obtain the {bound} data value',
501
+ title_a=f'Data averaged in the {axes[axis]}-direction',
502
+ row_label=subaxes[0], column_label=subaxes[1])
503
+
504
+ # Select the index range in the selected averaging direction
505
+ if not axis:
506
+ mean_roi_data = data[:,roi[2]:roi[3],roi[0]:roi[1]].mean(
507
+ axis=(1,2))
508
+ elif axis == 1:
509
+ mean_roi_data = data[roi[2]:roi[3],:,roi[0]:roi[1]].mean(
510
+ axis=(0,2))
511
+ elif axis == 2:
512
+ mean_roi_data = data[roi[2]:roi[3],roi[0]:roi[1],:].mean(
513
+ axis=(0,1))
514
+
515
+ _range = select_roi_1d(
516
+ mean_roi_data, preselected_roi=(0, data.shape[axis]),
517
+ title=f'Select the {axes[axis]}-direction range to obtain '
518
+ f'the {bound} data bound',
519
+ xlabel=axes[axis], ylabel='Average data')
520
+
521
+ # Obtain the lower/upper data bound
522
+ if not axis:
523
+ bounds.append(
524
+ data[
525
+ _range[0]:_range[1],roi[2]:roi[3],roi[0]:roi[1]
526
+ ].mean())
527
+ elif axis == 1:
528
+ bounds.append(
529
+ data[
530
+ roi[2]:roi[3],_range[0]:_range[1],roi[0]:roi[1]
531
+ ].mean())
532
+ elif axis == 2:
533
+ bounds.append(
534
+ data[
535
+ roi[2]:roi[3],roi[0]:roi[1],_range[0]:_range[1]
536
+ ].mean())
537
+
538
+ # Get the data cutoff threshold
539
+ threshold = np.mean(bounds)
540
+
541
+ # Apply the data cutoff threshold and return the output
542
+ data = np.where(data<threshold, 0, 1).astype(np.ubyte)
543
+ # from CHAP.utils.general import quick_imshow
544
+ # quick_imshow(data[int(data.shape[0]/2),:,:], block=True)
545
+ # quick_imshow(data[:,int(data.shape[1]/2),:], block=True)
546
+ # quick_imshow(data[:,:,int(data.shape[2]/2)], block=True)
547
+ if isinstance(dataset, np.ndarray):
548
+ return data
549
+ if isinstance(dataset, NXfield):
550
+ attrs = dataset.attrs
551
+ attrs.pop('target', None)
552
+ return NXfield(
553
+ value=data, name=dataset.nxname, attrs=dataset.attrs)
554
+ name = nxsignal.nxname + '_binarized'
555
+ if nxobject.nxclass == 'NXdata':
556
+ nxobject[name] = data
557
+ nxobject.attrs['signal'] = name
558
+ return nxobject
559
+ if nxobject.nxclass == 'NXroot':
560
+ nxentry = nxobject[nxobject.default]
561
+ else:
562
+ nxentry = nxobject
563
+ axes = []
564
+ for axis in nxdata.axes:
565
+ attrs = nxdata[axis].attrs
566
+ attrs.pop('target', None)
567
+ axes.append(
568
+ NXfield(nxdata[axis], name=axis, attrs=attrs))
569
+ nxentry[name] = NXprocess(
570
+ NXdata(NXfield(data, name=name), axes),
571
+ attrs={'source': nxsignal.nxpath})
572
+ nxdata = nxentry[name].data
573
+ nxentry.data = NXdata(
574
+ NXlink(nxdata.nxsignal.nxpath),
575
+ [NXlink(os_join(nxdata.nxpath, axis)) for axis in nxdata.axes])
576
+ nxentry.data.set_default()
577
+ return nxobject
578
+
579
+
580
+ class ConstructBaseline(Processor):
581
+ """A Processor to construct a baseline for a dataset.
582
+ """
583
+ def process(
584
+ self, data, mask=None, tol=1.e-6, lam=1.e6, max_iter=20,
585
+ save_figures=False, outputdir='.', interactive=False):
586
+ """Construct and return the baseline for a dataset.
587
+
588
+ :param data: Input data.
589
+ :type data: list[PipelineData]
590
+ :param mask: A mask to apply to the spectrum before baseline
591
+ construction, default to `None`.
592
+ :type mask: array-like, optional
593
+ :param tol: The convergence tolerence, defaults to `1.e-6`.
594
+ :type tol: float, optional
595
+ :param lam: The &lambda (smoothness) parameter (the balance
596
+ between the residual of the data and the baseline and the
597
+ smoothness of the baseline). The suggested range is between
598
+ 100 and 10^8, defaults to `10^6`.
599
+ :type lam: float, optional
600
+ :param max_iter: The maximum number of iterations,
601
+ defaults to `20`.
602
+ :type max_iter: int, optional
603
+ :param save_figures: Save .pngs of plots for checking inputs &
604
+ outputs of this Processor, defaults to False.
605
+ :type save_figures: bool, optional
606
+ :param outputdir: Directory to which any output figures will
607
+ be saved, defaults to '.'
608
+ :type outputdir: str, optional
609
+ :param interactive: Allows for user interactions, defaults to
610
+ False.
611
+ :type interactive: bool, optional
612
+ :return: The smoothed baseline and the configuration.
613
+ :rtype: numpy.array, dict
614
+ """
615
+ try:
616
+ data = np.asarray(self.unwrap_pipelinedata(data)[0])
617
+ except:
618
+ raise ValueError(
619
+ f'The structure of {data} contains no valid data')
620
+
621
+ return self.construct_baseline(
622
+ data, mask, tol, lam, max_iter, save_figures, outputdir,
623
+ interactive)
624
+
625
+ @staticmethod
626
+ def construct_baseline(
627
+ y, x=None, mask=None, tol=1.e-6, lam=1.e6, max_iter=20, title=None,
628
+ xlabel=None, ylabel=None, interactive=False, filename=None):
629
+ """Construct and return the baseline for a dataset.
630
+
631
+ :param y: Input data.
632
+ :type y: numpy.array
633
+ :param x: Independent dimension (only used when interactive is
634
+ `True` of when filename is set), defaults to `None`.
635
+ :type x: array-like, optional
636
+ :param mask: A mask to apply to the spectrum before baseline
637
+ construction, default to `None`.
638
+ :type mask: array-like, optional
639
+ :param tol: The convergence tolerence, defaults to `1.e-6`.
640
+ :type tol: float, optional
641
+ :param lam: The &lambda (smoothness) parameter (the balance
642
+ between the residual of the data and the baseline and the
643
+ smoothness of the baseline). The suggested range is between
644
+ 100 and 10^8, defaults to `10^6`.
645
+ :type lam: float, optional
646
+ :param max_iter: The maximum number of iterations,
647
+ defaults to `20`.
648
+ :type max_iter: int, optional
649
+ :param xlabel: Label for the x-axis of the displayed figure,
650
+ defaults to `None`.
651
+ :param title: Title for the displayed figure, defaults to `None`.
652
+ :type title: str, optional
653
+ :type xlabel: str, optional
654
+ :param ylabel: Label for the y-axis of the displayed figure,
655
+ defaults to `None`.
656
+ :type ylabel: str, optional
657
+ :param interactive: Allows for user interactions, defaults to
658
+ False.
659
+ :type interactive: bool, optional
660
+ :param filename: Save a .png of the plot to filename, defaults to
661
+ `None`, in which case the plot is not saved.
662
+ :type filename: str, optional
663
+ :return: The smoothed baseline and the configuration.
664
+ :rtype: numpy.array, dict
665
+ """
666
+ # Third party modules
667
+ if interactive or filename is not None:
668
+ from matplotlib.widgets import TextBox, Button
669
+ import matplotlib.pyplot as plt
670
+
671
+ # Local modules
672
+ from CHAP.utils.general import baseline_arPLS
673
+
674
+ def change_fig_subtitle(maxed_out=False, subtitle=None):
675
+ if fig_subtitles:
676
+ fig_subtitles[0].remove()
677
+ fig_subtitles.pop()
678
+ if subtitle is None:
679
+ subtitle = r'$\lambda$ = 'f'{lambdas[-1]:.2e}, '
680
+ if maxed_out:
681
+ subtitle += f'# iter = {num_iters[-1]} (maxed out) '
682
+ else:
683
+ subtitle += f'# iter = {num_iters[-1]} '
684
+ subtitle += f'error = {errors[-1]:.2e}'
685
+ fig_subtitles.append(
686
+ plt.figtext(*subtitle_pos, subtitle, **subtitle_props))
687
+
688
+ def select_lambda(expression):
689
+ """Callback function for the "Select lambda" TextBox.
690
+ """
691
+ if not len(expression):
692
+ return
693
+ try:
694
+ lam = float(expression)
695
+ if lam < 0:
696
+ raise ValueError
697
+ except ValueError:
698
+ change_fig_subtitle(
699
+ subtitle=f'Invalid lambda, enter a positive number')
700
+ else:
701
+ lambdas.pop()
702
+ lambdas.append(10**lam)
703
+ baseline, _, w, num_iter, error = baseline_arPLS(
704
+ y, mask=mask, tol=tol, lam=lambdas[-1], max_iter=max_iter,
705
+ full_output=True)
706
+ num_iters.pop()
707
+ num_iters.append(num_iter)
708
+ errors.pop()
709
+ errors.append(error)
710
+ if num_iter < max_iter:
711
+ change_fig_subtitle()
712
+ else:
713
+ change_fig_subtitle(maxed_out=True)
714
+ baseline_handle.set_ydata(baseline)
715
+ lambda_box.set_val('')
716
+ plt.draw()
717
+
718
+ def continue_iter(event):
719
+ """Callback function for the "Continue" button."""
720
+ baseline, _, w, n_iter, error = baseline_arPLS(
721
+ y, mask=mask, w=weights[-1], tol=tol, lam=lambdas[-1],
722
+ max_iter=max_iter, full_output=True)
723
+ num_iters[-1] += n_iter
724
+ errors.pop()
725
+ errors.append(error)
726
+ if n_iter < max_iter:
727
+ change_fig_subtitle()
728
+ else:
729
+ change_fig_subtitle(maxed_out=True)
730
+ baseline_handle.set_ydata(baseline)
731
+ plt.draw()
732
+ weights.pop()
733
+ weights.append(w)
734
+
735
+ def confirm(event):
736
+ """Callback function for the "Confirm" button."""
737
+ plt.close()
738
+
739
+ baseline, _, w, num_iter, error = baseline_arPLS(
740
+ y, mask=mask, tol=tol, lam=lam, max_iter=max_iter,
741
+ full_output=True)
742
+
743
+ if not interactive and filename is None:
744
+ return baseline
745
+
746
+ lambdas = [lam]
747
+ weights = [w]
748
+ num_iters = [num_iter]
749
+ errors = [error]
750
+ fig_subtitles = []
751
+
752
+ # Check inputs
753
+ if x is None:
754
+ x = np.arange(y.size)
755
+
756
+ # Setup the Matplotlib figure
757
+ title_pos = (0.5, 0.95)
758
+ title_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center',
759
+ 'verticalalignment': 'bottom'}
760
+ subtitle_pos = (0.5, 0.90)
761
+ subtitle_props = {'fontsize': 'x-large',
762
+ 'horizontalalignment': 'center',
763
+ 'verticalalignment': 'bottom'}
764
+ fig, ax = plt.subplots(figsize=(11, 8.5))
765
+ if mask is None:
766
+ ax.plot(x, y, label='input data')
767
+ else:
768
+ ax.plot(
769
+ x[mask.astype(bool)], y[mask.astype(bool)], label='input data')
770
+ baseline_handle = ax.plot(x, baseline, label='baseline')[0]
771
+ # ax.plot(x, y-baseline, label='baseline corrected data')
772
+ ax.set_xlabel(xlabel, fontsize='x-large')
773
+ ax.set_ylabel(ylabel, fontsize='x-large')
774
+ ax.legend()
775
+ if title is None:
776
+ fig_title = plt.figtext(*title_pos, 'Baseline', **title_props)
777
+ else:
778
+ fig_title = plt.figtext(*title_pos, title, **title_props)
779
+ if num_iter < max_iter:
780
+ change_fig_subtitle()
781
+ else:
782
+ change_fig_subtitle(maxed_out=True)
783
+ fig.subplots_adjust(bottom=0.0, top=0.85)
784
+
785
+ if interactive:
786
+
787
+ fig.subplots_adjust(bottom=0.2)
788
+
789
+ # Setup TextBox
790
+ lambda_box = TextBox(
791
+ plt.axes([0.15, 0.05, 0.15, 0.075]), r'log($\lambda$)')
792
+ lambda_cid = lambda_box.on_submit(select_lambda)
793
+
794
+ # Setup "Continue" button
795
+ continue_btn = Button(
796
+ plt.axes([0.45, 0.05, 0.15, 0.075]), 'Continue smoothing')
797
+ continue_cid = continue_btn.on_clicked(continue_iter)
798
+
799
+ # Setup "Confirm" button
800
+ confirm_btn = Button(plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm')
801
+ confirm_cid = confirm_btn.on_clicked(confirm)
802
+
803
+ # Show figure for user interaction
804
+ plt.show()
805
+
806
+ # Disconnect all widget callbacks when figure is closed
807
+ lambda_box.disconnect(lambda_cid)
808
+ continue_btn.disconnect(continue_cid)
809
+ confirm_btn.disconnect(confirm_cid)
810
+
811
+ # ... and remove the buttons before returning the figure
812
+ lambda_box.ax.remove()
813
+ continue_btn.ax.remove()
814
+ confirm_btn.ax.remove()
815
+
816
+ if filename is not None:
817
+ fig_title.set_in_layout(True)
818
+ fig_subtitles[-1].set_in_layout(True)
819
+ fig.tight_layout(rect=(0, 0, 1, 0.90))
820
+ fig.savefig(filename)
821
+ plt.close()
822
+
823
+ config = {
824
+ 'tol': tol, 'lambda': lambdas[-1], 'max_iter': max_iter,
825
+ 'num_iter': num_iters[-1], 'error': errors[-1], 'mask': mask}
826
+ return baseline, config
827
+
828
+
132
829
  class ImageProcessor(Processor):
133
- """A Processor to plot an image slice from a dataset.
830
+ """A Processor to plot an image (slice) from a NeXus object.
134
831
  """
135
- def process(self, data, index=0, axis=0):
136
- """Plot an image from a dataset contained in `data` and return
137
- the full dataset.
832
+ def process(
833
+ self, data, vmin=None, vmax=None, axis=0, index=None,
834
+ coord=None, interactive=False, save_figure=True, outputdir='.',
835
+ filename='image.png'):
836
+ """Plot and/or save an image (slice) from a NeXus NXobject object with
837
+ a default data path contained in `data` and return the NeXus NXdata
838
+ data object.
138
839
 
139
840
  :param data: Input data.
140
- :type data: CHAP.pipeline.PipelineData
141
- :param index: Array index of the slice of data to plot,
841
+ :type data: list[PipelineData]
842
+ :param vmin: Minimum array value in image slice, default to
843
+ `None`, which uses the actual minimum value in the slice.
844
+ :type vmin: float
845
+ :param vmax: Maximum array value in image slice, default to
846
+ `None`, which uses the actual maximum value in the slice.
847
+ :type vmax: float
848
+ :param axis: Axis direction or name of the image slice,
142
849
  defaults to `0`
850
+ :type axis: Union[int, str], optional
851
+ :param index: Array index of the slice of data to plot,
852
+ defaults to `None`
143
853
  :type index: int, optional
144
- :param axis: Axis direction of the image slice,
145
- defaults to `0`
146
- :type axis: int, optional
147
- :return: The full input dataset.
148
- :rtype: object
854
+ :param coord: Coordinate value of the slice of data to plot,
855
+ defaults to `None`
856
+ :type coord: Union[int, float], optional
857
+ :param interactive: Allows for user interactions, defaults to
858
+ `False`.
859
+ :type interactive: bool, optional
860
+ :param save_figure: Save a .png of the image, defaults to `True`.
861
+ :type save_figure: bool, optional
862
+ :param outputdir: Directory to which any output figure will
863
+ be saved, defaults to `'.'`
864
+ :type outputdir: str, optional
865
+ :param filename: Image filename, defaults to `"image.png"`.
866
+ :type filename: str, optional
867
+ :return: The input data object.
868
+ :rtype: nexusformat.nexus.NXdata
149
869
  """
150
- # Local modules
151
- from CHAP.utils.general import quick_imshow
870
+ # System modules
871
+ from os.path import (
872
+ isabs,
873
+ join,
874
+ )
152
875
 
876
+ # Third party modules
877
+ import matplotlib.pyplot as plt
878
+
879
+ # Local modules
880
+ from CHAP.utils.general import index_nearest
881
+
882
+ # Validate input parameters
883
+ if not isinstance(interactive, bool):
884
+ raise ValueError(f'Invalid parameter interactive ({interactive})')
885
+ if not isinstance(save_figure, bool):
886
+ raise ValueError(f'Invalid parameter save_figure ({save_figure})')
887
+ if not isinstance(outputdir, str):
888
+ raise ValueError(f'Invalid parameter outputdir ({outputdir})')
889
+ if not isinstance(filename, str):
890
+ raise ValueError(f'Invalid parameter filename ({filename})')
891
+ if not isabs(filename):
892
+ filename = join(outputdir, filename)
893
+
894
+ # Get the default Nexus NXdata object
153
895
  data = self.unwrap_pipelinedata(data)[0]
154
- if data.ndim == 2:
155
- quick_imshow(data, block=True)
156
- elif data.ndim == 3:
896
+ try:
897
+ nxdata = data.get_default()
898
+ except:
899
+ if nxdata.nxclass != 'NXdata':
900
+ raise ValueError('Invalid default pathway to an NXdata object '
901
+ f'in ({data})')
902
+
903
+ # Get the data slice
904
+ axes = nxdata.attrs.get('axes', None)
905
+ if axes is not None:
906
+ axes = list(axes.nxdata)
907
+ coords = None
908
+ title = f'{nxdata.nxpath}/{nxdata.signal}'
909
+ if nxdata.nxsignal.ndim == 2:
910
+ exit('ImageProcessor not tested yet for a 2D dataset')
911
+ if axis is not None:
912
+ axis = None
913
+ self.logger.warning('Ignoring parameter axis')
914
+ if index is not None:
915
+ index = None
916
+ self.logger.warning('Ignoring parameter index')
917
+ if coord is not None:
918
+ coord = None
919
+ self.logger.warning('Ignoring parameter coord')
920
+ a = nxdata.nxsignal
921
+ elif nxdata.nxsignal.ndim == 3:
922
+ if isinstance(axis, int):
923
+ if not 0 <= axis < nxdata.nxsignal.ndim:
924
+ raise ValueError(f'axis index out of range ({axis} not in '
925
+ f'[0, {nxdata.nxsignal.ndim-1}])')
926
+ elif isinstance(axis, str):
927
+ if axes is None or axis not in axes:
928
+ raise ValueError(
929
+ f'Unable to match axis = {axis} in {nxdata.tree}')
930
+ axis = axes.index(axis)
931
+ else:
932
+ raise ValueError(f'Invalid parameter axis ({axis})')
933
+ if axes is not None and hasattr(nxdata, axes[axis]):
934
+ coords = nxdata[axes[axis]].nxdata
935
+ axis_name = axes[axis]
936
+ else:
937
+ axis_name = f'axis {axis}'
938
+ if index is None and coord is None:
939
+ index = nxdata.nxsignal.shape[axis] // 2
940
+ else:
941
+ if index is not None:
942
+ if coord is not None:
943
+ coord = None
944
+ self.logger.warning('Ignoring parameter coord')
945
+ if not isinstance(index, int):
946
+ raise ValueError(f'Invalid parameter index ({index})')
947
+ elif not 0 <= index < nxdata.nxsignal.shape[axis]:
948
+ raise ValueError(
949
+ f'index value out of range ({index} not in '
950
+ f'[0, {nxdata.nxsignal.shape[axis]-1}])')
951
+ else:
952
+ if not isinstance(coord, (int, float)):
953
+ raise ValueError(f'Invalid parameter coord ({coord})')
954
+ if coords is None:
955
+ raise ValueError(
956
+ f'Unable to get coordinates for {axis_name} '
957
+ f'in {nxdata.tree}')
958
+ index = index_nearest(nxdata[axis_name], coord)
959
+ if coords is None:
960
+ slice_info = f'slice at {axis_name} and index {index}'
961
+ else:
962
+ coord = coords[index]
963
+ slice_info = f'slice at {axis_name} = '\
964
+ f'{nxdata[axis_name][index]:.3f}'
965
+ if 'units' in nxdata[axis_name].attrs:
966
+ slice_info += f' ({nxdata[axis_name].units})'
157
967
  if not axis:
158
- quick_imshow(data[index], block=True)
968
+ a = nxdata[nxdata.signal][index,:,:]
159
969
  elif axis == 1:
160
- quick_imshow(data[:,index,:], block=True)
970
+ a = nxdata[nxdata.signal][:,index,:]
161
971
  elif axis == 2:
162
- quick_imshow(data[:,:,index], block=True)
972
+ a = nxdata[nxdata.signal][:,:,index]
973
+ if coords is None:
974
+ axes = [i for i in range(3) if i != axis]
975
+ row_coords = range(a.shape[1])
976
+ row_label = f'axis {axes[1]} index'
977
+ column_coords = range(a.shape[0])
978
+ column_label = f'axis {axes[0]} index'
163
979
  else:
164
- raise ValueError(f'Invalid parameter axis ({axis})')
980
+ axes.pop(axis)
981
+ row_coords = nxdata[axes[1]].nxdata
982
+ row_label = axes[1]
983
+ if 'units' in nxdata[axes[1]].attrs:
984
+ row_label += f' ({nxdata[axes[1]].units})'
985
+ column_coords = nxdata[axes[0]].nxdata
986
+ column_label = axes[0]
987
+ if 'units' in nxdata[axes[0]].attrs:
988
+ column_label += f' ({nxdata[axes[0]].units})'
165
989
  else:
166
990
  raise ValueError('Invalid data dimension (must be 2D or 3D)')
167
991
 
168
- return data
992
+ # Create figure
993
+ a_max = a.max()
994
+ if vmin is None:
995
+ vmin = -a_max
996
+ if vmax is None:
997
+ vmax = a_max
998
+ extent = (
999
+ row_coords[0], row_coords[-1], column_coords[-1], column_coords[0])
1000
+ fig, ax = plt.subplots(figsize=(11, 8.5))
1001
+ plt.imshow(
1002
+ a, extent=extent, origin='lower', vmin=vmin, vmax=vmax,
1003
+ cmap='gray')
1004
+ fig.suptitle(title, fontsize='xx-large')
1005
+ ax.set_title(slice_info, fontsize='xx-large', pad=20)
1006
+ ax.set_xlabel(row_label, fontsize='x-large')
1007
+ ax.set_ylabel(column_label, fontsize='x-large')
1008
+ plt.colorbar()
1009
+ fig.tight_layout()
1010
+ if interactive:
1011
+ plt.show()
1012
+ if save_figure:
1013
+ fig.savefig(filename)
1014
+ plt.close()
1015
+
1016
+ return nxdata
169
1017
 
170
1018
 
171
1019
  class IntegrationProcessor(Processor):
@@ -177,7 +1025,7 @@ class IntegrationProcessor(Processor):
177
1025
 
178
1026
  :param data: Input data, containing the raw data, integration
179
1027
  method, and keyword args for the integration method.
180
- :type data: CHAP.pipeline.PipelineData
1028
+ :type data: list[PipelineData]
181
1029
  :return: Integrated raw data.
182
1030
  :rtype: pyFAI.containers.IntegrateResult
183
1031
  """
@@ -200,7 +1048,7 @@ class IntegrateMapProcessor(Processor):
200
1048
  with the value `'MapConfig'` for the `'schema'` key, and at
201
1049
  least one item with the value `'IntegrationConfig'` for the
202
1050
  `'schema'` key.
203
- :type data: CHAP.pipeline.PipelineData
1051
+ :type data: list[PipelineData]
204
1052
  :return: Integrated data and process metadata.
205
1053
  :rtype: nexusformat.nexus.NXprocess
206
1054
  """
@@ -359,28 +1207,44 @@ class MapProcessor(Processor):
359
1207
  NXentry object representing that map's metadata and any
360
1208
  scalar-valued raw data requested by the supplied map configuration.
361
1209
  """
362
- def process(self, data):
1210
+ def process(self, data, detector_names=[]):
363
1211
  """Process the output of a `Reader` that contains a map
364
1212
  configuration and returns a NeXus NXentry object representing
365
1213
  the map.
366
1214
 
367
1215
  :param data: Result of `Reader.read` where at least one item
368
1216
  has the value `'MapConfig'` for the `'schema'` key.
369
- :type data: CHAP.pipeline.PipelineData
1217
+ :type data: list[PipelineData]
1218
+ :param detector_names: Detector prefixes to include raw data
1219
+ for in the returned NeXus NXentry object, defaults to `[]`.
1220
+ :type detector_names: list[str], optional
370
1221
  :return: Map data and metadata.
371
1222
  :rtype: nexusformat.nexus.NXentry
372
1223
  """
1224
+ # Local modules
1225
+ from CHAP.utils.general import string_to_list
1226
+ if isinstance(detector_names, str):
1227
+ try:
1228
+ detector_names = [
1229
+ str(v) for v in string_to_list(
1230
+ detector_names, raise_error=True)]
1231
+ except:
1232
+ raise ValueError(
1233
+ f'Invalid parameter detector_names ({detector_names})')
373
1234
  map_config = self.get_config(data, 'common.models.map.MapConfig')
374
- nxentry = self.__class__.get_nxentry(map_config)
1235
+ nxentry = self.__class__.get_nxentry(map_config, detector_names)
375
1236
 
376
1237
  return nxentry
377
1238
 
378
1239
  @staticmethod
379
- def get_nxentry(map_config):
1240
+ def get_nxentry(map_config, detector_names=[]):
380
1241
  """Use a `MapConfig` to construct a NeXus NXentry object.
381
1242
 
382
1243
  :param map_config: A valid map configuration.
383
1244
  :type map_config: MapConfig
1245
+ :param detector_names: Detector prefixes to include raw data
1246
+ for in the returned NeXus NXentry object.
1247
+ :type detector_names: list[str]
384
1248
  :return: The map's data and metadata contained in a NeXus
385
1249
  structure.
386
1250
  :rtype: nexusformat.nexus.NXentry
@@ -401,6 +1265,8 @@ class MapProcessor(Processor):
401
1265
  nxentry.map_config = dumps(map_config.dict())
402
1266
  nxentry[map_config.sample.name] = NXsample(**map_config.sample.dict())
403
1267
  nxentry.attrs['station'] = map_config.station
1268
+ for key, value in map_config.attrs.items():
1269
+ nxentry.attrs[key] = value
404
1270
 
405
1271
  nxentry.spec_scans = NXcollection()
406
1272
  for scans in map_config.spec_scans:
@@ -440,10 +1306,26 @@ class MapProcessor(Processor):
440
1306
  nxentry.data.attrs['signal'] = signal
441
1307
  nxentry.data.attrs['auxilliary_signals'] = auxilliary_signals
442
1308
 
443
- for data in map_config.all_scalar_data:
444
- for map_index in np.ndindex(map_config.shape):
1309
+ # Create empty NXfields of appropriate shape for raw
1310
+ # detector data
1311
+ for detector_name in detector_names:
1312
+ if not isinstance(detector_name, str):
1313
+ detector_name = str(detector_name)
1314
+ detector_data = map_config.get_detector_data(
1315
+ detector_name, (0,) * len(map_config.shape))
1316
+ nxentry.data[detector_name] = NXfield(value=np.zeros(
1317
+ (*map_config.shape, *detector_data.shape)),
1318
+ dtype=detector_data.dtype)
1319
+
1320
+ for map_index in np.ndindex(map_config.shape):
1321
+ for data in map_config.all_scalar_data:
445
1322
  nxentry.data[data.label][map_index] = map_config.get_value(
446
1323
  data, map_index)
1324
+ for detector_name in detector_names:
1325
+ if not isinstance(detector_name, str):
1326
+ detector_name = str(detector_name)
1327
+ nxentry.data[detector_name][map_index] = \
1328
+ map_config.get_detector_data(detector_name, map_index)
447
1329
 
448
1330
  return nxentry
449
1331
 
@@ -573,6 +1455,66 @@ class PrintProcessor(Processor):
573
1455
  return data
574
1456
 
575
1457
 
1458
+ class PyfaiAzimuthalIntegrationProcessor(Processor):
1459
+ """Processor to azimuthally integrate one or more frames of 2d
1460
+ detector data using the
1461
+ [pyFAI](https://pyfai.readthedocs.io/en/v2023.1/index.html)
1462
+ package.
1463
+ """
1464
+ def process(self, data, poni_file, npt, mask_file=None,
1465
+ integrate1d_kwargs=None, inputdir='.'):
1466
+ """Azimuthally integrate the detector data provided and return
1467
+ the result as a dictionary of numpy arrays containing the
1468
+ values of the radial coordinate of the result, the intensities
1469
+ along the radial direction, and the poisson errors for each
1470
+ intensity spectrum.
1471
+
1472
+ :param data: Detector data to integrate.
1473
+ :type data: Union[PipelineData, list[np.ndarray]]
1474
+ :param poni_file: Name of the [pyFAI PONI
1475
+ file](https://pyfai.readthedocs.io/en/v2023.1/glossary.html?highlight=poni%20file#poni-file)
1476
+ containing the detector properties pyFAI needs to perform
1477
+ azimuthal integration.
1478
+ :type poni_file: str
1479
+ :param npt: Number of points in the output pattern.
1480
+ :type npt: int
1481
+ :param mask_file: A file to use for masking the input data.
1482
+ :type: str
1483
+ :param integrate1d_kwargs: Optional dictionary of keyword
1484
+ arguments to use with
1485
+ [`pyFAI.azimuthalIntegrator.AzimuthalIntegrator.integrate1d`](https://pyfai.readthedocs.io/en/v2023.1/api/pyFAI.html#pyFAI.azimuthalIntegrator.AzimuthalIntegrator.integrate1d). Defaults
1486
+ to `None`.
1487
+ :type integrate1d_kwargs: Optional[dict]
1488
+ :returns: Azimuthal integration results as a dictionary of
1489
+ numpy arrays.
1490
+ """
1491
+ import os
1492
+ from pyFAI import load
1493
+
1494
+ if not os.path.isabs(poni_file):
1495
+ poni_file = os.path.join(inputdir, poni_file)
1496
+ ai = load(poni_file)
1497
+
1498
+ if mask_file is None:
1499
+ mask = None
1500
+ else:
1501
+ if not os.path.isabs(mask_file):
1502
+ mask_file = os.path.join(inputdir, mask_file)
1503
+ import fabio
1504
+ mask = fabio.open(mask_file).data
1505
+
1506
+ try:
1507
+ det_data = self.unwrap_pipelinedata(data)[0]
1508
+ except:
1509
+ det_data = det_data
1510
+
1511
+ if integrate1d_kwargs is None:
1512
+ integrate1d_kwargs = {}
1513
+ integrate1d_kwargs['mask'] = mask
1514
+
1515
+ return [ai.integrate1d(d, npt, **integrate1d_kwargs) for d in det_data]
1516
+
1517
+
576
1518
  class RawDetectorDataMapProcessor(Processor):
577
1519
  """A Processor to return a map of raw derector data in a
578
1520
  NeXus NXroot object.
@@ -582,7 +1524,7 @@ class RawDetectorDataMapProcessor(Processor):
582
1524
  detector data data collected over the map.
583
1525
 
584
1526
  :param data: Input map configuration.
585
- :type data: CHAP.pipeline.PipelineData
1527
+ :type data: list[PipelineData]
586
1528
  :param detector_name: The detector prefix.
587
1529
  :type detector_name: str
588
1530
  :param detector_shape: The shape of detector data for a single
@@ -602,7 +1544,7 @@ class RawDetectorDataMapProcessor(Processor):
602
1544
 
603
1545
  :param data: Result of `Reader.read` where at least one item
604
1546
  has the value `'MapConfig'` for the `'schema'` key.
605
- :type data: CHAP.pipeline.PipelineData
1547
+ :type data: list[PipelineData]
606
1548
  :raises Exception: If a valid map config object cannot be
607
1549
  constructed from `data`.
608
1550
  :return: A valid instance of the map configuration object with
@@ -709,7 +1651,7 @@ class StrainAnalysisProcessor(Processor):
709
1651
 
710
1652
  :param data: Results of `MutlipleReader.read` containing input
711
1653
  map detector data and strain analysis configuration
712
- :type data: CHAP.pipeline.PipelineData
1654
+ :type data: list[PipelineData]
713
1655
  :return: A map of sample strains.
714
1656
  :rtype: xarray.Dataset
715
1657
  """
@@ -724,7 +1666,7 @@ class StrainAnalysisProcessor(Processor):
724
1666
  :param data: Result of `Reader.read` where at least one item
725
1667
  has the value `'StrainAnalysisConfig'` for the `'schema'`
726
1668
  key.
727
- :type data: CHAP.pipeline.PipelineData
1669
+ :type data: list[PipelineData]
728
1670
  :raises Exception: If valid config objects cannot be
729
1671
  constructed from `data`.
730
1672
  :return: A valid instance of the configuration object with
@@ -745,6 +1687,499 @@ class StrainAnalysisProcessor(Processor):
745
1687
  return strain_analysis_config
746
1688
 
747
1689
 
1690
+ class SetupNXdataProcessor(Processor):
1691
+ """Processor to set up and return an "empty" NeXus representation
1692
+ of a structured dataset. This representation will be an instance
1693
+ of `NXdata` that has:
1694
+ 1. An `NXfield` entry for every coordinate and signal specified.
1695
+ 1. `nxaxes` that are the `NXfield` entries for the coordinates and
1696
+ contain the values provided for each coordinate.
1697
+ 1. `NXfield` entries of appropriate shape, but containing all
1698
+ zeros, for every signal.
1699
+ 1. Attributes that define the axes, plus any additional attributes
1700
+ specified by the user.
1701
+
1702
+ This `Processor` is most useful as a "setup" step for
1703
+ constucting a representation of / container for a complete dataset
1704
+ that will be filled out in pieces later by
1705
+ `UpdateNXdataProcessor`.
1706
+
1707
+ Examples of use in a `Pipeline` configuration:
1708
+ - With inputs from a previous `PipelineItem` specifically written
1709
+ to provide inputs to this `Processor`:
1710
+ ```yaml
1711
+ config:
1712
+ inputdir: /rawdata/samplename
1713
+ outputdir: /reduceddata/samplename
1714
+ pipeline:
1715
+ - edd.SetupNXdataReader:
1716
+ filename: SpecInput.txt
1717
+ dataset_id: 1
1718
+ - common.SetupNXdataProcessor:
1719
+ nxname: samplename_dataset_1
1720
+ - common.NexusWriter:
1721
+ filename: data.nxs
1722
+ ```
1723
+ - With inputs provided directly though the optional arguments:
1724
+ ```yaml
1725
+ config:
1726
+ outputdir: /reduceddata/samplename
1727
+ pipeline:
1728
+ - common.SetupNXdataProcessor:
1729
+ nxname: your_dataset_name
1730
+ coords:
1731
+ - name: x
1732
+ values: [0.0, 0.5, 1.0]
1733
+ attrs:
1734
+ units: mm
1735
+ yourkey: yourvalue
1736
+ - name: temperature
1737
+ values: [200, 250, 275]
1738
+ attrs:
1739
+ units: Celsius
1740
+ yourotherkey: yourothervalue
1741
+ signals:
1742
+ - name: raw_detector_data
1743
+ shape: [407, 487]
1744
+ attrs:
1745
+ local_name: PIL11
1746
+ foo: bar
1747
+ - name: presample_intensity
1748
+ shape: []
1749
+ attrs:
1750
+ local_name: a3ic0
1751
+ zebra: fish
1752
+ attrs:
1753
+ arbitrary: metadata
1754
+ from: users
1755
+ goes: here
1756
+ - common.NexusWriter:
1757
+ filename: data.nxs
1758
+ ```
1759
+ """
1760
+ def process(self, data, nxname='data',
1761
+ coords=[], signals=[], attrs={}, data_points=[],
1762
+ extra_nxfields=[], duplicates='overwrite'):
1763
+ """Return an `NXdata` that has the requisite axes and
1764
+ `NXfield` entries to represent a structured dataset with the
1765
+ properties provided. Properties may be provided either through
1766
+ the `data` argument (from an appropriate `PipelineItem` that
1767
+ immediately preceeds this one in a `Pipeline`), or through the
1768
+ `coords`, `signals`, `attrs`, and/or `data_points`
1769
+ arguments. If any of the latter are used, their values will
1770
+ completely override any values for these parameters found from
1771
+ `data.`
1772
+
1773
+ :param data: Data from the previous item in a `Pipeline`.
1774
+ :type data: list[PipelineData]
1775
+ :param nxname: Name for the returned `NXdata` object. Defaults
1776
+ to `'data'`.
1777
+ :type nxname: str, optional
1778
+ :param coords: List of dictionaries defining the coordinates
1779
+ of the dataset. Each dictionary must have the keys
1780
+ `'name'` and `'values'`, whose values are the name of the
1781
+ coordinate axis (a string) and all the unique values of
1782
+ that coordinate for the structured dataset (a list of
1783
+ numbers), respectively. A third item in the dictionary is
1784
+ optional, but highly recommended: `'attrs'` may provide a
1785
+ dictionary of attributes to attach to the coordinate axis
1786
+ that assist in in interpreting the returned `NXdata`
1787
+ representation of the dataset. It is strongly recommended
1788
+ to provide the units of the values along an axis in the
1789
+ `attrs` dictionary. Defaults to [].
1790
+ :type coords: list[dict[str, object]], optional
1791
+ :param signals: List of dictionaries defining the signals of
1792
+ the dataset. Each dictionary must have the keys `'name'`
1793
+ and `'shape'`, whose values are the name of the signal
1794
+ field (a string) and the shape of the signal's value at
1795
+ each point in the dataset (a list of zero or more
1796
+ integers), respectively. A third item in the dictionary is
1797
+ optional, but highly recommended: `'attrs'` may provide a
1798
+ dictionary of attributes to attach to the signal fieldthat
1799
+ assist in in interpreting the returned `NXdata`
1800
+ representation of the dataset. It is strongly recommended
1801
+ to provide the units of the signal's values `attrs`
1802
+ dictionary. Defaults to [].
1803
+ :type signals: list[dict[str, object]], optional
1804
+ :param attrs: An arbitrary dictionary of attributes to assign
1805
+ to the returned `NXdata`. Defaults to {}.
1806
+ :type attrs: dict[str, object], optional
1807
+ :param data_points: A list of data points to partially (or
1808
+ even entirely) fil out the "empty" signal `NXfield`s
1809
+ before returning the `NXdata`. Defaults to [].
1810
+ :type data_points: list[dict[str, object]], optional
1811
+ :param extra_nxfields: List "extra" NXfield`s to include that
1812
+ can be described neither as a signal of the dataset, not a
1813
+ dedicated coordinate. This paramteter is good for
1814
+ including "alternate" values for one of the coordinate
1815
+ dimensions -- the same coordinate axis expressed in
1816
+ different units, for instance. Each item in the list
1817
+ shoulde be a dictionary of parameters for the
1818
+ `nexusformat.nexus.NXfield` constructor. Defaults to `[]`.
1819
+ :type extra_nxfields: list[dict[str, object]], optional
1820
+ :param duplicates: Behavior to use if any new data points occur
1821
+ at the same point in the dataset's coordinate space as an
1822
+ existing data point. Allowed values for `duplicates` are:
1823
+ `'overwrite'` and `'block'`. Defaults to `'overwrite'`.
1824
+ :type duplicates: Literal['overwrite', 'block']
1825
+ :returns: An `NXdata` that represents the structured dataset
1826
+ as specified.
1827
+ :rtype: nexusformat.nexus.NXdata
1828
+ """
1829
+ self.nxname = nxname
1830
+
1831
+ self.coords = coords
1832
+ self.signals = signals
1833
+ self.attrs = attrs
1834
+ try:
1835
+ setup_params = self.unwrap_pipelinedata(data)[0]
1836
+ except:
1837
+ setup_params = None
1838
+ if isinstance(setup_params, dict):
1839
+ for a in ('coords', 'signals', 'attrs'):
1840
+ setup_param = setup_params.get(a)
1841
+ if not getattr(self, a) and setup_param:
1842
+ self.logger.info(f'Using input data from pipeline for {a}')
1843
+ setattr(self, a, setup_param)
1844
+ else:
1845
+ self.logger.info(
1846
+ f'Ignoring input data from pipeline for {a}')
1847
+ else:
1848
+ self.logger.warning('Ignoring all input data from pipeline')
1849
+
1850
+ self.shape = tuple(len(c['values']) for c in self.coords)
1851
+
1852
+ self.extra_nxfields = extra_nxfields
1853
+ self._data_points = []
1854
+ self.duplicates = duplicates
1855
+ self.init_nxdata()
1856
+ for d in data_points:
1857
+ self.add_data_point(d)
1858
+
1859
+ return self.nxdata
1860
+
1861
+ def add_data_point(self, data_point):
1862
+ """Add a data point to this dataset.
1863
+ 1. Validate `data_point`.
1864
+ 2. Append `data_point` to `self._data_points`.
1865
+ 3. Update signal `NXfield`s in `self.nxdata`.
1866
+
1867
+ :param data_point: Data point defining a point in the
1868
+ dataset's coordinate space and the new signal values at
1869
+ that point.
1870
+ :type data_point: dict[str, object]
1871
+ :returns: None
1872
+ """
1873
+ self.logger.info(f'Adding data point no. {len(self._data_points)}')
1874
+ self.logger.debug(f'New data point: {data_point}')
1875
+ valid, msg = self.validate_data_point(data_point)
1876
+ if not valid:
1877
+ self.logger.error(f'Cannot add data point: {msg}')
1878
+ else:
1879
+ self._data_points.append(data_point)
1880
+ self.update_nxdata(data_point)
1881
+
1882
+ def validate_data_point(self, data_point):
1883
+ """Return `True` if `data_point` occurs at a valid point in
1884
+ this structured dataset's coordinate space, `False`
1885
+ otherwise. Also validate shapes of signal values and add NaN
1886
+ values for any missing signals.
1887
+
1888
+ :param data_point: Data point defining a point in the
1889
+ dataset's coordinate space and the new signal values at
1890
+ that point.
1891
+ :type data_point: dict[str, object]
1892
+ :returns: Validity of `data_point`, message
1893
+ :rtype: bool, str
1894
+ """
1895
+ import numpy as np
1896
+
1897
+ valid = True
1898
+ msg = ''
1899
+ # Convert all values to numpy types
1900
+ data_point = {k: np.asarray(v) for k, v in data_point.items()}
1901
+ # Ensure data_point defines a specific point in the dataset's
1902
+ # coordinate space
1903
+ if not all(c['name'] in data_point for c in self.coords):
1904
+ valid = False
1905
+ msg = 'Missing coordinate values'
1906
+ # Find & handle any duplicates
1907
+ for i, d in enumerate(self._data_points):
1908
+ is_duplicate = all(data_point[c] == d[c] for c in self.coord_names)
1909
+ if is_duplicate:
1910
+ if self.duplicates == 'overwrite':
1911
+ self._data_points.pop(i)
1912
+ elif self.duplicates == 'block':
1913
+ valid = False
1914
+ msg = 'Duplicate point will be blocked'
1915
+ # Ensure a value is present for all signals
1916
+ for s in self.signals:
1917
+ if s['name'] not in data_point:
1918
+ data_point[s['name']] = np.full(s['shape'], 0)
1919
+ else:
1920
+ if not data_point[s['name']].shape == tuple(s['shape']):
1921
+ valid = False
1922
+ msg = f'Shape mismatch for signal {s}'
1923
+ return valid, msg
1924
+
1925
+ def init_nxdata(self):
1926
+ """Initialize an empty `NXdata` representing this dataset to
1927
+ `self.nxdata`; values for axes' `NXfield`s are filled out,
1928
+ values for signals' `NXfield`s are empty an can be filled out
1929
+ later. Save the empty `NXdata` to the NeXus file. Initialise
1930
+ `self.nxfile` and `self.nxdata_path` with the `NXFile` object
1931
+ and actual nxpath used to save and make updates to the
1932
+ `NXdata`.
1933
+
1934
+ :returns: None
1935
+ """
1936
+ from nexusformat.nexus import NXdata, NXfield
1937
+ import numpy as np
1938
+
1939
+ axes = tuple(NXfield(
1940
+ value=c['values'],
1941
+ name=c['name'],
1942
+ attrs=c.get('attrs')) for c in self.coords)
1943
+ entries = {s['name']: NXfield(
1944
+ value=np.full((*self.shape, *s['shape']), 0),
1945
+ name=s['name'],
1946
+ attrs=s.get('attrs')) for s in self.signals}
1947
+ extra_nxfields = [NXfield(**params) for params in self.extra_nxfields]
1948
+ extra_nxfields = {f.nxname: f for f in extra_nxfields}
1949
+ entries.update(extra_nxfields)
1950
+ self.nxdata = NXdata(
1951
+ name=self.nxname, axes=axes, entries=entries, attrs=self.attrs)
1952
+
1953
+ def update_nxdata(self, data_point):
1954
+ """Update `self.nxdata`'s NXfield values.
1955
+
1956
+ :param data_point: Data point defining a point in the
1957
+ dataset's coordinate space and the new signal values at
1958
+ that point.
1959
+ :type data_point: dict[str, object]
1960
+ :returns: None
1961
+ """
1962
+ index = self.get_index(data_point)
1963
+ for s in self.signals:
1964
+ if s['name'] in data_point:
1965
+ self.nxdata[s['name']][index] = data_point[s['name']]
1966
+
1967
+ def get_index(self, data_point):
1968
+ """Return a tuple representing the array index of `data_point`
1969
+ in the coordinate space of the dataset.
1970
+
1971
+ :param data_point: Data point defining a point in the
1972
+ dataset's coordinate space.
1973
+ :type data_point: dict[str, object]
1974
+ :returns: Multi-dimensional index of `data_point` in the
1975
+ dataset's coordinate space.
1976
+ :rtype: tuple
1977
+ """
1978
+ return tuple(c['values'].index(data_point[c['name']]) \
1979
+ for c in self.coords)
1980
+
1981
+
1982
+ class UpdateNXdataProcessor(Processor):
1983
+ """Processor to fill in part(s) of an `NXdata` representing a
1984
+ structured dataset that's already been written to a NeXus file.
1985
+
1986
+ This Processor is most useful as an "update" step for an `NXdata`
1987
+ created by `common.SetupNXdataProcessor`, and is easitest to use
1988
+ in a `Pipeline` immediately after another `PipelineItem` designed
1989
+ specifically to return a value that can be used as input to this
1990
+ `Processor`.
1991
+
1992
+ Example of use in a `Pipeline` configuration:
1993
+ ```yaml
1994
+ config:
1995
+ inputdir: /rawdata/samplename
1996
+ pipeline:
1997
+ - edd.UpdateNXdataReader:
1998
+ spec_file: spec.log
1999
+ scan_number: 1
2000
+ - common.SetupNXdataProcessor:
2001
+ nxfilename: /reduceddata/samplename/data.nxs
2002
+ nxdata_path: /entry/samplename_dataset_1
2003
+ ```
2004
+ """
2005
+
2006
+ def process(self, data, nxfilename, nxdata_path, data_points=[],
2007
+ allow_approximate_coordinates=True):
2008
+ """Write new data points to the signal fields of an existing
2009
+ `NXdata` object representing a structued dataset in a NeXus
2010
+ file. Return the list of data points used to update the
2011
+ dataset.
2012
+
2013
+ :param data: Data from the previous item in a `Pipeline`. May
2014
+ contain a list of data points that will extend the list of
2015
+ data points optionally provided with the `data_points`
2016
+ argument.
2017
+ :type data: list[PipelineData]
2018
+ :param nxfilename: Name of the NeXus file containing the
2019
+ `NXdata` to update.
2020
+ :type nxfilename: str
2021
+ :param nxdata_path: The path to the `NXdata` to update in the file.
2022
+ :type nxdata_path: str
2023
+ :param data_points: List of data points, each one a dictionary
2024
+ whose keys are the names of the coordinates and axes, and
2025
+ whose values are the values of each coordinate / signal at
2026
+ a single point in the dataset. Deafults to [].
2027
+ :type data_points: list[dict[str, object]]
2028
+ :param allow_approximate_coordinates: Parameter to allow the
2029
+ nearest existing match for the new data points'
2030
+ coordinates to be used if an exact match connot be found
2031
+ (sometimes this is due simply to differences in rounding
2032
+ convetions). Defaults to True.
2033
+ :type allow_approximate_coordinates: bool, optional
2034
+ :returns: Complete list of data points used to update the dataset.
2035
+ :rtype: list[dict[str, object]]
2036
+ """
2037
+ from nexusformat.nexus import NXFile
2038
+ import numpy as np
2039
+ import os
2040
+
2041
+ _data_points = self.unwrap_pipelinedata(data)[0]
2042
+ if isinstance(_data_points, list):
2043
+ data_points.extend(_data_points)
2044
+ self.logger.info(f'Updating {len(data_points)} data points')
2045
+
2046
+ nxfile = NXFile(nxfilename, 'rw')
2047
+ nxdata = nxfile.readfile()[nxdata_path]
2048
+ axes_names = [a.nxname for a in nxdata.nxaxes]
2049
+
2050
+ data_points_used = []
2051
+ for i, d in enumerate(data_points):
2052
+ # Verify that the data point contains a value for all
2053
+ # coordinates in the dataset.
2054
+ if not all(a in d for a in axes_names):
2055
+ self.logger.error(
2056
+ f'Data point {i} is missing a value for at least one '
2057
+ + f'axis. Skipping. Axes are: {", ".join(axes_names)}')
2058
+ continue
2059
+ self.logger.info(
2060
+ f'Coordinates for data point {i}: '
2061
+ + ', '.join([f'{a}={d[a]}' for a in axes_names]))
2062
+ # Get the index of the data point in the dataset based on
2063
+ # its values for each coordinate.
2064
+ try:
2065
+ index = tuple(np.where(a.nxdata == d[a.nxname])[0][0] \
2066
+ for a in nxdata.nxaxes)
2067
+ except:
2068
+ if allow_approximate_coordinates:
2069
+ try:
2070
+ index = tuple(
2071
+ np.argmin(np.abs(a.nxdata - d[a.nxname])) \
2072
+ for a in nxdata.nxaxes)
2073
+ self.logger.warning(
2074
+ f'Nearest match for coordinates of data point {i}:'
2075
+ + ', '.join(
2076
+ [f'{a.nxname}={a[_i]}' \
2077
+ for _i, a in zip(index, nxdata.nxaxes)]))
2078
+ except:
2079
+ self.logger.error(
2080
+ f'Cannot get the index of data point {i}. '
2081
+ + f'Skipping.')
2082
+ continue
2083
+ else:
2084
+ self.logger.error(
2085
+ f'Cannot get the index of data point {i}. Skipping.')
2086
+ continue
2087
+ self.logger.info(f'Index of data point {i}: {index}')
2088
+ # Update the signals contained in this data point at the
2089
+ # proper index in the dataset's singal `NXfield`s
2090
+ for k, v in d.items():
2091
+ if k in axes_names:
2092
+ continue
2093
+ try:
2094
+ nxfile.writevalue(
2095
+ os.path.join(nxdata_path, k), np.asarray(v), index)
2096
+ except Exception as e:
2097
+ self.logger.error(
2098
+ f'Error updating signal {k} for new data point '
2099
+ + f'{i} (dataset index {index}): {e}')
2100
+ data_points_used.append(d)
2101
+
2102
+ nxfile.close()
2103
+
2104
+ return data_points_used
2105
+
2106
+
2107
+ class NXdataToDataPointsProcessor(Processor):
2108
+ """Transform an `NXdata` object into a list of dictionaries. Each
2109
+ dictionary represents a single data point in the coordinate space
2110
+ of the dataset. The keys are the names of the signals and axes in
2111
+ the dataset, and the values are a single scalar value (in the case
2112
+ of axes) or the value of the signal at that point in the
2113
+ coordinate space of the dataset (in the case of signals -- this
2114
+ means that values for signals may be any shape, depending on the
2115
+ shape of the signal itself).
2116
+
2117
+ Example of use in a pipeline configuration:
2118
+ ```yaml
2119
+ config:
2120
+ inputdir: /reduceddata/samplename
2121
+ - common.NXdataReader:
2122
+ name: data
2123
+ axes_names:
2124
+ - x
2125
+ - y
2126
+ signal_name: z
2127
+ nxfield_params:
2128
+ - filename: data.nxs
2129
+ nxpath: entry/data/x
2130
+ slice_params:
2131
+ - step: 2
2132
+ - filename: data.nxs
2133
+ nxpath: entry/data/y
2134
+ slice_params:
2135
+ - step: 2
2136
+ - filename: data.nxs
2137
+ nxpath: entry/data/z
2138
+ slice_params:
2139
+ - step: 2
2140
+ - step: 2
2141
+ - common.NXdataToDataPointsProcessor
2142
+ - common.UpdateNXdataProcessor:
2143
+ nxfilename: /reduceddata/samplename/sparsedata.nxs
2144
+ nxdata_path: /entry/data
2145
+ ```
2146
+ """
2147
+ def process(self, data):
2148
+ """Return a list of dictionaries representing the coordinate
2149
+ and signal values at every point in the dataset provided.
2150
+
2151
+ :param data: Input pipeline data containing an `NXdata`.
2152
+ :type data: list[PipelineData]
2153
+ :returns: List of all data points in the dataset.
2154
+ :rtype: list[dict[str,object]]
2155
+ """
2156
+ import numpy as np
2157
+
2158
+ nxdata = self.unwrap_pipelinedata(data)[0]
2159
+
2160
+ data_points = []
2161
+ axes_names = [a.nxname for a in nxdata.nxaxes]
2162
+ self.logger.info(f'Dataset axes: {axes_names}')
2163
+ dataset_shape = tuple([a.size for a in nxdata.nxaxes])
2164
+ self.logger.info(f'Dataset shape: {dataset_shape}')
2165
+ signal_names = [k for k, v in nxdata.entries.items() \
2166
+ if not k in axes_names \
2167
+ and v.shape[:len(dataset_shape)] == dataset_shape]
2168
+ self.logger.info(f'Dataset signals: {signal_names}')
2169
+ other_fields = [k for k, v in nxdata.entries.items() \
2170
+ if not k in axes_names + signal_names]
2171
+ if len(other_fields) > 0:
2172
+ self.logger.warning(
2173
+ 'Ignoring the following fields that cannot be interpreted as '
2174
+ + f'either dataset coordinates or signals: {other_fields}')
2175
+ for i in np.ndindex(dataset_shape):
2176
+ data_points.append({**{a: nxdata[a][_i] \
2177
+ for a, _i in zip(axes_names, i)},
2178
+ **{s: nxdata[s].nxdata[i] \
2179
+ for s in signal_names}})
2180
+ return data_points
2181
+
2182
+
748
2183
  class XarrayToNexusProcessor(Processor):
749
2184
  """A Processor to convert the data in an `xarray` structure to a
750
2185
  NeXus NXdata object.