nxs-analysis-tools 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1542 @@
1
+ """
2
+ Tools for reducing data into 2D and 1D, and visualization functions for plotting and animating
3
+ data.
4
+ """
5
+ import os
6
+ import io
7
+ import warnings
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from matplotlib.transforms import Affine2D
11
+ from matplotlib.markers import MarkerStyle
12
+ from matplotlib.ticker import MultipleLocator
13
+ import matplotlib.animation as animation
14
+ from matplotlib import colors
15
+ from matplotlib import patches
16
+ from IPython.display import display, Markdown, HTML, Image
17
+ from nexusformat.nexus import NXfield, NXdata, nxload, NeXusError, NXroot, NXentry, nxsave
18
+ from scipy.ndimage import rotate, zoom
19
+
20
+ from .lineartransformations import ShearTransformer
21
+
22
+
23
+ # Specify items on which users are allowed to perform standalone imports
24
+ __all__ = ['load_data', 'load_transform', 'plot_slice', 'Scissors',
25
+ 'reciprocal_lattice_params', 'rotate_data',
26
+ 'convert_to_inverse_angstroms', 'array_to_nxdata', 'Padder',
27
+ 'rebin_nxdata', 'rebin_3d', 'rebin_1d', 'animate_slice_temp',
28
+ 'animate_slice_axis']
29
+
30
+
31
+ def load_data(path, print_tree=True):
32
+ """
33
+ Load data from a NeXus file at a specified path. It is assumed that the data follows the CHESS
34
+ file structure (i.e., root/entry/data/counts, etc.).
35
+
36
+ Parameters
37
+ ----------
38
+ path : str
39
+ The path to the NeXus data file.
40
+
41
+ print_tree : bool, optional
42
+ Whether to print the data tree upon loading. Default True.
43
+
44
+ Returns
45
+ -------
46
+ data : nxdata object
47
+ The loaded data stored in a nxdata object.
48
+
49
+ """
50
+
51
+ g = nxload(path)
52
+ try:
53
+ print(g.entry.data.tree) if print_tree else None
54
+ except NeXusError:
55
+ pass
56
+
57
+ return g.entry.data
58
+
59
+
60
+ def load_transform(path, print_tree=True, use_nxlink=False):
61
+ """
62
+ Load transform data from an nxrefine output file.
63
+
64
+ Parameters
65
+ ----------
66
+ path : str
67
+ The path to the transform data file.
68
+
69
+ print_tree : bool, optional
70
+ If True, prints the NeXus data tree upon loading. Default is True.
71
+
72
+ use_nxlink : bool, optional
73
+ If True, maintains the NXlink defined in the data file, which references
74
+ the raw data in the transform.nxs file. This saves memory when working with
75
+ many datasets. In this case, the axes are in reverse order. Default is False.
76
+
77
+ Returns
78
+ -------
79
+ data : NXdata
80
+ The loaded transform data as an NXdata object.
81
+ """
82
+
83
+ root = nxload(path)
84
+
85
+ if use_nxlink:
86
+ data = root.entry.transform
87
+ else:
88
+ data = NXdata(NXfield(root.entry.transform.data.nxdata.transpose(2, 1, 0), name='counts'),
89
+ (root.entry.transform.Qh, root.entry.transform.Qk, root.entry.transform.Ql))
90
+
91
+ print(data.tree) if print_tree else None
92
+
93
+ return data
94
+
95
+
96
+ def array_to_nxdata(array, data_template, signal_name=None):
97
+ """
98
+ Create an NXdata object from an input array and an NXdata template,
99
+ with an optional signal name.
100
+
101
+ Parameters
102
+ ----------
103
+ array : array-like
104
+ The data array to be included in the NXdata object.
105
+
106
+ data_template : NXdata
107
+ An NXdata object serving as a template, which provides information
108
+ about axes and other metadata.
109
+
110
+ signal_name : str, optional
111
+ The name of the signal within the NXdata object. If not provided,
112
+ the signal name is inherited from the data_template.
113
+
114
+ Returns
115
+ -------
116
+ NXdata
117
+ An NXdata object containing the input data array and associated axes
118
+ based on the template.
119
+ """
120
+ d = data_template
121
+ if signal_name is None:
122
+ signal_name = d.nxsignal.nxname
123
+ return NXdata(NXfield(array, name=signal_name), d.nxaxes)
124
+
125
+
126
+ def rebin_3d(array):
127
+ """
128
+ Rebins a 3D NumPy array by a factor of 2 along each dimension.
129
+
130
+ This function reduces the size of the input array by averaging over non-overlapping
131
+ 2x2x2 blocks. Each dimension of the input array must be divisible by 2.
132
+
133
+ Parameters
134
+ ----------
135
+ array : np.ndarray
136
+ A 3-dimensional NumPy array to be rebinned.
137
+
138
+ Returns
139
+ -------
140
+ np.ndarray
141
+ A rebinned array with shape (N//2, M//2, L//2) if the original shape was (N, M, L).
142
+ """
143
+
144
+ # Ensure the array shape is divisible by 2 in each dimension
145
+ shape = array.shape
146
+ if any(dim % 2 != 0 for dim in shape):
147
+ raise ValueError("Each dimension of the array must be divisible by 2 to rebin.")
148
+
149
+ # Reshape the array to group the data into 2x2x2 blocks
150
+ reshaped = array.reshape(shape[0] // 2, 2, shape[1] // 2, 2, shape[2] // 2, 2)
151
+
152
+ # Average over the 2x2x2 blocks
153
+ rebinned = reshaped.mean(axis=(1, 3, 5))
154
+
155
+ return rebinned
156
+
157
+ def rebin_2d(array):
158
+ """
159
+ Rebins a 2D NumPy array by a factor of 2 along each dimension.
160
+
161
+ This function reduces the size of the input array by averaging over non-overlapping
162
+ 2x2 blocks. Each dimension of the input array must be divisible by 2.
163
+
164
+ Parameters
165
+ ----------
166
+ array : np.ndarray
167
+ A 2-dimensional NumPy array to be rebinned.
168
+
169
+ Returns
170
+ -------
171
+ np.ndarray
172
+ A rebinned array with shape (N//2, M//2) if the original shape was (N, M).
173
+ """
174
+
175
+ # Ensure the array shape is divisible by 2 in each dimension
176
+ shape = array.shape
177
+ if any(dim % 2 != 0 for dim in shape):
178
+ raise ValueError("Each dimension of the array must be divisible by 2 to rebin.")
179
+
180
+ # Reshape the array to group the data into 2x2 blocks
181
+ reshaped = array.reshape(shape[0] // 2, 2, shape[1] // 2, 2)
182
+
183
+ # Average over the 2x2 blocks
184
+ rebinned = reshaped.mean(axis=(1, 3))
185
+
186
+ return rebinned
187
+
188
+ def rebin_1d(array):
189
+ """
190
+ Rebins a 1D NumPy array by a factor of 2.
191
+
192
+ This function reduces the size of the input array by averaging over non-overlapping
193
+ pairs of elements. The input array length must be divisible by 2.
194
+
195
+ Parameters
196
+ ----------
197
+ array : np.ndarray
198
+ A 1-dimensional NumPy array to be rebinned.
199
+
200
+ Returns
201
+ -------
202
+ np.ndarray
203
+ A rebinned array with length N//2 if the original length was N.
204
+ """
205
+
206
+ # Ensure the array length is divisible by 2
207
+ if len(array) % 2 != 0:
208
+ raise ValueError("The length of the array must be divisible by 2 to rebin.")
209
+
210
+ # Reshape the array to group elements into pairs
211
+ reshaped = array.reshape(len(array) // 2, 2)
212
+
213
+ # Average over the pairs
214
+ rebinned = reshaped.mean(axis=1)
215
+
216
+ return rebinned
217
+
218
+
219
+ def rebin_nxdata(data):
220
+ """
221
+ Rebins the signal and axes of an NXdata object by a factor of 2 along each dimension.
222
+
223
+ This function first checks each axis of the input `NXdata` object:
224
+ - If the axis has an odd number of elements, the last element is excluded before rebinning.
225
+ - Then, each axis is rebinned using `rebin_1d`.
226
+
227
+ The signal array is similarly cropped to remove the last element along any dimension
228
+ with an odd shape, and then the data is averaged over 2x2x... blocks.
229
+
230
+ Parameters
231
+ ----------
232
+ data : NXdata
233
+ The NeXus data group containing the signal and axes to be rebinned.
234
+
235
+ Returns
236
+ -------
237
+ NXdata
238
+ A new NXdata object with signal and axes rebinned by a factor of 2 along each dimension.
239
+ """
240
+ # First, rebin axes
241
+ new_axes = []
242
+ for i in range(len(data.shape)):
243
+ if data.shape[i] % 2 == 1:
244
+ new_axes.append(
245
+ NXfield(
246
+ rebin_1d(data.nxaxes[i].nxdata[:-1]),
247
+ name=data.nxaxes[i].nxname
248
+ )
249
+ )
250
+ else:
251
+ new_axes.append(
252
+ NXfield(
253
+ rebin_1d(data.nxaxes[i].nxdata[:]),
254
+ name=data.nxaxes[i].nxname
255
+ )
256
+ )
257
+
258
+ # Second, rebin signal
259
+ data_arr = data.nxsignal.nxdata
260
+
261
+ # Crop the array if the shape is odd in any direction
262
+ slice_obj = []
263
+ for i, dim in enumerate(data_arr.shape):
264
+ if dim % 2 == 1:
265
+ slice_obj.append(slice(0, dim - 1))
266
+ else:
267
+ slice_obj.append(slice(None))
268
+
269
+ data_arr = data_arr[tuple(slice_obj)]
270
+
271
+ # Perform actual rebinning
272
+ if data.ndim == 3:
273
+ data_arr = rebin_3d(data_arr)
274
+ elif data.ndim == 2:
275
+ data_arr = rebin_2d(data_arr)
276
+ elif data.ndim == 1:
277
+ data_arr = rebin_1d(data_arr)
278
+
279
+ return NXdata(NXfield(data_arr, name=data.nxsignal.nxname),
280
+ tuple([axis for axis in new_axes])
281
+ )
282
+
283
+
284
+ def plot_slice(data, X=None, Y=None, sum_axis=None, transpose=False, vmin=None, vmax=None,
285
+ skew_angle=90, ax=None, xlim=None, ylim=None,
286
+ xticks=None, yticks=None, cbar=True, logscale=False,
287
+ symlogscale=False, cmap='viridis', linthresh=1,
288
+ title=None, mdheading=None, cbartitle=None,
289
+ **kwargs):
290
+ """
291
+ Plot a 2D slice of the provided dataset, with optional transformations
292
+ and customizations.
293
+
294
+ Parameters
295
+ ----------
296
+ data : :class:`nexusformat.nexus.NXdata` or ndarray
297
+ The dataset to plot. Can be an `NXdata` object or a `numpy` array.
298
+
299
+ sum_axis : int, optional
300
+ If the input data is 3D, this specifies the axis to sum over in order
301
+ to reduce the data to 2D for plotting. Required if `data` has three dimensions.
302
+
303
+ transpose : bool, optional
304
+ If True, transpose the dataset and its axes before plotting.
305
+ Default is False.
306
+
307
+ vmin : float, optional
308
+ The minimum value for the color scale. If not provided, the minimum
309
+ value of the dataset is used.
310
+
311
+ vmax : float, optional
312
+ The maximum value for the color scale. If not provided, the maximum
313
+ value of the dataset is used.
314
+
315
+ skew_angle : float, optional
316
+ The angle in degrees to shear the plot. Default is 90 degrees (no skew).
317
+
318
+ ax : matplotlib.axes.Axes, optional
319
+ The `matplotlib` axis to plot on. If None, a new figure and axis will
320
+ be created.
321
+
322
+ xlim : tuple, optional
323
+ The limits for the x-axis. If None, the limits are set automatically
324
+ based on the data.
325
+
326
+ ylim : tuple, optional
327
+ The limits for the y-axis. If None, the limits are set automatically
328
+ based on the data.
329
+
330
+ xticks : float or list of float, optional
331
+ The major tick interval or specific tick locations for the x-axis.
332
+ Default is to use a minor tick interval of 1.
333
+
334
+ yticks : float or list of float, optional
335
+ The major tick interval or specific tick locations for the y-axis.
336
+ Default is to use a minor tick interval of 1.
337
+
338
+ cbar : bool, optional
339
+ Whether to include a colorbar. Default is True.
340
+
341
+ logscale : bool, optional
342
+ Whether to use a logarithmic color scale. Default is False.
343
+
344
+ symlogscale : bool, optional
345
+ Whether to use a symmetrical logarithmic color scale. Default is False.
346
+
347
+ cmap : str or Colormap, optional
348
+ The colormap to use for the plot. Default is 'viridis'.
349
+
350
+ linthresh : float, optional
351
+ The linear threshold for symmetrical logarithmic scaling. Default is 1.
352
+
353
+ title : str, optional
354
+ The title for the plot. If None, no title is set.
355
+
356
+ mdheading : str, optional
357
+ A Markdown heading to display above the plot. If 'None' or not provided,
358
+ no heading is displayed.
359
+
360
+ cbartitle : str, optional
361
+ The title for the colorbar. If None, the colorbar label will be set to
362
+ the name of the signal.
363
+
364
+ **kwargs
365
+ Additional keyword arguments passed to `pcolormesh`.
366
+
367
+ Returns
368
+ -------
369
+ p : :class:`matplotlib.collections.QuadMesh`
370
+ The `matplotlib` QuadMesh object representing the plotted data.
371
+ """
372
+
373
+ # Some logic to control the processing of the arrays
374
+ is_array = False
375
+ is_nxdata = False
376
+ no_xy_provided = True
377
+
378
+ # If X,Y not provided by user
379
+ if X is not None and Y is not None:
380
+ no_xy_provided = False
381
+
382
+ # Examine data type to be plotted
383
+ if isinstance(data, np.ndarray):
384
+ is_array = True
385
+ elif isinstance(data, (NXdata, NXfield)):
386
+ is_nxdata = True
387
+ else:
388
+ raise TypeError(f"Unexpected data type: {type(data)}. "
389
+ f"Supported types are np.ndarray and NXdata.")
390
+
391
+ # If three-dimensional, demand sum_axis to reduce to two dimensions.
392
+ if data.ndim == 3:
393
+ if sum_axis is None:
394
+ raise ValueError("sum_axis must be specified when data.ndim == 3.")
395
+
396
+ if is_array:
397
+ data = data.sum(axis=sum_axis)
398
+ elif is_nxdata:
399
+ arr = data.nxsignal.nxdata
400
+ arr = arr.sum(axis=sum_axis)
401
+
402
+ # Create a 2D template from the original nxdata
403
+ slice_obj = [slice(None)] * len(data.shape)
404
+ slice_obj[sum_axis] = 0
405
+
406
+ # Use the 2D template to create a new nxdata
407
+ data = array_to_nxdata(arr, data[slice_obj])
408
+
409
+ if data.ndim != 2:
410
+ raise ValueError("Slice data must be 2D.")
411
+
412
+ # If the data is of type ndarray, then convert to NXdata
413
+ if is_array:
414
+ # Convert X to NXfield if it is not already
415
+ if X is None:
416
+ X = NXfield(np.arange(data.shape[0]), name='x')
417
+ elif isinstance(X, np.ndarray):
418
+ X = NXfield(X, name='x')
419
+ elif isinstance(X, NXfield):
420
+ pass
421
+ else:
422
+ raise TypeError("X must be of type np.ndarray or NXdata")
423
+
424
+ # Convert Y to NXfield if it is not already
425
+ if Y is None:
426
+ Y = NXfield(np.arange(data.shape[1]), name='y')
427
+ elif isinstance(Y, np.ndarray):
428
+ Y = NXfield(Y, name='y')
429
+ elif isinstance(Y, NXfield):
430
+ pass
431
+ else:
432
+ raise TypeError("Y must be of type np.ndarray or NXdata")
433
+
434
+ if transpose:
435
+ X, Y = Y, X
436
+ data = data.transpose()
437
+
438
+ data = NXdata(NXfield(data, name='value'), (X, Y))
439
+ data_arr = data.nxsignal.nxdata.transpose()
440
+ # Otherwise, if data is of type NXdata, then decide whether to create axes,
441
+ # use provided axes, or inherit axes.
442
+ elif is_nxdata:
443
+ if X is None:
444
+ X = data.nxaxes[0]
445
+ elif isinstance(X, np.ndarray):
446
+ X = NXfield(X, name='x')
447
+ elif isinstance(X, NXdata):
448
+ pass
449
+ if Y is None:
450
+ Y = data.nxaxes[1]
451
+ elif isinstance(Y, np.ndarray):
452
+ Y = NXfield(Y, name='y')
453
+ elif isinstance(Y, NXdata):
454
+ pass
455
+
456
+ # Transpose axes and data if specified
457
+ if transpose:
458
+ X, Y = Y, X
459
+ data = data.transpose()
460
+
461
+ data_arr = data.nxsignal.nxdata.transpose()
462
+
463
+ # Display Markdown heading
464
+ if mdheading is None:
465
+ pass
466
+ elif mdheading == "None":
467
+ display(Markdown('### Figure'))
468
+ else:
469
+ display(Markdown('### Figure - ' + mdheading))
470
+
471
+ # Inherit axes if user provides some
472
+ if ax is not None:
473
+ fig = ax.get_figure()
474
+ # Otherwise set up some default axes
475
+ else:
476
+ fig = plt.figure()
477
+ ax = fig.add_axes([0, 0, 1, 1])
478
+
479
+ # If limits not provided, use extrema
480
+ if vmin is None:
481
+ vmin = data_arr.min()
482
+ if vmax is None:
483
+ vmax = data_arr.max()
484
+
485
+ # Set norm (linear scale, logscale, or symlogscale)
486
+ norm = colors.Normalize(vmin=vmin, vmax=vmax) # Default: linear scale
487
+
488
+ if symlogscale:
489
+ norm = colors.SymLogNorm(linthresh=linthresh, vmin=-1 * vmax, vmax=vmax)
490
+ elif logscale:
491
+ norm = colors.LogNorm(vmin=vmin, vmax=vmax)
492
+
493
+
494
+ # Plot data
495
+ p = ax.pcolormesh(X.nxdata, Y.nxdata, data_arr, shading='auto', norm=norm, cmap=cmap, **kwargs)
496
+
497
+ ## Transform data to new coordinate system if necessary
498
+ t = ShearTransformer(skew_angle)
499
+
500
+ # If ylims provided, use those
501
+ if ylim is not None:
502
+ # Set ylims
503
+ ax.set(ylim=ylim)
504
+ ymin, ymax = ylim
505
+ # Else, use current ylims
506
+ else:
507
+ ymin, ymax = ax.get_ylim()
508
+ # Use ylims to calculate translation (necessary to display axes in correct position)
509
+ p.set_transform(t.t
510
+ + Affine2D().translate(-ymin * np.sin(t.shear_angle * np.pi / 180), 0)
511
+ + ax.transData)
512
+
513
+ # Set x limits
514
+ if xlim is not None:
515
+ xmin, xmax = xlim
516
+ else:
517
+ xmin, xmax = ax.get_xlim()
518
+ if skew_angle <= 90:
519
+ ax.set(xlim=(xmin, xmax + (ymax - ymin) / np.tan((90 - t.shear_angle) * np.pi / 180)))
520
+ else:
521
+ ax.set(xlim=(xmin - (ymax - ymin) / np.tan((t.shear_angle - 90) * np.pi / 180), xmax))
522
+
523
+ # Correct aspect ratio for the x/y axes after transformation
524
+ ax.set(aspect=np.cos(t.shear_angle * np.pi / 180))
525
+
526
+
527
+ # Automatically set tick locations, only if NXdata or if X,Y axes are provided for an array
528
+ if is_nxdata or (is_array and (no_xy_provided == False)):
529
+ # Add default minor ticks on x
530
+ ax.xaxis.set_minor_locator(MultipleLocator(1))
531
+
532
+ # Add tick marks all around
533
+ ax.tick_params(direction='in', top=True, right=True, which='both')
534
+
535
+ if xticks is not None:
536
+ # Use user provided values
537
+ ax.xaxis.set_major_locator(MultipleLocator(xticks))
538
+
539
+ # Add default minor ticks on y
540
+ ax.yaxis.set_minor_locator(MultipleLocator(1))
541
+
542
+ if yticks is not None:
543
+ # Use user provided values
544
+ ax.yaxis.set_major_locator(MultipleLocator(yticks))
545
+ else:
546
+ # Add tick marks all around
547
+ ax.tick_params(direction='in', top=True, right=True, which='major')
548
+
549
+ # Apply transform to tick marks
550
+ for i in range(0, len(ax.xaxis.get_ticklines())):
551
+ # Tick marker
552
+ m = MarkerStyle(3)
553
+ line = ax.xaxis.get_majorticklines()[i]
554
+ if i % 2:
555
+ # Top ticks (translation here makes their direction="in")
556
+ m._transform.set(Affine2D().translate(0, -1) + Affine2D().skew_deg(t.shear_angle, 0))
557
+ # This first method shifts the top ticks horizontally to match the skew angle.
558
+ # This does not look good in all cases.
559
+ # line.set_transform(Affine2D().translate((ymax-ymin)*np.sin(skew_angle*np.pi/180),0) +
560
+ # line.get_transform())
561
+ # This second method skews the tick marks in place and
562
+ # can sometimes lead to them being misaligned.
563
+ line.set_transform(line.get_transform()) # This does nothing
564
+ else:
565
+ # Bottom ticks
566
+ m._transform.set(Affine2D().skew_deg(t.shear_angle, 0))
567
+
568
+ line.set_marker(m)
569
+
570
+ for i in range(0, len(ax.xaxis.get_minorticklines())):
571
+ m = MarkerStyle(2)
572
+ line = ax.xaxis.get_minorticklines()[i]
573
+ if i % 2:
574
+ m._transform.set(Affine2D().translate(0, -1) + Affine2D().skew_deg(t.shear_angle, 0))
575
+ else:
576
+ m._transform.set(Affine2D().skew_deg(t.shear_angle, 0))
577
+
578
+ line.set_marker(m)
579
+
580
+ if cbar:
581
+ colorbar = fig.colorbar(p)
582
+ if cbartitle is None:
583
+ colorbar.set_label(data.nxsignal.nxname)
584
+
585
+ ax.set(
586
+ xlabel=X.nxname,
587
+ ylabel=Y.nxname,
588
+ )
589
+
590
+ if title is not None:
591
+ ax.set_title(title)
592
+
593
+ # Return the quadmesh object
594
+ return p
595
+
596
+ def animate_slice_temp(temp_dependence, slice_obj, ax=None, reverse_temps=False, interval=500,
597
+ save_gif=False, filename='animation', title=True, title_fmt='d',
598
+ plot_slice_kwargs=None, ax_kwargs=None):
599
+ """
600
+ Animate 2D slices from a temperature-dependent dataset.
601
+
602
+ Creates a matplotlib animation by extracting 2D slices from each dataset
603
+ in a TempDependence object and animating them in sequence by temperature.
604
+ Optionally displays the animation inline and/or saves it as a GIF.
605
+
606
+ Parameters
607
+ ----------
608
+ temp_dependence : nxs_analysis_tools.chess.TempDependence
609
+ Object holding datasets at various temperatures.
610
+ slice_obj : list of slice or None
611
+ Slice object to apply to each dataset; None entries are treated as ':'.
612
+ ax : matplotlib.axes.Axes, optional
613
+ The axes object to plot on. If None, a new figure and axes will be created.
614
+ reverse_temps : bool, optional
615
+ If True, animates datasets with increasing temperature. Default is False.
616
+ interval : int, optional
617
+ Delay between frames in milliseconds. Default is 500.
618
+ save_gif : bool, optional
619
+ If True, saves the animation to a .gif file. Default is False.
620
+ filename : str, optional
621
+ Filename (without extension) for saved .gif. Default is 'animation'.
622
+ title : bool, optional
623
+ If True, displays the temperature in the title of each frame. Default is True.
624
+ title_fmt : str, optional
625
+ Format string for temperature values (e.g., '.2f' for 2 decimals). Default is 'd' (integer).
626
+ plot_slice_kwargs : dict, optional
627
+ Additional keyword arguments passed to `plot_slice`.
628
+ ax_kwargs : dict, optional
629
+ Keyword arguments passed to `ax.set`.
630
+
631
+ Returns
632
+ -------
633
+ ani : matplotlib.animation.FuncAnimation
634
+ The resulting animation object.
635
+ """
636
+ if ax is None:
637
+ fig,ax = plt.subplots() # Generate a new figure and axis
638
+ else:
639
+ fig = ax.figure # Get the figure from the provided axis
640
+
641
+
642
+ if plot_slice_kwargs is None:
643
+ plot_slice_kwargs = {}
644
+ if ax_kwargs is None:
645
+ ax_kwargs = {}
646
+
647
+ # Normalize the slice object
648
+ normalized_slice = [slice(None) if s is None else s for s in slice_obj]
649
+
650
+ # Warn if colorbar is requested
651
+ if plot_slice_kwargs.get('cbar', False):
652
+ warnings.warn("Colorbar is not supported in animation and will be ignored.", UserWarning)
653
+ plot_slice_kwargs['cbar'] = False
654
+ elif 'cbar' not in plot_slice_kwargs.keys():
655
+ plot_slice_kwargs['cbar'] = False
656
+
657
+ def update(temp):
658
+ ax.clear()
659
+ dataset = temp_dependence.datasets[temp]
660
+ plot_slice(dataset[tuple(normalized_slice)], ax=ax, **plot_slice_kwargs)
661
+ ax.set(**ax_kwargs)
662
+
663
+ if title:
664
+ try:
665
+ formatted_temp = f"{int(temp):{title_fmt}}"
666
+ except ValueError:
667
+ raise ValueError(f"Invalid title_fmt '{title_fmt}' for temperature value '{temp}'")
668
+ ax.set(title=f'$T$={formatted_temp}')
669
+
670
+ # Animate frames upon warming
671
+ if reverse_temps:
672
+ frames = temp_dependence.temperatures.copy()
673
+ # Animate frames upon cooling (default)
674
+ else:
675
+ frames = temp_dependence.temperatures.copy()
676
+ frames.reverse()
677
+
678
+
679
+ ani = animation.FuncAnimation(fig, update,
680
+ frames=frames,
681
+ interval=interval, repeat=False)
682
+
683
+ display(HTML(ani.to_jshtml()))
684
+
685
+ if save_gif:
686
+ gif_file = f'{filename}.gif'
687
+ writer = animation.PillowWriter(fps=1000 / interval)
688
+ ani.save(gif_file, writer=writer)
689
+ with open(gif_file, 'rb') as f:
690
+ display(Image(f.read(), format='gif'))
691
+
692
+ return ani
693
+
694
+ def animate_slice_axis(data, axis, axis_values, ax=None, interval=500, save_gif=False, filename='animation', title=True, title_fmt='.2f', plot_slice_kwargs={}, ax_kwargs={}):
695
+ """
696
+ Animate 2D slices of a 3D dataset along a given axis.
697
+
698
+ Creates a matplotlib animation by sweeping through 2D slices of a 3D
699
+ dataset along the specified axis. Optionally displays the animation
700
+ inline (e.g., in Jupyter) and/or saves it as a GIF.
701
+
702
+ Parameters
703
+ ----------
704
+ data : nexusformat.nexus.NXdata
705
+ The 3D dataset to visualize.
706
+ axis : int
707
+ The axis along which to animate (must be 0, 1, or 2).
708
+ axis_values : iterable
709
+ The values along the animation axis to use as animation frames.
710
+ ax : matplotlib.axes.Axes, optional
711
+ The axes object to plot on. If None, a new figure and axes will be created.
712
+ interval : int, optional
713
+ Delay between frames in milliseconds. Default is 500.
714
+ save_gif : bool, optional
715
+ If True, saves the animation as a .gif file. Default is False.
716
+ filename : str, optional
717
+ Filename (without extension) to use for the saved .gif. Default is 'animation'.
718
+ title : bool, optional
719
+ If True, displays the axis value as a title for each frame. Default is True.
720
+ title_fmt : str, optional
721
+ Format string for axis value in the title (e.g., '.2f' for 2 decimals). Default is '.2f'.
722
+ plot_slice_kwargs : dict, optional
723
+ Additional keyword arguments passed to `plot_slice`.
724
+ ax_kwargs : dict, optional
725
+ Keyword arguments passed to `ax.set` to update axis settings.
726
+
727
+ Returns
728
+ -------
729
+ ani : matplotlib.animation.FuncAnimation
730
+ The animation object.
731
+ """
732
+ if ax is None:
733
+ fig,ax = plt.subplots() # Generate a new figure and axis
734
+ else:
735
+ fig = ax.figure # Get the figure from the provided axis
736
+
737
+ if axis not in [0, 1, 2]:
738
+ raise ValueError("axis must be either 0, 1, or 2.")
739
+
740
+ if plot_slice_kwargs.get('cbar', False):
741
+ warnings.warn("Colorbar is not supported in animation and will be ignored.", UserWarning)
742
+ plot_slice_kwargs['cbar'] = False
743
+ elif 'cbar' not in plot_slice_kwargs.keys():
744
+ plot_slice_kwargs['cbar'] = False
745
+
746
+
747
+ def update(parameter):
748
+ ax.clear()
749
+
750
+ # Construct slicing object for the selected axis
751
+ slice_obj = [slice(None)] * 3
752
+ slice_obj[axis] = parameter
753
+
754
+ # Plot the 2D slice
755
+ plot_slice(data[tuple(slice_obj)], ax=ax, **plot_slice_kwargs)
756
+ ax.set(**ax_kwargs)
757
+
758
+ if title:
759
+ axis_label = data.nxaxes[axis].nxname
760
+ ax.set(title=f'{axis_label}={parameter:{title_fmt}}')
761
+
762
+
763
+ ani = animation.FuncAnimation(fig, update, frames=axis_values, interval=interval, repeat=False)
764
+
765
+ display(HTML(ani.to_jshtml()))
766
+
767
+ if save_gif:
768
+ gif_file = f'{filename}.gif'
769
+ writergif = animation.PillowWriter(fps=1000/interval)
770
+ ani.save(gif_file, writer=writergif)
771
+ display(HTML(ani.to_jshtml()))
772
+ with open(gif_file, 'rb') as file:
773
+ display(Image(file.read(), format='gif'))
774
+
775
+ return ani
776
+
777
+
778
+ class Scissors:
779
+ """
780
+ Scissors class provides functionality for reducing data to a 1D linecut using an integration
781
+ window.
782
+
783
+ Attributes
784
+ ----------
785
+ data : :class:`nexusformat.nexus.NXdata` or None
786
+ Input :class:`nexusformat.nexus.NXdata`.
787
+ center : tuple or None
788
+ Central coordinate around which to perform the linecut.
789
+ window : tuple or None
790
+ Extents of the window for integration along each axis.
791
+ axis : int or None
792
+ Axis along which to perform the integration.
793
+ integration_volume : :class:`nexusformat.nexus.NXdata` or None
794
+ Data array after applying the integration window.
795
+ integrated_axes : tuple or None
796
+ Indices of axes that were integrated.
797
+ linecut : :class:`nexusformat.nexus.NXdata` or None
798
+ 1D linecut data after integration.
799
+ integration_window : tuple or None
800
+ Slice object representing the integration window in the data array.
801
+
802
+ Methods
803
+ -------
804
+ set_data(data)
805
+ Set the input :class:`nexusformat.nexus.NXdata`.
806
+ get_data()
807
+ Get the input :class:`nexusformat.nexus.NXdata`.
808
+ set_center(center)
809
+ Set the central coordinate for the linecut.
810
+ set_window(window, axis=None, verbose=False)
811
+ Set the extents of the integration window.
812
+ get_window()
813
+ Get the extents of the integration window.
814
+ cut_data(center=None, window=None, axis=None, verbose=False)
815
+ Reduce data to a 1D linecut using the integration window.
816
+ highlight_integration_window(data=None, label=None, highlight_color='red', **kwargs)
817
+ Plot the integration window highlighted on a 2D heatmap of the full dataset.
818
+ plot_integration_window(**kwargs)
819
+ Plot a 2D heatmap of the integration window data.
820
+ """
821
+
822
+ def __init__(self, data=None, center=None, window=None, axis=None):
823
+ """
824
+ Initializes a Scissors object.
825
+
826
+ Parameters
827
+ ----------
828
+ data : :class:`nexusformat.nexus.NXdata` or None, optional
829
+ Input NXdata. Default is None.
830
+ center : tuple or None, optional
831
+ Central coordinate around which to perform the linecut. Default is None.
832
+ window : tuple or None, optional
833
+ Extents of the window for integration along each axis. Default is None.
834
+ axis : int or None, optional
835
+ Axis along which to perform the integration. Default is None.
836
+ """
837
+
838
+ self.data = data
839
+ self.center = tuple(float(i) for i in center) if center is not None else None
840
+ self.window = tuple(float(i) for i in window) if window is not None else None
841
+ self.axis = axis
842
+
843
+ self.integration_volume = None
844
+ self.integrated_axes = None
845
+ self.linecut = None
846
+ self.integration_window = None
847
+
848
+ def set_data(self, data):
849
+ """
850
+ Set the input NXdata.
851
+
852
+ Parameters
853
+ ----------
854
+ data : :class:`nexusformat.nexus.NXdata`
855
+ Input data array.
856
+ """
857
+ self.data = data
858
+
859
+ def get_data(self):
860
+ """
861
+ Get the input data array.
862
+
863
+ Returns
864
+ -------
865
+ ndarray or None
866
+ Input data array.
867
+ """
868
+ return self.data
869
+
870
+ def set_center(self, center):
871
+ """
872
+ Set the central coordinate for the linecut.
873
+
874
+ Parameters
875
+ ----------
876
+ center : tuple
877
+ Central coordinate around which to perform the linecut.
878
+ """
879
+ self.center = tuple(float(i) for i in center) if center is not None else None
880
+
881
+ def set_window(self, window, axis=None, verbose=False):
882
+ """
883
+ Set the extents of the integration window.
884
+
885
+ Parameters
886
+ ----------
887
+ window : tuple
888
+ Extents of the window for integration along each axis.
889
+ axis : int or None, optional
890
+ The axis along which to perform the linecut. If not specified, the value from the
891
+ object's attribute will be used.
892
+ verbose : bool, optional
893
+ Enables printout of linecut axis and integrated axes. Default False.
894
+
895
+ """
896
+ self.window = tuple(float(i) for i in window) if window is not None else None
897
+
898
+ # Determine the axis for integration
899
+ self.axis = window.index(max(window)) if axis is None else axis
900
+
901
+ # Determine the integrated axes (axes other than the integration axis)
902
+ self.integrated_axes = tuple(i for i in range(self.data.ndim) if i != self.axis)
903
+
904
+ if verbose:
905
+ print("Linecut axis: " + str(self.data.nxaxes[self.axis].nxname))
906
+ print("Integrated axes: " + str([self.data.nxaxes[axis].nxname
907
+ for axis in self.integrated_axes]))
908
+
909
+ def get_window(self):
910
+ """
911
+ Get the extents of the integration window.
912
+
913
+ Returns
914
+ -------
915
+ tuple or None
916
+ Extents of the integration window.
917
+ """
918
+ return self.window
919
+
920
+ def cut_data(self, center=None, window=None, axis=None, verbose=False):
921
+ """
922
+ Reduces data to a 1D linecut with integration extents specified by the
923
+ window about a central coordinate.
924
+
925
+ Parameters
926
+ ----------
927
+ center : float or None, optional
928
+ Central coordinate for the linecut. If not specified, the value from the object's
929
+ attribute will be used.
930
+ window : tuple or None, optional
931
+ Integration window extents around the central coordinate. If not specified, the value
932
+ from the object's attribute will be used.
933
+ axis : int or None, optional
934
+ The axis along which to perform the linecut. If not specified, the value from the
935
+ object's attribute will be used.
936
+ verbose : bool
937
+ Enables printout of linecut axis and integrated axes. Default False.
938
+
939
+ Returns
940
+ -------
941
+ integrated_data : :class:`nexusformat.nexus.NXdata`
942
+ 1D linecut data after integration.
943
+
944
+ """
945
+
946
+ # Extract necessary attributes from the object
947
+ data = self.data
948
+ center = center if center is not None else self.center
949
+ self.set_center(center)
950
+ window = window if window is not None else self.window
951
+ self.set_window(window, axis, verbose)
952
+
953
+ # Convert the center to a tuple of floats
954
+ center = tuple(float(c) for c in center)
955
+
956
+ # Calculate the start and stop indices for slicing the data
957
+ start = np.subtract(center, window)
958
+ stop = np.add(center, window)
959
+ slice_obj = tuple(slice(s, e) for s, e in zip(start, stop))
960
+ self.integration_window = slice_obj
961
+
962
+ # Perform the data cut
963
+ self.integration_volume = data[slice_obj]
964
+ self.integration_volume.nxname = data.nxname
965
+
966
+ # Perform integration along the integrated axes
967
+ integrated_data = np.sum(self.integration_volume.nxsignal.nxdata,
968
+ axis=self.integrated_axes)
969
+
970
+ # Create an NXdata object for the linecut data
971
+ self.linecut = NXdata(NXfield(integrated_data, name=self.integration_volume.nxsignal.nxname),
972
+ self.integration_volume[self.integration_volume.nxaxes[self.axis].nxname])
973
+ self.linecut.nxname = self.integration_volume.nxname
974
+
975
+ return self.linecut
976
+
977
+ def highlight_integration_window(self, data=None, width=None, height=None,
978
+ label=None, highlight_color='red', **kwargs):
979
+ """
980
+ Plots the integration window highlighted on the three principal 2D cross-sections of a 3D dataset.
981
+
982
+ Parameters
983
+ ----------
984
+ data : array-like, optional
985
+ The 3D dataset to visualize. If not provided, uses `self.data`.
986
+ width : float, optional
987
+ Width of the visible x-axis range in each subplot. Used to zoom in on the integration region.
988
+ height : float, optional
989
+ Height of the visible y-axis range in each subplot. Used to zoom in on the integration region.
990
+ label : str, optional
991
+ Label for the rectangle patch marking the integration window, used in the legend.
992
+ highlight_color : str, optional
993
+ Color of the rectangle edges highlighting the integration window. Default is 'red'.
994
+ **kwargs : dict, optional
995
+ Additional keyword arguments passed to `plot_slice` for customizing the plot (e.g., cmap, vmin, vmax).
996
+
997
+ Returns
998
+ -------
999
+ p1, p2, p3 : matplotlib.collections.QuadMesh
1000
+ The plotted QuadMesh objects for the three cross-sections:
1001
+ XY at fixed Z, XZ at fixed Y, and YZ at fixed X.
1002
+
1003
+ """
1004
+ data = self.data if data is None else data
1005
+ center = self.center
1006
+ window = self.window
1007
+
1008
+ # Create a figure and subplots
1009
+ fig, axes = plt.subplots(1, 3, figsize=(15, 4))
1010
+
1011
+ # Plot cross-section 1
1012
+ slice_obj = [slice(None)] * data.ndim
1013
+ slice_obj[2] = center[2]
1014
+
1015
+ p1 = plot_slice(data[slice_obj],
1016
+ X=data.nxaxes[0],
1017
+ Y=data.nxaxes[1],
1018
+ ax=axes[0],
1019
+ **kwargs)
1020
+ ax = axes[0]
1021
+ rect_diffuse = patches.Rectangle(
1022
+ (center[0] - window[0],
1023
+ center[1] - window[1]),
1024
+ 2 * window[0], 2 * window[1],
1025
+ linewidth=1, edgecolor=highlight_color,
1026
+ facecolor='none', transform=p1.get_transform(), label=label,
1027
+ )
1028
+ ax.add_patch(rect_diffuse)
1029
+
1030
+ if 'xlim' not in kwargs and width is not None:
1031
+ ax.set(xlim=(center[0] - width / 2, center[0] + width / 2))
1032
+ if 'ylim' not in kwargs and height is not None:
1033
+ ax.set(ylim=(center[1] - height / 2, center[1] + height / 2))
1034
+
1035
+ # Plot cross-section 2
1036
+ slice_obj = [slice(None)] * data.ndim
1037
+ slice_obj[1] = center[1]
1038
+
1039
+ p2 = plot_slice(data[slice_obj],
1040
+ X=data.nxaxes[0],
1041
+ Y=data.nxaxes[2],
1042
+ ax=axes[1],
1043
+ **kwargs)
1044
+ ax = axes[1]
1045
+ rect_diffuse = patches.Rectangle(
1046
+ (center[0] - window[0],
1047
+ center[2] - window[2]),
1048
+ 2 * window[0], 2 * window[2],
1049
+ linewidth=1, edgecolor=highlight_color,
1050
+ facecolor='none', transform=p2.get_transform(), label=label,
1051
+ )
1052
+ ax.add_patch(rect_diffuse)
1053
+
1054
+ if 'xlim' not in kwargs and width is not None:
1055
+ ax.set(xlim=(center[0] - width / 2, center[0] + width / 2))
1056
+ if 'ylim' not in kwargs and height is not None:
1057
+ ax.set(ylim=(center[2] - height / 2, center[2] + height / 2))
1058
+
1059
+ # Plot cross-section 3
1060
+ slice_obj = [slice(None)] * data.ndim
1061
+ slice_obj[0] = center[0]
1062
+
1063
+ p3 = plot_slice(data[slice_obj],
1064
+ X=data.nxaxes[1],
1065
+ Y=data.nxaxes[2],
1066
+ ax=axes[2],
1067
+ **kwargs)
1068
+ ax = axes[2]
1069
+ rect_diffuse = patches.Rectangle(
1070
+ (center[1] - window[1],
1071
+ center[2] - window[2]),
1072
+ 2 * window[1], 2 * window[2],
1073
+ linewidth=1, edgecolor=highlight_color,
1074
+ facecolor='none', transform=p3.get_transform(), label=label,
1075
+ )
1076
+ ax.add_patch(rect_diffuse)
1077
+
1078
+ # If width and height are provided, center the view on the linecut area
1079
+ if 'xlim' not in kwargs and width is not None:
1080
+ ax.set(xlim=(center[1] - width / 2, center[1] + width / 2))
1081
+ if 'ylim' not in kwargs and height is not None:
1082
+ ax.set(ylim=(center[2] - height / 2, center[2] + height / 2))
1083
+
1084
+ # Adjust subplot padding
1085
+ fig.subplots_adjust(wspace=0.5)
1086
+
1087
+ if label is not None:
1088
+ [ax.legend() for ax in axes]
1089
+
1090
+ plt.show()
1091
+
1092
+ return p1, p2, p3
1093
+
1094
+ def plot_integration_window(self, **kwargs):
1095
+ """
1096
+ Plots the three principal cross-sections of the integration volume on a single figure.
1097
+
1098
+ Parameters
1099
+ ----------
1100
+ **kwargs : keyword arguments, optional
1101
+ Additional keyword arguments to customize the plot.
1102
+ """
1103
+ data = self.integration_volume
1104
+ center = self.center
1105
+
1106
+ fig, axes = plt.subplots(1, 3, figsize=(15, 4))
1107
+
1108
+ # Plot cross-section 1
1109
+ slice_obj = [slice(None)] * data.ndim
1110
+ slice_obj[2] = center[2]
1111
+ p1 = plot_slice(data[slice_obj],
1112
+ X=data.nxaxes[0],
1113
+ Y=data.nxaxes[1],
1114
+ ax=axes[0],
1115
+ **kwargs)
1116
+ axes[0].set_aspect(len(data.nxaxes[0].nxdata) / len(data.nxaxes[1].nxdata))
1117
+
1118
+ # Plot cross section 2
1119
+ slice_obj = [slice(None)] * data.ndim
1120
+ slice_obj[1] = center[1]
1121
+ p3 = plot_slice(data[slice_obj],
1122
+ X=data.nxaxes[0],
1123
+ Y=data.nxaxes[2],
1124
+ ax=axes[1],
1125
+ **kwargs)
1126
+ axes[1].set_aspect(len(data.nxaxes[0].nxdata) / len(data.nxaxes[2].nxdata))
1127
+
1128
+ # Plot cross-section 3
1129
+ slice_obj = [slice(None)] * data.ndim
1130
+ slice_obj[0] = center[0]
1131
+ p2 = plot_slice(data[slice_obj],
1132
+ X=data.nxaxes[1],
1133
+ Y=data.nxaxes[2],
1134
+ ax=axes[2],
1135
+ **kwargs)
1136
+ axes[2].set_aspect(len(data.nxaxes[1].nxdata) / len(data.nxaxes[2].nxdata))
1137
+
1138
+ # Adjust subplot padding
1139
+ fig.subplots_adjust(wspace=0.3)
1140
+
1141
+ plt.show()
1142
+
1143
+ return p1, p2, p3
1144
+
1145
+
1146
+ def reciprocal_lattice_params(lattice_params):
1147
+ """
1148
+ Calculate the reciprocal lattice parameters from the given direct lattice parameters.
1149
+
1150
+ Parameters
1151
+ ----------
1152
+ lattice_params : tuple
1153
+ A tuple containing the direct lattice parameters (a, b, c, alpha, beta, gamma), where
1154
+ a, b, and c are the magnitudes of the lattice vectors, and alpha, beta, and gamma are the
1155
+ angles between them in degrees.
1156
+
1157
+ Returns
1158
+ -------
1159
+ tuple
1160
+ A tuple containing the reciprocal lattice parameters (a*, b*, c*, alpha*, beta*, gamma*),
1161
+ where a*, b*, and c* are the magnitudes of the reciprocal lattice vectors, and alpha*,
1162
+ beta*, and gamma* are the angles between them in degrees.
1163
+ """
1164
+ a_mag, b_mag, c_mag, alpha, beta, gamma = lattice_params
1165
+ # Convert angles to radians
1166
+ alpha = np.deg2rad(alpha)
1167
+ beta = np.deg2rad(beta)
1168
+ gamma = np.deg2rad(gamma)
1169
+
1170
+ # Calculate unit cell volume
1171
+ V = a_mag * b_mag * c_mag * np.sqrt(
1172
+ 1 - np.cos(alpha) ** 2 - np.cos(beta) ** 2 - np.cos(gamma) ** 2
1173
+ + 2 * np.cos(alpha) * np.cos(beta) * np.cos(gamma)
1174
+ )
1175
+
1176
+ # Calculate reciprocal lattice parameters
1177
+ a_star = (b_mag * c_mag * np.sin(alpha)) / V
1178
+ b_star = (a_mag * c_mag * np.sin(beta)) / V
1179
+ c_star = (a_mag * b_mag * np.sin(gamma)) / V
1180
+ alpha_star = np.rad2deg(np.arccos((np.cos(beta) * np.cos(gamma) - np.cos(alpha))
1181
+ / (np.sin(beta) * np.sin(gamma))))
1182
+ beta_star = np.rad2deg(np.arccos((np.cos(alpha) * np.cos(gamma) - np.cos(beta))
1183
+ / (np.sin(alpha) * np.sin(gamma))))
1184
+ gamma_star = np.rad2deg(np.arccos((np.cos(alpha) * np.cos(beta) - np.cos(gamma))
1185
+ / (np.sin(alpha) * np.sin(beta))))
1186
+
1187
+ return a_star, b_star, c_star, alpha_star, beta_star, gamma_star
1188
+
1189
+
1190
+ def convert_to_inverse_angstroms(data, lattice_params):
1191
+ """
1192
+ Convert the axes of a 3D NXdata object from reciprocal lattice units (r.l.u.)
1193
+ to inverse angstroms using provided lattice parameters.
1194
+
1195
+ Parameters
1196
+ ----------
1197
+ data : :class:`nexusformat.nexus.NXdata`
1198
+ A 3D NXdata object with axes in reciprocal lattice units.
1199
+
1200
+ lattice_params : tuple of float
1201
+ A tuple containing the real-space lattice parameters
1202
+ (a, b, c, alpha, beta, gamma) in angstroms and degrees.
1203
+
1204
+ Returns
1205
+ -------
1206
+ NXdata
1207
+ A new NXdata object with axes scaled to inverse angstroms.
1208
+ """
1209
+
1210
+ a_, b_, c_, al_, be_, ga_ = reciprocal_lattice_params(lattice_params)
1211
+
1212
+ new_data = data.nxsignal
1213
+ a_star = NXfield(data.nxaxes[0].nxdata * a_, name='a_star')
1214
+ b_star = NXfield(data.nxaxes[1].nxdata * b_, name='b_star')
1215
+ c_star = NXfield(data.nxaxes[2].nxdata * c_, name='c_star')
1216
+
1217
+ return NXdata(new_data, (a_star, b_star, c_star))
1218
+
1219
+
1220
+ def rotate_data(data, lattice_angle, rotation_angle, rotation_axis=None, rotation_order=None, aspect=None, aspect_order=None, printout=False):
1221
+ """
1222
+ Rotates slices of data around the normal axis.
1223
+
1224
+ Parameters
1225
+ ----------
1226
+ data : :class:`nexusformat.nexus.NXdata`
1227
+ Input data.
1228
+ lattice_angle : float
1229
+ Angle between the two in-plane lattice axes in degrees.
1230
+ rotation_angle : float
1231
+ Angle of rotation in degrees.
1232
+ rotation_axis : int, optional
1233
+ Axis of rotation (0, 1, or 2). Only necessary when data is three-dimensional.
1234
+ rotation_order : int, optional
1235
+ Interpolation order passed to :func:`scipy.ndimage.rotate`.
1236
+ Determines the spline interpolation used during rotation.
1237
+ Valid values are integers from 0 (nearest-neighbor) to 5 (higher-order splines).
1238
+ Defaults to 0 if not specified.
1239
+ aspect : float, optional
1240
+ True aspect ratio between the lengths of the basis vectors of the two principal axes of the plane to be rotated. Calculated as aspect = (length of y) / (length of x). Defaults to 1.
1241
+ aspect_order : int, optional
1242
+ Interpolation order passed to :func:`scipy.ndimage.zoom` when applying and undoing
1243
+ the coordinate aspect ratio correction. Determines the spline interpolation used
1244
+ during resampling. Valid values are integers from 0 (nearest-neighbor) to 5.
1245
+ Defaults to 0 if not specified.
1246
+ printout : bool, optional
1247
+ Enables printout of rotation progress for three-dimensional data. If set to True,
1248
+ information about each rotation slice will be printed to the console, indicating
1249
+ the axis being rotated and the corresponding coordinate value. Defaults to False.
1250
+
1251
+
1252
+ Returns
1253
+ -------
1254
+ rotated_data : :class:`nexusformat.nexus.NXdata`
1255
+ Rotated data as an NXdata object.
1256
+ """
1257
+ if aspect is None:
1258
+ aspect = 1
1259
+ if aspect_order is None:
1260
+ aspect_order = 0
1261
+ if rotation_order is None:
1262
+ rotation_order = 0
1263
+
1264
+ if data.ndim == 3 and rotation_axis is None:
1265
+ raise ValueError('rotation_axis must be specified for three-dimensional datasets.')
1266
+
1267
+ if not((data.ndim == 2) or (data.ndim == 3)):
1268
+ raise ValueError('Data must be 2 or 3 dimensional.')
1269
+
1270
+ # Define output array
1271
+ output_array = np.zeros(data.nxsignal.shape)
1272
+
1273
+ # Iterate over all layers perpendicular to the rotation axis
1274
+ if data.ndim == 3:
1275
+ num_slices = len(data.nxaxes[rotation_axis])
1276
+ elif data.ndim == 2:
1277
+ num_slices = 1
1278
+
1279
+ for i in range(num_slices):
1280
+
1281
+ if data.ndim == 3:
1282
+ # Print progress
1283
+ if printout:
1284
+ print(f'\rRotating {data.nxaxes[rotation_axis].nxname}'
1285
+ f'={data.nxaxes[rotation_axis][i]}... ',
1286
+ end='', flush=True)
1287
+ index = [slice(None)] * 3
1288
+ index[rotation_axis] = i
1289
+ sliced_data = data[tuple(index)]
1290
+
1291
+ elif data.ndim == 2:
1292
+ sliced_data = data
1293
+
1294
+ # Add padding to avoid data cutoff during rotation
1295
+ p = Padder(sliced_data)
1296
+ padding = tuple(len(axis) for axis in sliced_data.nxaxes)
1297
+ counts = p.pad(padding)
1298
+ counts = p.padded.nxsignal
1299
+
1300
+ # Skew data to match lattice angle
1301
+ t = ShearTransformer(lattice_angle)
1302
+ counts = t.apply(counts)
1303
+
1304
+ # Apply coordinate aspect ratio correction
1305
+ # More resolution along y = more squeezing needed along y
1306
+ # More resolution along x = less squeezing needed along y
1307
+ y_res = sliced_data.shape[1] / (sliced_data.nxaxes[1].max() - sliced_data.nxaxes[1].min())
1308
+ x_res = sliced_data.shape[0] / (sliced_data.nxaxes[0].max() - sliced_data.nxaxes[0].min())
1309
+ counts = zoom(counts, zoom=(1, aspect * x_res / y_res), order=aspect_order)
1310
+
1311
+ # Perform rotation
1312
+ counts = rotate(counts, rotation_angle, reshape=False, order=rotation_order)
1313
+
1314
+ # Undo aspect ratio correction
1315
+ counts = zoom(counts, zoom=(1, 1 / (aspect * x_res / y_res)), order=aspect_order)
1316
+
1317
+ # Undo skew transformation
1318
+ counts = t.invert(counts)
1319
+
1320
+ # Remove padding
1321
+ counts = p.unpad(counts)
1322
+
1323
+ # Write slice
1324
+ if data.ndim == 3:
1325
+ index = [slice(None)] * 3
1326
+ index[rotation_axis] = i
1327
+ output_array[tuple(index)] = counts
1328
+ elif data.ndim == 2:
1329
+ output_array = counts
1330
+
1331
+ print('\nRotation completed.')
1332
+
1333
+ return NXdata(NXfield(output_array, name=p.padded.nxsignal.nxname),
1334
+ ([axis for axis in data.nxaxes]))
1335
+
1336
+
1337
+
1338
+ def rotate_data_2D(data, lattice_angle, rotation_angle):
1339
+ """
1340
+ DEPRECATED: Use `rotate_data` instead.
1341
+
1342
+ Rotates 2D data.
1343
+
1344
+ Parameters
1345
+ ----------
1346
+ data : :class:`nexusformat.nexus.NXdata`
1347
+ Input data.
1348
+ lattice_angle : float
1349
+ Angle between the two in-plane lattice axes in degrees.
1350
+ rotation_angle : float
1351
+ Angle of rotation in degrees.
1352
+
1353
+ Returns
1354
+ -------
1355
+ rotated_data : :class:`nexusformat.nexus.NXdata`
1356
+ Rotated data as an NXdata object.
1357
+ """
1358
+ warnings.warn(
1359
+ "rotate_data_2D is deprecated and will be removed in a future release. "
1360
+ "Use rotate_data instead.",
1361
+ DeprecationWarning,
1362
+ stacklevel=2,
1363
+ )
1364
+
1365
+ # Call the new general function
1366
+ return rotate_data(data, lattice_angle=lattice_angle, rotation_angle=rotation_angle)
1367
+
1368
+
1369
+ class Padder:
1370
+ """
1371
+ A class to symmetrically pad and unpad datasets with a region of zeros.
1372
+
1373
+ Attributes
1374
+ ----------
1375
+ data : NXdata or None
1376
+ The input data to be padded.
1377
+ padded : NXdata or None
1378
+ The padded data with symmetric zero padding.
1379
+ padding : tuple or None
1380
+ The number of zero-value pixels added along each edge of the array.
1381
+ steps : tuple or None
1382
+ The step sizes along each axis of the dataset.
1383
+ maxes : tuple or None
1384
+ The maximum values along each axis of the dataset.
1385
+
1386
+ Methods
1387
+ -------
1388
+ set_data(data)
1389
+ Set the input data for padding.
1390
+ pad(padding)
1391
+ Symmetrically pads the data with zero values.
1392
+ save(fout_name=None)
1393
+ Saves the padded dataset to a .nxs file.
1394
+ unpad(data)
1395
+ Removes the padded region from the data.
1396
+ """
1397
+
1398
+ def __init__(self, data=None):
1399
+ """
1400
+ Initialize the Padder object.
1401
+
1402
+ Parameters
1403
+ ----------
1404
+ data : NXdata, optional
1405
+ The input data to be padded. If provided, the `set_data` method
1406
+ is called to set the data.
1407
+ """
1408
+ self.maxes = None
1409
+ self.steps = None
1410
+ self.data = None
1411
+ self.padded = None
1412
+ self.padding = None
1413
+ if data is not None:
1414
+ self.set_data(data)
1415
+
1416
+ def set_data(self, data):
1417
+ """
1418
+ Set the input data for padding.
1419
+
1420
+ Parameters
1421
+ ----------
1422
+ data : NXdata
1423
+ The input data to be padded.
1424
+ """
1425
+ self.data = data
1426
+
1427
+ self.steps = tuple((axis.nxdata[1] - axis.nxdata[0])
1428
+ for axis in data.nxaxes)
1429
+
1430
+ # Absolute value of the maximum value; assumes the domain of the input
1431
+ # is symmetric (eg, -H_min = H_max)
1432
+ self.maxes = tuple(axis.nxdata.max() for axis in data.nxaxes)
1433
+
1434
+ def pad(self, padding):
1435
+ """
1436
+ Symmetrically pads the data with zero values.
1437
+
1438
+ Parameters
1439
+ ----------
1440
+ padding : tuple
1441
+ The number of zero-value pixels to add along each edge of the array.
1442
+
1443
+ Returns
1444
+ -------
1445
+ NXdata
1446
+ The padded data with symmetric zero padding.
1447
+ """
1448
+ data = self.data
1449
+ self.padding = padding
1450
+
1451
+ padded_shape = tuple(data.nxsignal.nxdata.shape[i]
1452
+ + self.padding[i] * 2 for i in range(data.ndim))
1453
+
1454
+ # Create padded dataset
1455
+ padded = np.zeros(padded_shape)
1456
+
1457
+ slice_obj = [slice(None)] * data.ndim
1458
+ for i, _ in enumerate(slice_obj):
1459
+ slice_obj[i] = slice(self.padding[i], -self.padding[i], None)
1460
+ slice_obj = tuple(slice_obj)
1461
+ padded[slice_obj] = data.nxsignal.nxdata
1462
+
1463
+ padmaxes = tuple(self.maxes[i] + self.padding[i] * self.steps[i]
1464
+ for i in range(data.ndim))
1465
+
1466
+ padded = NXdata(NXfield(padded, name=data.nxsignal.nxname),
1467
+ tuple(NXfield(np.linspace(-padmaxes[i], padmaxes[i], padded_shape[i]),
1468
+ name=data.nxaxes[i].nxname)
1469
+ for i in range(data.ndim)))
1470
+
1471
+ self.padded = padded
1472
+ return padded
1473
+
1474
+ def save(self, fout_name=None):
1475
+ """
1476
+ Saves the padded dataset to a .nxs file.
1477
+
1478
+ Parameters
1479
+ ----------
1480
+ fout_name : str, optional
1481
+ The output file name. Default is padded_(Hpadding)_(Kpadding)_(Lpadding).nxs
1482
+ """
1483
+ padH, padK, padL = self.padding
1484
+
1485
+ # Save padded dataset
1486
+ print("Saving padded dataset...")
1487
+ f = NXroot()
1488
+ f['entry'] = NXentry()
1489
+ f['entry']['data'] = self.padded
1490
+ if fout_name is None:
1491
+ fout_name = 'padded_' + str(padH) + '_' + str(padK) + '_' + str(padL) + '.nxs'
1492
+ nxsave(fout_name, f)
1493
+ print("Output file saved to: " + os.path.join(os.getcwd(), fout_name))
1494
+
1495
+ def unpad(self, data):
1496
+ """
1497
+ Removes the padded region from the data.
1498
+
1499
+ Parameters
1500
+ ----------
1501
+ data : ndarray or NXdata
1502
+ The padded data from which to remove the padding.
1503
+
1504
+ Returns
1505
+ -------
1506
+ ndarray or NXdata
1507
+ The unpadded data, with the symmetric padding region removed.
1508
+ """
1509
+ slice_obj = [slice(None)] * data.ndim
1510
+ for i in range(data.ndim):
1511
+ slice_obj[i] = slice(self.padding[i], -self.padding[i], None)
1512
+ slice_obj = tuple(slice_obj)
1513
+ return data[slice_obj]
1514
+
1515
+
1516
+ def load_discus_nxs(path):
1517
+ """
1518
+ Load .nxs format data from the DISCUS program (by T. Proffen and R. Neder)
1519
+ and convert it to the CHESS format.
1520
+
1521
+ Parameters
1522
+ ----------
1523
+ path : str
1524
+ The file path to the .nxs file generated by DISCUS.
1525
+
1526
+ Returns
1527
+ -------
1528
+ NXdata
1529
+ The data converted to the CHESS format, with axes labeled 'H', 'K', and 'L',
1530
+ and the signal labeled 'counts'.
1531
+
1532
+ """
1533
+ filename = path
1534
+ root = nxload(filename)
1535
+ hlim, klim, llim = root.lower_limits
1536
+ hstep, kstep, lstep = root.step_sizes
1537
+ h = NXfield(np.linspace(hlim, -hlim, int(np.abs(hlim * 2) / hstep) + 1), name='H')
1538
+ k = NXfield(np.linspace(klim, -klim, int(np.abs(klim * 2) / kstep) + 1), name='K')
1539
+ l = NXfield(np.linspace(llim, -llim, int(np.abs(llim * 2) / lstep) + 1), name='L')
1540
+ data = NXdata(NXfield(root.data[:, :, :], name='counts'), (h, k, l))
1541
+
1542
+ return data