pyTEMlib 0.2020.11.1__py3-none-any.whl → 0.2024.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyTEMlib might be problematic. Click here for more details.

Files changed (60) hide show
  1. pyTEMlib/__init__.py +11 -11
  2. pyTEMlib/animation.py +631 -0
  3. pyTEMlib/atom_tools.py +240 -245
  4. pyTEMlib/config_dir.py +57 -33
  5. pyTEMlib/core_loss_widget.py +658 -0
  6. pyTEMlib/crystal_tools.py +1255 -0
  7. pyTEMlib/diffraction_plot.py +756 -0
  8. pyTEMlib/dynamic_scattering.py +293 -0
  9. pyTEMlib/eds_tools.py +609 -0
  10. pyTEMlib/eels_dialog.py +749 -491
  11. pyTEMlib/{interactive_eels.py → eels_dialog_utilities.py} +1199 -1177
  12. pyTEMlib/eels_tools.py +2031 -1698
  13. pyTEMlib/file_tools.py +1276 -560
  14. pyTEMlib/file_tools_qt.py +193 -0
  15. pyTEMlib/graph_tools.py +1166 -450
  16. pyTEMlib/graph_viz.py +449 -0
  17. pyTEMlib/image_dialog.py +158 -0
  18. pyTEMlib/image_dlg.py +146 -232
  19. pyTEMlib/image_tools.py +1399 -1028
  20. pyTEMlib/info_widget.py +933 -0
  21. pyTEMlib/interactive_image.py +1 -226
  22. pyTEMlib/kinematic_scattering.py +1196 -0
  23. pyTEMlib/low_loss_widget.py +176 -0
  24. pyTEMlib/microscope.py +61 -81
  25. pyTEMlib/peak_dialog.py +1047 -410
  26. pyTEMlib/peak_dlg.py +286 -242
  27. pyTEMlib/probe_tools.py +653 -207
  28. pyTEMlib/sidpy_tools.py +153 -136
  29. pyTEMlib/simulation_tools.py +104 -87
  30. pyTEMlib/version.py +6 -3
  31. pyTEMlib/xrpa_x_sections.py +20972 -0
  32. {pyTEMlib-0.2020.11.1.dist-info → pyTEMlib-0.2024.9.0.dist-info}/LICENSE +21 -21
  33. pyTEMlib-0.2024.9.0.dist-info/METADATA +92 -0
  34. pyTEMlib-0.2024.9.0.dist-info/RECORD +37 -0
  35. {pyTEMlib-0.2020.11.1.dist-info → pyTEMlib-0.2024.9.0.dist-info}/WHEEL +5 -5
  36. {pyTEMlib-0.2020.11.1.dist-info → pyTEMlib-0.2024.9.0.dist-info}/entry_points.txt +0 -1
  37. pyTEMlib/KinsCat.py +0 -2758
  38. pyTEMlib/__version__.py +0 -2
  39. pyTEMlib/data/TEMlibrc +0 -68
  40. pyTEMlib/data/edges_db.csv +0 -189
  41. pyTEMlib/data/edges_db.pkl +0 -0
  42. pyTEMlib/data/fparam.txt +0 -103
  43. pyTEMlib/data/microscopes.csv +0 -7
  44. pyTEMlib/data/microscopes.xml +0 -167
  45. pyTEMlib/data/path.txt +0 -1
  46. pyTEMlib/defaults_parser.py +0 -90
  47. pyTEMlib/dm3_reader.py +0 -613
  48. pyTEMlib/edges_db.py +0 -76
  49. pyTEMlib/eels_dlg.py +0 -224
  50. pyTEMlib/hdf_utils.py +0 -483
  51. pyTEMlib/image_tools1.py +0 -2194
  52. pyTEMlib/info_dialog.py +0 -237
  53. pyTEMlib/info_dlg.py +0 -202
  54. pyTEMlib/nion_reader.py +0 -297
  55. pyTEMlib/nsi_reader.py +0 -170
  56. pyTEMlib/structure_tools.py +0 -316
  57. pyTEMlib/test.py +0 -2072
  58. pyTEMlib-0.2020.11.1.dist-info/METADATA +0 -20
  59. pyTEMlib-0.2020.11.1.dist-info/RECORD +0 -45
  60. {pyTEMlib-0.2020.11.1.dist-info → pyTEMlib-0.2024.9.0.dist-info}/top_level.txt +0 -0
pyTEMlib/image_tools.py CHANGED
@@ -1,1028 +1,1399 @@
1
- """
2
- Image Processing and Analysis Tools of pyTEMlib
3
- """
4
- #
5
- # image_tools.py
6
- # by Gerd Duscher, UTK
7
- # part of pyTEMlib
8
- # MIT license except where stated differently
9
- #
10
-
11
- import numpy as np
12
-
13
- import matplotlib as mpl
14
- import matplotlib.pylab as plt
15
- from matplotlib.patches import Polygon # plotting of polygons -- graph rings
16
-
17
- import matplotlib.widgets as mwidgets
18
- from matplotlib.widgets import RectangleSelector
19
-
20
- from .file_tools import *
21
- from .probe_tools import *
22
- import sys
23
-
24
- import itertools
25
- from itertools import product
26
-
27
- from scipy import fftpack
28
- from scipy import signal
29
- from scipy.interpolate import interp1d, interp2d
30
- from scipy.optimize import leastsq
31
- import scipy.optimize as optimization
32
-
33
- # Multidimensional Image library
34
- import scipy.ndimage as ndimage
35
- import scipy.constants as const
36
-
37
- import scipy.spatial as sp
38
- from scipy.spatial import Voronoi, KDTree, cKDTree
39
-
40
- import skimage
41
- import skimage.registration as registration
42
- from skimage.feature import register_translation # blob_dog, blob_doh
43
- from skimage.feature import peak_local_max
44
- from skimage.measure import points_in_poly
45
-
46
- # our blob detectors from the scipy image package
47
- from skimage.feature import blob_log # blob_dog, blob_doh
48
-
49
- from sklearn.feature_extraction import image
50
- from sklearn.utils.extmath import randomized_svd
51
-
52
- _SimpleITK_present = True
53
- try:
54
- import SimpleITK as sITK
55
- except ModuleNotFoundError:
56
- _SimpleITK_present = False
57
-
58
- if not _SimpleITK_present:
59
- print('SimpleITK not installed; Registration Functions for Image Stacks not available')
60
-
61
- sys.path.insert(0,'../../sidpy')
62
- # Wavelength in 1/nm
63
- def get_wavelength(e0):
64
- """
65
- Calculates the relativistic corrected de Broglie wave length of an electron
66
-
67
- Input:
68
- acceleration voltage in volt
69
- Output:
70
- wave length in 1/nm
71
- """
72
-
73
- eV = const.e * e0
74
- return const.h/np.sqrt(2*const.m_e*eV*(1+eV/(2*const.m_e*const.c**2)))*10**9
75
-
76
-
77
- def read_dm3_image_info(original_metadata):
78
- if not isinstance(original_metadata, dict):
79
- raise TypeError('We need a python dictionary to read the original metadata')
80
- if 'DM' not in original_metadata:
81
- return {}
82
- main_image = original_metadata['DM']['chosen_image']
83
- exp_dictionary = original_metadata['ImageList'][str(main_image)]['ImageTags']
84
- experiment = {}
85
-
86
- if 'Acquisition' in exp_dictionary:
87
- if 'Parameters' in exp_dictionary['Acquisition']:
88
- if 'High Level' in exp_dictionary['Acquisition']['Parameters']:
89
- if 'Exposure (s)' in exp_dictionary['Acquisition']['Parameters']['High Level']:
90
- experiment['exposure_time'] = \
91
- exp_dictionary['Acquisition']['Parameters']['High Level']['Exposure (s)']
92
-
93
- if 'Microscope Info' in exp_dictionary:
94
- if 'Microscope' in exp_dictionary['Microscope Info']:
95
- experiment['microscope'] = exp_dictionary['Microscope Info']['Microscope']
96
- if 'Voltage' in exp_dictionary['Microscope Info']:
97
- experiment['acceleration_voltage'] = exp_dictionary['Microscope Info']['Voltage']
98
- if 'Illumination Mode' in exp_dictionary['Microscope Info']:
99
- if exp_dictionary['Microscope Info']['Illumination Mode'] == 'TEM':
100
- experiment['convergence_angle'] = 0.0
101
- experiment['collection_angle'] = 100.0
102
- if exp_dictionary['Microscope Info']['Illumination Mode'] == 'SPOT':
103
- experiment['convergence_angle'] = 20.0
104
- experiment['collection_angle'] = 50.0
105
-
106
- return experiment
107
-
108
- def fourier_transform(dset):
109
- """
110
- Reads information into dictionary 'tags', performs 'FFT', and provides a smoothed FT and reciprocal
111
- and intensity limits for visualization.
112
-
113
- Input:
114
- dset: sidp Dataset
115
-
116
- Usage:
117
- fft_dataset = fourier_transform(sidpy_dataset)
118
- fft+dataset.plot()
119
- """
120
-
121
- assert isinstance(dset, sidpy.Dataset), 'Expected a sidpy Dataset'
122
-
123
- selection = []
124
- image_dim = []
125
- # image_dim = get_image_dims(sidpy.DimensionTypes.SPATIAL)
126
- if dset.data_type.name == 'IMAGE_STACK':
127
- for dim, axis in dset._axes.items():
128
- if axis.dimension_type == sidpy.DimensionTypes.SPATIAL:
129
- selection.append(slice(None))
130
- image_dim.append(dim)
131
- elif axis.dimension_type == sidpy.DimensionTypes.TEMPORAL or len(dset) == 3:
132
- selection.append(slice(None))
133
- stack_dim = dim
134
- else:
135
- selection.append(slice(0, 1))
136
- if len(image_dim) != 2:
137
- raise ValueError('need at least two SPATIAL dimension for an image stack')
138
- image_stack = np.squeeze(np.array(dset)[selection])
139
- image = np.sum(np.array(image_stack), axis=stack_dim)
140
- elif dset.data_type.name == 'IMAGE':
141
- image = np.array(dset)
142
- else:
143
- return
144
-
145
- image = image - image.min()
146
- fft_transform = (np.fft.fftshift(np.fft.fft2(image)))
147
-
148
- image_dims = pyTEMlib.sidpy_tools.get_image_dims(dset)
149
- extent = dset.get_extent(image_dims)
150
- scale_x = 1 / abs(extent[1] - extent[0])
151
- scale_y = 1 / abs(extent[2] - extent[3])
152
-
153
- units_x = '1/' + dset._axes[image_dims[0]].units
154
- units_y = '1/' + dset._axes[image_dims[1]].units
155
-
156
- fft_dset = sidpy.Dataset.from_array(fft_transform)
157
- fft_dset.quantity = dset.quantity
158
- fft_dset.units = 'a.u.'
159
- fft_dset.data_type = 'IMAGE'
160
- fft_dset.source = dset.title
161
- fft_dset.modality = 'fft'
162
- fft_dset.set_dimension(0, sidpy.Dimension((np.arange(fft_dset.shape[0]) - fft_dset.shape[0] / 2) * scale_x,
163
- name='u', units=units_x, dimension_type='RECIPROCAL',
164
- quantity='reciprocal_length'))
165
- fft_dset.set_dimension(1, sidpy.Dimension((np.arange(fft_dset.shape[1]) - fft_dset.shape[1] / 2) * scale_y,
166
- name='v', units=units_y, dimension_type='RECIPROCAL',
167
- quantity='reciprocal_length'))
168
-
169
- return fft_dset
170
-
171
-
172
- def power_spectrum(dset, smoothing=3):
173
- """
174
- Calculate power spectrum
175
-
176
- Input:
177
- channel: channel in h5f file with image content
178
- Output:
179
- tags: dictionary with
180
- ['data']: fourier transformed image
181
- ['axis']: scale of reciprocal image
182
- ['power_spectrum']: power_spectrum
183
- ['FOV']: field of view for extent parameter in plotting
184
- ['minimum_intensity']: suggested minimum intensity for plotting
185
- ['maximum_intensity']: suggested maximum intensity for plotting
186
-
187
- """
188
- fft_transform = fourier_transform(dset)
189
- fft_mag = np.abs(fft_transform)
190
- fft_mag2 = ndimage.gaussian_filter(fft_mag, sigma=(smoothing, smoothing), order=0)
191
-
192
- power_spec = fft_transform.like_data(np.log(1.+fft_mag2))
193
-
194
- # prepare mask
195
-
196
- x, y = np.meshgrid(power_spec.u.values, power_spec.v.values)
197
- mask = np.zeros(power_spec.shape)
198
-
199
- mask_spot = x ** 2 + y ** 2 > 1 ** 2
200
- mask = mask + mask_spot
201
- mask_spot = x ** 2 + y ** 2 < 11 ** 2
202
- mask = mask + mask_spot
203
-
204
- mask[np.where(mask == 1)] = 0 # just in case of overlapping disks
205
-
206
- # minimum_intensity = np.log2(1 + fft_mag2)[np.where(mask == 2)].min() * 0.95
207
- # maximum_intensity = np.log2(1 + fft_mag2)[np.where(mask == 2)].max() * 1.05
208
- power_spec.metadata = {'smoothing': smoothing}
209
- # 'minimum_intensity': minimum_intensity, 'maximum_intensity': maximum_intensity}
210
- power_spec.title = 'power spectrum ' + power_spec.source
211
-
212
- return power_spec
213
-
214
-
215
- def diffractogram_spots(dset, spot_threshold):
216
- """
217
- Find spots in diffractogram and sort them by distance from center
218
-
219
- Input:
220
- fft_tags: dictionary with
221
- ['spatial_***']: information of scale of fourier pattern
222
- ['data']: power_spectrum
223
- spot_threshold: threshold for blob finder
224
- Output:
225
- spots: numpy array with sorted position (x,y) and radius (r) of all spots
226
- """
227
- # Needed for conversion from pixel to Reciprocal space
228
- # we'll have to switch x- and y-coordinates due to the differences in numpy and matrix
229
- center = np.array([int(dset.shape[0]/2.), int(dset.shape[1]/2.), 1])
230
- rec_scale = np.array([get_slope(dset.u.values), get_slope(dset.v.values), 1])
231
-
232
- # spot detection ( for future reference there is no symmetry assumed here)
233
- data = np.abs(dset).T
234
- data = (data - data.min())
235
- data = data/data.max()
236
- # some images are strange and blob_log does not work on the power spectrum
237
- try:
238
- spots_random = (blob_log(data, max_sigma=5, threshold=spot_threshold) - center) * rec_scale
239
- except ValueError:
240
- spots_random = (peak_local_max(np.array(data.T), min_distance=3, threshold_rel=spot_threshold) - center[:2]) \
241
- * rec_scale
242
- spots_random = np.hstack(spots_random, np.zeros((spots_random.shape[0],1)))
243
-
244
- print(f'Found {spots_random.shape[0]} reflections')
245
-
246
- # sort reflections
247
- spots_random[:, 2] = np.linalg.norm(spots_random[:, 0:2], axis=1)
248
- spots_index = np.argsort(spots_random[:, 2])
249
- spots = spots_random[spots_index]
250
- # third row is angles
251
- spots[:, 2] = np.arctan2(spots[:, 0], spots[:, 1])
252
- return spots
253
-
254
-
255
- def adaptive_fourier_filter(dset, spots, low_pass=3, reflection_radius=0.3):
256
- """
257
- Use spots in diffractogram for a Fourier Filter
258
-
259
- Input:
260
- image: image to be filtered
261
- tags: dictionary with
262
- ['spatial_***']: information of scale of fourier pattern
263
- ['spots']: sorted spots in diffractogram in 1/nm
264
- low_pass: low pass filter in center of diffractogram
265
-
266
- Output:
267
- sidpy dataset or Fourier filtered image
268
- """
269
- # prepare mask
270
-
271
- fft_transform = fourier_transform(dset)
272
- x, y = np.meshgrid(fft_transform.u.values, fft_transform.v.values)
273
- mask = np.zeros(dset.shape)
274
-
275
- # mask reflections
276
- # reflection_radius = 0.3 # in 1/nm
277
- for spot in spots:
278
- mask_spot = (x - spot[0]) ** 2 + (y - spot[1]) ** 2 < reflection_radius ** 2 # make a spot
279
- mask = mask + mask_spot # add spot to mask
280
-
281
- # mask zero region larger (low-pass filter = intensity variations)
282
- # low_pass = 3 # in 1/nm
283
- mask_spot = x ** 2 + y ** 2 < low_pass ** 2
284
- mask = mask + mask_spot
285
- mask[np.where(mask > 1)] = 1
286
- fft_filtered = fft_transform * mask
287
-
288
- filtered_image = dset.like_data(np.fft.ifft2(np.fft.fftshift(fft_filtered)).real)
289
- filtered_image.title = 'Fourier filtered ' + dset.title
290
- filtered_image.source = dset.title
291
- filtered_image.metadata = {'analysis': 'adaptive fourier filtered', 'spots': spots,
292
- 'low_pass': low_pass, 'reflection_radius': reflection_radius}
293
-
294
- return filtered_image
295
-
296
-
297
- def rotational_symmetry_diffractogram(spots):
298
- """
299
- Determine rotational symmetry of diffraction spots
300
- """
301
- rotation_symmetry = []
302
- for n in [2, 3, 4, 6]:
303
- cc = np.array(
304
- [[np.cos(2 * np.pi / n), np.sin(2 * np.pi / n), 0], [-np.sin(2 * np.pi / n), np.cos(2 * np.pi / n), 0],
305
- [0, 0, 1]])
306
- sym_spots = np.dot(spots, cc)
307
- dif = []
308
- for p0, p1 in product(sym_spots[:, 0:2], spots[:, 0:2]):
309
- dif.append(np.linalg.norm(p0 - p1))
310
- dif = np.array(sorted(dif))
311
-
312
- if dif[int(spots.shape[0] * .7)] < 0.2:
313
- rotation_symmetry.append(n)
314
- return rotation_symmetry
315
-
316
- # ####################################################
317
- # Registration Functions
318
- # ####################################################
319
-
320
-
321
- def complete_registration(main_dataset, storage_channel=None):
322
- """
323
- First Rigid and the Non-Rigid Registration of image Stack
324
-
325
- Input:
326
- sidpy dataset
327
- Returns:
328
- two sidpy datasets, non-rigid registered first
329
- """
330
- rigid_registered_dataset = rigid_registration(main_dataset)
331
- if storage_channel is None:
332
- current_channel = main_dataset.h5_dataset.parent
333
- else:
334
- if not isinstance(storage_channel, h5py.Group):
335
- raise ValueError('storage channel needs to be a h5py Group')
336
- current_channel = storage_channel
337
-
338
- registration_channel = log_results(current_channel, rigid_registered_dataset)
339
-
340
- print('Non-Rigid_Registration')
341
-
342
- non_rigid_registered = demon_registration(rigid_registered_dataset)
343
-
344
- registration_channel = log_results(current_channel, non_rigid_registered)
345
-
346
- return non_rigid_registered, rigid_registered_dataset
347
-
348
-
349
- def demon_registration(dataset, verbose=False):
350
- """
351
- Diffeomorphic Demon Non-Rigid Registration
352
-
353
- Usage:
354
- dem_reg = demon_reg(cube, verbose = False)
355
-
356
- Input:
357
- cube: stack of image after rigid registration and cropping
358
- Output:
359
- dem_reg: stack of images with non-rigid registration
360
- Depends on:
361
- simpleITK and numpy
362
-
363
- Please Cite: http://www.simpleitk.org/SimpleITK/project/parti.html
364
- and T. Vercauteren, X. Pennec, A. Perchant and N. Ayache
365
- Diffeomorphic Demons Using ITK\'s Finite Difference Solver Hierarchy
366
- The Insight Journal, http://hdl.handle.net/1926/510 2007
367
- """
368
-
369
- dem_reg = np.zeros(dataset.shape)
370
- nimages = dataset.shape[0]
371
- if verbose:
372
- print(nimages)
373
- # create fixed image by summing over rigid registration
374
-
375
- fixed_np = np.average(np.array(dataset), axis=0)
376
-
377
- fixed = sITK.GetImageFromArray(fixed_np)
378
- fixed = sITK.DiscreteGaussian(fixed, 2.0)
379
-
380
- # demons = sITK.SymmetricForcesDemonsRegistrationFilter()
381
- demons = sITK.DiffeomorphicDemonsRegistrationFilter()
382
-
383
- demons.SetNumberOfIterations(200)
384
- demons.SetStandardDeviations(1.0)
385
-
386
- resampler = sITK.ResampleImageFilter()
387
- resampler.SetReferenceImage(fixed)
388
- resampler.SetInterpolator(sITK.sitkBSpline)
389
- resampler.SetDefaultPixelValue(0)
390
-
391
- done = 0
392
-
393
- if QT_available:
394
- progress = pyTEMlib.sidpy_tools.ProgressDialog("Non-Rigid Registration", nimages)
395
- for i in range(nimages):
396
- if QT_available:
397
- progress.set_value(i)
398
- else:
399
- if done < int((i + 1) / nimages * 50):
400
- done = int((i + 1) / nimages * 50)
401
- sys.stdout.write('\r')
402
- # progress output :
403
- sys.stdout.write("[%-50s] %d%%" % ('*' * done, 2 * done))
404
- sys.stdout.flush()
405
-
406
- moving = sITK.GetImageFromArray(dataset[i])
407
- moving_f = sITK.DiscreteGaussian(moving, 2.0)
408
- displacement_field = demons.Execute(fixed, moving_f)
409
- out_tx = sITK.DisplacementFieldTransform(displacement_field)
410
- resampler.SetTransform(out_tx)
411
- out = resampler.Execute(moving)
412
- dem_reg[i, :, :] = sITK.GetArrayFromImage(out)
413
- # print('image ', i)
414
-
415
- if QT_available:
416
- progress.close()
417
-
418
- print(':-)')
419
- print('You have successfully completed Diffeomorphic Demons Registration')
420
-
421
- demon_registered = dataset.like_data(dem_reg)
422
- demon_registered.title = 'Non-Rigid Registration'
423
- demon_registered.source = dataset.title
424
-
425
- demon_registered.metadata = {'analysis': 'non-rigid demon registration'}
426
- if 'boundaries' in dataset.metadata:
427
- demon_registered.metadata['boundaries'] = dataset.metadata['boundaries']
428
-
429
- return demon_registered
430
-
431
-
432
- ###############################
433
- # Rigid Registration New 05/09/2020
434
-
435
- def rigid_registration(dataset):
436
- """
437
- Rigid registration of image stack with sub-pixel accuracy
438
-
439
- uses phase_cross_correlation from skimage.registration
440
- (we determine drift from one image to next)
441
-
442
- Input:
443
- hdf5 group with image_stack dataset
444
- Output:
445
- registered stack and drift (with respect to center image)
446
- """
447
-
448
- nopix = dataset.shape[1]
449
- nopiy = dataset.shape[2]
450
- nimages = dataset.shape[0]
451
-
452
- print('Stack contains ', nimages, ' images, each with', nopix, ' pixels in x-direction and ', nopiy,
453
- ' pixels in y-direction')
454
- fixed = np.array(dataset[0])
455
- fft_fixed = np.fft.fft2(fixed)
456
-
457
- relative_drift = [[0., 0.]]
458
- done = 0
459
-
460
- if QT_available:
461
- progress = pyTEMlib.sidpy_tools.ProgressDialog("Rigid Registration", nimages)
462
- for i in range(nimages):
463
- if QT_available:
464
- progress.set_value(i)
465
- else:
466
- if done < int((i + 1) / nimages * 50):
467
- done = int((i + 1) / nimages * 50)
468
- sys.stdout.write('\r')
469
- # progress output :
470
- sys.stdout.write("[%-50s] %d%%" % ('*' * done, 2 * done))
471
- sys.stdout.flush()
472
-
473
- moving = np.array(dataset[i])
474
- fft_moving = np.fft.fft2(moving)
475
- if skimage.__version__[:4] == '0.16':
476
- shift = register_translation(fft_fixed, fft_moving, upsample_factor=1000, space='fourier')
477
- else:
478
- shift = registration.phase_cross_correlation(fft_fixed, fft_moving, upsample_factor=1000, space='fourier')
479
-
480
- fft_fixed = fft_moving
481
- # print(f'Image number {i:2} xshift = {shift[0][0]:6.3f} y-shift = {shift[0][1]:6.3f}')
482
-
483
- relative_drift.append(shift[0])
484
- if QT_available:
485
- progress.close()
486
- rig_reg, drift = rig_reg_drift(dataset, relative_drift)
487
-
488
- crop_reg, boundaries = crop_image_stack(rig_reg, drift)
489
-
490
- rigid_registered = dataset.like_data(crop_reg)
491
- rigid_registered.title = 'Rigid Registration'
492
- rigid_registered.source = dataset.title
493
- rigid_registered.metadata = {'analysis': 'rigid sub-pixel registration', 'drift': drift, 'boundaries': boundaries}
494
-
495
- return rigid_registered
496
-
497
-
498
- def rig_reg_drift(dset, rel_drift):
499
- """
500
- Uses relative drift to shift images ontop of each other
501
- Shifting is done with shift routine of ndimage from scipy
502
-
503
- is used by rigid_registration routine
504
-
505
- Input:
506
- image_channel with image_stack numpy array
507
- relative_drift from image to image as list of [shiftx, shifty]
508
-
509
- output:
510
- stack and drift
511
- """
512
-
513
- rig_reg = np.zeros(dset.shape)
514
- # absolute drift
515
- drift = np.array(rel_drift).copy()
516
-
517
- drift[0] = [0, 0]
518
- for i in range(drift.shape[0]):
519
- drift[i] = drift[i - 1] + rel_drift[i]
520
- center_drift = drift[int(drift.shape[0] / 2)]
521
- drift = drift - center_drift
522
- # Shift images
523
- for i in range(rig_reg.shape[0]):
524
- # Now we shift
525
- rig_reg[i, :, :] = ndimage.shift(dset[i], [drift[i, 0], drift[i, 1]], order=3)
526
- return rig_reg, drift
527
-
528
-
529
- def crop_image_stack(rig_reg, drift):
530
- """
531
- Crop images
532
- """
533
- xpmin = int(-np.floor(np.min(np.array(drift)[:, 0])))
534
- xpmax = int(rig_reg.shape[1] - np.ceil(np.max(np.array(drift)[:, 0])))
535
- ypmin = int(-np.floor(np.min(np.array(drift)[:, 1])))
536
- ypmax = int(rig_reg.shape[2] - np.ceil(np.max(np.array(drift)[:, 1])))
537
-
538
- return rig_reg[:, xpmin:xpmax, ypmin:ypmax], [xpmin, xpmax, ypmin, ypmax]
539
-
540
-
541
- class ImageWithLineProfile:
542
- """
543
- Just a try
544
- """
545
- def __init__(self, data, extent, title=''):
546
- fig, ax = plt.subplots(1, 1)
547
- self.figure = fig
548
- self.title = title
549
- self.line_plot = False
550
- self.ax = ax
551
- self.data = data
552
- self.extent = extent
553
- self.ax.imshow(data, extent=extent)
554
- self.ax.set_title(title)
555
- self.line, = self.ax.plot([0], [0], color='orange') # empty line
556
- self.end_x = self.line.get_xdata()
557
- self.end_y = self.line.get_ydata()
558
- self.cid = self.line.figure.canvas.mpl_connect('button_press_event', self)
559
-
560
- def __call__(self, event):
561
- if event.inaxes != self.line.axes:
562
- return
563
- self.start_x = self.end_x
564
- self.start_y = self.end_y
565
-
566
- self.line.set_data([self.start_x, event.xdata], [self.start_y, event.ydata])
567
- self.line.figure.canvas.draw()
568
-
569
- self.end_x = event.xdata
570
- self.end_y = event.ydata
571
-
572
- self.update()
573
-
574
- def update(self):
575
-
576
- if not self.line_plot:
577
- self.line_plot = True
578
- self.figure.clear()
579
- self.ax = self.figure.subplots(2, 1)
580
- self.ax[0].imshow(self.data, extent=self.extent)
581
- self.ax[0].set_title(self.title)
582
-
583
- self.line, = self.ax[0].plot([0], [0], color='orange') # empty line
584
- self.line_plot, = self.ax[1].plot([], [], color='orange')
585
- self.ax[1].set_xlabel('distance [nm]')
586
-
587
- x0 = self.start_x
588
- x1 = self.end_x
589
- y0 = self.start_y
590
- y1 = self.end_y
591
- length_plot = np.sqrt((x1-x0)**2+(y1-y0)**2)
592
-
593
- num = length_plot*(self.data.shape[0]/self.extent[1])
594
- x = np.linspace(x0, x1, num)*(self.data.shape[0]/self.extent[1])
595
- y = np.linspace(y0, y1, num)*(self.data.shape[0]/self.extent[1])
596
-
597
- # Extract the values along the line, using cubic interpolation
598
- zi2 = ndimage.map_coordinates(self.data.T, np.vstack((x, y)))
599
-
600
- x_axis = np.linspace(0, length_plot, len(zi2))
601
-
602
- self.x = x_axis
603
- self.z = zi2
604
-
605
- self.line_plot.set_xdata(x_axis)
606
- self.line_plot.set_ydata(zi2)
607
- self.ax[1].set_xlim(0, x_axis.max())
608
- self.ax[1].set_ylim(zi2.min(), zi2.max())
609
- self.ax[1].draw()
610
-
611
-
612
- def histogram_plot(image_tags):
613
- """
614
- Interactive histogram
615
- """
616
- nbins = 75
617
- minbin = 0.
618
- maxbin = 1.
619
- color_map_list = ['gray', 'viridis', 'jet', 'hot']
620
-
621
- if 'minimum_intensity' not in image_tags:
622
- image_tags['minimum_intensity'] = image_tags['plotimage'].min()
623
- minimum_intensity = image_tags['minimum_intensity']
624
- if 'maximum_intensity' not in image_tags:
625
- image_tags['maximum_intensity'] = image_tags['plotimage'].max()
626
- data = image_tags['plotimage']
627
- vmin = image_tags['minimum_intensity']
628
- vmax = image_tags['maximum_intensity']
629
- if 'color_map' not in image_tags:
630
- image_tags['color_map'] = color_map_list[0]
631
- cmap = plt.cm.get_cmap(image_tags['color_map'])
632
-
633
- colors = cmap(np.linspace(0., 1., nbins))
634
-
635
- norm2 = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
636
- hist, bin_edges = np.histogram(data, np.linspace(vmin, vmax, nbins), density=True)
637
-
638
- width = bin_edges[1]-bin_edges[0]
639
-
640
- def onselect(vmin, vmax):
641
-
642
- ax1.clear()
643
- cmap = plt.cm.get_cmap(image_tags['color_map'])
644
-
645
- colors = cmap(np.linspace(0., 1., nbins))
646
-
647
- norm2 = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
648
- hist2, bin_edges2 = np.histogram(data, np.linspace(vmin, vmax, nbins), density=True)
649
-
650
- width2 = (bin_edges2[1]-bin_edges2[0])
651
-
652
- for i in range(nbins-1):
653
- histogram[i].xy = (bin_edges2[i], 0)
654
- histogram[i].set_height(hist2[i])
655
- histogram[i].set_width(width2)
656
- histogram[i].set_facecolor(colors[i])
657
- ax.set_xlim(vmin, vmax)
658
- ax.set_ylim(0, hist2.max()*1.01)
659
-
660
- cb1 = mpl.colorbar.ColorbarBase(ax1, cmap=cmap, norm=norm2, orientation='horizontal')
661
-
662
- image_tags['minimum_intensity'] = vmin
663
- image_tags['maximum_intensity'] = vmax
664
-
665
- def onclick(event):
666
- global event2
667
- event2 = event
668
- print('%s click: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' %
669
- ('double' if event.dblclick else 'single', event.button,
670
- event.x, event.y, event.xdata, event.ydata))
671
- if event.inaxes == ax1:
672
- if event.button == 3:
673
- ind = color_map_list.index(image_tags['color_map'])+1
674
- if ind == len(color_map_list):
675
- ind = 0
676
- image_tags['color_map'] = color_map_list[ind] # 'viridis'
677
- vmin = image_tags['minimum_intensity']
678
- vmax = image_tags['maximum_intensity']
679
- else:
680
- vmax = data.max()
681
- vmin = data.min()
682
- onselect(vmin, vmax)
683
-
684
- fig2 = plt.figure()
685
-
686
- ax = fig2.add_axes([0., 0.2, 0.9, 0.7])
687
- ax1 = fig2.add_axes([0., 0.15, 0.9, 0.05])
688
-
689
- histogram = ax.bar(bin_edges[0:-1], hist, width=width, color=colors, edgecolor='black', alpha=0.8)
690
- onselect(vmin, vmax)
691
- cb1 = mpl.colorbar.ColorbarBase(ax1, cmap=cmap, norm=norm2, orientation='horizontal')
692
-
693
- rectprops = dict(facecolor='blue', alpha=0.5)
694
-
695
- span = mwidgets.SpanSelector(ax, onselect, 'horizontal', rectprops=rectprops)
696
-
697
- cid = fig2.canvas.mpl_connect('button_press_event', onclick)
698
- return span
699
-
700
-
701
- def clean_svd(im, pixel_size=1, source_size=5):
702
- """
703
- Cleans image through svd
704
- """
705
- patch_size = int(source_size/pixel_size)
706
- if patch_size < 3:
707
- patch_size = 3
708
- print(patch_size)
709
-
710
- patches = image.extract_patches_2d(im, (patch_size, patch_size))
711
- patches = patches.reshape(patches.shape[0], patches.shape[1]*patches.shape[2])
712
-
713
- num_components = 32
714
-
715
- u, s, v = randomized_svd(patches, num_components)
716
- u_im_size = int(np.sqrt(u.shape[0]))
717
- reduced_image = u[:, 0].reshape(u_im_size, u_im_size)
718
- reduced_image = reduced_image/reduced_image.sum()*im.sum()
719
- return reduced_image
720
-
721
-
722
- def rebin(im, binning=2):
723
- """
724
- rebin an image by the number of pixels in x and y direction given by binning
725
-
726
- Input:
727
- image: numpy array in 2 dimensions
728
- Output:
729
- binned image
730
- """
731
- if len(im.shape) == 2:
732
- return im.reshape((im.shape[0]//binning, binning, im.shape[1]//binning, binning)).mean(axis=3).mean(1)
733
- else:
734
- print('not a 2D image')
735
- return im
736
-
737
-
738
- def cart2pol(points):
739
- """
740
- Cartesian to polar coordinates
741
- """
742
- rho = np.linalg.norm(points[:, 0:2], axis=1)
743
- phi = np.arctan2(points[:, 1], points[:, 0])
744
- return rho, phi
745
-
746
-
747
- def pol2cart(rho, phi):
748
- """
749
- Polar to cartesian coordinates
750
- """
751
- x = rho * np.cos(phi)
752
- y = rho * np.sin(phi)
753
- return x, y
754
-
755
-
756
- def xy2polar(points, rounding=1e-3):
757
- """
758
- Conversion from cartesian to polar coordinates
759
-
760
- the angles and distances are sorted by r and then phi
761
- The indices of this sort is also returned
762
-
763
- input: points:
764
- numpy array with number of points in axis 0 first two elements in axis 1 are x and y
765
- optional:
766
- rounding in significant digits
767
- returns:
768
- r,phi, sorted_indices
769
- """
770
-
771
- r, phi = cart2pol(points)
772
-
773
- phi = phi-phi.min() # only positive angles
774
- r = (np.floor(r/rounding))*rounding # Remove rounding error differences
775
-
776
- sorted_indices = np.lexsort((phi, r)) # sort first by r and then by phi
777
- r = r[sorted_indices]
778
- phi = phi[sorted_indices]
779
-
780
- return r, phi, sorted_indices
781
-
782
-
783
- def cartesian2polar(x, y, grid, r, t, order=3):
784
- """
785
- Cartesian to ploar coordinates
786
- """
787
-
788
- rr, tt = np.meshgrid(r, t)
789
-
790
- new_x = rr*np.cos(tt)
791
- new_y = rr*np.sin(tt)
792
-
793
- ix = interp1d(x, np.arange(len(x)))
794
- iy = interp1d(y, np.arange(len(y)))
795
-
796
- new_ix = ix(new_x.ravel())
797
- new_iy = iy(new_y.ravel())
798
-
799
- return ndimage.map_coordinates(grid, np.array([new_ix, new_iy]), order=order).reshape(new_x.shape)
800
-
801
-
802
- def warp(diff, center):
803
- """
804
- diffraction pattern to polar coordinates
805
- """
806
- # Define original polar grid
807
- nx = diff.shape[0]
808
- ny = diff.shape[1]
809
-
810
- x = np.linspace(1, nx, nx, endpoint=True)-center[1]
811
- y = np.linspace(1, ny, ny, endpoint=True)-center[0]
812
- z = np.abs(diff)
813
-
814
- # Define new polar grid
815
- nr = min([center[0], center[1], diff.shape[0]-center[0], diff.shape[1]-center[1]])-1
816
- nt = 360*3
817
-
818
- r = np.linspace(1, nr, nr)
819
- t = np.linspace(0., np.pi, nt, endpoint=False)
820
- return cartesian2polar(x, y, z, r, t, order=3).T
821
-
822
-
823
- def calculate_ctf(wavelength, cs, defocus, k):
824
- """
825
- Calculate Contrast Transfer Function
826
-
827
- everything in nm
828
- """
829
- ctf = np.sin(np.pi*defocus*wavelength*k**2+0.5*np.pi*cs*wavelength**3*k**4)
830
- return ctf
831
-
832
-
833
- def calculate_scherzer(wavelength, cs):
834
- """
835
- Calculate the Scherzer defocus. Cs is in mm, lambda is in nm
836
-
837
- Input and output in nm
838
- """
839
- scherzer = -1.155*(cs*wavelength)**0.5 # in m
840
- return scherzer
841
-
842
-
843
- def calibrate_image_scale(fft_tags, spots_reference, spots_experiment):
844
- """
845
- Calibrate image scale based on bragg spots in Fourier space
846
- """
847
- gx = fft_tags['spatial_scale_x']
848
- gy = fft_tags['spatial_scale_y']
849
-
850
- dist_reference = np.linalg.norm(spots_reference, axis=1)
851
- distance_experiment = np.linalg.norm(spots_experiment, axis=1)
852
-
853
- first_reflections = abs(distance_experiment - dist_reference.min()) < .2
854
- print('Evaluate ', first_reflections.sum(), 'reflections')
855
- closest_exp_reflections = spots_experiment[first_reflections]
856
-
857
- def func(params, xdata, ydata):
858
- dgx, dgy = params
859
- return np.sqrt((xdata * dgx) ** 2 + (ydata * dgy) ** 2) - dist_reference.min()
860
-
861
- x0 = [1.001, 0.999]
862
- dg, sig = optimization.leastsq(func, x0, args=(closest_exp_reflections[:, 0], closest_exp_reflections[:, 1]))
863
- return dg
864
-
865
-
866
- def align_crystal_reflections(spots, crystals):
867
- """
868
- Find rotation between diffraction spots and diffractogram spots
869
- """
870
- crystal_reflections_polar = []
871
- angles = []
872
- exp_r, exp_phi = cart2pol(spots) # just in polar coordinates
873
- spots_polar = np.array([exp_r, exp_phi])
874
-
875
- for i in range(len(crystals)):
876
- tags = crystals[i]
877
- r, phi, indices = xy2polar(tags['allowed']['g']) # sorted by r and phi , only positive angles
878
- # we mask the experimental values that are found already
879
- angle = 0.
880
-
881
- angle_i = np.argmin(np.abs(exp_r - r[1]))
882
- angle = exp_phi[angle_i] - phi[0]
883
- angles.append(angle) # Determine rotation angle
884
-
885
- crystal_reflections_polar.append([r, angle + phi, indices])
886
- tags['allowed']['g_rotated'] = pol2cart(r, angle + phi)
887
- for spot in tags['allowed']['g']:
888
- dif = np.linalg.norm(spots[:, 0:2]-spot[0:2], axis=1)
889
- # print(dif.min())
890
- if dif.min() < 1.5:
891
- ind = np.argmin(dif)
892
-
893
- return crystal_reflections_polar, angles
894
-
895
-
896
- # Deconvolution
897
- def decon_lr(o_image, probe, verbose=False):
898
- """
899
- This task generates a restored image from an input image and point spread function (PSF)
900
-
901
- using the algorithm developed independently by Lucy (1974, Astron. J. 79, 745) and Richardson
902
- (1972, J. Opt. Soc. Am. 62, 55) and adapted for HST imagery by Snyder
903
- (1990, in Restoration of HST Images and Spectra, ST ScI Workshop Proceedings; see also
904
- Snyder, Hammoud, & White, JOSA, v. 10, no. 5, May 1993, in press).
905
- Additional options developed by Rick White (STScI) are also included.
906
-
907
- The Lucy-Richardson method can be derived from the maximum likelihood expression for data
908
- with a Poisson noise distribution. Thus, it naturally applies to optical imaging data such as HST.
909
- The method forces the restored image to be positive, in accord with photon-counting statistics.
910
-
911
- The Lucy-Richardson algorithm generates a restored image through an iterative method. The essence
912
- of the iteration is as follows: the (n+1)th estimate of the restored image is given by the nth estimate
913
- of the restored image multiplied by a correction image. That is,
914
-
915
- original data
916
-
917
- image = image --------------- * reflect(PSF)
918
-
919
- n+1 n image * PSF
920
-
921
- n
922
-
923
-
924
- where the *'s represent convolution operators and reflect(PSF) is the reflection of the PSF, i.e.
925
- reflect((PSF)(x,y)) = PSF(-x,-y). When the convolutions are carried out using fast Fourier transforms
926
- (FFTs), one can use the fact that FFT(reflect(PSF)) = conj(FFT(PSF)), where conj is the complex conjugate
927
-
928
- operator.
929
- """
930
-
931
- if len(o_image) < 1:
932
- return o_image
933
-
934
- if o_image.shape != probe.shape:
935
- print('Weirdness ', o_image.shape, ' != ', probe.shape)
936
-
937
- probe_c = np.ones(probe.shape, dtype=np.complex64)
938
- probe_c.real = probe
939
-
940
- error = np.ones(o_image.shape, dtype=np.complex64)
941
- est = np.ones(o_image.shape, dtype=np.complex64)
942
- source = np.ones(o_image.shape, dtype=np.complex64)
943
- source.real = o_image-o_image.min()+1e-9
944
-
945
- response_ft = fftpack.fft2(probe_c)
946
-
947
- ab = o_image.metadata['experiment']
948
- if 'convergence_angle' not in ab:
949
- ab['convergence_angle'] = 30
950
- ap_angle = ab['convergence_angle'] / 1000.0
951
-
952
- if 'acceleration_voltage' not in ab:
953
- e0 = 200000.
954
- else:
955
- e0 = float(ab['acceleration_voltage'])
956
-
957
- wl = get_wavelength(e0)
958
- over_d = 2 * ap_angle / wl
959
-
960
- dx = get_slope(o_image.dim_0)
961
- dk = 1.0 / float(o_image.dim_0[-1] - o_image.dim_0[0])
962
- screen_width = 1 / dx
963
-
964
- aperture = np.ones(o_image.shape, dtype=np.complex64)
965
- # Mask for the aperture before the Fourier transform
966
- n = o_image.shape[0]
967
- size_x = o_image.shape[0]
968
- size_y = o_image.shape[1]
969
- app_ratio = over_d / screen_width * n
970
-
971
- theta_x = np.array(-size_x / 2. + np.arange(size_x))
972
- theta_y = np.array(-size_y / 2. + np.arange(size_y))
973
- t_xv, t_yv = np.meshgrid(theta_x, theta_y)
974
-
975
- tp1 = t_xv ** 2 + t_yv ** 2 >= app_ratio ** 2
976
- aperture[tp1.T] = 0.
977
-
978
- if QT_available:
979
- progress = pyTEMlib.sidpy_tools.ProgressDialog("Lucy-Richardson", 100)
980
-
981
- # de = 100
982
- dest = 100
983
- i = 0
984
- while abs(dest) > 0.001: # or abs(de) > .025:
985
- i += 1
986
- error_old = np.sum(error.real)
987
- est_old = est.copy()
988
- error = source / np.real(fftpack.fftshift(fftpack.ifft2(fftpack.fft2(est) * response_ft)))
989
- est = est * np.real(fftpack.fftshift(fftpack.ifft2(fftpack.fft2(error) * np.conjugate(response_ft))))
990
- # est = est_old * est
991
- # est = np.real(fftpack.fftshift(fftpack.ifft2(fftpack.fft2(est)*fftpack.fftshift(aperture) )))
992
-
993
- error_new = np.real(np.sum(np.power(error, 2))) - error_old
994
- dest = np.sum(np.power((est - est_old).real, 2)) / np.sum(est) * 100
995
- # print(np.sum((est.real - est_old.real)* (est.real - est_old.real) )/np.sum(est.real)*100 )
996
-
997
- if error_old != 0:
998
- de = error_new / error_old * 1.0
999
- else:
1000
- de = error_new
1001
-
1002
- if verbose:
1003
- print(
1004
- ' LR Deconvolution - Iteration: {0:d} Error: {1:.2f} = change: {2:.5f}%, {3:.5f}%'.format(i, error_new,
1005
- de,
1006
- abs(dest)))
1007
-
1008
- if QT_available:
1009
- count = (0.1 - abs(dest)) * 1000.
1010
- if count < 0:
1011
- count = 0
1012
- progress.set_value(count)
1013
-
1014
- if i > 200:
1015
- dest = 0.0
1016
- print('terminate')
1017
- if QT_available:
1018
- progress.close()
1019
- print('\n Lucy-Richardson deconvolution converged in ' + str(i) + ' Iterations')
1020
- est2 = np.real(fftpack.ifft2(fftpack.fft2(est) * fftpack.fftshift(aperture)))
1021
- # plt.imshow(np.real(np.log10(np.abs(fftpack.fftshift(fftpack.fft2(est)))+1)+aperture), origin='lower',)
1022
- # plt.show()
1023
- dset =o_image.like_data(est2.real, name='Lucy_Richardson')
1024
- dset.data_type = 'image'
1025
- dset.quantity = 'intensity'
1026
- dset.source = o_image.title
1027
-
1028
- return dset
1
+ """
2
+ image_tools.py
3
+ by Gerd Duscher, UTK
4
+ part of pyTEMlib
5
+ MIT license except where stated differently
6
+ """
7
+
8
+ import numpy as np
9
+ import matplotlib
10
+ import matplotlib as mpl
11
+ import matplotlib.pylab as plt
12
+ import matplotlib.widgets as mwidgets
13
+ # from matplotlib.widgets import RectangleSelector
14
+
15
+ import sidpy
16
+ import pyTEMlib.file_tools as ft
17
+ import pyTEMlib.sidpy_tools
18
+ # import pyTEMlib.probe_tools
19
+
20
+ from tqdm.auto import trange, tqdm
21
+
22
+ # import itertools
23
+ from itertools import product
24
+
25
+ from scipy import fftpack
26
+ import scipy
27
+ # from scipy import signal
28
+ from scipy.interpolate import interp1d # , interp2d
29
+ import scipy.optimize as optimization
30
+
31
+ # Multidimensional Image library
32
+ import scipy.ndimage as ndimage
33
+ import scipy.constants as const
34
+
35
+ # from scipy.spatial import Voronoi, KDTree, cKDTree
36
+
37
+ import skimage
38
+
39
+ import skimage.registration as registration
40
+ # from skimage.feature import register_translation # blob_dog, blob_doh
41
+ from skimage.feature import peak_local_max
42
+ # from skimage.measure import points_in_poly
43
+
44
+ # our blob detectors from the scipy image package
45
+ from skimage.feature import blob_log # blob_dog, blob_doh
46
+
47
+ from sklearn.feature_extraction import image
48
+ from sklearn.utils.extmath import randomized_svd
49
+ from sklearn.cluster import DBSCAN
50
+
51
+ from collections import Counter
52
+
53
+ # center diff function
54
+ from skimage.filters import threshold_otsu, sobel
55
+ from scipy.optimize import leastsq
56
+ from sklearn.cluster import DBSCAN
57
+
58
+
59
+ _SimpleITK_present = True
60
+ try:
61
+ import SimpleITK as sitk
62
+ except ImportError:
63
+ sitk = False
64
+ _SimpleITK_present = False
65
+
66
+ if not _SimpleITK_present:
67
+ print('SimpleITK not installed; Registration Functions for Image Stacks not available\n' +
68
+ 'install with: conda install -c simpleitk simpleitk ')
69
+
70
+
71
+ # Wavelength in 1/nm
72
+ def get_wavelength(e0):
73
+ """
74
+ Calculates the relativistic corrected de Broglie wave length of an electron
75
+
76
+ Parameters
77
+ ----------
78
+ e0: float
79
+ acceleration voltage in volt
80
+
81
+ Returns
82
+ -------
83
+ wave length in 1/nm
84
+ """
85
+
86
+ eV = const.e * e0
87
+ return const.h/np.sqrt(2*const.m_e*eV*(1+eV/(2*const.m_e*const.c**2)))*10**9
88
+
89
+
90
+ def fourier_transform(dset: sidpy.Dataset) -> sidpy.Dataset:
91
+ """
92
+ Reads information into dictionary 'tags', performs 'FFT', and provides a smoothed FT and reciprocal
93
+ and intensity limits for visualization.
94
+
95
+ Parameters
96
+ ----------
97
+ dset: sidpy.Dataset
98
+ image
99
+
100
+ Returns
101
+ -------
102
+ fft_dset: sidpy.Dataset
103
+ Fourier transform with correct dimensions
104
+
105
+ Example
106
+ -------
107
+ >>> fft_dataset = fourier_transform(sidpy_dataset)
108
+ >>> fft_dataset.plot()
109
+ """
110
+
111
+ assert isinstance(dset, sidpy.Dataset), 'Expected a sidpy Dataset'
112
+
113
+ selection = []
114
+ image_dim = []
115
+ # image_dim = get_image_dims(sidpy.DimensionTypes.SPATIAL)
116
+
117
+ if dset.data_type == sidpy.DataType.IMAGE_STACK:
118
+ image_dim = dset.get_image_dims()
119
+ stack_dim = dset.get_dimensions_by_type('TEMPORAL')
120
+
121
+ if len(image_dim) != 2:
122
+ raise ValueError('need at least two SPATIAL dimension for an image stack')
123
+
124
+ for i in range(dset.ndim):
125
+ if i in image_dim:
126
+ selection.append(slice(None))
127
+ if len(stack_dim) == 0:
128
+ stack_dim = i
129
+ selection.append(slice(None))
130
+ elif i in stack_dim:
131
+ stack_dim = i
132
+ selection.append(slice(None))
133
+ else:
134
+ selection.append(slice(0, 1))
135
+
136
+ image_stack = np.squeeze(np.array(dset)[selection])
137
+ new_image = np.sum(np.array(image_stack), axis=stack_dim)
138
+ elif dset.data_type == sidpy.DataType.IMAGE:
139
+ new_image = np.array(dset)
140
+ else:
141
+ return
142
+
143
+ new_image = new_image - new_image.min()
144
+ fft_transform = (np.fft.fftshift(np.fft.fft2(new_image)))
145
+
146
+ image_dims = pyTEMlib.sidpy_tools.get_image_dims(dset)
147
+
148
+ units_x = '1/' + dset._axes[image_dims[0]].units
149
+ units_y = '1/' + dset._axes[image_dims[1]].units
150
+
151
+ fft_dset = sidpy.Dataset.from_array(fft_transform)
152
+ fft_dset.quantity = dset.quantity
153
+ fft_dset.units = 'a.u.'
154
+ fft_dset.data_type = 'IMAGE'
155
+ fft_dset.source = dset.title
156
+ fft_dset.modality = 'fft'
157
+
158
+ fft_dset.set_dimension(0, sidpy.Dimension(np.fft.fftshift(np.fft.fftfreq(new_image.shape[0],
159
+ d=ft.get_slope(dset.x.values))),
160
+
161
+ name='u', units=units_x, dimension_type='RECIPROCAL',
162
+ quantity='reciprocal_length'))
163
+ fft_dset.set_dimension(1, sidpy.Dimension(np.fft.fftshift(np.fft.fftfreq(new_image.shape[1],
164
+ d=ft.get_slope(dset.y.values))),
165
+ name='v', units=units_y, dimension_type='RECIPROCAL',
166
+ quantity='reciprocal_length'))
167
+
168
+ return fft_dset
169
+
170
+
171
+ def power_spectrum(dset, smoothing=3):
172
+ """
173
+ Calculate power spectrum
174
+
175
+ Parameters
176
+ ----------
177
+ dset: sidpy.Dataset
178
+ image
179
+ smoothing: int
180
+ Gaussian smoothing
181
+
182
+ Returns
183
+ -------
184
+ power_spec: sidpy.Dataset
185
+ power spectrum with correct dimensions
186
+
187
+ """
188
+
189
+ fft_transform = fourier_transform(dset) # dset.fft()
190
+ fft_mag = np.abs(fft_transform)
191
+ fft_mag2 = ndimage.gaussian_filter(fft_mag, sigma=(smoothing, smoothing), order=0)
192
+
193
+ power_spec = fft_transform.like_data(np.log(1.+fft_mag2))
194
+
195
+ # prepare mask
196
+ x, y = np.meshgrid(power_spec.v.values, power_spec.u.values)
197
+ mask = np.zeros(power_spec.shape)
198
+
199
+ mask_spot = x ** 2 + y ** 2 > 1 ** 2
200
+ mask = mask + mask_spot
201
+ mask_spot = x ** 2 + y ** 2 < 11 ** 2
202
+ mask = mask + mask_spot
203
+
204
+ mask[np.where(mask == 1)] = 0 # just in case of overlapping disks
205
+
206
+ minimum_intensity = np.array(power_spec)[np.where(mask == 2)].min() * 0.95
207
+ maximum_intensity = np.array(power_spec)[np.where(mask == 2)].max() * 1.05
208
+ power_spec.metadata = {'fft': {'smoothing': smoothing,
209
+ 'minimum_intensity': minimum_intensity, 'maximum_intensity': maximum_intensity}}
210
+ power_spec.title = 'power spectrum ' + power_spec.source
211
+
212
+ return power_spec
213
+
214
+
215
+ def diffractogram_spots(dset, spot_threshold, return_center=True, eps=0.1):
216
+ """Find spots in diffractogram and sort them by distance from center
217
+
218
+ Uses blob_log from scipy.spatial
219
+
220
+ Parameters
221
+ ----------
222
+ dset: sidpy.Dataset
223
+ diffractogram
224
+ spot_threshold: float
225
+ threshold for blob finder
226
+ return_center: bool, optional
227
+ return center of image if true
228
+ eps: float, optional
229
+ threshold for blob finder
230
+
231
+ Returns
232
+ -------
233
+ spots: numpy array
234
+ sorted position (x,y) and radius (r) of all spots
235
+ """
236
+
237
+ # spot detection (for future reference there is no symmetry assumed here)
238
+ data = np.array(np.log(1+np.abs(dset)))
239
+ data = data - data.min()
240
+ data = data/data.max()
241
+ # some images are strange and blob_log does not work on the power spectrum
242
+ try:
243
+ spots_random = blob_log(data, max_sigma=5, threshold=spot_threshold)
244
+ except ValueError:
245
+ spots_random = peak_local_max(np.array(data.T), min_distance=3, threshold_rel=spot_threshold)
246
+ spots_random = np.hstack(spots_random, np.zeros((spots_random.shape[0], 1)))
247
+
248
+ print(f'Found {spots_random.shape[0]} reflections')
249
+
250
+ # Needed for conversion from pixel to Reciprocal space
251
+ rec_scale = np.array([ft.get_slope(dset.u.values), ft.get_slope(dset.v.values)])
252
+ spots_random[:, :2] = spots_random[:, :2]*rec_scale+[dset.u.values[0], dset.v.values[0]]
253
+ # sort reflections
254
+ spots_random[:, 2] = np.linalg.norm(spots_random[:, 0:2], axis=1)
255
+ spots_index = np.argsort(spots_random[:, 2])
256
+ spots = spots_random[spots_index]
257
+ # third row is angles
258
+ spots[:, 2] = np.arctan2(spots[:, 0], spots[:, 1])
259
+
260
+ center = [0, 0]
261
+
262
+ if return_center:
263
+ points = spots[:, 0:2]
264
+
265
+ # Calculate the midpoints between all points
266
+ reshaped_points = points[:, np.newaxis, :]
267
+ midpoints = (reshaped_points + reshaped_points.transpose(1, 0, 2)) / 2.0
268
+ midpoints = midpoints.reshape(-1, 2)
269
+
270
+ # Find the most dense cluster of midpoints
271
+ dbscan = DBSCAN(eps=eps, min_samples=2)
272
+ labels = dbscan.fit_predict(midpoints)
273
+ cluster_counter = Counter(labels)
274
+ largest_cluster_label = max(cluster_counter, key=cluster_counter.get)
275
+ largest_cluster_points = midpoints[labels == largest_cluster_label]
276
+
277
+ # Average of these midpoints must be the center
278
+ center = np.mean(largest_cluster_points, axis=0)
279
+
280
+ return spots, center
281
+
282
+
283
+ def center_diffractogram(dset, return_plot = True, histogram_factor = None, smoothing = 1, min_samples = 100):
284
+ try:
285
+ diff = np.array(dset).T.astype(np.float16)
286
+ diff[diff < 0] = 0
287
+
288
+ if histogram_factor is not None:
289
+ hist, bins = np.histogram(np.ravel(diff), bins=256, range=(0, 1), density=True)
290
+ threshold = threshold_otsu(diff, hist = hist * histogram_factor)
291
+ else:
292
+ threshold = threshold_otsu(diff)
293
+ binary = (diff > threshold).astype(float)
294
+ smoothed_image = ndimage.gaussian_filter(binary, sigma=smoothing) # Smooth before edge detection
295
+ smooth_threshold = threshold_otsu(smoothed_image)
296
+ smooth_binary = (smoothed_image > smooth_threshold).astype(float)
297
+ # Find the edges using the Sobel operator
298
+ edges = sobel(smooth_binary)
299
+ edge_points = np.argwhere(edges)
300
+
301
+ # Use DBSCAN to cluster the edge points
302
+ db = DBSCAN(eps=10, min_samples=min_samples).fit(edge_points)
303
+ labels = db.labels_
304
+ if len(set(labels)) == 1:
305
+ raise ValueError("DBSCAN clustering resulted in only one group, check the parameters.")
306
+
307
+ # Get the largest group of edge points
308
+ unique, counts = np.unique(labels, return_counts=True)
309
+ counts = dict(zip(unique, counts))
310
+ largest_group = max(counts, key=counts.get)
311
+ edge_points = edge_points[labels == largest_group]
312
+
313
+ # Fit a circle to the diffraction ring
314
+ def calc_distance(c, x, y):
315
+ Ri = np.sqrt((x - c[0])**2 + (y - c[1])**2)
316
+ return Ri - Ri.mean()
317
+ x_m = np.mean(edge_points[:, 1])
318
+ y_m = np.mean(edge_points[:, 0])
319
+ center_guess = x_m, y_m
320
+ center, ier = leastsq(calc_distance, center_guess, args=(edge_points[:, 1], edge_points[:, 0]))
321
+ mean_radius = np.mean(calc_distance(center, edge_points[:, 1], edge_points[:, 0])) + np.sqrt((edge_points[:, 1] - center[0])**2 + (edge_points[:, 0] - center[1])**2).mean()
322
+
323
+ finally:
324
+ if return_plot:
325
+ fig, ax = plt.subplots(1, 4, figsize=(10, 4))
326
+ ax[0].set_title('Diffractogram')
327
+ ax[0].imshow(dset.T, cmap='viridis')
328
+ ax[1].set_title('Otsu Binary Image')
329
+ ax[1].imshow(binary, cmap='gray')
330
+ ax[2].set_title('Smoothed Binary Image')
331
+ ax[2].imshow(smooth_binary, cmap='gray')
332
+ ax[3].set_title('Edge Detection and Fitting')
333
+ ax[3].imshow(edges, cmap='gray')
334
+ ax[3].scatter(center[0], center[1], c='r', s=10)
335
+ circle = plt.Circle(center, mean_radius, color='red', fill=False)
336
+ ax[3].add_artist(circle)
337
+ for axis in ax:
338
+ axis.axis('off')
339
+ fig.tight_layout()
340
+
341
+ return center
342
+
343
+
344
+ def adaptive_fourier_filter(dset, spots, low_pass=3, reflection_radius=0.3):
345
+ """
346
+ Use spots in diffractogram for a Fourier Filter
347
+
348
+ Parameters:
349
+ -----------
350
+ dset: sidpy.Dataset
351
+ image to be filtered
352
+ spots: np.ndarray(N,2)
353
+ sorted spots in diffractogram in 1/nm
354
+ low_pass: float
355
+ low pass filter in center of diffractogram in 1/nm
356
+ reflection_radius: float
357
+ radius of masked reflections in 1/nm
358
+
359
+ Output:
360
+ -------
361
+ Fourier filtered image
362
+ """
363
+
364
+ if not isinstance(dset, sidpy.Dataset):
365
+ raise TypeError('We need a sidpy.Dataset')
366
+ fft_transform = fourier_transform(dset)
367
+
368
+ # prepare mask
369
+ x, y = np.meshgrid(fft_transform.v.values, fft_transform.u.values)
370
+ mask = np.zeros(dset.shape)
371
+
372
+ # mask reflections
373
+ for spot in spots:
374
+ mask_spot = (x - spot[1]) ** 2 + (y - spot[0]) ** 2 < reflection_radius ** 2 # make a spot
375
+ mask = mask + mask_spot # add spot to mask
376
+
377
+ # mask zero region larger (low-pass filter = intensity variations)
378
+ mask_spot = x ** 2 + y ** 2 < low_pass ** 2
379
+ mask = mask + mask_spot
380
+ mask[np.where(mask > 1)] = 1
381
+ fft_filtered = np.array(fft_transform * mask)
382
+
383
+ filtered_image = dset.like_data(np.fft.ifft2(np.fft.fftshift(fft_filtered)).real)
384
+ filtered_image.title = 'Fourier filtered ' + dset.title
385
+ filtered_image.source = dset.title
386
+ filtered_image.metadata = {'analysis': 'adaptive fourier filtered', 'spots': spots,
387
+ 'low_pass': low_pass, 'reflection_radius': reflection_radius}
388
+ return filtered_image
389
+
390
+
391
+ def rotational_symmetry_diffractogram(spots):
392
+ """ Test rotational symmetry of diffraction spots"""
393
+
394
+ rotation_symmetry = []
395
+ for n in [2, 3, 4, 6]:
396
+ cc = np.array(
397
+ [[np.cos(2 * np.pi / n), np.sin(2 * np.pi / n), 0], [-np.sin(2 * np.pi / n), np.cos(2 * np.pi / n), 0],
398
+ [0, 0, 1]])
399
+ sym_spots = np.dot(spots, cc)
400
+ dif = []
401
+ for p0, p1 in product(sym_spots[:, 0:2], spots[:, 0:2]):
402
+ dif.append(np.linalg.norm(p0 - p1))
403
+ dif = np.array(sorted(dif))
404
+
405
+ if dif[int(spots.shape[0] * .7)] < 0.2:
406
+ rotation_symmetry.append(n)
407
+ return rotation_symmetry
408
+
409
+ #####################################################
410
+ # Registration Functions
411
+ #####################################################
412
+
413
+
414
+ def complete_registration(main_dataset, storage_channel=None):
415
+ """Rigid and then non-rigid (demon) registration
416
+
417
+ Performs rigid and then non-rigid registration, please see individual functions:
418
+ - rigid_registration
419
+ - demon_registration
420
+
421
+ Parameters
422
+ ----------
423
+ main_dataset: sidpy.Dataset
424
+ dataset of data_type 'IMAGE_STACK' to be registered
425
+ storage_channel: h5py.Group
426
+ optional - location in hdf5 file to store datasets
427
+
428
+ Returns
429
+ -------
430
+ non_rigid_registered: sidpy.Dataset
431
+ rigid_registered_dataset: sidpy.Dataset
432
+
433
+ """
434
+
435
+ if not isinstance(main_dataset, sidpy.Dataset):
436
+ raise TypeError('We need a sidpy.Dataset')
437
+ if main_dataset.data_type.name != 'IMAGE_STACK':
438
+ raise TypeError('Registration makes only sense for an image stack')
439
+
440
+ print('Rigid_Registration')
441
+
442
+ rigid_registered_dataset = rigid_registration(main_dataset)
443
+
444
+
445
+ print('Non-Rigid_Registration')
446
+
447
+ non_rigid_registered = demon_registration(rigid_registered_dataset)
448
+ return non_rigid_registered, rigid_registered_dataset
449
+
450
+
451
+ def demon_registration(dataset, verbose=False):
452
+ """
453
+ Diffeomorphic Demon Non-Rigid Registration
454
+
455
+ Depends on:
456
+ simpleITK and numpy
457
+ Please Cite: http://www.simpleitk.org/SimpleITK/project/parti.html
458
+ and T. Vercauteren, X. Pennec, A. Perchant and N. Ayache
459
+ Diffeomorphic Demons Using ITK\'s Finite Difference Solver Hierarchy
460
+ The Insight Journal, http://hdl.handle.net/1926/510 2007
461
+
462
+ Parameters
463
+ ----------
464
+ dataset: sidpy.Dataset
465
+ stack of image after rigid registration and cropping
466
+ verbose: boolean
467
+ optional for increased output
468
+ Returns
469
+ -------
470
+ dem_reg: stack of images with non-rigid registration
471
+
472
+ Example
473
+ -------
474
+ dem_reg = demon_reg(stack_dataset, verbose=False)
475
+ """
476
+
477
+ if not isinstance(dataset, sidpy.Dataset):
478
+ raise TypeError('We need a sidpy.Dataset')
479
+ if dataset.data_type.name != 'IMAGE_STACK':
480
+ raise TypeError('Registration makes only sense for an image stack')
481
+
482
+ dem_reg = np.zeros(dataset.shape)
483
+ nimages = dataset.shape[0]
484
+ if verbose:
485
+ print(nimages)
486
+ # create fixed image by summing over rigid registration
487
+
488
+ fixed_np = np.average(np.array(dataset), axis=0)
489
+
490
+ if not _SimpleITK_present:
491
+ print('This feature is not available: \n Please install simpleITK with: conda install simpleitk -c simpleitk')
492
+
493
+ fixed = sitk.GetImageFromArray(fixed_np)
494
+ fixed = sitk.DiscreteGaussian(fixed, 2.0)
495
+
496
+ # demons = sitk.SymmetricForcesDemonsRegistrationFilter()
497
+ demons = sitk.DiffeomorphicDemonsRegistrationFilter()
498
+
499
+ demons.SetNumberOfIterations(200)
500
+ demons.SetStandardDeviations(1.0)
501
+
502
+ resampler = sitk.ResampleImageFilter()
503
+ resampler.SetReferenceImage(fixed)
504
+ resampler.SetInterpolator(sitk.sitkBSpline)
505
+ resampler.SetDefaultPixelValue(0)
506
+
507
+ for i in trange(nimages):
508
+
509
+ moving = sitk.GetImageFromArray(dataset[i])
510
+ moving_f = sitk.DiscreteGaussian(moving, 2.0)
511
+ displacement_field = demons.Execute(fixed, moving_f)
512
+ out_tx = sitk.DisplacementFieldTransform(displacement_field)
513
+ resampler.SetTransform(out_tx)
514
+ out = resampler.Execute(moving)
515
+ dem_reg[i, :, :] = sitk.GetArrayFromImage(out)
516
+
517
+ print(':-)')
518
+ print('You have successfully completed Diffeomorphic Demons Registration')
519
+
520
+ demon_registered = dataset.like_data(dem_reg)
521
+ demon_registered.title = 'Non-Rigid Registration'
522
+ demon_registered.source = dataset.title
523
+
524
+ demon_registered.metadata = {'analysis': 'non-rigid demon registration'}
525
+ if 'input_crop' in dataset.metadata:
526
+ demon_registered.metadata['input_crop'] = dataset.metadata['input_crop']
527
+ if 'input_shape' in dataset.metadata:
528
+ demon_registered.metadata['input_shape'] = dataset.metadata['input_shape']
529
+ demon_registered.metadata['input_dataset'] = dataset.source
530
+ return demon_registered
531
+
532
+
533
+ ###############################
534
+ # Rigid Registration New 05/09/2020
535
+
536
+ def rigid_registration(dataset, sub_pixel=True):
537
+ """
538
+ Rigid registration of image stack with pixel accuracy
539
+
540
+ Uses simple cross_correlation
541
+ (we determine drift from one image to next)
542
+
543
+ Parameters
544
+ ----------
545
+ dataset: sidpy.Dataset
546
+ sidpy dataset with image_stack dataset
547
+
548
+ Returns
549
+ -------
550
+ rigid_registered: sidpy.Dataset
551
+ Registered Stack and drift (with respect to center image)
552
+ """
553
+
554
+ if not isinstance(dataset, sidpy.Dataset):
555
+ raise TypeError('We need a sidpy.Dataset')
556
+ if dataset.data_type.name != 'IMAGE_STACK':
557
+ raise TypeError('Registration makes only sense for an image stack')
558
+
559
+ frame_dim = []
560
+ spatial_dim = []
561
+ selection = []
562
+
563
+ for i, axis in dataset._axes.items():
564
+ if axis.dimension_type.name == 'SPATIAL':
565
+ spatial_dim.append(i)
566
+ selection.append(slice(None))
567
+ else:
568
+ frame_dim.append(i)
569
+ selection.append(slice(0, 1))
570
+
571
+ if len(spatial_dim) != 2:
572
+ print('need two spatial dimensions')
573
+ if len(frame_dim) != 1:
574
+ print('need one frame dimensions')
575
+
576
+ nopix = dataset.shape[spatial_dim[0]]
577
+ nopiy = dataset.shape[spatial_dim[1]]
578
+ nimages = dataset.shape[frame_dim[0]]
579
+
580
+ print('Stack contains ', nimages, ' images, each with', nopix, ' pixels in x-direction and ', nopiy,
581
+ ' pixels in y-direction')
582
+
583
+ fixed = dataset[tuple(selection)].squeeze().compute()
584
+ fft_fixed = np.fft.fft2(fixed)
585
+
586
+ relative_drift = [[0., 0.]]
587
+
588
+ for i in trange(nimages):
589
+ selection[frame_dim[0]] = slice(i, i+1)
590
+ moving = dataset[tuple(selection)].squeeze().compute()
591
+ fft_moving = np.fft.fft2(moving)
592
+ if sub_pixel:
593
+ shift = skimage.registration.phase_cross_correlation(fft_fixed, fft_moving, upsample_factor=1000,
594
+ space='fourier')[0]
595
+ else:
596
+ image_product = fft_fixed * fft_moving.conj()
597
+ cc_image = np.fft.fftshift(np.fft.ifft2(image_product))
598
+ shift = np.array(ndimage.maximum_position(cc_image.real))-cc_image.shape[0]/2
599
+ fft_fixed = fft_moving
600
+ relative_drift.append(shift)
601
+ rig_reg, drift = rig_reg_drift(dataset, relative_drift)
602
+ crop_reg, input_crop = crop_image_stack(rig_reg, drift)
603
+
604
+ rigid_registered = sidpy.Dataset.from_array(crop_reg,
605
+ title='Rigid Registration',
606
+ data_type='IMAGE_STACK',
607
+ quantity=dataset.quantity,
608
+ units=dataset.units)
609
+ rigid_registered.title = 'Rigid_Registration'
610
+ rigid_registered.source = dataset.title
611
+ rigid_registered.metadata = {'analysis': 'rigid sub-pixel registration', 'drift': drift,
612
+ 'input_crop': input_crop, 'input_shape': dataset.shape[1:]}
613
+ rigid_registered.set_dimension(0, sidpy.Dimension(np.arange(rigid_registered.shape[0]),
614
+ name='frame', units='frame', quantity='time',
615
+ dimension_type='temporal'))
616
+
617
+ array_x = dataset._axes[spatial_dim[0]][input_crop[0]:input_crop[1]].values
618
+ rigid_registered.set_dimension(1, sidpy.Dimension(array_x,
619
+ 'x', units='nm', quantity='Length',
620
+ dimension_type='spatial'))
621
+ array_y = dataset._axes[spatial_dim[1]][input_crop[2]:input_crop[3]].values
622
+ rigid_registered.set_dimension(2, sidpy.Dimension(array_y,
623
+ 'y', units='nm', quantity='Length',
624
+ dimension_type='spatial'))
625
+ return rigid_registered.rechunk({0: 'auto', 1: -1, 2: -1})
626
+
627
+
628
+ def rig_reg_drift(dset, rel_drift):
629
+ """ Shifting images on top of each other
630
+
631
+ Uses relative drift to shift images on top of each other,
632
+ with center image as reference.
633
+ Shifting is done with shift routine of ndimage from scipy.
634
+ This function is used by rigid_registration routine
635
+
636
+ Parameters
637
+ ----------
638
+ dset: sidpy.Dataset
639
+ dataset with image_stack
640
+ rel_drift:
641
+ relative_drift from image to image as list of [shiftx, shifty]
642
+
643
+ Returns
644
+ -------
645
+ stack: numpy array
646
+ drift: list of drift in pixel
647
+ """
648
+
649
+ frame_dim = []
650
+ spatial_dim = []
651
+ selection = []
652
+
653
+ for i, axis in dset._axes.items():
654
+ if axis.dimension_type.name == 'SPATIAL':
655
+ spatial_dim.append(i)
656
+ selection.append(slice(None))
657
+ else:
658
+ frame_dim.append(i)
659
+ selection.append(slice(0, 1))
660
+
661
+ if len(spatial_dim) != 2:
662
+ print('need two spatial dimensions')
663
+ if len(frame_dim) != 1:
664
+ print('need one frame dimensions')
665
+
666
+ rig_reg = np.zeros([dset.shape[frame_dim[0]], dset.shape[spatial_dim[0]], dset.shape[spatial_dim[1]]])
667
+
668
+ # absolute drift
669
+ print(rel_drift)
670
+ drift = np.array(rel_drift).copy()
671
+
672
+ drift[0] = [0, 0]
673
+ for i in range(1, drift.shape[0]):
674
+ drift[i] = drift[i - 1] + rel_drift[i]
675
+ center_drift = drift[int(drift.shape[0] / 2)]
676
+ drift = drift - center_drift
677
+ # Shift images
678
+ for i in range(rig_reg.shape[0]):
679
+ selection[frame_dim[0]] = slice(i, i+1)
680
+ # Now we shift
681
+ rig_reg[i, :, :] = ndimage.shift(dset[tuple(selection)].squeeze().compute(),
682
+ [drift[i, 0], drift[i, 1]], order=3)
683
+ return rig_reg, drift
684
+
685
+
686
+ def crop_image_stack(rig_reg, drift):
687
+ """Crop images in stack according to drift
688
+
689
+ This function is used by rigid_registration routine
690
+
691
+ Parameters
692
+ ----------
693
+ rig_reg: numpy array (N,x,y)
694
+ drift: list (2,B)
695
+
696
+ Returns
697
+ -------
698
+ numpy array
699
+ """
700
+
701
+ xpmin = int(-np.floor(np.min(np.array(drift)[:, 0])))
702
+ xpmax = int(rig_reg.shape[1] - np.ceil(np.max(np.array(drift)[:, 0])))
703
+ ypmin = int(-np.floor(np.min(np.array(drift)[:, 1])))
704
+ ypmax = int(rig_reg.shape[2] - np.ceil(np.max(np.array(drift)[:, 1])))
705
+
706
+ return rig_reg[:, xpmin:xpmax, ypmin:ypmax], [xpmin, xpmax, ypmin, ypmax]
707
+
708
+
709
+ class ImageWithLineProfile:
710
+ """Image with line profile"""
711
+
712
+ def __init__(self, data, extent, title=''):
713
+ fig, ax = plt.subplots(1, 1)
714
+ self.figure = fig
715
+ self.title = title
716
+ self.line_plot = False
717
+ self.ax = ax
718
+ self.data = data
719
+ self.extent = extent
720
+ self.ax.imshow(data, extent=extent)
721
+ self.ax.set_title(title)
722
+ self.line, = self.ax.plot([0], [0], color='orange') # empty line
723
+ self.end_x = self.line.get_xdata()
724
+ self.end_y = self.line.get_ydata()
725
+ self.cid = self.line.figure.canvas.mpl_connect('button_press_event', self)
726
+
727
+ def __call__(self, event):
728
+ if event.inaxes != self.line.axes:
729
+ return
730
+ self.start_x = self.end_x
731
+ self.start_y = self.end_y
732
+
733
+ self.line.set_data([self.start_x, event.xdata], [self.start_y, event.ydata])
734
+ self.line.figure.canvas.draw()
735
+
736
+ self.end_x = event.xdata
737
+ self.end_y = event.ydata
738
+
739
+ self.update()
740
+
741
+ def update(self):
742
+ if not self.line_plot:
743
+ self.line_plot = True
744
+ self.figure.clear()
745
+ self.ax = self.figure.subplots(2, 1)
746
+ self.ax[0].imshow(self.data, extent=self.extent)
747
+ self.ax[0].set_title(self.title)
748
+
749
+ self.line, = self.ax[0].plot([0], [0], color='orange') # empty line
750
+ self.line_plot, = self.ax[1].plot([], [], color='orange')
751
+ self.ax[1].set_xlabel('distance [nm]')
752
+
753
+ x0 = self.start_x
754
+ x1 = self.end_x
755
+ y0 = self.start_y
756
+ y1 = self.end_y
757
+ length_plot = np.sqrt((x1-x0)**2+(y1-y0)**2)
758
+
759
+ num = length_plot*(self.data.shape[0]/self.extent[1])
760
+ x = np.linspace(x0, x1, num)*(self.data.shape[0]/self.extent[1])
761
+ y = np.linspace(y0, y1, num)*(self.data.shape[0]/self.extent[1])
762
+
763
+ # Extract the values along the line, using cubic interpolation
764
+ zi2 = ndimage.map_coordinates(self.data.T, np.vstack((x, y)))
765
+
766
+ x_axis = np.linspace(0, length_plot, len(zi2))
767
+ self.x = x_axis
768
+ self.z = zi2
769
+
770
+ self.line_plot.set_xdata(x_axis)
771
+ self.line_plot.set_ydata(zi2)
772
+ self.ax[1].set_xlim(0, x_axis.max())
773
+ self.ax[1].set_ylim(zi2.min(), zi2.max())
774
+ self.ax[1].draw()
775
+
776
+
777
+ class LineSelector(matplotlib.widgets.PolygonSelector):
778
+ def __init__(self, ax, onselect, line_width=1, **kwargs):
779
+ super().__init__(ax, onselect, **kwargs)
780
+ bounds = ax.viewLim.get_points()
781
+ np.max(bounds[0])
782
+ self.line_verts = np.array([[np.max(bounds[1])/2, np.max(bounds[0])/5], [np.max(bounds[1])/2,
783
+ np.max(bounds[0])/5+1],
784
+ [np.max(bounds[1])/5, np.max(bounds[0])/2], [np.max(bounds[1])/5,
785
+ np.max(bounds[0])/2]])
786
+ self.verts = self.line_verts
787
+ self.line_width = line_width
788
+
789
+ def set_linewidth(self, line_width=None):
790
+ if line_width is not None:
791
+ self.line_width = line_width
792
+
793
+ m = -(self.line_verts[0, 1]-self.line_verts[3, 1])/(self.line_verts[0, 0]-self.line_verts[3, 0])
794
+ c = 1/np.sqrt(1+m**2)
795
+ s = c*m
796
+ self.line_verts[1] = [self.line_verts[0, 0]+self.line_width*s, self.line_verts[0, 1]+self.line_width*c]
797
+ self.line_verts[2] = [self.line_verts[3, 0]+self.line_width*s, self.line_verts[3, 1]+self.line_width*c]
798
+
799
+ self.verts = self.line_verts.copy()
800
+
801
+ def onmove(self, event):
802
+ super().onmove(event)
803
+ if np.max(np.linalg.norm(self.line_verts-self.verts, axis=1)) > 1:
804
+ self.moved_point = np.argmax(np.linalg.norm(self.line_verts-self.verts, axis=1))
805
+
806
+ self.new_point = self.verts[self.moved_point]
807
+ moved_point = int(np.floor(self.moved_point/2)*3)
808
+ self.moved_point = moved_point
809
+ self.line_verts[moved_point] = self.new_point
810
+ self.set_linewidth()
811
+
812
+ def get_profile(dataset, line, spline_order=-1):
813
+ """
814
+ This function extracts a line profile from a given dataset. The line profile is a representation of the data values
815
+ along a specified line in the dataset. This function works for both image and spectral image data types.
816
+
817
+ Args:
818
+ dataset (sidpy.Dataset): The input dataset from which to extract the line profile.
819
+ line (list): A list specifying the line along which the profile should be extracted.
820
+ spline_order (int, optional): The order of the spline interpolation to use. Default is -1, which means no interpolation.
821
+
822
+ Returns:
823
+ profile_dataset (sidpy.Dataset): A new sidpy.Dataset containing the line profile.
824
+
825
+
826
+ """
827
+ xv, yv = get_line_selection_points(line)
828
+ if dataset.data_type.name == 'IMAGE':
829
+ dataset.get_image_dims()
830
+ xv /= (dataset.x[1] - dataset.x[0])
831
+ yv /= (dataset.y[1] - dataset.y[0])
832
+ profile = scipy.ndimage.map_coordinates(np.array(dataset), [xv, yv])
833
+
834
+ profile_dataset = sidpy.Dataset.from_array(profile.sum(axis=0))
835
+ profile_dataset.data_type='spectrum'
836
+ profile_dataset.units = dataset.units
837
+ profile_dataset.quantity = dataset.quantity
838
+ profile_dataset.set_dimension(0, sidpy.Dimension(np.linspace(xv[0,0], xv[-1,-1], profile_dataset.shape[0]),
839
+ name='x', units=dataset.x.units, quantity=dataset.x.quantity,
840
+ dimension_type='spatial'))
841
+
842
+ profile_dataset
843
+
844
+ if dataset.data_type.name == 'SPECTRAL_IMAGE':
845
+ spectral_axis = dataset.get_spectral_dims(return_axis=True)[0]
846
+ if spline_order > -1:
847
+ xv, yv, zv = get_line_selection_points_interpolated(line, z_length=dataset.shape[2])
848
+ profile = scipy.ndimage.map_coordinates(np.array(dataset), [xv, yv, zv], order=spline_order)
849
+ profile = profile.sum(axis=0)
850
+ profile = np.stack([profile, profile], axis=1)
851
+ start = xv[0, 0, 0]
852
+ else:
853
+ profile = get_line_profile(np.array(dataset), xv, yv, len(spectral_axis))
854
+ start = xv[0, 0]
855
+ print(profile.shape)
856
+ profile_dataset = sidpy.Dataset.from_array(profile)
857
+ profile_dataset.data_type='spectral_image'
858
+ profile_dataset.units = dataset.units
859
+ profile_dataset.quantity = dataset.quantity
860
+ profile_dataset.set_dimension(0, sidpy.Dimension(np.arange(profile_dataset.shape[0])+start,
861
+ name='x', units=dataset.x.units, quantity=dataset.x.quantity,
862
+ dimension_type='spatial'))
863
+ profile_dataset.set_dimension(1, sidpy.Dimension([0, 1],
864
+ name='y', units=dataset.x.units, quantity=dataset.x.quantity,
865
+ dimension_type='spatial'))
866
+
867
+ profile_dataset.set_dimension(2, spectral_axis)
868
+ return profile_dataset
869
+
870
+
871
+
872
+ def get_line_selection_points_interpolated(line, z_length=1):
873
+
874
+ start_point = line.line_verts[3]
875
+ right_point = line.line_verts[0]
876
+ low_point = line.line_verts[2]
877
+
878
+ if start_point[0] > right_point[0]:
879
+ start_point = line.line_verts[0]
880
+ right_point = line.line_verts[3]
881
+ low_point = line.line_verts[1]
882
+ m = (right_point[1] - start_point[1]) / (right_point[0] - start_point[0])
883
+ length_x = int(abs(start_point[0]-right_point[0]))
884
+ length_v = int(np.linalg.norm(start_point-right_point))
885
+
886
+ linewidth = int(abs(start_point[1]-low_point[1]))
887
+ x = np.linspace(0,length_x, length_v)
888
+ y = np.linspace(0,linewidth, line.line_width)
889
+ if z_length > 1:
890
+ z = np.linspace(0, z_length, z_length)
891
+ xv, yv, zv = np.meshgrid(x, y, np.arange(z_length))
892
+ x = np.atleast_2d(x).repeat(z_length, axis=0).T
893
+ y = np.atleast_2d(y).repeat(z_length, axis=0).T
894
+ else:
895
+ xv, yv = np.meshgrid(x, y)
896
+
897
+
898
+ yv = yv + x*m + start_point[1]
899
+ xv = (xv.swapaxes(0,1) -y*m ).swapaxes(0,1) + start_point[0]
900
+
901
+ if z_length > 1:
902
+ return xv, yv, zv
903
+ else:
904
+ return xv, yv
905
+
906
+
907
+ def get_line_selection_points(line):
908
+
909
+ start_point = line.line_verts[3]
910
+ right_point = line.line_verts[0]
911
+ low_point = line.line_verts[2]
912
+
913
+ if start_point[0] > right_point[0]:
914
+ start_point = line.line_verts[0]
915
+ right_point = line.line_verts[3]
916
+ low_point = line.line_verts[1]
917
+ m = (right_point[1] - start_point[1]) / (right_point[0] - start_point[0])
918
+ length_x = int(abs(start_point[0]-right_point[0]))
919
+ length_v = int(np.linalg.norm(start_point-right_point))
920
+
921
+ linewidth = int(abs(start_point[1]-low_point[1]))
922
+ x = np.linspace(0,length_x, length_v)
923
+ y = np.linspace(0,linewidth, line.line_width)
924
+ xv, yv = np.meshgrid(x, y)
925
+
926
+ yy = yv +x*m+start_point[1]
927
+ xx = (xv.T -y*m ).T + start_point[0]
928
+
929
+ return xx, yy
930
+
931
+
932
+ def get_line_profile(data, xv, yv, z_length):
933
+ profile = np.zeros([len(xv[0]), 2, z_length])
934
+ for index_x in range(xv.shape[1]):
935
+ for index_y in range(xv.shape[0]):
936
+ x = int(xv[index_y, index_x])
937
+ y = int(yv[index_y, index_x])
938
+ if x< data.shape[0] and x>0 and y < data.shape[1] and y>0:
939
+ profile[index_x, 0] +=data[x, y]
940
+ return profile
941
+
942
+
943
+ def histogram_plot(image_tags):
944
+ """interactive histogram"""
945
+ nbins = 75
946
+ color_map_list = ['gray', 'viridis', 'jet', 'hot']
947
+ if 'minimum_intensity' not in image_tags:
948
+ image_tags['minimum_intensity'] = image_tags['plotimage'].min()
949
+ minimum_intensity = image_tags['minimum_intensity']
950
+ if 'maximum_intensity' not in image_tags:
951
+ image_tags['maximum_intensity'] = image_tags['plotimage'].max()
952
+ data = image_tags['plotimage']
953
+ vmin = image_tags['minimum_intensity']
954
+ vmax = image_tags['maximum_intensity']
955
+ if 'color_map' not in image_tags:
956
+ image_tags['color_map'] = color_map_list[0]
957
+
958
+ cmap = plt.cm.get_cmap(image_tags['color_map'])
959
+ colors = cmap(np.linspace(0., 1., nbins))
960
+ norm2 = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
961
+ hist, bin_edges = np.histogram(data, np.linspace(vmin, vmax, nbins), density=True)
962
+
963
+ width = bin_edges[1]-bin_edges[0]
964
+
965
+ def onselect(vmin, vmax):
966
+ ax1.clear()
967
+ cmap = plt.cm.get_cmap(image_tags['color_map'])
968
+ colors = cmap(np.linspace(0., 1., nbins))
969
+ norm2 = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
970
+ hist2, bin_edges2 = np.histogram(data, np.linspace(vmin, vmax, nbins), density=True)
971
+
972
+ width2 = (bin_edges2[1]-bin_edges2[0])
973
+
974
+ for i in range(nbins-1):
975
+ histogram[i].xy = (bin_edges2[i], 0)
976
+ histogram[i].set_height(hist2[i])
977
+ histogram[i].set_width(width2)
978
+ histogram[i].set_facecolor(colors[i])
979
+ ax.set_xlim(vmin, vmax)
980
+ ax.set_ylim(0, hist2.max()*1.01)
981
+
982
+ cb1 = mpl.colorbar.ColorbarBase(ax1, cmap=cmap, norm=norm2, orientation='horizontal')
983
+
984
+ image_tags['minimum_intensity'] = vmin
985
+ image_tags['maximum_intensity'] = vmax
986
+
987
+ def onclick(event):
988
+ global event2
989
+ event2 = event
990
+ print('%s click: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' %
991
+ ('double' if event.dblclick else 'single', event.button,
992
+ event.x, event.y, event.xdata, event.ydata))
993
+ if event.inaxes == ax1:
994
+ if event.button == 3:
995
+ ind = color_map_list.index(image_tags['color_map'])+1
996
+ if ind == len(color_map_list):
997
+ ind = 0
998
+ image_tags['color_map'] = color_map_list[ind] # 'viridis'
999
+ vmin = image_tags['minimum_intensity']
1000
+ vmax = image_tags['maximum_intensity']
1001
+ else:
1002
+ vmax = data.max()
1003
+ vmin = data.min()
1004
+ onselect(vmin, vmax)
1005
+
1006
+ fig2 = plt.figure()
1007
+
1008
+ ax = fig2.add_axes([0., 0.2, 0.9, 0.7])
1009
+ ax1 = fig2.add_axes([0., 0.15, 0.9, 0.05])
1010
+
1011
+ histogram = ax.bar(bin_edges[0:-1], hist, width=width, color=colors, edgecolor='black', alpha=0.8)
1012
+ onselect(vmin, vmax)
1013
+ cb1 = mpl.colorbar.ColorbarBase(ax1, cmap=cmap, norm=norm2, orientation='horizontal')
1014
+
1015
+ rectprops = dict(facecolor='blue', alpha=0.5)
1016
+
1017
+ span = mwidgets.SpanSelector(ax, onselect, 'horizontal', rectprops=rectprops)
1018
+
1019
+ cid = fig2.canvas.mpl_connect('button_press_event', onclick)
1020
+ return span
1021
+
1022
+
1023
+ def clean_svd(im, pixel_size=1, source_size=5):
1024
+ """De-noising of image by using first component of single value decomposition"""
1025
+ patch_size = int(source_size/pixel_size)
1026
+ if patch_size < 3:
1027
+ patch_size = 3
1028
+ patches = image.extract_patches_2d(im, (patch_size, patch_size))
1029
+ patches = patches.reshape(patches.shape[0], patches.shape[1]*patches.shape[2])
1030
+
1031
+ num_components = 32
1032
+
1033
+ u, s, v = randomized_svd(patches, num_components)
1034
+ u_im_size = int(np.sqrt(u.shape[0]))
1035
+ reduced_image = u[:, 0].reshape(u_im_size, u_im_size)
1036
+ reduced_image = reduced_image/reduced_image.sum()*im.sum()
1037
+ return reduced_image
1038
+
1039
+
1040
+ def rebin(im, binning=2):
1041
+ """
1042
+ rebin an image by the number of pixels in x and y direction given by binning
1043
+
1044
+ Parameter
1045
+ ---------
1046
+ image: numpy array in 2 dimensions
1047
+
1048
+ Returns
1049
+ -------
1050
+ binned image as numpy array
1051
+ """
1052
+ if len(im.shape) == 2:
1053
+ return im.reshape((im.shape[0]//binning, binning, im.shape[1]//binning, binning)).mean(axis=3).mean(1)
1054
+ else:
1055
+ raise TypeError('not a 2D image')
1056
+
1057
+
1058
+ def cart2pol(points):
1059
+ """Cartesian to polar coordinate conversion
1060
+
1061
+ Parameters
1062
+ ---------
1063
+ points: float or numpy array
1064
+ points to be converted (Nx2)
1065
+
1066
+ Returns
1067
+ -------
1068
+ rho: float or numpy array
1069
+ distance
1070
+ phi: float or numpy array
1071
+ angle
1072
+ """
1073
+
1074
+ rho = np.linalg.norm(points[:, 0:2], axis=1)
1075
+ phi = np.arctan2(points[:, 1], points[:, 0])
1076
+
1077
+ return rho, phi
1078
+
1079
+
1080
+ def pol2cart(rho, phi):
1081
+ """Polar to Cartesian coordinate conversion
1082
+
1083
+ Parameters
1084
+ ----------
1085
+ rho: float or numpy array
1086
+ distance
1087
+ phi: float or numpy array
1088
+ angle
1089
+
1090
+ Returns
1091
+ -------
1092
+ x: float or numpy array
1093
+ x coordinates of converted points(Nx2)
1094
+ """
1095
+
1096
+ x = rho * np.cos(phi)
1097
+ y = rho * np.sin(phi)
1098
+ return x, y
1099
+
1100
+
1101
+ def xy2polar(points, rounding=1e-3):
1102
+ """ Conversion from carthesian to polar coordinates
1103
+
1104
+ the angles and distances are sorted by r and then phi
1105
+ The indices of this sort is also returned
1106
+
1107
+ Parameters
1108
+ ----------
1109
+ points: numpy array
1110
+ number of points in axis 0 first two elements in axis 1 are x and y
1111
+ rounding: int
1112
+ optional rounding in significant digits
1113
+
1114
+ Returns
1115
+ -------
1116
+ r, phi, sorted_indices
1117
+ """
1118
+
1119
+ r, phi = cart2pol(points)
1120
+
1121
+ phi = phi # %np.pi # only positive angles
1122
+ r = (np.floor(r/rounding))*rounding # Remove rounding error differences
1123
+
1124
+ sorted_indices = np.lexsort((phi, r)) # sort first by r and then by phi
1125
+ r = r[sorted_indices]
1126
+ phi = phi[sorted_indices]
1127
+
1128
+ return r, phi, sorted_indices
1129
+
1130
+
1131
+ def cartesian2polar(x, y, grid, r, t, order=3):
1132
+ """Transform cartesian grid to polar grid
1133
+
1134
+ Used by warp
1135
+ """
1136
+
1137
+ rr, tt = np.meshgrid(r, t)
1138
+
1139
+ new_x = rr*np.cos(tt)
1140
+ new_y = rr*np.sin(tt)
1141
+
1142
+ ix = interp1d(x, np.arange(len(x)))
1143
+ iy = interp1d(y, np.arange(len(y)))
1144
+
1145
+ new_ix = ix(new_x.ravel())
1146
+ new_iy = iy(new_y.ravel())
1147
+
1148
+ return ndimage.map_coordinates(grid, np.array([new_ix, new_iy]), order=order).reshape(new_x.shape)
1149
+
1150
+
1151
+ def warp(diff, center):
1152
+ """Takes a diffraction pattern (as a sidpy dataset)and warps it to a polar grid"""
1153
+
1154
+ # Define original polar grid
1155
+ nx = np.shape(diff)[0]
1156
+ ny = np.shape(diff)[1]
1157
+
1158
+ # Define center pixel
1159
+ pix2nm = np.gradient(diff.u.values)[0]
1160
+
1161
+ x = np.linspace(1, nx, nx, endpoint=True)-center[0]
1162
+ y = np.linspace(1, ny, ny, endpoint=True)-center[1]
1163
+ z = diff
1164
+
1165
+ # Define new polar grid
1166
+ nr = int(min([center[0], center[1], diff.shape[0]-center[0], diff.shape[1]-center[1]])-1)
1167
+ nt = 360 * 3
1168
+
1169
+ r = np.linspace(1, nr, nr)
1170
+ t = np.linspace(0., np.pi, nt, endpoint=False)
1171
+
1172
+ return cartesian2polar(x, y, z, r, t, order=3).T
1173
+
1174
+
1175
+ def calculate_ctf(wavelength, cs, defocus, k):
1176
+ """ Calculate Contrast Transfer Function
1177
+
1178
+ everything in nm
1179
+
1180
+ Parameters
1181
+ ----------
1182
+ wavelength: float
1183
+ deBroglie wavelength of electrons
1184
+ cs: float
1185
+ spherical aberration coefficient
1186
+ defocus: float
1187
+ defocus
1188
+ k: numpy array
1189
+ reciprocal scale
1190
+
1191
+ Returns
1192
+ -------
1193
+ ctf: numpy array
1194
+ contrast transfer function
1195
+
1196
+ """
1197
+ ctf = np.sin(np.pi*defocus*wavelength*k**2+0.5*np.pi*cs*wavelength**3*k**4)
1198
+ return ctf
1199
+
1200
+
1201
+ def calculate_scherzer(wavelength, cs):
1202
+ """
1203
+ Calculate the Scherzer defocus. Cs is in mm, lambda is in nm
1204
+
1205
+ # Input and output in nm
1206
+ """
1207
+
1208
+ scherzer = -1.155*(cs*wavelength)**0.5 # in m
1209
+ return scherzer
1210
+
1211
+
1212
+ def get_rotation(experiment_spots, crystal_spots):
1213
+ """Get rotation by comparing spots in diffractogram to diffraction Bragg spots
1214
+
1215
+ Parameter
1216
+ ---------
1217
+ experiment_spots: numpy array (nx2)
1218
+ positions (in 1/nm) of spots in diffractogram
1219
+ crystal_spots: numpy array (nx2)
1220
+ positions (in 1/nm) of Bragg spots according to kinematic scattering theory
1221
+
1222
+ """
1223
+
1224
+ r_experiment, phi_experiment = cart2pol(experiment_spots)
1225
+
1226
+ # get crystal spots of same length and sort them by angle as well
1227
+ r_crystal, phi_crystal, crystal_indices = xy2polar(crystal_spots)
1228
+ angle_index = np.argmin(np.abs(r_experiment-r_crystal[1]))
1229
+ rotation_angle = phi_experiment[angle_index] % (2*np.pi) - phi_crystal[1]
1230
+ print(phi_experiment[angle_index])
1231
+ st = np.sin(rotation_angle)
1232
+ ct = np.cos(rotation_angle)
1233
+ rotation_matrix = np.array([[ct, -st], [st, ct]])
1234
+
1235
+ return rotation_matrix, rotation_angle
1236
+
1237
+
1238
+ def calibrate_image_scale(fft_tags, spots_reference, spots_experiment):
1239
+ """depreciated get change of scale from comparison of spots to Bragg angles """
1240
+ gx = fft_tags['spatial_scale_x']
1241
+ gy = fft_tags['spatial_scale_y']
1242
+
1243
+ dist_reference = np.linalg.norm(spots_reference, axis=1)
1244
+ distance_experiment = np.linalg.norm(spots_experiment, axis=1)
1245
+
1246
+ first_reflections = abs(distance_experiment - dist_reference.min()) < .2
1247
+ print('Evaluate ', first_reflections.sum(), 'reflections')
1248
+ closest_exp_reflections = spots_experiment[first_reflections]
1249
+
1250
+ def func(params, xdata, ydata):
1251
+ dgx, dgy = params
1252
+ return np.sqrt((xdata * dgx) ** 2 + (ydata * dgy) ** 2) - dist_reference.min()
1253
+
1254
+ x0 = [1.001, 0.999]
1255
+ [dg, sig] = optimization.leastsq(func, x0, args=(closest_exp_reflections[:, 0], closest_exp_reflections[:, 1]))
1256
+ return dg
1257
+
1258
+
1259
+ def align_crystal_reflections(spots, crystals):
1260
+ """ Depreciated - use diffraction spots"""
1261
+
1262
+ crystal_reflections_polar = []
1263
+ angles = []
1264
+ exp_r, exp_phi = cart2pol(spots) # just in polar coordinates
1265
+ spots_polar = np.array([exp_r, exp_phi])
1266
+
1267
+ for i in range(len(crystals)):
1268
+ tags = crystals[i]
1269
+ r, phi, indices = xy2polar(tags['allowed']['g']) # sorted by r and phi , only positive angles
1270
+ # we mask the experimental values that are found already
1271
+ angle = 0.
1272
+
1273
+ angle_i = np.argmin(np.abs(exp_r - r[1]))
1274
+ angle = exp_phi[angle_i] - phi[0]
1275
+ angles.append(angle) # Determine rotation angle
1276
+
1277
+ crystal_reflections_polar.append([r, angle + phi, indices])
1278
+ tags['allowed']['g_rotated'] = pol2cart(r, angle + phi)
1279
+ for spot in tags['allowed']['g']:
1280
+ dif = np.linalg.norm(spots[:, 0:2]-spot[0:2], axis=1)
1281
+ # print(dif.min())
1282
+ if dif.min() < 1.5:
1283
+ ind = np.argmin(dif)
1284
+
1285
+ return crystal_reflections_polar, angles
1286
+
1287
+
1288
+ # Deconvolution
1289
+ def decon_lr(o_image, probe, verbose=False):
1290
+ """
1291
+ # This task generates a restored image from an input image and point spread function (PSF) using
1292
+ # the algorithm developed independently by Lucy (1974, Astron. J. 79, 745) and Richardson
1293
+ # (1972, J. Opt. Soc. Am. 62, 55) and adapted for HST imagery by Snyder
1294
+ # (1990, in Restoration of HST Images and Spectra, ST ScI Workshop Proceedings; see also
1295
+ # Snyder, Hammoud, & White, JOSA, v. 10, no. 5, May 1993, in press).
1296
+ # Additional options developed by Rick White (STScI) are also included.
1297
+ #
1298
+ # The Lucy-Richardson method can be derived from the maximum likelihood expression for data
1299
+ # with a Poisson noise distribution. Thus, it naturally applies to optical imaging data such as HST.
1300
+ # The method forces the restored image to be positive, in accord with photon-counting statistics.
1301
+ #
1302
+ # The Lucy-Richardson algorithm generates a restored image through an iterative method. The essence
1303
+ # of the iteration is as follows: the (n+1)th estimate of the restored image is given by the nth estimate
1304
+ # of the restored image multiplied by a correction image. That is,
1305
+ #
1306
+ # original data
1307
+ # image = image --------------- * reflect(PSF)
1308
+ # n+1 n image * PSF
1309
+ # n
1310
+
1311
+ # where the *'s represent convolution operators and reflect(PSF) is the reflection of the PSF, i.e.
1312
+ # reflect((PSF)(x,y)) = PSF(-x,-y). When the convolutions are carried out using fast Fourier transforms
1313
+ # (FFTs), one can use the fact that FFT(reflect(PSF)) = conj(FFT(PSF)), where conj is the complex conjugate
1314
+ # operator.
1315
+ """
1316
+
1317
+ if len(o_image) < 1:
1318
+ return o_image
1319
+
1320
+ if o_image.shape != probe.shape:
1321
+ print('Weirdness ', o_image.shape, ' != ', probe.shape)
1322
+
1323
+ probe_c = np.ones(probe.shape, dtype=np.complex64)
1324
+ probe_c.real = probe
1325
+
1326
+ error = np.ones(o_image.shape, dtype=np.complex64)
1327
+ est = np.ones(o_image.shape, dtype=np.complex64)
1328
+ source = np.ones(o_image.shape, dtype=np.complex64)
1329
+ source.real = o_image
1330
+
1331
+ response_ft = fftpack.fft2(probe_c)
1332
+
1333
+ ap_angle = o_image.metadata['experiment']['convergence_angle'] / 1000.0 # now in rad
1334
+
1335
+ e0 = float(o_image.metadata['experiment']['acceleration_voltage'])
1336
+
1337
+ wl = get_wavelength(e0)
1338
+ o_image.metadata['experiment']['wavelength'] = wl
1339
+
1340
+ over_d = 2 * ap_angle / wl
1341
+
1342
+ dx = o_image.x[1]-o_image.x[0]
1343
+ dk = 1.0 / float(o_image.x[-1]) # last value of x-axis is field of view
1344
+ screen_width = 1 / dx
1345
+
1346
+ aperture = np.ones(o_image.shape, dtype=np.complex64)
1347
+ # Mask for the aperture before the Fourier transform
1348
+ n = o_image.shape[0]
1349
+ size_x = o_image.shape[0]
1350
+ size_y = o_image.shape[1]
1351
+ app_ratio = over_d / screen_width * n
1352
+
1353
+ theta_x = np.array(-size_x / 2. + np.arange(size_x))
1354
+ theta_y = np.array(-size_y / 2. + np.arange(size_y))
1355
+ t_xv, t_yv = np.meshgrid(theta_x, theta_y)
1356
+
1357
+ tp1 = t_xv ** 2 + t_yv ** 2 >= app_ratio ** 2
1358
+ aperture[tp1.T] = 0.
1359
+ # print(app_ratio, screen_width, dk)
1360
+
1361
+ progress = tqdm(total=500)
1362
+ # de = 100
1363
+ dest = 100
1364
+ i = 0
1365
+ while abs(dest) > 0.0001: # or abs(de) > .025:
1366
+ i += 1
1367
+ error_old = np.sum(error.real)
1368
+ est_old = est.copy()
1369
+ error = source / np.real(fftpack.fftshift(fftpack.ifft2(fftpack.fft2(est) * response_ft)))
1370
+ est = est * np.real(fftpack.fftshift(fftpack.ifft2(fftpack.fft2(error) * np.conjugate(response_ft))))
1371
+ # est = est_old * est
1372
+ # est = np.real(fftpack.fftshift(fftpack.ifft2(fftpack.fft2(est)*fftpack.fftshift(aperture) )))
1373
+
1374
+ error_new = np.real(np.sum(np.power(error, 2))) - error_old
1375
+ dest = np.sum(np.power((est - est_old).real, 2)) / np.sum(est) * 100
1376
+ # print(np.sum((est.real - est_old.real)* (est.real - est_old.real) )/np.sum(est.real)*100 )
1377
+
1378
+ if error_old != 0:
1379
+ de = error_new / error_old * 1.0
1380
+ else:
1381
+ de = error_new
1382
+
1383
+ if verbose:
1384
+ print(
1385
+ ' LR Deconvolution - Iteration: {0:d} Error: {1:.2f} = change: {2:.5f}%, {3:.5f}%'.format(i, error_new,
1386
+ de,
1387
+ abs(dest)))
1388
+ if i > 500:
1389
+ dest = 0.0
1390
+ print('terminate')
1391
+ progress.update(1)
1392
+ progress.write(f"converged in {i} iterations")
1393
+ # progress.close()
1394
+ print('\n Lucy-Richardson deconvolution converged in ' + str(i) + ' iterations')
1395
+ est2 = np.real(fftpack.ifft2(fftpack.fft2(est) * fftpack.fftshift(aperture)))
1396
+ out_dataset = o_image.like_data(est2)
1397
+ out_dataset.title = 'Lucy Richardson deconvolution'
1398
+ out_dataset.data_type = 'image'
1399
+ return out_dataset