nxs-analysis-tools 0.0.46__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nxs-analysis-tools might be problematic. Click here for more details.

@@ -1,1132 +1,1360 @@
1
- """
2
- Reduces scattering data into 2D and 1D datasets.
3
- """
4
- import os
5
- import numpy as np
6
- import matplotlib.pyplot as plt
7
- from matplotlib.transforms import Affine2D
8
- from matplotlib.markers import MarkerStyle
9
- from matplotlib.ticker import MultipleLocator
10
- from matplotlib import colors
11
- from matplotlib import patches
12
- from IPython.display import display, Markdown
13
- from nexusformat.nexus import NXfield, NXdata, nxload, NeXusError, NXroot, NXentry, nxsave
14
- from scipy import ndimage
15
-
16
- # Specify items on which users are allowed to perform standalone imports
17
- __all__ = ['load_data', 'load_transform', 'plot_slice', 'Scissors',
18
- 'reciprocal_lattice_params', 'rotate_data', 'rotate_data_2D'
19
- 'array_to_nxdata', 'Padder']
20
-
21
-
22
- def load_data(path, print_tree=True):
23
- """
24
- Load data from a NeXus file at a specified path. It is assumed that the data follows the CHESS
25
- file structure (i.e., root/entry/data/counts, etc.).
26
-
27
- Parameters
28
- ----------
29
- path : str
30
- The path to the NeXus data file.
31
-
32
- print_tree : bool, optional
33
- Whether to print the data tree upon loading. Default True.
34
-
35
- Returns
36
- -------
37
- data : nxdata object
38
- The loaded data stored in a nxdata object.
39
-
40
- """
41
-
42
- g = nxload(path)
43
- try:
44
- print(g.entry.data.tree) if print_tree else None
45
- except NeXusError:
46
- pass
47
-
48
- return g.entry.data
49
-
50
-
51
- def load_transform(path, print_tree=True):
52
- """
53
- Load data obtained from nxrefine output from a specified path.
54
-
55
- Parameters
56
- ----------
57
- path : str
58
- The path to the transform data file.
59
-
60
- print_tree : bool, optional
61
- Whether to print the data tree upon loading. Default True.
62
-
63
- Returns
64
- -------
65
- data : nxdata object
66
- The loaded data stored in a nxdata object.
67
- """
68
-
69
- g = nxload(path)
70
-
71
- data = NXdata(NXfield(g.entry.transform.data.nxdata.transpose(2, 1, 0), name='counts'),
72
- (g.entry.transform.Qh, g.entry.transform.Qk, g.entry.transform.Ql))
73
-
74
- print(data.tree) if print_tree else None
75
-
76
- return data
77
-
78
-
79
- def array_to_nxdata(array, data_template, signal_name=None):
80
- """
81
- Create an NXdata object from an input array and an NXdata template,
82
- with an optional signal name.
83
-
84
- Parameters
85
- ----------
86
- array : array-like
87
- The data array to be included in the NXdata object.
88
-
89
- data_template : NXdata
90
- An NXdata object serving as a template, which provides information
91
- about axes and other metadata.
92
-
93
- signal_name : str, optional
94
- The name of the signal within the NXdata object. If not provided,
95
- the signal name is inherited from the data_template.
96
-
97
- Returns
98
- -------
99
- NXdata
100
- An NXdata object containing the input data array and associated axes
101
- based on the template.
102
- """
103
- d = data_template
104
- if signal_name is None:
105
- signal_name = d.signal
106
- return NXdata(NXfield(array, name=signal_name),
107
- tuple(d[d.axes[i]] for i in range(len(d.axes))))
108
-
109
-
110
- def plot_slice(data, X=None, Y=None, transpose=False, vmin=None, vmax=None,
111
- skew_angle=90, ax=None, xlim=None, ylim=None,
112
- xticks=None, yticks=None, cbar=True, logscale=False,
113
- symlogscale=False, cmap='viridis', linthresh=1,
114
- title=None, mdheading=None, cbartitle=None,
115
- **kwargs):
116
- """
117
- Plot a 2D slice of the provided dataset, with optional transformations
118
- and customizations.
119
-
120
- Parameters
121
- ----------
122
- data : :class:`nexusformat.nexus.NXdata` or ndarray
123
- The dataset to plot. Can be an `NXdata` object or a `numpy` array.
124
-
125
- X : NXfield, optional
126
- The X axis values. If None, a default range from 0 to the number of
127
- columns in `data` is used.
128
-
129
- Y : NXfield, optional
130
- The Y axis values. If None, a default range from 0 to the number of
131
- rows in `data` is used.
132
-
133
- transpose : bool, optional
134
- If True, transpose the dataset and its axes before plotting.
135
- Default is False.
136
-
137
- vmin : float, optional
138
- The minimum value for the color scale. If not provided, the minimum
139
- value of the dataset is used.
140
-
141
- vmax : float, optional
142
- The maximum value for the color scale. If not provided, the maximum
143
- value of the dataset is used.
144
-
145
- skew_angle : float, optional
146
- The angle in degrees to shear the plot. Default is 90 degrees (no skew).
147
-
148
- ax : matplotlib.axes.Axes, optional
149
- The `matplotlib` axis to plot on. If None, a new figure and axis will
150
- be created.
151
-
152
- xlim : tuple, optional
153
- The limits for the x-axis. If None, the limits are set automatically
154
- based on the data.
155
-
156
- ylim : tuple, optional
157
- The limits for the y-axis. If None, the limits are set automatically
158
- based on the data.
159
-
160
- xticks : float or list of float, optional
161
- The major tick interval or specific tick locations for the x-axis.
162
- Default is to use a minor tick interval of 1.
163
-
164
- yticks : float or list of float, optional
165
- The major tick interval or specific tick locations for the y-axis.
166
- Default is to use a minor tick interval of 1.
167
-
168
- cbar : bool, optional
169
- Whether to include a colorbar. Default is True.
170
-
171
- logscale : bool, optional
172
- Whether to use a logarithmic color scale. Default is False.
173
-
174
- symlogscale : bool, optional
175
- Whether to use a symmetrical logarithmic color scale. Default is False.
176
-
177
- cmap : str or Colormap, optional
178
- The colormap to use for the plot. Default is 'viridis'.
179
-
180
- linthresh : float, optional
181
- The linear threshold for symmetrical logarithmic scaling. Default is 1.
182
-
183
- title : str, optional
184
- The title for the plot. If None, no title is set.
185
-
186
- mdheading : str, optional
187
- A Markdown heading to display above the plot. If 'None' or not provided,
188
- no heading is displayed.
189
-
190
- cbartitle : str, optional
191
- The title for the colorbar. If None, the colorbar label will be set to
192
- the name of the signal.
193
-
194
- **kwargs
195
- Additional keyword arguments passed to `pcolormesh`.
196
-
197
- Returns
198
- -------
199
- p : :class:`matplotlib.collections.QuadMesh`
200
- The `matplotlib` QuadMesh object representing the plotted data.
201
- """
202
- if isinstance(data, np.ndarray):
203
- if X is None:
204
- X = NXfield(np.linspace(0, data.shape[0], data.shape[0]), name='x')
205
- if Y is None:
206
- Y = NXfield(np.linspace(0, data.shape[1], data.shape[1]), name='y')
207
- if transpose:
208
- X, Y = Y, X
209
- data = data.transpose()
210
- data = NXdata(NXfield(data, name='value'), (X, Y))
211
- data_arr = data[data.signal].nxdata.transpose()
212
- elif isinstance(data, (NXdata, NXfield)):
213
- if X is None:
214
- X = data[data.axes[0]]
215
- if Y is None:
216
- Y = data[data.axes[1]]
217
- if transpose:
218
- X, Y = Y, X
219
- data = data.transpose()
220
- data_arr = data[data.signal].nxdata.transpose()
221
- else:
222
- raise TypeError(f"Unexpected data type: {type(data)}. "
223
- f"Supported types are np.ndarray and NXdata.")
224
-
225
- # Display Markdown heading
226
- if mdheading is None:
227
- pass
228
- elif mdheading == "None":
229
- display(Markdown('### Figure'))
230
- else:
231
- display(Markdown('### Figure - ' + mdheading))
232
-
233
- # Inherit axes if user provides some
234
- if ax is not None:
235
- fig = ax.get_figure()
236
- # Otherwise set up some default axes
237
- else:
238
- fig = plt.figure()
239
- ax = fig.add_axes([0, 0, 1, 1])
240
-
241
- # If limits not provided, use extrema
242
- if vmin is None:
243
- vmin = data_arr.min()
244
- if vmax is None:
245
- vmax = data_arr.max()
246
-
247
- # Set norm (linear scale, logscale, or symlogscale)
248
- norm = colors.Normalize(vmin=vmin, vmax=vmax) # Default: linear scale
249
-
250
- if symlogscale:
251
- norm = colors.SymLogNorm(linthresh=linthresh, vmin=-1 * vmax, vmax=vmax)
252
- elif logscale:
253
- norm = colors.LogNorm(vmin=vmin, vmax=vmax)
254
-
255
- # Plot data
256
- p = ax.pcolormesh(X.nxdata, Y.nxdata, data_arr, shading='auto', norm=norm, cmap=cmap, **kwargs)
257
-
258
- ## Transform data to new coordinate system if necessary
259
- # Correct skew angle
260
- skew_angle_adj = 90 - skew_angle
261
- # Create blank 2D affine transformation
262
- t = Affine2D()
263
- # Scale y-axis to preserve norm while shearing
264
- t += Affine2D().scale(1, np.cos(skew_angle_adj * np.pi / 180))
265
- # Shear along x-axis
266
- t += Affine2D().skew_deg(skew_angle_adj, 0)
267
- # Return to original y-axis scaling
268
- t += Affine2D().scale(1, np.cos(skew_angle_adj * np.pi / 180)).inverted()
269
- ## Correct for x-displacement after shearing
270
- # If ylims provided, use those
271
- if ylim is not None:
272
- # Set ylims
273
- ax.set(ylim=ylim)
274
- ymin, ymax = ylim
275
- # Else, use current ylims
276
- else:
277
- ymin, ymax = ax.get_ylim()
278
- # Use ylims to calculate translation (necessary to display axes in correct position)
279
- p.set_transform(t
280
- + Affine2D().translate(-ymin * np.sin(skew_angle_adj * np.pi / 180), 0)
281
- + ax.transData)
282
-
283
- # Set x limits
284
- if xlim is not None:
285
- xmin, xmax = xlim
286
- else:
287
- xmin, xmax = ax.get_xlim()
288
- if skew_angle <= 90:
289
- ax.set(xlim=(xmin, xmax + (ymax - ymin) / np.tan((90 - skew_angle_adj) * np.pi / 180)))
290
- else:
291
- ax.set(xlim=(xmin - (ymax - ymin) / np.tan((skew_angle_adj - 90) * np.pi / 180), xmax))
292
-
293
- # Correct aspect ratio for the x/y axes after transformation
294
- ax.set(aspect=np.cos(skew_angle_adj * np.pi / 180))
295
-
296
- # Add tick marks all around
297
- ax.tick_params(direction='in', top=True, right=True, which='both')
298
-
299
- # Set tick locations
300
- if xticks is None:
301
- # Add default minor ticks
302
- ax.xaxis.set_minor_locator(MultipleLocator(1))
303
- else:
304
- # Otherwise use user provided values
305
- ax.xaxis.set_major_locator(MultipleLocator(xticks))
306
- ax.xaxis.set_minor_locator(MultipleLocator(1))
307
- if yticks is None:
308
- # Add default minor ticks
309
- ax.yaxis.set_minor_locator(MultipleLocator(1))
310
- else:
311
- # Otherwise use user provided values
312
- ax.yaxis.set_major_locator(MultipleLocator(yticks))
313
- ax.yaxis.set_minor_locator(MultipleLocator(1))
314
-
315
- # Apply transform to tick marks
316
- for i in range(0, len(ax.xaxis.get_ticklines())):
317
- # Tick marker
318
- m = MarkerStyle(3)
319
- line = ax.xaxis.get_majorticklines()[i]
320
- if i % 2:
321
- # Top ticks (translation here makes their direction="in")
322
- m._transform.set(Affine2D().translate(0, -1) + Affine2D().skew_deg(skew_angle_adj, 0))
323
- # This first method shifts the top ticks horizontally to match the skew angle.
324
- # This does not look good in all cases.
325
- # line.set_transform(Affine2D().translate((ymax-ymin)*np.sin(skew_angle*np.pi/180),0) +
326
- # line.get_transform())
327
- # This second method skews the tick marks in place and
328
- # can sometimes lead to them being misaligned.
329
- line.set_transform(line.get_transform()) # This does nothing
330
- else:
331
- # Bottom ticks
332
- m._transform.set(Affine2D().skew_deg(skew_angle_adj, 0))
333
-
334
- line.set_marker(m)
335
-
336
- for i in range(0, len(ax.xaxis.get_minorticklines())):
337
- m = MarkerStyle(2)
338
- line = ax.xaxis.get_minorticklines()[i]
339
- if i % 2:
340
- m._transform.set(Affine2D().translate(0, -1) + Affine2D().skew_deg(skew_angle_adj, 0))
341
- else:
342
- m._transform.set(Affine2D().skew_deg(skew_angle_adj, 0))
343
-
344
- line.set_marker(m)
345
-
346
- if cbar:
347
- colorbar = fig.colorbar(p)
348
- if cbartitle is None:
349
- colorbar.set_label(data.signal)
350
-
351
- ax.set(
352
- xlabel=X.nxname,
353
- ylabel=Y.nxname,
354
- )
355
-
356
- if title is not None:
357
- ax.set_title(title)
358
-
359
- # Return the quadmesh object
360
- return p
361
-
362
-
363
- class Scissors:
364
- """
365
- Scissors class provides functionality for reducing data to a 1D linecut using an integration
366
- window.
367
-
368
- Attributes
369
- ----------
370
- data : :class:`nexusformat.nexus.NXdata` or None
371
- Input :class:`nexusformat.nexus.NXdata`.
372
- center : tuple or None
373
- Central coordinate around which to perform the linecut.
374
- window : tuple or None
375
- Extents of the window for integration along each axis.
376
- axis : int or None
377
- Axis along which to perform the integration.
378
- integration_volume : :class:`nexusformat.nexus.NXdata` or None
379
- Data array after applying the integration window.
380
- integrated_axes : tuple or None
381
- Indices of axes that were integrated.
382
- linecut : :class:`nexusformat.nexus.NXdata` or None
383
- 1D linecut data after integration.
384
- integration_window : tuple or None
385
- Slice object representing the integration window in the data array.
386
-
387
- Methods
388
- -------
389
- set_data(data)
390
- Set the input :class:`nexusformat.nexus.NXdata`.
391
- get_data()
392
- Get the input :class:`nexusformat.nexus.NXdata`.
393
- set_center(center)
394
- Set the central coordinate for the linecut.
395
- set_window(window, axis=None, verbose=False)
396
- Set the extents of the integration window.
397
- get_window()
398
- Get the extents of the integration window.
399
- cut_data(center=None, window=None, axis=None, verbose=False)
400
- Reduce data to a 1D linecut using the integration window.
401
- highlight_integration_window(data=None, label=None, highlight_color='red', **kwargs)
402
- Plot the integration window highlighted on a 2D heatmap of the full dataset.
403
- plot_integration_window(**kwargs)
404
- Plot a 2D heatmap of the integration window data.
405
- """
406
-
407
- def __init__(self, data=None, center=None, window=None, axis=None):
408
- """
409
- Initializes a Scissors object.
410
-
411
- Parameters
412
- ----------
413
- data : :class:`nexusformat.nexus.NXdata` or None, optional
414
- Input NXdata. Default is None.
415
- center : tuple or None, optional
416
- Central coordinate around which to perform the linecut. Default is None.
417
- window : tuple or None, optional
418
- Extents of the window for integration along each axis. Default is None.
419
- axis : int or None, optional
420
- Axis along which to perform the integration. Default is None.
421
- """
422
-
423
- self.data = data
424
- self.center = tuple(float(i) for i in center) if center is not None else None
425
- self.window = tuple(float(i) for i in window) if window is not None else None
426
- self.axis = axis
427
-
428
- self.integration_volume = None
429
- self.integrated_axes = None
430
- self.linecut = None
431
- self.integration_window = None
432
-
433
- def set_data(self, data):
434
- """
435
- Set the input NXdata.
436
-
437
- Parameters
438
- ----------
439
- data : :class:`nexusformat.nexus.NXdata`
440
- Input data array.
441
- """
442
- self.data = data
443
-
444
- def get_data(self):
445
- """
446
- Get the input data array.
447
-
448
- Returns
449
- -------
450
- ndarray or None
451
- Input data array.
452
- """
453
- return self.data
454
-
455
- def set_center(self, center):
456
- """
457
- Set the central coordinate for the linecut.
458
-
459
- Parameters
460
- ----------
461
- center : tuple
462
- Central coordinate around which to perform the linecut.
463
- """
464
- self.center = tuple(float(i) for i in center) if center is not None else None
465
-
466
- def set_window(self, window, axis=None, verbose=False):
467
- """
468
- Set the extents of the integration window.
469
-
470
- Parameters
471
- ----------
472
- window : tuple
473
- Extents of the window for integration along each axis.
474
- axis : int or None, optional
475
- The axis along which to perform the linecut. If not specified, the value from the
476
- object's attribute will be used.
477
- verbose : bool, optional
478
- Enables printout of linecut axis and integrated axes. Default False.
479
-
480
- """
481
- self.window = tuple(float(i) for i in window) if window is not None else None
482
-
483
- # Determine the axis for integration
484
- self.axis = window.index(max(window)) if axis is None else axis
485
-
486
- # Determine the integrated axes (axes other than the integration axis)
487
- self.integrated_axes = tuple(i for i in range(self.data.ndim) if i != self.axis)
488
-
489
- if verbose:
490
- print("Linecut axis: " + str(self.data.axes[self.axis]))
491
- print("Integrated axes: " + str([self.data.axes[axis]
492
- for axis in self.integrated_axes]))
493
-
494
- def get_window(self):
495
- """
496
- Get the extents of the integration window.
497
-
498
- Returns
499
- -------
500
- tuple or None
501
- Extents of the integration window.
502
- """
503
- return self.window
504
-
505
- def cut_data(self, center=None, window=None, axis=None, verbose=False):
506
- """
507
- Reduces data to a 1D linecut with integration extents specified by the
508
- window about a central coordinate.
509
-
510
- Parameters
511
- ----------
512
- center : float or None, optional
513
- Central coordinate for the linecut. If not specified, the value from the object's
514
- attribute will be used.
515
- window : tuple or None, optional
516
- Integration window extents around the central coordinate. If not specified, the value
517
- from the object's attribute will be used.
518
- axis : int or None, optional
519
- The axis along which to perform the linecut. If not specified, the value from the
520
- object's attribute will be used.
521
- verbose : bool
522
- Enables printout of linecut axis and integrated axes. Default False.
523
-
524
- Returns
525
- -------
526
- integrated_data : :class:`nexusformat.nexus.NXdata`
527
- 1D linecut data after integration.
528
-
529
- """
530
-
531
- # Extract necessary attributes from the object
532
- data = self.data
533
- center = center if center is not None else self.center
534
- self.set_center(center)
535
- window = window if window is not None else self.window
536
- self.set_window(window, axis, verbose)
537
-
538
- # Convert the center to a tuple of floats
539
- center = tuple(float(c) for c in center)
540
-
541
- # Calculate the start and stop indices for slicing the data
542
- start = np.subtract(center, window)
543
- stop = np.add(center, window)
544
- slice_obj = tuple(slice(s, e) for s, e in zip(start, stop))
545
- self.integration_window = slice_obj
546
-
547
- # Perform the data cut
548
- self.integration_volume = data[slice_obj]
549
- self.integration_volume.nxname = data.nxname
550
-
551
- # Perform integration along the integrated axes
552
- integrated_data = np.sum(self.integration_volume[self.integration_volume.signal].nxdata,
553
- axis=self.integrated_axes)
554
-
555
- # Create an NXdata object for the linecut data
556
- self.linecut = NXdata(NXfield(integrated_data, name=self.integration_volume.signal),
557
- self.integration_volume[self.integration_volume.axes[self.axis]])
558
- self.linecut.nxname = self.integration_volume.nxname
559
-
560
- return self.linecut
561
-
562
- def highlight_integration_window(self, data=None, label=None, highlight_color='red', **kwargs):
563
- """
564
- Plots integration window highlighted on the three principal cross sections of the first
565
- temperature dataset.
566
-
567
- Parameters
568
- ----------
569
- data : array-like, optional
570
- The 2D heatmap dataset to plot. If not provided, the dataset stored in `self.data` will
571
- be used.
572
- label : str, optional
573
- The label for the integration window plot.
574
- highlight_color : str, optional
575
- The edge color used to highlight the integration window. Default is 'red'.
576
- **kwargs : keyword arguments, optional
577
- Additional keyword arguments to customize the plot.
578
-
579
- """
580
- data = self.data if data is None else data
581
- center = self.center
582
- window = self.window
583
-
584
- # Create a figure and subplots
585
- fig, axes = plt.subplots(1, 3, figsize=(15, 4))
586
-
587
- # Plot cross section 1
588
- slice_obj = [slice(None)] * data.ndim
589
- slice_obj[2] = center[2]
590
-
591
- p1 = plot_slice(data[slice_obj],
592
- X=data[data.axes[0]],
593
- Y=data[data.axes[1]],
594
- ax=axes[0],
595
- **kwargs)
596
- ax = axes[0]
597
- rect_diffuse = patches.Rectangle(
598
- (center[0] - window[0],
599
- center[1] - window[1]),
600
- 2 * window[0], 2 * window[1],
601
- linewidth=1, edgecolor=highlight_color,
602
- facecolor='none', transform=p1.get_transform(), label=label,
603
- )
604
- ax.add_patch(rect_diffuse)
605
-
606
- # Plot cross section 2
607
- slice_obj = [slice(None)] * data.ndim
608
- slice_obj[1] = center[1]
609
-
610
- p2 = plot_slice(data[slice_obj],
611
- X=data[data.axes[0]],
612
- Y=data[data.axes[2]],
613
- ax=axes[1],
614
- **kwargs)
615
- ax = axes[1]
616
- rect_diffuse = patches.Rectangle(
617
- (center[0] - window[0],
618
- center[2] - window[2]),
619
- 2 * window[0], 2 * window[2],
620
- linewidth=1, edgecolor=highlight_color,
621
- facecolor='none', transform=p2.get_transform(), label=label,
622
- )
623
- ax.add_patch(rect_diffuse)
624
-
625
- # Plot cross section 3
626
- slice_obj = [slice(None)] * data.ndim
627
- slice_obj[0] = center[0]
628
-
629
- p3 = plot_slice(data[slice_obj],
630
- X=data[data.axes[1]],
631
- Y=data[data.axes[2]],
632
- ax=axes[2],
633
- **kwargs)
634
- ax = axes[2]
635
- rect_diffuse = patches.Rectangle(
636
- (center[1] - window[1],
637
- center[2] - window[2]),
638
- 2 * window[1], 2 * window[2],
639
- linewidth=1, edgecolor=highlight_color,
640
- facecolor='none', transform=p3.get_transform(), label=label,
641
- )
642
- ax.add_patch(rect_diffuse)
643
-
644
- # Adjust subplot padding
645
- fig.subplots_adjust(wspace=0.5)
646
-
647
- if label is not None:
648
- [ax.legend() for ax in axes]
649
-
650
- plt.show()
651
-
652
- return p1, p2, p3
653
-
654
- def plot_integration_window(self, **kwargs):
655
- """
656
- Plots the three principal cross-sections of the integration volume on a single figure.
657
-
658
- Parameters
659
- ----------
660
- **kwargs : keyword arguments, optional
661
- Additional keyword arguments to customize the plot.
662
- """
663
- data = self.integration_volume
664
- center = self.center
665
-
666
- fig, axes = plt.subplots(1, 3, figsize=(15, 4))
667
-
668
- # Plot cross section 1
669
- slice_obj = [slice(None)] * data.ndim
670
- slice_obj[2] = center[2]
671
- p1 = plot_slice(data[slice_obj],
672
- X=data[data.axes[0]],
673
- Y=data[data.axes[1]],
674
- ax=axes[0],
675
- **kwargs)
676
- axes[0].set_aspect(len(data[data.axes[0]].nxdata) / len(data[data.axes[1]].nxdata))
677
-
678
- # Plot cross section 2
679
- slice_obj = [slice(None)] * data.ndim
680
- slice_obj[1] = center[1]
681
- p3 = plot_slice(data[slice_obj],
682
- X=data[data.axes[0]],
683
- Y=data[data.axes[2]],
684
- ax=axes[1],
685
- **kwargs)
686
- axes[1].set_aspect(len(data[data.axes[0]].nxdata) / len(data[data.axes[2]].nxdata))
687
-
688
- # Plot cross section 3
689
- slice_obj = [slice(None)] * data.ndim
690
- slice_obj[0] = center[0]
691
- p2 = plot_slice(data[slice_obj],
692
- X=data[data.axes[1]],
693
- Y=data[data.axes[2]],
694
- ax=axes[2],
695
- **kwargs)
696
- axes[2].set_aspect(len(data[data.axes[1]].nxdata) / len(data[data.axes[2]].nxdata))
697
-
698
- # Adjust subplot padding
699
- fig.subplots_adjust(wspace=0.3)
700
-
701
- plt.show()
702
-
703
- return p1, p2, p3
704
-
705
-
706
- def reciprocal_lattice_params(lattice_params):
707
- """
708
- Calculate the reciprocal lattice parameters from the given direct lattice parameters.
709
-
710
- Parameters
711
- ----------
712
- lattice_params : tuple
713
- A tuple containing the direct lattice parameters (a, b, c, alpha, beta, gamma), where
714
- a, b, and c are the magnitudes of the lattice vectors, and alpha, beta, and gamma are the
715
- angles between them in degrees.
716
-
717
- Returns
718
- -------
719
- tuple
720
- A tuple containing the reciprocal lattice parameters (a*, b*, c*, alpha*, beta*, gamma*),
721
- where a*, b*, and c* are the magnitudes of the reciprocal lattice vectors, and alpha*,
722
- beta*, and gamma* are the angles between them in degrees.
723
- """
724
- a_mag, b_mag, c_mag, alpha, beta, gamma = lattice_params
725
- # Convert angles to radians
726
- alpha = np.deg2rad(alpha)
727
- beta = np.deg2rad(beta)
728
- gamma = np.deg2rad(gamma)
729
-
730
- # Calculate unit cell volume
731
- V = a_mag * b_mag * c_mag * np.sqrt(
732
- 1 - np.cos(alpha) ** 2 - np.cos(beta) ** 2 - np.cos(gamma) ** 2
733
- + 2 * np.cos(alpha) * np.cos(beta) * np.cos(gamma)
734
- )
735
-
736
- # Calculate reciprocal lattice parameters
737
- a_star = (b_mag * c_mag * np.sin(alpha)) / V
738
- b_star = (a_mag * c_mag * np.sin(beta)) / V
739
- c_star = (a_mag * b_mag * np.sin(gamma)) / V
740
- alpha_star = np.rad2deg(np.arccos((np.cos(beta) * np.cos(gamma) - np.cos(alpha))
741
- / (np.sin(beta) * np.sin(gamma))))
742
- beta_star = np.rad2deg(np.arccos((np.cos(alpha) * np.cos(gamma) - np.cos(beta))
743
- / (np.sin(alpha) * np.sin(gamma))))
744
- gamma_star = np.rad2deg(np.arccos((np.cos(alpha) * np.cos(beta) - np.cos(gamma))
745
- / (np.sin(alpha) * np.sin(beta))))
746
-
747
- return a_star, b_star, c_star, alpha_star, beta_star, gamma_star
748
-
749
-
750
- def rotate_data(data, lattice_angle, rotation_angle, rotation_axis, printout=False):
751
- """
752
- Rotates 3D data around a specified axis.
753
-
754
- Parameters
755
- ----------
756
- data : :class:`nexusformat.nexus.NXdata`
757
- Input data.
758
- lattice_angle : float
759
- Angle between the two in-plane lattice axes in degrees.
760
- rotation_angle : float
761
- Angle of rotation in degrees.
762
- rotation_axis : int
763
- Axis of rotation (0, 1, or 2).
764
- printout : bool, optional
765
- Enables printout of rotation progress. If set to True, information
766
- about each rotation slice will be printed to the console, indicating
767
- the axis being rotated and the corresponding coordinate value.
768
- Defaults to False.
769
-
770
-
771
- Returns
772
- -------
773
- rotated_data : :class:`nexusformat.nexus.NXdata`
774
- Rotated data as an NXdata object.
775
- """
776
- # Define output array
777
- output_array = np.zeros(data[data.signal].shape)
778
-
779
- # Define transformation
780
- skew_angle_adj = 90 - lattice_angle
781
- t = Affine2D()
782
- # Scale y-axis to preserve norm while shearing
783
- t += Affine2D().scale(1, np.cos(skew_angle_adj * np.pi / 180))
784
- # Shear along x-axis
785
- t += Affine2D().skew_deg(skew_angle_adj, 0)
786
- # Return to original y-axis scaling
787
- t += Affine2D().scale(1, np.cos(skew_angle_adj * np.pi / 180)).inverted()
788
-
789
- for i in range(len(data[data.axes[rotation_axis]])):
790
- if printout:
791
- print(f'\rRotating {data.axes[rotation_axis]}'
792
- f'={data[data.axes[rotation_axis]][i]}... ',
793
- end='', flush=True)
794
- # Identify current slice
795
- if rotation_axis == 0:
796
- sliced_data = data[i, :, :]
797
- elif rotation_axis == 1:
798
- sliced_data = data[:, i, :]
799
- elif rotation_axis == 2:
800
- sliced_data = data[:, :, i]
801
- else:
802
- sliced_data = None
803
-
804
- p = Padder(sliced_data)
805
- padding = tuple(len(sliced_data[axis]) for axis in sliced_data.axes)
806
- counts = p.pad(padding)
807
- counts = p.padded[p.padded.signal]
808
-
809
- counts_skewed = ndimage.affine_transform(counts,
810
- t.inverted().get_matrix()[:2, :2],
811
- offset=[counts.shape[0] / 2
812
- * np.sin(skew_angle_adj * np.pi / 180),
813
- 0],
814
- order=0,
815
- )
816
- scale1 = np.cos(skew_angle_adj * np.pi / 180)
817
- counts_scaled1 = ndimage.affine_transform(counts_skewed,
818
- Affine2D().scale(scale1, 1).get_matrix()[:2, :2],
819
- offset=[(1 - scale1) * counts.shape[0] / 2, 0],
820
- order=0,
821
- )
822
- scale2 = counts.shape[0] / counts.shape[1]
823
- counts_scaled2 = ndimage.affine_transform(counts_scaled1,
824
- Affine2D().scale(scale2, 1).get_matrix()[:2, :2],
825
- offset=[(1 - scale2) * counts.shape[0] / 2, 0],
826
- order=0,
827
- )
828
-
829
- counts_rotated = ndimage.rotate(counts_scaled2, rotation_angle, reshape=False, order=0)
830
-
831
- counts_unscaled2 = ndimage.affine_transform(counts_rotated,
832
- Affine2D().scale(
833
- scale2, 1
834
- ).inverted().get_matrix()[:2, :2],
835
- offset=[-(1 - scale2) * counts.shape[
836
- 0] / 2 / scale2, 0],
837
- order=0,
838
- )
839
-
840
- counts_unscaled1 = ndimage.affine_transform(counts_unscaled2,
841
- Affine2D().scale(
842
- scale1, 1
843
- ).inverted().get_matrix()[:2, :2],
844
- offset=[-(1 - scale1) * counts.shape[
845
- 0] / 2 / scale1, 0],
846
- order=0,
847
- )
848
-
849
- counts_unskewed = ndimage.affine_transform(counts_unscaled1,
850
- t.get_matrix()[:2, :2],
851
- offset=[
852
- (-counts.shape[0] / 2
853
- * np.sin(skew_angle_adj * np.pi / 180)),
854
- 0],
855
- order=0,
856
- )
857
-
858
- counts_unpadded = p.unpad(counts_unskewed)
859
-
860
- # Write current slice
861
- if rotation_axis == 0:
862
- output_array[i, :, :] = counts_unpadded
863
- elif rotation_axis == 1:
864
- output_array[:, i, :] = counts_unpadded
865
- elif rotation_axis == 2:
866
- output_array[:, :, i] = counts_unpadded
867
- print('\nDone.')
868
- return NXdata(NXfield(output_array, name=p.padded.signal),
869
- (data[data.axes[0]], data[data.axes[1]], data[data.axes[2]]))
870
-
871
-
872
- def rotate_data_2D(data, lattice_angle, rotation_angle):
873
- """
874
- Rotates 2D data.
875
-
876
- Parameters
877
- ----------
878
- data : :class:`nexusformat.nexus.NXdata`
879
- Input data.
880
- lattice_angle : float
881
- Angle between the two in-plane lattice axes in degrees.
882
- rotation_angle : float
883
- Angle of rotation in degrees.
884
-
885
-
886
- Returns
887
- -------
888
- rotated_data : :class:`nexusformat.nexus.NXdata`
889
- Rotated data as an NXdata object.
890
- """
891
-
892
- # Define transformation
893
- skew_angle_adj = 90 - lattice_angle
894
- t = Affine2D()
895
- # Scale y-axis to preserve norm while shearing
896
- t += Affine2D().scale(1, np.cos(skew_angle_adj * np.pi / 180))
897
- # Shear along x-axis
898
- t += Affine2D().skew_deg(skew_angle_adj, 0)
899
- # Return to original y-axis scaling
900
- t += Affine2D().scale(1, np.cos(skew_angle_adj * np.pi / 180)).inverted()
901
-
902
- p = Padder(data)
903
- padding = tuple(len(data[axis]) for axis in data.axes)
904
- counts = p.pad(padding)
905
- counts = p.padded[p.padded.signal]
906
-
907
- counts_skewed = ndimage.affine_transform(counts,
908
- t.inverted().get_matrix()[:2, :2],
909
- offset=[counts.shape[0] / 2
910
- * np.sin(skew_angle_adj * np.pi / 180), 0],
911
- order=0,
912
- )
913
- scale1 = np.cos(skew_angle_adj * np.pi / 180)
914
- counts_scaled1 = ndimage.affine_transform(counts_skewed,
915
- Affine2D().scale(scale1, 1).get_matrix()[:2, :2],
916
- offset=[(1 - scale1) * counts.shape[0] / 2, 0],
917
- order=0,
918
- )
919
- scale2 = counts.shape[0] / counts.shape[1]
920
- counts_scaled2 = ndimage.affine_transform(counts_scaled1,
921
- Affine2D().scale(scale2, 1).get_matrix()[:2, :2],
922
- offset=[(1 - scale2) * counts.shape[0] / 2, 0],
923
- order=0,
924
- )
925
-
926
- counts_rotated = ndimage.rotate(counts_scaled2, rotation_angle, reshape=False, order=0)
927
-
928
- counts_unscaled2 = ndimage.affine_transform(counts_rotated,
929
- Affine2D().scale(
930
- scale2, 1
931
- ).inverted().get_matrix()[:2, :2],
932
- offset=[-(1 - scale2) * counts.shape[
933
- 0] / 2 / scale2, 0],
934
- order=0,
935
- )
936
-
937
- counts_unscaled1 = ndimage.affine_transform(counts_unscaled2,
938
- Affine2D().scale(
939
- scale1, 1
940
- ).inverted().get_matrix()[:2, :2],
941
- offset=[-(1 - scale1) * counts.shape[
942
- 0] / 2 / scale1, 0],
943
- order=0,
944
- )
945
-
946
- counts_unskewed = ndimage.affine_transform(counts_unscaled1,
947
- t.get_matrix()[:2, :2],
948
- offset=[
949
- (-counts.shape[0] / 2
950
- * np.sin(skew_angle_adj * np.pi / 180)),
951
- 0],
952
- order=0,
953
- )
954
-
955
- counts_unpadded = p.unpad(counts_unskewed)
956
-
957
- print('\nDone.')
958
- return NXdata(NXfield(counts_unpadded, name=p.padded.signal),
959
- (data[data.axes[0]], data[data.axes[1]]))
960
-
961
-
962
- class Padder:
963
- """
964
- A class to symmetrically pad and unpad datasets with a region of zeros.
965
-
966
- Attributes
967
- ----------
968
- data : NXdata or None
969
- The input data to be padded.
970
- padded : NXdata or None
971
- The padded data with symmetric zero padding.
972
- padding : tuple or None
973
- The number of zero-value pixels added along each edge of the array.
974
- steps : tuple or None
975
- The step sizes along each axis of the dataset.
976
- maxes : tuple or None
977
- The maximum values along each axis of the dataset.
978
-
979
- Methods
980
- -------
981
- set_data(data)
982
- Set the input data for padding.
983
- pad(padding)
984
- Symmetrically pads the data with zero values.
985
- save(fout_name=None)
986
- Saves the padded dataset to a .nxs file.
987
- unpad(data)
988
- Removes the padded region from the data.
989
- """
990
-
991
- def __init__(self, data=None):
992
- """
993
- Initialize the Padder object.
994
-
995
- Parameters
996
- ----------
997
- data : NXdata, optional
998
- The input data to be padded. If provided, the `set_data` method
999
- is called to set the data.
1000
- """
1001
- self.padded = None
1002
- self.padding = None
1003
- if data is not None:
1004
- self.set_data(data)
1005
-
1006
- def set_data(self, data):
1007
- """
1008
- Set the input data for padding.
1009
-
1010
- Parameters
1011
- ----------
1012
- data : NXdata
1013
- The input data to be padded.
1014
- """
1015
- self.data = data
1016
-
1017
- self.steps = tuple((data[axis].nxdata[1] - data[axis].nxdata[0])
1018
- for axis in data.axes)
1019
-
1020
- # Absolute value of the maximum value; assumes the domain of the input
1021
- # is symmetric (eg, -H_min = H_max)
1022
- self.maxes = tuple(data[axis].nxdata.max() for axis in data.axes)
1023
-
1024
- def pad(self, padding):
1025
- """
1026
- Symmetrically pads the data with zero values.
1027
-
1028
- Parameters
1029
- ----------
1030
- padding : tuple
1031
- The number of zero-value pixels to add along each edge of the array.
1032
-
1033
- Returns
1034
- -------
1035
- NXdata
1036
- The padded data with symmetric zero padding.
1037
- """
1038
- data = self.data
1039
- self.padding = padding
1040
-
1041
- padded_shape = tuple(data[data.signal].nxdata.shape[i]
1042
- + self.padding[i] * 2 for i in range(data.ndim))
1043
-
1044
- # Create padded dataset
1045
- padded = np.zeros(padded_shape)
1046
-
1047
- slice_obj = [slice(None)] * data.ndim
1048
- for i, _ in enumerate(slice_obj):
1049
- slice_obj[i] = slice(self.padding[i], -self.padding[i], None)
1050
- slice_obj = tuple(slice_obj)
1051
- padded[slice_obj] = data[data.signal].nxdata
1052
-
1053
- padmaxes = tuple(self.maxes[i] + self.padding[i] * self.steps[i]
1054
- for i in range(data.ndim))
1055
-
1056
- padded = NXdata(NXfield(padded, name=data.signal),
1057
- tuple(NXfield(np.linspace(-padmaxes[i], padmaxes[i], padded_shape[i]),
1058
- name=data.axes[i])
1059
- for i in range(data.ndim)))
1060
-
1061
- self.padded = padded
1062
- return padded
1063
-
1064
- def save(self, fout_name=None):
1065
- """
1066
- Saves the padded dataset to a .nxs file.
1067
-
1068
- Parameters
1069
- ----------
1070
- fout_name : str, optional
1071
- The output file name. Default is padded_(Hpadding)_(Kpadding)_(Lpadding).nxs
1072
- """
1073
- padH, padK, padL = self.padding
1074
-
1075
- # Save padded dataset
1076
- print("Saving padded dataset...")
1077
- f = NXroot()
1078
- f['entry'] = NXentry()
1079
- f['entry']['data'] = self.padded
1080
- if fout_name is None:
1081
- fout_name = 'padded_' + str(padH) + '_' + str(padK) + '_' + str(padL) + '.nxs'
1082
- nxsave(fout_name, f)
1083
- print("Output file saved to: " + os.path.join(os.getcwd(), fout_name))
1084
-
1085
- def unpad(self, data):
1086
- """
1087
- Removes the padded region from the data.
1088
-
1089
- Parameters
1090
- ----------
1091
- data : ndarray or NXdata
1092
- The padded data from which to remove the padding.
1093
-
1094
- Returns
1095
- -------
1096
- ndarray or NXdata
1097
- The unpadded data, with the symmetric padding region removed.
1098
- """
1099
- slice_obj = [slice(None)] * data.ndim
1100
- for i in range(data.ndim):
1101
- slice_obj[i] = slice(self.padding[i], -self.padding[i], None)
1102
- slice_obj = tuple(slice_obj)
1103
- return data[slice_obj]
1104
-
1105
-
1106
- def load_discus_nxs(path):
1107
- """
1108
- Load .nxs format data from the DISCUS program (by T. Proffen and R. Neder)
1109
- and convert it to the CHESS format.
1110
-
1111
- Parameters
1112
- ----------
1113
- path : str
1114
- The file path to the .nxs file generated by DISCUS.
1115
-
1116
- Returns
1117
- -------
1118
- NXdata
1119
- The data converted to the CHESS format, with axes labeled 'H', 'K', and 'L',
1120
- and the signal labeled 'counts'.
1121
-
1122
- """
1123
- filename = path
1124
- root = nxload(filename)
1125
- hlim, klim, llim = root.lower_limits
1126
- hstep, kstep, lstep = root.step_sizes
1127
- h = NXfield(np.linspace(hlim, -hlim, int(np.abs(hlim * 2) / hstep) + 1), name='H')
1128
- k = NXfield(np.linspace(klim, -klim, int(np.abs(klim * 2) / kstep) + 1), name='K')
1129
- l = NXfield(np.linspace(llim, -llim, int(np.abs(llim * 2) / lstep) + 1), name='L')
1130
- data = NXdata(NXfield(root.data[:, :, :], name='counts'), (h, k, l))
1131
-
1132
- return data
1
+ """
2
+ Reduces scattering data into 2D and 1D datasets.
3
+ """
4
+ import os
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from matplotlib.transforms import Affine2D
8
+ from matplotlib.markers import MarkerStyle
9
+ from matplotlib.ticker import MultipleLocator
10
+ from matplotlib import colors
11
+ from matplotlib import patches
12
+ from IPython.display import display, Markdown
13
+ from nexusformat.nexus import NXfield, NXdata, nxload, NeXusError, NXroot, NXentry, nxsave
14
+ from scipy import ndimage
15
+
16
+ # Specify items on which users are allowed to perform standalone imports
17
+ __all__ = ['load_data', 'load_transform', 'plot_slice', 'Scissors',
18
+ 'reciprocal_lattice_params', 'rotate_data', 'rotate_data_2D',
19
+ 'convert_to_inverse_angstroms', 'array_to_nxdata', 'Padder',
20
+ 'rebin_nxdata', 'rebin_3d', 'rebin_1d']
21
+
22
+
23
+ def load_data(path, print_tree=True):
24
+ """
25
+ Load data from a NeXus file at a specified path. It is assumed that the data follows the CHESS
26
+ file structure (i.e., root/entry/data/counts, etc.).
27
+
28
+ Parameters
29
+ ----------
30
+ path : str
31
+ The path to the NeXus data file.
32
+
33
+ print_tree : bool, optional
34
+ Whether to print the data tree upon loading. Default True.
35
+
36
+ Returns
37
+ -------
38
+ data : nxdata object
39
+ The loaded data stored in a nxdata object.
40
+
41
+ """
42
+
43
+ g = nxload(path)
44
+ try:
45
+ print(g.entry.data.tree) if print_tree else None
46
+ except NeXusError:
47
+ pass
48
+
49
+ return g.entry.data
50
+
51
+
52
+ def load_transform(path, print_tree=True):
53
+ """
54
+ Load transform data from an nxrefine output file.
55
+
56
+ Parameters
57
+ ----------
58
+ path : str
59
+ The path to the transform data file.
60
+
61
+ print_tree : bool, optional
62
+ If True, prints the NeXus data tree upon loading. Default is True.
63
+
64
+ Returns
65
+ -------
66
+ data : NXdata
67
+ The loaded transform data as an NXdata object.
68
+ """
69
+
70
+ g = nxload(path)
71
+
72
+ data = NXdata(NXfield(g.entry.transform.data.nxdata.transpose(2, 1, 0), name='counts'),
73
+ (g.entry.transform.Qh, g.entry.transform.Qk, g.entry.transform.Ql))
74
+
75
+ print(data.tree) if print_tree else None
76
+
77
+ return data
78
+
79
+
80
+ def array_to_nxdata(array, data_template, signal_name=None):
81
+ """
82
+ Create an NXdata object from an input array and an NXdata template,
83
+ with an optional signal name.
84
+
85
+ Parameters
86
+ ----------
87
+ array : array-like
88
+ The data array to be included in the NXdata object.
89
+
90
+ data_template : NXdata
91
+ An NXdata object serving as a template, which provides information
92
+ about axes and other metadata.
93
+
94
+ signal_name : str, optional
95
+ The name of the signal within the NXdata object. If not provided,
96
+ the signal name is inherited from the data_template.
97
+
98
+ Returns
99
+ -------
100
+ NXdata
101
+ An NXdata object containing the input data array and associated axes
102
+ based on the template.
103
+ """
104
+ d = data_template
105
+ if signal_name is None:
106
+ signal_name = d.signal
107
+ return NXdata(NXfield(array, name=signal_name),
108
+ tuple(d[d.axes[i]] for i in range(len(d.axes))))
109
+
110
+
111
+ def rebin_3d(array):
112
+ """
113
+ Rebins a 3D NumPy array by a factor of 2 along each dimension.
114
+
115
+ This function reduces the size of the input array by averaging over non-overlapping
116
+ 2x2x2 blocks. Each dimension of the input array must be divisible by 2.
117
+
118
+ Parameters
119
+ ----------
120
+ array : np.ndarray
121
+ A 3-dimensional NumPy array to be rebinned.
122
+
123
+ Returns
124
+ -------
125
+ np.ndarray
126
+ A rebinned array with shape (N//2, M//2, L//2) if the original shape was (N, M, L).
127
+ """
128
+
129
+ # Ensure the array shape is divisible by 2 in each dimension
130
+ shape = array.shape
131
+ if any(dim % 2 != 0 for dim in shape):
132
+ raise ValueError("Each dimension of the array must be divisible by 2 to rebin.")
133
+
134
+ # Reshape the array to group the data into 2x2x2 blocks
135
+ reshaped = array.reshape(shape[0] // 2, 2, shape[1] // 2, 2, shape[2] // 2, 2)
136
+
137
+ # Average over the 2x2x2 blocks
138
+ rebinned = reshaped.mean(axis=(1, 3, 5))
139
+
140
+ return rebinned
141
+
142
+
143
+ def rebin_1d(array):
144
+ """
145
+ Rebins a 1D NumPy array by a factor of 2.
146
+
147
+ This function reduces the size of the input array by averaging over non-overlapping
148
+ pairs of elements. The input array length must be divisible by 2.
149
+
150
+ Parameters
151
+ ----------
152
+ array : np.ndarray
153
+ A 1-dimensional NumPy array to be rebinned.
154
+
155
+ Returns
156
+ -------
157
+ np.ndarray
158
+ A rebinned array with length N//2 if the original length was N.
159
+ """
160
+
161
+ # Ensure the array length is divisible by 2
162
+ if len(array) % 2 != 0:
163
+ raise ValueError("The length of the array must be divisible by 2 to rebin.")
164
+
165
+ # Reshape the array to group elements into pairs
166
+ reshaped = array.reshape(len(array) // 2, 2)
167
+
168
+ # Average over the pairs
169
+ rebinned = reshaped.mean(axis=1)
170
+
171
+ return rebinned
172
+
173
+
174
+ def rebin_nxdata(data):
175
+ """
176
+ Rebins the signal and axes of an NXdata object by a factor of 2 along each dimension.
177
+
178
+ This function first checks each axis of the input `NXdata` object:
179
+ - If the axis has an odd number of elements, the last element is excluded before rebinning.
180
+ - Then, each axis is rebinned using `rebin_1d`.
181
+
182
+ The signal array is similarly cropped to remove the last element along any dimension
183
+ with an odd shape, and then the data is averaged over 2x2x... blocks using the same
184
+ `rebin_1d` method (assumed to apply across 1D slices).
185
+
186
+ Parameters
187
+ ----------
188
+ data : NXdata
189
+ The NeXus data group containing the signal and axes to be rebinned.
190
+
191
+ Returns
192
+ -------
193
+ NXdata
194
+ A new NXdata object with signal and axes rebinned by a factor of 2 along each dimension.
195
+ """
196
+ # First, rebin axes
197
+ new_axes = []
198
+ for i in range(len(data.shape)):
199
+ if data.shape[i] % 2 == 1:
200
+ new_axes.append(
201
+ NXfield(
202
+ rebin_1d(data.nxaxes[i].nxdata[:-1]),
203
+ name=data.axes[i]
204
+ )
205
+ )
206
+ else:
207
+ new_axes.append(
208
+ NXfield(
209
+ rebin_1d(data.nxaxes[i].nxdata[:]),
210
+ name=data.axes[i]
211
+ )
212
+ )
213
+
214
+ # Second, rebin signal
215
+ data_arr = data.nxsignal.nxdata
216
+
217
+ # Crop the array if the shape is odd in any direction
218
+ slice_obj = []
219
+ for i, dim in enumerate(data_arr.shape):
220
+ if dim % 2 == 1:
221
+ slice_obj.append(slice(0, dim - 1))
222
+ else:
223
+ slice_obj.append(slice(None))
224
+
225
+ data_arr = data_arr[tuple(slice_obj)]
226
+
227
+ # Perform actual rebinning
228
+ data_arr = rebin_3d(data_arr)
229
+
230
+ return NXdata(NXfield(data_arr, name=data.signal),
231
+ tuple([axis for axis in new_axes])
232
+ )
233
+
234
+
235
+ def plot_slice(data, X=None, Y=None, sum_axis=None, transpose=False, vmin=None, vmax=None,
236
+ skew_angle=90, ax=None, xlim=None, ylim=None,
237
+ xticks=None, yticks=None, cbar=True, logscale=False,
238
+ symlogscale=False, cmap='viridis', linthresh=1,
239
+ title=None, mdheading=None, cbartitle=None,
240
+ **kwargs):
241
+ """
242
+ Plot a 2D slice of the provided dataset, with optional transformations
243
+ and customizations.
244
+
245
+ Parameters
246
+ ----------
247
+ data : :class:`nexusformat.nexus.NXdata` or ndarray
248
+ The dataset to plot. Can be an `NXdata` object or a `numpy` array.
249
+
250
+ X : NXfield, optional
251
+ The X axis values. If None, a default range from 0 to the number of
252
+ columns in `data` is used.
253
+
254
+ Y : NXfield, optional
255
+ The Y axis values. If None, a default range from 0 to the number of
256
+ rows in `data` is used.
257
+
258
+ sum_axis : int, optional
259
+ If the input data is 3D, this specifies the axis to sum over in order
260
+ to reduce the data to 2D for plotting. Required if `data` has three dimensions.
261
+
262
+ transpose : bool, optional
263
+ If True, transpose the dataset and its axes before plotting.
264
+ Default is False.
265
+
266
+ vmin : float, optional
267
+ The minimum value for the color scale. If not provided, the minimum
268
+ value of the dataset is used.
269
+
270
+ vmax : float, optional
271
+ The maximum value for the color scale. If not provided, the maximum
272
+ value of the dataset is used.
273
+
274
+ skew_angle : float, optional
275
+ The angle in degrees to shear the plot. Default is 90 degrees (no skew).
276
+
277
+ ax : matplotlib.axes.Axes, optional
278
+ The `matplotlib` axis to plot on. If None, a new figure and axis will
279
+ be created.
280
+
281
+ xlim : tuple, optional
282
+ The limits for the x-axis. If None, the limits are set automatically
283
+ based on the data.
284
+
285
+ ylim : tuple, optional
286
+ The limits for the y-axis. If None, the limits are set automatically
287
+ based on the data.
288
+
289
+ xticks : float or list of float, optional
290
+ The major tick interval or specific tick locations for the x-axis.
291
+ Default is to use a minor tick interval of 1.
292
+
293
+ yticks : float or list of float, optional
294
+ The major tick interval or specific tick locations for the y-axis.
295
+ Default is to use a minor tick interval of 1.
296
+
297
+ cbar : bool, optional
298
+ Whether to include a colorbar. Default is True.
299
+
300
+ logscale : bool, optional
301
+ Whether to use a logarithmic color scale. Default is False.
302
+
303
+ symlogscale : bool, optional
304
+ Whether to use a symmetrical logarithmic color scale. Default is False.
305
+
306
+ cmap : str or Colormap, optional
307
+ The colormap to use for the plot. Default is 'viridis'.
308
+
309
+ linthresh : float, optional
310
+ The linear threshold for symmetrical logarithmic scaling. Default is 1.
311
+
312
+ title : str, optional
313
+ The title for the plot. If None, no title is set.
314
+
315
+ mdheading : str, optional
316
+ A Markdown heading to display above the plot. If 'None' or not provided,
317
+ no heading is displayed.
318
+
319
+ cbartitle : str, optional
320
+ The title for the colorbar. If None, the colorbar label will be set to
321
+ the name of the signal.
322
+
323
+ **kwargs
324
+ Additional keyword arguments passed to `pcolormesh`.
325
+
326
+ Returns
327
+ -------
328
+ p : :class:`matplotlib.collections.QuadMesh`
329
+ The `matplotlib` QuadMesh object representing the plotted data.
330
+ """
331
+ is_array = False
332
+ is_nxdata = False
333
+
334
+ if isinstance(data, np.ndarray):
335
+ is_array = True
336
+ elif isinstance(data, (NXdata, NXfield)):
337
+ is_nxdata = True
338
+ else:
339
+ raise TypeError(f"Unexpected data type: {type(data)}. "
340
+ f"Supported types are np.ndarray and NXdata.")
341
+
342
+ # If three-dimensional, demand sum_axis to reduce to two dimensions.
343
+ if is_array and len(data.shape) == 3:
344
+ assert sum_axis is not None, "sum_axis must be specified when data is a 3D array"
345
+
346
+ data = data.sum(axis=sum_axis)
347
+
348
+ if is_nxdata and len(data.shape) == 3:
349
+ assert sum_axis is not None, "sum_axis must be specified when data is a 3D array"
350
+
351
+ arr = data.nxsignal.nxdata
352
+ arr = arr.sum(axis=sum_axis)
353
+
354
+ # Create a 2D template from the original nxdata
355
+ slice_obj = [slice(None)] * len(data.shape)
356
+ slice_obj[sum_axis] = 0
357
+
358
+ # Use the 2D template to create a new nxdata
359
+ data = array_to_nxdata(arr, data[slice_obj])
360
+
361
+ if is_array:
362
+ if X is None:
363
+ X = NXfield(np.linspace(0, data.shape[0], data.shape[0]), name='x')
364
+ if Y is None:
365
+ Y = NXfield(np.linspace(0, data.shape[1], data.shape[1]), name='y')
366
+ if transpose:
367
+ X, Y = Y, X
368
+ data = data.transpose()
369
+ data = NXdata(NXfield(data, name='value'), (X, Y))
370
+ data_arr = data[data.signal].nxdata.transpose()
371
+ elif is_nxdata:
372
+ if X is None:
373
+ X = data[data.axes[0]]
374
+ if Y is None:
375
+ Y = data[data.axes[1]]
376
+ if transpose:
377
+ X, Y = Y, X
378
+ data = data.transpose()
379
+ data_arr = data[data.signal].nxdata.transpose()
380
+
381
+ # Display Markdown heading
382
+ if mdheading is None:
383
+ pass
384
+ elif mdheading == "None":
385
+ display(Markdown('### Figure'))
386
+ else:
387
+ display(Markdown('### Figure - ' + mdheading))
388
+
389
+ # Inherit axes if user provides some
390
+ if ax is not None:
391
+ fig = ax.get_figure()
392
+ # Otherwise set up some default axes
393
+ else:
394
+ fig = plt.figure()
395
+ ax = fig.add_axes([0, 0, 1, 1])
396
+
397
+ # If limits not provided, use extrema
398
+ if vmin is None:
399
+ vmin = data_arr.min()
400
+ if vmax is None:
401
+ vmax = data_arr.max()
402
+
403
+ # Set norm (linear scale, logscale, or symlogscale)
404
+ norm = colors.Normalize(vmin=vmin, vmax=vmax) # Default: linear scale
405
+
406
+ if symlogscale:
407
+ norm = colors.SymLogNorm(linthresh=linthresh, vmin=-1 * vmax, vmax=vmax)
408
+ elif logscale:
409
+ norm = colors.LogNorm(vmin=vmin, vmax=vmax)
410
+
411
+ # Plot data
412
+ p = ax.pcolormesh(X.nxdata, Y.nxdata, data_arr, shading='auto', norm=norm, cmap=cmap, **kwargs)
413
+
414
+ ## Transform data to new coordinate system if necessary
415
+ # Correct skew angle
416
+ skew_angle_adj = 90 - skew_angle
417
+ # Create blank 2D affine transformation
418
+ t = Affine2D()
419
+ # Scale y-axis to preserve norm while shearing
420
+ t += Affine2D().scale(1, np.cos(skew_angle_adj * np.pi / 180))
421
+ # Shear along x-axis
422
+ t += Affine2D().skew_deg(skew_angle_adj, 0)
423
+ # Return to original y-axis scaling
424
+ t += Affine2D().scale(1, np.cos(skew_angle_adj * np.pi / 180)).inverted()
425
+ ## Correct for x-displacement after shearing
426
+ # If ylims provided, use those
427
+ if ylim is not None:
428
+ # Set ylims
429
+ ax.set(ylim=ylim)
430
+ ymin, ymax = ylim
431
+ # Else, use current ylims
432
+ else:
433
+ ymin, ymax = ax.get_ylim()
434
+ # Use ylims to calculate translation (necessary to display axes in correct position)
435
+ p.set_transform(t
436
+ + Affine2D().translate(-ymin * np.sin(skew_angle_adj * np.pi / 180), 0)
437
+ + ax.transData)
438
+
439
+ # Set x limits
440
+ if xlim is not None:
441
+ xmin, xmax = xlim
442
+ else:
443
+ xmin, xmax = ax.get_xlim()
444
+ if skew_angle <= 90:
445
+ ax.set(xlim=(xmin, xmax + (ymax - ymin) / np.tan((90 - skew_angle_adj) * np.pi / 180)))
446
+ else:
447
+ ax.set(xlim=(xmin - (ymax - ymin) / np.tan((skew_angle_adj - 90) * np.pi / 180), xmax))
448
+
449
+ # Correct aspect ratio for the x/y axes after transformation
450
+ ax.set(aspect=np.cos(skew_angle_adj * np.pi / 180))
451
+
452
+ # Add tick marks all around
453
+ ax.tick_params(direction='in', top=True, right=True, which='both')
454
+
455
+ # Set tick locations
456
+ if xticks is None:
457
+ # Add default minor ticks
458
+ ax.xaxis.set_minor_locator(MultipleLocator(1))
459
+ else:
460
+ # Otherwise use user provided values
461
+ ax.xaxis.set_major_locator(MultipleLocator(xticks))
462
+ ax.xaxis.set_minor_locator(MultipleLocator(1))
463
+ if yticks is None:
464
+ # Add default minor ticks
465
+ ax.yaxis.set_minor_locator(MultipleLocator(1))
466
+ else:
467
+ # Otherwise use user provided values
468
+ ax.yaxis.set_major_locator(MultipleLocator(yticks))
469
+ ax.yaxis.set_minor_locator(MultipleLocator(1))
470
+
471
+ # Apply transform to tick marks
472
+ for i in range(0, len(ax.xaxis.get_ticklines())):
473
+ # Tick marker
474
+ m = MarkerStyle(3)
475
+ line = ax.xaxis.get_majorticklines()[i]
476
+ if i % 2:
477
+ # Top ticks (translation here makes their direction="in")
478
+ m._transform.set(Affine2D().translate(0, -1) + Affine2D().skew_deg(skew_angle_adj, 0))
479
+ # This first method shifts the top ticks horizontally to match the skew angle.
480
+ # This does not look good in all cases.
481
+ # line.set_transform(Affine2D().translate((ymax-ymin)*np.sin(skew_angle*np.pi/180),0) +
482
+ # line.get_transform())
483
+ # This second method skews the tick marks in place and
484
+ # can sometimes lead to them being misaligned.
485
+ line.set_transform(line.get_transform()) # This does nothing
486
+ else:
487
+ # Bottom ticks
488
+ m._transform.set(Affine2D().skew_deg(skew_angle_adj, 0))
489
+
490
+ line.set_marker(m)
491
+
492
+ for i in range(0, len(ax.xaxis.get_minorticklines())):
493
+ m = MarkerStyle(2)
494
+ line = ax.xaxis.get_minorticklines()[i]
495
+ if i % 2:
496
+ m._transform.set(Affine2D().translate(0, -1) + Affine2D().skew_deg(skew_angle_adj, 0))
497
+ else:
498
+ m._transform.set(Affine2D().skew_deg(skew_angle_adj, 0))
499
+
500
+ line.set_marker(m)
501
+
502
+ if cbar:
503
+ colorbar = fig.colorbar(p)
504
+ if cbartitle is None:
505
+ colorbar.set_label(data.signal)
506
+
507
+ ax.set(
508
+ xlabel=X.nxname,
509
+ ylabel=Y.nxname,
510
+ )
511
+
512
+ if title is not None:
513
+ ax.set_title(title)
514
+
515
+ # Return the quadmesh object
516
+ return p
517
+
518
+
519
+ class Scissors:
520
+ """
521
+ Scissors class provides functionality for reducing data to a 1D linecut using an integration
522
+ window.
523
+
524
+ Attributes
525
+ ----------
526
+ data : :class:`nexusformat.nexus.NXdata` or None
527
+ Input :class:`nexusformat.nexus.NXdata`.
528
+ center : tuple or None
529
+ Central coordinate around which to perform the linecut.
530
+ window : tuple or None
531
+ Extents of the window for integration along each axis.
532
+ axis : int or None
533
+ Axis along which to perform the integration.
534
+ integration_volume : :class:`nexusformat.nexus.NXdata` or None
535
+ Data array after applying the integration window.
536
+ integrated_axes : tuple or None
537
+ Indices of axes that were integrated.
538
+ linecut : :class:`nexusformat.nexus.NXdata` or None
539
+ 1D linecut data after integration.
540
+ integration_window : tuple or None
541
+ Slice object representing the integration window in the data array.
542
+
543
+ Methods
544
+ -------
545
+ set_data(data)
546
+ Set the input :class:`nexusformat.nexus.NXdata`.
547
+ get_data()
548
+ Get the input :class:`nexusformat.nexus.NXdata`.
549
+ set_center(center)
550
+ Set the central coordinate for the linecut.
551
+ set_window(window, axis=None, verbose=False)
552
+ Set the extents of the integration window.
553
+ get_window()
554
+ Get the extents of the integration window.
555
+ cut_data(center=None, window=None, axis=None, verbose=False)
556
+ Reduce data to a 1D linecut using the integration window.
557
+ highlight_integration_window(data=None, label=None, highlight_color='red', **kwargs)
558
+ Plot the integration window highlighted on a 2D heatmap of the full dataset.
559
+ plot_integration_window(**kwargs)
560
+ Plot a 2D heatmap of the integration window data.
561
+ """
562
+
563
+ def __init__(self, data=None, center=None, window=None, axis=None):
564
+ """
565
+ Initializes a Scissors object.
566
+
567
+ Parameters
568
+ ----------
569
+ data : :class:`nexusformat.nexus.NXdata` or None, optional
570
+ Input NXdata. Default is None.
571
+ center : tuple or None, optional
572
+ Central coordinate around which to perform the linecut. Default is None.
573
+ window : tuple or None, optional
574
+ Extents of the window for integration along each axis. Default is None.
575
+ axis : int or None, optional
576
+ Axis along which to perform the integration. Default is None.
577
+ """
578
+
579
+ self.data = data
580
+ self.center = tuple(float(i) for i in center) if center is not None else None
581
+ self.window = tuple(float(i) for i in window) if window is not None else None
582
+ self.axis = axis
583
+
584
+ self.integration_volume = None
585
+ self.integrated_axes = None
586
+ self.linecut = None
587
+ self.integration_window = None
588
+
589
+ def set_data(self, data):
590
+ """
591
+ Set the input NXdata.
592
+
593
+ Parameters
594
+ ----------
595
+ data : :class:`nexusformat.nexus.NXdata`
596
+ Input data array.
597
+ """
598
+ self.data = data
599
+
600
+ def get_data(self):
601
+ """
602
+ Get the input data array.
603
+
604
+ Returns
605
+ -------
606
+ ndarray or None
607
+ Input data array.
608
+ """
609
+ return self.data
610
+
611
+ def set_center(self, center):
612
+ """
613
+ Set the central coordinate for the linecut.
614
+
615
+ Parameters
616
+ ----------
617
+ center : tuple
618
+ Central coordinate around which to perform the linecut.
619
+ """
620
+ self.center = tuple(float(i) for i in center) if center is not None else None
621
+
622
+ def set_window(self, window, axis=None, verbose=False):
623
+ """
624
+ Set the extents of the integration window.
625
+
626
+ Parameters
627
+ ----------
628
+ window : tuple
629
+ Extents of the window for integration along each axis.
630
+ axis : int or None, optional
631
+ The axis along which to perform the linecut. If not specified, the value from the
632
+ object's attribute will be used.
633
+ verbose : bool, optional
634
+ Enables printout of linecut axis and integrated axes. Default False.
635
+
636
+ """
637
+ self.window = tuple(float(i) for i in window) if window is not None else None
638
+
639
+ # Determine the axis for integration
640
+ self.axis = window.index(max(window)) if axis is None else axis
641
+
642
+ # Determine the integrated axes (axes other than the integration axis)
643
+ self.integrated_axes = tuple(i for i in range(self.data.ndim) if i != self.axis)
644
+
645
+ if verbose:
646
+ print("Linecut axis: " + str(self.data.axes[self.axis]))
647
+ print("Integrated axes: " + str([self.data.axes[axis]
648
+ for axis in self.integrated_axes]))
649
+
650
+ def get_window(self):
651
+ """
652
+ Get the extents of the integration window.
653
+
654
+ Returns
655
+ -------
656
+ tuple or None
657
+ Extents of the integration window.
658
+ """
659
+ return self.window
660
+
661
+ def cut_data(self, center=None, window=None, axis=None, verbose=False):
662
+ """
663
+ Reduces data to a 1D linecut with integration extents specified by the
664
+ window about a central coordinate.
665
+
666
+ Parameters
667
+ ----------
668
+ center : float or None, optional
669
+ Central coordinate for the linecut. If not specified, the value from the object's
670
+ attribute will be used.
671
+ window : tuple or None, optional
672
+ Integration window extents around the central coordinate. If not specified, the value
673
+ from the object's attribute will be used.
674
+ axis : int or None, optional
675
+ The axis along which to perform the linecut. If not specified, the value from the
676
+ object's attribute will be used.
677
+ verbose : bool
678
+ Enables printout of linecut axis and integrated axes. Default False.
679
+
680
+ Returns
681
+ -------
682
+ integrated_data : :class:`nexusformat.nexus.NXdata`
683
+ 1D linecut data after integration.
684
+
685
+ """
686
+
687
+ # Extract necessary attributes from the object
688
+ data = self.data
689
+ center = center if center is not None else self.center
690
+ self.set_center(center)
691
+ window = window if window is not None else self.window
692
+ self.set_window(window, axis, verbose)
693
+
694
+ # Convert the center to a tuple of floats
695
+ center = tuple(float(c) for c in center)
696
+
697
+ # Calculate the start and stop indices for slicing the data
698
+ start = np.subtract(center, window)
699
+ stop = np.add(center, window)
700
+ slice_obj = tuple(slice(s, e) for s, e in zip(start, stop))
701
+ self.integration_window = slice_obj
702
+
703
+ # Perform the data cut
704
+ self.integration_volume = data[slice_obj]
705
+ self.integration_volume.nxname = data.nxname
706
+
707
+ # Perform integration along the integrated axes
708
+ integrated_data = np.sum(self.integration_volume[self.integration_volume.signal].nxdata,
709
+ axis=self.integrated_axes)
710
+
711
+ # Create an NXdata object for the linecut data
712
+ self.linecut = NXdata(NXfield(integrated_data, name=self.integration_volume.signal),
713
+ self.integration_volume[self.integration_volume.axes[self.axis]])
714
+ self.linecut.nxname = self.integration_volume.nxname
715
+
716
+ return self.linecut
717
+
718
+ def highlight_integration_window(self, data=None, width=None, height=None,
719
+ label=None, highlight_color='red', **kwargs):
720
+ """
721
+ Plots the integration window highlighted on the three principal 2D cross-sections of a 3D dataset.
722
+
723
+ Parameters
724
+ ----------
725
+ data : array-like, optional
726
+ The 3D dataset to visualize. If not provided, uses `self.data`.
727
+ width : float, optional
728
+ Width of the visible x-axis range in each subplot. Used to zoom in on the integration region.
729
+ height : float, optional
730
+ Height of the visible y-axis range in each subplot. Used to zoom in on the integration region.
731
+ label : str, optional
732
+ Label for the rectangle patch marking the integration window, used in the legend.
733
+ highlight_color : str, optional
734
+ Color of the rectangle edges highlighting the integration window. Default is 'red'.
735
+ **kwargs : dict, optional
736
+ Additional keyword arguments passed to `plot_slice` for customizing the plot (e.g., cmap, vmin, vmax).
737
+
738
+ Returns
739
+ -------
740
+ p1, p2, p3 : matplotlib.collections.QuadMesh
741
+ The plotted QuadMesh objects for the three cross-sections:
742
+ XY at fixed Z, XZ at fixed Y, and YZ at fixed X.
743
+
744
+ """
745
+ data = self.data if data is None else data
746
+ center = self.center
747
+ window = self.window
748
+
749
+ # Create a figure and subplots
750
+ fig, axes = plt.subplots(1, 3, figsize=(15, 4))
751
+
752
+ # Plot cross-section 1
753
+ slice_obj = [slice(None)] * data.ndim
754
+ slice_obj[2] = center[2]
755
+
756
+ p1 = plot_slice(data[slice_obj],
757
+ X=data[data.axes[0]],
758
+ Y=data[data.axes[1]],
759
+ ax=axes[0],
760
+ **kwargs)
761
+ ax = axes[0]
762
+ rect_diffuse = patches.Rectangle(
763
+ (center[0] - window[0],
764
+ center[1] - window[1]),
765
+ 2 * window[0], 2 * window[1],
766
+ linewidth=1, edgecolor=highlight_color,
767
+ facecolor='none', transform=p1.get_transform(), label=label,
768
+ )
769
+ ax.add_patch(rect_diffuse)
770
+
771
+ if 'xlim' not in kwargs and width is not None:
772
+ ax.set(xlim=(center[0] - width / 2, center[0] + width / 2))
773
+ if 'ylim' not in kwargs and height is not None:
774
+ ax.set(ylim=(center[1] - height / 2, center[1] + height / 2))
775
+
776
+ # Plot cross-section 2
777
+ slice_obj = [slice(None)] * data.ndim
778
+ slice_obj[1] = center[1]
779
+
780
+ p2 = plot_slice(data[slice_obj],
781
+ X=data[data.axes[0]],
782
+ Y=data[data.axes[2]],
783
+ ax=axes[1],
784
+ **kwargs)
785
+ ax = axes[1]
786
+ rect_diffuse = patches.Rectangle(
787
+ (center[0] - window[0],
788
+ center[2] - window[2]),
789
+ 2 * window[0], 2 * window[2],
790
+ linewidth=1, edgecolor=highlight_color,
791
+ facecolor='none', transform=p2.get_transform(), label=label,
792
+ )
793
+ ax.add_patch(rect_diffuse)
794
+
795
+ if 'xlim' not in kwargs and width is not None:
796
+ ax.set(xlim=(center[0] - width / 2, center[0] + width / 2))
797
+ if 'ylim' not in kwargs and height is not None:
798
+ ax.set(ylim=(center[2] - height / 2, center[2] + height / 2))
799
+
800
+ # Plot cross-section 3
801
+ slice_obj = [slice(None)] * data.ndim
802
+ slice_obj[0] = center[0]
803
+
804
+ p3 = plot_slice(data[slice_obj],
805
+ X=data[data.axes[1]],
806
+ Y=data[data.axes[2]],
807
+ ax=axes[2],
808
+ **kwargs)
809
+ ax = axes[2]
810
+ rect_diffuse = patches.Rectangle(
811
+ (center[1] - window[1],
812
+ center[2] - window[2]),
813
+ 2 * window[1], 2 * window[2],
814
+ linewidth=1, edgecolor=highlight_color,
815
+ facecolor='none', transform=p3.get_transform(), label=label,
816
+ )
817
+ ax.add_patch(rect_diffuse)
818
+
819
+ # If width and height are provided, center the view on the linecut area
820
+ if 'xlim' not in kwargs and width is not None:
821
+ ax.set(xlim=(center[1] - width / 2, center[1] + width / 2))
822
+ if 'ylim' not in kwargs and height is not None:
823
+ ax.set(ylim=(center[2] - height / 2, center[2] + height / 2))
824
+
825
+ # Adjust subplot padding
826
+ fig.subplots_adjust(wspace=0.5)
827
+
828
+ if label is not None:
829
+ [ax.legend() for ax in axes]
830
+
831
+ plt.show()
832
+
833
+ return p1, p2, p3
834
+
835
+ def plot_integration_window(self, **kwargs):
836
+ """
837
+ Plots the three principal cross-sections of the integration volume on a single figure.
838
+
839
+ Parameters
840
+ ----------
841
+ **kwargs : keyword arguments, optional
842
+ Additional keyword arguments to customize the plot.
843
+ """
844
+ data = self.integration_volume
845
+ center = self.center
846
+
847
+ fig, axes = plt.subplots(1, 3, figsize=(15, 4))
848
+
849
+ # Plot cross-section 1
850
+ slice_obj = [slice(None)] * data.ndim
851
+ slice_obj[2] = center[2]
852
+ p1 = plot_slice(data[slice_obj],
853
+ X=data[data.axes[0]],
854
+ Y=data[data.axes[1]],
855
+ ax=axes[0],
856
+ **kwargs)
857
+ axes[0].set_aspect(len(data[data.axes[0]].nxdata) / len(data[data.axes[1]].nxdata))
858
+
859
+ # Plot cross section 2
860
+ slice_obj = [slice(None)] * data.ndim
861
+ slice_obj[1] = center[1]
862
+ p3 = plot_slice(data[slice_obj],
863
+ X=data[data.axes[0]],
864
+ Y=data[data.axes[2]],
865
+ ax=axes[1],
866
+ **kwargs)
867
+ axes[1].set_aspect(len(data[data.axes[0]].nxdata) / len(data[data.axes[2]].nxdata))
868
+
869
+ # Plot cross-section 3
870
+ slice_obj = [slice(None)] * data.ndim
871
+ slice_obj[0] = center[0]
872
+ p2 = plot_slice(data[slice_obj],
873
+ X=data[data.axes[1]],
874
+ Y=data[data.axes[2]],
875
+ ax=axes[2],
876
+ **kwargs)
877
+ axes[2].set_aspect(len(data[data.axes[1]].nxdata) / len(data[data.axes[2]].nxdata))
878
+
879
+ # Adjust subplot padding
880
+ fig.subplots_adjust(wspace=0.3)
881
+
882
+ plt.show()
883
+
884
+ return p1, p2, p3
885
+
886
+
887
+ def reciprocal_lattice_params(lattice_params):
888
+ """
889
+ Calculate the reciprocal lattice parameters from the given direct lattice parameters.
890
+
891
+ Parameters
892
+ ----------
893
+ lattice_params : tuple
894
+ A tuple containing the direct lattice parameters (a, b, c, alpha, beta, gamma), where
895
+ a, b, and c are the magnitudes of the lattice vectors, and alpha, beta, and gamma are the
896
+ angles between them in degrees.
897
+
898
+ Returns
899
+ -------
900
+ tuple
901
+ A tuple containing the reciprocal lattice parameters (a*, b*, c*, alpha*, beta*, gamma*),
902
+ where a*, b*, and c* are the magnitudes of the reciprocal lattice vectors, and alpha*,
903
+ beta*, and gamma* are the angles between them in degrees.
904
+ """
905
+ a_mag, b_mag, c_mag, alpha, beta, gamma = lattice_params
906
+ # Convert angles to radians
907
+ alpha = np.deg2rad(alpha)
908
+ beta = np.deg2rad(beta)
909
+ gamma = np.deg2rad(gamma)
910
+
911
+ # Calculate unit cell volume
912
+ V = a_mag * b_mag * c_mag * np.sqrt(
913
+ 1 - np.cos(alpha) ** 2 - np.cos(beta) ** 2 - np.cos(gamma) ** 2
914
+ + 2 * np.cos(alpha) * np.cos(beta) * np.cos(gamma)
915
+ )
916
+
917
+ # Calculate reciprocal lattice parameters
918
+ a_star = (b_mag * c_mag * np.sin(alpha)) / V
919
+ b_star = (a_mag * c_mag * np.sin(beta)) / V
920
+ c_star = (a_mag * b_mag * np.sin(gamma)) / V
921
+ alpha_star = np.rad2deg(np.arccos((np.cos(beta) * np.cos(gamma) - np.cos(alpha))
922
+ / (np.sin(beta) * np.sin(gamma))))
923
+ beta_star = np.rad2deg(np.arccos((np.cos(alpha) * np.cos(gamma) - np.cos(beta))
924
+ / (np.sin(alpha) * np.sin(gamma))))
925
+ gamma_star = np.rad2deg(np.arccos((np.cos(alpha) * np.cos(beta) - np.cos(gamma))
926
+ / (np.sin(alpha) * np.sin(beta))))
927
+
928
+ return a_star, b_star, c_star, alpha_star, beta_star, gamma_star
929
+
930
+
931
+ def convert_to_inverse_angstroms(data, lattice_params):
932
+ """
933
+ Convert the axes of a 3D NXdata object from reciprocal lattice units (r.l.u.)
934
+ to inverse angstroms using provided lattice parameters.
935
+
936
+ Parameters
937
+ ----------
938
+ data : :class:`nexusformat.nexus.NXdata`
939
+ A 3D NXdata object with axes in reciprocal lattice units.
940
+
941
+ lattice_params : tuple of float
942
+ A tuple containing the real-space lattice parameters
943
+ (a, b, c, alpha, beta, gamma) in angstroms and degrees.
944
+
945
+ Returns
946
+ -------
947
+ NXdata
948
+ A new NXdata object with axes scaled to inverse angstroms.
949
+ """
950
+
951
+ a_, b_, c_, al_, be_, ga_ = reciprocal_lattice_params(lattice_params)
952
+
953
+ new_data = data.nxsignal
954
+ a_star = NXfield(data.nxaxes[0].nxdata * a_, name='a_star')
955
+ b_star = NXfield(data.nxaxes[1].nxdata * b_, name='b_star')
956
+ c_star = NXfield(data.nxaxes[2].nxdata * c_, name='c_star')
957
+
958
+ return NXdata(new_data, (a_star, b_star, c_star))
959
+
960
+
961
+ def rotate_data(data, lattice_angle, rotation_angle, rotation_axis, printout=False):
962
+ """
963
+ Rotates 3D data around a specified axis.
964
+
965
+ Parameters
966
+ ----------
967
+ data : :class:`nexusformat.nexus.NXdata`
968
+ Input data.
969
+ lattice_angle : float
970
+ Angle between the two in-plane lattice axes in degrees.
971
+ rotation_angle : float
972
+ Angle of rotation in degrees.
973
+ rotation_axis : int
974
+ Axis of rotation (0, 1, or 2).
975
+ printout : bool, optional
976
+ Enables printout of rotation progress. If set to True, information
977
+ about each rotation slice will be printed to the console, indicating
978
+ the axis being rotated and the corresponding coordinate value.
979
+ Defaults to False.
980
+
981
+
982
+ Returns
983
+ -------
984
+ rotated_data : :class:`nexusformat.nexus.NXdata`
985
+ Rotated data as an NXdata object.
986
+ """
987
+ # Define output array
988
+ output_array = np.zeros(data[data.signal].shape)
989
+
990
+ # Define shear transformation
991
+ skew_angle_adj = 90 - lattice_angle
992
+ t = Affine2D()
993
+ # Scale y-axis to preserve norm while shearing
994
+ t += Affine2D().scale(1, np.cos(skew_angle_adj * np.pi / 180))
995
+ # Shear along x-axis
996
+ t += Affine2D().skew_deg(skew_angle_adj, 0)
997
+ # Return to original y-axis scaling
998
+ t += Affine2D().scale(1, np.cos(skew_angle_adj * np.pi / 180)).inverted()
999
+
1000
+ # Iterate over all layers perpendicular to the rotation axis
1001
+ for i in range(len(data[data.axes[rotation_axis]])):
1002
+ # Print progress
1003
+ if printout:
1004
+ print(f'\rRotating {data.axes[rotation_axis]}'
1005
+ f'={data[data.axes[rotation_axis]][i]}... ',
1006
+ end='', flush=True)
1007
+
1008
+ # Identify current slice
1009
+ if rotation_axis == 0:
1010
+ sliced_data = data[i, :, :]
1011
+ elif rotation_axis == 1:
1012
+ sliced_data = data[:, i, :]
1013
+ elif rotation_axis == 2:
1014
+ sliced_data = data[:, :, i]
1015
+ else:
1016
+ sliced_data = None
1017
+
1018
+ # Add padding to avoid data cutoff during rotation
1019
+ p = Padder(sliced_data)
1020
+ padding = tuple(len(sliced_data[axis]) for axis in sliced_data.axes)
1021
+ counts = p.pad(padding)
1022
+ counts = p.padded[p.padded.signal]
1023
+
1024
+ # Perform shear operation
1025
+ counts_skewed = ndimage.affine_transform(counts,
1026
+ t.inverted().get_matrix()[:2, :2],
1027
+ offset=[counts.shape[0] / 2
1028
+ * np.sin(skew_angle_adj * np.pi / 180),
1029
+ 0],
1030
+ order=0,
1031
+ )
1032
+ # Scale data based on skew angle
1033
+ scale1 = np.cos(skew_angle_adj * np.pi / 180)
1034
+ counts_scaled1 = ndimage.affine_transform(counts_skewed,
1035
+ Affine2D().scale(scale1, 1).get_matrix()[:2, :2],
1036
+ offset=[(1 - scale1) * counts.shape[0] / 2, 0],
1037
+ order=0,
1038
+ )
1039
+ # Scale data based on ratio of array dimensions
1040
+ scale2 = counts.shape[0] / counts.shape[1]
1041
+ counts_scaled2 = ndimage.affine_transform(counts_scaled1,
1042
+ Affine2D().scale(scale2, 1).get_matrix()[:2, :2],
1043
+ offset=[(1 - scale2) * counts.shape[0] / 2, 0],
1044
+ order=0,
1045
+ )
1046
+
1047
+ # Perform rotation
1048
+ counts_rotated = ndimage.rotate(counts_scaled2, rotation_angle, reshape=False, order=0)
1049
+
1050
+ # Undo scaling 2
1051
+ counts_unscaled2 = ndimage.affine_transform(counts_rotated,
1052
+ Affine2D().scale(
1053
+ scale2, 1
1054
+ ).inverted().get_matrix()[:2, :2],
1055
+ offset=[-(1 - scale2) * counts.shape[
1056
+ 0] / 2 / scale2, 0],
1057
+ order=0,
1058
+ )
1059
+ # Undo scaling 1
1060
+ counts_unscaled1 = ndimage.affine_transform(counts_unscaled2,
1061
+ Affine2D().scale(
1062
+ scale1, 1
1063
+ ).inverted().get_matrix()[:2, :2],
1064
+ offset=[-(1 - scale1) * counts.shape[
1065
+ 0] / 2 / scale1, 0],
1066
+ order=0,
1067
+ )
1068
+ # Undo shear operation
1069
+ counts_unskewed = ndimage.affine_transform(counts_unscaled1,
1070
+ t.get_matrix()[:2, :2],
1071
+ offset=[
1072
+ (-counts.shape[0] / 2
1073
+ * np.sin(skew_angle_adj * np.pi / 180)),
1074
+ 0],
1075
+ order=0,
1076
+ )
1077
+ # Remove padding
1078
+ counts_unpadded = p.unpad(counts_unskewed)
1079
+
1080
+ # Write current slice
1081
+ if rotation_axis == 0:
1082
+ output_array[i, :, :] = counts_unpadded
1083
+ elif rotation_axis == 1:
1084
+ output_array[:, i, :] = counts_unpadded
1085
+ elif rotation_axis == 2:
1086
+ output_array[:, :, i] = counts_unpadded
1087
+ print('\nDone.')
1088
+ return NXdata(NXfield(output_array, name=p.padded.signal),
1089
+ (data[data.axes[0]], data[data.axes[1]], data[data.axes[2]]))
1090
+
1091
+
1092
+ def rotate_data_2D(data, lattice_angle, rotation_angle):
1093
+ """
1094
+ Rotates 2D data.
1095
+
1096
+ Parameters
1097
+ ----------
1098
+ data : :class:`nexusformat.nexus.NXdata`
1099
+ Input data.
1100
+ lattice_angle : float
1101
+ Angle between the two in-plane lattice axes in degrees.
1102
+ rotation_angle : float
1103
+ Angle of rotation in degrees.
1104
+
1105
+
1106
+ Returns
1107
+ -------
1108
+ rotated_data : :class:`nexusformat.nexus.NXdata`
1109
+ Rotated data as an NXdata object.
1110
+ """
1111
+
1112
+ # Define transformation
1113
+ skew_angle_adj = 90 - lattice_angle
1114
+ t = Affine2D()
1115
+ # Scale y-axis to preserve norm while shearing
1116
+ t += Affine2D().scale(1, np.cos(skew_angle_adj * np.pi / 180))
1117
+ # Shear along x-axis
1118
+ t += Affine2D().skew_deg(skew_angle_adj, 0)
1119
+ # Return to original y-axis scaling
1120
+ t += Affine2D().scale(1, np.cos(skew_angle_adj * np.pi / 180)).inverted()
1121
+
1122
+ # Add padding to avoid data cutoff during rotation
1123
+ p = Padder(data)
1124
+ padding = tuple(len(data[axis]) for axis in data.axes)
1125
+ counts = p.pad(padding)
1126
+ counts = p.padded[p.padded.signal]
1127
+
1128
+ # Perform shear operation
1129
+ counts_skewed = ndimage.affine_transform(counts,
1130
+ t.inverted().get_matrix()[:2, :2],
1131
+ offset=[counts.shape[0] / 2
1132
+ * np.sin(skew_angle_adj * np.pi / 180), 0],
1133
+ order=0,
1134
+ )
1135
+ # Scale data based on skew angle
1136
+ scale1 = np.cos(skew_angle_adj * np.pi / 180)
1137
+ counts_scaled1 = ndimage.affine_transform(counts_skewed,
1138
+ Affine2D().scale(scale1, 1).get_matrix()[:2, :2],
1139
+ offset=[(1 - scale1) * counts.shape[0] / 2, 0],
1140
+ order=0,
1141
+ )
1142
+ # Scale data based on ratio of array dimensions
1143
+ scale2 = counts.shape[0] / counts.shape[1]
1144
+ counts_scaled2 = ndimage.affine_transform(counts_scaled1,
1145
+ Affine2D().scale(scale2, 1).get_matrix()[:2, :2],
1146
+ offset=[(1 - scale2) * counts.shape[0] / 2, 0],
1147
+ order=0,
1148
+ )
1149
+ # Perform rotation
1150
+ counts_rotated = ndimage.rotate(counts_scaled2, rotation_angle, reshape=False, order=0)
1151
+
1152
+ # Undo scaling 2
1153
+ counts_unscaled2 = ndimage.affine_transform(counts_rotated,
1154
+ Affine2D().scale(
1155
+ scale2, 1
1156
+ ).inverted().get_matrix()[:2, :2],
1157
+ offset=[-(1 - scale2) * counts.shape[
1158
+ 0] / 2 / scale2, 0],
1159
+ order=0,
1160
+ )
1161
+ # Undo scaling 1
1162
+ counts_unscaled1 = ndimage.affine_transform(counts_unscaled2,
1163
+ Affine2D().scale(
1164
+ scale1, 1
1165
+ ).inverted().get_matrix()[:2, :2],
1166
+ offset=[-(1 - scale1) * counts.shape[
1167
+ 0] / 2 / scale1, 0],
1168
+ order=0,
1169
+ )
1170
+ # Undo shear operation
1171
+ counts_unskewed = ndimage.affine_transform(counts_unscaled1,
1172
+ t.get_matrix()[:2, :2],
1173
+ offset=[
1174
+ (-counts.shape[0] / 2
1175
+ * np.sin(skew_angle_adj * np.pi / 180)),
1176
+ 0],
1177
+ order=0,
1178
+ )
1179
+ # Remove padding
1180
+ counts_unpadded = p.unpad(counts_unskewed)
1181
+
1182
+ print('\nDone.')
1183
+ return NXdata(NXfield(counts_unpadded, name=p.padded.signal),
1184
+ (data[data.axes[0]], data[data.axes[1]]))
1185
+
1186
+
1187
+ class Padder:
1188
+ """
1189
+ A class to symmetrically pad and unpad datasets with a region of zeros.
1190
+
1191
+ Attributes
1192
+ ----------
1193
+ data : NXdata or None
1194
+ The input data to be padded.
1195
+ padded : NXdata or None
1196
+ The padded data with symmetric zero padding.
1197
+ padding : tuple or None
1198
+ The number of zero-value pixels added along each edge of the array.
1199
+ steps : tuple or None
1200
+ The step sizes along each axis of the dataset.
1201
+ maxes : tuple or None
1202
+ The maximum values along each axis of the dataset.
1203
+
1204
+ Methods
1205
+ -------
1206
+ set_data(data)
1207
+ Set the input data for padding.
1208
+ pad(padding)
1209
+ Symmetrically pads the data with zero values.
1210
+ save(fout_name=None)
1211
+ Saves the padded dataset to a .nxs file.
1212
+ unpad(data)
1213
+ Removes the padded region from the data.
1214
+ """
1215
+
1216
+ def __init__(self, data=None):
1217
+ """
1218
+ Initialize the Padder object.
1219
+
1220
+ Parameters
1221
+ ----------
1222
+ data : NXdata, optional
1223
+ The input data to be padded. If provided, the `set_data` method
1224
+ is called to set the data.
1225
+ """
1226
+ self.maxes = None
1227
+ self.steps = None
1228
+ self.data = None
1229
+ self.padded = None
1230
+ self.padding = None
1231
+ if data is not None:
1232
+ self.set_data(data)
1233
+
1234
+ def set_data(self, data):
1235
+ """
1236
+ Set the input data for padding.
1237
+
1238
+ Parameters
1239
+ ----------
1240
+ data : NXdata
1241
+ The input data to be padded.
1242
+ """
1243
+ self.data = data
1244
+
1245
+ self.steps = tuple((data[axis].nxdata[1] - data[axis].nxdata[0])
1246
+ for axis in data.axes)
1247
+
1248
+ # Absolute value of the maximum value; assumes the domain of the input
1249
+ # is symmetric (eg, -H_min = H_max)
1250
+ self.maxes = tuple(data[axis].nxdata.max() for axis in data.axes)
1251
+
1252
+ def pad(self, padding):
1253
+ """
1254
+ Symmetrically pads the data with zero values.
1255
+
1256
+ Parameters
1257
+ ----------
1258
+ padding : tuple
1259
+ The number of zero-value pixels to add along each edge of the array.
1260
+
1261
+ Returns
1262
+ -------
1263
+ NXdata
1264
+ The padded data with symmetric zero padding.
1265
+ """
1266
+ data = self.data
1267
+ self.padding = padding
1268
+
1269
+ padded_shape = tuple(data[data.signal].nxdata.shape[i]
1270
+ + self.padding[i] * 2 for i in range(data.ndim))
1271
+
1272
+ # Create padded dataset
1273
+ padded = np.zeros(padded_shape)
1274
+
1275
+ slice_obj = [slice(None)] * data.ndim
1276
+ for i, _ in enumerate(slice_obj):
1277
+ slice_obj[i] = slice(self.padding[i], -self.padding[i], None)
1278
+ slice_obj = tuple(slice_obj)
1279
+ padded[slice_obj] = data[data.signal].nxdata
1280
+
1281
+ padmaxes = tuple(self.maxes[i] + self.padding[i] * self.steps[i]
1282
+ for i in range(data.ndim))
1283
+
1284
+ padded = NXdata(NXfield(padded, name=data.signal),
1285
+ tuple(NXfield(np.linspace(-padmaxes[i], padmaxes[i], padded_shape[i]),
1286
+ name=data.axes[i])
1287
+ for i in range(data.ndim)))
1288
+
1289
+ self.padded = padded
1290
+ return padded
1291
+
1292
+ def save(self, fout_name=None):
1293
+ """
1294
+ Saves the padded dataset to a .nxs file.
1295
+
1296
+ Parameters
1297
+ ----------
1298
+ fout_name : str, optional
1299
+ The output file name. Default is padded_(Hpadding)_(Kpadding)_(Lpadding).nxs
1300
+ """
1301
+ padH, padK, padL = self.padding
1302
+
1303
+ # Save padded dataset
1304
+ print("Saving padded dataset...")
1305
+ f = NXroot()
1306
+ f['entry'] = NXentry()
1307
+ f['entry']['data'] = self.padded
1308
+ if fout_name is None:
1309
+ fout_name = 'padded_' + str(padH) + '_' + str(padK) + '_' + str(padL) + '.nxs'
1310
+ nxsave(fout_name, f)
1311
+ print("Output file saved to: " + os.path.join(os.getcwd(), fout_name))
1312
+
1313
+ def unpad(self, data):
1314
+ """
1315
+ Removes the padded region from the data.
1316
+
1317
+ Parameters
1318
+ ----------
1319
+ data : ndarray or NXdata
1320
+ The padded data from which to remove the padding.
1321
+
1322
+ Returns
1323
+ -------
1324
+ ndarray or NXdata
1325
+ The unpadded data, with the symmetric padding region removed.
1326
+ """
1327
+ slice_obj = [slice(None)] * data.ndim
1328
+ for i in range(data.ndim):
1329
+ slice_obj[i] = slice(self.padding[i], -self.padding[i], None)
1330
+ slice_obj = tuple(slice_obj)
1331
+ return data[slice_obj]
1332
+
1333
+
1334
+ def load_discus_nxs(path):
1335
+ """
1336
+ Load .nxs format data from the DISCUS program (by T. Proffen and R. Neder)
1337
+ and convert it to the CHESS format.
1338
+
1339
+ Parameters
1340
+ ----------
1341
+ path : str
1342
+ The file path to the .nxs file generated by DISCUS.
1343
+
1344
+ Returns
1345
+ -------
1346
+ NXdata
1347
+ The data converted to the CHESS format, with axes labeled 'H', 'K', and 'L',
1348
+ and the signal labeled 'counts'.
1349
+
1350
+ """
1351
+ filename = path
1352
+ root = nxload(filename)
1353
+ hlim, klim, llim = root.lower_limits
1354
+ hstep, kstep, lstep = root.step_sizes
1355
+ h = NXfield(np.linspace(hlim, -hlim, int(np.abs(hlim * 2) / hstep) + 1), name='H')
1356
+ k = NXfield(np.linspace(klim, -klim, int(np.abs(klim * 2) / kstep) + 1), name='K')
1357
+ l = NXfield(np.linspace(llim, -llim, int(np.abs(llim * 2) / lstep) + 1), name='L')
1358
+ data = NXdata(NXfield(root.data[:, :, :], name='counts'), (h, k, l))
1359
+
1360
+ return data