ChessAnalysisPipeline 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ChessAnalysisPipeline might be problematic. Click here for more details.

Files changed (38) hide show
  1. CHAP/__init__.py +1 -1
  2. CHAP/common/__init__.py +13 -0
  3. CHAP/common/models/integration.py +29 -26
  4. CHAP/common/models/map.py +395 -224
  5. CHAP/common/processor.py +1725 -93
  6. CHAP/common/reader.py +265 -28
  7. CHAP/common/writer.py +191 -18
  8. CHAP/edd/__init__.py +9 -2
  9. CHAP/edd/models.py +886 -665
  10. CHAP/edd/processor.py +2592 -936
  11. CHAP/edd/reader.py +889 -0
  12. CHAP/edd/utils.py +846 -292
  13. CHAP/foxden/__init__.py +6 -0
  14. CHAP/foxden/processor.py +42 -0
  15. CHAP/foxden/writer.py +65 -0
  16. CHAP/giwaxs/__init__.py +8 -0
  17. CHAP/giwaxs/models.py +100 -0
  18. CHAP/giwaxs/processor.py +520 -0
  19. CHAP/giwaxs/reader.py +5 -0
  20. CHAP/giwaxs/writer.py +5 -0
  21. CHAP/pipeline.py +48 -10
  22. CHAP/runner.py +161 -72
  23. CHAP/tomo/models.py +31 -29
  24. CHAP/tomo/processor.py +169 -118
  25. CHAP/utils/__init__.py +1 -0
  26. CHAP/utils/fit.py +1292 -1315
  27. CHAP/utils/general.py +411 -53
  28. CHAP/utils/models.py +594 -0
  29. CHAP/utils/parfile.py +10 -2
  30. ChessAnalysisPipeline-0.0.16.dist-info/LICENSE +60 -0
  31. {ChessAnalysisPipeline-0.0.14.dist-info → ChessAnalysisPipeline-0.0.16.dist-info}/METADATA +1 -1
  32. ChessAnalysisPipeline-0.0.16.dist-info/RECORD +62 -0
  33. {ChessAnalysisPipeline-0.0.14.dist-info → ChessAnalysisPipeline-0.0.16.dist-info}/WHEEL +1 -1
  34. CHAP/utils/scanparsers.py +0 -1431
  35. ChessAnalysisPipeline-0.0.14.dist-info/LICENSE +0 -21
  36. ChessAnalysisPipeline-0.0.14.dist-info/RECORD +0 -54
  37. {ChessAnalysisPipeline-0.0.14.dist-info → ChessAnalysisPipeline-0.0.16.dist-info}/entry_points.txt +0 -0
  38. {ChessAnalysisPipeline-0.0.14.dist-info → ChessAnalysisPipeline-0.0.16.dist-info}/top_level.txt +0 -0
CHAP/common/processor.py CHANGED
@@ -8,6 +8,9 @@ Description: Module for Processors used in multiple experiment-specific
8
8
  workflows.
9
9
  """
10
10
 
11
+ # System modules
12
+ import os
13
+
11
14
  # Third party modules
12
15
  import numpy as np
13
16
 
@@ -58,12 +61,6 @@ class AnimationProcessor(Processor):
58
61
  :return: The matplotlib animation.
59
62
  :rtype: matplotlib.animation.ArtistAnimation
60
63
  """
61
- # System modules
62
- from os.path import (
63
- isabs,
64
- join,
65
- )
66
-
67
64
  # Third party modules
68
65
  import matplotlib.animation as animation
69
66
  import matplotlib.pyplot as plt
@@ -134,6 +131,7 @@ class AnimationProcessor(Processor):
134
131
  a_max = frames[0].max()
135
132
  for n in range(1, num_frames):
136
133
  a_max = min(a_max, frames[n].max())
134
+ a_max = float(a_max)
137
135
  if vmin is None:
138
136
  vmin = -a_max
139
137
  if vmax is None:
@@ -248,17 +246,9 @@ class BinarizeProcessor(Processor):
248
246
  :raises ValueError: Upon invalid input parameters.
249
247
  :return: The binarized dataset with a return type equal to
250
248
  that of the input dataset.
251
- :rtype: numpy.ndarray, nexusformat.nexus.NXobject
249
+ :rtype: typing.Union[numpy.ndarray, nexusformat.nexus.NXobject]
252
250
  """
253
- # System modules
254
- from os.path import join as os_join
255
- from os.path import relpath
256
-
257
- # Local modules
258
- from CHAP.utils.general import (
259
- is_int,
260
- nxcopy,
261
- )
251
+ # Third party modules
262
252
  from nexusformat.nexus import (
263
253
  NXdata,
264
254
  NXfield,
@@ -267,6 +257,12 @@ class BinarizeProcessor(Processor):
267
257
  nxsetconfig,
268
258
  )
269
259
 
260
+ # Local modules
261
+ from CHAP.utils.general import (
262
+ is_int,
263
+ nxcopy,
264
+ )
265
+
270
266
  if method not in [
271
267
  'CHAP', 'manual', 'otsu', 'yen', 'isodata', 'minimum']:
272
268
  raise ValueError(f'Invalid parameter method ({method})')
@@ -344,19 +340,21 @@ class BinarizeProcessor(Processor):
344
340
  exclude_nxpaths = []
345
341
  if nxdefault is not None:
346
342
  exclude_nxpaths.append(
347
- os_join(relpath(nxdefault.nxpath, dataset.nxpath)))
343
+ os.path.join(os.path.relpath(
344
+ nxdefault.nxpath, dataset.nxpath)))
348
345
  if remove_original_data:
349
346
  if (nxdefault is None
350
347
  or nxdefault.nxpath != nxdata.nxpath):
351
- relpath_nxdata = relpath(nxdata.nxpath, dataset.nxpath)
348
+ relpath_nxdata = os.path.relpath(
349
+ nxdata.nxpath, dataset.nxpath)
352
350
  keys = list(nxdata.keys())
353
351
  keys.remove(nxsignal.nxname)
354
352
  for axis in nxdata.axes:
355
353
  keys.remove(axis)
356
354
  if len(keys):
357
355
  raise RuntimeError('Not tested yet')
358
- exclude_nxpaths.append(os_join(
359
- relpath(nxsignal.nxpath, dataset.nxpath)))
356
+ exclude_nxpaths.append(os.path.join(
357
+ os.path.relpath(nxsignal.nxpath, dataset.nxpath)))
360
358
  elif relpath_nxdata == '.':
361
359
  exclude_nxpaths.append(nxsignal.nxname)
362
360
  if dataset.nxclass != 'NXdata':
@@ -373,11 +371,11 @@ class BinarizeProcessor(Processor):
373
371
  keys.remove(axis)
374
372
  if len(keys):
375
373
  raise RuntimeError('Not tested yet')
376
- exclude_nxpaths.append(os_join(
377
- relpath(nxsignal.nxpath, dataset.nxpath)))
374
+ exclude_nxpaths.append(os.path.join(
375
+ os.path.relpath(nxsignal.nxpath, dataset.nxpath)))
378
376
  else:
379
- exclude_nxpaths.append(os_join(
380
- relpath(nxgroup.nxpath, dataset.nxpath)))
377
+ exclude_nxpaths.append(os.path.join(
378
+ os.path.relpath(nxgroup.nxpath, dataset.nxpath)))
381
379
  nxobject = nxcopy(dataset, exclude_nxpaths=exclude_nxpaths)
382
380
 
383
381
  # Get a histogram of the data
@@ -494,12 +492,11 @@ class BinarizeProcessor(Processor):
494
492
  # Select the ROI's orthogonal to the selected averaging direction
495
493
  bounds = []
496
494
  for i, bound in enumerate(['"0"', '"1"']):
497
- _, roi = select_roi_2d(
495
+ roi = select_roi_2d(
498
496
  mean_data,
499
497
  title=f'Select the ROI to obtain the {bound} data value',
500
498
  title_a=f'Data averaged in the {axes[axis]}-direction',
501
499
  row_label=subaxes[0], column_label=subaxes[1])
502
- plt.close()
503
500
 
504
501
  # Select the index range in the selected averaging direction
505
502
  if not axis:
@@ -512,12 +509,11 @@ class BinarizeProcessor(Processor):
512
509
  mean_roi_data = data[roi[2]:roi[3],roi[0]:roi[1],:].mean(
513
510
  axis=(0,1))
514
511
 
515
- _, _range = select_roi_1d(
512
+ _range = select_roi_1d(
516
513
  mean_roi_data, preselected_roi=(0, data.shape[axis]),
517
514
  title=f'Select the {axes[axis]}-direction range to obtain '
518
515
  f'the {bound} data bound',
519
516
  xlabel=axes[axis], ylabel='Average data')
520
- plt.close()
521
517
 
522
518
  # Obtain the lower/upper data bound
523
519
  if not axis:
@@ -573,10 +569,261 @@ class BinarizeProcessor(Processor):
573
569
  nxdata = nxentry[name].data
574
570
  nxentry.data = NXdata(
575
571
  NXlink(nxdata.nxsignal.nxpath),
576
- [NXlink(os_join(nxdata.nxpath, axis)) for axis in nxdata.axes])
572
+ [NXlink(os.path.join(nxdata.nxpath, axis))
573
+ for axis in nxdata.axes])
574
+ nxentry.data.set_default()
577
575
  return nxobject
578
576
 
579
577
 
578
+ class ConstructBaseline(Processor):
579
+ """A Processor to construct a baseline for a dataset.
580
+ """
581
+ def process(
582
+ self, data, mask=None, tol=1.e-6, lam=1.e6, max_iter=20,
583
+ save_figures=False, outputdir='.', interactive=False):
584
+ """Construct and return the baseline for a dataset.
585
+
586
+ :param data: Input data.
587
+ :type data: list[PipelineData]
588
+ :param mask: A mask to apply to the spectrum before baseline
589
+ construction, default to `None`.
590
+ :type mask: array-like, optional
591
+ :param tol: The convergence tolerence, defaults to `1.e-6`.
592
+ :type tol: float, optional
593
+ :param lam: The &lambda (smoothness) parameter (the balance
594
+ between the residual of the data and the baseline and the
595
+ smoothness of the baseline). The suggested range is between
596
+ 100 and 10^8, defaults to `10^6`.
597
+ :type lam: float, optional
598
+ :param max_iter: The maximum number of iterations,
599
+ defaults to `20`.
600
+ :type max_iter: int, optional
601
+ :param save_figures: Save .pngs of plots for checking inputs &
602
+ outputs of this Processor, defaults to False.
603
+ :type save_figures: bool, optional
604
+ :param outputdir: Directory to which any output figures will
605
+ be saved, defaults to '.'
606
+ :type outputdir: str, optional
607
+ :param interactive: Allows for user interactions, defaults to
608
+ False.
609
+ :type interactive: bool, optional
610
+ :return: The smoothed baseline and the configuration.
611
+ :rtype: numpy.array, dict
612
+ """
613
+ try:
614
+ data = np.asarray(self.unwrap_pipelinedata(data)[0])
615
+ except:
616
+ raise ValueError(
617
+ f'The structure of {data} contains no valid data')
618
+
619
+ return self.construct_baseline(
620
+ data, mask, tol, lam, max_iter, save_figures, outputdir,
621
+ interactive)
622
+
623
+ @staticmethod
624
+ def construct_baseline(
625
+ y, x=None, mask=None, tol=1.e-6, lam=1.e6, max_iter=20, title=None,
626
+ xlabel=None, ylabel=None, interactive=False, filename=None):
627
+ """Construct and return the baseline for a dataset.
628
+
629
+ :param y: Input data.
630
+ :type y: numpy.array
631
+ :param x: Independent dimension (only used when interactive is
632
+ `True` of when filename is set), defaults to `None`.
633
+ :type x: array-like, optional
634
+ :param mask: A mask to apply to the spectrum before baseline
635
+ construction, default to `None`.
636
+ :type mask: array-like, optional
637
+ :param tol: The convergence tolerence, defaults to `1.e-6`.
638
+ :type tol: float, optional
639
+ :param lam: The &lambda (smoothness) parameter (the balance
640
+ between the residual of the data and the baseline and the
641
+ smoothness of the baseline). The suggested range is between
642
+ 100 and 10^8, defaults to `10^6`.
643
+ :type lam: float, optional
644
+ :param max_iter: The maximum number of iterations,
645
+ defaults to `20`.
646
+ :type max_iter: int, optional
647
+ :param xlabel: Label for the x-axis of the displayed figure,
648
+ defaults to `None`.
649
+ :param title: Title for the displayed figure, defaults to `None`.
650
+ :type title: str, optional
651
+ :type xlabel: str, optional
652
+ :param ylabel: Label for the y-axis of the displayed figure,
653
+ defaults to `None`.
654
+ :type ylabel: str, optional
655
+ :param interactive: Allows for user interactions, defaults to
656
+ False.
657
+ :type interactive: bool, optional
658
+ :param filename: Save a .png of the plot to filename, defaults to
659
+ `None`, in which case the plot is not saved.
660
+ :type filename: str, optional
661
+ :return: The smoothed baseline and the configuration.
662
+ :rtype: numpy.array, dict
663
+ """
664
+ # Third party modules
665
+ if interactive or filename is not None:
666
+ from matplotlib.widgets import TextBox, Button
667
+ import matplotlib.pyplot as plt
668
+
669
+ # Local modules
670
+ from CHAP.utils.general import baseline_arPLS
671
+
672
+ def change_fig_subtitle(maxed_out=False, subtitle=None):
673
+ if fig_subtitles:
674
+ fig_subtitles[0].remove()
675
+ fig_subtitles.pop()
676
+ if subtitle is None:
677
+ subtitle = r'$\lambda$ = 'f'{lambdas[-1]:.2e}, '
678
+ if maxed_out:
679
+ subtitle += f'# iter = {num_iters[-1]} (maxed out) '
680
+ else:
681
+ subtitle += f'# iter = {num_iters[-1]} '
682
+ subtitle += f'error = {errors[-1]:.2e}'
683
+ fig_subtitles.append(
684
+ plt.figtext(*subtitle_pos, subtitle, **subtitle_props))
685
+
686
+ def select_lambda(expression):
687
+ """Callback function for the "Select lambda" TextBox.
688
+ """
689
+ if not len(expression):
690
+ return
691
+ try:
692
+ lam = float(expression)
693
+ if lam < 0:
694
+ raise ValueError
695
+ except ValueError:
696
+ change_fig_subtitle(
697
+ subtitle=f'Invalid lambda, enter a positive number')
698
+ else:
699
+ lambdas.pop()
700
+ lambdas.append(10**lam)
701
+ baseline, _, w, num_iter, error = baseline_arPLS(
702
+ y, mask=mask, tol=tol, lam=lambdas[-1], max_iter=max_iter,
703
+ full_output=True)
704
+ num_iters.pop()
705
+ num_iters.append(num_iter)
706
+ errors.pop()
707
+ errors.append(error)
708
+ if num_iter < max_iter:
709
+ change_fig_subtitle()
710
+ else:
711
+ change_fig_subtitle(maxed_out=True)
712
+ baseline_handle.set_ydata(baseline)
713
+ lambda_box.set_val('')
714
+ plt.draw()
715
+
716
+ def continue_iter(event):
717
+ """Callback function for the "Continue" button."""
718
+ baseline, _, w, n_iter, error = baseline_arPLS(
719
+ y, mask=mask, w=weights[-1], tol=tol, lam=lambdas[-1],
720
+ max_iter=max_iter, full_output=True)
721
+ num_iters[-1] += n_iter
722
+ errors.pop()
723
+ errors.append(error)
724
+ if n_iter < max_iter:
725
+ change_fig_subtitle()
726
+ else:
727
+ change_fig_subtitle(maxed_out=True)
728
+ baseline_handle.set_ydata(baseline)
729
+ plt.draw()
730
+ weights.pop()
731
+ weights.append(w)
732
+
733
+ def confirm(event):
734
+ """Callback function for the "Confirm" button."""
735
+ plt.close()
736
+
737
+ baseline, _, w, num_iter, error = baseline_arPLS(
738
+ y, mask=mask, tol=tol, lam=lam, max_iter=max_iter,
739
+ full_output=True)
740
+
741
+ if not interactive and filename is None:
742
+ return baseline
743
+
744
+ lambdas = [lam]
745
+ weights = [w]
746
+ num_iters = [num_iter]
747
+ errors = [error]
748
+ fig_subtitles = []
749
+
750
+ # Check inputs
751
+ if x is None:
752
+ x = np.arange(y.size)
753
+
754
+ # Setup the Matplotlib figure
755
+ title_pos = (0.5, 0.95)
756
+ title_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center',
757
+ 'verticalalignment': 'bottom'}
758
+ subtitle_pos = (0.5, 0.90)
759
+ subtitle_props = {'fontsize': 'x-large',
760
+ 'horizontalalignment': 'center',
761
+ 'verticalalignment': 'bottom'}
762
+ fig, ax = plt.subplots(figsize=(11, 8.5))
763
+ if mask is None:
764
+ ax.plot(x, y, label='input data')
765
+ else:
766
+ ax.plot(
767
+ x[mask.astype(bool)], y[mask.astype(bool)], label='input data')
768
+ baseline_handle = ax.plot(x, baseline, label='baseline')[0]
769
+ # ax.plot(x, y-baseline, label='baseline corrected data')
770
+ ax.set_xlabel(xlabel, fontsize='x-large')
771
+ ax.set_ylabel(ylabel, fontsize='x-large')
772
+ ax.legend()
773
+ if title is None:
774
+ fig_title = plt.figtext(*title_pos, 'Baseline', **title_props)
775
+ else:
776
+ fig_title = plt.figtext(*title_pos, title, **title_props)
777
+ if num_iter < max_iter:
778
+ change_fig_subtitle()
779
+ else:
780
+ change_fig_subtitle(maxed_out=True)
781
+ fig.subplots_adjust(bottom=0.0, top=0.85)
782
+
783
+ if interactive:
784
+
785
+ fig.subplots_adjust(bottom=0.2)
786
+
787
+ # Setup TextBox
788
+ lambda_box = TextBox(
789
+ plt.axes([0.15, 0.05, 0.15, 0.075]), r'log($\lambda$)')
790
+ lambda_cid = lambda_box.on_submit(select_lambda)
791
+
792
+ # Setup "Continue" button
793
+ continue_btn = Button(
794
+ plt.axes([0.45, 0.05, 0.15, 0.075]), 'Continue smoothing')
795
+ continue_cid = continue_btn.on_clicked(continue_iter)
796
+
797
+ # Setup "Confirm" button
798
+ confirm_btn = Button(plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm')
799
+ confirm_cid = confirm_btn.on_clicked(confirm)
800
+
801
+ # Show figure for user interaction
802
+ plt.show()
803
+
804
+ # Disconnect all widget callbacks when figure is closed
805
+ lambda_box.disconnect(lambda_cid)
806
+ continue_btn.disconnect(continue_cid)
807
+ confirm_btn.disconnect(confirm_cid)
808
+
809
+ # ... and remove the buttons before returning the figure
810
+ lambda_box.ax.remove()
811
+ continue_btn.ax.remove()
812
+ confirm_btn.ax.remove()
813
+
814
+ if filename is not None:
815
+ fig_title.set_in_layout(True)
816
+ fig_subtitles[-1].set_in_layout(True)
817
+ fig.tight_layout(rect=(0, 0, 1, 0.90))
818
+ fig.savefig(filename)
819
+ plt.close()
820
+
821
+ config = {
822
+ 'tol': tol, 'lambda': lambdas[-1], 'max_iter': max_iter,
823
+ 'num_iter': num_iters[-1], 'error': errors[-1], 'mask': mask}
824
+ return baseline, config
825
+
826
+
580
827
  class ImageProcessor(Processor):
581
828
  """A Processor to plot an image (slice) from a NeXus object.
582
829
  """
@@ -584,9 +831,9 @@ class ImageProcessor(Processor):
584
831
  self, data, vmin=None, vmax=None, axis=0, index=None,
585
832
  coord=None, interactive=False, save_figure=True, outputdir='.',
586
833
  filename='image.png'):
587
- """Plot and/or save an image (slice) from a NeXus NXobject object with
588
- a default data path contained in `data` and return the NeXus NXdata
589
- data object.
834
+ """Plot and/or save an image (slice) from a NeXus NXobject
835
+ object with a default data path contained in `data` and return
836
+ the NeXus NXdata data object.
590
837
 
591
838
  :param data: Input data.
592
839
  :type data: list[PipelineData]
@@ -618,12 +865,6 @@ class ImageProcessor(Processor):
618
865
  :return: The input data object.
619
866
  :rtype: nexusformat.nexus.NXdata
620
867
  """
621
- # System modules
622
- from os.path import (
623
- isabs,
624
- join,
625
- )
626
-
627
868
  # Third party modules
628
869
  import matplotlib.pyplot as plt
629
870
 
@@ -639,8 +880,8 @@ class ImageProcessor(Processor):
639
880
  raise ValueError(f'Invalid parameter outputdir ({outputdir})')
640
881
  if not isinstance(filename, str):
641
882
  raise ValueError(f'Invalid parameter filename ({filename})')
642
- if not isabs(filename):
643
- filename = join(outputdir, filename)
883
+ if not os.path.isabs(filename):
884
+ filename = os.path.join(outputdir, filename)
644
885
 
645
886
  # Get the default Nexus NXdata object
646
887
  data = self.unwrap_pipelinedata(data)[0]
@@ -796,8 +1037,9 @@ class IntegrateMapProcessor(Processor):
796
1037
  containing a map of the integrated detector data requested.
797
1038
 
798
1039
  :param data: Input data, containing at least one item
799
- with the value `'MapConfig'` for the `'schema'` key, and at
800
- least one item with the value `'IntegrationConfig'` for the
1040
+ with the value `'common.models.map.MapConfig'` for the
1041
+ `'schema'` key, and at least one item with the value
1042
+ `'common.models.integration.IntegrationConfig'` for the
801
1043
  `'schema'` key.
802
1044
  :type data: list[PipelineData]
803
1045
  :return: Integrated data and process metadata.
@@ -815,10 +1057,11 @@ class IntegrateMapProcessor(Processor):
815
1057
  """Use a `MapConfig` and `IntegrationConfig` to construct a
816
1058
  NeXus NXprocess object.
817
1059
 
818
- :param map_config: A valid map configuration.
819
- :type map_config: MapConfig
820
- :param integration_config: A valid integration configuration
821
- :type integration_config: IntegrationConfig.
1060
+ :param map_config: A valid map configuration..
1061
+ :type map_config: common.models.map.MapConfig
1062
+ :param integration_config: A valid integration configuration.
1063
+ :type integration_config:
1064
+ common.models.integration.IntegrationConfig
822
1065
  :return: The integrated detector data and metadata.
823
1066
  :rtype: nexusformat.nexus.NXprocess
824
1067
  """
@@ -871,7 +1114,7 @@ class IntegrateMapProcessor(Processor):
871
1114
  *map_config.dims,
872
1115
  *integration_config.integrated_data_dims
873
1116
  )
874
- for i, dim in enumerate(map_config.independent_dimensions[::-1]):
1117
+ for i, dim in enumerate(map_config.independent_dimensions):
875
1118
  nxprocess.data[dim.label] = NXfield(
876
1119
  value=map_config.coords[dim.label],
877
1120
  units=dim.units,
@@ -901,7 +1144,7 @@ class IntegrateMapProcessor(Processor):
901
1144
  value=np.empty(
902
1145
  (*tuple(
903
1146
  [len(coord_values) for coord_name, coord_values
904
- in map_config.coords.items()][::-1]),
1147
+ in map_config.coords.items()]),
905
1148
  *integration_config.integrated_data_shape)),
906
1149
  units='a.u',
907
1150
  attrs={'long_name':'Intensity (a.u)'})
@@ -958,33 +1201,256 @@ class MapProcessor(Processor):
958
1201
  NXentry object representing that map's metadata and any
959
1202
  scalar-valued raw data requested by the supplied map configuration.
960
1203
  """
961
- def process(self, data):
1204
+ def process(
1205
+ self, data, config=None, detector_names=None, num_proc=1,
1206
+ comm=None, inputdir=None):
962
1207
  """Process the output of a `Reader` that contains a map
963
1208
  configuration and returns a NeXus NXentry object representing
964
1209
  the map.
965
1210
 
966
1211
  :param data: Result of `Reader.read` where at least one item
967
- has the value `'MapConfig'` for the `'schema'` key.
1212
+ has the value `'common.models.map.MapConfig'` for the
1213
+ `'schema'` key.
968
1214
  :type data: list[PipelineData]
1215
+ :param config: Initialization parameters for an instance of
1216
+ common.models.map.MapConfig, defaults to `None`.
1217
+ :type config: dict, optional
1218
+ :param detector_names: Detector names/prefixes to include raw
1219
+ data for in the returned NeXus NXentry object,
1220
+ defaults to `None`.
1221
+ :type detector_names: Union(int, str, list[int], list[str]),
1222
+ optional
1223
+ :param num_proc: Number of processors used to read map,
1224
+ defaults to `1`.
1225
+ :type num_proc: int, optional
969
1226
  :return: Map data and metadata.
970
1227
  :rtype: nexusformat.nexus.NXentry
971
1228
  """
972
- map_config = self.get_config(data, 'common.models.map.MapConfig')
973
- nxentry = self.__class__.get_nxentry(map_config)
1229
+ # System modules
1230
+ from copy import deepcopy
1231
+ import logging
1232
+ from tempfile import NamedTemporaryFile
1233
+
1234
+ # Third party modules
1235
+ import yaml
1236
+
1237
+ # Local modules
1238
+ from CHAP.runner import (
1239
+ RunConfig,
1240
+ runner,
1241
+ )
1242
+ from CHAP.utils.general import (
1243
+ is_str_series,
1244
+ string_to_list,
1245
+ )
1246
+
1247
+ # Get the validated map configuration
1248
+ try:
1249
+ map_config = self.get_config(
1250
+ data, 'common.models.map.MapConfig', inputdir=inputdir)
1251
+ except Exception as data_exc:
1252
+ self.logger.info('No valid Map configuration in input pipeline '
1253
+ 'data, using config parameter instead.')
1254
+ try:
1255
+ # Local modules
1256
+ from CHAP.common.models.map import MapConfig
1257
+
1258
+ map_config = MapConfig(**config, inputdir=inputdir)
1259
+ except Exception as dict_exc:
1260
+ raise RuntimeError from dict_exc
1261
+
1262
+ # Validate the number of processors
1263
+ if not isinstance(num_proc, int):
1264
+ self.logger.warning('Ignoring invalid parameter num_proc '
1265
+ f'({num_proc}), running serially')
1266
+ num_proc = 1
1267
+ elif num_proc > 1:
1268
+ try:
1269
+ # System modules
1270
+ from os import cpu_count
1271
+
1272
+ # Third party modules
1273
+ from mpi4py import MPI
1274
+
1275
+ if num_proc > cpu_count():
1276
+ self.logger.warning(
1277
+ f'The requested number of processors ({num_proc}) '
1278
+ 'exceeds the maximum number of processors '
1279
+ f'({cpu_count()}): reset it to {cpu_count()}')
1280
+ num_proc = cpu_count()
1281
+ except:
1282
+ self.logger.warning('Unable to load mpi4py, running serially')
1283
+ num_proc = 1
1284
+
1285
+ # Validate the detector names/prefixes
1286
+ if map_config.experiment_type == 'EDD':
1287
+ if detector_names is None:
1288
+ detector_indices = None
1289
+ else:
1290
+ # Local modules
1291
+ from CHAP.utils.general import is_str_series
1292
+
1293
+ if isinstance(detector_names, int):
1294
+ detector_names = [str(detector_names)]
1295
+ elif isinstance(detector_names, str):
1296
+ try:
1297
+ detector_names = [
1298
+ str(v) for v in string_to_list(
1299
+ detector_names, raise_error=True)]
1300
+ except:
1301
+ raise ValueError('Invalid parameter detector_names '
1302
+ f'({detector_names})')
1303
+ else:
1304
+ detector_names = [str(v) for v in detector_names]
1305
+ detector_indices = [int(name) for name in detector_names]
1306
+ else:
1307
+ if detector_names is None:
1308
+ raise ValueError(
1309
+ 'Missing "detector_names" parameter')
1310
+ if isinstance(detector_names, str):
1311
+ detector_names = [detector_names]
1312
+ if not is_str_series(detector_names, log=False):
1313
+ raise ValueError(
1314
+ f'Invalid "detector_names" parameter ({detector_names})')
1315
+
1316
+ # Create the sub-pipeline configuration for each processor
1317
+ # FIX: catered to EDD with one spec scan
1318
+ assert len(map_config.spec_scans) == 1
1319
+ spec_scans = map_config.spec_scans[0]
1320
+ scan_numbers = spec_scans.scan_numbers
1321
+ num_scan = len(scan_numbers)
1322
+ if num_scan < num_proc:
1323
+ self.logger.warning(
1324
+ f'The requested number of processors ({num_proc}) exceeds '
1325
+ f'the number of scans ({num_scan}): reset it to {num_scan}')
1326
+ num_proc = num_scan
1327
+ if num_proc == 1:
1328
+ common_comm = comm
1329
+ offsets = [0]
1330
+ else:
1331
+ scans_per_proc = num_scan//num_proc
1332
+ num = scans_per_proc
1333
+ if num_scan - scans_per_proc*num_proc > 0:
1334
+ num += 1
1335
+ spec_scans.scan_numbers = scan_numbers[:num]
1336
+ n_scan = num
1337
+ pipeline_config = []
1338
+ offsets = [0]
1339
+ for n_proc in range(1, num_proc):
1340
+ num = scans_per_proc
1341
+ if n_proc < num_scan - scans_per_proc*num_proc:
1342
+ num += 1
1343
+ config = deepcopy(map_config.dict())
1344
+ config['spec_scans'][0]['scan_numbers'] = \
1345
+ scan_numbers[n_scan:n_scan+num]
1346
+ pipeline_config.append(
1347
+ [{'common.MapProcessor': {
1348
+ 'config': config, 'detector_names': detector_names}}])
1349
+ offsets.append(n_scan)
1350
+ n_scan += num
1351
+
1352
+ # Spawn the workers to run the sub-pipeline
1353
+ run_config = RunConfig(
1354
+ config={'log_level': logging.getLevelName(self.logger.level),
1355
+ 'spawn': 1})
1356
+ tmp_names = []
1357
+ with NamedTemporaryFile(delete=False) as fp:
1358
+ fp_name = fp.name
1359
+ tmp_names.append(fp_name)
1360
+ with open(fp_name, 'w') as f:
1361
+ yaml.dump({'config': {'spawn': 1}}, f, sort_keys=False)
1362
+ for n_proc in range(1, num_proc):
1363
+ f_name = f'{fp_name}_{n_proc}'
1364
+ tmp_names.append(f_name)
1365
+ with open(f_name, 'w') as f:
1366
+ yaml.dump(
1367
+ {'config': run_config.__dict__,
1368
+ 'pipeline': pipeline_config[n_proc-1]},
1369
+ f, sort_keys=False)
1370
+ sub_comm = MPI.COMM_SELF.Spawn(
1371
+ 'CHAP', args=[fp_name], maxprocs=num_proc-1)
1372
+ common_comm = sub_comm.Merge(False)
1373
+ # Align with the barrier in RunConfig() on common_comm
1374
+ # called from the spawned main()
1375
+ common_comm.barrier()
1376
+ # Align with the barrier in run() on common_comm
1377
+ # called from the spawned main()
1378
+ common_comm.barrier()
1379
+
1380
+ if common_comm is None:
1381
+ num_proc = 1
1382
+ rank = 0
1383
+ else:
1384
+ num_proc = common_comm.Get_size()
1385
+ rank = common_comm.Get_rank()
1386
+ if num_proc == 1:
1387
+ offset = 0
1388
+ else:
1389
+ num_scan = common_comm.bcast(num_scan, root=0)
1390
+ offset = common_comm.scatter(offsets, root=0)
1391
+
1392
+ # Read the raw data
1393
+ if map_config.experiment_type == 'EDD':
1394
+ data, independent_dimensions, all_scalar_data = \
1395
+ self._read_raw_data_edd(
1396
+ map_config, detector_indices, common_comm, num_scan,
1397
+ offset)
1398
+ else:
1399
+ data, independent_dimensions, all_scalar_data = \
1400
+ self._read_raw_data(
1401
+ map_config, detector_names, common_comm, num_scan, offset)
1402
+ if not rank:
1403
+ self.logger.debug(f'Data shape: {data.shape}')
1404
+ if independent_dimensions is not None:
1405
+ self.logger.debug('Independent dimensions shape: '
1406
+ f'{independent_dimensions.shape}')
1407
+ if all_scalar_data is not None:
1408
+ self.logger.debug('Scalar data shape: '
1409
+ f'{all_scalar_data.shape}')
1410
+
1411
+ if rank:
1412
+ return None
1413
+
1414
+ if num_proc > 1:
1415
+ # Reset the scan_numbers to the original full set
1416
+ spec_scans.scan_numbers = scan_numbers
1417
+ # Disconnect spawned workers and cleanup temporary files
1418
+ common_comm.barrier()
1419
+ sub_comm.Disconnect()
1420
+ for tmp_name in tmp_names:
1421
+ os.remove(tmp_name)
1422
+
1423
+ # Construct the NeXus NXentry object
1424
+ nxentry = self._get_nxentry(
1425
+ map_config, detector_names, data, independent_dimensions,
1426
+ all_scalar_data)
974
1427
 
975
1428
  return nxentry
976
1429
 
977
- @staticmethod
978
- def get_nxentry(map_config):
1430
+ def _get_nxentry(
1431
+ self, map_config, detector_names, data, independent_dimensions,
1432
+ all_scalar_data):
979
1433
  """Use a `MapConfig` to construct a NeXus NXentry object.
980
1434
 
981
1435
  :param map_config: A valid map configuration.
982
- :type map_config: MapConfig
1436
+ :type map_config: common.models.map.MapConfig
1437
+ :param detector_names: Detector names to include raw data
1438
+ for in the returned NeXus NXentry object,
1439
+ defaults to `None`.
1440
+ :type detector_names: list[str]
1441
+ :param data: The map's raw data.
1442
+ :type data: numpy.ndarray
1443
+ :param independent_dimensions: The map's independent
1444
+ coordinates.
1445
+ :type independent_dimensions: numpy.ndarray
1446
+ :param all_scalar_data: The map's scalar data.
1447
+ :type all_scalar_data: numpy.ndarray
983
1448
  :return: The map's data and metadata contained in a NeXus
984
1449
  structure.
985
1450
  :rtype: nexusformat.nexus.NXentry
986
1451
  """
987
1452
  # System modules
1453
+ from copy import deepcopy
988
1454
  from json import dumps
989
1455
 
990
1456
  # Third party modules
@@ -996,11 +1462,16 @@ class MapProcessor(Processor):
996
1462
  NXsample,
997
1463
  )
998
1464
 
1465
+ # Local modules:
1466
+ from CHAP.common.models.map import PointByPointScanData
1467
+ from CHAP.utils.general import is_int_series
1468
+
1469
+ # Set up NeXus NXentry and add misc. CHESS-specific metadata
999
1470
  nxentry = NXentry(name=map_config.title)
1000
- nxentry.map_config = dumps(map_config.dict())
1001
- nxentry[map_config.sample.name] = NXsample(**map_config.sample.dict())
1002
1471
  nxentry.attrs['station'] = map_config.station
1003
-
1472
+ for key, value in map_config.attrs.items():
1473
+ nxentry.attrs[key] = value
1474
+ nxentry.detector_names = detector_names
1004
1475
  nxentry.spec_scans = NXcollection()
1005
1476
  for scans in map_config.spec_scans:
1006
1477
  nxentry.spec_scans[scans.scanparsers[0].scan_name] = \
@@ -1008,44 +1479,618 @@ class MapProcessor(Processor):
1008
1479
  dtype='int8',
1009
1480
  attrs={'spec_file': str(scans.spec_file)})
1010
1481
 
1011
- nxentry.data = NXdata()
1012
- if map_config.map_type == 'structured':
1013
- nxentry.data.attrs['axes'] = map_config.dims
1014
- for i, dim in enumerate(map_config.independent_dimensions[::-1]):
1015
- nxentry.data[dim.label] = NXfield(
1016
- value=map_config.coords[dim.label],
1482
+ # Add sample metadata
1483
+ nxentry[map_config.sample.name] = NXsample(**map_config.sample.dict())
1484
+
1485
+ # Set up default NeXus NXdata group (squeeze out constant dimensions)
1486
+ constant_dim = []
1487
+ for i, dim in enumerate(map_config.independent_dimensions):
1488
+ unique = np.unique(independent_dimensions[i])
1489
+ if unique.size == 1:
1490
+ constant_dim.append(i)
1491
+ nxentry.data = NXdata(
1492
+ NXfield(data, 'detector_data'),
1493
+ tuple([
1494
+ NXfield(
1495
+ independent_dimensions[i], dim.label,
1496
+ attrs={'units': dim.units,
1497
+ 'long_name': f'{dim.label} ({dim.units})',
1498
+ 'data_type': dim.data_type,
1499
+ 'local_name': dim.name})
1500
+ for i, dim in enumerate(map_config.independent_dimensions)
1501
+ if i not in constant_dim]))
1502
+ nxentry.data.set_default()
1503
+
1504
+ # Set up auxiliary NeXus NXdata group (add the constant dimensions)
1505
+ auxiliary_signals = []
1506
+ auxiliary_data = []
1507
+ for i, dim in enumerate(map_config.all_scalar_data):
1508
+ auxiliary_signals.append(dim.label)
1509
+ auxiliary_data.append(NXfield(
1510
+ value=all_scalar_data[i],
1017
1511
  units=dim.units,
1018
1512
  attrs={'long_name': f'{dim.label} ({dim.units})',
1019
1513
  'data_type': dim.data_type,
1020
- 'local_name': dim.name})
1021
- if map_config.map_type == 'structured':
1022
- nxentry.data.attrs[f'{dim.label}_indices'] = i
1023
-
1024
- signal = False
1025
- auxilliary_signals = []
1026
- for data in map_config.all_scalar_data:
1027
- nxentry.data[data.label] = NXfield(
1028
- value=np.empty(map_config.shape),
1029
- units=data.units,
1030
- attrs={'long_name': f'{data.label} ({data.units})',
1031
- 'data_type': data.data_type,
1032
- 'local_name': data.name})
1033
- if not signal:
1034
- signal = data.label
1514
+ 'local_name': dim.name}))
1515
+ for i, dim in enumerate(deepcopy(map_config.independent_dimensions)):
1516
+ if i in constant_dim:
1517
+ auxiliary_signals.append(dim.label)
1518
+ auxiliary_data.append(NXfield(
1519
+ independent_dimensions[i], dim.label,
1520
+ attrs={'units': dim.units,
1521
+ 'long_name': f'{dim.label} ({dim.units})',
1522
+ 'data_type': dim.data_type,
1523
+ 'local_name': dim.name}))
1524
+ map_config.all_scalar_data.append(
1525
+ PointByPointScanData(**dict(dim)))
1526
+ map_config.independent_dimensions.remove(dim)
1527
+ if auxiliary_signals:
1528
+ nxentry.auxdata = NXdata()
1529
+ for label, data in zip(auxiliary_signals, auxiliary_data):
1530
+ nxentry.auxdata[label] = data
1531
+ if 'SCAN_N' in auxiliary_signals:
1532
+ nxentry.auxdata.attrs['signal'] = 'SCAN_N'
1035
1533
  else:
1036
- auxilliary_signals.append(data.label)
1534
+ nxentry.auxdata.attrs['signal'] = auxiliary_signals[0]
1535
+ auxiliary_signals.remove(nxentry.auxdata.attrs['signal'])
1536
+ nxentry.auxdata.attrs['auxiliary_signals'] = auxiliary_signals
1037
1537
 
1038
- if signal:
1039
- nxentry.data.attrs['signal'] = signal
1040
- nxentry.data.attrs['auxilliary_signals'] = auxilliary_signals
1041
-
1042
- for data in map_config.all_scalar_data:
1043
- for map_index in np.ndindex(map_config.shape):
1044
- nxentry.data[data.label][map_index] = map_config.get_value(
1045
- data, map_index)
1538
+ nxentry.map_config = dumps(map_config.dict())
1046
1539
 
1047
1540
  return nxentry
1048
1541
 
1542
+ def _read_raw_data_edd(
1543
+ self, map_config, detector_indices, comm, num_scan, offset):
1544
+ """Read the raw EDD data for a given map configuration.
1545
+
1546
+ :param map_config: A valid map configuration.
1547
+ :type map_config: common.models.map.MapConfig
1548
+ :param detector_indices: Indices to the corresponding
1549
+ detector names.
1550
+ :type detector_indices: list[int]
1551
+ :return: The map's raw data, independent dimensions and scalar
1552
+ data
1553
+ :rtype: numpy.ndarray, numpy.ndarray, numpy.ndarray
1554
+ """
1555
+ # Third party modules
1556
+ try:
1557
+ from mpi4py import MPI
1558
+ from mpi4py.util import dtlib
1559
+ except:
1560
+ pass
1561
+
1562
+ # Local modules
1563
+ from CHAP.utils.general import list_to_string
1564
+
1565
+ if comm is None:
1566
+ num_proc = 1
1567
+ rank = 0
1568
+ else:
1569
+ num_proc = comm.Get_size()
1570
+ rank = comm.Get_rank()
1571
+ if not rank:
1572
+ self.logger.debug(f'Number of processors: {num_proc}')
1573
+ self.logger.debug(f'Number of scans: {num_scan}')
1574
+
1575
+ # Create the shared data buffers
1576
+ # FIX: just one spec scan at this point
1577
+ assert len(map_config.spec_scans) == 1
1578
+ scan = map_config.spec_scans[0]
1579
+ scan_numbers = scan.scan_numbers
1580
+ scanparser = scan.get_scanparser(scan_numbers[0])
1581
+ ddata = scanparser.get_detector_data(detector_indices)
1582
+ spec_scan_shape = scanparser.spec_scan_shape
1583
+ num_dim = np.prod(spec_scan_shape)
1584
+ num_id = len(map_config.independent_dimensions)
1585
+ num_sd = len(map_config.all_scalar_data)
1586
+ if num_proc == 1:
1587
+ assert num_scan == len(scan_numbers)
1588
+ data = np.empty((num_scan, *ddata.shape), dtype=ddata.dtype)
1589
+ independent_dimensions = np.empty(
1590
+ (num_id, num_scan*num_dim), dtype=np.float64)
1591
+ all_scalar_data = np.empty(
1592
+ (num_sd, num_scan*num_dim), dtype=np.float64)
1593
+ else:
1594
+ self.logger.debug(f'Scan offset on processor {rank}: {offset}')
1595
+ self.logger.debug(f'Scan numbers on processor {rank}: '
1596
+ f'{list_to_string(scan_numbers)}')
1597
+ datatype = dtlib.from_numpy_dtype(ddata.dtype)
1598
+ itemsize = datatype.Get_size()
1599
+ if not rank:
1600
+ nbytes = num_scan * np.prod(ddata.shape) * itemsize
1601
+ else:
1602
+ nbytes = 0
1603
+ win = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm)
1604
+ buf, itemsize = win.Shared_query(0)
1605
+ assert itemsize == datatype.Get_size()
1606
+ data = np.ndarray(
1607
+ buffer=buf, dtype=ddata.dtype, shape=(num_scan, *ddata.shape))
1608
+ datatype = dtlib.from_numpy_dtype(np.float64)
1609
+ itemsize = datatype.Get_size()
1610
+ if not rank:
1611
+ nbytes = num_id * num_scan * num_dim * itemsize
1612
+ win_id = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm)
1613
+ buf_id, _ = win_id.Shared_query(0)
1614
+ independent_dimensions = np.ndarray(
1615
+ buffer=buf_id, dtype=np.float64,
1616
+ shape=(num_id, num_scan*num_dim))
1617
+ if not rank:
1618
+ nbytes = num_sd * num_scan * num_dim * itemsize
1619
+ win_sd = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm)
1620
+ buf_sd, _ = win_sd.Shared_query(0)
1621
+ all_scalar_data = np.ndarray(
1622
+ buffer=buf_sd, dtype=np.float64,
1623
+ shape=(num_sd, num_scan*num_dim))
1624
+
1625
+ # Read the raw data
1626
+ init = True
1627
+ for scan in map_config.spec_scans:
1628
+ for scan_number in scan.scan_numbers:
1629
+ if init:
1630
+ init = False
1631
+ else:
1632
+ scanparser = scan.get_scanparser(scan_number)
1633
+ assert spec_scan_shape == scanparser.spec_scan_shape
1634
+ ddata = scanparser.get_detector_data(detector_indices)
1635
+ data[offset] = ddata
1636
+ spec_scan_motor_mnes = scanparser.spec_scan_motor_mnes
1637
+ start_dim = offset * num_dim
1638
+ end_dim = start_dim + num_dim
1639
+ if len(spec_scan_shape) == 1:
1640
+ for i, dim in enumerate(map_config.independent_dimensions):
1641
+ v = dim.get_value(
1642
+ scan, scan_number, scan_step_index=-1,
1643
+ relative=False)
1644
+ if dim.name in spec_scan_motor_mnes:
1645
+ independent_dimensions[i][start_dim:end_dim] = v
1646
+ else:
1647
+ independent_dimensions[i][start_dim:end_dim] = \
1648
+ np.repeat(v, spec_scan_shape[0])
1649
+ for i, dim in enumerate(map_config.all_scalar_data):
1650
+ v = dim.get_value(
1651
+ scan, scan_number, scan_step_index=-1,
1652
+ relative=False)
1653
+ #if dim.name in spec_scan_motor_mnes:
1654
+ if dim.data_type == 'scan_column':
1655
+ all_scalar_data[i][start_dim:end_dim] = v
1656
+ else:
1657
+ all_scalar_data[i][start_dim:end_dim] = \
1658
+ np.repeat(v, spec_scan_shape[0])
1659
+ else:
1660
+ for i, dim in enumerate(map_config.independent_dimensions):
1661
+ v = dim.get_value(
1662
+ scan, scan_number, scan_step_index=-1,
1663
+ relative=False)
1664
+ if dim.name == spec_scan_motor_mnes[0]:
1665
+ # Fast motor
1666
+ independent_dimensions[i][start_dim:end_dim] = \
1667
+ np.concatenate((v,)*spec_scan_shape[1])
1668
+ elif dim.name == spec_scan_motor_mnes[1]:
1669
+ # Slow motor
1670
+ independent_dimensions[i][start_dim:end_dim] = \
1671
+ np.repeat(v, spec_scan_shape[0])
1672
+ else:
1673
+ independent_dimensions[i][start_dim:end_dim] = v
1674
+ for i, dim in enumerate(map_config.all_scalar_data):
1675
+ v = dim.get_value(
1676
+ scan, scan_number, scan_step_index=-1,
1677
+ relative=False)
1678
+ if dim.data_type == 'scan_column':
1679
+ all_scalar_data[i][start_dim:end_dim] = v
1680
+ elif dim.data_type == 'smb_par':
1681
+ if dim.name == spec_scan_motor_mnes[0]:
1682
+ # Fast motor
1683
+ all_scalar_data[i][start_dim:end_dim] = \
1684
+ np.concatenate((v,)*spec_scan_shape[1])
1685
+ elif dim.name == spec_scan_motor_mnes[1]:
1686
+ # Slow motor
1687
+ all_scalar_data[i][start_dim:end_dim] = \
1688
+ np.repeat(v, spec_scan_shape[0])
1689
+ else:
1690
+ all_scalar_data[i][start_dim:end_dim] = v
1691
+ else:
1692
+ raise RuntimeError(
1693
+ f'{dim.data_type} in data_type not tested')
1694
+ offset += 1
1695
+
1696
+ return (
1697
+ data.reshape((np.prod(data.shape[:2]), *data.shape[2:])),
1698
+ independent_dimensions, all_scalar_data)
1699
+
1700
+ def _read_raw_data(
1701
+ self, map_config, detector_names, comm, num_scan, offset):
1702
+ """Read the raw data for a given map configuration.
1703
+
1704
+ :param map_config: A valid map configuration.
1705
+ :type map_config: common.models.map.MapConfig
1706
+ :param detector_names: Detector names to include raw data
1707
+ for in the returned NeXus NXentry object,
1708
+ defaults to `None`.
1709
+ :type detector_names: list[str]
1710
+ :return: The map's raw data, independent dimensions and scalar
1711
+ data
1712
+ :rtype: numpy.ndarray, numpy.ndarray, numpy.ndarray
1713
+ """
1714
+ # Third party modules
1715
+ try:
1716
+ from mpi4py import MPI
1717
+ from mpi4py.util import dtlib
1718
+ except:
1719
+ pass
1720
+
1721
+ # Local modules
1722
+ from CHAP.utils.general import list_to_string
1723
+
1724
+ if comm is None:
1725
+ num_proc = 1
1726
+ rank = 0
1727
+ else:
1728
+ num_proc = comm.Get_size()
1729
+ rank = comm.Get_rank()
1730
+ if not rank:
1731
+ self.logger.debug(f'Number of processors: {num_proc}')
1732
+ self.logger.debug(f'Number of scans: {num_scan}')
1733
+
1734
+ # Create the shared data buffers
1735
+ # FIX: just one spec scan and one detector at this point
1736
+ assert len(map_config.spec_scans) == 1
1737
+ assert len(detector_names) == 1
1738
+ scans = map_config.spec_scans[0]
1739
+ scan_numbers = scans.scan_numbers
1740
+ scanparser = scans.get_scanparser(scan_numbers[0])
1741
+ ddata = scanparser.get_detector_data(detector_names[0])
1742
+ num_dim = ddata.shape[0]
1743
+ num_id = len(map_config.independent_dimensions)
1744
+ num_sd = len(map_config.all_scalar_data)
1745
+ if not num_sd:
1746
+ all_scalar_data = None
1747
+ if num_proc == 1:
1748
+ assert num_scan == len(scan_numbers)
1749
+ data = np.empty((num_scan, *ddata.shape), dtype=ddata.dtype)
1750
+ independent_dimensions = np.empty(
1751
+ (num_scan, num_id, num_dim), dtype=np.float64)
1752
+ if num_sd:
1753
+ all_scalar_data = np.empty(
1754
+ (num_scan, num_sd, num_dim), dtype=np.float64)
1755
+ else:
1756
+ self.logger.debug(f'Scan offset on processor {rank}: {offset}')
1757
+ self.logger.debug(f'Scan numbers on processor {rank}: '
1758
+ f'{list_to_string(scan_numbers)}')
1759
+ datatype = dtlib.from_numpy_dtype(ddata.dtype)
1760
+ itemsize = datatype.Get_size()
1761
+ if not rank:
1762
+ nbytes = num_scan * np.prod(ddata.shape) * itemsize
1763
+ else:
1764
+ nbytes = 0
1765
+ win = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm)
1766
+ buf, _ = win.Shared_query(0)
1767
+ data = np.ndarray(
1768
+ buffer=buf, dtype=ddata.dtype, shape=(num_scan, *ddata.shape))
1769
+ datatype = dtlib.from_numpy_dtype(np.float64)
1770
+ itemsize = datatype.Get_size()
1771
+ if not rank:
1772
+ nbytes = num_scan * num_id * num_dim * itemsize
1773
+ else:
1774
+ nbytes = 0
1775
+ win_id = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm)
1776
+ buf_id, _ = win_id.Shared_query(0)
1777
+ independent_dimensions = np.ndarray(
1778
+ buffer=buf_id, dtype=np.float64,
1779
+ shape=(num_scan, num_id, num_dim))
1780
+ if num_sd:
1781
+ if not rank:
1782
+ nbytes = num_scan * num_sd * num_dim * itemsize
1783
+ win_sd = MPI.Win.Allocate_shared(nbytes, itemsize, comm=comm)
1784
+ buf_sd, _ = win_sd.Shared_query(0)
1785
+ all_scalar_data = np.ndarray(
1786
+ buffer=buf_sd, dtype=np.float64,
1787
+ shape=(num_scan, num_sd, num_dim))
1788
+
1789
+ # Read the raw data
1790
+ init = True
1791
+ for scans in map_config.spec_scans:
1792
+ for scan_number in scans.scan_numbers:
1793
+ if init:
1794
+ init = False
1795
+ else:
1796
+ scanparser = scans.get_scanparser(scan_number)
1797
+ ddata = scanparser.get_detector_data(detector_names[0])
1798
+ data[offset] = ddata
1799
+ for i, dim in enumerate(map_config.independent_dimensions):
1800
+ if dim.data_type == 'scan_column':
1801
+ independent_dimensions[offset,i] = dim.get_value(
1802
+ #v = dim.get_value(
1803
+ scans, scan_number, scan_step_index=-1,
1804
+ relative=False)[:num_dim]
1805
+ #print(f'\ndim: {dim}\nv {np.asarray(v).shape}: {v}')
1806
+ #independent_dimensions[offset,i] = v[:num_dim]
1807
+ elif dim.data_type in ['smb_par', 'spec_motor']:
1808
+ independent_dimensions[offset,i] = dim.get_value(
1809
+ #v = dim.get_value(
1810
+ scans, scan_number, scan_step_index=-1,
1811
+ relative=False)
1812
+ #print(f'\ndim: {dim}\nv {np.asarray(v).shape}: {v}')
1813
+ #independent_dimensions[offset,i] = v
1814
+ else:
1815
+ raise RuntimeError(
1816
+ f'{dim.data_type} in data_type not tested')
1817
+ for i, dim in enumerate(map_config.all_scalar_data):
1818
+ all_scalar_data[offset,i] = dim.get_value(
1819
+ scans, scan_number, scan_step_index=-1,
1820
+ relative=False)
1821
+ offset += 1
1822
+
1823
+ if num_sd:
1824
+ return (
1825
+ data.reshape((1, np.prod(data.shape[:2]), *data.shape[2:])),
1826
+ np.stack(tuple([independent_dimensions[:,i].flatten()
1827
+ for i in range(num_id)])),
1828
+ np.stack(tuple([all_scalar_data[:,i].flatten()
1829
+ for i in range(num_sd)])))
1830
+ return (
1831
+ data.reshape((1, np.prod(data.shape[:2]), *data.shape[2:])),
1832
+ np.stack(tuple([independent_dimensions[:,i].flatten()
1833
+ for i in range(num_id)])),
1834
+ all_scalar_data)
1835
+
1836
+
1837
+ class MPITestProcessor(Processor):
1838
+ """A test MPI Processor.
1839
+ """
1840
+ def process(self, data, sub_pipeline={}):
1841
+ # Third party modules
1842
+ import mpi4py as mpi4py
1843
+ from mpi4py import MPI
1844
+
1845
+ my_rank = MPI.COMM_WORLD.Get_rank()
1846
+ size = MPI.COMM_WORLD.Get_size()
1847
+ (version, subversion) = MPI.Get_version()
1848
+
1849
+ mpi4py_version = mpi4py.__version__
1850
+
1851
+ if (my_rank == 0):
1852
+ if (size > 1):
1853
+ print('Successful first MPI test executed in parallel on '
1854
+ f'{size} processes using mpi4py version '
1855
+ f'{mpi4py_version}.')
1856
+ if int(mpi4py_version[0]) < 3:
1857
+ print('CAUTION: You are using an mpi4py version '
1858
+ 'below 3.0.0.')
1859
+ else:
1860
+ print('CAUTION: This MPI test is executed only on one MPI '
1861
+ 'process, i.e., sequentially!')
1862
+ print('Your installation supports MPI standard version '
1863
+ f'{version}.{subversion}.')
1864
+ print(f'Finished on processor {my_rank} of {size}')
1865
+
1866
+
1867
+ class MPICollectProcessor(Processor):
1868
+ """A Processor that collects the distributed worker data from
1869
+ MPIMapProcessor on the root node
1870
+ """
1871
+ def process(self, data, comm, root_as_worker=True):
1872
+ # Third party modules
1873
+ from mpi4py import MPI
1874
+
1875
+ num_proc = comm.Get_size()
1876
+ rank = comm.Get_rank()
1877
+ if root_as_worker:
1878
+ data = self.unwrap_pipelinedata(data)[-1]
1879
+ if num_proc > 1:
1880
+ data = comm.gather(data, root=0)
1881
+ else:
1882
+ for n_worker in range(1, num_proc):
1883
+ if rank == n_worker:
1884
+ comm.send(self.unwrap_pipelinedata(data)[-1], dest=0)
1885
+ data = None
1886
+ elif not rank:
1887
+ if n_worker == 1:
1888
+ data = [comm.recv(source=n_worker)]
1889
+ else:
1890
+ data.append(comm.recv(source=n_worker))
1891
+ return data
1892
+
1893
+
1894
+ class MPIMapProcessor(Processor):
1895
+ """A Processor that applies a parallel generic sub-pipeline to
1896
+ a map configuration.
1897
+ """
1898
+ def process(self, data, sub_pipeline={}):
1899
+ # System modules
1900
+ from copy import deepcopy
1901
+
1902
+ # Third party modules
1903
+ from mpi4py import MPI
1904
+
1905
+ # Local modules
1906
+ from CHAP.runner import (
1907
+ RunConfig,
1908
+ run,
1909
+ )
1910
+ from CHAP.common.models.map import (
1911
+ SpecScans,
1912
+ SpecConfig,
1913
+ )
1914
+
1915
+ comm = MPI.COMM_WORLD
1916
+ num_proc = comm.Get_size()
1917
+ rank = comm.Get_rank()
1918
+
1919
+ # Get the map configuration from data
1920
+ map_config = self.get_config(
1921
+ data, 'common.models.map.MapConfig')
1922
+
1923
+ # Create the spec reader configuration for each processor
1924
+ spec_scans = map_config.spec_scans[0]
1925
+ scan_numbers = spec_scans.scan_numbers
1926
+ num_scan = len(scan_numbers)
1927
+ scans_per_proc = num_scan//num_proc
1928
+ n_scan = 0
1929
+ for n_proc in range(num_proc):
1930
+ num = scans_per_proc
1931
+ if n_proc == rank:
1932
+ if rank < num_scan - scans_per_proc*num_proc:
1933
+ num += 1
1934
+ scan_numbers = scan_numbers[n_scan:n_scan+num]
1935
+ n_scan += num
1936
+ spec_config = {
1937
+ 'station': map_config.station,
1938
+ 'experiment_type': map_config.experiment_type,
1939
+ 'spec_scans': [SpecScans(
1940
+ spec_file=spec_scans.spec_file, scan_numbers=scan_numbers)]}
1941
+
1942
+ # Get the run configuration to use for the sub-pipeline
1943
+ run_config = RunConfig(sub_pipeline.get('config', {}), comm)
1944
+ pipeline_config = []
1945
+ for item in sub_pipeline['pipeline']:
1946
+ if isinstance(item, dict):
1947
+ for k, v in deepcopy(item).items():
1948
+ if k.endswith('Reader'):
1949
+ v['config'] = spec_config
1950
+ item[k] = v
1951
+ if num_proc > 1 and k.endswith('Writer'):
1952
+ r, e = os.path.splitext(v['filename'])
1953
+ v['filename'] = f'{r}_{rank}{e}'
1954
+ item[k] = v
1955
+ pipeline_config.append(item)
1956
+
1957
+ # Run the sub-pipeline on each processor
1958
+ return run(
1959
+ pipeline_config, inputdir=run_config.inputdir,
1960
+ outputdir=run_config.outputdir,
1961
+ interactive=run_config.interactive, comm=comm)
1962
+
1963
+
1964
+ class MPISpawnMapProcessor(Processor):
1965
+ """A Processor that applies a parallel generic sub-pipeline to
1966
+ a map configuration by spawning workers processes.
1967
+ """
1968
+ def process(
1969
+ self, data, num_proc=1, root_as_worker=True, collect_on_root=True,
1970
+ sub_pipeline={}):
1971
+ # System modules
1972
+ from copy import deepcopy
1973
+ from tempfile import NamedTemporaryFile
1974
+
1975
+ # Third party modules
1976
+ try:
1977
+ from mpi4py import MPI
1978
+ except:
1979
+ raise ImportError('Unable to import mpi4py')
1980
+ import yaml
1981
+
1982
+ # Local modules
1983
+ from CHAP.runner import (
1984
+ RunConfig,
1985
+ runner,
1986
+ )
1987
+ from CHAP.common.models.map import (
1988
+ SpecScans,
1989
+ SpecConfig,
1990
+ )
1991
+
1992
+ # Get the map configuration from data
1993
+ map_config = self.get_config(
1994
+ data, 'common.models.map.MapConfig')
1995
+
1996
+ # Get the run configuration to use for the sub-pipeline
1997
+ run_config = RunConfig(config=sub_pipeline.get('config', {}))
1998
+
1999
+ # Create the sub-pipeline configuration for each processor
2000
+ spec_scans = map_config.spec_scans[0]
2001
+ scan_numbers = spec_scans.scan_numbers
2002
+ num_scan = len(scan_numbers)
2003
+ scans_per_proc = num_scan//num_proc
2004
+ n_scan = 0
2005
+ pipeline_config = []
2006
+ for n_proc in range(num_proc):
2007
+ num = scans_per_proc
2008
+ if n_proc < num_scan - scans_per_proc*num_proc:
2009
+ num += 1
2010
+ spec_config = {
2011
+ 'station': map_config.station,
2012
+ 'experiment_type': map_config.experiment_type,
2013
+ 'spec_scans': [SpecScans(
2014
+ spec_file=spec_scans.spec_file,
2015
+ scan_numbers=scan_numbers[n_scan:n_scan+num]).__dict__]}
2016
+ sub_pipeline_config = []
2017
+ for item in deepcopy(sub_pipeline['pipeline']):
2018
+ if isinstance(item, dict):
2019
+ for k, v in deepcopy(item).items():
2020
+ if k.endswith('Reader'):
2021
+ v['config'] = spec_config
2022
+ item[k] = v
2023
+ if num_proc > 1 and k.endswith('Writer'):
2024
+ r, e = os.path.splitext(v['filename'])
2025
+ v['filename'] = f'{r}_{n_proc}{e}'
2026
+ item[k] = v
2027
+ sub_pipeline_config.append(item)
2028
+ if collect_on_root and (not root_as_worker or num_proc > 1):
2029
+ sub_pipeline_config += [
2030
+ {'common.MPICollectProcessor': {
2031
+ 'root_as_worker': root_as_worker}}]
2032
+ pipeline_config.append(sub_pipeline_config)
2033
+ n_scan += num
2034
+
2035
+ # Optionally include the root node as a worker node
2036
+ if root_as_worker:
2037
+ first_proc = 1
2038
+ run_config.spawn = 1
2039
+ else:
2040
+ first_proc = 0
2041
+ run_config.spawn = -1
2042
+
2043
+ # Spawn the workers to run the sub-pipeline
2044
+ if num_proc > first_proc:
2045
+ tmp_names = []
2046
+ with NamedTemporaryFile(delete=False) as fp:
2047
+ fp_name = fp.name
2048
+ tmp_names.append(fp_name)
2049
+ with open(fp_name, 'w') as f:
2050
+ yaml.dump(
2051
+ {'config': {'spawn': run_config.spawn}}, f,
2052
+ sort_keys=False)
2053
+ for n_proc in range(first_proc, num_proc):
2054
+ f_name = f'{fp_name}_{n_proc}'
2055
+ tmp_names.append(f_name)
2056
+ with open(f_name, 'w') as f:
2057
+ yaml.dump(
2058
+ {'config': run_config.__dict__,
2059
+ 'pipeline': pipeline_config[n_proc]},
2060
+ f, sort_keys=False)
2061
+ sub_comm = MPI.COMM_SELF.Spawn(
2062
+ 'CHAP', args=[fp_name], maxprocs=num_proc-first_proc)
2063
+ common_comm = sub_comm.Merge(False)
2064
+ if run_config.spawn > 0:
2065
+ # Align with the barrier in RunConfig() on common_comm
2066
+ # called from the spawned main()
2067
+ common_comm.barrier()
2068
+ else:
2069
+ common_comm = None
2070
+
2071
+ # Run the sub-pipeline on the root node
2072
+ if root_as_worker:
2073
+ data = runner(run_config, pipeline_config[0], common_comm)
2074
+ elif collect_on_root:
2075
+ run_config.spawn = 0
2076
+ pipeline_config = [{'common.MPICollectProcessor': {
2077
+ 'root_as_worker': root_as_worker}}]
2078
+ data = runner(run_config, pipeline_config, common_comm)
2079
+ else:
2080
+ # Align with the barrier in run() on common_comm
2081
+ # called from the spawned main()
2082
+ common_comm.barrier()
2083
+ data = None
2084
+
2085
+ # Disconnect spawned workers and cleanup temporary files
2086
+ if num_proc > first_proc:
2087
+ common_comm.barrier()
2088
+ sub_comm.Disconnect()
2089
+ for tmp_name in tmp_names:
2090
+ os.remove(tmp_name)
2091
+
2092
+ return data
2093
+
1049
2094
 
1050
2095
  class NexusToNumpyProcessor(Processor):
1051
2096
  """A Processor to convert the default plottable data in a NeXus
@@ -1162,7 +2207,7 @@ class PrintProcessor(Processor):
1162
2207
  """
1163
2208
  print(f'{self.__name__} data :')
1164
2209
  if callable(getattr(data, '_str_tree', None)):
1165
- # If data is likely an NXobject, print its tree
2210
+ # If data is likely a NeXus NXobject, print its tree
1166
2211
  # representation (since NXobjects' str representations are
1167
2212
  # just their nxname)
1168
2213
  print(data._str_tree(attrs=True, recursive=True))
@@ -1172,6 +2217,67 @@ class PrintProcessor(Processor):
1172
2217
  return data
1173
2218
 
1174
2219
 
2220
+ class PyfaiAzimuthalIntegrationProcessor(Processor):
2221
+ """Processor to azimuthally integrate one or more frames of 2d
2222
+ detector data using the
2223
+ [pyFAI](https://pyfai.readthedocs.io/en/v2023.1/index.html)
2224
+ package.
2225
+ """
2226
+ def process(self, data, poni_file, npt, mask_file=None,
2227
+ integrate1d_kwargs=None, inputdir='.'):
2228
+ """Azimuthally integrate the detector data provided and return
2229
+ the result as a dictionary of numpy arrays containing the
2230
+ values of the radial coordinate of the result, the intensities
2231
+ along the radial direction, and the poisson errors for each
2232
+ intensity spectrum.
2233
+
2234
+ :param data: Detector data to integrate.
2235
+ :type data: Union[PipelineData, list[np.ndarray]]
2236
+ :param poni_file: Name of the [pyFAI PONI
2237
+ file](https://pyfai.readthedocs.io/en/v2023.1/glossary.html?highlight=poni%20file#poni-file)
2238
+ containing the detector properties pyFAI needs to perform
2239
+ azimuthal integration.
2240
+ :type poni_file: str
2241
+ :param npt: Number of points in the output pattern.
2242
+ :type npt: int
2243
+ :param mask_file: A file to use for masking the input data.
2244
+ :type: str
2245
+ :param integrate1d_kwargs: Optional dictionary of keyword
2246
+ arguments to use with
2247
+ [`pyFAI.azimuthalIntegrator.AzimuthalIntegrator.integrate1d`](https://pyfai.readthedocs.io/en/v2023.1/api/pyFAI.html#pyFAI.azimuthalIntegrator.AzimuthalIntegrator.integrate1d). Defaults
2248
+ to `None`.
2249
+ :type integrate1d_kwargs: Optional[dict]
2250
+ :returns: Azimuthal integration results as a dictionary of
2251
+ numpy arrays.
2252
+ """
2253
+ # Third party modules
2254
+ from pyFAI import load
2255
+
2256
+ if not os.path.isabs(poni_file):
2257
+ poni_file = os.path.join(inputdir, poni_file)
2258
+ ai = load(poni_file)
2259
+
2260
+ if mask_file is None:
2261
+ mask = None
2262
+ else:
2263
+ # Third party modules
2264
+ import fabio
2265
+ if not os.path.isabs(mask_file):
2266
+ mask_file = os.path.join(inputdir, mask_file)
2267
+ mask = fabio.open(mask_file).data
2268
+
2269
+ try:
2270
+ det_data = self.unwrap_pipelinedata(data)[0]
2271
+ except:
2272
+ det_data = det_data
2273
+
2274
+ if integrate1d_kwargs is None:
2275
+ integrate1d_kwargs = {}
2276
+ integrate1d_kwargs['mask'] = mask
2277
+
2278
+ return [ai.integrate1d(d, npt, **integrate1d_kwargs) for d in det_data]
2279
+
2280
+
1175
2281
  class RawDetectorDataMapProcessor(Processor):
1176
2282
  """A Processor to return a map of raw derector data in a
1177
2283
  NeXus NXroot object.
@@ -1200,13 +2306,14 @@ class RawDetectorDataMapProcessor(Processor):
1200
2306
  `Processor`.
1201
2307
 
1202
2308
  :param data: Result of `Reader.read` where at least one item
1203
- has the value `'MapConfig'` for the `'schema'` key.
2309
+ has the value `'common.models.map.MapConfig'` for the
2310
+ `'schema'` key.
1204
2311
  :type data: list[PipelineData]
1205
2312
  :raises Exception: If a valid map config object cannot be
1206
2313
  constructed from `data`.
1207
2314
  :return: A valid instance of the map configuration object with
1208
2315
  field values taken from `data`.
1209
- :rtype: MapConfig
2316
+ :rtype: common.models.map.MapConfig
1210
2317
  """
1211
2318
  # Local modules
1212
2319
  from CHAP.common.models.map import MapConfig
@@ -1216,7 +2323,7 @@ class RawDetectorDataMapProcessor(Processor):
1216
2323
  for item in data:
1217
2324
  if isinstance(item, dict):
1218
2325
  schema = item.get('schema')
1219
- if schema == 'MapConfig':
2326
+ if schema == 'common.models.map.MapConfig':
1220
2327
  map_config = item.get('data')
1221
2328
 
1222
2329
  if not map_config:
@@ -1230,7 +2337,7 @@ class RawDetectorDataMapProcessor(Processor):
1230
2337
  relevant metadata in the form of a NeXus structure.
1231
2338
 
1232
2339
  :param map_config: The map configuration.
1233
- :type map_config: MapConfig
2340
+ :type map_config: common.models.map.MapConfig
1234
2341
  :param detector_name: The detector prefix.
1235
2342
  :type detector_name: str
1236
2343
  :param detector_shape: The shape of detector data for a single
@@ -1344,6 +2451,504 @@ class StrainAnalysisProcessor(Processor):
1344
2451
  return strain_analysis_config
1345
2452
 
1346
2453
 
2454
+ class SetupNXdataProcessor(Processor):
2455
+ """Processor to set up and return an "empty" NeXus representation
2456
+ of a structured dataset. This representation will be an instance
2457
+ of a NeXus NXdata object that has:
2458
+ 1. A NeXus NXfield entry for every coordinate/signal specified.
2459
+ 1. `nxaxes` that are the NeXus NXfield entries for the coordinates
2460
+ and contain the values provided for each coordinate.
2461
+ 1. NeXus NXfield entries of appropriate shape, but containing all
2462
+ zeros, for every signal.
2463
+ 1. Attributes that define the axes, plus any additional attributes
2464
+ specified by the user.
2465
+
2466
+ This `Processor` is most useful as a "setup" step for
2467
+ constucting a representation of / container for a complete dataset
2468
+ that will be filled out in pieces later by
2469
+ `UpdateNXdataProcessor`.
2470
+
2471
+ Examples of use in a `Pipeline` configuration:
2472
+ - With inputs from a previous `PipelineItem` specifically written
2473
+ to provide inputs to this `Processor`:
2474
+ ```yaml
2475
+ config:
2476
+ inputdir: /rawdata/samplename
2477
+ outputdir: /reduceddata/samplename
2478
+ pipeline:
2479
+ - edd.SetupNXdataReader:
2480
+ filename: SpecInput.txt
2481
+ dataset_id: 1
2482
+ - common.SetupNXdataProcessor:
2483
+ nxname: samplename_dataset_1
2484
+ - common.NexusWriter:
2485
+ filename: data.nxs
2486
+ ```
2487
+ - With inputs provided directly though the optional arguments:
2488
+ ```yaml
2489
+ config:
2490
+ outputdir: /reduceddata/samplename
2491
+ pipeline:
2492
+ - common.SetupNXdataProcessor:
2493
+ nxname: your_dataset_name
2494
+ coords:
2495
+ - name: x
2496
+ values: [0.0, 0.5, 1.0]
2497
+ attrs:
2498
+ units: mm
2499
+ yourkey: yourvalue
2500
+ - name: temperature
2501
+ values: [200, 250, 275]
2502
+ attrs:
2503
+ units: Celsius
2504
+ yourotherkey: yourothervalue
2505
+ signals:
2506
+ - name: raw_detector_data
2507
+ shape: [407, 487]
2508
+ attrs:
2509
+ local_name: PIL11
2510
+ foo: bar
2511
+ - name: presample_intensity
2512
+ shape: []
2513
+ attrs:
2514
+ local_name: a3ic0
2515
+ zebra: fish
2516
+ attrs:
2517
+ arbitrary: metadata
2518
+ from: users
2519
+ goes: here
2520
+ - common.NexusWriter:
2521
+ filename: data.nxs
2522
+ ```
2523
+ """
2524
+ def process(self, data, nxname='data',
2525
+ coords=[], signals=[], attrs={}, data_points=[],
2526
+ extra_nxfields=[], duplicates='overwrite'):
2527
+ """Return a NeXus NXdata object that has the requisite axes
2528
+ and NeXus NXfield entries to represent a structured dataset
2529
+ with the properties provided. Properties may be provided either
2530
+ through the `data` argument (from an appropriate `PipelineItem`
2531
+ that immediately preceeds this one in a `Pipeline`), or through
2532
+ the `coords`, `signals`, `attrs`, and/or `data_points`
2533
+ arguments. If any of the latter are used, their values will
2534
+ completely override any values for these parameters found from
2535
+ `data.`
2536
+
2537
+ :param data: Data from the previous item in a `Pipeline`.
2538
+ :type data: list[PipelineData]
2539
+ :param nxname: Name for the returned NeXus NXdata object.
2540
+ Defaults to `'data'`.
2541
+ :type nxname: str, optional
2542
+ :param coords: List of dictionaries defining the coordinates
2543
+ of the dataset. Each dictionary must have the keys
2544
+ `'name'` and `'values'`, whose values are the name of the
2545
+ coordinate axis (a string) and all the unique values of
2546
+ that coordinate for the structured dataset (a list of
2547
+ numbers), respectively. A third item in the dictionary is
2548
+ optional, but highly recommended: `'attrs'` may provide a
2549
+ dictionary of attributes to attach to the coordinate axis
2550
+ that assist in in interpreting the returned NeXus NXdata
2551
+ representation of the dataset. It is strongly recommended
2552
+ to provide the units of the values along an axis in the
2553
+ `attrs` dictionary. Defaults to [].
2554
+ :type coords: list[dict[str, object]], optional
2555
+ :param signals: List of dictionaries defining the signals of
2556
+ the dataset. Each dictionary must have the keys `'name'`
2557
+ and `'shape'`, whose values are the name of the signal
2558
+ field (a string) and the shape of the signal's value at
2559
+ each point in the dataset (a list of zero or more
2560
+ integers), respectively. A third item in the dictionary is
2561
+ optional, but highly recommended: `'attrs'` may provide a
2562
+ dictionary of attributes to attach to the signal fieldthat
2563
+ assist in in interpreting the returned NeXus NXdata
2564
+ representation of the dataset. It is strongly recommended
2565
+ to provide the units of the signal's values `attrs`
2566
+ dictionary. Defaults to [].
2567
+ :type signals: list[dict[str, object]], optional
2568
+ :param attrs: An arbitrary dictionary of attributes to assign
2569
+ to the returned NeXus NXdata object. Defaults to {}.
2570
+ :type attrs: dict[str, object], optional
2571
+ :param data_points: A list of data points to partially (or
2572
+ even entirely) fil out the "empty" signal NeXus NXfield's
2573
+ before returning the NeXus NXdata object. Defaults to [].
2574
+ :type data_points: list[dict[str, object]], optional
2575
+ :param extra_nxfields: List "extra" NeXus NXfield's to include that
2576
+ can be described neither as a signal of the dataset, not a
2577
+ dedicated coordinate. This paramteter is good for
2578
+ including "alternate" values for one of the coordinate
2579
+ dimensions -- the same coordinate axis expressed in
2580
+ different units, for instance. Each item in the list
2581
+ shoulde be a dictionary of parameters for the
2582
+ `nexusformat.nexus.NXfield` constructor. Defaults to `[]`.
2583
+ :type extra_nxfields: list[dict[str, object]], optional
2584
+ :param duplicates: Behavior to use if any new data points occur
2585
+ at the same point in the dataset's coordinate space as an
2586
+ existing data point. Allowed values for `duplicates` are:
2587
+ `'overwrite'` and `'block'`. Defaults to `'overwrite'`.
2588
+ :type duplicates: Literal['overwrite', 'block']
2589
+ :returns: A NeXus NXdata object that represents the structured
2590
+ dataset as specified.
2591
+ :rtype: nexusformat.nexus.NXdata
2592
+ """
2593
+ self.nxname = nxname
2594
+
2595
+ self.coords = coords
2596
+ self.signals = signals
2597
+ self.attrs = attrs
2598
+ try:
2599
+ setup_params = self.unwrap_pipelinedata(data)[0]
2600
+ except:
2601
+ setup_params = None
2602
+ if isinstance(setup_params, dict):
2603
+ for a in ('coords', 'signals', 'attrs'):
2604
+ setup_param = setup_params.get(a)
2605
+ if not getattr(self, a) and setup_param:
2606
+ self.logger.info(f'Using input data from pipeline for {a}')
2607
+ setattr(self, a, setup_param)
2608
+ else:
2609
+ self.logger.info(
2610
+ f'Ignoring input data from pipeline for {a}')
2611
+ else:
2612
+ self.logger.warning('Ignoring all input data from pipeline')
2613
+
2614
+ self.shape = tuple(len(c['values']) for c in self.coords)
2615
+
2616
+ self.extra_nxfields = extra_nxfields
2617
+ self._data_points = []
2618
+ self.duplicates = duplicates
2619
+ self.init_nxdata()
2620
+ for d in data_points:
2621
+ self.add_data_point(d)
2622
+
2623
+ return self.nxdata
2624
+
2625
+ def add_data_point(self, data_point):
2626
+ """Add a data point to this dataset.
2627
+ 1. Validate `data_point`.
2628
+ 2. Append `data_point` to `self._data_points`.
2629
+ 3. Update signal `NXfield`s in `self.nxdata`.
2630
+
2631
+ :param data_point: Data point defining a point in the
2632
+ dataset's coordinate space and the new signal values at
2633
+ that point.
2634
+ :type data_point: dict[str, object]
2635
+ :returns: None
2636
+ """
2637
+ self.logger.info(f'Adding data point no. {len(self._data_points)}')
2638
+ self.logger.debug(f'New data point: {data_point}')
2639
+ valid, msg = self.validate_data_point(data_point)
2640
+ if not valid:
2641
+ self.logger.error(f'Cannot add data point: {msg}')
2642
+ else:
2643
+ self._data_points.append(data_point)
2644
+ self.update_nxdata(data_point)
2645
+
2646
+ def validate_data_point(self, data_point):
2647
+ """Return `True` if `data_point` occurs at a valid point in
2648
+ this structured dataset's coordinate space, `False`
2649
+ otherwise. Also validate shapes of signal values and add NaN
2650
+ values for any missing signals.
2651
+
2652
+ :param data_point: Data point defining a point in the
2653
+ dataset's coordinate space and the new signal values at
2654
+ that point.
2655
+ :type data_point: dict[str, object]
2656
+ :returns: Validity of `data_point`, message
2657
+ :rtype: bool, str
2658
+ """
2659
+ # Third party modules
2660
+ import numpy as np
2661
+
2662
+ valid = True
2663
+ msg = ''
2664
+ # Convert all values to numpy types
2665
+ data_point = {k: np.asarray(v) for k, v in data_point.items()}
2666
+ # Ensure data_point defines a specific point in the dataset's
2667
+ # coordinate space
2668
+ if not all(c['name'] in data_point for c in self.coords):
2669
+ valid = False
2670
+ msg = 'Missing coordinate values'
2671
+ # Find & handle any duplicates
2672
+ for i, d in enumerate(self._data_points):
2673
+ is_duplicate = all(data_point[c] == d[c] for c in self.coord_names)
2674
+ if is_duplicate:
2675
+ if self.duplicates == 'overwrite':
2676
+ self._data_points.pop(i)
2677
+ elif self.duplicates == 'block':
2678
+ valid = False
2679
+ msg = 'Duplicate point will be blocked'
2680
+ # Ensure a value is present for all signals
2681
+ for s in self.signals:
2682
+ if s['name'] not in data_point:
2683
+ data_point[s['name']] = np.full(s['shape'], 0)
2684
+ else:
2685
+ if not data_point[s['name']].shape == tuple(s['shape']):
2686
+ valid = False
2687
+ msg = f'Shape mismatch for signal {s}'
2688
+ return valid, msg
2689
+
2690
+ def init_nxdata(self):
2691
+ """Initialize an empty NeXus NXdata representing this dataset
2692
+ to `self.nxdata`; values for axes' `NXfield`s are filled out,
2693
+ values for signals' `NXfield`s are empty an can be filled out
2694
+ later. Save the empty NeXus NXdata object to the NeXus file.
2695
+ Initialise `self.nxfile` and `self.nxdata_path` with the
2696
+ `NXFile` object and actual nxpath used to save and make updates
2697
+ to the Nexus NXdata object.
2698
+
2699
+ :returns: None
2700
+ """
2701
+ # Third party modules
2702
+ from nexusformat.nexus import NXdata, NXfield
2703
+ import numpy as np
2704
+
2705
+ axes = tuple(NXfield(
2706
+ value=c['values'],
2707
+ name=c['name'],
2708
+ attrs=c.get('attrs')) for c in self.coords)
2709
+ entries = {s['name']: NXfield(
2710
+ value=np.full((*self.shape, *s['shape']), 0),
2711
+ name=s['name'],
2712
+ attrs=s.get('attrs')) for s in self.signals}
2713
+ extra_nxfields = [NXfield(**params) for params in self.extra_nxfields]
2714
+ extra_nxfields = {f.nxname: f for f in extra_nxfields}
2715
+ entries.update(extra_nxfields)
2716
+ self.nxdata = NXdata(
2717
+ name=self.nxname, axes=axes, entries=entries, attrs=self.attrs)
2718
+
2719
+ def update_nxdata(self, data_point):
2720
+ """Update `self.nxdata`'s NXfield values.
2721
+
2722
+ :param data_point: Data point defining a point in the
2723
+ dataset's coordinate space and the new signal values at
2724
+ that point.
2725
+ :type data_point: dict[str, object]
2726
+ :returns: None
2727
+ """
2728
+ index = self.get_index(data_point)
2729
+ for s in self.signals:
2730
+ if s['name'] in data_point:
2731
+ self.nxdata[s['name']][index] = data_point[s['name']]
2732
+
2733
+ def get_index(self, data_point):
2734
+ """Return a tuple representing the array index of `data_point`
2735
+ in the coordinate space of the dataset.
2736
+
2737
+ :param data_point: Data point defining a point in the
2738
+ dataset's coordinate space.
2739
+ :type data_point: dict[str, object]
2740
+ :returns: Multi-dimensional index of `data_point` in the
2741
+ dataset's coordinate space.
2742
+ :rtype: tuple
2743
+ """
2744
+ return tuple(c['values'].index(data_point[c['name']]) \
2745
+ for c in self.coords)
2746
+
2747
+
2748
+ class UpdateNXdataProcessor(Processor):
2749
+ """Processor to fill in part(s) of a NeXus NXdata representing a
2750
+ structured dataset that's already been written to a NeXus file.
2751
+
2752
+ This Processor is most useful as an "update" step for a NeXus
2753
+ NXdata object created by `common.SetupNXdataProcessor`, and is
2754
+ most easy to use in a `Pipeline` immediately after another
2755
+ `PipelineItem` designed specifically to return a value that can
2756
+ be used as input to this `Processor`.
2757
+
2758
+ Example of use in a `Pipeline` configuration:
2759
+ ```yaml
2760
+ config:
2761
+ inputdir: /rawdata/samplename
2762
+ pipeline:
2763
+ - edd.UpdateNXdataReader:
2764
+ spec_file: spec.log
2765
+ scan_number: 1
2766
+ - common.SetupNXdataProcessor:
2767
+ nxfilename: /reduceddata/samplename/data.nxs
2768
+ nxdata_path: /entry/samplename_dataset_1
2769
+ ```
2770
+ """
2771
+
2772
+ def process(self, data, nxfilename, nxdata_path, data_points=[],
2773
+ allow_approximate_coordinates=True):
2774
+ """Write new data points to the signal fields of an existing
2775
+ NeXus NXdata object representing a structued dataset in a NeXus
2776
+ file. Return the list of data points used to update the
2777
+ dataset.
2778
+
2779
+ :param data: Data from the previous item in a `Pipeline`. May
2780
+ contain a list of data points that will extend the list of
2781
+ data points optionally provided with the `data_points`
2782
+ argument.
2783
+ :type data: list[PipelineData]
2784
+ :param nxfilename: Name of the NeXus file containing the
2785
+ NeXus NXdata object to update.
2786
+ :type nxfilename: str
2787
+ :param nxdata_path: The path to the NeXus NXdata object to
2788
+ update in the file.
2789
+ :type nxdata_path: str
2790
+ :param data_points: List of data points, each one a dictionary
2791
+ whose keys are the names of the coordinates and axes, and
2792
+ whose values are the values of each coordinate / signal at
2793
+ a single point in the dataset. Deafults to [].
2794
+ :type data_points: list[dict[str, object]]
2795
+ :param allow_approximate_coordinates: Parameter to allow the
2796
+ nearest existing match for the new data points'
2797
+ coordinates to be used if an exact match connot be found
2798
+ (sometimes this is due simply to differences in rounding
2799
+ convetions). Defaults to True.
2800
+ :type allow_approximate_coordinates: bool, optional
2801
+ :returns: Complete list of data points used to update the dataset.
2802
+ :rtype: list[dict[str, object]]
2803
+ """
2804
+ # Third party modules
2805
+ from nexusformat.nexus import NXFile
2806
+ import numpy as np
2807
+
2808
+ _data_points = self.unwrap_pipelinedata(data)[0]
2809
+ if isinstance(_data_points, list):
2810
+ data_points.extend(_data_points)
2811
+ self.logger.info(f'Updating {len(data_points)} data points')
2812
+
2813
+ nxfile = NXFile(nxfilename, 'rw')
2814
+ nxdata = nxfile.readfile()[nxdata_path]
2815
+ axes_names = [a.nxname for a in nxdata.nxaxes]
2816
+
2817
+ data_points_used = []
2818
+ for i, d in enumerate(data_points):
2819
+ # Verify that the data point contains a value for all
2820
+ # coordinates in the dataset.
2821
+ if not all(a in d for a in axes_names):
2822
+ self.logger.error(
2823
+ f'Data point {i} is missing a value for at least one '
2824
+ + f'axis. Skipping. Axes are: {", ".join(axes_names)}')
2825
+ continue
2826
+ self.logger.info(
2827
+ f'Coordinates for data point {i}: '
2828
+ + ', '.join([f'{a}={d[a]}' for a in axes_names]))
2829
+ # Get the index of the data point in the dataset based on
2830
+ # its values for each coordinate.
2831
+ try:
2832
+ index = tuple(np.where(a.nxdata == d[a.nxname])[0][0] \
2833
+ for a in nxdata.nxaxes)
2834
+ except:
2835
+ if allow_approximate_coordinates:
2836
+ try:
2837
+ index = tuple(
2838
+ np.argmin(np.abs(a.nxdata - d[a.nxname])) \
2839
+ for a in nxdata.nxaxes)
2840
+ self.logger.warning(
2841
+ f'Nearest match for coordinates of data point {i}:'
2842
+ + ', '.join(
2843
+ [f'{a.nxname}={a[_i]}' \
2844
+ for _i, a in zip(index, nxdata.nxaxes)]))
2845
+ except:
2846
+ self.logger.error(
2847
+ f'Cannot get the index of data point {i}. '
2848
+ + f'Skipping.')
2849
+ continue
2850
+ else:
2851
+ self.logger.error(
2852
+ f'Cannot get the index of data point {i}. Skipping.')
2853
+ continue
2854
+ self.logger.info(f'Index of data point {i}: {index}')
2855
+ # Update the signals contained in this data point at the
2856
+ # proper index in the dataset's singal `NXfield`s
2857
+ for k, v in d.items():
2858
+ if k in axes_names:
2859
+ continue
2860
+ try:
2861
+ nxfile.writevalue(
2862
+ os.path.join(nxdata_path, k), np.asarray(v), index)
2863
+ except Exception as e:
2864
+ self.logger.error(
2865
+ f'Error updating signal {k} for new data point '
2866
+ + f'{i} (dataset index {index}): {e}')
2867
+ data_points_used.append(d)
2868
+
2869
+ nxfile.close()
2870
+
2871
+ return data_points_used
2872
+
2873
+
2874
+ class NXdataToDataPointsProcessor(Processor):
2875
+ """Transform a NeXus NXdata object into a list of dictionaries.
2876
+ Each dictionary represents a single data point in the coordinate
2877
+ space of the dataset. The keys are the names of the signals and
2878
+ axes in the dataset, and the values are a single scalar value (in
2879
+ the case of axes) or the value of the signal at that point in the
2880
+ coordinate space of the dataset (in the case of signals -- this
2881
+ means that values for signals may be any shape, depending on the
2882
+ shape of the signal itself).
2883
+
2884
+ Example of use in a pipeline configuration:
2885
+ ```yaml
2886
+ config:
2887
+ inputdir: /reduceddata/samplename
2888
+ - common.NXdataReader:
2889
+ name: data
2890
+ axes_names:
2891
+ - x
2892
+ - y
2893
+ signal_name: z
2894
+ nxfield_params:
2895
+ - filename: data.nxs
2896
+ nxpath: entry/data/x
2897
+ slice_params:
2898
+ - step: 2
2899
+ - filename: data.nxs
2900
+ nxpath: entry/data/y
2901
+ slice_params:
2902
+ - step: 2
2903
+ - filename: data.nxs
2904
+ nxpath: entry/data/z
2905
+ slice_params:
2906
+ - step: 2
2907
+ - step: 2
2908
+ - common.NXdataToDataPointsProcessor
2909
+ - common.UpdateNXdataProcessor:
2910
+ nxfilename: /reduceddata/samplename/sparsedata.nxs
2911
+ nxdata_path: /entry/data
2912
+ ```
2913
+ """
2914
+ def process(self, data):
2915
+ """Return a list of dictionaries representing the coordinate
2916
+ and signal values at every point in the dataset provided.
2917
+
2918
+ :param data: Input pipeline data containing a NeXus NXdata
2919
+ object.
2920
+ :type data: list[PipelineData]
2921
+ :returns: List of all data points in the dataset.
2922
+ :rtype: list[dict[str,object]]
2923
+ """
2924
+ # Third party modules
2925
+ import numpy as np
2926
+
2927
+ nxdata = self.unwrap_pipelinedata(data)[0]
2928
+
2929
+ data_points = []
2930
+ axes_names = [a.nxname for a in nxdata.nxaxes]
2931
+ self.logger.info(f'Dataset axes: {axes_names}')
2932
+ dataset_shape = tuple([a.size for a in nxdata.nxaxes])
2933
+ self.logger.info(f'Dataset shape: {dataset_shape}')
2934
+ signal_names = [k for k, v in nxdata.entries.items() \
2935
+ if not k in axes_names \
2936
+ and v.shape[:len(dataset_shape)] == dataset_shape]
2937
+ self.logger.info(f'Dataset signals: {signal_names}')
2938
+ other_fields = [k for k, v in nxdata.entries.items() \
2939
+ if not k in axes_names + signal_names]
2940
+ if len(other_fields) > 0:
2941
+ self.logger.warning(
2942
+ 'Ignoring the following fields that cannot be interpreted as '
2943
+ + f'either dataset coordinates or signals: {other_fields}')
2944
+ for i in np.ndindex(dataset_shape):
2945
+ data_points.append({**{a: nxdata[a][_i] \
2946
+ for a, _i in zip(axes_names, i)},
2947
+ **{s: nxdata[s].nxdata[i] \
2948
+ for s in signal_names}})
2949
+ return data_points
2950
+
2951
+
1347
2952
  class XarrayToNexusProcessor(Processor):
1348
2953
  """A Processor to convert the data in an `xarray` structure to a
1349
2954
  NeXus NXdata object.
@@ -1394,3 +2999,30 @@ if __name__ == '__main__':
1394
2999
  from CHAP.processor import main
1395
3000
 
1396
3001
  main()
3002
+
3003
+
3004
+ class SumProcessor(Processor):
3005
+ """A Processor to sum the data in a NeXus NXobject, given a set of
3006
+ nxpaths
3007
+ """
3008
+ def process(self, data):
3009
+ """Return the summed data array
3010
+
3011
+ :param data:
3012
+ :type data:
3013
+ :return: The summed data.
3014
+ :rtype: numpy.ndarray
3015
+ """
3016
+ from copy import deepcopy
3017
+
3018
+ nxentry, nxpaths = self.unwrap_pipelinedata(data)[-1]
3019
+ if len(nxpaths) == 1:
3020
+ return nxentry[nxpaths[0]]
3021
+ sum_data = deepcopy(nxentry[nxpaths[0]])
3022
+ for nxpath in nxpaths[1:]:
3023
+ nxdata = nxentry[nxpath]
3024
+ for entry in nxdata.entries:
3025
+ sum_data[entry] += nxdata[entry]
3026
+
3027
+ return sum_data
3028
+