nxs-analysis-tools 0.0.32__py3-none-any.whl → 0.0.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nxs-analysis-tools might be problematic. Click here for more details.

@@ -1,887 +1,989 @@
1
- """
2
- Reduces scattering data into 2D and 1D datasets.
3
- """
4
- import os
5
- import numpy as np
6
- import matplotlib.pyplot as plt
7
- from matplotlib.transforms import Affine2D
8
- from matplotlib.markers import MarkerStyle
9
- from matplotlib.ticker import MultipleLocator
10
- from matplotlib import colors
11
- from matplotlib import patches
12
- from IPython.display import display, Markdown
13
- from nexusformat.nexus import NXfield, NXdata, nxload, NeXusError, NXroot, NXentry, nxsave
14
- from scipy import ndimage
15
-
16
- # Specify items on which users are allowed to perform standalone imports
17
- __all__ = ['load_data', 'plot_slice', 'Scissors', 'reciprocal_lattice_params', 'rotate_data', 'array_to_nxdata', 'Padder']
18
-
19
-
20
- def load_data(path):
21
- """
22
- Load data from a specified path.
23
-
24
- Parameters
25
- ----------
26
- path : str
27
- The path to the data file.
28
-
29
- Returns
30
- -------
31
- data : nxdata object
32
- The loaded data stored in a nxdata object.
33
-
34
- """
35
- g = nxload(path)
36
- try:
37
- print(g.entry.data.tree)
38
- except NeXusError:
39
- pass
40
-
41
- return g.entry.data
42
-
43
-
44
- def array_to_nxdata(array, data_template, signal_name='counts'):
45
- """
46
- Create an NXdata object from an input array and an NXdata template, with an optional signal name.
47
-
48
- Parameters
49
- ----------
50
- array : array-like
51
- The data array to be included in the NXdata object.
52
-
53
- data_template : NXdata
54
- An NXdata object serving as a template, which provides information about axes and other metadata.
55
-
56
- signal_name : str, optional
57
- The name of the signal within the NXdata object. If not provided,
58
- the default signal name 'counts' is used.
59
-
60
- Returns
61
- -------
62
- NXdata
63
- An NXdata object containing the input data array and associated axes based on the template.
64
- """
65
- d = data_template
66
- return NXdata(NXfield(array, name=signal_name), tuple([d[d.axes[i]] for i in range(len(d.axes))]))
67
-
68
-
69
- def plot_slice(data, X=None, Y=None, transpose=False, vmin=None, vmax=None, skew_angle=90,
70
- ax=None, xlim=None, ylim=None, xticks=None, yticks=None, cbar=True, logscale=False,
71
- symlogscale=False, cmap='viridis', linthresh=1, title=None, mdheading=None, cbartitle=None,
72
- **kwargs):
73
- """
74
- Parameters
75
- ----------
76
- data : :class:`nexusformat.nexus.NXdata` object or ndarray
77
- The NXdata object containing the dataset to plot.
78
-
79
- X : NXfield, optional
80
- The X axis values. Default is first axis of `data`.
81
-
82
- Y : NXfield, optional
83
- The y axis values. Default is second axis of `data`.
84
-
85
- transpose : bool, optional
86
- If True, tranpose the dataset and its axes before plotting. Default is False.
87
-
88
- vmin : float, optional
89
- The minimum value to plot in the dataset.
90
- If not provided, the minimum of the dataset will be used.
91
-
92
- vmax : float, optional
93
- The maximum value to plot in the dataset.
94
- If not provided, the maximum of the dataset will be used.
95
-
96
- skew_angle : float, optional
97
- The angle to shear the plot in degrees. Defaults to 90 degrees (no skewing).
98
-
99
- ax : matplotlib.axes.Axes, optional
100
- An optional axis object to plot the heatmap onto.
101
-
102
- xlim : tuple, optional
103
- The limits of the x-axis. If not provided, the limits will be automatically set.
104
-
105
- ylim : tuple, optional
106
- The limits of the y-axis. If not provided, the limits will be automatically set.
107
-
108
- xticks : float, optional
109
- The major tick interval for the x-axis.
110
- If not provided, the function will use a default minor tick interval of 1.
111
-
112
- yticks : float, optional
113
- The major tick interval for the y-axis.
114
- If not provided, the function will use a default minor tick interval of 1.
115
-
116
- cbar : bool, optional
117
- Whether to include a colorbar in the plot. Defaults to True.
118
-
119
- logscale : bool, optional
120
- Whether to use a logarithmic color scale. Defaults to False.
121
-
122
- symlogscale : bool, optional
123
- Whether to use a symmetrical logarithmic color scale. Defaults to False.
124
-
125
- cmap : str or Colormap, optional
126
- The color map to use. Defaults to 'viridis'.
127
-
128
- linthresh : float, optional
129
- The linear threshold for the symmetrical logarithmic color scale. Defaults to 1.
130
-
131
- mdheading : str, optional
132
- A string containing the Markdown heading for the plot. Default `None`.
133
-
134
- Returns
135
- -------
136
- p : :class:`matplotlib.collections.QuadMesh`
137
-
138
- A :class:`matplotlib.collections.QuadMesh` object, to mimick behavior of
139
- :class:`matplotlib.pyplot.pcolormesh`.
140
-
141
- """
142
- if type(data) == np.ndarray:
143
- if X is None:
144
- X = NXfield(np.linspace(0, data.shape[1], data.shape[1]), name='x')
145
- if Y is None:
146
- Y = NXfield(np.linspace(0, data.shape[0], data.shape[0]), name='y')
147
- if transpose:
148
- X, Y = Y, X
149
- data = data.transpose()
150
- data = NXdata(NXfield(data, name='value'), (X, Y))
151
- data_arr = data
152
- elif type(data) == NXdata or type(data) == NXfield:
153
- if X is None:
154
- X = data[data.axes[0]]
155
- if Y is None:
156
- Y = data[data.axes[1]]
157
- if transpose:
158
- X, Y = Y, X
159
- data = data.transpose()
160
- data_arr = data[data.signal].nxdata.transpose()
161
- else:
162
- raise TypeError(f"Unexpected data type: {type(data)}. Supported types are np.ndarray and NXdata.")
163
-
164
- # Display Markdown heading
165
- if mdheading is None:
166
- pass
167
- elif mdheading == "None":
168
- display(Markdown('### Figure'))
169
- else:
170
- display(Markdown('### Figure - ' + mdheading))
171
-
172
- # Inherit axes if user provides some
173
- if ax is not None:
174
- ax = ax
175
- fig = ax.get_figure()
176
- # Otherwise set up some default axes
177
- else:
178
- fig = plt.figure()
179
- ax = fig.add_axes([0, 0, 1, 1])
180
-
181
- # If limits not provided, use extrema
182
- if vmin is None:
183
- vmin = data_arr.min()
184
- if vmax is None:
185
- vmax = data_arr.max()
186
-
187
- # Set norm (linear scale, logscale, or symlogscale)
188
- norm = colors.Normalize(vmin=vmin, vmax=vmax) # Default: linear scale
189
-
190
- if symlogscale:
191
- norm = colors.SymLogNorm(linthresh=linthresh, vmin=-1 * vmax, vmax=vmax)
192
- elif logscale:
193
- norm = colors.LogNorm(vmin=vmin, vmax=vmax)
194
-
195
- # Plot data
196
- p = ax.pcolormesh(X.nxdata, Y.nxdata, data_arr, shading='auto', norm=norm, cmap=cmap, **kwargs)
197
-
198
- ## Transform data to new coordinate system if necessary
199
- # Correct skew angle
200
- skew_angle_adj = 90 - skew_angle
201
- # Create blank 2D affine transformation
202
- t = Affine2D()
203
- # Scale y-axis to preserve norm while shearing
204
- t += Affine2D().scale(1, np.cos(skew_angle_adj * np.pi / 180))
205
- # Shear along x-axis
206
- t += Affine2D().skew_deg(skew_angle_adj, 0)
207
- # Return to original y-axis scaling
208
- t += Affine2D().scale(1, np.cos(skew_angle_adj * np.pi / 180)).inverted()
209
- ## Correct for x-displacement after shearing
210
- # If ylims provided, use those
211
- if ylim is not None:
212
- # Set ylims
213
- ax.set(ylim=ylim)
214
- ymin, ymax = ylim
215
- # Else, use current ylims
216
- else:
217
- ymin, ymax = ax.get_ylim()
218
- # Use ylims to calculate translation (necessary to display axes in correct position)
219
- p.set_transform(t + Affine2D().translate(-ymin * np.sin(skew_angle_adj * np.pi / 180), 0) + ax.transData)
220
-
221
- # Set x limits
222
- if xlim is not None:
223
- xmin, xmax = xlim
224
- else:
225
- xmin, xmax = ax.get_xlim()
226
- if skew_angle <= 90:
227
- ax.set(xlim=(xmin, xmax + (ymax - ymin) / np.tan((90 - skew_angle_adj) * np.pi / 180)))
228
- else:
229
- ax.set(xlim=(xmin - (ymax - ymin) / np.tan((skew_angle_adj - 90) * np.pi / 180), xmax))
230
-
231
- # Correct aspect ratio for the x/y axes after transformation
232
- ax.set(aspect=np.cos(skew_angle_adj * np.pi / 180))
233
-
234
- # Add tick marks all around
235
- ax.tick_params(direction='in', top=True, right=True, which='both')
236
-
237
- # Set tick locations
238
- if xticks is None:
239
- # Add default minor ticks
240
- ax.xaxis.set_minor_locator(MultipleLocator(1))
241
- else:
242
- # Otherwise use user provided values
243
- ax.xaxis.set_major_locator(MultipleLocator(xticks))
244
- ax.xaxis.set_minor_locator(MultipleLocator(1))
245
- if yticks is None:
246
- # Add default minor ticks
247
- ax.yaxis.set_minor_locator(MultipleLocator(1))
248
- else:
249
- # Otherwise use user provided values
250
- ax.yaxis.set_major_locator(MultipleLocator(yticks))
251
- ax.yaxis.set_minor_locator(MultipleLocator(1))
252
-
253
- # Apply transform to tick marks
254
- for i in range(0, len(ax.xaxis.get_ticklines())):
255
- # Tick marker
256
- m = MarkerStyle(3)
257
- line = ax.xaxis.get_majorticklines()[i]
258
- if i % 2:
259
- # Top ticks (translation here makes their direction="in")
260
- m._transform.set(Affine2D().translate(0, -1) + Affine2D().skew_deg(skew_angle_adj, 0))
261
- # This first method shifts the top ticks horizontally to match the skew angle.
262
- # This does not look good in all cases.
263
- # line.set_transform(Affine2D().translate((ymax-ymin)*np.sin(skew_angle*np.pi/180),0) +
264
- # line.get_transform())
265
- # This second method skews the tick marks in place and
266
- # can sometimes lead to them being misaligned.
267
- line.set_transform(line.get_transform()) # This does nothing
268
- else:
269
- # Bottom ticks
270
- m._transform.set(Affine2D().skew_deg(skew_angle_adj, 0))
271
-
272
- line.set_marker(m)
273
-
274
- for i in range(0, len(ax.xaxis.get_minorticklines())):
275
- m = MarkerStyle(2)
276
- line = ax.xaxis.get_minorticklines()[i]
277
- if i % 2:
278
- m._transform.set(Affine2D().translate(0, -1) + Affine2D().skew_deg(skew_angle_adj, 0))
279
- else:
280
- m._transform.set(Affine2D().skew_deg(skew_angle_adj, 0))
281
-
282
- line.set_marker(m)
283
-
284
- if cbar:
285
- colorbar = fig.colorbar(p)
286
- if cbartitle is None:
287
- colorbar.set_label(data.signal)
288
-
289
- ax.set(
290
- xlabel=X.nxname,
291
- ylabel=Y.nxname,
292
- )
293
-
294
- if title is not None:
295
- ax.set_title(title)
296
-
297
- # Return the quadmesh object
298
- return p
299
-
300
-
301
- class Scissors:
302
- """
303
- Scissors class provides functionality for reducing data to a 1D linecut using an integration
304
- window.
305
-
306
- Attributes
307
- ----------
308
- data : :class:`nexusformat.nexus.NXdata` or None
309
- Input :class:`nexusformat.nexus.NXdata`.
310
- center : tuple or None
311
- Central coordinate around which to perform the linecut.
312
- window : tuple or None
313
- Extents of the window for integration along each axis.
314
- axis : int or None
315
- Axis along which to perform the integration.
316
- data_cut : ndarray or None
317
- Data array after applying the integration window.
318
- integrated_axes : tuple or None
319
- Indices of axes that were integrated.
320
- linecut : :class:`nexusformat.nexus.NXdata` or None
321
- 1D linecut data after integration.
322
- window_plane_slice_obj : list or None
323
- Slice object representing the integration window in the data array.
324
-
325
- Methods
326
- -------
327
- set_data(data)
328
- Set the input :class:`nexusformat.nexus.NXdata`
329
- get_data()
330
- Get the input :class:`nexusformat.nexus.NXdata`.
331
- set_center(center)
332
- Set the central coordinate for the linecut.
333
- set_window(window)
334
- Set the extents of the integration window.
335
- get_window()
336
- Get the extents of the integration window.
337
- cut_data(axis=None)
338
- Reduce data to a 1D linecut using the integration window.
339
- show_integration_window(label=None)
340
- Plot the integration window highlighted on a 2D heatmap of the full dataset.
341
- plot_window()
342
- Plot a 2D heatmap of the integration window data.
343
- """
344
-
345
- def __init__(self, data=None, center=None, window=None, axis=None):
346
- """
347
- Initializes a Scissors object.
348
-
349
- Parameters
350
- ----------
351
- data : :class:`nexusformat.nexus.NXdata` or None, optional
352
- Input NXdata. Default is None.
353
- center : tuple or None, optional
354
- Central coordinate around which to perform the linecut. Default is None.
355
- window : tuple or None, optional
356
- Extents of the window for integration along each axis. Default is None.
357
- axis : int or None, optional
358
- Axis along which to perform the integration. Default is None.
359
- """
360
-
361
- self.data = data
362
- self.center = tuple([float(i) for i in center]) if center is not None else None
363
- self.window = tuple([float(i) for i in window]) if window is not None else None
364
- self.axis = axis
365
-
366
- self.integration_volume = None
367
- self.integrated_axes = None
368
- self.linecut = None
369
- self.integration_window = None
370
-
371
- def set_data(self, data):
372
- """
373
- Set the input NXdata.
374
-
375
- Parameters
376
- ----------
377
- data : :class:`nexusformat.nexus.NXdata`
378
- Input data array.
379
- """
380
- self.data = data
381
-
382
- def get_data(self):
383
- """
384
- Get the input data array.
385
-
386
- Returns
387
- -------
388
- ndarray or None
389
- Input data array.
390
- """
391
- return self.data
392
-
393
- def set_center(self, center):
394
- """
395
- Set the central coordinate for the linecut.
396
-
397
- Parameters
398
- ----------
399
- center : tuple
400
- Central coordinate around which to perform the linecut.
401
- """
402
- self.center = tuple([float(i) for i in center]) if center is not None else None
403
-
404
- def set_window(self, window):
405
- """
406
- Set the extents of the integration window.
407
-
408
- Parameters
409
- ----------
410
- window : tuple
411
- Extents of the window for integration along each axis.
412
- """
413
- self.window = tuple([float(i) for i in window]) if window is not None else None
414
-
415
- # Determine the axis for integration
416
- self.axis = window.index(max(window))
417
- print("Linecut axis: " + str(self.data.axes[self.axis]))
418
-
419
- # Determine the integrated axes (axes other than the integration axis)
420
- self.integrated_axes = tuple(i for i in range(self.data.ndim) if i != self.axis)
421
- print("Integrated axes: " + str([self.data.axes[axis] for axis in self.integrated_axes]))
422
-
423
- def get_window(self):
424
- """
425
- Get the extents of the integration window.
426
-
427
- Returns
428
- -------
429
- tuple or None
430
- Extents of the integration window.
431
- """
432
- return self.window
433
-
434
- def cut_data(self, center=None, window=None, axis=None):
435
- """
436
- Reduces data to a 1D linecut with integration extents specified by the window about a central
437
- coordinate.
438
-
439
- Parameters
440
- -----------
441
- center : float or None, optional
442
- Central coordinate for the linecut. If not specified, the value from the object's
443
- attribute will be used.
444
- window : tuple or None, optional
445
- Integration window extents around the central coordinate. If not specified, the value
446
- from the object's attribute will be used.
447
- axis : int or None, optional
448
- The axis along which to perform the linecut. If not specified, the value from the
449
- object's attribute will be used.
450
-
451
- Returns
452
- --------
453
- integrated_data : :class:`nexusformat.nexus.NXdata`
454
- 1D linecut data after integration.
455
- """
456
-
457
- # Extract necessary attributes from the object
458
- data = self.data
459
- center = center if center is not None else self.center
460
- self.set_center(center)
461
- window = window if window is not None else self.window
462
- self.set_window(window)
463
- axis = axis if axis is not None else self.axis
464
-
465
- # Convert the center to a tuple of floats
466
- center = tuple(float(c) for c in center)
467
-
468
- # Calculate the start and stop indices for slicing the data
469
- start = np.subtract(center, window)
470
- stop = np.add(center, window)
471
- slice_obj = tuple(slice(s, e) for s, e in zip(start, stop))
472
- self.integration_window = slice_obj
473
-
474
- # Perform the data cut
475
- self.integration_volume = data[slice_obj]
476
- self.integration_volume.nxname = data.nxname
477
-
478
- # Perform integration along the integrated axes
479
- integrated_data = np.sum(self.integration_volume[self.integration_volume.signal].nxdata,
480
- axis=self.integrated_axes)
481
-
482
- # Create an NXdata object for the linecut data
483
- self.linecut = NXdata(NXfield(integrated_data, name=self.integration_volume.signal),
484
- self.integration_volume[self.integration_volume.axes[axis]])
485
- self.linecut.nxname = self.integration_volume.nxname
486
-
487
- return self.linecut
488
-
489
- def highlight_integration_window(self, data=None, label=None, highlight_color='red', **kwargs):
490
- """
491
- Plots integration window highlighted on the three principal cross sections of the first
492
- temperature dataset.
493
-
494
- Parameters
495
- ----------
496
- data : array-like, optional
497
- The 2D heatmap dataset to plot. If not provided, the dataset stored in `self.data` will
498
- be used.
499
- label : str, optional
500
- The label for the integration window plot.
501
- highlight_color : str, optional
502
- The edge color used to highlight the integration window. Default is 'red'.
503
- **kwargs : keyword arguments, optional
504
- Additional keyword arguments to customize the plot.
505
-
506
- """
507
- data = self.data if data is None else data
508
- center = self.center
509
- window = self.window
510
- integrated_axes = self.integrated_axes
511
-
512
- # Create a figure and subplots
513
- fig, axes = plt.subplots(1, 3, figsize=(15, 4))
514
-
515
- # Plot cross section 1
516
- slice_obj = [slice(None)] * data.ndim
517
- slice_obj[2] = center[2]
518
-
519
- p1 = plot_slice(data[slice_obj],
520
- X=data[data.axes[0]],
521
- Y=data[data.axes[1]],
522
- ax=axes[0],
523
- **kwargs)
524
- ax = axes[0]
525
- rect_diffuse = patches.Rectangle(
526
- (center[0] - window[0],
527
- center[1] - window[1]),
528
- 2 * window[0], 2 * window[1],
529
- linewidth=1, edgecolor=highlight_color, facecolor='none', transform=p1.get_transform(), label=label,
530
- )
531
- ax.add_patch(rect_diffuse)
532
-
533
- # Plot cross section 2
534
- slice_obj = [slice(None)] * data.ndim
535
- slice_obj[1] = center[1]
536
-
537
- p2 = plot_slice(data[slice_obj],
538
- X=data[data.axes[0]],
539
- Y=data[data.axes[2]],
540
- ax=axes[1],
541
- **kwargs)
542
- ax = axes[1]
543
- rect_diffuse = patches.Rectangle(
544
- (center[0] - window[0],
545
- center[2] - window[2]),
546
- 2 * window[0], 2 * window[2],
547
- linewidth=1, edgecolor=highlight_color, facecolor='none', transform=p2.get_transform(), label=label,
548
- )
549
- ax.add_patch(rect_diffuse)
550
-
551
- # Plot cross section 3
552
- slice_obj = [slice(None)] * data.ndim
553
- slice_obj[0] = center[0]
554
-
555
- p3 = plot_slice(data[slice_obj],
556
- X=data[data.axes[1]],
557
- Y=data[data.axes[2]],
558
- ax=axes[2],
559
- **kwargs)
560
- ax = axes[2]
561
- rect_diffuse = patches.Rectangle(
562
- (center[1] - window[1],
563
- center[2] - window[2]),
564
- 2 * window[1], 2 * window[2],
565
- linewidth=1, edgecolor=highlight_color, facecolor='none', transform=p3.get_transform(), label=label,
566
- )
567
- ax.add_patch(rect_diffuse)
568
-
569
- # Adjust subplot padding
570
- fig.subplots_adjust(wspace=0.5)
571
-
572
- if label is not None:
573
- [ax.legend() for ax in axes]
574
-
575
- plt.show()
576
-
577
- return p1, p2, p3
578
-
579
- def plot_integration_window(self, **kwargs):
580
- """
581
- Plots the three principal cross-sections of the integration volume on a single figure.
582
-
583
- Parameters
584
- ----------
585
- **kwargs : keyword arguments, optional
586
- Additional keyword arguments to customize the plot.
587
- """
588
- data = self.integration_volume
589
- axis = self.axis
590
- center = self.center
591
- window = self.window
592
- integrated_axes = self.integrated_axes
593
-
594
- fig, axes = plt.subplots(1, 3, figsize=(15, 4))
595
-
596
- # Plot cross section 1
597
- slice_obj = [slice(None)] * data.ndim
598
- slice_obj[2] = center[2]
599
- p1 = plot_slice(data[slice_obj],
600
- X=data[data.axes[0]],
601
- Y=data[data.axes[1]],
602
- ax=axes[0],
603
- **kwargs)
604
- axes[0].set_aspect(len(data[data.axes[0]].nxdata) / len(data[data.axes[1]].nxdata))
605
-
606
- # Plot cross section 2
607
- slice_obj = [slice(None)] * data.ndim
608
- slice_obj[1] = center[1]
609
- p3 = plot_slice(data[slice_obj],
610
- X=data[data.axes[0]],
611
- Y=data[data.axes[2]],
612
- ax=axes[1],
613
- **kwargs)
614
- axes[1].set_aspect(len(data[data.axes[0]].nxdata) / len(data[data.axes[2]].nxdata))
615
-
616
- # Plot cross section 3
617
- slice_obj = [slice(None)] * data.ndim
618
- slice_obj[0] = center[0]
619
- p2 = plot_slice(data[slice_obj],
620
- X=data[data.axes[1]],
621
- Y=data[data.axes[2]],
622
- ax=axes[2],
623
- **kwargs)
624
- axes[2].set_aspect(len(data[data.axes[1]].nxdata) / len(data[data.axes[2]].nxdata))
625
-
626
- # Adjust subplot padding
627
- fig.subplots_adjust(wspace=0.3)
628
-
629
- plt.show()
630
-
631
- return p1, p2, p3
632
-
633
-
634
- def reciprocal_lattice_params(lattice_params):
635
- a_mag, b_mag, c_mag, alpha, beta, gamma = lattice_params
636
- # Convert angles to radians
637
- alpha = np.deg2rad(alpha)
638
- beta = np.deg2rad(beta)
639
- gamma = np.deg2rad(gamma)
640
-
641
- # Calculate unit cell volume
642
- V = a_mag * b_mag * c_mag * np.sqrt(
643
- 1 - np.cos(alpha) ** 2 - np.cos(beta) ** 2 - np.cos(gamma) ** 2 + 2 * np.cos(alpha) * np.cos(beta) * np.cos(
644
- gamma)
645
- )
646
-
647
- # Calculate reciprocal lattice parameters
648
- a_star = (b_mag * c_mag * np.sin(alpha)) / V
649
- b_star = (a_mag * c_mag * np.sin(beta)) / V
650
- c_star = (a_mag * b_mag * np.sin(gamma)) / V
651
- alpha_star = np.rad2deg(np.arccos((np.cos(beta) * np.cos(gamma) - np.cos(alpha)) / (np.sin(beta) * np.sin(gamma))))
652
- beta_star = np.rad2deg(np.arccos((np.cos(alpha) * np.cos(gamma) - np.cos(beta)) / (np.sin(alpha) * np.sin(gamma))))
653
- gamma_star = np.rad2deg(np.arccos((np.cos(alpha) * np.cos(beta) - np.cos(gamma)) / (np.sin(alpha) * np.sin(beta))))
654
-
655
- return a_star, b_star, c_star, alpha_star, beta_star, gamma_star
656
-
657
-
658
- def rotate_data(data, lattice_angle, rotation_angle, rotation_axis, printout=False):
659
- """
660
- Rotates 3D data around a specified axis.
661
-
662
- Parameters
663
- ----------
664
- data : :class:`nexusformat.nexus.NXdata`
665
- Input data.
666
- lattice_angle : float
667
- Angle between the two in-plane lattice axes in degrees.
668
- rotation_angle : float
669
- Angle of rotation in degrees.
670
- rotation_axis : int
671
- Axis of rotation (0, 1, or 2).
672
- printout : bool, optional
673
- Enables printout of rotation progress. If set to True, information about each rotation slice will be printed
674
- to the console, indicating the axis being rotated and the corresponding
675
- coordinate value. Defaults to False.
676
-
677
-
678
- Returns
679
- -------
680
- rotated_data : :class:`nexusformat.nexus.NXdata`
681
- Rotated data as an NXdata object.
682
- """
683
- # Define output array
684
- output_array = np.zeros(data[data.signal].shape)
685
-
686
- # Define transformation
687
- skew_angle_adj = 90 - lattice_angle
688
- t = Affine2D()
689
- # Scale y-axis to preserve norm while shearing
690
- t += Affine2D().scale(1, np.cos(skew_angle_adj * np.pi / 180))
691
- # Shear along x-axis
692
- t += Affine2D().skew_deg(skew_angle_adj, 0)
693
- # Return to original y-axis scaling
694
- t += Affine2D().scale(1, np.cos(skew_angle_adj * np.pi / 180)).inverted()
695
-
696
- for i in range(len(data[data.axes[rotation_axis]])):
697
- if printout:
698
- print(f'\rRotating {data.axes[rotation_axis]}={data[data.axes[rotation_axis]][i]}... ',
699
- end='', flush=True)
700
- # Identify current slice
701
- if rotation_axis == 0:
702
- sliced_data = data[i, :, :]
703
- elif rotation_axis == 1:
704
- sliced_data = data[:, i, :]
705
- elif rotation_axis == 2:
706
- sliced_data = data[:, :, i]
707
- else:
708
- sliced_data = None
709
-
710
- p = Padder(sliced_data)
711
- padding = tuple([len(sliced_data[axis]) for axis in sliced_data.axes])
712
- counts = p.pad(padding).counts
713
-
714
- counts_skewed = ndimage.affine_transform(counts,
715
- t.inverted().get_matrix()[:2, :2],
716
- offset=[counts.shape[0] / 2 * np.sin(skew_angle_adj * np.pi / 180), 0],
717
- order=0,
718
- )
719
- scale1 = np.cos(skew_angle_adj * np.pi / 180)
720
- counts_scaled1 = ndimage.affine_transform(counts_skewed,
721
- Affine2D().scale(scale1, 1).get_matrix()[:2, :2],
722
- offset=[(1 - scale1) * counts.shape[0] / 2, 0],
723
- order=0,
724
- )
725
- scale2 = counts.shape[0] / counts.shape[1]
726
- counts_scaled2 = ndimage.affine_transform(counts_scaled1,
727
- Affine2D().scale(scale2, 1).get_matrix()[:2, :2],
728
- offset=[(1 - scale2) * counts.shape[0] / 2, 0],
729
- order=0,
730
- )
731
-
732
- counts_rotated = ndimage.rotate(counts_scaled2, rotation_angle, reshape=False, order=0)
733
-
734
- counts_unscaled2 = ndimage.affine_transform(counts_rotated,
735
- Affine2D().scale(scale2, 1).inverted().get_matrix()[:2, :2],
736
- offset=[-(1 - scale2) * counts.shape[
737
- 0] / 2 / scale2, 0],
738
- order=0,
739
- )
740
-
741
- counts_unscaled1 = ndimage.affine_transform(counts_unscaled2,
742
- Affine2D().scale(scale1,
743
- 1).inverted().get_matrix()[:2, :2],
744
- offset=[-(1 - scale1) * counts.shape[
745
- 0] / 2 / scale1, 0],
746
- order=0,
747
- )
748
-
749
- counts_unskewed = ndimage.affine_transform(counts_unscaled1,
750
- t.get_matrix()[:2, :2],
751
- offset=[
752
- (-counts.shape[0] / 2 * np.sin(skew_angle_adj * np.pi / 180)),
753
- 0],
754
- order=0,
755
- )
756
-
757
- counts_unpadded = p.unpad(counts_unskewed)
758
-
759
- # Write current slice
760
- if rotation_axis == 0:
761
- output_array[i, :, :] = counts_unpadded
762
- elif rotation_axis == 1:
763
- output_array[:, i, :] = counts_unpadded
764
- elif rotation_axis == 2:
765
- output_array[:, :, i] = counts_unpadded
766
- print('\nDone.')
767
- return NXdata(NXfield(output_array, name='counts'),
768
- (data[data.axes[0]], data[data.axes[1]], data[data.axes[2]]))
769
-
770
-
771
- class Padder():
772
- """
773
- A class to pad and unpad datasets with a symmetric region of zeros.
774
- """
775
-
776
- def __init__(self, data=None):
777
- """
778
- Initialize the Symmetrizer3D object.
779
-
780
- Parameters
781
- ----------
782
- data : NXdata, optional
783
- The input data to be symmetrized. If provided, the `set_data` method is called to set the data.
784
-
785
- """
786
- self.padded = None
787
- self.padding = None
788
- if data is not None:
789
- self.set_data(data)
790
-
791
- def set_data(self, data):
792
- """
793
- Set the input data for symmetrization.
794
-
795
- Parameters
796
- ----------
797
- data : NXdata
798
- The input data to be symmetrized.
799
-
800
- """
801
- self.data = data
802
-
803
- self.steps = tuple([(data[axis].nxdata[1] - data[axis].nxdata[0]) for axis in data.axes])
804
-
805
- # Absolute value of the maximum value; assumes the domain of the input is symmetric (eg, -H_min = H_max)
806
- self.maxes = tuple([data[axis].nxdata.max() for axis in data.axes])
807
-
808
- def pad(self, padding):
809
- """
810
- Symmetrically pads the data with zero values.
811
-
812
- Parameters
813
- ----------
814
- padding : tuple
815
- The number of zero-value pixels to add along each edge of the array.
816
- """
817
- data = self.data
818
- self.padding = padding
819
-
820
- padded_shape = tuple([data[data.signal].nxdata.shape[i] + self.padding[i] * 2 for i in range(data.ndim)])
821
-
822
- # Create padded dataset
823
- padded = np.zeros(padded_shape)
824
-
825
- slice_obj = [slice(None)] * data.ndim
826
- for i, _ in enumerate(slice_obj):
827
- slice_obj[i] = slice(self.padding[i], -self.padding[i], None)
828
- slice_obj = tuple(slice_obj)
829
- padded[slice_obj] = data[data.signal].nxdata
830
-
831
- padmaxes = tuple([self.maxes[i] + self.padding[i] * self.steps[i] for i in range(data.ndim)])
832
-
833
- padded = NXdata(NXfield(padded, name=data.signal),
834
- tuple([NXfield(np.linspace(-padmaxes[i], padmaxes[i], padded_shape[i]),
835
- name=data.axes[i])
836
- for i in range(data.ndim)]))
837
-
838
- self.padded = padded
839
- return padded
840
-
841
- def save(self, fout_name=None):
842
- """
843
- Saves the padded dataset to a .nxs file.
844
-
845
- Parameters
846
- ----------
847
- fout_name : str, optional
848
- The output file name. Default is padded_(Hpadding)_(Kpadding)_(Lpadding).nxs
849
- """
850
- padH, padK, padL = self.padding
851
-
852
- # Save padded dataset
853
- print("Saving padded dataset...")
854
- f = NXroot()
855
- f['entry'] = NXentry()
856
- f['entry']['data'] = self.padded
857
- if fout_name is None:
858
- fout_name = 'padded_' + str(padH) + '_' + str(padK) + '_' + str(padL) + '.nxs'
859
- nxsave(fout_name, f)
860
- print("Output file saved to: " + os.path.join(os.getcwd(), fout_name))
861
-
862
- def unpad(self, data):
863
- """
864
- Removes the padded region from the data.
865
-
866
- Parameters
867
- ----------
868
- data : ndarray or NXdata
869
- The padded data from which to remove the padding.
870
-
871
- Returns
872
- -------
873
- ndarray or NXdata
874
- The unpadded data, with the symmetric padding region removed.
875
-
876
- Notes
877
- -----
878
- This method removes the symmetric padding region that was added using the `pad` method. It returns the data
879
- without the padded region.
880
-
881
-
882
- """
883
- slice_obj = [slice(None)] * data.ndim
884
- for i in range(data.ndim):
885
- slice_obj[i] = slice(self.padding[i], -self.padding[i], None)
886
- slice_obj = tuple(slice_obj)
887
- return data[slice_obj]
1
+ """
2
+ Reduces scattering data into 2D and 1D datasets.
3
+ """
4
+ import os
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from matplotlib.transforms import Affine2D
8
+ from matplotlib.markers import MarkerStyle
9
+ from matplotlib.ticker import MultipleLocator
10
+ from matplotlib import colors
11
+ from matplotlib import patches
12
+ from IPython.display import display, Markdown
13
+ from nexusformat.nexus import NXfield, NXdata, nxload, NeXusError, NXroot, NXentry, nxsave
14
+ from scipy import ndimage
15
+
16
+ # Specify items on which users are allowed to perform standalone imports
17
+ __all__ = ['load_data', 'load_transform', 'plot_slice', 'Scissors', 'reciprocal_lattice_params', 'rotate_data',
18
+ 'array_to_nxdata', 'Padder']
19
+
20
+
21
+ def load_data(path):
22
+ """
23
+ Load data from a specified path.
24
+
25
+ Parameters
26
+ ----------
27
+ path : str
28
+ The path to the data file.
29
+
30
+ Returns
31
+ -------
32
+ data : nxdata object
33
+ The loaded data stored in a nxdata object.
34
+
35
+ """
36
+
37
+ g = nxload(path)
38
+ try:
39
+ print(g.entry.data.tree)
40
+ except NeXusError:
41
+ pass
42
+
43
+ return g.entry.data
44
+
45
+
46
+ def load_transform(path):
47
+ """
48
+ Load nxrefine-transformed data from a specified path.
49
+
50
+ Parameters
51
+ ----------
52
+ path : str The path to the data file.
53
+
54
+ Returns
55
+ -------
56
+ data : nxdata object The loaded data stored in a nxdata object.
57
+ """
58
+ g = nxload(path)
59
+ return NXdata(NXfield(g.entry.transform.data.nxdata.transpose(2, 1, 0), name='counts'),
60
+ (g.entry.transform.Qh, g.entry.transform.Qk, g.entry.transform.Ql))
61
+
62
+
63
+ def array_to_nxdata(array, data_template, signal_name='counts'):
64
+ """
65
+ Create an NXdata object from an input array and an NXdata template, with an optional signal name.
66
+
67
+ Parameters
68
+ ----------
69
+ array : array-like
70
+ The data array to be included in the NXdata object.
71
+
72
+ data_template : NXdata
73
+ An NXdata object serving as a template, which provides information about axes and other metadata.
74
+
75
+ signal_name : str, optional
76
+ The name of the signal within the NXdata object. If not provided,
77
+ the default signal name 'counts' is used.
78
+
79
+ Returns
80
+ -------
81
+ NXdata
82
+ An NXdata object containing the input data array and associated axes based on the template.
83
+ """
84
+ d = data_template
85
+ return NXdata(NXfield(array, name=signal_name), tuple([d[d.axes[i]] for i in range(len(d.axes))]))
86
+
87
+
88
+ def plot_slice(data, X=None, Y=None, transpose=False, vmin=None, vmax=None, skew_angle=90,
89
+ ax=None, xlim=None, ylim=None, xticks=None, yticks=None, cbar=True, logscale=False,
90
+ symlogscale=False, cmap='viridis', linthresh=1, title=None, mdheading=None, cbartitle=None,
91
+ **kwargs):
92
+ """
93
+ Parameters
94
+ ----------
95
+ data : :class:`nexusformat.nexus.NXdata` object or ndarray
96
+ The NXdata object containing the dataset to plot.
97
+
98
+ X : NXfield, optional
99
+ The X axis values. Default is first axis of `data`.
100
+
101
+ Y : NXfield, optional
102
+ The y axis values. Default is second axis of `data`.
103
+
104
+ transpose : bool, optional
105
+ If True, tranpose the dataset and its axes before plotting. Default is False.
106
+
107
+ vmin : float, optional
108
+ The minimum value to plot in the dataset.
109
+ If not provided, the minimum of the dataset will be used.
110
+
111
+ vmax : float, optional
112
+ The maximum value to plot in the dataset.
113
+ If not provided, the maximum of the dataset will be used.
114
+
115
+ skew_angle : float, optional
116
+ The angle to shear the plot in degrees. Defaults to 90 degrees (no skewing).
117
+
118
+ ax : matplotlib.axes.Axes, optional
119
+ An optional axis object to plot the heatmap onto.
120
+
121
+ xlim : tuple, optional
122
+ The limits of the x-axis. If not provided, the limits will be automatically set.
123
+
124
+ ylim : tuple, optional
125
+ The limits of the y-axis. If not provided, the limits will be automatically set.
126
+
127
+ xticks : float, optional
128
+ The major tick interval for the x-axis.
129
+ If not provided, the function will use a default minor tick interval of 1.
130
+
131
+ yticks : float, optional
132
+ The major tick interval for the y-axis.
133
+ If not provided, the function will use a default minor tick interval of 1.
134
+
135
+ cbar : bool, optional
136
+ Whether to include a colorbar in the plot. Defaults to True.
137
+
138
+ logscale : bool, optional
139
+ Whether to use a logarithmic color scale. Defaults to False.
140
+
141
+ symlogscale : bool, optional
142
+ Whether to use a symmetrical logarithmic color scale. Defaults to False.
143
+
144
+ cmap : str or Colormap, optional
145
+ The color map to use. Defaults to 'viridis'.
146
+
147
+ linthresh : float, optional
148
+ The linear threshold for the symmetrical logarithmic color scale. Defaults to 1.
149
+
150
+ mdheading : str, optional
151
+ A string containing the Markdown heading for the plot. Default `None`.
152
+
153
+ Returns
154
+ -------
155
+ p : :class:`matplotlib.collections.QuadMesh`
156
+
157
+ A :class:`matplotlib.collections.QuadMesh` object, to mimick behavior of
158
+ :class:`matplotlib.pyplot.pcolormesh`.
159
+
160
+ """
161
+ if type(data) == np.ndarray:
162
+ if X is None:
163
+ X = NXfield(np.linspace(0, data.shape[1], data.shape[1]), name='x')
164
+ if Y is None:
165
+ Y = NXfield(np.linspace(0, data.shape[0], data.shape[0]), name='y')
166
+ if transpose:
167
+ X, Y = Y, X
168
+ data = data.transpose()
169
+ data = NXdata(NXfield(data, name='value'), (X, Y))
170
+ data_arr = data
171
+ elif type(data) == NXdata or type(data) == NXfield:
172
+ if X is None:
173
+ X = data[data.axes[0]]
174
+ if Y is None:
175
+ Y = data[data.axes[1]]
176
+ if transpose:
177
+ X, Y = Y, X
178
+ data = data.transpose()
179
+ data_arr = data[data.signal].nxdata.transpose()
180
+ else:
181
+ raise TypeError(f"Unexpected data type: {type(data)}. Supported types are np.ndarray and NXdata.")
182
+
183
+ # Display Markdown heading
184
+ if mdheading is None:
185
+ pass
186
+ elif mdheading == "None":
187
+ display(Markdown('### Figure'))
188
+ else:
189
+ display(Markdown('### Figure - ' + mdheading))
190
+
191
+ # Inherit axes if user provides some
192
+ if ax is not None:
193
+ ax = ax
194
+ fig = ax.get_figure()
195
+ # Otherwise set up some default axes
196
+ else:
197
+ fig = plt.figure()
198
+ ax = fig.add_axes([0, 0, 1, 1])
199
+
200
+ # If limits not provided, use extrema
201
+ if vmin is None:
202
+ vmin = data_arr.min()
203
+ if vmax is None:
204
+ vmax = data_arr.max()
205
+
206
+ # Set norm (linear scale, logscale, or symlogscale)
207
+ norm = colors.Normalize(vmin=vmin, vmax=vmax) # Default: linear scale
208
+
209
+ if symlogscale:
210
+ norm = colors.SymLogNorm(linthresh=linthresh, vmin=-1 * vmax, vmax=vmax)
211
+ elif logscale:
212
+ norm = colors.LogNorm(vmin=vmin, vmax=vmax)
213
+
214
+ # Plot data
215
+ p = ax.pcolormesh(X.nxdata, Y.nxdata, data_arr, shading='auto', norm=norm, cmap=cmap, **kwargs)
216
+
217
+ ## Transform data to new coordinate system if necessary
218
+ # Correct skew angle
219
+ skew_angle_adj = 90 - skew_angle
220
+ # Create blank 2D affine transformation
221
+ t = Affine2D()
222
+ # Scale y-axis to preserve norm while shearing
223
+ t += Affine2D().scale(1, np.cos(skew_angle_adj * np.pi / 180))
224
+ # Shear along x-axis
225
+ t += Affine2D().skew_deg(skew_angle_adj, 0)
226
+ # Return to original y-axis scaling
227
+ t += Affine2D().scale(1, np.cos(skew_angle_adj * np.pi / 180)).inverted()
228
+ ## Correct for x-displacement after shearing
229
+ # If ylims provided, use those
230
+ if ylim is not None:
231
+ # Set ylims
232
+ ax.set(ylim=ylim)
233
+ ymin, ymax = ylim
234
+ # Else, use current ylims
235
+ else:
236
+ ymin, ymax = ax.get_ylim()
237
+ # Use ylims to calculate translation (necessary to display axes in correct position)
238
+ p.set_transform(t + Affine2D().translate(-ymin * np.sin(skew_angle_adj * np.pi / 180), 0) + ax.transData)
239
+
240
+ # Set x limits
241
+ if xlim is not None:
242
+ xmin, xmax = xlim
243
+ else:
244
+ xmin, xmax = ax.get_xlim()
245
+ if skew_angle <= 90:
246
+ ax.set(xlim=(xmin, xmax + (ymax - ymin) / np.tan((90 - skew_angle_adj) * np.pi / 180)))
247
+ else:
248
+ ax.set(xlim=(xmin - (ymax - ymin) / np.tan((skew_angle_adj - 90) * np.pi / 180), xmax))
249
+
250
+ # Correct aspect ratio for the x/y axes after transformation
251
+ ax.set(aspect=np.cos(skew_angle_adj * np.pi / 180))
252
+
253
+ # Add tick marks all around
254
+ ax.tick_params(direction='in', top=True, right=True, which='both')
255
+
256
+ # Set tick locations
257
+ if xticks is None:
258
+ # Add default minor ticks
259
+ ax.xaxis.set_minor_locator(MultipleLocator(1))
260
+ else:
261
+ # Otherwise use user provided values
262
+ ax.xaxis.set_major_locator(MultipleLocator(xticks))
263
+ ax.xaxis.set_minor_locator(MultipleLocator(1))
264
+ if yticks is None:
265
+ # Add default minor ticks
266
+ ax.yaxis.set_minor_locator(MultipleLocator(1))
267
+ else:
268
+ # Otherwise use user provided values
269
+ ax.yaxis.set_major_locator(MultipleLocator(yticks))
270
+ ax.yaxis.set_minor_locator(MultipleLocator(1))
271
+
272
+ # Apply transform to tick marks
273
+ for i in range(0, len(ax.xaxis.get_ticklines())):
274
+ # Tick marker
275
+ m = MarkerStyle(3)
276
+ line = ax.xaxis.get_majorticklines()[i]
277
+ if i % 2:
278
+ # Top ticks (translation here makes their direction="in")
279
+ m._transform.set(Affine2D().translate(0, -1) + Affine2D().skew_deg(skew_angle_adj, 0))
280
+ # This first method shifts the top ticks horizontally to match the skew angle.
281
+ # This does not look good in all cases.
282
+ # line.set_transform(Affine2D().translate((ymax-ymin)*np.sin(skew_angle*np.pi/180),0) +
283
+ # line.get_transform())
284
+ # This second method skews the tick marks in place and
285
+ # can sometimes lead to them being misaligned.
286
+ line.set_transform(line.get_transform()) # This does nothing
287
+ else:
288
+ # Bottom ticks
289
+ m._transform.set(Affine2D().skew_deg(skew_angle_adj, 0))
290
+
291
+ line.set_marker(m)
292
+
293
+ for i in range(0, len(ax.xaxis.get_minorticklines())):
294
+ m = MarkerStyle(2)
295
+ line = ax.xaxis.get_minorticklines()[i]
296
+ if i % 2:
297
+ m._transform.set(Affine2D().translate(0, -1) + Affine2D().skew_deg(skew_angle_adj, 0))
298
+ else:
299
+ m._transform.set(Affine2D().skew_deg(skew_angle_adj, 0))
300
+
301
+ line.set_marker(m)
302
+
303
+ if cbar:
304
+ colorbar = fig.colorbar(p)
305
+ if cbartitle is None:
306
+ colorbar.set_label(data.signal)
307
+
308
+ ax.set(
309
+ xlabel=X.nxname,
310
+ ylabel=Y.nxname,
311
+ )
312
+
313
+ if title is not None:
314
+ ax.set_title(title)
315
+
316
+ # Return the quadmesh object
317
+ return p
318
+
319
+
320
+ class Scissors:
321
+ """
322
+ Scissors class provides functionality for reducing data to a 1D linecut using an integration
323
+ window.
324
+
325
+ Attributes
326
+ ----------
327
+ data : :class:`nexusformat.nexus.NXdata` or None
328
+ Input :class:`nexusformat.nexus.NXdata`.
329
+ center : tuple or None
330
+ Central coordinate around which to perform the linecut.
331
+ window : tuple or None
332
+ Extents of the window for integration along each axis.
333
+ axis : int or None
334
+ Axis along which to perform the integration.
335
+ data_cut : ndarray or None
336
+ Data array after applying the integration window.
337
+ integrated_axes : tuple or None
338
+ Indices of axes that were integrated.
339
+ linecut : :class:`nexusformat.nexus.NXdata` or None
340
+ 1D linecut data after integration.
341
+ window_plane_slice_obj : list or None
342
+ Slice object representing the integration window in the data array.
343
+
344
+ Methods
345
+ -------
346
+ set_data(data)
347
+ Set the input :class:`nexusformat.nexus.NXdata`
348
+ get_data()
349
+ Get the input :class:`nexusformat.nexus.NXdata`.
350
+ set_center(center)
351
+ Set the central coordinate for the linecut.
352
+ set_window(window)
353
+ Set the extents of the integration window.
354
+ get_window()
355
+ Get the extents of the integration window.
356
+ cut_data(axis=None)
357
+ Reduce data to a 1D linecut using the integration window.
358
+ show_integration_window(label=None)
359
+ Plot the integration window highlighted on a 2D heatmap of the full dataset.
360
+ plot_window()
361
+ Plot a 2D heatmap of the integration window data.
362
+ """
363
+
364
+ def __init__(self, data=None, center=None, window=None, axis=None):
365
+ """
366
+ Initializes a Scissors object.
367
+
368
+ Parameters
369
+ ----------
370
+ data : :class:`nexusformat.nexus.NXdata` or None, optional
371
+ Input NXdata. Default is None.
372
+ center : tuple or None, optional
373
+ Central coordinate around which to perform the linecut. Default is None.
374
+ window : tuple or None, optional
375
+ Extents of the window for integration along each axis. Default is None.
376
+ axis : int or None, optional
377
+ Axis along which to perform the integration. Default is None.
378
+ """
379
+
380
+ self.data = data
381
+ self.center = tuple([float(i) for i in center]) if center is not None else None
382
+ self.window = tuple([float(i) for i in window]) if window is not None else None
383
+ self.axis = axis
384
+
385
+ self.integration_volume = None
386
+ self.integrated_axes = None
387
+ self.linecut = None
388
+ self.integration_window = None
389
+
390
+ def set_data(self, data):
391
+ """
392
+ Set the input NXdata.
393
+
394
+ Parameters
395
+ ----------
396
+ data : :class:`nexusformat.nexus.NXdata`
397
+ Input data array.
398
+ """
399
+ self.data = data
400
+
401
+ def get_data(self):
402
+ """
403
+ Get the input data array.
404
+
405
+ Returns
406
+ -------
407
+ ndarray or None
408
+ Input data array.
409
+ """
410
+ return self.data
411
+
412
+ def set_center(self, center):
413
+ """
414
+ Set the central coordinate for the linecut.
415
+
416
+ Parameters
417
+ ----------
418
+ center : tuple
419
+ Central coordinate around which to perform the linecut.
420
+ """
421
+ self.center = tuple([float(i) for i in center]) if center is not None else None
422
+
423
+ def set_window(self, window):
424
+ """
425
+ Set the extents of the integration window.
426
+
427
+ Parameters
428
+ ----------
429
+ window : tuple
430
+ Extents of the window for integration along each axis.
431
+ """
432
+ self.window = tuple([float(i) for i in window]) if window is not None else None
433
+
434
+ # Determine the axis for integration
435
+ self.axis = window.index(max(window))
436
+ print("Linecut axis: " + str(self.data.axes[self.axis]))
437
+
438
+ # Determine the integrated axes (axes other than the integration axis)
439
+ self.integrated_axes = tuple(i for i in range(self.data.ndim) if i != self.axis)
440
+ print("Integrated axes: " + str([self.data.axes[axis] for axis in self.integrated_axes]))
441
+
442
+ def get_window(self):
443
+ """
444
+ Get the extents of the integration window.
445
+
446
+ Returns
447
+ -------
448
+ tuple or None
449
+ Extents of the integration window.
450
+ """
451
+ return self.window
452
+
453
+ def cut_data(self, center=None, window=None, axis=None):
454
+ """
455
+ Reduces data to a 1D linecut with integration extents specified by the window about a central
456
+ coordinate.
457
+
458
+ Parameters
459
+ -----------
460
+ center : float or None, optional
461
+ Central coordinate for the linecut. If not specified, the value from the object's
462
+ attribute will be used.
463
+ window : tuple or None, optional
464
+ Integration window extents around the central coordinate. If not specified, the value
465
+ from the object's attribute will be used.
466
+ axis : int or None, optional
467
+ The axis along which to perform the linecut. If not specified, the value from the
468
+ object's attribute will be used.
469
+
470
+ Returns
471
+ --------
472
+ integrated_data : :class:`nexusformat.nexus.NXdata`
473
+ 1D linecut data after integration.
474
+ """
475
+
476
+ # Extract necessary attributes from the object
477
+ data = self.data
478
+ center = center if center is not None else self.center
479
+ self.set_center(center)
480
+ window = window if window is not None else self.window
481
+ self.set_window(window)
482
+ axis = axis if axis is not None else self.axis
483
+
484
+ # Convert the center to a tuple of floats
485
+ center = tuple(float(c) for c in center)
486
+
487
+ # Calculate the start and stop indices for slicing the data
488
+ start = np.subtract(center, window)
489
+ stop = np.add(center, window)
490
+ slice_obj = tuple(slice(s, e) for s, e in zip(start, stop))
491
+ self.integration_window = slice_obj
492
+
493
+ # Perform the data cut
494
+ self.integration_volume = data[slice_obj]
495
+ self.integration_volume.nxname = data.nxname
496
+
497
+ # Perform integration along the integrated axes
498
+ integrated_data = np.sum(self.integration_volume[self.integration_volume.signal].nxdata,
499
+ axis=self.integrated_axes)
500
+
501
+ # Create an NXdata object for the linecut data
502
+ self.linecut = NXdata(NXfield(integrated_data, name=self.integration_volume.signal),
503
+ self.integration_volume[self.integration_volume.axes[axis]])
504
+ self.linecut.nxname = self.integration_volume.nxname
505
+
506
+ return self.linecut
507
+
508
+ def highlight_integration_window(self, data=None, label=None, highlight_color='red', **kwargs):
509
+ """
510
+ Plots integration window highlighted on the three principal cross sections of the first
511
+ temperature dataset.
512
+
513
+ Parameters
514
+ ----------
515
+ data : array-like, optional
516
+ The 2D heatmap dataset to plot. If not provided, the dataset stored in `self.data` will
517
+ be used.
518
+ label : str, optional
519
+ The label for the integration window plot.
520
+ highlight_color : str, optional
521
+ The edge color used to highlight the integration window. Default is 'red'.
522
+ **kwargs : keyword arguments, optional
523
+ Additional keyword arguments to customize the plot.
524
+
525
+ """
526
+ data = self.data if data is None else data
527
+ center = self.center
528
+ window = self.window
529
+ integrated_axes = self.integrated_axes
530
+
531
+ # Create a figure and subplots
532
+ fig, axes = plt.subplots(1, 3, figsize=(15, 4))
533
+
534
+ # Plot cross section 1
535
+ slice_obj = [slice(None)] * data.ndim
536
+ slice_obj[2] = center[2]
537
+
538
+ p1 = plot_slice(data[slice_obj],
539
+ X=data[data.axes[0]],
540
+ Y=data[data.axes[1]],
541
+ ax=axes[0],
542
+ **kwargs)
543
+ ax = axes[0]
544
+ rect_diffuse = patches.Rectangle(
545
+ (center[0] - window[0],
546
+ center[1] - window[1]),
547
+ 2 * window[0], 2 * window[1],
548
+ linewidth=1, edgecolor=highlight_color, facecolor='none', transform=p1.get_transform(), label=label,
549
+ )
550
+ ax.add_patch(rect_diffuse)
551
+
552
+ # Plot cross section 2
553
+ slice_obj = [slice(None)] * data.ndim
554
+ slice_obj[1] = center[1]
555
+
556
+ p2 = plot_slice(data[slice_obj],
557
+ X=data[data.axes[0]],
558
+ Y=data[data.axes[2]],
559
+ ax=axes[1],
560
+ **kwargs)
561
+ ax = axes[1]
562
+ rect_diffuse = patches.Rectangle(
563
+ (center[0] - window[0],
564
+ center[2] - window[2]),
565
+ 2 * window[0], 2 * window[2],
566
+ linewidth=1, edgecolor=highlight_color, facecolor='none', transform=p2.get_transform(), label=label,
567
+ )
568
+ ax.add_patch(rect_diffuse)
569
+
570
+ # Plot cross section 3
571
+ slice_obj = [slice(None)] * data.ndim
572
+ slice_obj[0] = center[0]
573
+
574
+ p3 = plot_slice(data[slice_obj],
575
+ X=data[data.axes[1]],
576
+ Y=data[data.axes[2]],
577
+ ax=axes[2],
578
+ **kwargs)
579
+ ax = axes[2]
580
+ rect_diffuse = patches.Rectangle(
581
+ (center[1] - window[1],
582
+ center[2] - window[2]),
583
+ 2 * window[1], 2 * window[2],
584
+ linewidth=1, edgecolor=highlight_color, facecolor='none', transform=p3.get_transform(), label=label,
585
+ )
586
+ ax.add_patch(rect_diffuse)
587
+
588
+ # Adjust subplot padding
589
+ fig.subplots_adjust(wspace=0.5)
590
+
591
+ if label is not None:
592
+ [ax.legend() for ax in axes]
593
+
594
+ plt.show()
595
+
596
+ return p1, p2, p3
597
+
598
+ def plot_integration_window(self, **kwargs):
599
+ """
600
+ Plots the three principal cross-sections of the integration volume on a single figure.
601
+
602
+ Parameters
603
+ ----------
604
+ **kwargs : keyword arguments, optional
605
+ Additional keyword arguments to customize the plot.
606
+ """
607
+ data = self.integration_volume
608
+ axis = self.axis
609
+ center = self.center
610
+ window = self.window
611
+ integrated_axes = self.integrated_axes
612
+
613
+ fig, axes = plt.subplots(1, 3, figsize=(15, 4))
614
+
615
+ # Plot cross section 1
616
+ slice_obj = [slice(None)] * data.ndim
617
+ slice_obj[2] = center[2]
618
+ p1 = plot_slice(data[slice_obj],
619
+ X=data[data.axes[0]],
620
+ Y=data[data.axes[1]],
621
+ ax=axes[0],
622
+ **kwargs)
623
+ axes[0].set_aspect(len(data[data.axes[0]].nxdata) / len(data[data.axes[1]].nxdata))
624
+
625
+ # Plot cross section 2
626
+ slice_obj = [slice(None)] * data.ndim
627
+ slice_obj[1] = center[1]
628
+ p3 = plot_slice(data[slice_obj],
629
+ X=data[data.axes[0]],
630
+ Y=data[data.axes[2]],
631
+ ax=axes[1],
632
+ **kwargs)
633
+ axes[1].set_aspect(len(data[data.axes[0]].nxdata) / len(data[data.axes[2]].nxdata))
634
+
635
+ # Plot cross section 3
636
+ slice_obj = [slice(None)] * data.ndim
637
+ slice_obj[0] = center[0]
638
+ p2 = plot_slice(data[slice_obj],
639
+ X=data[data.axes[1]],
640
+ Y=data[data.axes[2]],
641
+ ax=axes[2],
642
+ **kwargs)
643
+ axes[2].set_aspect(len(data[data.axes[1]].nxdata) / len(data[data.axes[2]].nxdata))
644
+
645
+ # Adjust subplot padding
646
+ fig.subplots_adjust(wspace=0.3)
647
+
648
+ plt.show()
649
+
650
+ return p1, p2, p3
651
+
652
+
653
+ def reciprocal_lattice_params(lattice_params):
654
+ a_mag, b_mag, c_mag, alpha, beta, gamma = lattice_params
655
+ # Convert angles to radians
656
+ alpha = np.deg2rad(alpha)
657
+ beta = np.deg2rad(beta)
658
+ gamma = np.deg2rad(gamma)
659
+
660
+ # Calculate unit cell volume
661
+ V = a_mag * b_mag * c_mag * np.sqrt(
662
+ 1 - np.cos(alpha) ** 2 - np.cos(beta) ** 2 - np.cos(gamma) ** 2 + 2 * np.cos(alpha) * np.cos(beta) * np.cos(
663
+ gamma)
664
+ )
665
+
666
+ # Calculate reciprocal lattice parameters
667
+ a_star = (b_mag * c_mag * np.sin(alpha)) / V
668
+ b_star = (a_mag * c_mag * np.sin(beta)) / V
669
+ c_star = (a_mag * b_mag * np.sin(gamma)) / V
670
+ alpha_star = np.rad2deg(np.arccos((np.cos(beta) * np.cos(gamma) - np.cos(alpha)) / (np.sin(beta) * np.sin(gamma))))
671
+ beta_star = np.rad2deg(np.arccos((np.cos(alpha) * np.cos(gamma) - np.cos(beta)) / (np.sin(alpha) * np.sin(gamma))))
672
+ gamma_star = np.rad2deg(np.arccos((np.cos(alpha) * np.cos(beta) - np.cos(gamma)) / (np.sin(alpha) * np.sin(beta))))
673
+
674
+ return a_star, b_star, c_star, alpha_star, beta_star, gamma_star
675
+
676
+
677
+ def rotate_data(data, lattice_angle, rotation_angle, rotation_axis, printout=False):
678
+ """
679
+ Rotates 3D data around a specified axis.
680
+
681
+ Parameters
682
+ ----------
683
+ data : :class:`nexusformat.nexus.NXdata`
684
+ Input data.
685
+ lattice_angle : float
686
+ Angle between the two in-plane lattice axes in degrees.
687
+ rotation_angle : float
688
+ Angle of rotation in degrees.
689
+ rotation_axis : int
690
+ Axis of rotation (0, 1, or 2).
691
+ printout : bool, optional
692
+ Enables printout of rotation progress. If set to True, information about each rotation slice will be printed
693
+ to the console, indicating the axis being rotated and the corresponding
694
+ coordinate value. Defaults to False.
695
+
696
+
697
+ Returns
698
+ -------
699
+ rotated_data : :class:`nexusformat.nexus.NXdata`
700
+ Rotated data as an NXdata object.
701
+ """
702
+ # Define output array
703
+ output_array = np.zeros(data[data.signal].shape)
704
+
705
+ # Define transformation
706
+ skew_angle_adj = 90 - lattice_angle
707
+ t = Affine2D()
708
+ # Scale y-axis to preserve norm while shearing
709
+ t += Affine2D().scale(1, np.cos(skew_angle_adj * np.pi / 180))
710
+ # Shear along x-axis
711
+ t += Affine2D().skew_deg(skew_angle_adj, 0)
712
+ # Return to original y-axis scaling
713
+ t += Affine2D().scale(1, np.cos(skew_angle_adj * np.pi / 180)).inverted()
714
+
715
+ for i in range(len(data[data.axes[rotation_axis]])):
716
+ if printout:
717
+ print(f'\rRotating {data.axes[rotation_axis]}={data[data.axes[rotation_axis]][i]}... ',
718
+ end='', flush=True)
719
+ # Identify current slice
720
+ if rotation_axis == 0:
721
+ sliced_data = data[i, :, :]
722
+ elif rotation_axis == 1:
723
+ sliced_data = data[:, i, :]
724
+ elif rotation_axis == 2:
725
+ sliced_data = data[:, :, i]
726
+ else:
727
+ sliced_data = None
728
+
729
+ p = Padder(sliced_data)
730
+ padding = tuple([len(sliced_data[axis]) for axis in sliced_data.axes])
731
+ counts = p.pad(padding).counts
732
+
733
+ counts_skewed = ndimage.affine_transform(counts,
734
+ t.inverted().get_matrix()[:2, :2],
735
+ offset=[counts.shape[0] / 2 * np.sin(skew_angle_adj * np.pi / 180), 0],
736
+ order=0,
737
+ )
738
+ scale1 = np.cos(skew_angle_adj * np.pi / 180)
739
+ counts_scaled1 = ndimage.affine_transform(counts_skewed,
740
+ Affine2D().scale(scale1, 1).get_matrix()[:2, :2],
741
+ offset=[(1 - scale1) * counts.shape[0] / 2, 0],
742
+ order=0,
743
+ )
744
+ scale2 = counts.shape[0] / counts.shape[1]
745
+ counts_scaled2 = ndimage.affine_transform(counts_scaled1,
746
+ Affine2D().scale(scale2, 1).get_matrix()[:2, :2],
747
+ offset=[(1 - scale2) * counts.shape[0] / 2, 0],
748
+ order=0,
749
+ )
750
+
751
+ counts_rotated = ndimage.rotate(counts_scaled2, rotation_angle, reshape=False, order=0)
752
+
753
+ counts_unscaled2 = ndimage.affine_transform(counts_rotated,
754
+ Affine2D().scale(scale2, 1).inverted().get_matrix()[:2, :2],
755
+ offset=[-(1 - scale2) * counts.shape[
756
+ 0] / 2 / scale2, 0],
757
+ order=0,
758
+ )
759
+
760
+ counts_unscaled1 = ndimage.affine_transform(counts_unscaled2,
761
+ Affine2D().scale(scale1,
762
+ 1).inverted().get_matrix()[:2, :2],
763
+ offset=[-(1 - scale1) * counts.shape[
764
+ 0] / 2 / scale1, 0],
765
+ order=0,
766
+ )
767
+
768
+ counts_unskewed = ndimage.affine_transform(counts_unscaled1,
769
+ t.get_matrix()[:2, :2],
770
+ offset=[
771
+ (-counts.shape[0] / 2 * np.sin(skew_angle_adj * np.pi / 180)),
772
+ 0],
773
+ order=0,
774
+ )
775
+
776
+ counts_unpadded = p.unpad(counts_unskewed)
777
+
778
+ # Write current slice
779
+ if rotation_axis == 0:
780
+ output_array[i, :, :] = counts_unpadded
781
+ elif rotation_axis == 1:
782
+ output_array[:, i, :] = counts_unpadded
783
+ elif rotation_axis == 2:
784
+ output_array[:, :, i] = counts_unpadded
785
+ print('\nDone.')
786
+ return NXdata(NXfield(output_array, name='counts'),
787
+ (data[data.axes[0]], data[data.axes[1]], data[data.axes[2]]))
788
+
789
+ def rotate_data2D(data, lattice_angle, rotation_angle):
790
+ """
791
+ Rotates 3D data around a specified axis.
792
+
793
+ Parameters
794
+ ----------
795
+ data : :class:`nexusformat.nexus.NXdata`
796
+ Input data.
797
+ lattice_angle : float
798
+ Angle between the two in-plane lattice axes in degrees.
799
+ rotation_angle : float
800
+ Angle of rotation in degrees..
801
+
802
+
803
+ Returns
804
+ -------
805
+ rotated_data : :class:`nexusformat.nexus.NXdata`
806
+ Rotated data as an NXdata object.
807
+ """
808
+
809
+ # Define transformation
810
+ skew_angle_adj = 90 - lattice_angle
811
+ t = Affine2D()
812
+ # Scale y-axis to preserve norm while shearing
813
+ t += Affine2D().scale(1, np.cos(skew_angle_adj * np.pi / 180))
814
+ # Shear along x-axis
815
+ t += Affine2D().skew_deg(skew_angle_adj, 0)
816
+ # Return to original y-axis scaling
817
+ t += Affine2D().scale(1, np.cos(skew_angle_adj * np.pi / 180)).inverted()
818
+
819
+ p = Padder(data)
820
+ padding = tuple([len(data[axis]) for axis in data.axes])
821
+ counts = p.pad(padding).counts
822
+
823
+ counts_skewed = ndimage.affine_transform(counts,
824
+ t.inverted().get_matrix()[:2, :2],
825
+ offset=[counts.shape[0] / 2 * np.sin(skew_angle_adj * np.pi / 180), 0],
826
+ order=0,
827
+ )
828
+ scale1 = np.cos(skew_angle_adj * np.pi / 180)
829
+ counts_scaled1 = ndimage.affine_transform(counts_skewed,
830
+ Affine2D().scale(scale1, 1).get_matrix()[:2, :2],
831
+ offset=[(1 - scale1) * counts.shape[0] / 2, 0],
832
+ order=0,
833
+ )
834
+ scale2 = counts.shape[0] / counts.shape[1]
835
+ counts_scaled2 = ndimage.affine_transform(counts_scaled1,
836
+ Affine2D().scale(scale2, 1).get_matrix()[:2, :2],
837
+ offset=[(1 - scale2) * counts.shape[0] / 2, 0],
838
+ order=0,
839
+ )
840
+
841
+ counts_rotated = ndimage.rotate(counts_scaled2, rotation_angle, reshape=False, order=0)
842
+
843
+ counts_unscaled2 = ndimage.affine_transform(counts_rotated,
844
+ Affine2D().scale(scale2, 1).inverted().get_matrix()[:2, :2],
845
+ offset=[-(1 - scale2) * counts.shape[
846
+ 0] / 2 / scale2, 0],
847
+ order=0,
848
+ )
849
+
850
+ counts_unscaled1 = ndimage.affine_transform(counts_unscaled2,
851
+ Affine2D().scale(scale1,
852
+ 1).inverted().get_matrix()[:2, :2],
853
+ offset=[-(1 - scale1) * counts.shape[
854
+ 0] / 2 / scale1, 0],
855
+ order=0,
856
+ )
857
+
858
+ counts_unskewed = ndimage.affine_transform(counts_unscaled1,
859
+ t.get_matrix()[:2, :2],
860
+ offset=[
861
+ (-counts.shape[0] / 2 * np.sin(skew_angle_adj * np.pi / 180)),
862
+ 0],
863
+ order=0,
864
+ )
865
+
866
+ counts_unpadded = p.unpad(counts_unskewed)
867
+
868
+ print('\nDone.')
869
+ return NXdata(NXfield(counts_unpadded, name='counts'),
870
+ (data[data.axes[0]], data[data.axes[1]]))
871
+
872
+
873
+ class Padder():
874
+ """
875
+ A class to pad and unpad datasets with a symmetric region of zeros.
876
+ """
877
+
878
+ def __init__(self, data=None):
879
+ """
880
+ Initialize the Symmetrizer3D object.
881
+
882
+ Parameters
883
+ ----------
884
+ data : NXdata, optional
885
+ The input data to be symmetrized. If provided, the `set_data` method is called to set the data.
886
+
887
+ """
888
+ self.padded = None
889
+ self.padding = None
890
+ if data is not None:
891
+ self.set_data(data)
892
+
893
+ def set_data(self, data):
894
+ """
895
+ Set the input data for symmetrization.
896
+
897
+ Parameters
898
+ ----------
899
+ data : NXdata
900
+ The input data to be symmetrized.
901
+
902
+ """
903
+ self.data = data
904
+
905
+ self.steps = tuple([(data[axis].nxdata[1] - data[axis].nxdata[0]) for axis in data.axes])
906
+
907
+ # Absolute value of the maximum value; assumes the domain of the input is symmetric (eg, -H_min = H_max)
908
+ self.maxes = tuple([data[axis].nxdata.max() for axis in data.axes])
909
+
910
+ def pad(self, padding):
911
+ """
912
+ Symmetrically pads the data with zero values.
913
+
914
+ Parameters
915
+ ----------
916
+ padding : tuple
917
+ The number of zero-value pixels to add along each edge of the array.
918
+ """
919
+ data = self.data
920
+ self.padding = padding
921
+
922
+ padded_shape = tuple([data[data.signal].nxdata.shape[i] + self.padding[i] * 2 for i in range(data.ndim)])
923
+
924
+ # Create padded dataset
925
+ padded = np.zeros(padded_shape)
926
+
927
+ slice_obj = [slice(None)] * data.ndim
928
+ for i, _ in enumerate(slice_obj):
929
+ slice_obj[i] = slice(self.padding[i], -self.padding[i], None)
930
+ slice_obj = tuple(slice_obj)
931
+ padded[slice_obj] = data[data.signal].nxdata
932
+
933
+ padmaxes = tuple([self.maxes[i] + self.padding[i] * self.steps[i] for i in range(data.ndim)])
934
+
935
+ padded = NXdata(NXfield(padded, name=data.signal),
936
+ tuple([NXfield(np.linspace(-padmaxes[i], padmaxes[i], padded_shape[i]),
937
+ name=data.axes[i])
938
+ for i in range(data.ndim)]))
939
+
940
+ self.padded = padded
941
+ return padded
942
+
943
+ def save(self, fout_name=None):
944
+ """
945
+ Saves the padded dataset to a .nxs file.
946
+
947
+ Parameters
948
+ ----------
949
+ fout_name : str, optional
950
+ The output file name. Default is padded_(Hpadding)_(Kpadding)_(Lpadding).nxs
951
+ """
952
+ padH, padK, padL = self.padding
953
+
954
+ # Save padded dataset
955
+ print("Saving padded dataset...")
956
+ f = NXroot()
957
+ f['entry'] = NXentry()
958
+ f['entry']['data'] = self.padded
959
+ if fout_name is None:
960
+ fout_name = 'padded_' + str(padH) + '_' + str(padK) + '_' + str(padL) + '.nxs'
961
+ nxsave(fout_name, f)
962
+ print("Output file saved to: " + os.path.join(os.getcwd(), fout_name))
963
+
964
+ def unpad(self, data):
965
+ """
966
+ Removes the padded region from the data.
967
+
968
+ Parameters
969
+ ----------
970
+ data : ndarray or NXdata
971
+ The padded data from which to remove the padding.
972
+
973
+ Returns
974
+ -------
975
+ ndarray or NXdata
976
+ The unpadded data, with the symmetric padding region removed.
977
+
978
+ Notes
979
+ -----
980
+ This method removes the symmetric padding region that was added using the `pad` method. It returns the data
981
+ without the padded region.
982
+
983
+
984
+ """
985
+ slice_obj = [slice(None)] * data.ndim
986
+ for i in range(data.ndim):
987
+ slice_obj[i] = slice(self.padding[i], -self.padding[i], None)
988
+ slice_obj = tuple(slice_obj)
989
+ return data[slice_obj]