ChessAnalysisPipeline 0.0.14__py3-none-any.whl → 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ChessAnalysisPipeline might be problematic. Click here for more details.

CHAP/common/processor.py CHANGED
@@ -134,6 +134,7 @@ class AnimationProcessor(Processor):
134
134
  a_max = frames[0].max()
135
135
  for n in range(1, num_frames):
136
136
  a_max = min(a_max, frames[n].max())
137
+ a_max = float(a_max)
137
138
  if vmin is None:
138
139
  vmin = -a_max
139
140
  if vmax is None:
@@ -248,7 +249,7 @@ class BinarizeProcessor(Processor):
248
249
  :raises ValueError: Upon invalid input parameters.
249
250
  :return: The binarized dataset with a return type equal to
250
251
  that of the input dataset.
251
- :rtype: numpy.ndarray, nexusformat.nexus.NXobject
252
+ :rtype: typing.Union[numpy.ndarray, nexusformat.nexus.NXobject]
252
253
  """
253
254
  # System modules
254
255
  from os.path import join as os_join
@@ -494,12 +495,11 @@ class BinarizeProcessor(Processor):
494
495
  # Select the ROI's orthogonal to the selected averaging direction
495
496
  bounds = []
496
497
  for i, bound in enumerate(['"0"', '"1"']):
497
- _, roi = select_roi_2d(
498
+ roi = select_roi_2d(
498
499
  mean_data,
499
500
  title=f'Select the ROI to obtain the {bound} data value',
500
501
  title_a=f'Data averaged in the {axes[axis]}-direction',
501
502
  row_label=subaxes[0], column_label=subaxes[1])
502
- plt.close()
503
503
 
504
504
  # Select the index range in the selected averaging direction
505
505
  if not axis:
@@ -512,12 +512,11 @@ class BinarizeProcessor(Processor):
512
512
  mean_roi_data = data[roi[2]:roi[3],roi[0]:roi[1],:].mean(
513
513
  axis=(0,1))
514
514
 
515
- _, _range = select_roi_1d(
515
+ _range = select_roi_1d(
516
516
  mean_roi_data, preselected_roi=(0, data.shape[axis]),
517
517
  title=f'Select the {axes[axis]}-direction range to obtain '
518
518
  f'the {bound} data bound',
519
519
  xlabel=axes[axis], ylabel='Average data')
520
- plt.close()
521
520
 
522
521
  # Obtain the lower/upper data bound
523
522
  if not axis:
@@ -574,9 +573,259 @@ class BinarizeProcessor(Processor):
574
573
  nxentry.data = NXdata(
575
574
  NXlink(nxdata.nxsignal.nxpath),
576
575
  [NXlink(os_join(nxdata.nxpath, axis)) for axis in nxdata.axes])
576
+ nxentry.data.set_default()
577
577
  return nxobject
578
578
 
579
579
 
580
+ class ConstructBaseline(Processor):
581
+ """A Processor to construct a baseline for a dataset.
582
+ """
583
+ def process(
584
+ self, data, mask=None, tol=1.e-6, lam=1.e6, max_iter=20,
585
+ save_figures=False, outputdir='.', interactive=False):
586
+ """Construct and return the baseline for a dataset.
587
+
588
+ :param data: Input data.
589
+ :type data: list[PipelineData]
590
+ :param mask: A mask to apply to the spectrum before baseline
591
+ construction, default to `None`.
592
+ :type mask: array-like, optional
593
+ :param tol: The convergence tolerence, defaults to `1.e-6`.
594
+ :type tol: float, optional
595
+ :param lam: The &lambda (smoothness) parameter (the balance
596
+ between the residual of the data and the baseline and the
597
+ smoothness of the baseline). The suggested range is between
598
+ 100 and 10^8, defaults to `10^6`.
599
+ :type lam: float, optional
600
+ :param max_iter: The maximum number of iterations,
601
+ defaults to `20`.
602
+ :type max_iter: int, optional
603
+ :param save_figures: Save .pngs of plots for checking inputs &
604
+ outputs of this Processor, defaults to False.
605
+ :type save_figures: bool, optional
606
+ :param outputdir: Directory to which any output figures will
607
+ be saved, defaults to '.'
608
+ :type outputdir: str, optional
609
+ :param interactive: Allows for user interactions, defaults to
610
+ False.
611
+ :type interactive: bool, optional
612
+ :return: The smoothed baseline and the configuration.
613
+ :rtype: numpy.array, dict
614
+ """
615
+ try:
616
+ data = np.asarray(self.unwrap_pipelinedata(data)[0])
617
+ except:
618
+ raise ValueError(
619
+ f'The structure of {data} contains no valid data')
620
+
621
+ return self.construct_baseline(
622
+ data, mask, tol, lam, max_iter, save_figures, outputdir,
623
+ interactive)
624
+
625
+ @staticmethod
626
+ def construct_baseline(
627
+ y, x=None, mask=None, tol=1.e-6, lam=1.e6, max_iter=20, title=None,
628
+ xlabel=None, ylabel=None, interactive=False, filename=None):
629
+ """Construct and return the baseline for a dataset.
630
+
631
+ :param y: Input data.
632
+ :type y: numpy.array
633
+ :param x: Independent dimension (only used when interactive is
634
+ `True` of when filename is set), defaults to `None`.
635
+ :type x: array-like, optional
636
+ :param mask: A mask to apply to the spectrum before baseline
637
+ construction, default to `None`.
638
+ :type mask: array-like, optional
639
+ :param tol: The convergence tolerence, defaults to `1.e-6`.
640
+ :type tol: float, optional
641
+ :param lam: The &lambda (smoothness) parameter (the balance
642
+ between the residual of the data and the baseline and the
643
+ smoothness of the baseline). The suggested range is between
644
+ 100 and 10^8, defaults to `10^6`.
645
+ :type lam: float, optional
646
+ :param max_iter: The maximum number of iterations,
647
+ defaults to `20`.
648
+ :type max_iter: int, optional
649
+ :param xlabel: Label for the x-axis of the displayed figure,
650
+ defaults to `None`.
651
+ :param title: Title for the displayed figure, defaults to `None`.
652
+ :type title: str, optional
653
+ :type xlabel: str, optional
654
+ :param ylabel: Label for the y-axis of the displayed figure,
655
+ defaults to `None`.
656
+ :type ylabel: str, optional
657
+ :param interactive: Allows for user interactions, defaults to
658
+ False.
659
+ :type interactive: bool, optional
660
+ :param filename: Save a .png of the plot to filename, defaults to
661
+ `None`, in which case the plot is not saved.
662
+ :type filename: str, optional
663
+ :return: The smoothed baseline and the configuration.
664
+ :rtype: numpy.array, dict
665
+ """
666
+ # Third party modules
667
+ if interactive or filename is not None:
668
+ from matplotlib.widgets import TextBox, Button
669
+ import matplotlib.pyplot as plt
670
+
671
+ # Local modules
672
+ from CHAP.utils.general import baseline_arPLS
673
+
674
+ def change_fig_subtitle(maxed_out=False, subtitle=None):
675
+ if fig_subtitles:
676
+ fig_subtitles[0].remove()
677
+ fig_subtitles.pop()
678
+ if subtitle is None:
679
+ subtitle = r'$\lambda$ = 'f'{lambdas[-1]:.2e}, '
680
+ if maxed_out:
681
+ subtitle += f'# iter = {num_iters[-1]} (maxed out) '
682
+ else:
683
+ subtitle += f'# iter = {num_iters[-1]} '
684
+ subtitle += f'error = {errors[-1]:.2e}'
685
+ fig_subtitles.append(
686
+ plt.figtext(*subtitle_pos, subtitle, **subtitle_props))
687
+
688
+ def select_lambda(expression):
689
+ """Callback function for the "Select lambda" TextBox.
690
+ """
691
+ if not len(expression):
692
+ return
693
+ try:
694
+ lam = float(expression)
695
+ if lam < 0:
696
+ raise ValueError
697
+ except ValueError:
698
+ change_fig_subtitle(
699
+ subtitle=f'Invalid lambda, enter a positive number')
700
+ else:
701
+ lambdas.pop()
702
+ lambdas.append(10**lam)
703
+ baseline, _, w, num_iter, error = baseline_arPLS(
704
+ y, mask=mask, tol=tol, lam=lambdas[-1], max_iter=max_iter,
705
+ full_output=True)
706
+ num_iters.pop()
707
+ num_iters.append(num_iter)
708
+ errors.pop()
709
+ errors.append(error)
710
+ if num_iter < max_iter:
711
+ change_fig_subtitle()
712
+ else:
713
+ change_fig_subtitle(maxed_out=True)
714
+ baseline_handle.set_ydata(baseline)
715
+ lambda_box.set_val('')
716
+ plt.draw()
717
+
718
+ def continue_iter(event):
719
+ """Callback function for the "Continue" button."""
720
+ baseline, _, w, n_iter, error = baseline_arPLS(
721
+ y, mask=mask, w=weights[-1], tol=tol, lam=lambdas[-1],
722
+ max_iter=max_iter, full_output=True)
723
+ num_iters[-1] += n_iter
724
+ errors.pop()
725
+ errors.append(error)
726
+ if n_iter < max_iter:
727
+ change_fig_subtitle()
728
+ else:
729
+ change_fig_subtitle(maxed_out=True)
730
+ baseline_handle.set_ydata(baseline)
731
+ plt.draw()
732
+ weights.pop()
733
+ weights.append(w)
734
+
735
+ def confirm(event):
736
+ """Callback function for the "Confirm" button."""
737
+ plt.close()
738
+
739
+ baseline, _, w, num_iter, error = baseline_arPLS(
740
+ y, mask=mask, tol=tol, lam=lam, max_iter=max_iter,
741
+ full_output=True)
742
+
743
+ if not interactive and filename is None:
744
+ return baseline
745
+
746
+ lambdas = [lam]
747
+ weights = [w]
748
+ num_iters = [num_iter]
749
+ errors = [error]
750
+ fig_subtitles = []
751
+
752
+ # Check inputs
753
+ if x is None:
754
+ x = np.arange(y.size)
755
+
756
+ # Setup the Matplotlib figure
757
+ title_pos = (0.5, 0.95)
758
+ title_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center',
759
+ 'verticalalignment': 'bottom'}
760
+ subtitle_pos = (0.5, 0.90)
761
+ subtitle_props = {'fontsize': 'x-large',
762
+ 'horizontalalignment': 'center',
763
+ 'verticalalignment': 'bottom'}
764
+ fig, ax = plt.subplots(figsize=(11, 8.5))
765
+ if mask is None:
766
+ ax.plot(x, y, label='input data')
767
+ else:
768
+ ax.plot(
769
+ x[mask.astype(bool)], y[mask.astype(bool)], label='input data')
770
+ baseline_handle = ax.plot(x, baseline, label='baseline')[0]
771
+ # ax.plot(x, y-baseline, label='baseline corrected data')
772
+ ax.set_xlabel(xlabel, fontsize='x-large')
773
+ ax.set_ylabel(ylabel, fontsize='x-large')
774
+ ax.legend()
775
+ if title is None:
776
+ fig_title = plt.figtext(*title_pos, 'Baseline', **title_props)
777
+ else:
778
+ fig_title = plt.figtext(*title_pos, title, **title_props)
779
+ if num_iter < max_iter:
780
+ change_fig_subtitle()
781
+ else:
782
+ change_fig_subtitle(maxed_out=True)
783
+ fig.subplots_adjust(bottom=0.0, top=0.85)
784
+
785
+ if interactive:
786
+
787
+ fig.subplots_adjust(bottom=0.2)
788
+
789
+ # Setup TextBox
790
+ lambda_box = TextBox(
791
+ plt.axes([0.15, 0.05, 0.15, 0.075]), r'log($\lambda$)')
792
+ lambda_cid = lambda_box.on_submit(select_lambda)
793
+
794
+ # Setup "Continue" button
795
+ continue_btn = Button(
796
+ plt.axes([0.45, 0.05, 0.15, 0.075]), 'Continue smoothing')
797
+ continue_cid = continue_btn.on_clicked(continue_iter)
798
+
799
+ # Setup "Confirm" button
800
+ confirm_btn = Button(plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm')
801
+ confirm_cid = confirm_btn.on_clicked(confirm)
802
+
803
+ # Show figure for user interaction
804
+ plt.show()
805
+
806
+ # Disconnect all widget callbacks when figure is closed
807
+ lambda_box.disconnect(lambda_cid)
808
+ continue_btn.disconnect(continue_cid)
809
+ confirm_btn.disconnect(confirm_cid)
810
+
811
+ # ... and remove the buttons before returning the figure
812
+ lambda_box.ax.remove()
813
+ continue_btn.ax.remove()
814
+ confirm_btn.ax.remove()
815
+
816
+ if filename is not None:
817
+ fig_title.set_in_layout(True)
818
+ fig_subtitles[-1].set_in_layout(True)
819
+ fig.tight_layout(rect=(0, 0, 1, 0.90))
820
+ fig.savefig(filename)
821
+ plt.close()
822
+
823
+ config = {
824
+ 'tol': tol, 'lambda': lambdas[-1], 'max_iter': max_iter,
825
+ 'num_iter': num_iters[-1], 'error': errors[-1], 'mask': mask}
826
+ return baseline, config
827
+
828
+
580
829
  class ImageProcessor(Processor):
581
830
  """A Processor to plot an image (slice) from a NeXus object.
582
831
  """
@@ -958,7 +1207,7 @@ class MapProcessor(Processor):
958
1207
  NXentry object representing that map's metadata and any
959
1208
  scalar-valued raw data requested by the supplied map configuration.
960
1209
  """
961
- def process(self, data):
1210
+ def process(self, data, detector_names=[]):
962
1211
  """Process the output of a `Reader` that contains a map
963
1212
  configuration and returns a NeXus NXentry object representing
964
1213
  the map.
@@ -966,20 +1215,36 @@ class MapProcessor(Processor):
966
1215
  :param data: Result of `Reader.read` where at least one item
967
1216
  has the value `'MapConfig'` for the `'schema'` key.
968
1217
  :type data: list[PipelineData]
1218
+ :param detector_names: Detector prefixes to include raw data
1219
+ for in the returned NeXus NXentry object, defaults to `[]`.
1220
+ :type detector_names: list[str], optional
969
1221
  :return: Map data and metadata.
970
1222
  :rtype: nexusformat.nexus.NXentry
971
1223
  """
1224
+ # Local modules
1225
+ from CHAP.utils.general import string_to_list
1226
+ if isinstance(detector_names, str):
1227
+ try:
1228
+ detector_names = [
1229
+ str(v) for v in string_to_list(
1230
+ detector_names, raise_error=True)]
1231
+ except:
1232
+ raise ValueError(
1233
+ f'Invalid parameter detector_names ({detector_names})')
972
1234
  map_config = self.get_config(data, 'common.models.map.MapConfig')
973
- nxentry = self.__class__.get_nxentry(map_config)
1235
+ nxentry = self.__class__.get_nxentry(map_config, detector_names)
974
1236
 
975
1237
  return nxentry
976
1238
 
977
1239
  @staticmethod
978
- def get_nxentry(map_config):
1240
+ def get_nxentry(map_config, detector_names=[]):
979
1241
  """Use a `MapConfig` to construct a NeXus NXentry object.
980
1242
 
981
1243
  :param map_config: A valid map configuration.
982
1244
  :type map_config: MapConfig
1245
+ :param detector_names: Detector prefixes to include raw data
1246
+ for in the returned NeXus NXentry object.
1247
+ :type detector_names: list[str]
983
1248
  :return: The map's data and metadata contained in a NeXus
984
1249
  structure.
985
1250
  :rtype: nexusformat.nexus.NXentry
@@ -1000,6 +1265,8 @@ class MapProcessor(Processor):
1000
1265
  nxentry.map_config = dumps(map_config.dict())
1001
1266
  nxentry[map_config.sample.name] = NXsample(**map_config.sample.dict())
1002
1267
  nxentry.attrs['station'] = map_config.station
1268
+ for key, value in map_config.attrs.items():
1269
+ nxentry.attrs[key] = value
1003
1270
 
1004
1271
  nxentry.spec_scans = NXcollection()
1005
1272
  for scans in map_config.spec_scans:
@@ -1039,10 +1306,26 @@ class MapProcessor(Processor):
1039
1306
  nxentry.data.attrs['signal'] = signal
1040
1307
  nxentry.data.attrs['auxilliary_signals'] = auxilliary_signals
1041
1308
 
1042
- for data in map_config.all_scalar_data:
1043
- for map_index in np.ndindex(map_config.shape):
1309
+ # Create empty NXfields of appropriate shape for raw
1310
+ # detector data
1311
+ for detector_name in detector_names:
1312
+ if not isinstance(detector_name, str):
1313
+ detector_name = str(detector_name)
1314
+ detector_data = map_config.get_detector_data(
1315
+ detector_name, (0,) * len(map_config.shape))
1316
+ nxentry.data[detector_name] = NXfield(value=np.zeros(
1317
+ (*map_config.shape, *detector_data.shape)),
1318
+ dtype=detector_data.dtype)
1319
+
1320
+ for map_index in np.ndindex(map_config.shape):
1321
+ for data in map_config.all_scalar_data:
1044
1322
  nxentry.data[data.label][map_index] = map_config.get_value(
1045
1323
  data, map_index)
1324
+ for detector_name in detector_names:
1325
+ if not isinstance(detector_name, str):
1326
+ detector_name = str(detector_name)
1327
+ nxentry.data[detector_name][map_index] = \
1328
+ map_config.get_detector_data(detector_name, map_index)
1046
1329
 
1047
1330
  return nxentry
1048
1331
 
@@ -1172,6 +1455,66 @@ class PrintProcessor(Processor):
1172
1455
  return data
1173
1456
 
1174
1457
 
1458
+ class PyfaiAzimuthalIntegrationProcessor(Processor):
1459
+ """Processor to azimuthally integrate one or more frames of 2d
1460
+ detector data using the
1461
+ [pyFAI](https://pyfai.readthedocs.io/en/v2023.1/index.html)
1462
+ package.
1463
+ """
1464
+ def process(self, data, poni_file, npt, mask_file=None,
1465
+ integrate1d_kwargs=None, inputdir='.'):
1466
+ """Azimuthally integrate the detector data provided and return
1467
+ the result as a dictionary of numpy arrays containing the
1468
+ values of the radial coordinate of the result, the intensities
1469
+ along the radial direction, and the poisson errors for each
1470
+ intensity spectrum.
1471
+
1472
+ :param data: Detector data to integrate.
1473
+ :type data: Union[PipelineData, list[np.ndarray]]
1474
+ :param poni_file: Name of the [pyFAI PONI
1475
+ file](https://pyfai.readthedocs.io/en/v2023.1/glossary.html?highlight=poni%20file#poni-file)
1476
+ containing the detector properties pyFAI needs to perform
1477
+ azimuthal integration.
1478
+ :type poni_file: str
1479
+ :param npt: Number of points in the output pattern.
1480
+ :type npt: int
1481
+ :param mask_file: A file to use for masking the input data.
1482
+ :type: str
1483
+ :param integrate1d_kwargs: Optional dictionary of keyword
1484
+ arguments to use with
1485
+ [`pyFAI.azimuthalIntegrator.AzimuthalIntegrator.integrate1d`](https://pyfai.readthedocs.io/en/v2023.1/api/pyFAI.html#pyFAI.azimuthalIntegrator.AzimuthalIntegrator.integrate1d). Defaults
1486
+ to `None`.
1487
+ :type integrate1d_kwargs: Optional[dict]
1488
+ :returns: Azimuthal integration results as a dictionary of
1489
+ numpy arrays.
1490
+ """
1491
+ import os
1492
+ from pyFAI import load
1493
+
1494
+ if not os.path.isabs(poni_file):
1495
+ poni_file = os.path.join(inputdir, poni_file)
1496
+ ai = load(poni_file)
1497
+
1498
+ if mask_file is None:
1499
+ mask = None
1500
+ else:
1501
+ if not os.path.isabs(mask_file):
1502
+ mask_file = os.path.join(inputdir, mask_file)
1503
+ import fabio
1504
+ mask = fabio.open(mask_file).data
1505
+
1506
+ try:
1507
+ det_data = self.unwrap_pipelinedata(data)[0]
1508
+ except:
1509
+ det_data = det_data
1510
+
1511
+ if integrate1d_kwargs is None:
1512
+ integrate1d_kwargs = {}
1513
+ integrate1d_kwargs['mask'] = mask
1514
+
1515
+ return [ai.integrate1d(d, npt, **integrate1d_kwargs) for d in det_data]
1516
+
1517
+
1175
1518
  class RawDetectorDataMapProcessor(Processor):
1176
1519
  """A Processor to return a map of raw derector data in a
1177
1520
  NeXus NXroot object.
@@ -1344,6 +1687,499 @@ class StrainAnalysisProcessor(Processor):
1344
1687
  return strain_analysis_config
1345
1688
 
1346
1689
 
1690
+ class SetupNXdataProcessor(Processor):
1691
+ """Processor to set up and return an "empty" NeXus representation
1692
+ of a structured dataset. This representation will be an instance
1693
+ of `NXdata` that has:
1694
+ 1. An `NXfield` entry for every coordinate and signal specified.
1695
+ 1. `nxaxes` that are the `NXfield` entries for the coordinates and
1696
+ contain the values provided for each coordinate.
1697
+ 1. `NXfield` entries of appropriate shape, but containing all
1698
+ zeros, for every signal.
1699
+ 1. Attributes that define the axes, plus any additional attributes
1700
+ specified by the user.
1701
+
1702
+ This `Processor` is most useful as a "setup" step for
1703
+ constucting a representation of / container for a complete dataset
1704
+ that will be filled out in pieces later by
1705
+ `UpdateNXdataProcessor`.
1706
+
1707
+ Examples of use in a `Pipeline` configuration:
1708
+ - With inputs from a previous `PipelineItem` specifically written
1709
+ to provide inputs to this `Processor`:
1710
+ ```yaml
1711
+ config:
1712
+ inputdir: /rawdata/samplename
1713
+ outputdir: /reduceddata/samplename
1714
+ pipeline:
1715
+ - edd.SetupNXdataReader:
1716
+ filename: SpecInput.txt
1717
+ dataset_id: 1
1718
+ - common.SetupNXdataProcessor:
1719
+ nxname: samplename_dataset_1
1720
+ - common.NexusWriter:
1721
+ filename: data.nxs
1722
+ ```
1723
+ - With inputs provided directly though the optional arguments:
1724
+ ```yaml
1725
+ config:
1726
+ outputdir: /reduceddata/samplename
1727
+ pipeline:
1728
+ - common.SetupNXdataProcessor:
1729
+ nxname: your_dataset_name
1730
+ coords:
1731
+ - name: x
1732
+ values: [0.0, 0.5, 1.0]
1733
+ attrs:
1734
+ units: mm
1735
+ yourkey: yourvalue
1736
+ - name: temperature
1737
+ values: [200, 250, 275]
1738
+ attrs:
1739
+ units: Celsius
1740
+ yourotherkey: yourothervalue
1741
+ signals:
1742
+ - name: raw_detector_data
1743
+ shape: [407, 487]
1744
+ attrs:
1745
+ local_name: PIL11
1746
+ foo: bar
1747
+ - name: presample_intensity
1748
+ shape: []
1749
+ attrs:
1750
+ local_name: a3ic0
1751
+ zebra: fish
1752
+ attrs:
1753
+ arbitrary: metadata
1754
+ from: users
1755
+ goes: here
1756
+ - common.NexusWriter:
1757
+ filename: data.nxs
1758
+ ```
1759
+ """
1760
+ def process(self, data, nxname='data',
1761
+ coords=[], signals=[], attrs={}, data_points=[],
1762
+ extra_nxfields=[], duplicates='overwrite'):
1763
+ """Return an `NXdata` that has the requisite axes and
1764
+ `NXfield` entries to represent a structured dataset with the
1765
+ properties provided. Properties may be provided either through
1766
+ the `data` argument (from an appropriate `PipelineItem` that
1767
+ immediately preceeds this one in a `Pipeline`), or through the
1768
+ `coords`, `signals`, `attrs`, and/or `data_points`
1769
+ arguments. If any of the latter are used, their values will
1770
+ completely override any values for these parameters found from
1771
+ `data.`
1772
+
1773
+ :param data: Data from the previous item in a `Pipeline`.
1774
+ :type data: list[PipelineData]
1775
+ :param nxname: Name for the returned `NXdata` object. Defaults
1776
+ to `'data'`.
1777
+ :type nxname: str, optional
1778
+ :param coords: List of dictionaries defining the coordinates
1779
+ of the dataset. Each dictionary must have the keys
1780
+ `'name'` and `'values'`, whose values are the name of the
1781
+ coordinate axis (a string) and all the unique values of
1782
+ that coordinate for the structured dataset (a list of
1783
+ numbers), respectively. A third item in the dictionary is
1784
+ optional, but highly recommended: `'attrs'` may provide a
1785
+ dictionary of attributes to attach to the coordinate axis
1786
+ that assist in in interpreting the returned `NXdata`
1787
+ representation of the dataset. It is strongly recommended
1788
+ to provide the units of the values along an axis in the
1789
+ `attrs` dictionary. Defaults to [].
1790
+ :type coords: list[dict[str, object]], optional
1791
+ :param signals: List of dictionaries defining the signals of
1792
+ the dataset. Each dictionary must have the keys `'name'`
1793
+ and `'shape'`, whose values are the name of the signal
1794
+ field (a string) and the shape of the signal's value at
1795
+ each point in the dataset (a list of zero or more
1796
+ integers), respectively. A third item in the dictionary is
1797
+ optional, but highly recommended: `'attrs'` may provide a
1798
+ dictionary of attributes to attach to the signal fieldthat
1799
+ assist in in interpreting the returned `NXdata`
1800
+ representation of the dataset. It is strongly recommended
1801
+ to provide the units of the signal's values `attrs`
1802
+ dictionary. Defaults to [].
1803
+ :type signals: list[dict[str, object]], optional
1804
+ :param attrs: An arbitrary dictionary of attributes to assign
1805
+ to the returned `NXdata`. Defaults to {}.
1806
+ :type attrs: dict[str, object], optional
1807
+ :param data_points: A list of data points to partially (or
1808
+ even entirely) fil out the "empty" signal `NXfield`s
1809
+ before returning the `NXdata`. Defaults to [].
1810
+ :type data_points: list[dict[str, object]], optional
1811
+ :param extra_nxfields: List "extra" NXfield`s to include that
1812
+ can be described neither as a signal of the dataset, not a
1813
+ dedicated coordinate. This paramteter is good for
1814
+ including "alternate" values for one of the coordinate
1815
+ dimensions -- the same coordinate axis expressed in
1816
+ different units, for instance. Each item in the list
1817
+ shoulde be a dictionary of parameters for the
1818
+ `nexusformat.nexus.NXfield` constructor. Defaults to `[]`.
1819
+ :type extra_nxfields: list[dict[str, object]], optional
1820
+ :param duplicates: Behavior to use if any new data points occur
1821
+ at the same point in the dataset's coordinate space as an
1822
+ existing data point. Allowed values for `duplicates` are:
1823
+ `'overwrite'` and `'block'`. Defaults to `'overwrite'`.
1824
+ :type duplicates: Literal['overwrite', 'block']
1825
+ :returns: An `NXdata` that represents the structured dataset
1826
+ as specified.
1827
+ :rtype: nexusformat.nexus.NXdata
1828
+ """
1829
+ self.nxname = nxname
1830
+
1831
+ self.coords = coords
1832
+ self.signals = signals
1833
+ self.attrs = attrs
1834
+ try:
1835
+ setup_params = self.unwrap_pipelinedata(data)[0]
1836
+ except:
1837
+ setup_params = None
1838
+ if isinstance(setup_params, dict):
1839
+ for a in ('coords', 'signals', 'attrs'):
1840
+ setup_param = setup_params.get(a)
1841
+ if not getattr(self, a) and setup_param:
1842
+ self.logger.info(f'Using input data from pipeline for {a}')
1843
+ setattr(self, a, setup_param)
1844
+ else:
1845
+ self.logger.info(
1846
+ f'Ignoring input data from pipeline for {a}')
1847
+ else:
1848
+ self.logger.warning('Ignoring all input data from pipeline')
1849
+
1850
+ self.shape = tuple(len(c['values']) for c in self.coords)
1851
+
1852
+ self.extra_nxfields = extra_nxfields
1853
+ self._data_points = []
1854
+ self.duplicates = duplicates
1855
+ self.init_nxdata()
1856
+ for d in data_points:
1857
+ self.add_data_point(d)
1858
+
1859
+ return self.nxdata
1860
+
1861
+ def add_data_point(self, data_point):
1862
+ """Add a data point to this dataset.
1863
+ 1. Validate `data_point`.
1864
+ 2. Append `data_point` to `self._data_points`.
1865
+ 3. Update signal `NXfield`s in `self.nxdata`.
1866
+
1867
+ :param data_point: Data point defining a point in the
1868
+ dataset's coordinate space and the new signal values at
1869
+ that point.
1870
+ :type data_point: dict[str, object]
1871
+ :returns: None
1872
+ """
1873
+ self.logger.info(f'Adding data point no. {len(self._data_points)}')
1874
+ self.logger.debug(f'New data point: {data_point}')
1875
+ valid, msg = self.validate_data_point(data_point)
1876
+ if not valid:
1877
+ self.logger.error(f'Cannot add data point: {msg}')
1878
+ else:
1879
+ self._data_points.append(data_point)
1880
+ self.update_nxdata(data_point)
1881
+
1882
+ def validate_data_point(self, data_point):
1883
+ """Return `True` if `data_point` occurs at a valid point in
1884
+ this structured dataset's coordinate space, `False`
1885
+ otherwise. Also validate shapes of signal values and add NaN
1886
+ values for any missing signals.
1887
+
1888
+ :param data_point: Data point defining a point in the
1889
+ dataset's coordinate space and the new signal values at
1890
+ that point.
1891
+ :type data_point: dict[str, object]
1892
+ :returns: Validity of `data_point`, message
1893
+ :rtype: bool, str
1894
+ """
1895
+ import numpy as np
1896
+
1897
+ valid = True
1898
+ msg = ''
1899
+ # Convert all values to numpy types
1900
+ data_point = {k: np.asarray(v) for k, v in data_point.items()}
1901
+ # Ensure data_point defines a specific point in the dataset's
1902
+ # coordinate space
1903
+ if not all(c['name'] in data_point for c in self.coords):
1904
+ valid = False
1905
+ msg = 'Missing coordinate values'
1906
+ # Find & handle any duplicates
1907
+ for i, d in enumerate(self._data_points):
1908
+ is_duplicate = all(data_point[c] == d[c] for c in self.coord_names)
1909
+ if is_duplicate:
1910
+ if self.duplicates == 'overwrite':
1911
+ self._data_points.pop(i)
1912
+ elif self.duplicates == 'block':
1913
+ valid = False
1914
+ msg = 'Duplicate point will be blocked'
1915
+ # Ensure a value is present for all signals
1916
+ for s in self.signals:
1917
+ if s['name'] not in data_point:
1918
+ data_point[s['name']] = np.full(s['shape'], 0)
1919
+ else:
1920
+ if not data_point[s['name']].shape == tuple(s['shape']):
1921
+ valid = False
1922
+ msg = f'Shape mismatch for signal {s}'
1923
+ return valid, msg
1924
+
1925
+ def init_nxdata(self):
1926
+ """Initialize an empty `NXdata` representing this dataset to
1927
+ `self.nxdata`; values for axes' `NXfield`s are filled out,
1928
+ values for signals' `NXfield`s are empty an can be filled out
1929
+ later. Save the empty `NXdata` to the NeXus file. Initialise
1930
+ `self.nxfile` and `self.nxdata_path` with the `NXFile` object
1931
+ and actual nxpath used to save and make updates to the
1932
+ `NXdata`.
1933
+
1934
+ :returns: None
1935
+ """
1936
+ from nexusformat.nexus import NXdata, NXfield
1937
+ import numpy as np
1938
+
1939
+ axes = tuple(NXfield(
1940
+ value=c['values'],
1941
+ name=c['name'],
1942
+ attrs=c.get('attrs')) for c in self.coords)
1943
+ entries = {s['name']: NXfield(
1944
+ value=np.full((*self.shape, *s['shape']), 0),
1945
+ name=s['name'],
1946
+ attrs=s.get('attrs')) for s in self.signals}
1947
+ extra_nxfields = [NXfield(**params) for params in self.extra_nxfields]
1948
+ extra_nxfields = {f.nxname: f for f in extra_nxfields}
1949
+ entries.update(extra_nxfields)
1950
+ self.nxdata = NXdata(
1951
+ name=self.nxname, axes=axes, entries=entries, attrs=self.attrs)
1952
+
1953
+ def update_nxdata(self, data_point):
1954
+ """Update `self.nxdata`'s NXfield values.
1955
+
1956
+ :param data_point: Data point defining a point in the
1957
+ dataset's coordinate space and the new signal values at
1958
+ that point.
1959
+ :type data_point: dict[str, object]
1960
+ :returns: None
1961
+ """
1962
+ index = self.get_index(data_point)
1963
+ for s in self.signals:
1964
+ if s['name'] in data_point:
1965
+ self.nxdata[s['name']][index] = data_point[s['name']]
1966
+
1967
+ def get_index(self, data_point):
1968
+ """Return a tuple representing the array index of `data_point`
1969
+ in the coordinate space of the dataset.
1970
+
1971
+ :param data_point: Data point defining a point in the
1972
+ dataset's coordinate space.
1973
+ :type data_point: dict[str, object]
1974
+ :returns: Multi-dimensional index of `data_point` in the
1975
+ dataset's coordinate space.
1976
+ :rtype: tuple
1977
+ """
1978
+ return tuple(c['values'].index(data_point[c['name']]) \
1979
+ for c in self.coords)
1980
+
1981
+
1982
+ class UpdateNXdataProcessor(Processor):
1983
+ """Processor to fill in part(s) of an `NXdata` representing a
1984
+ structured dataset that's already been written to a NeXus file.
1985
+
1986
+ This Processor is most useful as an "update" step for an `NXdata`
1987
+ created by `common.SetupNXdataProcessor`, and is easitest to use
1988
+ in a `Pipeline` immediately after another `PipelineItem` designed
1989
+ specifically to return a value that can be used as input to this
1990
+ `Processor`.
1991
+
1992
+ Example of use in a `Pipeline` configuration:
1993
+ ```yaml
1994
+ config:
1995
+ inputdir: /rawdata/samplename
1996
+ pipeline:
1997
+ - edd.UpdateNXdataReader:
1998
+ spec_file: spec.log
1999
+ scan_number: 1
2000
+ - common.SetupNXdataProcessor:
2001
+ nxfilename: /reduceddata/samplename/data.nxs
2002
+ nxdata_path: /entry/samplename_dataset_1
2003
+ ```
2004
+ """
2005
+
2006
+ def process(self, data, nxfilename, nxdata_path, data_points=[],
2007
+ allow_approximate_coordinates=True):
2008
+ """Write new data points to the signal fields of an existing
2009
+ `NXdata` object representing a structued dataset in a NeXus
2010
+ file. Return the list of data points used to update the
2011
+ dataset.
2012
+
2013
+ :param data: Data from the previous item in a `Pipeline`. May
2014
+ contain a list of data points that will extend the list of
2015
+ data points optionally provided with the `data_points`
2016
+ argument.
2017
+ :type data: list[PipelineData]
2018
+ :param nxfilename: Name of the NeXus file containing the
2019
+ `NXdata` to update.
2020
+ :type nxfilename: str
2021
+ :param nxdata_path: The path to the `NXdata` to update in the file.
2022
+ :type nxdata_path: str
2023
+ :param data_points: List of data points, each one a dictionary
2024
+ whose keys are the names of the coordinates and axes, and
2025
+ whose values are the values of each coordinate / signal at
2026
+ a single point in the dataset. Deafults to [].
2027
+ :type data_points: list[dict[str, object]]
2028
+ :param allow_approximate_coordinates: Parameter to allow the
2029
+ nearest existing match for the new data points'
2030
+ coordinates to be used if an exact match connot be found
2031
+ (sometimes this is due simply to differences in rounding
2032
+ convetions). Defaults to True.
2033
+ :type allow_approximate_coordinates: bool, optional
2034
+ :returns: Complete list of data points used to update the dataset.
2035
+ :rtype: list[dict[str, object]]
2036
+ """
2037
+ from nexusformat.nexus import NXFile
2038
+ import numpy as np
2039
+ import os
2040
+
2041
+ _data_points = self.unwrap_pipelinedata(data)[0]
2042
+ if isinstance(_data_points, list):
2043
+ data_points.extend(_data_points)
2044
+ self.logger.info(f'Updating {len(data_points)} data points')
2045
+
2046
+ nxfile = NXFile(nxfilename, 'rw')
2047
+ nxdata = nxfile.readfile()[nxdata_path]
2048
+ axes_names = [a.nxname for a in nxdata.nxaxes]
2049
+
2050
+ data_points_used = []
2051
+ for i, d in enumerate(data_points):
2052
+ # Verify that the data point contains a value for all
2053
+ # coordinates in the dataset.
2054
+ if not all(a in d for a in axes_names):
2055
+ self.logger.error(
2056
+ f'Data point {i} is missing a value for at least one '
2057
+ + f'axis. Skipping. Axes are: {", ".join(axes_names)}')
2058
+ continue
2059
+ self.logger.info(
2060
+ f'Coordinates for data point {i}: '
2061
+ + ', '.join([f'{a}={d[a]}' for a in axes_names]))
2062
+ # Get the index of the data point in the dataset based on
2063
+ # its values for each coordinate.
2064
+ try:
2065
+ index = tuple(np.where(a.nxdata == d[a.nxname])[0][0] \
2066
+ for a in nxdata.nxaxes)
2067
+ except:
2068
+ if allow_approximate_coordinates:
2069
+ try:
2070
+ index = tuple(
2071
+ np.argmin(np.abs(a.nxdata - d[a.nxname])) \
2072
+ for a in nxdata.nxaxes)
2073
+ self.logger.warning(
2074
+ f'Nearest match for coordinates of data point {i}:'
2075
+ + ', '.join(
2076
+ [f'{a.nxname}={a[_i]}' \
2077
+ for _i, a in zip(index, nxdata.nxaxes)]))
2078
+ except:
2079
+ self.logger.error(
2080
+ f'Cannot get the index of data point {i}. '
2081
+ + f'Skipping.')
2082
+ continue
2083
+ else:
2084
+ self.logger.error(
2085
+ f'Cannot get the index of data point {i}. Skipping.')
2086
+ continue
2087
+ self.logger.info(f'Index of data point {i}: {index}')
2088
+ # Update the signals contained in this data point at the
2089
+ # proper index in the dataset's singal `NXfield`s
2090
+ for k, v in d.items():
2091
+ if k in axes_names:
2092
+ continue
2093
+ try:
2094
+ nxfile.writevalue(
2095
+ os.path.join(nxdata_path, k), np.asarray(v), index)
2096
+ except Exception as e:
2097
+ self.logger.error(
2098
+ f'Error updating signal {k} for new data point '
2099
+ + f'{i} (dataset index {index}): {e}')
2100
+ data_points_used.append(d)
2101
+
2102
+ nxfile.close()
2103
+
2104
+ return data_points_used
2105
+
2106
+
2107
+ class NXdataToDataPointsProcessor(Processor):
2108
+ """Transform an `NXdata` object into a list of dictionaries. Each
2109
+ dictionary represents a single data point in the coordinate space
2110
+ of the dataset. The keys are the names of the signals and axes in
2111
+ the dataset, and the values are a single scalar value (in the case
2112
+ of axes) or the value of the signal at that point in the
2113
+ coordinate space of the dataset (in the case of signals -- this
2114
+ means that values for signals may be any shape, depending on the
2115
+ shape of the signal itself).
2116
+
2117
+ Example of use in a pipeline configuration:
2118
+ ```yaml
2119
+ config:
2120
+ inputdir: /reduceddata/samplename
2121
+ - common.NXdataReader:
2122
+ name: data
2123
+ axes_names:
2124
+ - x
2125
+ - y
2126
+ signal_name: z
2127
+ nxfield_params:
2128
+ - filename: data.nxs
2129
+ nxpath: entry/data/x
2130
+ slice_params:
2131
+ - step: 2
2132
+ - filename: data.nxs
2133
+ nxpath: entry/data/y
2134
+ slice_params:
2135
+ - step: 2
2136
+ - filename: data.nxs
2137
+ nxpath: entry/data/z
2138
+ slice_params:
2139
+ - step: 2
2140
+ - step: 2
2141
+ - common.NXdataToDataPointsProcessor
2142
+ - common.UpdateNXdataProcessor:
2143
+ nxfilename: /reduceddata/samplename/sparsedata.nxs
2144
+ nxdata_path: /entry/data
2145
+ ```
2146
+ """
2147
+ def process(self, data):
2148
+ """Return a list of dictionaries representing the coordinate
2149
+ and signal values at every point in the dataset provided.
2150
+
2151
+ :param data: Input pipeline data containing an `NXdata`.
2152
+ :type data: list[PipelineData]
2153
+ :returns: List of all data points in the dataset.
2154
+ :rtype: list[dict[str,object]]
2155
+ """
2156
+ import numpy as np
2157
+
2158
+ nxdata = self.unwrap_pipelinedata(data)[0]
2159
+
2160
+ data_points = []
2161
+ axes_names = [a.nxname for a in nxdata.nxaxes]
2162
+ self.logger.info(f'Dataset axes: {axes_names}')
2163
+ dataset_shape = tuple([a.size for a in nxdata.nxaxes])
2164
+ self.logger.info(f'Dataset shape: {dataset_shape}')
2165
+ signal_names = [k for k, v in nxdata.entries.items() \
2166
+ if not k in axes_names \
2167
+ and v.shape[:len(dataset_shape)] == dataset_shape]
2168
+ self.logger.info(f'Dataset signals: {signal_names}')
2169
+ other_fields = [k for k, v in nxdata.entries.items() \
2170
+ if not k in axes_names + signal_names]
2171
+ if len(other_fields) > 0:
2172
+ self.logger.warning(
2173
+ 'Ignoring the following fields that cannot be interpreted as '
2174
+ + f'either dataset coordinates or signals: {other_fields}')
2175
+ for i in np.ndindex(dataset_shape):
2176
+ data_points.append({**{a: nxdata[a][_i] \
2177
+ for a, _i in zip(axes_names, i)},
2178
+ **{s: nxdata[s].nxdata[i] \
2179
+ for s in signal_names}})
2180
+ return data_points
2181
+
2182
+
1347
2183
  class XarrayToNexusProcessor(Processor):
1348
2184
  """A Processor to convert the data in an `xarray` structure to a
1349
2185
  NeXus NXdata object.