ChessAnalysisPipeline 0.0.17.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. CHAP/TaskManager.py +216 -0
  2. CHAP/__init__.py +27 -0
  3. CHAP/common/__init__.py +57 -0
  4. CHAP/common/models/__init__.py +8 -0
  5. CHAP/common/models/common.py +124 -0
  6. CHAP/common/models/integration.py +659 -0
  7. CHAP/common/models/map.py +1291 -0
  8. CHAP/common/processor.py +2869 -0
  9. CHAP/common/reader.py +658 -0
  10. CHAP/common/utils.py +110 -0
  11. CHAP/common/writer.py +730 -0
  12. CHAP/edd/__init__.py +23 -0
  13. CHAP/edd/models.py +876 -0
  14. CHAP/edd/processor.py +3069 -0
  15. CHAP/edd/reader.py +1023 -0
  16. CHAP/edd/select_material_params_gui.py +348 -0
  17. CHAP/edd/utils.py +1572 -0
  18. CHAP/edd/writer.py +26 -0
  19. CHAP/foxden/__init__.py +19 -0
  20. CHAP/foxden/models.py +71 -0
  21. CHAP/foxden/processor.py +124 -0
  22. CHAP/foxden/reader.py +224 -0
  23. CHAP/foxden/utils.py +80 -0
  24. CHAP/foxden/writer.py +168 -0
  25. CHAP/giwaxs/__init__.py +11 -0
  26. CHAP/giwaxs/models.py +491 -0
  27. CHAP/giwaxs/processor.py +776 -0
  28. CHAP/giwaxs/reader.py +8 -0
  29. CHAP/giwaxs/writer.py +8 -0
  30. CHAP/inference/__init__.py +7 -0
  31. CHAP/inference/processor.py +69 -0
  32. CHAP/inference/reader.py +8 -0
  33. CHAP/inference/writer.py +8 -0
  34. CHAP/models.py +227 -0
  35. CHAP/pipeline.py +479 -0
  36. CHAP/processor.py +125 -0
  37. CHAP/reader.py +124 -0
  38. CHAP/runner.py +277 -0
  39. CHAP/saxswaxs/__init__.py +7 -0
  40. CHAP/saxswaxs/processor.py +8 -0
  41. CHAP/saxswaxs/reader.py +8 -0
  42. CHAP/saxswaxs/writer.py +8 -0
  43. CHAP/server.py +125 -0
  44. CHAP/sin2psi/__init__.py +7 -0
  45. CHAP/sin2psi/processor.py +8 -0
  46. CHAP/sin2psi/reader.py +8 -0
  47. CHAP/sin2psi/writer.py +8 -0
  48. CHAP/tomo/__init__.py +15 -0
  49. CHAP/tomo/models.py +210 -0
  50. CHAP/tomo/processor.py +3862 -0
  51. CHAP/tomo/reader.py +9 -0
  52. CHAP/tomo/writer.py +59 -0
  53. CHAP/utils/__init__.py +6 -0
  54. CHAP/utils/converters.py +188 -0
  55. CHAP/utils/fit.py +2947 -0
  56. CHAP/utils/general.py +2655 -0
  57. CHAP/utils/material.py +274 -0
  58. CHAP/utils/models.py +595 -0
  59. CHAP/utils/parfile.py +224 -0
  60. CHAP/writer.py +122 -0
  61. MLaaS/__init__.py +0 -0
  62. MLaaS/ktrain.py +205 -0
  63. MLaaS/mnist_img.py +83 -0
  64. MLaaS/tfaas_client.py +371 -0
  65. chessanalysispipeline-0.0.17.dev3.dist-info/LICENSE +60 -0
  66. chessanalysispipeline-0.0.17.dev3.dist-info/METADATA +29 -0
  67. chessanalysispipeline-0.0.17.dev3.dist-info/RECORD +70 -0
  68. chessanalysispipeline-0.0.17.dev3.dist-info/WHEEL +5 -0
  69. chessanalysispipeline-0.0.17.dev3.dist-info/entry_points.txt +2 -0
  70. chessanalysispipeline-0.0.17.dev3.dist-info/top_level.txt +2 -0
CHAP/edd/processor.py ADDED
@@ -0,0 +1,3069 @@
1
+ #!/usr/bin/env python
2
+ #-*- coding: utf-8 -*-
3
+ """
4
+ File : processor.py
5
+ Author : Keara Soloway, Rolf Verberg
6
+ Description: Module for Processors used only by EDD experiments
7
+ """
8
+
9
+ # System modules
10
+ from copy import deepcopy
11
+ import os
12
+ from sys import float_info
13
+ from time import time
14
+ from typing import Optional
15
+
16
+ # Third party modules
17
+ import numpy as np
18
+ from pydantic import (
19
+ Field,
20
+ PrivateAttr,
21
+ model_validator,
22
+ )
23
+
24
+ # Local modules
25
+ from CHAP import Processor
26
+ from CHAP.pipeline import PipelineData
27
+ from CHAP.common.models.map import DetectorConfig
28
+ from CHAP.utils.general import fig_to_iobuf
29
+ from CHAP.edd.models import (
30
+ MCADetectorCalibration,
31
+ MCADetectorDiffractionVolumeLength,
32
+ MCADetectorStrainAnalysis,
33
+ MCADetectorConfig,
34
+ DiffractionVolumeLengthConfig,
35
+ MCAEnergyCalibrationConfig,
36
+ MCATthCalibrationConfig,
37
+ StrainAnalysisConfig,
38
+ )
39
+
40
+ FLOAT_MIN = float_info.min
41
+
42
+ # Current good detector channels for the 23 channel EDD detector:
43
+ # 0, 2, 3, 5, 6, 7, 8, 10, 13, 14, 16, 17, 18, 19, 21, 22
44
+
45
+ def get_axes(nxdata, skip_axes=None):
46
+ """Get the axes of an NXdata object used in EDD."""
47
+ if skip_axes is None:
48
+ skip_axes = []
49
+ if 'unstructured_axes' in nxdata.attrs:
50
+ axes = nxdata.attrs['unstructured_axes']
51
+ elif 'axes' in nxdata.attrs:
52
+ axes = nxdata.attrs['axes']
53
+ else:
54
+ return []
55
+ if isinstance(axes, str):
56
+ axes = [axes]
57
+ return [str(a) for a in axes if a not in skip_axes]
58
+
59
+
60
+ class BaseEddProcessor(Processor):
61
+ """Base processor for the EDD processors."""
62
+ _energies: list = PrivateAttr(default=[])
63
+ _figures: list = PrivateAttr(default=[])
64
+ _masks: list = PrivateAttr(default=[])
65
+ _mask_index_ranges: list = PrivateAttr(default=[])
66
+ _mean_data: list = PrivateAttr(default=[])
67
+ _nxdata_detectors: list = PrivateAttr(default=[])
68
+
69
+ def _apply_combined_mask(self, calibration_method=None):
70
+ """Apply the combined mask over the combined included energy
71
+ ranges.
72
+ """
73
+ for index, (energies, mean_data, nxdata, detector) in enumerate(
74
+ zip(self._energies, self._mean_data, self._nxdata_detectors,
75
+ self.detector_config.detectors)):
76
+ # Add the mask for the fluorescence peaks from the
77
+ # energy calibration for certain tth calibrations
78
+ if calibration_method == 'direct_fit_tth_ecc':
79
+ detector.convert_mask_ranges(
80
+ detector._energy_calibration_mask_ranges +
81
+ detector.get_mask_ranges())
82
+
83
+ mask = detector.mca_mask()
84
+ low, upp = np.argmax(mask), mask.size - np.argmax(mask[::-1])
85
+ self._energies[index] = energies[low:upp]
86
+ self._masks.append(mask[low:upp])
87
+ self._mask_index_ranges.append((low, upp))
88
+ self._mean_data[index] = mean_data[low:upp]
89
+ self._nxdata_detectors[index].nxsignal = nxdata.nxsignal[:,low:upp]
90
+
91
+ def _apply_energy_mask(self, lower_cutoff=25, upper_cutoff=200):
92
+ """Apply an energy mask by blanking out data below and/or
93
+ above a certain threshold.
94
+ """
95
+ dtype = self._nxdata_detectors[0].nxsignal.dtype
96
+ for index, (energies, _) in enumerate(
97
+ zip(self._energies, self.detector_config.detectors)):
98
+ energy_mask = np.where(energies >= lower_cutoff, 1, 0)
99
+ energy_mask = np.where(energies <= upper_cutoff, energy_mask, 0)
100
+ # Also blank out the last channel, which has shown to be
101
+ # troublesome
102
+ energy_mask[-1] = 0
103
+ self._mean_data[index] *= energy_mask
104
+ self._nxdata_detectors[index].nxsignal.nxdata *= \
105
+ energy_mask.astype(dtype)
106
+
107
+ def _apply_flux_correction(self):
108
+ """Apply the flux correction."""
109
+ # Check each detector's include_energy_ranges field against the
110
+ # flux file, if available.
111
+ if self.config.flux_file is not None:
112
+ raise RuntimeError('Flux correction not tested after updates')
113
+ # flux = np.loadtxt(self.config.flux_file)
114
+ # flux_file_energies = flux[:,0]/1.e3
115
+ # flux_e_min = flux_file_energies.min()
116
+ # flux_e_max = flux_file_energies.max()
117
+ # for detector in self.detector_config.detectors:
118
+ # for i, (det_e_min, det_e_max) in enumerate(
119
+ # deepcopy(detector.include_energy_ranges)):
120
+ # if det_e_min < flux_e_min or det_e_max > flux_e_max:
121
+ # energy_range = [float(max(det_e_min, flux_e_min)),
122
+ # float(min(det_e_max, flux_e_max))]
123
+ # print(
124
+ # f'WARNING: include_energy_ranges[{i}] out of range'
125
+ # f' ({detector.include_energy_ranges[i]}): adjusted'
126
+ # f' to {energy_range}')
127
+ # detector.include_energy_ranges[i] = energy_range
128
+
129
+ def _get_mask_hkls(self):
130
+ """Get the mask and HKLs used in the current processor."""
131
+ # Local modules
132
+ from CHAP.edd.utils import (
133
+ get_unique_hkls_ds,
134
+ select_mask_and_hkls,
135
+ )
136
+
137
+ if self.save_figures:
138
+ if self.__name__ == 'MCATthCalibrationProcessor':
139
+ basename = 'tth_calibration_mask_hkls'
140
+ elif self.__name__ == 'StrainAnalysisProcessor':
141
+ basename = 'strainanalysis_mask_hkls'
142
+ elif self.__name__ == 'LatticeParameterRefinementProcessor':
143
+ basename = 'lp_refinement_mask_hkls'
144
+ else:
145
+ basename = f'{self.__name__}_mask_hkls'
146
+
147
+ for energies, mean_data, nxdata, detector in zip(
148
+ self._energies, self._mean_data, self._nxdata_detectors,
149
+ self.detector_config.detectors):
150
+
151
+ # Get the unique HKLs and lattice spacings used in the
152
+ # curent proessor
153
+ hkls, ds = get_unique_hkls_ds(
154
+ self.config.materials, tth_max=detector.tth_max,
155
+ tth_tol=detector.tth_tol)
156
+
157
+ # Interactively adjust the mask and HKLs used in the
158
+ # current processor
159
+ #if isinstance(detector, MCADetectorStrainAnalysis):
160
+ if detector.processor_type == 'strainanalysis':
161
+ calibration_bin_ranges = detector.get_calibration_mask_ranges()
162
+ else:
163
+ calibration_bin_ranges = None
164
+ if detector.tth_calibrated is None:
165
+ tth = detector.tth_initial_guess
166
+ else:
167
+ tth = detector.tth_calibrated
168
+ mask_ranges, hkl_indices, buf = \
169
+ select_mask_and_hkls(
170
+ energies, mean_data, hkls, ds, tth,
171
+ preselected_bin_ranges=detector.get_mask_ranges(),
172
+ preselected_hkl_indices=detector.hkl_indices,
173
+ detector_id=detector.get_id(),
174
+ ref_map=nxdata.nxsignal.nxdata,
175
+ calibration_bin_ranges=calibration_bin_ranges,
176
+ label='Sum of the spectra in the map',
177
+ interactive=self.interactive,
178
+ return_buf=self.save_figures)
179
+ if self.save_figures:
180
+ self._figures.append((buf, f'{detector.get_id()}_{basename}'))
181
+ detector.hkl_indices = hkl_indices
182
+ detector.convert_mask_ranges(mask_ranges)
183
+ self.logger.debug(
184
+ f'energy mask_ranges for detector {detector.get_id()}:'
185
+ f' {detector.energy_mask_ranges}')
186
+ self.logger.debug(
187
+ f'hkl_indices for detector {detector.get_id()}:'
188
+ f' {detector.hkl_indices}')
189
+ if not detector.energy_mask_ranges:
190
+ raise ValueError(
191
+ 'No value provided for energy_mask_ranges. Provide '
192
+ 'them in the tth calibration configuration, or re-run the '
193
+ 'pipeline with the interactive flag set.')
194
+ if not detector.hkl_indices:
195
+ raise ValueError(
196
+ 'No value provided for hkl_indices. Provide them in '
197
+ 'the tth calibration configuration, or re-run the '
198
+ 'pipeline with the interactive flag set.')
199
+
200
+ def _setup_detector_data(self, nxobject, **kwargs):
201
+ """Load the raw MCA data from the SpecReader output and compute
202
+ the detector bin energies and the mean spectra.
203
+ """
204
+ # Third party modules
205
+ from nexusformat.nexus import (
206
+ NXdata,
207
+ NXfield,
208
+ )
209
+
210
+ available_detector_ids = kwargs['available_detector_ids']
211
+ max_energy_kev = kwargs.get('max_energy_kev')
212
+
213
+ scans = []
214
+ raw_data = []
215
+ for scan_name in nxobject.spec_scans:
216
+ spec_scan = nxobject.spec_scans[scan_name]
217
+ for scan_number, scan_data in spec_scan.items():
218
+ scans.append(f'{scan_name}_{scan_number}')
219
+ data = scan_data.data.data.nxdata
220
+ if data.ndim != 3:
221
+ raise ValueError(
222
+ f'Illegal raw detector data shape ({data.shape})')
223
+ if self.__name__ == 'DiffractionVolumeLengthProcessor':
224
+ raw_data.append(data)
225
+ else:
226
+ raw_data.append(data.sum(axis=0))
227
+ if self.__name__ == 'DiffractionVolumeLengthProcessor':
228
+ raw_data = np.sum(raw_data, axis=0)
229
+ else:
230
+ raw_data = np.asarray(raw_data)
231
+ num_bins = raw_data.shape[-1]
232
+
233
+ for detector in self.detector_config.detectors:
234
+ if detector.num_bins is None:
235
+ detector.num_bins = num_bins
236
+ elif detector.num_bins != num_bins:
237
+ raise ValueError(
238
+ 'Inconsistent number of MCA detector channels between '
239
+ 'the raw data and the detector configuration '
240
+ f'({num_bins} vs {detector.num_bins})')
241
+ if detector.energy_calibration_coeffs is None:
242
+ if max_energy_kev is None:
243
+ raise ValueError(
244
+ 'Missing max_energy_kev parameter')
245
+ detector.energy_calibration_coeffs = [
246
+ 0.0, max_energy_kev/(num_bins-1.0), 0.0]
247
+ self._energies.append(detector.energies)
248
+ index = int(available_detector_ids.index(detector.get_id()))
249
+ nxdata_det = NXdata(
250
+ NXfield(raw_data[:,index,:], 'detector_data'),
251
+ (NXfield(scans, 'scans')))
252
+ self._nxdata_detectors.append(nxdata_det)
253
+ self._mean_data = [
254
+ np.mean(
255
+ nxdata.nxsignal.nxdata[
256
+ [i for i in range(0, nxdata.nxsignal.shape[0])
257
+ if nxdata.nxsignal.nxdata[i].sum()]],
258
+ axis=tuple(i for i in range(0, nxdata.nxsignal.ndim-1)))
259
+ for nxdata in self._nxdata_detectors]
260
+ self.logger.debug(
261
+ f'data shape: {self._nxdata_detectors[0].nxsignal.shape}')
262
+ self.logger.debug(
263
+ f'mean_data shape: {np.asarray(self._mean_data).shape}')
264
+
265
+ def _subtract_baselines(self):
266
+ """Get and subtract the detector baselines."""
267
+ # Local modules
268
+ from CHAP.edd.models import BaselineConfig
269
+ from CHAP.common.processor import ConstructBaseline
270
+
271
+ if self.save_figures:
272
+ if self.__name__ == 'LatticeParameterRefinementProcessor':
273
+ basename = 'lp_refinement_baseline'
274
+ elif self.__name__ == 'DiffractionVolumeLengthProcessor':
275
+ basename = 'dvl_baseline'
276
+ elif self.__name__ == 'MCAEnergyCalibrationProcessor':
277
+ basename = 'energy_calibration_baseline'
278
+ elif self.__name__ == 'MCATthCalibrationProcessor':
279
+ basename = 'tth_calibration_baseline'
280
+ elif self.__name__ == 'StrainAnalysisProcessor':
281
+ basename = 'strainanalysis_baseline'
282
+ else:
283
+ basename = f'{self.__name__}_baseline'
284
+
285
+ baselines = []
286
+ for energies, mean_data, (low, _), nxdata, detector in zip(
287
+ self._energies, self._mean_data, self._mask_index_ranges,
288
+ self._nxdata_detectors, self.detector_config.detectors):
289
+ if detector.baseline:
290
+ if isinstance(detector.baseline, bool):
291
+ detector.baseline = BaselineConfig()
292
+ if self.__name__ in ('DiffractionVolumeLengthProcessor',
293
+ 'MCAEnergyCalibrationProcessor'):
294
+ x = low + np.arange(mean_data.size)
295
+ xlabel = 'Detector Channel (-)'
296
+ else:
297
+ x = energies
298
+ xlabel = 'Energy (keV)'
299
+
300
+ baseline, baseline_config, buf = \
301
+ ConstructBaseline.construct_baseline(
302
+ mean_data, x=x, tol=detector.baseline.tol,
303
+ lam=detector.baseline.lam,
304
+ max_iter=detector.baseline.max_iter,
305
+ title=f'Baseline for detector {detector.get_id()}',
306
+ xlabel=xlabel, ylabel='Intensity (counts)',
307
+ interactive=self.interactive,
308
+ return_buf=self.save_figures)
309
+ if self.save_figures:
310
+ self._figures.append(
311
+ (buf, f'{detector.get_id()}_{basename}'))
312
+
313
+ baselines.append(baseline)
314
+ detector.baseline.lam = baseline_config['lambda']
315
+ detector.baseline.attrs['num_iter'] = \
316
+ baseline_config['num_iter']
317
+ detector.baseline.attrs['error'] = \
318
+ baseline_config['error']
319
+
320
+ nxdata.nxsignal -= baseline
321
+ mean_data -= baseline
322
+
323
+
324
+ class BaseStrainProcessor(BaseEddProcessor):
325
+ """Base processor for LatticeParameterRefinementProcessor and
326
+ StrainAnalysisProcessor.
327
+ """
328
+ def _adjust_material_props(self, materials, index=0):
329
+ """Adjust the material properties."""
330
+ # Local modules
331
+ if self.interactive:
332
+ from CHAP.edd.select_material_params_gui import \
333
+ select_material_params
334
+ else:
335
+ from CHAP.edd.utils import select_material_params
336
+
337
+ detector = self.detector_config.detectors[index]
338
+ return select_material_params(
339
+ self._energies[index], self._mean_data[index],
340
+ detector.tth_calibrated, label='Sum of the spectra in the map',
341
+ preselected_materials=materials, interactive=self.interactive,
342
+ return_buf=self.save_figures)
343
+
344
+ def _get_sum_axes_data(self, nxdata, detector_id, sum_axes=True):
345
+ """Get the raw MCA data collected by the scan averaged over the
346
+ sum_axes.
347
+ """
348
+ # Third party modules
349
+ from nexusformat.nexus import (
350
+ NXdata,
351
+ NXfield,
352
+ )
353
+
354
+ data = nxdata[detector_id].nxdata
355
+ if not isinstance(sum_axes, list):
356
+ if sum_axes and 'fly_axis_labels' in nxdata.attrs:
357
+ sum_axes = nxdata.attrs['fly_axis_labels']
358
+ if isinstance(sum_axes, str):
359
+ sum_axes = [sum_axes]
360
+ else:
361
+ sum_axes = []
362
+ axes = get_axes(nxdata, skip_axes=sum_axes)
363
+ if not axes:
364
+ return NXdata(NXfield([np.mean(data, axis=0)], 'detector_data'))
365
+ dims = np.asarray([nxdata[a].nxdata for a in axes], dtype=np.float64).T
366
+ sum_indices = []
367
+ unique_points = []
368
+ for i in range(data.shape[0]):
369
+ point = dims[i]
370
+ found = False
371
+ for index, unique_point in enumerate(unique_points):
372
+ if all(point == unique_point):
373
+ sum_indices[index].append(i)
374
+ found = True
375
+ break
376
+ if not found:
377
+ unique_points.append(point)
378
+ sum_indices.append([i])
379
+ unique_points = np.asarray(unique_points).T
380
+ mean_data = np.empty((unique_points.shape[1], data.shape[-1]))
381
+ for i in range(unique_points.shape[1]):
382
+ mean_data[i] = np.mean(data[sum_indices[i]], axis=0)
383
+ nxdata_det = NXdata(
384
+ NXfield(mean_data, 'detector_data'),
385
+ tuple([
386
+ NXfield(unique_points[i], a, attrs=nxdata[a].attrs)
387
+ for i, a in enumerate(axes)]))
388
+ if len(axes) > 1:
389
+ nxdata_det.attrs['unstructured_axes'] = \
390
+ nxdata_det.attrs.pop('axes')
391
+ return nxdata_det
392
+
393
+ def _setup_detector_data(self, nxobject, **kwargs):
394
+ """Load the raw MCA data map accounting for oversampling or
395
+ axes summation if requested and compute the detector bin
396
+ energies and the mean spectra.
397
+ """
398
+ # Third party modules
399
+ from nexusformat.nexus import (
400
+ NXdata,
401
+ NXfield,
402
+ )
403
+
404
+ strain_analysis_config = kwargs['strain_analysis_config']
405
+ update = kwargs.get('update', True)
406
+
407
+ have_raw_detector_data = False
408
+ oversampling_axis = {}
409
+ if strain_analysis_config.sum_axes:
410
+ scan_type = int(str(nxobject.attrs.get('scan_type', 0)))
411
+ if scan_type == 4:
412
+ # Local modules
413
+ from CHAP.utils.general import rolling_average
414
+
415
+ # Check for oversampling axis and create the binned
416
+ # coordinates
417
+ raise RuntimeError('oversampling needs testing')
418
+ fly_axis = nxobject.attrs.get('fly_axis_labels').nxdata[0]
419
+ oversampling = strain_analysis_config.oversampling
420
+ oversampling_axis[fly_axis] = rolling_average(
421
+ nxobject[fly_axis].nxdata,
422
+ start=oversampling.get('start', 0),
423
+ end=oversampling.get('end'),
424
+ width=oversampling.get('width'),
425
+ stride=oversampling.get('stride'),
426
+ num=oversampling.get('num'),
427
+ mode=oversampling.get('mode', 'valid'))
428
+ elif (scan_type > 2
429
+ or isinstance(strain_analysis_config.sum_axes, list)):
430
+ # Collect the raw MCA data averaged over sum_axes
431
+ for detector in self.detector_config.detectors:
432
+ self._nxdata_detectors.append(
433
+ self._get_sum_axes_data(
434
+ nxobject, detector.get_id(),
435
+ sum_axes=strain_analysis_config.sum_axes))
436
+ have_raw_detector_data = True
437
+ if not have_raw_detector_data:
438
+ # Collect the raw MCA data if not averaged over sum_axes
439
+ axes = get_axes(nxobject)
440
+ for detector in self.detector_config.detectors:
441
+ nxdata_det = NXdata(
442
+ NXfield(
443
+ nxobject[detector.get_id()].nxdata, 'detector_data'),
444
+ tuple([
445
+ NXfield(
446
+ nxobject[a].nxdata, a, attrs=nxobject[a].attrs)
447
+ for a in axes]))
448
+ if len(axes) > 1:
449
+ nxdata_det.attrs['unstructured_axes'] = \
450
+ nxdata_det.attrs.pop('axes')
451
+ self._nxdata_detectors.append(nxdata_det)
452
+ if update:
453
+ self._mean_data = [
454
+ np.mean(
455
+ nxdata.nxsignal.nxdata[
456
+ [i for i in range(0, nxdata.nxsignal.shape[0])
457
+ if nxdata[i].nxsignal.nxdata.sum()]],
458
+ axis=tuple(i for i in range(0, nxdata.nxsignal.ndim-1)))
459
+ for nxdata in self._nxdata_detectors]
460
+ else:
461
+ self._mean_data = len(self._nxdata_detectors)*[
462
+ np.zeros((self._nxdata_detectors[0].nxsignal.shape[-1]))]
463
+ for detector in self.detector_config.detectors:
464
+ self._energies.append(detector.energies)
465
+ self.logger.debug(
466
+ 'data shape: '
467
+ f'{nxobject[self.detector_config.detectors[0].get_id()].nxdata.shape}')
468
+ self.logger.debug(
469
+ f'mean_data shape: {np.asarray(self._mean_data).shape}')
470
+
471
+
472
+ class DiffractionVolumeLengthProcessor(BaseEddProcessor):
473
+ """A Processor using a steel foil raster scan to calculate the
474
+ diffraction volume length for an EDD setup.
475
+
476
+ :ivar config: Initialization parameters for an instance of
477
+ CHAP.edd.models.DiffractionVolumeLengthConfig.
478
+ :type config: dict, optional
479
+ :ivar detector_config: Initialization parameters for an instance of
480
+ CHAP.edd.models.MCADetectorConfig. Defaults to the detector
481
+ configuration of the raw detector data.
482
+ :type detector_config: dict, optional
483
+ :ivar save_figures: Save .pngs of plots for checking inputs &
484
+ outputs of this Processor, defaults to `False`.
485
+ :type save_figures: bool, optional
486
+ """
487
+ pipeline_fields: dict = Field(
488
+ default = {
489
+ 'config': 'edd.models.DiffractionVolumeLengthConfig',
490
+ 'detector_config': {
491
+ 'schema': ['edd.models.DiffractionVolumeLengthConfig',
492
+ 'edd.models.MCADetectorConfig'],
493
+ 'merge_key_paths': {'key_path': 'detectors/id', 'type': int}},
494
+ },
495
+ init_var=True)
496
+
497
+ config: Optional[
498
+ DiffractionVolumeLengthConfig] = DiffractionVolumeLengthConfig()
499
+ detector_config: MCADetectorConfig
500
+ save_figures: Optional[bool] = False
501
+
502
+ @model_validator(mode='before')
503
+ @classmethod
504
+ def validate_diffractionvolumeLengthprocessor_before(cls, data):
505
+ if isinstance(data, dict):
506
+ detector_config = data.pop('detector_config', {})
507
+ detector_config['processor_type'] = 'diffractionvolumelength'
508
+ data['detector_config'] = detector_config
509
+ return data
510
+
511
+ @model_validator(mode='after')
512
+ def validate_diffractionvolumeLengthprocessor_after(self):
513
+ if self.config.sample_thickness is None:
514
+ raise ValueError('Missing parameter "sample_thickness"')
515
+ return self
516
+
517
+ def process(self, data):
518
+ """Return the calculated value of the DVL.
519
+
520
+ :param data: DVL calculation input configuration.
521
+ :type data: list[PipelineData]
522
+ :raises RuntimeError: Unable to get a valid DVL configuration.
523
+ :return: DVL configuration.
524
+ :rtype: dict, PipelineData
525
+ """
526
+ # Third party modules
527
+ from json import loads
528
+
529
+ # Load the detector data
530
+ # FIX input a numpy and create/use NXobject to numpy proc
531
+ # FIX right now spec info is lost in output yaml, add to it?
532
+ nxentry = self.get_default_nxentry(self.get_data(data))
533
+
534
+ # Validate the detector configuration
535
+ raw_detector_config = DetectorConfig(**loads(str(nxentry.detectors)))
536
+ raw_detector_ids = [d.get_id() for d in raw_detector_config.detectors]
537
+ if not self.detector_config.detectors:
538
+ self.detector_config.detectors = [
539
+ MCADetectorDiffractionVolumeLength(
540
+ **d.model_dump(), processor_type='diffractionvolumelength')
541
+ for d in raw_detector_config.detectors]
542
+ self.detector_config.update_detectors()
543
+ else:
544
+ skipped_detectors = []
545
+ detectors = []
546
+ for detector in self.detector_config.detectors:
547
+ if detector.get_id() in raw_detector_ids:
548
+ raw_detector = raw_detector_config.detectors[
549
+ int(raw_detector_ids.index(detector.get_id()))]
550
+ for k, v in raw_detector.attrs.items():
551
+ if k not in detector.attrs:
552
+ if isinstance(v, list):
553
+ detector.attrs[k] = np.asarray(v)
554
+ else:
555
+ detector.attrs[k] = v
556
+ detector.energy_mask_ranges = None
557
+ detectors.append(detector)
558
+ else:
559
+ skipped_detectors.append(detector.get_id())
560
+ if len(skipped_detectors) == 1:
561
+ self.logger.warning(
562
+ f'Skipping detector {skipped_detectors[0]} '
563
+ '(no raw data)')
564
+ elif skipped_detectors:
565
+ # Local modules
566
+ from CHAP.utils.general import list_to_string
567
+
568
+ skipped_detectors = [int(d) for d in skipped_detectors]
569
+ self.logger.warning(
570
+ 'Skipping detectors '
571
+ f'{list_to_string(skipped_detectors)} (no raw data)')
572
+ self.detector_config.detectors = detectors
573
+ if not self.detector_config.detectors:
574
+ raise ValueError(
575
+ 'No raw data for the requested DVL measurement detectors)')
576
+
577
+ # Load the raw MCA data and compute the detector bin energies
578
+ # and the mean spectra
579
+ self._setup_detector_data(
580
+ nxentry, available_detector_ids=raw_detector_ids,
581
+ max_energy_kev=self.config.max_energy_kev)
582
+
583
+ # Load the scanned motor position values
584
+ scanned_vals = self._get_scanned_vals(nxentry)
585
+
586
+ # Apply the flux correction
587
+ # self._apply_flux_correction()
588
+
589
+ # Apply the energy mask
590
+ self._apply_energy_mask()
591
+
592
+ # Get the mask used in the DVL measurement
593
+ self._get_mask()
594
+
595
+ # Apply the combined energy ranges mask
596
+ self._apply_combined_mask()
597
+
598
+ # Get and subtract the detector baselines
599
+ self._subtract_baselines()
600
+
601
+ # Calculate or manually select the diffraction volume lengths
602
+ self._measure_dvl(scanned_vals)
603
+
604
+ # Combine the adiffraction volume length and detector
605
+ # configuration and move default detector fields to the
606
+ # detector attrs
607
+ for d in self.detector_config.detectors:
608
+ d.attrs['default_fields'] = {
609
+ k:v.default for k, v in d.model_fields.items()
610
+ if (k != 'attrs' and (k not in d.model_fields_set
611
+ or v.default == getattr(d, k)))}
612
+ configs = {
613
+ **self.config.model_dump(),
614
+ 'detectors': [d.model_dump(exclude_defaults=True)
615
+ for d in self.detector_config.detectors]}
616
+ return configs, PipelineData(
617
+ name=self.__name__, data=self._figures,
618
+ schema='common.write.ImageWriter')
619
+
620
+ def _get_mask(self):
621
+ """Get the mask used in the DVL measurement."""
622
+ # Local modules
623
+ from CHAP.utils.general import select_mask_1d
624
+
625
+ for mean_data, detector in zip(
626
+ self._mean_data, self.detector_config.detectors):
627
+
628
+ # Interactively adjust the mask used in the energy
629
+ # calibration
630
+ buf, _, detector.mask_ranges = select_mask_1d(
631
+ mean_data, preselected_index_ranges=detector.mask_ranges,
632
+ title=f'Mask for detector {detector.get_id()}',
633
+ xlabel='Detector Channel (-)',
634
+ ylabel='Intensity (counts)',
635
+ min_num_index_ranges=1, interactive=self.interactive,
636
+ return_buf=self.save_figures)
637
+ if self.save_figures:
638
+ self._figures.append((buf, f'{detector.get_id()}_dvl_mask'))
639
+ self.logger.debug(
640
+ f'mask_ranges for detector {detector.get_id()}:'
641
+ f' {detector.mask_ranges}')
642
+ if not detector.mask_ranges:
643
+ raise ValueError(
644
+ 'No value provided for mask_ranges. Provide it in '
645
+ 'the DVL configuration, or re-run the pipeline '
646
+ 'with the interactive flag set.')
647
+
648
+ def _get_scanned_vals(self, nxentry):
649
+ """Load the raw MCA data from the SpecReader output and get
650
+ the scan columns.
651
+ """
652
+ # Third party modules
653
+ from json import loads
654
+
655
+ scanned_vals = None
656
+ for scan_name in nxentry.spec_scans:
657
+ for scan_data in nxentry.spec_scans[scan_name].values():
658
+ motor_mnes = loads(str(scan_data.spec_scan_motor_mnes))
659
+ if scanned_vals is None:
660
+ scanned_vals = np.asarray(
661
+ loads(str(scan_data.scan_columns))[motor_mnes[0]])
662
+ else:
663
+ assert np.array_equal(scanned_vals, np.asarray(
664
+ loads(str(scan_data.scan_columns))[motor_mnes[0]]))
665
+ return scanned_vals
666
+
667
+ def _measure_dvl(self, scanned_vals):
668
+ """Return a measured value for the length of the diffraction
669
+ volume. Use the iron foil raster scan data provided in
670
+ `dvl_config` and fit a gaussian to the sum of all MCA channel
671
+ counts vs scanned motor position in the raster scan. The
672
+ computed diffraction volume length is approximately equal to
673
+ the standard deviation of the fitted peak.
674
+
675
+ :param scanned_vals: The scanned motor position values.
676
+ :type scanned_vals: numpy.ndarray
677
+ :return: Updated energy DVL measurement configuration and a list of
678
+ byte stream representions of Matplotlib figures.
679
+ :rtype: dict, PipelineData
680
+ """
681
+ # Third party modules
682
+ from nexusformat.nexus import (
683
+ NXdata,
684
+ NXfield,
685
+ )
686
+
687
+ # Local modules
688
+ from CHAP.utils.fit import FitProcessor
689
+ from CHAP.utils.general import (
690
+ index_nearest,
691
+ select_mask_1d,
692
+ )
693
+
694
+ for mask, nxdata, detector in zip(
695
+ self._masks, self._nxdata_detectors,
696
+ self.detector_config.detectors):
697
+
698
+ self.logger.info(f'Measuring DVL for detector {detector.get_id()}')
699
+
700
+ masked_data = nxdata.nxsignal.nxdata[:,mask]
701
+ masked_max = np.max(masked_data, axis=1).astype(float)
702
+ masked_sum = np.sum(masked_data, axis=1).astype(float)
703
+
704
+ # Find the motor position corresponding roughly to the center
705
+ # of the diffraction volume
706
+ scan_center = np.sum(scanned_vals*masked_sum) / np.sum(masked_sum)
707
+ x = scanned_vals - scan_center
708
+
709
+ # Normalize the data
710
+ masked_max /= masked_max.max()
711
+ masked_sum /= masked_sum.max()
712
+
713
+ # Construct the fit model and preform the fit
714
+ models = []
715
+ if detector.background is not None:
716
+ if len(detector.background) == 1:
717
+ models.append(
718
+ {'model': detector.background[0], 'prefix': 'bkgd_'})
719
+ else:
720
+ for model in detector.background:
721
+ models.append({'model': model, 'prefix': f'{model}_'})
722
+ models.append({'model': 'gaussian'})
723
+ self.logger.debug('Fitting mean spectrum')
724
+ fit = FitProcessor(**self.run_config)
725
+ result = fit.process(
726
+ NXdata(
727
+ NXfield(masked_sum, 'y'), NXfield(x, 'x')),
728
+ {'models': models, 'method': 'trf'})
729
+
730
+ # Calculate / manually select diffraction volume length
731
+ detector.dvl = float(
732
+ result.best_values['sigma'] * self.config.sigma_to_dvl_factor -
733
+ self.config.sample_thickness)
734
+ detector.fit_amplitude = float(result.best_values['amplitude'])
735
+ detector.fit_center = float(
736
+ scan_center + result.best_values['center'])
737
+ detector.fit_sigma = float(result.best_values['sigma'])
738
+ if self.config.measurement_mode == 'manual':
739
+ if self.interactive:
740
+ _, _, dvl_bounds = select_mask_1d(
741
+ masked_sum, x=x,
742
+ preselected_index_ranges=[
743
+ (index_nearest(x, -0.5*detector.dvl),
744
+ index_nearest(x, 0.5*detector.dvl))],
745
+ title=('Diffraction volume length'),
746
+ xlabel=('Beam direction (offset from scan "center")'),
747
+ ylabel='Normalized intensity (-)',
748
+ min_num_index_ranges=1, max_num_index_ranges=1,
749
+ interactive=self.interactive)
750
+ dvl_bounds = dvl_bounds[0]
751
+ detector.dvl = abs(x[dvl_bounds[1]] - x[dvl_bounds[0]])
752
+ else:
753
+ self.logger.warning(
754
+ 'Cannot manually indicate DVL when running CHAP '
755
+ 'non-interactively. Using default DVL calcluation '
756
+ 'instead.')
757
+
758
+ if self.interactive or self.save_figures:
759
+ # Third party modules
760
+ import matplotlib.pyplot as plt
761
+
762
+ fig, ax = plt.subplots()
763
+ ax.set_title(f'Diffraction Volume ({detector.get_id()})')
764
+ ax.set_xlabel('Beam direction (offset from scan "center")')
765
+ ax.set_ylabel('Normalized intensity (-)')
766
+ ax.plot(x, masked_sum, label='Sum of masked data')
767
+ ax.plot(x, masked_max, label='Maximum of masked data')
768
+ ax.plot(x, result.best_fit, label='Gaussian fit (to sum)')
769
+ ax.axvspan(
770
+ result.best_values['center']- 0.5*detector.dvl,
771
+ result.best_values['center'] + 0.5*detector.dvl,
772
+ color='gray', alpha=0.5,
773
+ label=f'diffraction volume ({self.config.measurement_mode})')
774
+ ax.legend()
775
+ plt.figtext(
776
+ 0.5, 0.95,
777
+ f'Diffraction volume length: {detector.dvl:.2f}',
778
+ fontsize='x-large',
779
+ horizontalalignment='center',
780
+ verticalalignment='bottom')
781
+ if self.save_figures:
782
+ fig.tight_layout(rect=(0, 0, 1, 0.95))
783
+ self._figures.append((
784
+ fig_to_iobuf(fig), f'{detector.get_id()}_dvl'))
785
+ if self.interactive:
786
+ plt.show()
787
+ plt.close()
788
+
789
+ return self.config.model_dump(), PipelineData(
790
+ name=self.__name__, data=self._figures,
791
+ schema='common.write.ImageWriter')
792
+
793
+
794
+ class LatticeParameterRefinementProcessor(BaseStrainProcessor):
795
+ """Processor to get a refined estimate for a sample's lattice
796
+ parameters.
797
+
798
+ :ivar config: Initialization parameters for an instance of
799
+ CHAP.edd.models.StrainAnalysisConfig.
800
+ :type config: dict, optional
801
+ :ivar detector_config: Initialization parameters for an instance of
802
+ CHAP.edd.models.MCADetectorConfig. Defaults to the detector
803
+ configuration of the raw detector data merged with that of the
804
+ 2&theta calibration step..
805
+ :ivar save_figures: Save .pngs of plots for checking inputs &
806
+ outputs of this Processor, defaults to `False`.
807
+ :type save_figures: bool, optional
808
+ """
809
+ pipeline_fields: dict = Field(
810
+ default = {
811
+ 'config': 'edd.models.StrainAnalysisConfig',
812
+ 'detector_config': {
813
+ 'schema': 'edd.models.MCADetectorConfig',
814
+ 'merge_key_paths': {'key_path': 'detectors/id', 'type': int}},
815
+ },
816
+ init_var=True)
817
+ config: Optional[StrainAnalysisConfig] = StrainAnalysisConfig()
818
+ detector_config: MCADetectorConfig
819
+ save_figures: Optional[bool] = False
820
+
821
+ @model_validator(mode='before')
822
+ @classmethod
823
+ def validate_latticeparameterrefinementprocessor_before(cls, data):
824
+ if isinstance(data, dict):
825
+ detector_config = data.pop('detector_config', {})
826
+ detector_config['processor_type'] = 'strainanalysis'
827
+ data['detector_config'] = detector_config
828
+ return data
829
+
830
+ def process(self, data):
831
+ """Given a strain analysis configuration, return a copy
832
+ contining refined values for the materials' lattice
833
+ parameters.
834
+
835
+ :param data: Input data for the lattice parameter refinement
836
+ procedure.
837
+ :type data: list[PipelineData]
838
+ :raises RuntimeError: Unable to refine the lattice parameters.
839
+ :return: The strain analysis configuration with the refined
840
+ lattice parameter configuration and, optionally, a list of
841
+ byte stream representions of Matplotlib figures.
842
+ :rtype: dict, PipelineData
843
+ """
844
+ # Third party modules
845
+ from nexusformat.nexus import (
846
+ NXentry,
847
+ NXroot,
848
+ )
849
+
850
+ # Local modules
851
+ from CHAP.utils.general import list_to_string
852
+
853
+ # Load the pipeline input data
854
+ try:
855
+ nxobject = self.get_data(data)
856
+ if isinstance(nxobject, NXroot):
857
+ nxroot = nxobject
858
+ elif isinstance(nxobject, NXentry):
859
+ nxroot = NXroot()
860
+ nxroot[nxobject.nxname] = nxobject
861
+ nxobject.set_default()
862
+ except Exception as exc:
863
+ raise RuntimeError(
864
+ 'No valid input in the pipeline data') from exc
865
+
866
+ # Load the detector data
867
+ nxentry = self.get_default_nxentry(nxroot)
868
+ nxdata = nxentry[nxentry.default]
869
+
870
+ # Load the validated calibration configuration
871
+ calibration_config = self.get_config(
872
+ data, schema='edd.models.MCATthCalibrationConfig', remove=False)
873
+
874
+ # Load the validated calibration detector configurations
875
+ calibration_detector_config = self.get_data(
876
+ data, schema='edd.models.MCATthCalibrationConfig')
877
+ calibration_detectors = [
878
+ MCADetectorCalibration(**d)
879
+ for d in calibration_detector_config.get('detectors', [])]
880
+ calibration_detector_ids = [d.get_id() for d in calibration_detectors]
881
+
882
+ # Check for available raw detector data and for the available
883
+ # calibration data
884
+ if not self.detector_config.detectors:
885
+ self.detector_config.detectors = [
886
+ MCADetectorStrainAnalysis(
887
+ id=id_, processor_type='strainanalysis')
888
+ for id_ in nxentry.detector_ids]
889
+ self.detector_config.update_detectors()
890
+ skipped_detectors = []
891
+ sskipped_detectors = []
892
+ detectors = []
893
+ for detector in self.detector_config.detectors:
894
+ detector_id = detector.get_id()
895
+ if detector_id not in nxdata:
896
+ skipped_detectors.append(detector_id)
897
+ elif detector_id not in calibration_detector_ids:
898
+ sskipped_detectors.append(detector_id)
899
+ else:
900
+ raw_detector_data = nxdata[detector_id].nxdata
901
+ if raw_detector_data.ndim != 2:
902
+ self.logger.warning(
903
+ f'Skipping detector {detector_id} (Illegal data shape '
904
+ f'{raw_detector_data.shape})')
905
+ elif raw_detector_data.sum():
906
+ for k, v in nxdata[detector_id].attrs.items():
907
+ detector.attrs[k] = v.nxdata
908
+ if self.config.rel_height_cutoff is not None:
909
+ detector.rel_height_cutoff = \
910
+ self.config.rel_height_cutoff
911
+ detector.add_calibration(
912
+ calibration_detectors[
913
+ int(calibration_detector_ids.index(detector_id))])
914
+ detectors.append(detector)
915
+ else:
916
+ self.logger.warning(
917
+ f'Skipping detector {detector_id} (zero intensity)')
918
+ if len(skipped_detectors) == 1:
919
+ self.logger.warning(
920
+ f'Skipping detector {skipped_detectors[0]} '
921
+ '(no raw data)')
922
+ elif skipped_detectors:
923
+ skipped_detectors = [int(d) for d in skipped_detectors]
924
+ self.logger.warning(
925
+ 'Skipping detectors '
926
+ f'{list_to_string(skipped_detectors)} (no raw data)')
927
+ if len(sskipped_detectors) == 1:
928
+ self.logger.warning(
929
+ f'Skipping detector {sskipped_detectors[0]} '
930
+ '(no raw data)')
931
+ elif sskipped_detectors:
932
+ skipped_detectors = [int(d) for d in sskipped_detectors]
933
+ self.logger.warning(
934
+ 'Skipping detectors '
935
+ f'{list_to_string(skipped_detectors)} (no calibration data)')
936
+ self.detector_config.detectors = detectors
937
+ if not self.detector_config.detectors:
938
+ raise ValueError('No valid data or unable to match an available '
939
+ 'calibrated detector for the strain analysis')
940
+
941
+ # Load the raw MCA data and compute the detector bin energies
942
+ # and the mean spectra
943
+ self._setup_detector_data(
944
+ nxentry[nxentry.default], strain_analysis_config=self.config)
945
+
946
+ # Apply the energy mask
947
+ self._apply_energy_mask()
948
+
949
+ # Get the mask and HKLs used in the strain analysis
950
+ self._get_mask_hkls()
951
+
952
+ # Apply the combined energy ranges mask
953
+ self._apply_combined_mask()
954
+
955
+ # Get and subtract the detector baselines
956
+ self._subtract_baselines()
957
+
958
+ # Get the refined values for the material properties
959
+ self._refine_lattice_parameters()
960
+
961
+ # Return the lattice parameter refinement from visual inspection
962
+ if self._figures:
963
+ return (
964
+ self.config.model_dump(),
965
+ PipelineData(
966
+ name=self.__name__, data=self._figures,
967
+ schema='common.write.ImageWriter'))
968
+ return self.config.model_dump()
969
+
970
+ def _refine_lattice_parameters(self):
971
+ """Update the strain analysis configuration with the refined
972
+ values for the material properties.
973
+ """
974
+ # Local modules
975
+ from CHAP.edd.models import MaterialConfig
976
+
977
+ names = []
978
+ sgnums = []
979
+ lattice_parameters = []
980
+ for i, detector in enumerate(self.detector_config.detectors):
981
+ materials, buf = self._adjust_material_props(
982
+ self.config.materials, i)
983
+ for m in materials:
984
+ if m.material_name in names:
985
+ lattice_parameters[names.index(m.material_name)].append(
986
+ m.lattice_parameters)
987
+ else:
988
+ names.append(m.material_name)
989
+ sgnums.append(m.sgnum)
990
+ lattice_parameters.append([m.lattice_parameters])
991
+ if self.save_figures:
992
+ self._figures.append((
993
+ buf, f'{detector.get_id()}_lp_refinement_material_config'))
994
+ refined_materials = []
995
+ for name, sgnum, lat_params in zip(names, sgnums, lattice_parameters):
996
+ if lat_params:
997
+ refined_materials.append(MaterialConfig(
998
+ material_name=name, sgnum=sgnum,
999
+ lattice_parameters=np.asarray(lat_params).mean(axis=0)))
1000
+ else:
1001
+ refined_materials.append(MaterialConfig(
1002
+ material_name=name, sgnum=sgnum,
1003
+ lattice_parameters=lat_params))
1004
+ self.config.materials = refined_materials
1005
+
1006
+ # """
1007
+ # Method: given
1008
+ # a scan of a material, fit the peaks of each MCA spectrum for a
1009
+ # given detector. Based on those fitted peak locations,
1010
+ # calculate the lattice parameters that would produce them.
1011
+ # Return the averaged value of the calculated lattice parameters
1012
+ # across all spectra.
1013
+ # """
1014
+ # # Get the interplanar spacings measured for each fit HKL peak
1015
+ # # at the spectrum averaged over every point in the map to get
1016
+ # # the refined estimate for the material's lattice parameter
1017
+ # uniform_fit_centers = uniform_results['centers']
1018
+ # uniform_best_fit = uniform_results['best_fits']
1019
+ # unconstrained_fit_centers = unconstrained_results['centers']
1020
+ # unconstrained_best_fit = unconstrained_results['best_fits']
1021
+ # d_uniform = get_peak_locations(
1022
+ # np.asarray(uniform_fit_centers), detector.tth_calibrated)
1023
+ # d_unconstrained = get_peak_locations(
1024
+ # np.asarray(unconstrained_fit_centers), detector.tth_calibrated)
1025
+ # a_uniform = float((rs * d_uniform).mean())
1026
+ # a_unconstrained = rs * d_unconstrained
1027
+ # self.logger.warning(
1028
+ # 'Lattice parameter refinement assumes cubic lattice')
1029
+ # self.logger.info(
1030
+ # 'Refined lattice parameter from uniform fit over averaged '
1031
+ # f'spectrum: {a_uniform}')
1032
+ # self.logger.info(
1033
+ # 'Refined lattice parameter from unconstrained fit over averaged '
1034
+ # f'spectrum: {a_unconstrained}')
1035
+ #
1036
+ # if interactive or save_figures:
1037
+ # # Third party modules
1038
+ # import matplotlib.pyplot as plt
1039
+ #
1040
+ # fig, ax = plt.subplots(figsize=(11, 8.5))
1041
+ # ax.set_title(
1042
+ # f'Detector {detector.get_id()}: Lattice Parameter Refinement')
1043
+ # ax.set_xlabel('Detector energy (keV)')
1044
+ # ax.set_ylabel('Mean intensity (counts)')
1045
+ # ax.plot(energies, mean_intensity, 'k.', label='MCA data')
1046
+ # ax.plot(energies, uniform_best_fit, 'r', label='Best uniform fit')
1047
+ # ax.plot(
1048
+ # energies, unconstrained_best_fit, 'b',
1049
+ # label='Best unconstrained fit')
1050
+ # ax.legend()
1051
+ # for i, loc in enumerate(peak_locations):
1052
+ # ax.axvline(loc, c='k', ls='--')
1053
+ # ax.text(loc, 1, str(hkls_fit[i])[1:-1],
1054
+ # ha='right', va='top', rotation=90,
1055
+ # transform=ax.get_xaxis_transform())
1056
+ # if save_figures:
1057
+ # fig.tight_layout()#rect=(0, 0, 1, 0.95))
1058
+ # figfile = os.path.join(
1059
+ # outputdir, f'{detector.get_id()}_lat_param_fits')
1060
+ # plt.savefig(figfile)
1061
+ # self.logger.info(f'Saved figure to {figfile}')
1062
+ # if interactive:
1063
+ # plt.show()
1064
+ #
1065
+ # return [
1066
+ # a_uniform, a_uniform, a_uniform, 90., 90., 90.]
1067
+
1068
+
1069
+ class MCAEnergyCalibrationProcessor(BaseEddProcessor):
1070
+ """Processor to return parameters for linearly transforming MCA
1071
+ channel indices to energies (in keV). Procedure: provide a
1072
+ spectrum from the MCA element to be calibrated and the theoretical
1073
+ location of at least one peak present in that spectrum (peak
1074
+ locations must be given in keV). It is strongly recommended to use
1075
+ the location of fluorescence peaks whenever possible, _not_
1076
+ diffraction peaks, as this Processor does not account for
1077
+ 2&theta.
1078
+
1079
+ :ivar config: Initialization parameters for an instance of
1080
+ CHAP.edd.models.MCAEnergyCalibrationConfig.
1081
+ :type config: dict, optional
1082
+ :ivar detector_config: Initialization parameters for an instance of
1083
+ CHAP.edd.models.MCADetectorConfig. Defaults to the detector
1084
+ configuration of the raw detector data.
1085
+ :type detector_config: dict, optional
1086
+ :ivar save_figures: Save .pngs of plots for checking inputs &
1087
+ outputs of this Processor, defaults to `False`.
1088
+ :type save_figures: bool, optional
1089
+ """
1090
+ pipeline_fields: dict = Field(
1091
+ default = {
1092
+ 'config': 'edd.models.MCAEnergyCalibrationConfig',
1093
+ 'detector_config': {
1094
+ 'schema': ['edd.models.MCAEnergyCalibrationConfig',
1095
+ 'edd.models.MCADetectorConfig'],
1096
+ 'merge_key_paths': {'key_path': 'detectors/id', 'type': int}},
1097
+ },
1098
+ init_var=True)
1099
+ config: Optional[MCAEnergyCalibrationConfig] = MCAEnergyCalibrationConfig()
1100
+ detector_config: MCADetectorConfig
1101
+ save_figures: Optional[bool] = False
1102
+
1103
+ @model_validator(mode='before')
1104
+ @classmethod
1105
+ def validate_mcaenergycalibrationprocessor_before(cls, data):
1106
+ if isinstance(data, dict):
1107
+ detector_config = data.pop('detector_config', {})
1108
+ detector_config['processor_type'] = 'calibration'
1109
+ data['detector_config'] = detector_config
1110
+ return data
1111
+
1112
+ def process(self, data):
1113
+ """For each detector in the `MCAEnergyCalibrationConfig`
1114
+ provided with `data`, fit the specified peaks in the MCA
1115
+ spectrum specified. Using the difference between the provided
1116
+ peak locations and the fit centers of those peaks, compute
1117
+ the correction coefficients to convert uncalibrated MCA
1118
+ channel energies to calibrated channel energies. For each
1119
+ detector, set `energy_calibration_coeffs` in the calibration
1120
+ config provided to these values and return the updated
1121
+ configuration.
1122
+
1123
+ :param data: Energy calibration configuration.
1124
+ :type data: list[PipelineData]
1125
+ :returns: Dictionary representing the energy-calibrated
1126
+ version of the calibrated configuration and a list of
1127
+ byte stream representions of Matplotlib figures.
1128
+ :rtype: dict, PipelineData
1129
+ """
1130
+ # Third party modules
1131
+ from json import loads
1132
+
1133
+ # Load the detector data
1134
+ # FIX input a numpy and create/use NXobject to numpy proc
1135
+ # FIX right now spec info is lost in output yaml, add to it?
1136
+ nxentry = self.get_default_nxentry(self.get_data(data))
1137
+
1138
+ # Check for available detectors and validate the raw detector
1139
+ # configuration
1140
+ raw_detector_config = DetectorConfig(**loads(str(nxentry.detectors)))
1141
+ raw_detector_ids = [d.get_id() for d in raw_detector_config.detectors]
1142
+ if not self.detector_config.detectors:
1143
+ self.detector_config.detectors = [
1144
+ MCADetectorCalibration(
1145
+ **d.model_dump(), processor_type='calibration')
1146
+ for d in raw_detector_config.detectors]
1147
+ self.detector_config.update_detectors()
1148
+ else:
1149
+ skipped_detectors = []
1150
+ detectors = []
1151
+ for detector in self.detector_config.detectors:
1152
+ for raw_detector in raw_detector_config.detectors:
1153
+ if detector.get_id() == raw_detector.get_id():
1154
+ for k, v in raw_detector.attrs.items():
1155
+ if k not in detector.attrs:
1156
+ if isinstance(v, list):
1157
+ detector.attrs[k] = np.asarray(v)
1158
+ else:
1159
+ detector.attrs[k] = v
1160
+ detector.energy_mask_ranges = None
1161
+ detectors.append(detector)
1162
+ break
1163
+ else:
1164
+ skipped_detectors.append(detector.get_id())
1165
+ if len(skipped_detectors) == 1:
1166
+ self.logger.warning(
1167
+ f'Skipping detector {skipped_detectors[0]} '
1168
+ '(no raw data)')
1169
+ elif skipped_detectors:
1170
+ # Local modules
1171
+ from CHAP.utils.general import list_to_string
1172
+
1173
+ skipped_detectors = [int(d) for d in skipped_detectors]
1174
+ self.logger.warning(
1175
+ 'Skipping detectors '
1176
+ f'{list_to_string(skipped_detectors)} (no raw data)')
1177
+ self.detector_config.detectors = detectors
1178
+ if not self.detector_config.detectors:
1179
+ raise ValueError(
1180
+ 'No raw data for the requested calibration detectors)')
1181
+
1182
+ # Load the raw MCA data and compute the detector bin energies
1183
+ # and the mean spectra
1184
+ self._setup_detector_data(
1185
+ nxentry, available_detector_ids=raw_detector_ids,
1186
+ max_energy_kev=self.config.max_energy_kev)
1187
+
1188
+ # Apply the flux correction
1189
+ self._apply_flux_correction()
1190
+
1191
+ # Apply the energy mask
1192
+ self._apply_energy_mask()
1193
+
1194
+ # Get the mask used in the energy calibration
1195
+ self._get_mask()
1196
+
1197
+ # Apply the combined energy ranges mask
1198
+ self._apply_combined_mask()
1199
+
1200
+ # Get and subtract the detector baselines
1201
+ self._subtract_baselines()
1202
+
1203
+ # Calibrate detector channel energies based on fluorescence peaks
1204
+ self._calibrate()
1205
+
1206
+ # Combine the calibration and detector configuration
1207
+ # and move default detector fields to the detector attrs
1208
+ for d in self.detector_config.detectors:
1209
+ d.attrs['default_fields'] = {
1210
+ k:v.default for k, v in d.model_fields.items()
1211
+ if (k != 'attrs' and (k not in d.model_fields_set
1212
+ or v.default == getattr(d, k)))}
1213
+ configs = {
1214
+ **self.config.model_dump(),
1215
+ 'detectors': [d.model_dump(exclude_defaults=True)
1216
+ for d in self.detector_config.detectors]}
1217
+ return configs, PipelineData(
1218
+ name=self.__name__, data=self._figures,
1219
+ schema='common.write.ImageWriter')
1220
+
1221
+ def _get_mask(self):
1222
+ """Get the mask used in the energy calibration."""
1223
+ # Local modules
1224
+ from CHAP.utils.general import select_mask_1d
1225
+
1226
+ for mean_data, detector in zip(
1227
+ self._mean_data, self.detector_config.detectors):
1228
+ # Interactively adjust the mask used in the energy
1229
+ # calibration
1230
+ buf, _, detector.mask_ranges = select_mask_1d(
1231
+ mean_data, preselected_index_ranges=detector.mask_ranges,
1232
+ title=f'Mask for detector {detector.get_id()}',
1233
+ xlabel='Detector Channel (-)',
1234
+ ylabel='Intensity (counts)',
1235
+ min_num_index_ranges=1, interactive=self.interactive,
1236
+ return_buf=self.save_figures)
1237
+ self.logger.debug(
1238
+ f'mask_ranges for detector {detector.get_id()}:'
1239
+ f' {detector.mask_ranges}')
1240
+ if self.save_figures:
1241
+ self._figures.append((
1242
+ buf, f'{detector.get_id()}_energy_calibration_mask'))
1243
+ if not detector.mask_ranges:
1244
+ raise ValueError(
1245
+ 'No value provided for mask_ranges. Provide it in '
1246
+ 'the energy calibration configuration, or re-run '
1247
+ 'the pipeline with the interactive flag set.')
1248
+
1249
+ def _calibrate(self):
1250
+ """Return the energy calibration configuration dictionary
1251
+ after calibrating the energy_calibration_coeffs (a, b, and c)
1252
+ for quadratically converting the current detector's MCA
1253
+ channels to bin energies.
1254
+
1255
+ :returns: Updated energy calibration configuration.
1256
+ :rtype: dict
1257
+ """
1258
+ # Third party modules
1259
+ from nexusformat.nexus import (
1260
+ NXdata,
1261
+ NXfield,
1262
+ )
1263
+
1264
+ # Local modules
1265
+ from CHAP.utils.fit import FitProcessor
1266
+ from CHAP.utils.general import index_nearest
1267
+
1268
+ max_peak_energy = self.config.peak_energies[
1269
+ self.config.max_peak_index]
1270
+ peak_energies = list(np.sort(self.config.peak_energies))
1271
+ max_peak_index = peak_energies.index(max_peak_energy)
1272
+
1273
+ for energies, mask, mean_data, (low, _), detector in zip(
1274
+ self._energies, self._masks, self._mean_data,
1275
+ self._mask_index_ranges, self.detector_config.detectors):
1276
+
1277
+ self.logger.info(f'Calibrating detector {detector.get_id()}')
1278
+
1279
+ bins = low + np.arange(energies.size, dtype=np.int16)
1280
+
1281
+ # Get the intial peak positions for fitting
1282
+ input_indices = [low + index_nearest(energies, energy)
1283
+ for energy in peak_energies]
1284
+ buf, initial_peak_indices = self._get_initial_peak_positions(
1285
+ mean_data*np.asarray(mask).astype(np.int32), low,
1286
+ detector.mask_ranges, input_indices, max_peak_index,
1287
+ detector.get_id(), return_buf=self.save_figures)
1288
+ if self.save_figures:
1289
+ self._figures.append(
1290
+ (buf,
1291
+ f'{detector.get_id()}'
1292
+ '_energy_calibration_initial_peak_positions'))
1293
+
1294
+ # Construct the fit model and perform the fit
1295
+ models = []
1296
+ if detector.background is not None:
1297
+ if len(detector.background) == 1:
1298
+ models.append(
1299
+ {'model': detector.background[0], 'prefix': 'bkgd_'})
1300
+ else:
1301
+ for model in detector.background:
1302
+ models.append({'model': model, 'prefix': f'{model}_'})
1303
+ models.append(
1304
+ {'model': 'multipeak', 'centers': initial_peak_indices,
1305
+ 'centers_range': detector.centers_range,
1306
+ 'fwhm_min': detector.fwhm_min,
1307
+ 'fwhm_max': detector.fwhm_max})
1308
+ self.logger.debug('Fitting spectrum')
1309
+ fit = FitProcessor(**self.run_config)
1310
+ mean_data_fit = fit.process(
1311
+ NXdata(
1312
+ NXfield(mean_data[mask], 'y'), NXfield(bins[mask], 'x')),
1313
+ {'models': models, 'method': 'trf'})
1314
+
1315
+ # Extract the fit results for the peaks
1316
+ fit_peak_amplitudes = sorted([
1317
+ mean_data_fit.best_values[f'peak{i+1}_amplitude']
1318
+ for i in range(len(initial_peak_indices))])
1319
+ self.logger.debug(f'Fit peak amplitudes: {fit_peak_amplitudes}')
1320
+ fit_peak_indices = sorted([
1321
+ mean_data_fit.best_values[f'peak{i+1}_center']
1322
+ for i in range(len(initial_peak_indices))])
1323
+ self.logger.debug(f'Fit peak center indices: {fit_peak_indices}')
1324
+ fit_peak_fwhms = sorted([
1325
+ 2.35482*mean_data_fit.best_values[f'peak{i+1}_sigma']
1326
+ for i in range(len(initial_peak_indices))])
1327
+ self.logger.debug(f'Fit peak fwhms: {fit_peak_fwhms}')
1328
+
1329
+ # FIX for now stick with a linear energy correction
1330
+ fit = FitProcessor(**self.run_config)
1331
+ energy_fit = fit.process(
1332
+ NXdata(
1333
+ NXfield(peak_energies, 'y'),
1334
+ NXfield(fit_peak_indices, 'x')),
1335
+ {'models': [{'model': 'linear'}]})
1336
+ a = 0.0
1337
+ b = float(energy_fit.best_values['slope'])
1338
+ c = float(energy_fit.best_values['intercept'])
1339
+ detector.energy_calibration_coeffs = [a, b, c]
1340
+ delta_energy = self.config.max_energy_kev/detector.num_bins
1341
+ if not 0.95*delta_energy*0.95 < b < 1.05*delta_energy:
1342
+ self.logger.warning(
1343
+ f'Calibrated slope ({b}) is outside the 5% tolerance '
1344
+ 'of the detector input value (max_energy_kev/num_bins '
1345
+ f'= {delta_energy}.)')
1346
+
1347
+ # Reference plot to visualize the fit results:
1348
+ if self.interactive or self.save_figures:
1349
+ # Third part modules
1350
+ import matplotlib.pyplot as plt
1351
+
1352
+ bins_masked = bins[mask]
1353
+ fig, axs = plt.subplots(1, 2, figsize=(11, 4.25))
1354
+ fig.suptitle(
1355
+ f'Detector {detector.get_id()} energy calibration')
1356
+ # Left plot: raw MCA data & best fit of peaks
1357
+ axs[0].set_title('MCA spectrum peak fit')
1358
+ axs[0].set_xlabel('Detector Channel (-)')
1359
+ axs[0].set_ylabel('Intensity (counts)')
1360
+ axs[0].plot(
1361
+ bins_masked, mean_data[mask], 'b.', label='MCA data')
1362
+ axs[0].plot(
1363
+ bins_masked, mean_data_fit.best_fit, 'r', label='Best fit')
1364
+ axs[0].plot(
1365
+ bins_masked, mean_data_fit.residual, 'g', label='Residual')
1366
+ axs[0].legend()
1367
+ # Right plot: linear fit of theoretical peak energies vs
1368
+ # fit peak locations
1369
+ axs[1].set_title(
1370
+ 'Detector energy vs. detector channel')
1371
+ axs[1].set_xlabel('Detector Channel (-)')
1372
+ axs[1].set_ylabel('Detector Energy (keV)')
1373
+ axs[1].plot(
1374
+ fit_peak_indices, peak_energies, c='b', marker='o',
1375
+ ms=6, mfc='none', ls='', label='Initial peak positions')
1376
+ axs[1].plot(
1377
+ bins_masked, b*bins_masked + c, 'r',
1378
+ label=f'Best linear fit:\nm = {b:.5f} $keV$/channel\n'
1379
+ f'b = {c:.5f} $keV$')
1380
+ axs[1].set_ylim(
1381
+ (None, 1.2*axs[1].get_ylim()[1]-0.2*axs[1].get_ylim()[0]))
1382
+ axs[1].legend()
1383
+ ax2 = axs[1].twinx()
1384
+ ax2.set_ylabel('Residual (keV)', color='g')
1385
+ ax2.tick_params(axis='y', labelcolor='g')
1386
+ ax2.ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
1387
+ ax2.plot(
1388
+ fit_peak_indices, peak_energies-energy_fit.best_fit,
1389
+ c='g', marker='o', ms=6, ls='', label='Residual')
1390
+ ax2.set_ylim((None, 2*ax2.get_ylim()[1]-ax2.get_ylim()[0]))
1391
+ ax2.legend()
1392
+ fig.tight_layout()
1393
+
1394
+ if self.save_figures:
1395
+ self._figures.append(
1396
+ (fig_to_iobuf(fig),
1397
+ f'{detector.get_id()}_energy_calibration_fit'))
1398
+ if self.interactive:
1399
+ plt.show()
1400
+ plt.close()
1401
+
1402
+ def _get_initial_peak_positions(
1403
+ self, y, low, index_ranges, input_indices, input_max_peak_index,
1404
+ detector_id, reset_flag=0, return_buf=False):
1405
+ # Third party modules
1406
+ import matplotlib.pyplot as plt
1407
+ from matplotlib.widgets import Button
1408
+
1409
+ def change_fig_title(title):
1410
+ """Change the figure title."""
1411
+ if fig_title:
1412
+ fig_title[0].remove()
1413
+ fig_title.pop()
1414
+ fig_title.append(plt.figtext(*title_pos, title, **title_props))
1415
+
1416
+ def change_error_text(error=''):
1417
+ """Change the error text."""
1418
+ if error_texts:
1419
+ error_texts[0].remove()
1420
+ error_texts.pop()
1421
+ error_texts.append(plt.figtext(*error_pos, error, **error_props))
1422
+
1423
+ def reset(event):
1424
+ """Callback function for the "Reset" button."""
1425
+ peak_indices.clear()
1426
+ plt.close()
1427
+
1428
+ def confirm(event):
1429
+ """Callback function for the "Confirm" button."""
1430
+ if error_texts:
1431
+ error_texts[0].remove()
1432
+ error_texts.pop()
1433
+ plt.close()
1434
+
1435
+ def find_peaks(min_height=0.05, min_width=5, tolerance=10):
1436
+ """Find the peaks.
1437
+
1438
+ :param min_height: Minimum peak height in search, defaults
1439
+ to `0.05`.
1440
+ :type min_height: float, optional
1441
+ :param min_width: Minimum peak width in search, defaults
1442
+ to `5`.
1443
+ :type min_width: float, optional
1444
+ :param tolerance: Tolerance in peak index in channels for
1445
+ finding matching peaks, defaults to `10`.
1446
+ :type tolerance: int, optional
1447
+ :return: The peak indices.
1448
+ :rtype: list[int]
1449
+ """
1450
+ # Third party modules
1451
+ from scipy.signal import find_peaks as find_peaks_scipy
1452
+
1453
+ # Find peaks
1454
+ peaks = find_peaks_scipy(y, height=min_height,
1455
+ prominence=0.05*y.max(), width=min_width)
1456
+ #available_peak_indices = list(peaks[0])
1457
+ available_peak_indices = [low + i for i in peaks[0]]
1458
+ max_peak_index = np.asarray(peaks[1]["peak_heights"]).argmax()
1459
+ ratio = (available_peak_indices[max_peak_index]
1460
+ / input_indices[input_max_peak_index])
1461
+ peak_indices = [-1]*num_peak
1462
+ peak_indices[input_max_peak_index] = \
1463
+ available_peak_indices[max_peak_index]
1464
+ available_peak_indices.pop(max_peak_index)
1465
+ for i, input_index in enumerate(input_indices):
1466
+ if i != input_max_peak_index:
1467
+ index_guess = int(input_index * ratio)
1468
+ min_error = np.inf
1469
+ best_index = -1
1470
+ for index in available_peak_indices.copy():
1471
+ error = abs(index_guess-index)
1472
+ if error < tolerance and error < min_error:
1473
+ best_index = index
1474
+ min_error = error
1475
+ if best_index < 0:
1476
+ best_index = index_guess
1477
+ else:
1478
+ available_peak_indices.remove(best_index)
1479
+ peak_indices[i] = best_index
1480
+ return peak_indices
1481
+
1482
+ def select_peaks():
1483
+ """Manually select initial peak indices."""
1484
+ peak_indices = []
1485
+ while len(set(peak_indices)) < num_peak:
1486
+ error_text = ''
1487
+ change_fig_title(f'Select {num_peak} peak positions')
1488
+ peak_indices = [
1489
+ int(pt[0]) for pt in plt.ginput(num_peak, timeout=30)]
1490
+ if len(set(peak_indices)) < num_peak:
1491
+ error_text = f'Choose {num_peak} unique position'
1492
+ peak_indices.clear()
1493
+ outside_indices = []
1494
+ for index in peak_indices:
1495
+ if not any(l <= index <= u for l, u in index_ranges):
1496
+ outside_indices.append(index)
1497
+ if len(outside_indices) == 1:
1498
+ error_text = \
1499
+ f'Index {outside_indices[0]} outside of selection ' \
1500
+ f'region ({index_ranges}), try again'
1501
+ peak_indices.clear()
1502
+ elif outside_indices:
1503
+ error_text = \
1504
+ f'Indices {outside_indices} outside of selection ' \
1505
+ 'region, try again'
1506
+ peak_indices.clear()
1507
+ if not peak_indices:
1508
+ plt.close()
1509
+ fig, ax = plt.subplots(figsize=(11, 8.5))
1510
+ ax.set_xlabel('Detector Channel (-)', fontsize='x-large')
1511
+ ax.set_ylabel('Intensity (counts)', fontsize='x-large')
1512
+ ax.set_xlim(index_ranges[0][0], index_ranges[-1][1])
1513
+ fig.subplots_adjust(bottom=0.0, top=0.85)
1514
+ ax.plot(low + np.arange(y.size), y, color='k')
1515
+ fig.subplots_adjust(bottom=0.2)
1516
+ change_error_text(error_text)
1517
+ plt.draw()
1518
+ return peak_indices
1519
+
1520
+ peak_indices = []
1521
+ fig_title = []
1522
+ error_texts = []
1523
+
1524
+ y = np.asarray(y)
1525
+ if detector_id is None:
1526
+ detector_id = ''
1527
+ elif not reset_flag:
1528
+ detector_id = f' on detector {detector_id}'
1529
+ num_peak = len(input_indices)
1530
+
1531
+ # Setup the Matplotlib figure
1532
+ title_pos = (0.5, 0.95)
1533
+ title_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center',
1534
+ 'verticalalignment': 'bottom'}
1535
+ error_pos = (0.5, 0.90)
1536
+ error_props = {'fontsize': 'x-large', 'horizontalalignment': 'center',
1537
+ 'verticalalignment': 'bottom'}
1538
+ selected_peak_props = {
1539
+ 'color': 'red', 'linestyle': '-', 'linewidth': 2,
1540
+ 'marker': 10, 'markersize': 10, 'fillstyle': 'full'}
1541
+
1542
+ fig, ax = plt.subplots(figsize=(11, 8.5))
1543
+ ax.plot(low + np.arange(y.size), y, color='k')
1544
+ ax.set_xlabel('Detector Channel (-)', fontsize='x-large')
1545
+ ax.set_ylabel('Intensity (counts)', fontsize='x-large')
1546
+ ax.set_xlim(index_ranges[0][0], index_ranges[-1][1])
1547
+ fig.subplots_adjust(bottom=0.0, top=0.85)
1548
+
1549
+ if not self.interactive:
1550
+
1551
+ peak_indices += find_peaks()
1552
+
1553
+ for index in peak_indices:
1554
+ ax.axvline(index, **selected_peak_props)
1555
+ change_fig_title('Initial peak positions from peak finding '
1556
+ f'routine{detector_id}')
1557
+
1558
+ else:
1559
+
1560
+ fig.subplots_adjust(bottom=0.2)
1561
+
1562
+ # Get initial peak indices
1563
+ if not reset_flag:
1564
+ peak_indices += find_peaks()
1565
+ change_fig_title('Initial peak positions from peak finding '
1566
+ f'routine{detector_id}')
1567
+ if peak_indices:
1568
+ for index in peak_indices:
1569
+ if not any(l <= index <= u for l, u in index_ranges):
1570
+ peak_indices.clear()
1571
+ break
1572
+ if not peak_indices:
1573
+ peak_indices += select_peaks()
1574
+ change_fig_title(
1575
+ 'Selected initial peak positions{detector_id}')
1576
+
1577
+ for index in peak_indices:
1578
+ ax.axvline(index, **selected_peak_props)
1579
+
1580
+ # Setup "Reset" button
1581
+ if not reset_flag:
1582
+ reset_btn = Button(
1583
+ plt.axes([0.1, 0.05, 0.15, 0.075]), 'Manually select')
1584
+ else:
1585
+ reset_btn = Button(
1586
+ plt.axes([0.1, 0.05, 0.15, 0.075]), 'Reset')
1587
+ reset_cid = reset_btn.on_clicked(reset)
1588
+
1589
+ # Setup "Confirm" button
1590
+ confirm_btn = Button(
1591
+ plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm')
1592
+ confirm_cid = confirm_btn.on_clicked(confirm)
1593
+
1594
+ plt.show()
1595
+
1596
+ # Disconnect all widget callbacks when figure is closed
1597
+ reset_btn.disconnect(reset_cid)
1598
+ confirm_btn.disconnect(confirm_cid)
1599
+
1600
+ # ... and remove the buttons before returning the figure
1601
+ reset_btn.ax.remove()
1602
+ confirm_btn.ax.remove()
1603
+
1604
+ if return_buf:
1605
+ fig_title[0].set_in_layout(True)
1606
+ fig.tight_layout(rect=(0, 0, 1, 0.95))
1607
+ buf = fig_to_iobuf(fig)
1608
+ else:
1609
+ buf = None
1610
+ plt.close()
1611
+
1612
+ if self.interactive and len(peak_indices) != num_peak:
1613
+ reset_flag += 1
1614
+ return self._get_initial_peak_positions(
1615
+ y, low, index_ranges, input_indices, input_max_peak_index,
1616
+ detector_id, reset_flag=reset_flag, return_buf=return_buf)
1617
+ return buf, peak_indices
1618
+
1619
+
1620
+ class MCATthCalibrationProcessor(BaseEddProcessor):
1621
+ """Processor to calibrate the 2&theta angle and fine tune the
1622
+ energy calibration coefficients for an EDD experimental setup.
1623
+
1624
+ :ivar config: Initialization parameters for an instance of
1625
+ CHAP.edd.models.MCATthCalibrationConfig.
1626
+ :type config: dict, optional
1627
+ :ivar detector_config: Initialization parameters for an instance of
1628
+ CHAP.edd.models.MCADetectorConfig. Defaults to the detector
1629
+ configuration of the raw detector data merged with that of the
1630
+ energy calibration step..
1631
+ :ivar save_figures: Save .pngs of plots for checking inputs &
1632
+ outputs of this Processor, defaults to `False`.
1633
+ :type save_figures: bool, optional
1634
+ """
1635
+ pipeline_fields: dict = Field(
1636
+ default = {
1637
+ 'config': ['edd.models.MCAEnergyCalibrationConfig',
1638
+ 'edd.models.MCATthCalibrationConfig'],
1639
+ 'detector_config': {
1640
+ 'schema': ['edd.models.MCAEnergyCalibrationConfig',
1641
+ 'edd.models.MCATthCalibrationConfig',
1642
+ 'edd.models.MCADetectorConfig'],
1643
+ 'merge_key_paths': {'key_path': 'detectors/id', 'type': int}},
1644
+ },
1645
+ init_var=True)
1646
+ config: Optional[MCATthCalibrationConfig] = MCATthCalibrationConfig()
1647
+ detector_config: MCADetectorConfig
1648
+ save_figures: Optional[bool] = False
1649
+
1650
+ @model_validator(mode='before')
1651
+ @classmethod
1652
+ def validate_mcatthcalibrationprocessor_before(cls, data):
1653
+ if isinstance(data, dict):
1654
+ detector_config = data.pop('detector_config', {})
1655
+ detector_config['processor_type'] = 'calibration'
1656
+ data['detector_config'] = detector_config
1657
+ return data
1658
+
1659
+ def process(self, data):
1660
+ """Return the calibrated 2&theta value and the fine tuned
1661
+ energy calibration coefficients to convert MCA channel
1662
+ indices to MCA channel energies.
1663
+
1664
+ :param data: 2&theta calibration configuration.
1665
+ :type data: list[PipelineData]
1666
+ :raises RuntimeError: Invalid or missing input configuration.
1667
+ :return: Original configuration with the tuned values for
1668
+ 2&theta and the linear correction parameters added and a
1669
+ list of byte stream representions of Matplotlib figures.
1670
+ :rtype: dict, PipelineData
1671
+ """
1672
+ # Third party modules
1673
+ from json import loads
1674
+
1675
+ # Local modules
1676
+ from CHAP.utils.general import list_to_string
1677
+
1678
+ # Load the detector data
1679
+ # FIX input a numpy and create/use NXobject to numpy proc
1680
+ # FIX right now spec info is lost in output yaml, add to it?
1681
+ nxentry = self.get_default_nxentry(self.get_data(data))
1682
+
1683
+ # Check for available detectors and validate the raw detector
1684
+ # configuration
1685
+ if not self.detector_config.detectors:
1686
+ raise RuntimeError('No calibrated detectors')
1687
+ raw_detector_config = DetectorConfig(**loads(str(nxentry.detectors)))
1688
+ raw_detector_ids = [d.get_id() for d in raw_detector_config.detectors]
1689
+ skipped_detectors = []
1690
+ detectors = []
1691
+ for detector in self.detector_config.detectors:
1692
+ if detector.get_id() in raw_detector_ids:
1693
+ raw_detector = raw_detector_config.detectors[
1694
+ int(raw_detector_ids.index(detector.get_id()))]
1695
+ for k, v in raw_detector.attrs.items():
1696
+ if k not in detector.attrs:
1697
+ detector.attrs[k] = v
1698
+ detectors.append(detector)
1699
+ else:
1700
+ skipped_detectors.append(detector.get_id())
1701
+ if len(skipped_detectors) == 1:
1702
+ self.logger.warning(
1703
+ f'Skipping detector {skipped_detectors[0]} '
1704
+ '(no raw data)')
1705
+ elif skipped_detectors:
1706
+ skipped_detectors = [int(d) for d in skipped_detectors]
1707
+ self.logger.warning(
1708
+ 'Skipping detectors '
1709
+ f'{list_to_string(skipped_detectors)} (no raw data)')
1710
+ skipped_detectors = []
1711
+ for i, detector in reversed(list(enumerate(detectors))):
1712
+ if detector.energy_calibration_coeffs is None:
1713
+ skipped_detectors.append(detector.get_id())
1714
+ detectors.pop(i)
1715
+ else:
1716
+ detector.set_energy_calibration_mask_ranges()
1717
+ detector.tth_initial_guess = self.config.tth_initial_guess
1718
+ if detector.energy_mask_ranges is None:
1719
+ raise ValueError('energy_mask_ranges is required for '
1720
+ 'all detectors')
1721
+ detector.mask_ranges = None
1722
+ if len(skipped_detectors) == 1:
1723
+ self.logger.warning(
1724
+ f'Skipping detector {skipped_detectors[0]} '
1725
+ '(no calibration data)')
1726
+ elif skipped_detectors:
1727
+ skipped_detectors = [int(d) for d in skipped_detectors]
1728
+ self.logger.warning(
1729
+ 'Skipping detectors '
1730
+ f'{list_to_string(skipped_detectors)} (no calibration data)')
1731
+ self.detector_config.detectors = detectors
1732
+ if not self.detector_config.detectors:
1733
+ raise RuntimeError('No raw or calibrated detectors')
1734
+
1735
+ # Load the raw MCA data and compute the detector bin energies
1736
+ # and the mean spectra
1737
+ self._setup_detector_data(
1738
+ nxentry, available_detector_ids=raw_detector_ids)
1739
+
1740
+ # Apply the flux correction
1741
+ self._apply_flux_correction()
1742
+
1743
+ # Apply the energy mask
1744
+ self._apply_energy_mask()
1745
+
1746
+ # Select the initial tth value
1747
+ if self.interactive or self.save_figures:
1748
+ self._select_tth_init()
1749
+
1750
+ # Get the mask used in the energy calibration
1751
+ self._get_mask_hkls()
1752
+
1753
+ # Apply the combined energy ranges mask
1754
+ self._apply_combined_mask(self.config.calibration_method)
1755
+
1756
+ # Get and subtract the detector baselines
1757
+ self._subtract_baselines()
1758
+
1759
+ # Calibrate detector channel energies
1760
+ self._calibrate()
1761
+
1762
+ # Combine the calibration and detector configuration
1763
+ # and move default detector fields to the detector attrs
1764
+ for d in self.detector_config.detectors:
1765
+ d.attrs['default_fields'] = {
1766
+ k:v.default for k, v in d.model_fields.items()
1767
+ if (k != 'attrs' and (k not in d.model_fields_set
1768
+ or v.default == getattr(d, k)))}
1769
+ configs = {
1770
+ **self.config.model_dump(),
1771
+ 'detectors': [d.model_dump(exclude_defaults=True)
1772
+ for d in self.detector_config.detectors]}
1773
+ return configs, PipelineData(
1774
+ name=self.__name__, data=self._figures,
1775
+ schema='common.write.ImageWriter')
1776
+
1777
+ def _calibrate(self):
1778
+ """Calibrate 2&theta and linear and fine tune the energy
1779
+ calibration coefficients to convert MCA channel indices to MCA
1780
+ channel energies.
1781
+
1782
+ CHAP.edd.models.MCATthCalibrationConfig
1783
+ :returns: 2&theta calibration configuration.
1784
+ :rtype: dict
1785
+ """
1786
+ # Local modules
1787
+ from CHAP.edd.utils import (
1788
+ get_peak_locations,
1789
+ get_unique_hkls_ds,
1790
+ )
1791
+
1792
+ quadratic_energy_calibration = \
1793
+ self.config.quadratic_energy_calibration
1794
+ calibration_method = self.config.calibration_method
1795
+
1796
+ for energies, mask, mean_data, (low, upp), detector in zip(
1797
+ self._energies, self._masks, self._mean_data,
1798
+ self._mask_index_ranges, self.detector_config.detectors):
1799
+
1800
+ self.logger.info(f'Calibrating detector {detector.get_id()}')
1801
+
1802
+ tth = detector.tth_initial_guess
1803
+ bins = low + np.arange(energies.size, dtype=np.int16)
1804
+
1805
+ # Correct raw MCA data for variable flux at different energies
1806
+ flux_correct = \
1807
+ self.config.flux_correction_interpolation_function()
1808
+ if flux_correct is not None:
1809
+ mca_intensity_weights = flux_correct(energies)
1810
+ mean_data = mean_data / mca_intensity_weights
1811
+
1812
+ # Get the Bragg peak HKLs, lattice spacings and energies
1813
+ hkls, ds = get_unique_hkls_ds(
1814
+ self.config.materials, tth_max=detector.tth_max,
1815
+ tth_tol=detector.tth_tol)
1816
+ hkls = np.asarray([hkls[i] for i in detector.hkl_indices])
1817
+ ds = np.asarray([ds[i] for i in detector.hkl_indices])
1818
+ e_bragg = get_peak_locations(ds, tth)
1819
+
1820
+ # Perform the fit
1821
+ t0 = time()
1822
+ if calibration_method == 'direct_fit_bragg':
1823
+ results = self._direct_bragg_peak_fit(
1824
+ energies, mean_data, bins, mask, detector,
1825
+ e_bragg, tth, quadratic_energy_calibration)
1826
+ elif calibration_method == 'direct_fit_tth_ecc':
1827
+ results = self._direct_fit_tth_ecc(
1828
+ energies, mean_data, bins, mask, detector, ds, e_bragg,
1829
+ self.config.peak_energies, tth,
1830
+ quadratic_energy_calibration)
1831
+ else:
1832
+ raise ValueError(
1833
+ f'calibration_method {calibration_method} not implemented')
1834
+ self.logger.info(
1835
+ f'Fitting detector {detector.get_id()} took '
1836
+ f'{time()-t0:.3f} seconds')
1837
+ detector.tth_calibrated = results['tth']
1838
+ detector.energy_calibration_coeffs = \
1839
+ results['energy_calibration_coeffs']
1840
+
1841
+ # Update the peak energies and the MCA channel energies
1842
+ # with the newly calibrated coefficients
1843
+ energies = detector.energies[low:upp]
1844
+ if calibration_method == 'direct_fit_tth_ecc':
1845
+ e_bragg = get_peak_locations(ds, detector.tth_calibrated)
1846
+
1847
+ if self.interactive or self.save_figures:
1848
+ # Third party modules
1849
+ import matplotlib.pyplot as plt
1850
+
1851
+ # Create the figure
1852
+ fig, axs = plt.subplots(2, 2, sharex='all', figsize=(11, 8.5))
1853
+ fig.suptitle(
1854
+ f'Detector {detector.get_id()} 'r'2$\theta$ Calibration')
1855
+
1856
+ # Upper left axes: best fit with calibrated peak centers
1857
+ axs[0,0].set_title(r'2$\theta$ Calibration Fits')
1858
+ axs[0,0].set_xlabel('Energy (keV)')
1859
+ axs[0,0].set_ylabel('Intensity (counts)')
1860
+ for i, e_peak in enumerate(e_bragg):
1861
+ axs[0,0].axvline(e_peak, c='k', ls='--')
1862
+ axs[0,0].text(
1863
+ e_peak, 1, str(hkls[i])[1:-1], ha='right', va='top',
1864
+ rotation=90, transform=axs[0,0].get_xaxis_transform())
1865
+ if flux_correct is None:
1866
+ axs[0,0].plot(
1867
+ energies[mask], mean_data[mask], marker='.', c='C2',
1868
+ ms=3, ls='', label='MCA data')
1869
+ else:
1870
+ axs[0,0].plot(
1871
+ energies[mask], mean_data[mask], marker='.', c='C2',
1872
+ ms=3, ls='', label='Flux-corrected MCA data')
1873
+ if quadratic_energy_calibration:
1874
+ label = 'Unconstrained fit using calibrated a, b, and c'
1875
+ else:
1876
+ label = 'Unconstrained fit using calibrated b and c'
1877
+ axs[0,0].plot(
1878
+ energies[mask], results['best_fit_unconstrained'], c='C1',
1879
+ label=label)
1880
+ axs[0,0].legend()
1881
+
1882
+ # Lower left axes: fit residual
1883
+ axs[1,0].set_title('Fit Residuals')
1884
+ axs[1,0].set_xlabel('Energy (keV)')
1885
+ axs[1,0].set_ylabel('Residual (counts)')
1886
+ axs[1,0].plot(
1887
+ energies[mask], results['residual_unconstrained'], c='C1',
1888
+ label=label)
1889
+ axs[1,0].legend()
1890
+
1891
+ # Upper right axes: E vs strain for each fit
1892
+ strains_unconstrained = 1.e6 * results['strains_unconstrained']
1893
+ strain_unconstrained = np.mean(strains_unconstrained)
1894
+ axs[0,1].set_title('Peak Energy vs. Microstrain')
1895
+ axs[0,1].set_xlabel('Energy (keV)')
1896
+ axs[0,1].set_ylabel('Strain (\u03BC\u03B5)')
1897
+ axs[0,1].plot(
1898
+ e_bragg, strains_unconstrained, marker='o', mfc='none',
1899
+ c='C1', label='Unconstrained')
1900
+ axs[0,1].axhline(
1901
+ strain_unconstrained, ls='--', c='C1',
1902
+ label='Unconstrained: unweighted mean')
1903
+ axs[0,1].legend()
1904
+
1905
+ # Lower right axes: theoretical E vs fitted E for all peaks
1906
+ a_fit, b_fit, c_fit = detector.energy_calibration_coeffs
1907
+ e_bragg_unconstrained = results['e_bragg_unconstrained']
1908
+ axs[1,1].set_title('Theoretical vs. Fitted Peak Energies')
1909
+ axs[1,1].set_xlabel('Energy (keV)')
1910
+ axs[1,1].set_ylabel('Energy (keV)')
1911
+ if calibration_method == 'direct_fit_tth_ecc':
1912
+ e_fit = np.concatenate(
1913
+ (self.config.peak_energies, e_bragg))
1914
+ e_fit_unconstrained = np.concatenate(
1915
+ (results['e_xrf_unconstrained'],
1916
+ e_bragg_unconstrained))
1917
+ else:
1918
+ e_fit = e_bragg
1919
+ e_fit_unconstrained = e_bragg_unconstrained
1920
+ if quadratic_energy_calibration:
1921
+ label = 'Unconstrained: quadratic fit'
1922
+ else:
1923
+ label = 'Unconstrained: linear fit'
1924
+ label += f'\nTakeoff angle: {tth:.5f}'r'$^\circ$'
1925
+ if quadratic_energy_calibration:
1926
+ label += f'\na = {a_fit:.5e} $keV$/channel$^2$'
1927
+ label += f'\nb = {b_fit:.5f} $keV$/channel$'
1928
+ label += f'\nc = {c_fit:.5f} $keV$'
1929
+ else:
1930
+ label += f'\nm = {b_fit:.5f} $keV$/channel'
1931
+ label += f'\nb = {c_fit:.5f} $keV$'
1932
+ axs[1,1].plot(
1933
+ e_fit, e_fit, marker='o', mfc='none', ls='',
1934
+ label='Theoretical peak positions')
1935
+ axs[1,1].plot(
1936
+ e_fit, e_fit_unconstrained, c='C1', label=label)
1937
+ axs[1,1].set_ylim(
1938
+ (None,
1939
+ 1.2*axs[1,1].get_ylim()[1]-0.2*axs[1,1].get_ylim()[0]))
1940
+ axs[1,1].legend()
1941
+ ax2 = axs[1,1].twinx()
1942
+ ax2.set_ylabel('Residual (keV)', color='g')
1943
+ ax2.tick_params(axis='y', labelcolor='g')
1944
+ ax2.ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
1945
+ ax2.plot(
1946
+ e_fit, abs(e_fit-e_fit_unconstrained), c='g', marker='o',
1947
+ ms=6, ls='', label='Residual')
1948
+ ax2.set_ylim((None, 2*ax2.get_ylim()[1]-ax2.get_ylim()[0]))
1949
+ ax2.legend()
1950
+ fig.tight_layout()
1951
+
1952
+ if self.save_figures:
1953
+ self._figures.append((
1954
+ fig_to_iobuf(fig),
1955
+ f'{detector.get_id()}_tth_calibration_fit'))
1956
+ if self.interactive:
1957
+ plt.show()
1958
+ plt.close()
1959
+
1960
+ def _direct_bragg_peak_fit(
1961
+ self, energies, mean_data, bins, mask, detector, e_bragg, tth,
1962
+ quadratic_energy_calibration):
1963
+ """Perform an unconstrained fit minimizing the residual on the
1964
+ Bragg peaks only for a given 2&theta.
1965
+ """
1966
+ # Third party modules
1967
+ from nexusformat.nexus import (
1968
+ NXdata,
1969
+ NXfield,
1970
+ )
1971
+ from scipy.signal import find_peaks as find_peaks_scipy
1972
+
1973
+ # Local modules
1974
+ from CHAP.utils.fit import FitProcessor
1975
+ from CHAP.utils.general import index_nearest
1976
+
1977
+ # Get initial peak centers
1978
+ peaks = find_peaks_scipy(
1979
+ mean_data, width=5, height=0.005*mean_data.max())
1980
+ centers = list(peaks[0])
1981
+ centers = [bins[0] + centers[index_nearest(centers, c)]
1982
+ for c in [index_nearest(energies, e) for e in e_bragg]]
1983
+
1984
+ # Construct the fit model
1985
+ models = []
1986
+ if detector.background is not None:
1987
+ if len(detector.background) == 1:
1988
+ models.append(
1989
+ {'model': detector.background[0], 'prefix': 'bkgd_'})
1990
+ else:
1991
+ for model in detector.background:
1992
+ models.append({'model': model, 'prefix': f'{model}_'})
1993
+ if detector.backgroundpeaks is not None:
1994
+ backgroundpeaks = deepcopy(detector.backgroundpeaks)
1995
+ delta_energy = energies[1]-energies[0]
1996
+ if backgroundpeaks.centers_range is not None:
1997
+ backgroundpeaks.centers_range /= delta_energy
1998
+ if backgroundpeaks.fwhm_min is not None:
1999
+ backgroundpeaks.fwhm_min /= delta_energy
2000
+ if backgroundpeaks.fwhm_max is not None:
2001
+ backgroundpeaks.fwhm_max /= delta_energy
2002
+ backgroundpeaks.centers = [
2003
+ c/delta_energy for c in backgroundpeaks.centers]
2004
+ _, backgroundpeaks = FitProcessor.create_multipeak_model(
2005
+ backgroundpeaks)
2006
+ for peak in backgroundpeaks:
2007
+ peak.prefix = f'bkgd_{peak.prefix}'
2008
+ models += backgroundpeaks
2009
+ models.append(
2010
+ {'model': 'multipeak', 'centers': centers,
2011
+ 'centers_range': detector.centers_range,
2012
+ 'fwhm_min': detector.fwhm_min,
2013
+ 'fwhm_max': detector.fwhm_max})
2014
+
2015
+ # Perform an unconstrained fit in terms of MCA bin index
2016
+ fit = FitProcessor(**self.run_config)
2017
+ result = fit.process(
2018
+ NXdata(NXfield(mean_data[mask], 'y'), NXfield(bins[mask], 'x')),
2019
+ {'models': models, 'method': 'trf'})
2020
+ best_fit = result.best_fit
2021
+ residual = result.residual
2022
+
2023
+ # Extract the Bragg peak indices from the fit
2024
+ i_bragg_fit = np.asarray(
2025
+ [result.best_values[f'peak{i+1}_center']
2026
+ for i in range(len(e_bragg))])
2027
+
2028
+ # Fit a line through zero strain peak energies vs detector
2029
+ # energy bins
2030
+ if quadratic_energy_calibration:
2031
+ model = 'quadratic'
2032
+ else:
2033
+ model = 'linear'
2034
+ fit = FitProcessor(**self.run_config)
2035
+ result = fit.process(
2036
+ NXdata(NXfield(e_bragg, 'y'), NXfield(i_bragg_fit, 'x')),
2037
+ {'models': [{'model': model}]})
2038
+ if quadratic_energy_calibration:
2039
+ a_fit = result.best_values['a']
2040
+ b_fit = result.best_values['b']
2041
+ c_fit = result.best_values['c']
2042
+ else:
2043
+ a_fit = 0.0
2044
+ b_fit = result.best_values['slope']
2045
+ c_fit = result.best_values['intercept']
2046
+ e_bragg_unconstrained = (
2047
+ (a_fit*i_bragg_fit + b_fit) * i_bragg_fit + c_fit)
2048
+
2049
+ return {
2050
+ 'best_fit_unconstrained': best_fit,
2051
+ 'residual_unconstrained': residual,
2052
+ 'e_bragg_unconstrained': e_bragg_unconstrained,
2053
+ 'strains_unconstrained': np.log(e_bragg / e_bragg_unconstrained),
2054
+ 'tth': float(tth),
2055
+ 'energy_calibration_coeffs': [
2056
+ float(a_fit), float(b_fit), float(c_fit)],
2057
+ }
2058
+
2059
+ def _direct_fit_tth_ecc(
2060
+ self, energies, mean_data, bins, mask, detector, ds, e_bragg,
2061
+ e_xrf, tth, quadratic_energy_calibration):
2062
+ """Perform a fit minimizing the residual on the Bragg peaks
2063
+ only or on both the Bragg peaks and the fluorescence peaks
2064
+ for a given 2&theta in terms of the energy calibration
2065
+ coefficients.
2066
+ """
2067
+ #RV FIX Right now only implemented for a cubic lattice
2068
+ # Third party modules
2069
+ from nexusformat.nexus import (
2070
+ NXdata,
2071
+ NXfield,
2072
+ )
2073
+ from scipy.constants import physical_constants
2074
+
2075
+ # Local modules
2076
+ from CHAP.edd.utils import get_peak_locations
2077
+ from CHAP.utils.fit import FitProcessor
2078
+
2079
+ # Collect the free fit parameters
2080
+ # RV FIX Confine b to a limited range about its expect value?
2081
+ a, b, c = detector.energy_calibration_coeffs
2082
+ parameters = [{'name': 'tth', 'value': np.radians(tth)}]
2083
+ if quadratic_energy_calibration:
2084
+ parameters.append({'name': 'a', 'value': a})
2085
+ parameters.append({'name': 'b', 'value': b})
2086
+ parameters.append({'name': 'c', 'value': c})
2087
+
2088
+ # Construct the fit model
2089
+ num_bragg = len(e_bragg)
2090
+ num_xrf = len(e_xrf)
2091
+
2092
+ # Get the background
2093
+ bkgd_models = []
2094
+ if detector.background is not None:
2095
+ if isinstance(detector.background, str):
2096
+ bkgd_models.append(
2097
+ {'model': detector.background, 'prefix': 'bkgd_'})
2098
+ else:
2099
+ for model in detector.background:
2100
+ bkgd_models.append({'model': model, 'prefix': f'{model}_'})
2101
+
2102
+ # Add the background peaks in MCA channels
2103
+ models = deepcopy(bkgd_models)
2104
+ if detector.backgroundpeaks is not None:
2105
+ backgroundpeaks = deepcopy(detector.backgroundpeaks)
2106
+ delta_energy = energies[1]-energies[0]
2107
+ if backgroundpeaks.centers_range is not None:
2108
+ backgroundpeaks.centers_range /= delta_energy
2109
+ if backgroundpeaks.fwhm_min is not None:
2110
+ backgroundpeaks.fwhm_min /= delta_energy
2111
+ if backgroundpeaks.fwhm_max is not None:
2112
+ backgroundpeaks.fwhm_max /= delta_energy
2113
+ backgroundpeaks.centers = [
2114
+ c/delta_energy for c in backgroundpeaks.centers]
2115
+ _, backgroundpeaks = FitProcessor.create_multipeak_model(
2116
+ backgroundpeaks)
2117
+ for peak in backgroundpeaks:
2118
+ peak.prefix = f'bkgd_{peak.prefix}'
2119
+ models += backgroundpeaks
2120
+
2121
+ # Add the fluorescent peaks
2122
+ sig_min, sig_max = ((detector.fwhm_min, detector.fwhm_max) /
2123
+ (2.0*np.sqrt(2.0*np.log(2.0))))
2124
+ for i, e_peak in enumerate(e_xrf):
2125
+ expr = f'({e_peak}-c)/b'
2126
+ if quadratic_energy_calibration:
2127
+ expr = '(' + expr + f')*(1.0-a*(({e_peak}-c)/(b*b)))'
2128
+ models.append(
2129
+ {'model': 'gaussian', 'prefix': f'xrf{i+1}_',
2130
+ 'parameters': [
2131
+ {'name': 'amplitude', 'min': FLOAT_MIN},
2132
+ {'name': 'center', 'expr': expr},
2133
+ {'name': 'sigma', 'min': sig_min, 'max': sig_max}]})
2134
+
2135
+ # Add the Bragg peaks
2136
+ hc = 1.e7 * physical_constants['Planck constant in eV/Hz'][0] \
2137
+ * physical_constants['speed of light in vacuum'][0]
2138
+ for i, d in enumerate(ds):
2139
+ norm = 0.5*hc/d
2140
+ expr = f'(({norm}/sin(0.5*tth))-c)/b'
2141
+ if quadratic_energy_calibration:
2142
+ expr = '(' + expr \
2143
+ + f')*(1.0-a*((({norm}/sin(0.5*tth))-c)/(b*b)))'
2144
+ models.append(
2145
+ {'model': 'gaussian', 'prefix': f'peak{i+1}_',
2146
+ 'parameters': [
2147
+ {'name': 'amplitude', 'min': FLOAT_MIN},
2148
+ {'name': 'center', 'expr': expr},
2149
+ {'name': 'sigma', 'min': sig_min, 'max': sig_max}]})
2150
+
2151
+ # Perform the fit
2152
+ fit = FitProcessor(**self.run_config)
2153
+ result = fit.process(
2154
+ NXdata(NXfield(mean_data[mask], 'y'), NXfield(bins[mask], 'x')),
2155
+ {'parameters': parameters, 'models': models, 'method': 'trf'})
2156
+
2157
+ # Extract values of interest from the best values
2158
+ tth_fit = np.degrees(result.best_values['tth'])
2159
+ if quadratic_energy_calibration:
2160
+ a_fit = result.best_values['a']
2161
+ else:
2162
+ a_fit = 0.0
2163
+ b_fit = result.best_values['b']
2164
+ c_fit = result.best_values['c']
2165
+ i_peak_fit = np.asarray(
2166
+ [result.best_values[f'xrf{i+1}_center'] for i in range(num_xrf)]
2167
+ + [result.best_values[f'peak{i+1}_center']
2168
+ for i in range(num_bragg)])
2169
+ e_peak_fit = (a_fit*i_peak_fit + b_fit) * i_peak_fit+ c_fit
2170
+
2171
+ # Add the background peaks in keV
2172
+ models = deepcopy(bkgd_models)
2173
+ if detector.backgroundpeaks is not None:
2174
+ _, backgroundpeaks = FitProcessor.create_multipeak_model(
2175
+ detector.backgroundpeaks)
2176
+ for peak in backgroundpeaks:
2177
+ peak.prefix = f'bkgd_{peak.prefix}'
2178
+ bkgd_models += backgroundpeaks
2179
+
2180
+ # Get an unconstrained fit for the fitted energy calibration
2181
+ # coefficients
2182
+ models = bkgd_models + [{
2183
+ 'model': 'multipeak', 'centers': list(e_peak_fit),
2184
+ 'centers_range': b_fit * detector.centers_range,
2185
+ 'fwhm_min': b_fit * detector.fwhm_min,
2186
+ 'fwhm_max': b_fit * detector.fwhm_max}]
2187
+ result = fit.process(
2188
+ NXdata(NXfield(mean_data[mask], 'y'), NXfield(bins[mask], 'x')),
2189
+ {'parameters': parameters, 'models': models, 'method': 'trf'})
2190
+ fit = FitProcessor(**self.run_config)
2191
+ result = fit.process(
2192
+ NXdata(NXfield(mean_data[mask], 'y'),
2193
+ NXfield(energies[mask], 'x')),
2194
+ {'models': models, 'method': 'trf'})
2195
+ e_xrf_unconstrained = np.sort(
2196
+ [result.best_values[f'peak{i+1}_center']
2197
+ for i in range(num_xrf)])
2198
+ e_bragg_unconstrained = np.sort(
2199
+ [result.best_values[f'peak{i+1}_center']
2200
+ for i in range(num_xrf, num_xrf+num_bragg)])
2201
+
2202
+ # Update the peak energies with the newly calibrated tth
2203
+ e_bragg = get_peak_locations(ds, tth_fit)
2204
+
2205
+ return {
2206
+ 'best_fit_unconstrained': result.best_fit,
2207
+ 'residual_unconstrained': result.residual,
2208
+ 'e_xrf_unconstrained': e_xrf_unconstrained,
2209
+ 'e_bragg_unconstrained': e_bragg_unconstrained,
2210
+ 'strains_unconstrained': np.log(e_bragg / e_bragg_unconstrained),
2211
+ 'tth': tth_fit,
2212
+ 'energy_calibration_coeffs': [
2213
+ float(a_fit), float(b_fit), float(c_fit)],
2214
+ }
2215
+
2216
+ def _select_tth_init(self):
2217
+ """Select the initial 2&theta guess from the mean MCA
2218
+ spectrum.
2219
+ """
2220
+ # Local modules
2221
+ from CHAP.edd.utils import (
2222
+ get_unique_hkls_ds,
2223
+ select_tth_initial_guess,
2224
+ )
2225
+
2226
+ for energies, mean_data, detector in zip(
2227
+ self._energies, self._mean_data,
2228
+ self.detector_config.detectors):
2229
+
2230
+ # Get the unique HKLs and lattice spacings for the tth
2231
+ # calibration
2232
+ hkls, ds = get_unique_hkls_ds(
2233
+ self.config.materials, tth_max=detector.tth_max,
2234
+ tth_tol=detector.tth_tol)
2235
+
2236
+ detector.tth_initial_guess, buf = select_tth_initial_guess(
2237
+ energies, mean_data, hkls, ds, detector.tth_initial_guess,
2238
+ detector.get_id(), self.interactive, self.save_figures)
2239
+ if self.save_figures:
2240
+ self._figures.append((
2241
+ buf, f'{detector.get_id()}_tth_calibration_initial_guess'))
2242
+ self.logger.debug(
2243
+ f'tth_initial_guess for detector {detector.get_id()}: '
2244
+ f'{detector.tth_initial_guess}')
2245
+
2246
+
2247
+ class StrainAnalysisProcessor(BaseStrainProcessor):
2248
+ """Processor that takes a map of MCA data and returns a map of
2249
+ sample strains.
2250
+
2251
+ :ivar config: Initialization parameters for an instance of
2252
+ CHAP.edd.models.StrainAnalysisConfig.
2253
+ :type config: dict, optional
2254
+ :ivar detector_config: Initialization parameters for an instance of
2255
+ CHAP.edd.models.MCADetectorConfig. Defaults to the detector
2256
+ configuration of the raw detector data merged with that of the
2257
+ 2&theta calibration step..
2258
+ :ivar save_figures: Save .pngs of plots for checking inputs &
2259
+ outputs of this Processor, defaults to `False`.
2260
+ :type save_figures: bool, optional
2261
+ :ivar setup: Setup the strain analysis
2262
+ `nexusformat.nexus.NXroot` object, defaults to `True`.
2263
+ :type setup: bool, optional
2264
+ :ivar update: Perform the strain analysis and return the
2265
+ results as a list of updated points or update the result
2266
+ from the `setup` stage, defaults to `True`.
2267
+ :type update: bool, optional
2268
+ """
2269
+ pipeline_fields: dict = Field(
2270
+ default = {
2271
+ 'config': 'edd.models.StrainAnalysisConfig',
2272
+ 'detector_config': {
2273
+ 'schema': 'edd.models.MCADetectorConfig',
2274
+ 'merge_key_paths': {'key_path': 'detectors/id', 'type': int}},
2275
+ },
2276
+ init_var=True)
2277
+ config: Optional[StrainAnalysisConfig] = StrainAnalysisConfig()
2278
+ detector_config: MCADetectorConfig
2279
+ save_figures: Optional[bool] = False
2280
+ setup: Optional[bool] = True
2281
+ update: Optional[bool] = True
2282
+
2283
+ @model_validator(mode='before')
2284
+ @classmethod
2285
+ def validate_strainanalysisprocessor_before(cls, data):
2286
+ if isinstance(data, dict):
2287
+ detector_config = data.pop('detector_config', {})
2288
+ detector_config['processor_type'] = 'strainanalysis'
2289
+ data['detector_config'] = detector_config
2290
+ return data
2291
+
2292
+ @staticmethod
2293
+ def add_points(nxroot, points, logger=None):
2294
+ """Add or update the strain analysis for a set of map points
2295
+ in a `nexusformat.nexus.NXroot` object.
2296
+
2297
+ :param nxroot: The strain analysis object to add/update the
2298
+ points to.
2299
+ :type nxroot: nexusformat.nexus.NXroot
2300
+ :param points: The strain analysis results for a set of points.
2301
+ :type points: list[dict[str, object]
2302
+ """
2303
+ # Third party modules
2304
+ # pylint: disable=no-name-in-module
2305
+ from nexusformat.nexus import (
2306
+ NXdetector,
2307
+ NXprocess,
2308
+ )
2309
+ # pylint: enable=no-name-in-module
2310
+
2311
+ nxprocess = None
2312
+ for nxobject in nxroot.values():
2313
+ if isinstance(nxobject, NXprocess):
2314
+ nxprocess = nxobject
2315
+ break
2316
+ if nxprocess is None:
2317
+ raise RuntimeError('Unable to find the strainanalysis object')
2318
+
2319
+ nxdata_detectors = []
2320
+ for nxobject in nxprocess.values():
2321
+ if isinstance(nxobject, NXdetector):
2322
+ nxdata_detectors.append(nxobject.data)
2323
+ if not nxdata_detectors:
2324
+ raise RuntimeError(
2325
+ 'Unable to find detector data in strainanalysis object')
2326
+ axes = get_axes(nxdata_detectors[0], skip_axes=['energy'])
2327
+
2328
+ if axes:
2329
+ coords = np.asarray(
2330
+ [nxdata_detectors[0][a].nxdata for a in axes]).T
2331
+
2332
+ def get_matching_indices(all_coords, point_coords, decimals=None):
2333
+ if isinstance(decimals, int):
2334
+ all_coords = np.round(all_coords, decimals=decimals)
2335
+ point_coords = np.round(point_coords, decimals=decimals)
2336
+ coords_match = np.all(all_coords == point_coords, axis=1)
2337
+ index = np.where(coords_match)[0]
2338
+ return index
2339
+
2340
+ # FIX: can we round to 3 decimals right away in general?
2341
+ # FIX: assumes points contains a sorted and continous
2342
+ # slice of updates
2343
+ i_0 = get_matching_indices(
2344
+ coords,
2345
+ np.asarray([points[0][a] for a in axes]), decimals=3)[0]
2346
+ i_f = get_matching_indices(
2347
+ coords,
2348
+ np.asarray([points[-1][a] for a in axes]), decimals=3)[0]
2349
+ slices = {k: np.asarray([p[k] for p in points]) for k in points[0]}
2350
+ for k, v in slices.items():
2351
+ if k not in axes:
2352
+ logger.debug(f'Updating field {k}')
2353
+ nxprocess[k][i_0:i_f+1] = v
2354
+ else:
2355
+ for k, v in points[0].items():
2356
+ nxprocess[k].nxdata = v
2357
+
2358
+ # Add the summed intensity for each detector
2359
+ for nxdata in nxdata_detectors:
2360
+ nxdata.summed_intensity = nxdata.intensity.sum(axis=0)
2361
+
2362
+ def process(self, data):
2363
+ """Setup the strain analysis and/or return the strain analysis
2364
+ results as a list of updated points or a
2365
+ `nexusformat.nexus.NXroot` object.
2366
+
2367
+ :param data: Input data containing configurations for a map,
2368
+ completed energy/tth calibration, and (optionally)
2369
+ parameters for the strain analysis.
2370
+ :type data: list[PipelineData]
2371
+ :raises RuntimeError: Unable to get a valid strain analysis
2372
+ configuration.
2373
+ :return: The strain analysis setup or results, a list of
2374
+ byte stream representions of Matplotlib figures and an
2375
+ animation of the fit results.
2376
+ :rtype: Union[list[dict[str, object]],
2377
+ nexusformat.nexus.NXroot], PipelineData, PipelineData
2378
+ """
2379
+ # Third party modules
2380
+ from nexusformat.nexus import (
2381
+ NXentry,
2382
+ NXroot,
2383
+ )
2384
+
2385
+ # Local modules
2386
+ from CHAP.utils.general import list_to_string
2387
+
2388
+ if not (self.setup or self.update):
2389
+ raise RuntimeError('Illegal combination of setup and update')
2390
+ if not self.update:
2391
+ if self.interactive:
2392
+ self.logger.warning('Ineractive option disabled during setup')
2393
+ self.interactive = False
2394
+ if self.save_figures:
2395
+ self.logger.warning(
2396
+ 'Saving figures option disabled during setup')
2397
+ self.save_figures = False
2398
+ self._animation = []
2399
+
2400
+ # Load the pipeline input data
2401
+ try:
2402
+ nxobject = self.get_data(data)
2403
+ if isinstance(nxobject, NXroot):
2404
+ nxroot = nxobject
2405
+ elif isinstance(nxobject, NXentry):
2406
+ nxroot = NXroot()
2407
+ nxroot[nxobject.nxname] = nxobject
2408
+ nxobject.set_default()
2409
+ else:
2410
+ raise RuntimeError
2411
+ except Exception as exc:
2412
+ raise RuntimeError(
2413
+ 'No valid input in the pipeline data') from exc
2414
+
2415
+ # Load the detector data
2416
+ # FIX set rel_height_cutoff
2417
+ nxentry = self.get_default_nxentry(nxroot)
2418
+ nxdata = nxentry[nxentry.default]
2419
+
2420
+ # Load the validated calibration configuration
2421
+ calibration_config = self.get_config(
2422
+ data, schema='edd.models.MCATthCalibrationConfig', remove=False)
2423
+
2424
+ # Load the validated calibration detector configurations
2425
+ calibration_detector_config = self.get_data(
2426
+ data, schema='edd.models.MCATthCalibrationConfig')
2427
+ calibration_detectors = [
2428
+ MCADetectorCalibration(**d)
2429
+ for d in calibration_detector_config.get('detectors', [])]
2430
+ calibration_detector_ids = [d.get_id() for d in calibration_detectors]
2431
+
2432
+ # Check for available raw detector data and for the available
2433
+ # calibration data
2434
+ if not self.detector_config.detectors:
2435
+ self.detector_config.detectors = [
2436
+ MCADetectorStrainAnalysis(
2437
+ id=id_, processor_type='strainanalysis')
2438
+ for id_ in nxentry.detector_ids]
2439
+ self.detector_config.update_detectors()
2440
+ skipped_detectors = []
2441
+ sskipped_detectors = []
2442
+ detectors = []
2443
+ for detector in self.detector_config.detectors:
2444
+ detector_id = detector.get_id()
2445
+ if detector_id not in nxdata:
2446
+ skipped_detectors.append(detector_id)
2447
+ elif detector_id not in calibration_detector_ids:
2448
+ sskipped_detectors.append(detector_id)
2449
+ else:
2450
+ raw_detector_data = nxdata[detector_id].nxdata
2451
+ if raw_detector_data.ndim != 2:
2452
+ self.logger.warning(
2453
+ f'Skipping detector {detector_id} (Illegal data shape '
2454
+ f'{raw_detector_data.shape})')
2455
+ elif raw_detector_data.sum():
2456
+ for k, v in nxdata[detector_id].attrs.items():
2457
+ detector.attrs[k] = v.nxdata
2458
+ if self.config.rel_height_cutoff is not None:
2459
+ detector.rel_height_cutoff = \
2460
+ self.config.rel_height_cutoff
2461
+ detector.add_calibration(
2462
+ calibration_detectors[
2463
+ int(calibration_detector_ids.index(detector_id))])
2464
+ detectors.append(detector)
2465
+ else:
2466
+ self.logger.warning(
2467
+ f'Skipping detector {detector_id} (zero intensity)')
2468
+ if len(skipped_detectors) == 1:
2469
+ self.logger.warning(
2470
+ f'Skipping detector {skipped_detectors[0]} '
2471
+ '(no raw data)')
2472
+ elif skipped_detectors:
2473
+ skipped_detectors = [int(d) for d in skipped_detectors]
2474
+ self.logger.warning(
2475
+ 'Skipping detectors '
2476
+ f'{list_to_string(skipped_detectors)} (no raw data)')
2477
+ if len(sskipped_detectors) == 1:
2478
+ self.logger.warning(
2479
+ f'Skipping detector {sskipped_detectors[0]} '
2480
+ '(no raw data)')
2481
+ elif sskipped_detectors:
2482
+ skipped_detectors = [int(d) for d in sskipped_detectors]
2483
+ self.logger.warning(
2484
+ 'Skipping detectors '
2485
+ f'{list_to_string(skipped_detectors)} (no calibration data)')
2486
+ self.detector_config.detectors = detectors
2487
+ if not self.detector_config.detectors:
2488
+ raise ValueError('No valid data or unable to match an available '
2489
+ 'calibrated detector for the strain analysis')
2490
+
2491
+ # Load the raw MCA data and compute the detector bin energies
2492
+ # and the mean spectra
2493
+ self._setup_detector_data(
2494
+ nxentry[nxentry.default],
2495
+ strain_analysis_config=self.config, update=self.update)
2496
+
2497
+ # Apply the energy mask
2498
+ self._apply_energy_mask()
2499
+
2500
+ # Get the mask and HKLs used in the strain analysis
2501
+ self._get_mask_hkls()
2502
+
2503
+ # Apply the combined energy ranges mask
2504
+ self._apply_combined_mask()
2505
+
2506
+ # Setup and/or run the strain analysis
2507
+ points = []
2508
+ if self.update:
2509
+ points = self._strain_analysis()
2510
+ if self.setup:
2511
+ nxroot = self._get_nxroot(nxentry, calibration_config)
2512
+ if points:
2513
+ self.logger.info(f'Adding {len(points)} points')
2514
+ self.add_points(nxroot, points, logger=self.logger)
2515
+ self.logger.info(f'... done')
2516
+ else:
2517
+ self.logger.warning('Skip adding points')
2518
+ if not (self._figures or self._animation):
2519
+ return nxroot
2520
+ ret = [nxroot]
2521
+ else:
2522
+ if not (self._figures or self._animation):
2523
+ return points
2524
+ ret = [points]
2525
+ if self._figures:
2526
+ ret.append(
2527
+ PipelineData(
2528
+ name=self.__name__, data=self._figures,
2529
+ schema='common.write.ImageWriter'))
2530
+ if self._animation:
2531
+ ret.append(
2532
+ PipelineData(
2533
+ name=self.__name__, data=self._animation,
2534
+ schema='common.write.ImageWriter'))
2535
+ return tuple(ret)
2536
+
2537
+ def _add_fit_nxcollection(self, nxdetector, fit_type, hkls):
2538
+ """Add the fit collection as a `nexusformat.nexus.NXcollection`
2539
+ object.
2540
+ """
2541
+ # Third party modules
2542
+ # pylint: disable=no-name-in-module
2543
+ from nexusformat.nexus import (
2544
+ NXcollection,
2545
+ NXdata,
2546
+ NXfield,
2547
+ NXparameters,
2548
+ )
2549
+ # pylint: enable=no-name-in-module
2550
+
2551
+ nxdetector[f'{fit_type}_fit'] = NXcollection()
2552
+ nxcollection = nxdetector[f'{fit_type}_fit']
2553
+ det_nxdata = nxdetector.data
2554
+
2555
+ # Get data shape
2556
+ shape = det_nxdata.intensity.shape
2557
+
2558
+ # Full map of results
2559
+ nxcollection.results = NXdata()
2560
+ nxdata = nxcollection.results
2561
+ self._linkdims(nxdata, det_nxdata)
2562
+ nxdata.best_fit = NXfield(shape=shape, dtype=np.float64)
2563
+ nxdata.residual = NXfield(shape=shape, dtype=np.float64)
2564
+ nxdata.redchi = NXfield(shape=[shape[0]], dtype=np.float64)
2565
+ nxdata.success = NXfield(shape=[shape[0]], dtype='bool')
2566
+
2567
+ # Peak-by-peak results
2568
+ for hkl in hkls:
2569
+ hkl_name = '_'.join(str(hkl)[1:-1].split(' '))
2570
+ nxcollection[hkl_name] = NXparameters()
2571
+ # Create initial centers field
2572
+ if fit_type == 'uniform':
2573
+ nxcollection[hkl_name].center_initial_guess = 0.0
2574
+ else:
2575
+ nxcollection[hkl_name].center_initial_guess = NXdata()
2576
+ self._linkdims(
2577
+ nxcollection[hkl_name].center_initial_guess, det_nxdata,
2578
+ skip_field_dims=['energy'])
2579
+ # Report HKL peak centers
2580
+ nxcollection[hkl_name].centers = NXdata()
2581
+ self._linkdims(
2582
+ nxcollection[hkl_name].centers, det_nxdata,
2583
+ skip_field_dims=['energy'])
2584
+ nxcollection[hkl_name].centers.values = NXfield(
2585
+ shape=[shape[0]], dtype=np.float64, attrs={'units': 'keV'})
2586
+ nxcollection[hkl_name].centers.errors = NXfield(
2587
+ shape=[shape[0]], dtype=np.float64)
2588
+ nxcollection[hkl_name].centers.attrs['signal'] = 'values'
2589
+ # Report HKL peak amplitudes
2590
+ nxcollection[hkl_name].amplitudes = NXdata()
2591
+ self._linkdims(
2592
+ nxcollection[hkl_name].amplitudes, det_nxdata,
2593
+ skip_field_dims=['energy'])
2594
+ nxcollection[hkl_name].amplitudes.values = NXfield(
2595
+ shape=[shape[0]], dtype=np.float64, attrs={'units': 'counts'})
2596
+ nxcollection[hkl_name].amplitudes.errors = NXfield(
2597
+ shape=[shape[0]], dtype=np.float64)
2598
+ nxcollection[hkl_name].amplitudes.attrs['signal'] = 'values'
2599
+ # Report HKL peak FWHM
2600
+ nxcollection[hkl_name].sigmas = NXdata()
2601
+ self._linkdims(
2602
+ nxcollection[hkl_name].sigmas, det_nxdata,
2603
+ skip_field_dims=['energy'])
2604
+ nxcollection[hkl_name].sigmas.values = NXfield(
2605
+ shape=[shape[0]], dtype=np.float64, attrs={'units': 'keV'})
2606
+ nxcollection[hkl_name].sigmas.errors = NXfield(
2607
+ shape=[shape[0]], dtype=np.float64)
2608
+ nxcollection[hkl_name].sigmas.attrs['signal'] = 'values'
2609
+
2610
+ def _create_animation(
2611
+ self, nxdata, energies, intensities, best_fits, detector_id):
2612
+ """Create an animation of the fit results."""
2613
+ # Third party modules
2614
+ from matplotlib import animation
2615
+ import matplotlib.pyplot as plt
2616
+
2617
+ def animate(i):
2618
+ data = intensities[i]
2619
+ max_ = data.max()
2620
+ norm = max(1.0, max_)
2621
+ intensity.set_ydata(data / norm)
2622
+ best_fit.set_ydata(best_fits[i] / norm)
2623
+ index.set_text('\n'.join(
2624
+ [f'norm = {int(max_)}'] +
2625
+ [f'relative norm = {(max_ / norm_all_data):.5f}'] +
2626
+ [f'{a}[{i}] = {nxdata[a][i]}' for a in axes]))
2627
+ if self.save_figures:
2628
+ self._figures.append((
2629
+ fig_to_iobuf(fig),
2630
+ os.path.join(path, f'frame_{str(i).zfill(num_digit)}')))
2631
+ return intensity, best_fit, index
2632
+
2633
+ if self.save_figures:
2634
+ start_index = len(self._figures)
2635
+ path = f'{detector_id}_strainanalysis_unconstrained_fits'
2636
+ else:
2637
+ start_index = 0
2638
+
2639
+ axes = get_axes(nxdata)
2640
+ if 'energy' in axes:
2641
+ axes.remove('energy')
2642
+ norm_all_data = max(1.0, intensities.max())
2643
+
2644
+ fig, ax = plt.subplots()
2645
+ data = intensities[0]
2646
+ norm = max(1.0, data.max())
2647
+ intensity, = ax.plot(energies, data / norm, 'b.', label='data')
2648
+ best_fit, = ax.plot(energies, best_fits[0] / norm, 'k-', label='fit')
2649
+ ax.set(
2650
+ title='Unconstrained Fits',
2651
+ xlabel='Energy (keV)',
2652
+ ylabel='Normalized Intensity (-)')
2653
+ ax.legend(loc='upper right')
2654
+ ax.set_ylim(-0.05, 1.05)
2655
+ index = ax.text(
2656
+ 0.05, 0.95, '', transform=ax.transAxes, va='top')
2657
+
2658
+ num_frame = intensities.size // intensities.shape[-1]
2659
+ num_digit = len(str(num_frame))
2660
+ if not self.save_figures:
2661
+ ani = animation.FuncAnimation(
2662
+ fig, animate, frames=num_frame, interval=1000, blit=False,
2663
+ repeat=False)
2664
+ else:
2665
+ for i in range(num_frame):
2666
+ animate(i)
2667
+
2668
+ plt.close()
2669
+ plt.subplots_adjust(top=1, bottom=0, left=0, right=1)
2670
+
2671
+ frames = []
2672
+ for (buf, _), _ in self._figures[start_index:]:
2673
+ buf.seek(0)
2674
+ frame = plt.imread(buf)
2675
+ im = plt.imshow(frame, animated=True)
2676
+ if not i:
2677
+ plt.imshow(frame)
2678
+ frames.append([im])
2679
+
2680
+ ani = animation.ArtistAnimation(
2681
+ plt.gcf(), frames, interval=1000, blit=False,
2682
+ repeat=False)
2683
+
2684
+ if self.interactive:
2685
+ plt.show()
2686
+
2687
+ if self.save_figures:
2688
+ self._animation.append((
2689
+ (ani, 'gif'),
2690
+ f'{detector_id}_strainanalysis_unconstrained_fits'))
2691
+ plt.close()
2692
+
2693
+ def _get_nxroot(self, nxentry, calibration_config):
2694
+ """Return a `nexusformat.nexus.NXroot` object initialized for
2695
+ the stress analysis.
2696
+
2697
+ :param nxentry: Strain analysis map, including the raw
2698
+ MCA data.
2699
+ :type nxentry: nexusformat.nexus.NXentry
2700
+ :param calibration_config: 2&theta calibration configuration.
2701
+ :type calibration_config:
2702
+ CHAP.edd.models.MCATthCalibrationConfig
2703
+ :return: Strain analysis results & associated metadata..
2704
+ :rtype: nexusformat.nexus.NXroot
2705
+ """
2706
+ # Third party modules
2707
+ # pylint: disable=no-name-in-module
2708
+ from nexusformat.nexus import (
2709
+ NXdetector,
2710
+ NXfield,
2711
+ NXprocess,
2712
+ NXroot,
2713
+ )
2714
+ # pylint: enable=no-name-in-module
2715
+
2716
+ # Local modules
2717
+ from CHAP.edd.utils import get_unique_hkls_ds
2718
+ from CHAP.utils.general import nxcopy
2719
+
2720
+ if not self.interactive and not self.config.materials:
2721
+ raise ValueError(
2722
+ 'No material provided. Provide a material in the '
2723
+ 'StrainAnalysis Configuration, or re-run the pipeline with '
2724
+ 'the --interactive flag.')
2725
+
2726
+ # Create the NXroot object
2727
+ nxroot = NXroot()
2728
+ nxroot[nxentry.nxname] = nxentry
2729
+ nxroot[f'{nxentry.nxname}_strainanalysis'] = NXprocess()
2730
+ nxprocess = nxroot[f'{nxentry.nxname}_strainanalysis']
2731
+ nxprocess.calibration_config = \
2732
+ calibration_config.model_dump_json()
2733
+ nxprocess.strain_analysis_config = \
2734
+ self.config.model_dump_json()
2735
+
2736
+ # Loop over the detectors to fill in the nxprocess
2737
+ for energies, mask, nxdata, detector in zip(
2738
+ self._energies, self._masks, self._nxdata_detectors,
2739
+ self.detector_config.detectors):
2740
+
2741
+ # Get the current data object
2742
+ data = nxdata.nxsignal
2743
+ num_points = data.shape[0]
2744
+
2745
+ # Setup the NXdetector object for the current detector
2746
+ self.logger.debug(
2747
+ f'Setting up NXdetector group for {detector.get_id()}')
2748
+ nxdetector = NXdetector()
2749
+ nxprocess[detector.get_id()] = nxdetector
2750
+ nxdetector.local_name = detector.get_id()
2751
+ nxdetector.detector_config = detector.model_dump_json()
2752
+ nxdetector.data = nxcopy(nxdata, exclude_nxpaths='detector_data')
2753
+ det_nxdata = nxdetector.data
2754
+ if 'axes' in det_nxdata.attrs:
2755
+ if isinstance(det_nxdata.attrs['axes'], str):
2756
+ det_nxdata.attrs['axes'] = [
2757
+ det_nxdata.attrs['axes'], 'energy']
2758
+ else:
2759
+ det_nxdata.attrs['axes'].append('energy')
2760
+ else:
2761
+ det_nxdata.attrs['axes'] = ['energy']
2762
+ det_nxdata.energy = NXfield(
2763
+ value=energies[mask], attrs={'units': 'keV'})
2764
+ det_nxdata.tth = NXfield(
2765
+ dtype=np.float64,
2766
+ shape=(num_points,),
2767
+ attrs={'units':'degrees', 'long_name': '2\u03B8 (degrees)'})
2768
+ det_nxdata.uniform_strain = NXfield(
2769
+ dtype=np.float64,
2770
+ shape=(num_points,),
2771
+ attrs={'long_name': 'Strain from uniform fit (\u03BC\u03B5)'})
2772
+ det_nxdata.unconstrained_strain = NXfield(
2773
+ dtype=np.float64,
2774
+ shape=(num_points,),
2775
+ attrs={'long_name':
2776
+ 'Strain from unconstrained fit (\u03BC\u03B5)'})
2777
+
2778
+ # Add the detector data
2779
+ det_nxdata.intensity = NXfield(
2780
+ value=np.asarray([data[i].astype(np.float64)[mask]
2781
+ for i in range(num_points)]),
2782
+ attrs={'units': 'counts'})
2783
+ det_nxdata.attrs['signal'] = 'intensity'
2784
+
2785
+ # Get the unique HKLs and lattice spacings for the strain
2786
+ # analysis materials
2787
+ hkls, _ = get_unique_hkls_ds(
2788
+ self.config.materials, tth_max=detector.tth_max,
2789
+ tth_tol=detector.tth_tol)
2790
+
2791
+ # Get the HKLs and lattice spacings that will be used for
2792
+ # fitting
2793
+ hkls_fit = np.asarray([hkls[i] for i in detector.hkl_indices])
2794
+
2795
+ # Add the uniform fit nxcollection
2796
+ self._add_fit_nxcollection(nxdetector, 'uniform', hkls_fit)
2797
+
2798
+ # Add the unconstrained fit nxcollection
2799
+ self._add_fit_nxcollection(nxdetector, 'unconstrained', hkls_fit)
2800
+
2801
+ # Add the strain fields
2802
+ tth_map = detector.get_tth_map((num_points,))
2803
+ det_nxdata.tth.nxdata = tth_map
2804
+
2805
+ return nxroot
2806
+
2807
+ def _linkdims(
2808
+ self, nxgroup, nxdata_source, add_field_dims=None,
2809
+ skip_field_dims=None, oversampling_axis=None):
2810
+ """Link the dimensions for a 'nexusformat.nexus.NXgroup`
2811
+ object.
2812
+ """
2813
+ # Third party modules
2814
+ from nexusformat.nexus import NXfield
2815
+ from nexusformat.nexus.tree import NXlinkfield
2816
+
2817
+ if skip_field_dims is None:
2818
+ skip_field_dims = []
2819
+ if oversampling_axis is None:
2820
+ oversampling_axis = {}
2821
+ if 'axes' in nxdata_source.attrs:
2822
+ axes = nxdata_source.attrs['axes']
2823
+ if isinstance(axes, str):
2824
+ axes = [axes]
2825
+ else:
2826
+ axes = []
2827
+ axes = [a for a in axes if a not in skip_field_dims]
2828
+ if 'unstructured_axes' in nxdata_source.attrs:
2829
+ unstructured_axes = nxdata_source.attrs['unstructured_axes']
2830
+ if isinstance(unstructured_axes, str):
2831
+ unstructured_axes = [unstructured_axes]
2832
+ else:
2833
+ unstructured_axes = []
2834
+ link_axes = axes + unstructured_axes
2835
+ for dim in link_axes:
2836
+ if dim in oversampling_axis:
2837
+ bin_name = dim.replace('fly_', 'bin_')
2838
+ axes[axes.index(dim)] = bin_name
2839
+ exit('FIX replace in both axis and unstructured_axes')
2840
+ nxgroup[bin_name] = NXfield(
2841
+ value=oversampling_axis[dim],
2842
+ units=nxdata_source[dim].units,
2843
+ attrs={
2844
+ 'long_name':
2845
+ f'oversampled {nxdata_source[dim].long_name}',
2846
+ 'data_type': nxdata_source[dim].data_type,
2847
+ 'local_name': 'oversampled '
2848
+ f'{nxdata_source[dim].local_name}'})
2849
+ else:
2850
+ if isinstance(nxdata_source[dim], NXlinkfield):
2851
+ nxgroup[dim] = nxdata_source[dim]
2852
+ else:
2853
+ nxgroup.makelink(nxdata_source[dim])
2854
+ if f'{dim}_indices' in nxdata_source.attrs:
2855
+ nxgroup.attrs[f'{dim}_indices'] = \
2856
+ nxdata_source.attrs[f'{dim}_indices']
2857
+ if add_field_dims is None:
2858
+ if axes:
2859
+ nxgroup.attrs['axes'] = axes
2860
+ if unstructured_axes:
2861
+ nxgroup.attrs['unstructured_axes'] = unstructured_axes
2862
+ else:
2863
+ nxgroup.attrs['axes'] = axes + add_field_dims
2864
+ if unstructured_axes:
2865
+ nxgroup.attrs['unstructured_axes'] = unstructured_axes
2866
+
2867
+ def _strain_analysis(self):
2868
+ """Perform the strain analysis on the full or partial map."""
2869
+ # Local modules
2870
+ from CHAP.edd.utils import (
2871
+ get_peak_locations,
2872
+ get_spectra_fits,
2873
+ get_unique_hkls_ds,
2874
+ )
2875
+
2876
+ # Get and subtract the detector baselines
2877
+ self._subtract_baselines()
2878
+
2879
+ # Adjust the material properties
2880
+ _, buf = self._adjust_material_props(self.config.materials)
2881
+ if self.save_figures:
2882
+ self._figures.append((
2883
+ buf,
2884
+ f'{self.detector_config.detectors[0].get_id()}_'
2885
+ 'strainanalysis_material_config'))
2886
+
2887
+ # Setup the points list with the map axes values
2888
+ nxdata_ref = self._nxdata_detectors[0]
2889
+ axes = get_axes(nxdata_ref)
2890
+ if axes:
2891
+ points = [
2892
+ {a: nxdata_ref[a].nxdata[i] for a in axes}
2893
+ for i in range(nxdata_ref[axes[0]].size)]
2894
+ else:
2895
+ points = [{}]
2896
+
2897
+ # Loop over the detectors to fill in the nxprocess
2898
+ for energies, mask, mean_data, nxdata, detector in zip(
2899
+ self._energies, self._masks, self._mean_data,
2900
+ self._nxdata_detectors, self.detector_config.detectors):
2901
+
2902
+ self.logger.debug(
2903
+ f'Beginning strain analysis for {detector.get_id()}')
2904
+
2905
+ # Get the spectra for this detector
2906
+ intensities = nxdata.nxsignal.nxdata.T[mask].T
2907
+
2908
+ # Get the unique HKLs and lattice spacings for the strain
2909
+ # analysis materials
2910
+ hkls, ds = get_unique_hkls_ds(
2911
+ self.config.materials, tth_max=detector.tth_max,
2912
+ tth_tol=detector.tth_tol)
2913
+
2914
+ # Get the HKLs and lattice spacings that will be used for
2915
+ # fitting
2916
+ hkls_fit = np.asarray([hkls[i] for i in detector.hkl_indices])
2917
+ ds_fit = np.asarray([ds[i] for i in detector.hkl_indices])
2918
+ peak_locations = get_peak_locations(
2919
+ ds_fit, detector.tth_calibrated)
2920
+
2921
+ # Find initial peak estimates
2922
+ if (not self.config.find_peaks
2923
+ or detector.rel_height_cutoff is None):
2924
+ use_peaks = np.ones((peak_locations.size)).astype(bool)
2925
+ else:
2926
+ # Third party modules
2927
+ from scipy.signal import find_peaks as find_peaks_scipy
2928
+
2929
+ peaks = find_peaks_scipy(
2930
+ mean_data, width=5,
2931
+ height=detector.rel_height_cutoff*mean_data.max())
2932
+ #heights = peaks[1]['peak_heights']
2933
+ widths = peaks[1]['widths']
2934
+ centers = [energies[v] for v in peaks[0]]
2935
+ use_peaks = np.zeros((peak_locations.size)).astype(bool)
2936
+ # FIX Potentially use peak_heights/widths as initial
2937
+ # values in fit?
2938
+ # peak_heights = np.zeros((peak_locations.size))
2939
+ # peak_widths = np.zeros((peak_locations.size))
2940
+ delta = energies[1] - energies[0]
2941
+ #for height, width, center in zip(heights, widths, centers):
2942
+ for _ in range(4):
2943
+ for width, center in zip(widths, centers):
2944
+ for n, loc in enumerate(peak_locations):
2945
+ # FIX Hardwired range now, use detector.centers_range?
2946
+ if center-width*delta < loc < center+width*delta:
2947
+ use_peaks[n] = True
2948
+ # peak_heights[n] = height
2949
+ # peak_widths[n] = width*delta
2950
+ break
2951
+ if any(use_peaks):
2952
+ break
2953
+ delta *= 2
2954
+ if any(use_peaks):
2955
+ self.logger.debug(
2956
+ f'Using peaks with centers at {peak_locations[use_peaks]}')
2957
+ else:
2958
+ self.logger.warning(
2959
+ 'No matching peaks with heights above the threshold, '
2960
+ f'skipping the fit for detector {detector.get_id()}')
2961
+ return []
2962
+ hkls_fit = hkls_fit[use_peaks]
2963
+
2964
+ # Perform the fit
2965
+ self.logger.info(f'Fitting detector {detector.get_id()} ...')
2966
+ uniform_results, unconstrained_results = get_spectra_fits(
2967
+ np.squeeze(intensities), energies[mask],
2968
+ peak_locations[use_peaks], detector,
2969
+ num_proc=self.config.num_proc, **self.run_config)
2970
+ if intensities.shape[0] == 1:
2971
+ uniform_results = {k: [v] for k, v in uniform_results.items()}
2972
+ unconstrained_results = {
2973
+ k: [v] for k, v in unconstrained_results.items()}
2974
+ for field in ('centers', 'amplitudes', 'sigmas'):
2975
+ uniform_results[field] = np.asarray(
2976
+ uniform_results[field]).T
2977
+ uniform_results[f'{field}_errors'] = np.asarray(
2978
+ uniform_results[f'{field}_errors']).T
2979
+ unconstrained_results[field] = np.asarray(
2980
+ unconstrained_results[field]).T
2981
+ unconstrained_results[f'{field}_errors'] = np.asarray(
2982
+ unconstrained_results[f'{field}_errors']).T
2983
+
2984
+ self.logger.info('... done')
2985
+
2986
+ # Add the fit results to the list of points
2987
+ tth_map = detector.get_tth_map((nxdata.shape[0],))
2988
+ nominal_centers = np.asarray(
2989
+ [get_peak_locations(d0, tth_map)
2990
+ for d0, use_peak in zip(ds_fit, use_peaks) if use_peak])
2991
+ uniform_strains = np.log(
2992
+ nominal_centers / uniform_results['centers'])
2993
+ uniform_strain = np.mean(uniform_strains, axis=0)
2994
+ unconstrained_strains = np.log(
2995
+ nominal_centers / unconstrained_results['centers'])
2996
+ unconstrained_strain = np.mean(unconstrained_strains, axis=0)
2997
+ for i, point in enumerate(points):
2998
+ point.update({
2999
+ f'{detector.get_id()}/data/intensity': intensities[i],
3000
+ f'{detector.get_id()}/data/uniform_strain':
3001
+ uniform_strain[i],
3002
+ f'{detector.get_id()}/data/unconstrained_strain':
3003
+ unconstrained_strain[i],
3004
+ f'{detector.get_id()}/uniform_fit/results/best_fit':
3005
+ uniform_results['best_fits'][i],
3006
+ f'{detector.get_id()}/uniform_fit/results/residual':
3007
+ uniform_results['residuals'][i],
3008
+ f'{detector.get_id()}/uniform_fit/results/redchi':
3009
+ uniform_results['redchis'][i],
3010
+ f'{detector.get_id()}/uniform_fit/results/success':
3011
+ uniform_results['success'][i],
3012
+ f'{detector.get_id()}/unconstrained_fit/results/best_fit':
3013
+ unconstrained_results['best_fits'][i],
3014
+ f'{detector.get_id()}/unconstrained_fit/results/residual':
3015
+ unconstrained_results['residuals'][i],
3016
+ f'{detector.get_id()}/unconstrained_fit/results/redchi':
3017
+ unconstrained_results['redchis'][i],
3018
+ f'{detector.get_id()}/unconstrained_fit/results/success':
3019
+ unconstrained_results['success'][i],
3020
+ })
3021
+ for j, hkl in enumerate(hkls_fit):
3022
+ hkl_name = '_'.join(str(hkl)[1:-1].split(' '))
3023
+ uniform_fit_path = \
3024
+ f'{detector.get_id()}/uniform_fit/{hkl_name}'
3025
+ unconstrained_fit_path = \
3026
+ f'{detector.get_id()}/unconstrained_fit/{hkl_name}'
3027
+ centers = uniform_results['centers']
3028
+ point.update({
3029
+ f'{uniform_fit_path}/centers/values':
3030
+ uniform_results['centers'][j][i],
3031
+ f'{uniform_fit_path}/centers/errors':
3032
+ uniform_results['centers_errors'][j][i],
3033
+ f'{uniform_fit_path}/amplitudes/values':
3034
+ uniform_results['amplitudes'][j][i],
3035
+ f'{uniform_fit_path}/amplitudes/errors':
3036
+ uniform_results['amplitudes_errors'][j][i],
3037
+ f'{uniform_fit_path}/sigmas/values':
3038
+ uniform_results['sigmas'][j][i],
3039
+ f'{uniform_fit_path}/sigmas/errors':
3040
+ uniform_results['sigmas_errors'][j][i],
3041
+ f'{unconstrained_fit_path}/centers/values':
3042
+ unconstrained_results['centers'][j][i],
3043
+ f'{unconstrained_fit_path}/centers/errors':
3044
+ unconstrained_results['centers_errors'][j][i],
3045
+ f'{unconstrained_fit_path}/amplitudes/values':
3046
+ unconstrained_results['amplitudes'][j][i],
3047
+ f'{unconstrained_fit_path}/amplitudes/errors':
3048
+ unconstrained_results['amplitudes_errors'][j][i],
3049
+ f'{unconstrained_fit_path}/sigmas/values':
3050
+ unconstrained_results['sigmas'][j][i],
3051
+ f'{unconstrained_fit_path}/sigmas/errors':
3052
+ unconstrained_results['sigmas_errors'][j][i],
3053
+ })
3054
+
3055
+ # Create an animation of the fit points
3056
+ if (not self.config.skip_animation
3057
+ and (self.interactive or self.save_figures)):
3058
+ self._create_animation(
3059
+ nxdata, energies[mask], intensities,
3060
+ unconstrained_results['best_fits'], detector.get_id())
3061
+
3062
+ return points
3063
+
3064
+
3065
+ if __name__ == '__main__':
3066
+ # Local modules
3067
+ from CHAP.processor import main
3068
+
3069
+ main()