ChessAnalysisPipeline 0.0.13__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ChessAnalysisPipeline might be problematic. Click here for more details.

CHAP/common/processor.py CHANGED
@@ -19,64 +19,145 @@ class AnimationProcessor(Processor):
19
19
  """A Processor to show and return an animation.
20
20
  """
21
21
  def process(
22
- self, data, num_frames, axis=0, interval=1000, blit=True,
23
- repeat=True, repeat_delay=1000):
22
+ self, data, num_frames, vmin=None, vmax=None, axis=None,
23
+ interval=1000, blit=True, repeat=True, repeat_delay=1000,
24
+ interactive=False):
24
25
  """Show and return an animation of image slices from a dataset
25
26
  contained in `data`.
26
27
 
27
28
  :param data: Input data.
28
- :type data: CHAP.pipeline.PipelineData
29
+ :type data: list[PipelineData]
29
30
  :param num_frames: Number of frames for the animation.
30
31
  :type num_frames: int
31
- :param axis: Axis direction of the image slices,
32
+ :param vmin: Minimum array value in image slice, default to
33
+ `None`, which uses the actual minimum value in the slice.
34
+ :type vmin: float
35
+ :param vmax: Maximum array value in image slice, default to
36
+ `None`, which uses the actual maximum value in the slice.
37
+ :type vmax: float
38
+ :param axis: Axis direction or name of the image slices,
32
39
  defaults to `0`
33
- :type axis: int, optional
34
- :param interval: Delay between frames in milliseconds,
35
- defaults to `1000`
40
+ :type axis: Union[int, str], optional
41
+ :param interval: Delay between frames in milliseconds (only
42
+ used when interactive=True), defaults to `1000`
36
43
  :type interval: int, optional
37
44
  :param blit: Whether blitting is used to optimize drawing,
38
45
  default to `True`
39
46
  :type blit: bool, optional
40
47
  :param repeat: Whether the animation repeats when the sequence
41
- of frames is completed, defaults to `True`
48
+ of frames is completed (only used when interactive=True),
49
+ defaults to `True`
42
50
  :type repeat: bool, optional
43
51
  :param repeat_delay: Delay in milliseconds between consecutive
44
- animation runs if repeat is `True`, defaults to `1000`
52
+ animation runs if repeat is `True` (only used when
53
+ interactive=True), defaults to `1000`
45
54
  :type repeat_delay: int, optional
55
+ :param interactive: Allows for user interactions, defaults to
56
+ `False`.
57
+ :type interactive: bool, optional
46
58
  :return: The matplotlib animation.
47
59
  :rtype: matplotlib.animation.ArtistAnimation
48
60
  """
61
+ # System modules
62
+ from os.path import (
63
+ isabs,
64
+ join,
65
+ )
66
+
49
67
  # Third party modules
50
68
  import matplotlib.animation as animation
51
69
  import matplotlib.pyplot as plt
52
70
 
71
+ # Get the default Nexus NXdata object
72
+ data = self.unwrap_pipelinedata(data)[0]
73
+ try:
74
+ nxdata = data.get_default()
75
+ except:
76
+ if nxdata.nxclass != 'NXdata':
77
+ raise ValueError('Invalid default pathway to an NXdata object '
78
+ f'in ({data})')
79
+
53
80
  # Get the frames
54
- data = self.unwrap_pipelinedata(data)[-1]
55
- delta = int(data.shape[axis]/(num_frames+1))
56
- indices = np.linspace(delta, data.shape[axis]-delta, num_frames)
57
- if data.ndim == 3:
81
+ axes = nxdata.attrs.get('axes', None)
82
+ title = f'{nxdata.nxpath}/{nxdata.signal}'
83
+ if nxdata.nxsignal.ndim == 2:
84
+ exit('AnimationProcessor not tested yet for a 2D dataset')
85
+ elif nxdata.nxsignal.ndim == 3:
86
+ if isinstance(axis, int):
87
+ if not 0 <= axis < nxdata.nxsignal.ndim:
88
+ raise ValueError(f'axis index out of range ({axis} not in '
89
+ f'[0, {nxdata.nxsignal.ndim-1}])')
90
+ axis_name = 'axis {axis}'
91
+ elif isinstance(axis, str):
92
+ if axes is None or axis not in list(axes.nxdata):
93
+ raise ValueError(
94
+ f'Unable to match axis = {axis} in {nxdata.tree}')
95
+ axes = list(axes.nxdata)
96
+ axis_name = axis
97
+ axis = axes.index(axis)
98
+ else:
99
+ raise ValueError(f'Invalid parameter axis ({axis})')
100
+ delta = int(nxdata.nxsignal.shape[axis]/(num_frames+1))
101
+ indices = np.linspace(
102
+ delta, nxdata.nxsignal.shape[axis]-delta, num_frames)
58
103
  if not axis:
59
- frames = [data[int(index)] for index in indices]
104
+ frames = [nxdata[nxdata.signal][int(index),:,:]
105
+ for index in indices]
60
106
  elif axis == 1:
61
- frames = [data[:,int(index),:] for index in indices]
107
+ frames = [nxdata[nxdata.signal][:,int(index),:]
108
+ for index in indices]
62
109
  elif axis == 2:
63
- frames = [data[:,:,int(index)] for index in indices]
110
+ frames = [nxdata[nxdata.signal][:,:,int(index)]
111
+ for index in indices]
112
+ if axes is None:
113
+ axes = [i for i in range(3) if i != axis]
114
+ row_coords = range(a.shape[1])
115
+ row_label = f'axis {axes[1]} index'
116
+ column_coords = range(a.shape[0])
117
+ column_label = f'axis {axes[0]} index'
118
+ else:
119
+ axes.pop(axis)
120
+ row_coords = nxdata[axes[1]].nxdata
121
+ row_label = axes[1]
122
+ if 'units' in nxdata[axes[1]].attrs:
123
+ row_label += f' ({nxdata[axes[1]].units})'
124
+ column_coords = nxdata[axes[0]].nxdata
125
+ column_label = axes[0]
126
+ if 'units' in nxdata[axes[0]].attrs:
127
+ column_label += f' ({nxdata[axes[0]].units})'
64
128
  else:
65
129
  raise ValueError('Invalid data dimension (must be 2D or 3D)')
66
130
 
67
- fig = plt.figure()
68
- # vmin = np.min(frames)/8
69
- # vmax = np.max(frames)/8
131
+
132
+ # Create the movie
133
+ if vmin is None or vmax is None:
134
+ a_max = frames[0].max()
135
+ for n in range(1, num_frames):
136
+ a_max = min(a_max, frames[n].max())
137
+ if vmin is None:
138
+ vmin = -a_max
139
+ if vmax is None:
140
+ vmax = a_max
141
+ extent = (
142
+ row_coords[0], row_coords[-1], column_coords[-1], column_coords[0])
143
+ fig, ax = plt.subplots(figsize=(11, 8.5))
144
+ ax.set_title(title, fontsize='xx-large', pad=20)
145
+ ax.set_xlabel(row_label, fontsize='x-large')
146
+ ax.set_ylabel(column_label, fontsize='x-large')
147
+ fig.tight_layout()
70
148
  ims = [[plt.imshow(
71
- #frames[n], vmin=vmin,vmax=vmax, cmap='gray',
72
- frames[n], cmap='gray',
149
+ frames[n], extent=extent, origin='lower',
150
+ vmin=vmin, vmax=vmax, cmap='gray',
73
151
  animated=True)]
74
152
  for n in range(num_frames)]
75
- ani = animation.ArtistAnimation(
76
- fig, ims, interval=interval, blit=blit, repeat=repeat,
77
- repeat_delay=repeat_delay)
78
-
79
- plt.show()
153
+ plt.colorbar()
154
+ if interactive:
155
+ ani = animation.ArtistAnimation(
156
+ fig, ims, interval=interval, blit=blit, repeat=repeat,
157
+ repeat_delay=repeat_delay)
158
+ plt.show()
159
+ else:
160
+ ani = animation.ArtistAnimation(fig, ims, blit=blit)
80
161
 
81
162
  return ani
82
163
 
@@ -129,43 +210,561 @@ class AsyncProcessor(Processor):
129
210
  asyncio.run(execute_tasks(self.mgr, data))
130
211
 
131
212
 
213
+ class BinarizeProcessor(Processor):
214
+ """A Processor to binarize a dataset.
215
+ """
216
+ def process(
217
+ self, data, nxpath='', interactive=False, method='CHAP',
218
+ num_bin=256, axis=None, remove_original_data=False):
219
+ """Show and return a binarized dataset from a dataset
220
+ contained in `data`. The dataset must either be of type
221
+ `numpy.ndarray` or a NeXus NXobject object with a default path
222
+ to a NeXus NXfield object.
223
+
224
+ :param data: Input data.
225
+ :type data: list[PipelineData]
226
+ :param nxpath: The relative path to a specific NeXus NXentry or
227
+ NeXus NXdata object in the NeXus file tree to read the
228
+ input data from (ignored for Numpy or NeXus NXfield input
229
+ datasets), defaults to `''`
230
+ :type nxpath: str, optional
231
+ :param interactive: Allows for user interactions (ignored
232
+ for any method other than `'manual'`), defaults to `False`.
233
+ :type interactive: bool, optional
234
+ :param method: Binarization method, defaults to `'CHAP'`
235
+ (CHAP's internal implementation of Otzu's method).
236
+ :type method: Literal['CHAP', 'manual', 'otsu', 'yen', 'isodata',
237
+ 'minimum']
238
+ :param num_bin: The number of bins used to calculate the
239
+ histogram in the binarization algorithms (ignored for
240
+ method = `'manual'`), defaults to `256`.
241
+ :type num_bin: int, optional
242
+ :param axis: Axis direction of the image slices (ignored
243
+ for any method other than `'manual'`), defaults to `None`
244
+ :type axis: int, optional
245
+ :param remove_original_data: Removes the original data field
246
+ (ignored for Numpy input datasets), defaults to `False`.
247
+ :type force_remove_original_data: bool, optional
248
+ :raises ValueError: Upon invalid input parameters.
249
+ :return: The binarized dataset with a return type equal to
250
+ that of the input dataset.
251
+ :rtype: numpy.ndarray, nexusformat.nexus.NXobject
252
+ """
253
+ # System modules
254
+ from os.path import join as os_join
255
+ from os.path import relpath
256
+
257
+ # Local modules
258
+ from CHAP.utils.general import (
259
+ is_int,
260
+ nxcopy,
261
+ )
262
+ from nexusformat.nexus import (
263
+ NXdata,
264
+ NXfield,
265
+ NXlink,
266
+ NXprocess,
267
+ nxsetconfig,
268
+ )
269
+
270
+ if method not in [
271
+ 'CHAP', 'manual', 'otsu', 'yen', 'isodata', 'minimum']:
272
+ raise ValueError(f'Invalid parameter method ({method})')
273
+ if not is_int(num_bin, gt=0):
274
+ raise ValueError(f'Invalid parameter num_bin ({num_bin})')
275
+ if not isinstance(remove_original_data, bool):
276
+ raise ValueError('Invalid parameter remove_original_data '
277
+ f'({remove_original_data})')
278
+
279
+ nxsetconfig(memory=100000)
280
+
281
+ # Get the dataset and make a copy if it is a NeXus NXgroup
282
+ dataset = self.unwrap_pipelinedata(data)[-1]
283
+ if isinstance(dataset, np.ndarray):
284
+ if method == 'manual':
285
+ if axis is not None and not is_int(axis, gt=0, lt=3):
286
+ raise ValueError(f'Invalid parameter axis ({axis})')
287
+ axes = ['i', 'j', 'k']
288
+ data = dataset
289
+ elif isinstance(dataset, NXfield):
290
+ if method == 'manual':
291
+ if axis is not None and not is_int(axis, gt=0, lt=3):
292
+ raise ValueError(f'Invalid parameter axis ({axis})')
293
+ axes = ['i', 'j', 'k']
294
+ if isinstance(dataset, NXfield):
295
+ if nxpath not in ('', '/'):
296
+ self.logger.warning('Ignoring parameter nxpath')
297
+ data = dataset.nxdata
298
+ else:
299
+ try:
300
+ data = dataset[nxpath].nxdata
301
+ except:
302
+ raise ValueError(f'Invalid parameter nxpath ({nxpath})')
303
+ else:
304
+ # Get the default Nexus NXdata object
305
+ try:
306
+ nxdefault = dataset.get_default()
307
+ except:
308
+ nxdefault = None
309
+ if nxdefault is not None and nxdefault.nxclass != 'NXdata':
310
+ raise ValueError('Invalid default pathway NXobject type '
311
+ f'({nxdefault.nxclass})')
312
+ # Get the requested NeXus NXdata object to binarize
313
+ if nxpath is None:
314
+ nxclass = dataset.nxclass
315
+ else:
316
+ try:
317
+ nxclass = dataset[nxpath].nxclass
318
+ except:
319
+ raise ValueError(f'Invalid parameter nxpath ({nxpath})')
320
+ if nxclass == 'NXdata':
321
+ nxdata = dataset[nxpath]
322
+ else:
323
+ if nxdefault is None:
324
+ raise ValueError(f'No default pathway to a NXdata object')
325
+ nxdata = nxdefault
326
+ nxsignal = nxdata.nxsignal
327
+ if method == 'manual':
328
+ if hasattr(nxdata.attrs, 'axes'):
329
+ axes = nxdata.attrs['axes']
330
+ if isinstance(axis, str):
331
+ if axis not in axes:
332
+ raise ValueError(f'Invalid parameter axis ({axis})')
333
+ axis = axes.index(axis)
334
+ elif axis is not None and not is_int(axis, gt=0, lt=3):
335
+ raise ValueError(f'Invalid parameter axis ({axis})')
336
+ else:
337
+ axes = ['i', 'j', 'k']
338
+ if nxsignal.ndim != 3:
339
+ raise ValueError('Invalid data dimension (must be 3D)')
340
+ data = nxsignal.nxdata
341
+ # Create a copy of the input NeXus object, removing the
342
+ # default NeXus NXdata object as well as the original
343
+ # dateset if the remove_original_data parameter is set
344
+ exclude_nxpaths = []
345
+ if nxdefault is not None:
346
+ exclude_nxpaths.append(
347
+ os_join(relpath(nxdefault.nxpath, dataset.nxpath)))
348
+ if remove_original_data:
349
+ if (nxdefault is None
350
+ or nxdefault.nxpath != nxdata.nxpath):
351
+ relpath_nxdata = relpath(nxdata.nxpath, dataset.nxpath)
352
+ keys = list(nxdata.keys())
353
+ keys.remove(nxsignal.nxname)
354
+ for axis in nxdata.axes:
355
+ keys.remove(axis)
356
+ if len(keys):
357
+ raise RuntimeError('Not tested yet')
358
+ exclude_nxpaths.append(os_join(
359
+ relpath(nxsignal.nxpath, dataset.nxpath)))
360
+ elif relpath_nxdata == '.':
361
+ exclude_nxpaths.append(nxsignal.nxname)
362
+ if dataset.nxclass != 'NXdata':
363
+ exclude_nxpaths += nxdata.axes
364
+ else:
365
+ exclude_nxpaths.append(relpath_nxdata)
366
+ if not (dataset.nxclass == 'NXdata'
367
+ or nxdata.nxsignal.nxtarget is None):
368
+ nxsignal = dataset[nxsignal.nxtarget]
369
+ nxgroup = nxsignal.nxgroup
370
+ keys = list(nxgroup.keys())
371
+ keys.remove(nxsignal.nxname)
372
+ for axis in nxgroup.axes:
373
+ keys.remove(axis)
374
+ if len(keys):
375
+ raise RuntimeError('Not tested yet')
376
+ exclude_nxpaths.append(os_join(
377
+ relpath(nxsignal.nxpath, dataset.nxpath)))
378
+ else:
379
+ exclude_nxpaths.append(os_join(
380
+ relpath(nxgroup.nxpath, dataset.nxpath)))
381
+ nxobject = nxcopy(dataset, exclude_nxpaths=exclude_nxpaths)
382
+
383
+ # Get a histogram of the data
384
+ if method not in ['manual', 'yen']:
385
+ counts, edges = np.histogram(data, bins=num_bin)
386
+ centers = edges[:-1] + 0.5 * np.diff(edges)
387
+
388
+ # Calculate the data cutoff threshold
389
+ if method == 'CHAP':
390
+ weights = np.cumsum(counts)
391
+ means = np.cumsum(counts * centers)
392
+ weights = weights[0:-1]/weights[-1]
393
+ means = means[0:-1]/means[-1]
394
+ variances = (means-weights)**2/(weights*(1.-weights))
395
+ threshold = centers[np.argmax(variances)]
396
+ elif method == 'otsu':
397
+ # Third party modules
398
+ from skimage.filters import threshold_otsu
399
+
400
+ threshold = threshold_otsu(hist=(counts, centers))
401
+ elif method == 'yen':
402
+ # Third party modules
403
+ from skimage.filters import threshold_yen
404
+
405
+ _min = data.min()
406
+ _max = data.max()
407
+ data = 1+(num_bin-1)*(data-_min)/(_max-_min)
408
+ counts, edges = np.histogram(data, bins=num_bin)
409
+ centers = edges[:-1] + 0.5 * np.diff(edges)
410
+
411
+ threshold = threshold_yen(hist=(counts, centers))
412
+ elif method == 'isodata':
413
+ # Third party modules
414
+ from skimage.filters import threshold_isodata
415
+
416
+ threshold = threshold_isodata(hist=(counts, centers))
417
+ elif method == 'minimum':
418
+ # Third party modules
419
+ from skimage.filters import threshold_minimum
420
+
421
+ threshold = threshold_minimum(hist=(counts, centers))
422
+ else:
423
+ # Third party modules
424
+ import matplotlib.pyplot as plt
425
+ from matplotlib.widgets import RadioButtons, Button
426
+
427
+ # Local modules
428
+ from CHAP.utils.general import (
429
+ select_roi_1d,
430
+ select_roi_2d,
431
+ )
432
+
433
+ def select_direction(direction):
434
+ """Callback function for the "Select direction" input."""
435
+ selected_direction.append(radio_btn.value_selected)
436
+ plt.close()
437
+
438
+ def accept(event):
439
+ """Callback function for the "Accept" button."""
440
+ selected_direction.append(radio_btn.value_selected)
441
+ plt.close()
442
+
443
+ # Select the direction for data averaging
444
+ if axis is not None:
445
+ mean_data = data.mean(axis=axis)
446
+ subaxes = [i for i in range(3) if i != axis]
447
+ else:
448
+ selected_direction = []
449
+
450
+ # Setup figure
451
+ title_pos = (0.5, 0.95)
452
+ title_props = {'fontsize': 'xx-large',
453
+ 'horizontalalignment': 'center',
454
+ 'verticalalignment': 'bottom'}
455
+ fig, axs = plt.subplots(ncols=3, figsize=(17, 8.5))
456
+ mean_data = []
457
+ for i, ax in enumerate(axs):
458
+ mean_data.append(data.mean(axis=i))
459
+ subaxes = [a for a in axes if a != axes[i]]
460
+ ax.imshow(mean_data[i], aspect='auto', cmap='gray')
461
+ ax.set_title(
462
+ f'Data averaged in {axes[i]}-direction',
463
+ fontsize='x-large')
464
+ ax.set_xlabel(subaxes[1], fontsize='x-large')
465
+ ax.set_ylabel(subaxes[0], fontsize='x-large')
466
+ fig_title = plt.figtext(
467
+ *title_pos,
468
+ 'Select a direction or press "Accept" for the default one '
469
+ f'({axes[0]}) to obtain the binary threshold value',
470
+ **title_props)
471
+ fig.subplots_adjust(bottom=0.25, top=0.85)
472
+
473
+ # Setup RadioButtons
474
+ select_text = plt.figtext(
475
+ 0.225, 0.175, 'Averaging direction', fontsize='x-large',
476
+ horizontalalignment='center', verticalalignment='center')
477
+ radio_btn = RadioButtons(
478
+ plt.axes([0.175, 0.05, 0.1, 0.1]), labels=axes, active=0)
479
+ radio_cid = radio_btn.on_clicked(select_direction)
480
+
481
+ # Setup "Accept" button
482
+ accept_btn = Button(
483
+ plt.axes([0.7, 0.05, 0.15, 0.075]), 'Accept')
484
+ accept_cid = accept_btn.on_clicked(accept)
485
+
486
+ plt.show()
487
+
488
+ axis = axes.index(selected_direction[0])
489
+ mean_data = mean_data[axis]
490
+ subaxes = [a for a in axes if a != axes[axis]]
491
+
492
+ plt.close()
493
+
494
+ # Select the ROI's orthogonal to the selected averaging direction
495
+ bounds = []
496
+ for i, bound in enumerate(['"0"', '"1"']):
497
+ _, roi = select_roi_2d(
498
+ mean_data,
499
+ title=f'Select the ROI to obtain the {bound} data value',
500
+ title_a=f'Data averaged in the {axes[axis]}-direction',
501
+ row_label=subaxes[0], column_label=subaxes[1])
502
+ plt.close()
503
+
504
+ # Select the index range in the selected averaging direction
505
+ if not axis:
506
+ mean_roi_data = data[:,roi[2]:roi[3],roi[0]:roi[1]].mean(
507
+ axis=(1,2))
508
+ elif axis == 1:
509
+ mean_roi_data = data[roi[2]:roi[3],:,roi[0]:roi[1]].mean(
510
+ axis=(0,2))
511
+ elif axis == 2:
512
+ mean_roi_data = data[roi[2]:roi[3],roi[0]:roi[1],:].mean(
513
+ axis=(0,1))
514
+
515
+ _, _range = select_roi_1d(
516
+ mean_roi_data, preselected_roi=(0, data.shape[axis]),
517
+ title=f'Select the {axes[axis]}-direction range to obtain '
518
+ f'the {bound} data bound',
519
+ xlabel=axes[axis], ylabel='Average data')
520
+ plt.close()
521
+
522
+ # Obtain the lower/upper data bound
523
+ if not axis:
524
+ bounds.append(
525
+ data[
526
+ _range[0]:_range[1],roi[2]:roi[3],roi[0]:roi[1]
527
+ ].mean())
528
+ elif axis == 1:
529
+ bounds.append(
530
+ data[
531
+ roi[2]:roi[3],_range[0]:_range[1],roi[0]:roi[1]
532
+ ].mean())
533
+ elif axis == 2:
534
+ bounds.append(
535
+ data[
536
+ roi[2]:roi[3],roi[0]:roi[1],_range[0]:_range[1]
537
+ ].mean())
538
+
539
+ # Get the data cutoff threshold
540
+ threshold = np.mean(bounds)
541
+
542
+ # Apply the data cutoff threshold and return the output
543
+ data = np.where(data<threshold, 0, 1).astype(np.ubyte)
544
+ # from CHAP.utils.general import quick_imshow
545
+ # quick_imshow(data[int(data.shape[0]/2),:,:], block=True)
546
+ # quick_imshow(data[:,int(data.shape[1]/2),:], block=True)
547
+ # quick_imshow(data[:,:,int(data.shape[2]/2)], block=True)
548
+ if isinstance(dataset, np.ndarray):
549
+ return data
550
+ if isinstance(dataset, NXfield):
551
+ attrs = dataset.attrs
552
+ attrs.pop('target', None)
553
+ return NXfield(
554
+ value=data, name=dataset.nxname, attrs=dataset.attrs)
555
+ name = nxsignal.nxname + '_binarized'
556
+ if nxobject.nxclass == 'NXdata':
557
+ nxobject[name] = data
558
+ nxobject.attrs['signal'] = name
559
+ return nxobject
560
+ if nxobject.nxclass == 'NXroot':
561
+ nxentry = nxobject[nxobject.default]
562
+ else:
563
+ nxentry = nxobject
564
+ axes = []
565
+ for axis in nxdata.axes:
566
+ attrs = nxdata[axis].attrs
567
+ attrs.pop('target', None)
568
+ axes.append(
569
+ NXfield(nxdata[axis], name=axis, attrs=attrs))
570
+ nxentry[name] = NXprocess(
571
+ NXdata(NXfield(data, name=name), axes),
572
+ attrs={'source': nxsignal.nxpath})
573
+ nxdata = nxentry[name].data
574
+ nxentry.data = NXdata(
575
+ NXlink(nxdata.nxsignal.nxpath),
576
+ [NXlink(os_join(nxdata.nxpath, axis)) for axis in nxdata.axes])
577
+ return nxobject
578
+
579
+
132
580
  class ImageProcessor(Processor):
133
- """A Processor to plot an image slice from a dataset.
581
+ """A Processor to plot an image (slice) from a NeXus object.
134
582
  """
135
- def process(self, data, index=0, axis=0):
136
- """Plot an image from a dataset contained in `data` and return
137
- the full dataset.
583
+ def process(
584
+ self, data, vmin=None, vmax=None, axis=0, index=None,
585
+ coord=None, interactive=False, save_figure=True, outputdir='.',
586
+ filename='image.png'):
587
+ """Plot and/or save an image (slice) from a NeXus NXobject object with
588
+ a default data path contained in `data` and return the NeXus NXdata
589
+ data object.
138
590
 
139
591
  :param data: Input data.
140
- :type data: CHAP.pipeline.PipelineData
141
- :param index: Array index of the slice of data to plot,
592
+ :type data: list[PipelineData]
593
+ :param vmin: Minimum array value in image slice, default to
594
+ `None`, which uses the actual minimum value in the slice.
595
+ :type vmin: float
596
+ :param vmax: Maximum array value in image slice, default to
597
+ `None`, which uses the actual maximum value in the slice.
598
+ :type vmax: float
599
+ :param axis: Axis direction or name of the image slice,
142
600
  defaults to `0`
601
+ :type axis: Union[int, str], optional
602
+ :param index: Array index of the slice of data to plot,
603
+ defaults to `None`
143
604
  :type index: int, optional
144
- :param axis: Axis direction of the image slice,
145
- defaults to `0`
146
- :type axis: int, optional
147
- :return: The full input dataset.
148
- :rtype: object
605
+ :param coord: Coordinate value of the slice of data to plot,
606
+ defaults to `None`
607
+ :type coord: Union[int, float], optional
608
+ :param interactive: Allows for user interactions, defaults to
609
+ `False`.
610
+ :type interactive: bool, optional
611
+ :param save_figure: Save a .png of the image, defaults to `True`.
612
+ :type save_figure: bool, optional
613
+ :param outputdir: Directory to which any output figure will
614
+ be saved, defaults to `'.'`
615
+ :type outputdir: str, optional
616
+ :param filename: Image filename, defaults to `"image.png"`.
617
+ :type filename: str, optional
618
+ :return: The input data object.
619
+ :rtype: nexusformat.nexus.NXdata
149
620
  """
150
- # Local modules
151
- from CHAP.utils.general import quick_imshow
621
+ # System modules
622
+ from os.path import (
623
+ isabs,
624
+ join,
625
+ )
626
+
627
+ # Third party modules
628
+ import matplotlib.pyplot as plt
152
629
 
630
+ # Local modules
631
+ from CHAP.utils.general import index_nearest
632
+
633
+ # Validate input parameters
634
+ if not isinstance(interactive, bool):
635
+ raise ValueError(f'Invalid parameter interactive ({interactive})')
636
+ if not isinstance(save_figure, bool):
637
+ raise ValueError(f'Invalid parameter save_figure ({save_figure})')
638
+ if not isinstance(outputdir, str):
639
+ raise ValueError(f'Invalid parameter outputdir ({outputdir})')
640
+ if not isinstance(filename, str):
641
+ raise ValueError(f'Invalid parameter filename ({filename})')
642
+ if not isabs(filename):
643
+ filename = join(outputdir, filename)
644
+
645
+ # Get the default Nexus NXdata object
153
646
  data = self.unwrap_pipelinedata(data)[0]
154
- if data.ndim == 2:
155
- quick_imshow(data, block=True)
156
- elif data.ndim == 3:
647
+ try:
648
+ nxdata = data.get_default()
649
+ except:
650
+ if nxdata.nxclass != 'NXdata':
651
+ raise ValueError('Invalid default pathway to an NXdata object '
652
+ f'in ({data})')
653
+
654
+ # Get the data slice
655
+ axes = nxdata.attrs.get('axes', None)
656
+ if axes is not None:
657
+ axes = list(axes.nxdata)
658
+ coords = None
659
+ title = f'{nxdata.nxpath}/{nxdata.signal}'
660
+ if nxdata.nxsignal.ndim == 2:
661
+ exit('ImageProcessor not tested yet for a 2D dataset')
662
+ if axis is not None:
663
+ axis = None
664
+ self.logger.warning('Ignoring parameter axis')
665
+ if index is not None:
666
+ index = None
667
+ self.logger.warning('Ignoring parameter index')
668
+ if coord is not None:
669
+ coord = None
670
+ self.logger.warning('Ignoring parameter coord')
671
+ a = nxdata.nxsignal
672
+ elif nxdata.nxsignal.ndim == 3:
673
+ if isinstance(axis, int):
674
+ if not 0 <= axis < nxdata.nxsignal.ndim:
675
+ raise ValueError(f'axis index out of range ({axis} not in '
676
+ f'[0, {nxdata.nxsignal.ndim-1}])')
677
+ elif isinstance(axis, str):
678
+ if axes is None or axis not in axes:
679
+ raise ValueError(
680
+ f'Unable to match axis = {axis} in {nxdata.tree}')
681
+ axis = axes.index(axis)
682
+ else:
683
+ raise ValueError(f'Invalid parameter axis ({axis})')
684
+ if axes is not None and hasattr(nxdata, axes[axis]):
685
+ coords = nxdata[axes[axis]].nxdata
686
+ axis_name = axes[axis]
687
+ else:
688
+ axis_name = f'axis {axis}'
689
+ if index is None and coord is None:
690
+ index = nxdata.nxsignal.shape[axis] // 2
691
+ else:
692
+ if index is not None:
693
+ if coord is not None:
694
+ coord = None
695
+ self.logger.warning('Ignoring parameter coord')
696
+ if not isinstance(index, int):
697
+ raise ValueError(f'Invalid parameter index ({index})')
698
+ elif not 0 <= index < nxdata.nxsignal.shape[axis]:
699
+ raise ValueError(
700
+ f'index value out of range ({index} not in '
701
+ f'[0, {nxdata.nxsignal.shape[axis]-1}])')
702
+ else:
703
+ if not isinstance(coord, (int, float)):
704
+ raise ValueError(f'Invalid parameter coord ({coord})')
705
+ if coords is None:
706
+ raise ValueError(
707
+ f'Unable to get coordinates for {axis_name} '
708
+ f'in {nxdata.tree}')
709
+ index = index_nearest(nxdata[axis_name], coord)
710
+ if coords is None:
711
+ slice_info = f'slice at {axis_name} and index {index}'
712
+ else:
713
+ coord = coords[index]
714
+ slice_info = f'slice at {axis_name} = '\
715
+ f'{nxdata[axis_name][index]:.3f}'
716
+ if 'units' in nxdata[axis_name].attrs:
717
+ slice_info += f' ({nxdata[axis_name].units})'
157
718
  if not axis:
158
- quick_imshow(data[index], block=True)
719
+ a = nxdata[nxdata.signal][index,:,:]
159
720
  elif axis == 1:
160
- quick_imshow(data[:,index,:], block=True)
721
+ a = nxdata[nxdata.signal][:,index,:]
161
722
  elif axis == 2:
162
- quick_imshow(data[:,:,index], block=True)
723
+ a = nxdata[nxdata.signal][:,:,index]
724
+ if coords is None:
725
+ axes = [i for i in range(3) if i != axis]
726
+ row_coords = range(a.shape[1])
727
+ row_label = f'axis {axes[1]} index'
728
+ column_coords = range(a.shape[0])
729
+ column_label = f'axis {axes[0]} index'
163
730
  else:
164
- raise ValueError(f'Invalid parameter axis ({axis})')
731
+ axes.pop(axis)
732
+ row_coords = nxdata[axes[1]].nxdata
733
+ row_label = axes[1]
734
+ if 'units' in nxdata[axes[1]].attrs:
735
+ row_label += f' ({nxdata[axes[1]].units})'
736
+ column_coords = nxdata[axes[0]].nxdata
737
+ column_label = axes[0]
738
+ if 'units' in nxdata[axes[0]].attrs:
739
+ column_label += f' ({nxdata[axes[0]].units})'
165
740
  else:
166
741
  raise ValueError('Invalid data dimension (must be 2D or 3D)')
167
742
 
168
- return data
743
+ # Create figure
744
+ a_max = a.max()
745
+ if vmin is None:
746
+ vmin = -a_max
747
+ if vmax is None:
748
+ vmax = a_max
749
+ extent = (
750
+ row_coords[0], row_coords[-1], column_coords[-1], column_coords[0])
751
+ fig, ax = plt.subplots(figsize=(11, 8.5))
752
+ plt.imshow(
753
+ a, extent=extent, origin='lower', vmin=vmin, vmax=vmax,
754
+ cmap='gray')
755
+ fig.suptitle(title, fontsize='xx-large')
756
+ ax.set_title(slice_info, fontsize='xx-large', pad=20)
757
+ ax.set_xlabel(row_label, fontsize='x-large')
758
+ ax.set_ylabel(column_label, fontsize='x-large')
759
+ plt.colorbar()
760
+ fig.tight_layout()
761
+ if interactive:
762
+ plt.show()
763
+ if save_figure:
764
+ fig.savefig(filename)
765
+ plt.close()
766
+
767
+ return nxdata
169
768
 
170
769
 
171
770
  class IntegrationProcessor(Processor):
@@ -177,7 +776,7 @@ class IntegrationProcessor(Processor):
177
776
 
178
777
  :param data: Input data, containing the raw data, integration
179
778
  method, and keyword args for the integration method.
180
- :type data: CHAP.pipeline.PipelineData
779
+ :type data: list[PipelineData]
181
780
  :return: Integrated raw data.
182
781
  :rtype: pyFAI.containers.IntegrateResult
183
782
  """
@@ -200,7 +799,7 @@ class IntegrateMapProcessor(Processor):
200
799
  with the value `'MapConfig'` for the `'schema'` key, and at
201
800
  least one item with the value `'IntegrationConfig'` for the
202
801
  `'schema'` key.
203
- :type data: CHAP.pipeline.PipelineData
802
+ :type data: list[PipelineData]
204
803
  :return: Integrated data and process metadata.
205
804
  :rtype: nexusformat.nexus.NXprocess
206
805
  """
@@ -366,7 +965,7 @@ class MapProcessor(Processor):
366
965
 
367
966
  :param data: Result of `Reader.read` where at least one item
368
967
  has the value `'MapConfig'` for the `'schema'` key.
369
- :type data: CHAP.pipeline.PipelineData
968
+ :type data: list[PipelineData]
370
969
  :return: Map data and metadata.
371
970
  :rtype: nexusformat.nexus.NXentry
372
971
  """
@@ -582,7 +1181,7 @@ class RawDetectorDataMapProcessor(Processor):
582
1181
  detector data data collected over the map.
583
1182
 
584
1183
  :param data: Input map configuration.
585
- :type data: CHAP.pipeline.PipelineData
1184
+ :type data: list[PipelineData]
586
1185
  :param detector_name: The detector prefix.
587
1186
  :type detector_name: str
588
1187
  :param detector_shape: The shape of detector data for a single
@@ -602,7 +1201,7 @@ class RawDetectorDataMapProcessor(Processor):
602
1201
 
603
1202
  :param data: Result of `Reader.read` where at least one item
604
1203
  has the value `'MapConfig'` for the `'schema'` key.
605
- :type data: CHAP.pipeline.PipelineData
1204
+ :type data: list[PipelineData]
606
1205
  :raises Exception: If a valid map config object cannot be
607
1206
  constructed from `data`.
608
1207
  :return: A valid instance of the map configuration object with
@@ -709,7 +1308,7 @@ class StrainAnalysisProcessor(Processor):
709
1308
 
710
1309
  :param data: Results of `MutlipleReader.read` containing input
711
1310
  map detector data and strain analysis configuration
712
- :type data: CHAP.pipeline.PipelineData
1311
+ :type data: list[PipelineData]
713
1312
  :return: A map of sample strains.
714
1313
  :rtype: xarray.Dataset
715
1314
  """
@@ -724,7 +1323,7 @@ class StrainAnalysisProcessor(Processor):
724
1323
  :param data: Result of `Reader.read` where at least one item
725
1324
  has the value `'StrainAnalysisConfig'` for the `'schema'`
726
1325
  key.
727
- :type data: CHAP.pipeline.PipelineData
1326
+ :type data: list[PipelineData]
728
1327
  :raises Exception: If valid config objects cannot be
729
1328
  constructed from `data`.
730
1329
  :return: A valid instance of the configuration object with