nxs-analysis-tools 0.0.35__py3-none-any.whl → 0.0.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nxs-analysis-tools might be problematic. Click here for more details.

@@ -14,18 +14,20 @@ from nexusformat.nexus import NXfield, NXdata, nxload, NeXusError, NXroot, NXent
14
14
  from scipy import ndimage
15
15
 
16
16
  # Specify items on which users are allowed to perform standalone imports
17
- __all__ = ['load_data', 'load_transform', 'plot_slice', 'Scissors', 'reciprocal_lattice_params', 'rotate_data',
17
+ __all__ = ['load_data', 'load_transform', 'plot_slice', 'Scissors',
18
+ 'reciprocal_lattice_params', 'rotate_data',
18
19
  'array_to_nxdata', 'Padder']
19
20
 
20
21
 
21
22
  def load_data(path):
22
23
  """
23
- Load data from a specified path.
24
+ Load data from a NeXus file at a specified path. It is assumed that the data follows the CHESS
25
+ file structure (i.e., root/entry/data/counts, etc.).
24
26
 
25
27
  Parameters
26
28
  ----------
27
29
  path : str
28
- The path to the data file.
30
+ The path to the NeXus data file.
29
31
 
30
32
  Returns
31
33
  -------
@@ -45,11 +47,11 @@ def load_data(path):
45
47
 
46
48
  def load_transform(path):
47
49
  """
48
- Load nxrefine-transformed data from a specified path.
50
+ Load data obtained from nxrefine output from a specified path.
49
51
 
50
52
  Parameters
51
53
  ----------
52
- path : str The path to the data file.
54
+ path : str The path to the transform data file.
53
55
 
54
56
  Returns
55
57
  -------
@@ -62,7 +64,8 @@ def load_transform(path):
62
64
 
63
65
  def array_to_nxdata(array, data_template, signal_name='counts'):
64
66
  """
65
- Create an NXdata object from an input array and an NXdata template, with an optional signal name.
67
+ Create an NXdata object from an input array and an NXdata template,
68
+ with an optional signal name.
66
69
 
67
70
  Parameters
68
71
  ----------
@@ -70,7 +73,8 @@ def array_to_nxdata(array, data_template, signal_name='counts'):
70
73
  The data array to be included in the NXdata object.
71
74
 
72
75
  data_template : NXdata
73
- An NXdata object serving as a template, which provides information about axes and other metadata.
76
+ An NXdata object serving as a template, which provides information
77
+ about axes and other metadata.
74
78
 
75
79
  signal_name : str, optional
76
80
  The name of the signal within the NXdata object. If not provided,
@@ -79,86 +83,107 @@ def array_to_nxdata(array, data_template, signal_name='counts'):
79
83
  Returns
80
84
  -------
81
85
  NXdata
82
- An NXdata object containing the input data array and associated axes based on the template.
86
+ An NXdata object containing the input data array and associated axes
87
+ based on the template.
83
88
  """
84
89
  d = data_template
85
- return NXdata(NXfield(array, name=signal_name), tuple([d[d.axes[i]] for i in range(len(d.axes))]))
90
+ return NXdata(NXfield(array, name=signal_name),
91
+ tuple(d[d.axes[i]] for i in range(len(d.axes))))
86
92
 
87
93
 
88
- def plot_slice(data, X=None, Y=None, transpose=False, vmin=None, vmax=None, skew_angle=90,
89
- ax=None, xlim=None, ylim=None, xticks=None, yticks=None, cbar=True, logscale=False,
90
- symlogscale=False, cmap='viridis', linthresh=1, title=None, mdheading=None, cbartitle=None,
94
+ def plot_slice(data, X=None, Y=None, transpose=False, vmin=None, vmax=None,
95
+ skew_angle=90, ax=None, xlim=None, ylim=None,
96
+ xticks=None, yticks=None, cbar=True, logscale=False,
97
+ symlogscale=False, cmap='viridis', linthresh=1,
98
+ title=None, mdheading=None, cbartitle=None,
91
99
  **kwargs):
92
100
  """
101
+ Plot a 2D slice of the provided dataset, with optional transformations
102
+ and customizations.
103
+
93
104
  Parameters
94
105
  ----------
95
- data : :class:`nexusformat.nexus.NXdata` object or ndarray
96
- The NXdata object containing the dataset to plot.
106
+ data : :class:`nexusformat.nexus.NXdata` or ndarray
107
+ The dataset to plot. Can be an `NXdata` object or a `numpy` array.
97
108
 
98
109
  X : NXfield, optional
99
- The X axis values. Default is first axis of `data`.
110
+ The X axis values. If None, a default range from 0 to the number of
111
+ columns in `data` is used.
100
112
 
101
113
  Y : NXfield, optional
102
- The y axis values. Default is second axis of `data`.
114
+ The Y axis values. If None, a default range from 0 to the number of
115
+ rows in `data` is used.
103
116
 
104
117
  transpose : bool, optional
105
- If True, tranpose the dataset and its axes before plotting. Default is False.
118
+ If True, transpose the dataset and its axes before plotting.
119
+ Default is False.
106
120
 
107
121
  vmin : float, optional
108
- The minimum value to plot in the dataset.
109
- If not provided, the minimum of the dataset will be used.
122
+ The minimum value for the color scale. If not provided, the minimum
123
+ value of the dataset is used.
110
124
 
111
125
  vmax : float, optional
112
- The maximum value to plot in the dataset.
113
- If not provided, the maximum of the dataset will be used.
126
+ The maximum value for the color scale. If not provided, the maximum
127
+ value of the dataset is used.
114
128
 
115
129
  skew_angle : float, optional
116
- The angle to shear the plot in degrees. Defaults to 90 degrees (no skewing).
130
+ The angle in degrees to shear the plot. Default is 90 degrees (no skew).
117
131
 
118
132
  ax : matplotlib.axes.Axes, optional
119
- An optional axis object to plot the heatmap onto.
133
+ The `matplotlib` axis to plot on. If None, a new figure and axis will
134
+ be created.
120
135
 
121
136
  xlim : tuple, optional
122
- The limits of the x-axis. If not provided, the limits will be automatically set.
137
+ The limits for the x-axis. If None, the limits are set automatically
138
+ based on the data.
123
139
 
124
140
  ylim : tuple, optional
125
- The limits of the y-axis. If not provided, the limits will be automatically set.
141
+ The limits for the y-axis. If None, the limits are set automatically
142
+ based on the data.
126
143
 
127
- xticks : float, optional
128
- The major tick interval for the x-axis.
129
- If not provided, the function will use a default minor tick interval of 1.
144
+ xticks : float or list of float, optional
145
+ The major tick interval or specific tick locations for the x-axis.
146
+ Default is to use a minor tick interval of 1.
130
147
 
131
- yticks : float, optional
132
- The major tick interval for the y-axis.
133
- If not provided, the function will use a default minor tick interval of 1.
148
+ yticks : float or list of float, optional
149
+ The major tick interval or specific tick locations for the y-axis.
150
+ Default is to use a minor tick interval of 1.
134
151
 
135
152
  cbar : bool, optional
136
- Whether to include a colorbar in the plot. Defaults to True.
153
+ Whether to include a colorbar. Default is True.
137
154
 
138
155
  logscale : bool, optional
139
- Whether to use a logarithmic color scale. Defaults to False.
156
+ Whether to use a logarithmic color scale. Default is False.
140
157
 
141
158
  symlogscale : bool, optional
142
- Whether to use a symmetrical logarithmic color scale. Defaults to False.
159
+ Whether to use a symmetrical logarithmic color scale. Default is False.
143
160
 
144
161
  cmap : str or Colormap, optional
145
- The color map to use. Defaults to 'viridis'.
162
+ The colormap to use for the plot. Default is 'viridis'.
146
163
 
147
164
  linthresh : float, optional
148
- The linear threshold for the symmetrical logarithmic color scale. Defaults to 1.
165
+ The linear threshold for symmetrical logarithmic scaling. Default is 1.
166
+
167
+ title : str, optional
168
+ The title for the plot. If None, no title is set.
149
169
 
150
170
  mdheading : str, optional
151
- A string containing the Markdown heading for the plot. Default `None`.
171
+ A Markdown heading to display above the plot. If 'None' or not provided,
172
+ no heading is displayed.
173
+
174
+ cbartitle : str, optional
175
+ The title for the colorbar. If None, the colorbar label will be set to
176
+ the name of the signal.
177
+
178
+ **kwargs
179
+ Additional keyword arguments passed to `pcolormesh`.
152
180
 
153
181
  Returns
154
182
  -------
155
183
  p : :class:`matplotlib.collections.QuadMesh`
156
-
157
- A :class:`matplotlib.collections.QuadMesh` object, to mimick behavior of
158
- :class:`matplotlib.pyplot.pcolormesh`.
159
-
184
+ The `matplotlib` QuadMesh object representing the plotted data.
160
185
  """
161
- if type(data) == np.ndarray:
186
+ if isinstance(data, np.ndarray):
162
187
  if X is None:
163
188
  X = NXfield(np.linspace(0, data.shape[1], data.shape[1]), name='x')
164
189
  if Y is None:
@@ -168,7 +193,7 @@ def plot_slice(data, X=None, Y=None, transpose=False, vmin=None, vmax=None, skew
168
193
  data = data.transpose()
169
194
  data = NXdata(NXfield(data, name='value'), (X, Y))
170
195
  data_arr = data
171
- elif type(data) == NXdata or type(data) == NXfield:
196
+ elif isinstance(data, (NXdata, NXfield)):
172
197
  if X is None:
173
198
  X = data[data.axes[0]]
174
199
  if Y is None:
@@ -178,7 +203,8 @@ def plot_slice(data, X=None, Y=None, transpose=False, vmin=None, vmax=None, skew
178
203
  data = data.transpose()
179
204
  data_arr = data[data.signal].nxdata.transpose()
180
205
  else:
181
- raise TypeError(f"Unexpected data type: {type(data)}. Supported types are np.ndarray and NXdata.")
206
+ raise TypeError(f"Unexpected data type: {type(data)}. "
207
+ f"Supported types are np.ndarray and NXdata.")
182
208
 
183
209
  # Display Markdown heading
184
210
  if mdheading is None:
@@ -190,7 +216,6 @@ def plot_slice(data, X=None, Y=None, transpose=False, vmin=None, vmax=None, skew
190
216
 
191
217
  # Inherit axes if user provides some
192
218
  if ax is not None:
193
- ax = ax
194
219
  fig = ax.get_figure()
195
220
  # Otherwise set up some default axes
196
221
  else:
@@ -235,7 +260,9 @@ def plot_slice(data, X=None, Y=None, transpose=False, vmin=None, vmax=None, skew
235
260
  else:
236
261
  ymin, ymax = ax.get_ylim()
237
262
  # Use ylims to calculate translation (necessary to display axes in correct position)
238
- p.set_transform(t + Affine2D().translate(-ymin * np.sin(skew_angle_adj * np.pi / 180), 0) + ax.transData)
263
+ p.set_transform(t
264
+ + Affine2D().translate(-ymin * np.sin(skew_angle_adj * np.pi / 180), 0)
265
+ + ax.transData)
239
266
 
240
267
  # Set x limits
241
268
  if xlim is not None:
@@ -332,32 +359,32 @@ class Scissors:
332
359
  Extents of the window for integration along each axis.
333
360
  axis : int or None
334
361
  Axis along which to perform the integration.
335
- data_cut : ndarray or None
362
+ integration_volume : :class:`nexusformat.nexus.NXdata` or None
336
363
  Data array after applying the integration window.
337
364
  integrated_axes : tuple or None
338
365
  Indices of axes that were integrated.
339
366
  linecut : :class:`nexusformat.nexus.NXdata` or None
340
367
  1D linecut data after integration.
341
- window_plane_slice_obj : list or None
368
+ integration_window : tuple or None
342
369
  Slice object representing the integration window in the data array.
343
370
 
344
371
  Methods
345
372
  -------
346
373
  set_data(data)
347
- Set the input :class:`nexusformat.nexus.NXdata`
374
+ Set the input :class:`nexusformat.nexus.NXdata`.
348
375
  get_data()
349
376
  Get the input :class:`nexusformat.nexus.NXdata`.
350
377
  set_center(center)
351
378
  Set the central coordinate for the linecut.
352
- set_window(window)
379
+ set_window(window, axis=None, verbose=False)
353
380
  Set the extents of the integration window.
354
381
  get_window()
355
382
  Get the extents of the integration window.
356
- cut_data(axis=None)
383
+ cut_data(center=None, window=None, axis=None, verbose=False)
357
384
  Reduce data to a 1D linecut using the integration window.
358
- show_integration_window(label=None)
385
+ highlight_integration_window(data=None, label=None, highlight_color='red', **kwargs)
359
386
  Plot the integration window highlighted on a 2D heatmap of the full dataset.
360
- plot_window()
387
+ plot_integration_window(**kwargs)
361
388
  Plot a 2D heatmap of the integration window data.
362
389
  """
363
390
 
@@ -378,8 +405,8 @@ class Scissors:
378
405
  """
379
406
 
380
407
  self.data = data
381
- self.center = tuple([float(i) for i in center]) if center is not None else None
382
- self.window = tuple([float(i) for i in window]) if window is not None else None
408
+ self.center = tuple(float(i) for i in center) if center is not None else None
409
+ self.window = tuple(float(i) for i in window) if window is not None else None
383
410
  self.axis = axis
384
411
 
385
412
  self.integration_volume = None
@@ -418,9 +445,9 @@ class Scissors:
418
445
  center : tuple
419
446
  Central coordinate around which to perform the linecut.
420
447
  """
421
- self.center = tuple([float(i) for i in center]) if center is not None else None
448
+ self.center = tuple(float(i) for i in center) if center is not None else None
422
449
 
423
- def set_window(self, window):
450
+ def set_window(self, window, axis=None, verbose=False):
424
451
  """
425
452
  Set the extents of the integration window.
426
453
 
@@ -428,16 +455,25 @@ class Scissors:
428
455
  ----------
429
456
  window : tuple
430
457
  Extents of the window for integration along each axis.
458
+ axis : int or None, optional
459
+ The axis along which to perform the linecut. If not specified, the value from the
460
+ object's attribute will be used.
461
+ verbose : bool, optional
462
+ Enables printout of linecut axis and integrated axes. Default False.
463
+
431
464
  """
432
- self.window = tuple([float(i) for i in window]) if window is not None else None
465
+ self.window = tuple(float(i) for i in window) if window is not None else None
433
466
 
434
467
  # Determine the axis for integration
435
- self.axis = window.index(max(window))
436
- print("Linecut axis: " + str(self.data.axes[self.axis]))
468
+ self.axis = window.index(max(window)) if axis is None else axis
437
469
 
438
470
  # Determine the integrated axes (axes other than the integration axis)
439
471
  self.integrated_axes = tuple(i for i in range(self.data.ndim) if i != self.axis)
440
- print("Integrated axes: " + str([self.data.axes[axis] for axis in self.integrated_axes]))
472
+
473
+ if verbose:
474
+ print("Linecut axis: " + str(self.data.axes[self.axis]))
475
+ print("Integrated axes: " + str([self.data.axes[axis]
476
+ for axis in self.integrated_axes]))
441
477
 
442
478
  def get_window(self):
443
479
  """
@@ -450,13 +486,13 @@ class Scissors:
450
486
  """
451
487
  return self.window
452
488
 
453
- def cut_data(self, center=None, window=None, axis=None):
489
+ def cut_data(self, center=None, window=None, axis=None, verbose=False):
454
490
  """
455
- Reduces data to a 1D linecut with integration extents specified by the window about a central
456
- coordinate.
491
+ Reduces data to a 1D linecut with integration extents specified by the
492
+ window about a central coordinate.
457
493
 
458
494
  Parameters
459
- -----------
495
+ ----------
460
496
  center : float or None, optional
461
497
  Central coordinate for the linecut. If not specified, the value from the object's
462
498
  attribute will be used.
@@ -466,11 +502,14 @@ class Scissors:
466
502
  axis : int or None, optional
467
503
  The axis along which to perform the linecut. If not specified, the value from the
468
504
  object's attribute will be used.
505
+ verbose : bool
506
+ Enables printout of linecut axis and integrated axes. Default False.
469
507
 
470
508
  Returns
471
- --------
509
+ -------
472
510
  integrated_data : :class:`nexusformat.nexus.NXdata`
473
511
  1D linecut data after integration.
512
+
474
513
  """
475
514
 
476
515
  # Extract necessary attributes from the object
@@ -478,8 +517,7 @@ class Scissors:
478
517
  center = center if center is not None else self.center
479
518
  self.set_center(center)
480
519
  window = window if window is not None else self.window
481
- self.set_window(window)
482
- axis = axis if axis is not None else self.axis
520
+ self.set_window(window, axis, verbose)
483
521
 
484
522
  # Convert the center to a tuple of floats
485
523
  center = tuple(float(c) for c in center)
@@ -500,7 +538,7 @@ class Scissors:
500
538
 
501
539
  # Create an NXdata object for the linecut data
502
540
  self.linecut = NXdata(NXfield(integrated_data, name=self.integration_volume.signal),
503
- self.integration_volume[self.integration_volume.axes[axis]])
541
+ self.integration_volume[self.integration_volume.axes[self.axis]])
504
542
  self.linecut.nxname = self.integration_volume.nxname
505
543
 
506
544
  return self.linecut
@@ -526,7 +564,6 @@ class Scissors:
526
564
  data = self.data if data is None else data
527
565
  center = self.center
528
566
  window = self.window
529
- integrated_axes = self.integrated_axes
530
567
 
531
568
  # Create a figure and subplots
532
569
  fig, axes = plt.subplots(1, 3, figsize=(15, 4))
@@ -545,7 +582,8 @@ class Scissors:
545
582
  (center[0] - window[0],
546
583
  center[1] - window[1]),
547
584
  2 * window[0], 2 * window[1],
548
- linewidth=1, edgecolor=highlight_color, facecolor='none', transform=p1.get_transform(), label=label,
585
+ linewidth=1, edgecolor=highlight_color,
586
+ facecolor='none', transform=p1.get_transform(), label=label,
549
587
  )
550
588
  ax.add_patch(rect_diffuse)
551
589
 
@@ -563,7 +601,8 @@ class Scissors:
563
601
  (center[0] - window[0],
564
602
  center[2] - window[2]),
565
603
  2 * window[0], 2 * window[2],
566
- linewidth=1, edgecolor=highlight_color, facecolor='none', transform=p2.get_transform(), label=label,
604
+ linewidth=1, edgecolor=highlight_color,
605
+ facecolor='none', transform=p2.get_transform(), label=label,
567
606
  )
568
607
  ax.add_patch(rect_diffuse)
569
608
 
@@ -581,7 +620,8 @@ class Scissors:
581
620
  (center[1] - window[1],
582
621
  center[2] - window[2]),
583
622
  2 * window[1], 2 * window[2],
584
- linewidth=1, edgecolor=highlight_color, facecolor='none', transform=p3.get_transform(), label=label,
623
+ linewidth=1, edgecolor=highlight_color,
624
+ facecolor='none', transform=p3.get_transform(), label=label,
585
625
  )
586
626
  ax.add_patch(rect_diffuse)
587
627
 
@@ -605,10 +645,7 @@ class Scissors:
605
645
  Additional keyword arguments to customize the plot.
606
646
  """
607
647
  data = self.integration_volume
608
- axis = self.axis
609
648
  center = self.center
610
- window = self.window
611
- integrated_axes = self.integrated_axes
612
649
 
613
650
  fig, axes = plt.subplots(1, 3, figsize=(15, 4))
614
651
 
@@ -651,6 +688,23 @@ class Scissors:
651
688
 
652
689
 
653
690
  def reciprocal_lattice_params(lattice_params):
691
+ """
692
+ Calculate the reciprocal lattice parameters from the given direct lattice parameters.
693
+
694
+ Parameters
695
+ ----------
696
+ lattice_params : tuple
697
+ A tuple containing the direct lattice parameters (a, b, c, alpha, beta, gamma), where
698
+ a, b, and c are the magnitudes of the lattice vectors, and alpha, beta, and gamma are the
699
+ angles between them in degrees.
700
+
701
+ Returns
702
+ -------
703
+ tuple
704
+ A tuple containing the reciprocal lattice parameters (a*, b*, c*, alpha*, beta*, gamma*),
705
+ where a*, b*, and c* are the magnitudes of the reciprocal lattice vectors, and alpha*,
706
+ beta*, and gamma* are the angles between them in degrees.
707
+ """
654
708
  a_mag, b_mag, c_mag, alpha, beta, gamma = lattice_params
655
709
  # Convert angles to radians
656
710
  alpha = np.deg2rad(alpha)
@@ -659,17 +713,20 @@ def reciprocal_lattice_params(lattice_params):
659
713
 
660
714
  # Calculate unit cell volume
661
715
  V = a_mag * b_mag * c_mag * np.sqrt(
662
- 1 - np.cos(alpha) ** 2 - np.cos(beta) ** 2 - np.cos(gamma) ** 2 + 2 * np.cos(alpha) * np.cos(beta) * np.cos(
663
- gamma)
716
+ 1 - np.cos(alpha) ** 2 - np.cos(beta) ** 2 - np.cos(gamma) ** 2
717
+ + 2 * np.cos(alpha) * np.cos(beta) * np.cos(gamma)
664
718
  )
665
719
 
666
720
  # Calculate reciprocal lattice parameters
667
721
  a_star = (b_mag * c_mag * np.sin(alpha)) / V
668
722
  b_star = (a_mag * c_mag * np.sin(beta)) / V
669
723
  c_star = (a_mag * b_mag * np.sin(gamma)) / V
670
- alpha_star = np.rad2deg(np.arccos((np.cos(beta) * np.cos(gamma) - np.cos(alpha)) / (np.sin(beta) * np.sin(gamma))))
671
- beta_star = np.rad2deg(np.arccos((np.cos(alpha) * np.cos(gamma) - np.cos(beta)) / (np.sin(alpha) * np.sin(gamma))))
672
- gamma_star = np.rad2deg(np.arccos((np.cos(alpha) * np.cos(beta) - np.cos(gamma)) / (np.sin(alpha) * np.sin(beta))))
724
+ alpha_star = np.rad2deg(np.arccos((np.cos(beta) * np.cos(gamma) - np.cos(alpha))
725
+ / (np.sin(beta) * np.sin(gamma))))
726
+ beta_star = np.rad2deg(np.arccos((np.cos(alpha) * np.cos(gamma) - np.cos(beta))
727
+ / (np.sin(alpha) * np.sin(gamma))))
728
+ gamma_star = np.rad2deg(np.arccos((np.cos(alpha) * np.cos(beta) - np.cos(gamma))
729
+ / (np.sin(alpha) * np.sin(beta))))
673
730
 
674
731
  return a_star, b_star, c_star, alpha_star, beta_star, gamma_star
675
732
 
@@ -689,9 +746,10 @@ def rotate_data(data, lattice_angle, rotation_angle, rotation_axis, printout=Fal
689
746
  rotation_axis : int
690
747
  Axis of rotation (0, 1, or 2).
691
748
  printout : bool, optional
692
- Enables printout of rotation progress. If set to True, information about each rotation slice will be printed
693
- to the console, indicating the axis being rotated and the corresponding
694
- coordinate value. Defaults to False.
749
+ Enables printout of rotation progress. If set to True, information
750
+ about each rotation slice will be printed to the console, indicating
751
+ the axis being rotated and the corresponding coordinate value.
752
+ Defaults to False.
695
753
 
696
754
 
697
755
  Returns
@@ -714,7 +772,8 @@ def rotate_data(data, lattice_angle, rotation_angle, rotation_axis, printout=Fal
714
772
 
715
773
  for i in range(len(data[data.axes[rotation_axis]])):
716
774
  if printout:
717
- print(f'\rRotating {data.axes[rotation_axis]}={data[data.axes[rotation_axis]][i]}... ',
775
+ print(f'\rRotating {data.axes[rotation_axis]}'
776
+ f'={data[data.axes[rotation_axis]][i]}... ',
718
777
  end='', flush=True)
719
778
  # Identify current slice
720
779
  if rotation_axis == 0:
@@ -727,12 +786,14 @@ def rotate_data(data, lattice_angle, rotation_angle, rotation_axis, printout=Fal
727
786
  sliced_data = None
728
787
 
729
788
  p = Padder(sliced_data)
730
- padding = tuple([len(sliced_data[axis]) for axis in sliced_data.axes])
789
+ padding = tuple(len(sliced_data[axis]) for axis in sliced_data.axes)
731
790
  counts = p.pad(padding).counts
732
791
 
733
792
  counts_skewed = ndimage.affine_transform(counts,
734
793
  t.inverted().get_matrix()[:2, :2],
735
- offset=[counts.shape[0] / 2 * np.sin(skew_angle_adj * np.pi / 180), 0],
794
+ offset=[counts.shape[0] / 2
795
+ * np.sin(skew_angle_adj * np.pi / 180),
796
+ 0],
736
797
  order=0,
737
798
  )
738
799
  scale1 = np.cos(skew_angle_adj * np.pi / 180)
@@ -751,15 +812,18 @@ def rotate_data(data, lattice_angle, rotation_angle, rotation_axis, printout=Fal
751
812
  counts_rotated = ndimage.rotate(counts_scaled2, rotation_angle, reshape=False, order=0)
752
813
 
753
814
  counts_unscaled2 = ndimage.affine_transform(counts_rotated,
754
- Affine2D().scale(scale2, 1).inverted().get_matrix()[:2, :2],
815
+ Affine2D().scale(
816
+ scale2, 1
817
+ ).inverted().get_matrix()[:2, :2],
755
818
  offset=[-(1 - scale2) * counts.shape[
756
819
  0] / 2 / scale2, 0],
757
820
  order=0,
758
821
  )
759
822
 
760
823
  counts_unscaled1 = ndimage.affine_transform(counts_unscaled2,
761
- Affine2D().scale(scale1,
762
- 1).inverted().get_matrix()[:2, :2],
824
+ Affine2D().scale(
825
+ scale1, 1
826
+ ).inverted().get_matrix()[:2, :2],
763
827
  offset=[-(1 - scale1) * counts.shape[
764
828
  0] / 2 / scale1, 0],
765
829
  order=0,
@@ -768,7 +832,8 @@ def rotate_data(data, lattice_angle, rotation_angle, rotation_axis, printout=Fal
768
832
  counts_unskewed = ndimage.affine_transform(counts_unscaled1,
769
833
  t.get_matrix()[:2, :2],
770
834
  offset=[
771
- (-counts.shape[0] / 2 * np.sin(skew_angle_adj * np.pi / 180)),
835
+ (-counts.shape[0] / 2
836
+ * np.sin(skew_angle_adj * np.pi / 180)),
772
837
  0],
773
838
  order=0,
774
839
  )
@@ -786,9 +851,10 @@ def rotate_data(data, lattice_angle, rotation_angle, rotation_axis, printout=Fal
786
851
  return NXdata(NXfield(output_array, name='counts'),
787
852
  (data[data.axes[0]], data[data.axes[1]], data[data.axes[2]]))
788
853
 
854
+
789
855
  def rotate_data2D(data, lattice_angle, rotation_angle):
790
856
  """
791
- Rotates 3D data around a specified axis.
857
+ Rotates 2D data.
792
858
 
793
859
  Parameters
794
860
  ----------
@@ -797,7 +863,7 @@ def rotate_data2D(data, lattice_angle, rotation_angle):
797
863
  lattice_angle : float
798
864
  Angle between the two in-plane lattice axes in degrees.
799
865
  rotation_angle : float
800
- Angle of rotation in degrees..
866
+ Angle of rotation in degrees.
801
867
 
802
868
 
803
869
  Returns
@@ -817,12 +883,13 @@ def rotate_data2D(data, lattice_angle, rotation_angle):
817
883
  t += Affine2D().scale(1, np.cos(skew_angle_adj * np.pi / 180)).inverted()
818
884
 
819
885
  p = Padder(data)
820
- padding = tuple([len(data[axis]) for axis in data.axes])
886
+ padding = tuple(len(data[axis]) for axis in data.axes)
821
887
  counts = p.pad(padding).counts
822
888
 
823
889
  counts_skewed = ndimage.affine_transform(counts,
824
890
  t.inverted().get_matrix()[:2, :2],
825
- offset=[counts.shape[0] / 2 * np.sin(skew_angle_adj * np.pi / 180), 0],
891
+ offset=[counts.shape[0] / 2
892
+ * np.sin(skew_angle_adj * np.pi / 180), 0],
826
893
  order=0,
827
894
  )
828
895
  scale1 = np.cos(skew_angle_adj * np.pi / 180)
@@ -841,15 +908,18 @@ def rotate_data2D(data, lattice_angle, rotation_angle):
841
908
  counts_rotated = ndimage.rotate(counts_scaled2, rotation_angle, reshape=False, order=0)
842
909
 
843
910
  counts_unscaled2 = ndimage.affine_transform(counts_rotated,
844
- Affine2D().scale(scale2, 1).inverted().get_matrix()[:2, :2],
911
+ Affine2D().scale(
912
+ scale2, 1
913
+ ).inverted().get_matrix()[:2, :2],
845
914
  offset=[-(1 - scale2) * counts.shape[
846
915
  0] / 2 / scale2, 0],
847
916
  order=0,
848
917
  )
849
918
 
850
919
  counts_unscaled1 = ndimage.affine_transform(counts_unscaled2,
851
- Affine2D().scale(scale1,
852
- 1).inverted().get_matrix()[:2, :2],
920
+ Affine2D().scale(
921
+ scale1, 1
922
+ ).inverted().get_matrix()[:2, :2],
853
923
  offset=[-(1 - scale1) * counts.shape[
854
924
  0] / 2 / scale1, 0],
855
925
  order=0,
@@ -858,7 +928,8 @@ def rotate_data2D(data, lattice_angle, rotation_angle):
858
928
  counts_unskewed = ndimage.affine_transform(counts_unscaled1,
859
929
  t.get_matrix()[:2, :2],
860
930
  offset=[
861
- (-counts.shape[0] / 2 * np.sin(skew_angle_adj * np.pi / 180)),
931
+ (-counts.shape[0] / 2
932
+ * np.sin(skew_angle_adj * np.pi / 180)),
862
933
  0],
863
934
  order=0,
864
935
  )
@@ -870,20 +941,44 @@ def rotate_data2D(data, lattice_angle, rotation_angle):
870
941
  (data[data.axes[0]], data[data.axes[1]]))
871
942
 
872
943
 
873
- class Padder():
944
+ class Padder:
874
945
  """
875
- A class to pad and unpad datasets with a symmetric region of zeros.
946
+ A class to symmetrically pad and unpad datasets with a region of zeros.
947
+
948
+ Attributes
949
+ ----------
950
+ data : NXdata or None
951
+ The input data to be padded.
952
+ padded : NXdata or None
953
+ The padded data with symmetric zero padding.
954
+ padding : tuple or None
955
+ The number of zero-value pixels added along each edge of the array.
956
+ steps : tuple or None
957
+ The step sizes along each axis of the dataset.
958
+ maxes : tuple or None
959
+ The maximum values along each axis of the dataset.
960
+
961
+ Methods
962
+ -------
963
+ set_data(data)
964
+ Set the input data for padding.
965
+ pad(padding)
966
+ Symmetrically pads the data with zero values.
967
+ save(fout_name=None)
968
+ Saves the padded dataset to a .nxs file.
969
+ unpad(data)
970
+ Removes the padded region from the data.
876
971
  """
877
972
 
878
973
  def __init__(self, data=None):
879
974
  """
880
- Initialize the Symmetrizer3D object.
975
+ Initialize the Padder object.
881
976
 
882
977
  Parameters
883
978
  ----------
884
979
  data : NXdata, optional
885
- The input data to be symmetrized. If provided, the `set_data` method is called to set the data.
886
-
980
+ The input data to be padded. If provided, the `set_data` method
981
+ is called to set the data.
887
982
  """
888
983
  self.padded = None
889
984
  self.padding = None
@@ -892,20 +987,21 @@ class Padder():
892
987
 
893
988
  def set_data(self, data):
894
989
  """
895
- Set the input data for symmetrization.
990
+ Set the input data for padding.
896
991
 
897
992
  Parameters
898
993
  ----------
899
994
  data : NXdata
900
- The input data to be symmetrized.
901
-
995
+ The input data to be padded.
902
996
  """
903
997
  self.data = data
904
998
 
905
- self.steps = tuple([(data[axis].nxdata[1] - data[axis].nxdata[0]) for axis in data.axes])
999
+ self.steps = tuple((data[axis].nxdata[1] - data[axis].nxdata[0])
1000
+ for axis in data.axes)
906
1001
 
907
- # Absolute value of the maximum value; assumes the domain of the input is symmetric (eg, -H_min = H_max)
908
- self.maxes = tuple([data[axis].nxdata.max() for axis in data.axes])
1002
+ # Absolute value of the maximum value; assumes the domain of the input
1003
+ # is symmetric (eg, -H_min = H_max)
1004
+ self.maxes = tuple(data[axis].nxdata.max() for axis in data.axes)
909
1005
 
910
1006
  def pad(self, padding):
911
1007
  """
@@ -915,11 +1011,17 @@ class Padder():
915
1011
  ----------
916
1012
  padding : tuple
917
1013
  The number of zero-value pixels to add along each edge of the array.
1014
+
1015
+ Returns
1016
+ -------
1017
+ NXdata
1018
+ The padded data with symmetric zero padding.
918
1019
  """
919
1020
  data = self.data
920
1021
  self.padding = padding
921
1022
 
922
- padded_shape = tuple([data[data.signal].nxdata.shape[i] + self.padding[i] * 2 for i in range(data.ndim)])
1023
+ padded_shape = tuple(data[data.signal].nxdata.shape[i]
1024
+ + self.padding[i] * 2 for i in range(data.ndim))
923
1025
 
924
1026
  # Create padded dataset
925
1027
  padded = np.zeros(padded_shape)
@@ -930,12 +1032,13 @@ class Padder():
930
1032
  slice_obj = tuple(slice_obj)
931
1033
  padded[slice_obj] = data[data.signal].nxdata
932
1034
 
933
- padmaxes = tuple([self.maxes[i] + self.padding[i] * self.steps[i] for i in range(data.ndim)])
1035
+ padmaxes = tuple(self.maxes[i] + self.padding[i] * self.steps[i]
1036
+ for i in range(data.ndim))
934
1037
 
935
1038
  padded = NXdata(NXfield(padded, name=data.signal),
936
- tuple([NXfield(np.linspace(-padmaxes[i], padmaxes[i], padded_shape[i]),
937
- name=data.axes[i])
938
- for i in range(data.ndim)]))
1039
+ tuple(NXfield(np.linspace(-padmaxes[i], padmaxes[i], padded_shape[i]),
1040
+ name=data.axes[i])
1041
+ for i in range(data.ndim)))
939
1042
 
940
1043
  self.padded = padded
941
1044
  return padded
@@ -974,16 +1077,38 @@ class Padder():
974
1077
  -------
975
1078
  ndarray or NXdata
976
1079
  The unpadded data, with the symmetric padding region removed.
977
-
978
- Notes
979
- -----
980
- This method removes the symmetric padding region that was added using the `pad` method. It returns the data
981
- without the padded region.
982
-
983
-
984
1080
  """
985
1081
  slice_obj = [slice(None)] * data.ndim
986
1082
  for i in range(data.ndim):
987
1083
  slice_obj[i] = slice(self.padding[i], -self.padding[i], None)
988
1084
  slice_obj = tuple(slice_obj)
989
1085
  return data[slice_obj]
1086
+
1087
+
1088
+ def load_discus_nxs(path):
1089
+ """
1090
+ Load .nxs format data from the DISCUS program (by T. Proffen and R. Neder)
1091
+ and convert it to the CHESS format.
1092
+
1093
+ Parameters
1094
+ ----------
1095
+ path : str
1096
+ The file path to the .nxs file generated by DISCUS.
1097
+
1098
+ Returns
1099
+ -------
1100
+ NXdata
1101
+ The data converted to the CHESS format, with axes labeled 'H', 'K', and 'L',
1102
+ and the signal labeled 'counts'.
1103
+
1104
+ """
1105
+ filename = path
1106
+ root = nxload(filename)
1107
+ hlim, klim, llim = root.lower_limits
1108
+ hstep, kstep, lstep = root.step_sizes
1109
+ h = NXfield(np.linspace(hlim, -hlim, int(np.abs(hlim * 2) / hstep) + 1), name='H')
1110
+ k = NXfield(np.linspace(klim, -klim, int(np.abs(klim * 2) / kstep) + 1), name='K')
1111
+ l = NXfield(np.linspace(llim, -llim, int(np.abs(llim * 2) / lstep) + 1), name='L')
1112
+ data = NXdata(NXfield(root.data[:, :, :], name='counts'), (h, k, l))
1113
+
1114
+ return data