pyTEMlib 0.2020.11.0__py3-none-any.whl → 0.2024.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyTEMlib might be problematic. Click here for more details.

Files changed (59) hide show
  1. pyTEMlib/__init__.py +11 -11
  2. pyTEMlib/animation.py +631 -0
  3. pyTEMlib/atom_tools.py +240 -222
  4. pyTEMlib/config_dir.py +57 -29
  5. pyTEMlib/core_loss_widget.py +658 -0
  6. pyTEMlib/crystal_tools.py +1255 -0
  7. pyTEMlib/diffraction_plot.py +756 -0
  8. pyTEMlib/dynamic_scattering.py +293 -0
  9. pyTEMlib/eds_tools.py +609 -0
  10. pyTEMlib/eels_dialog.py +749 -486
  11. pyTEMlib/{interactive_eels.py → eels_dialog_utilities.py} +1199 -1524
  12. pyTEMlib/eels_tools.py +2031 -1731
  13. pyTEMlib/file_tools.py +1276 -491
  14. pyTEMlib/file_tools_qt.py +193 -0
  15. pyTEMlib/graph_tools.py +1166 -450
  16. pyTEMlib/graph_viz.py +449 -0
  17. pyTEMlib/image_dialog.py +158 -0
  18. pyTEMlib/image_dlg.py +146 -0
  19. pyTEMlib/image_tools.py +1399 -956
  20. pyTEMlib/info_widget.py +933 -0
  21. pyTEMlib/interactive_image.py +1 -0
  22. pyTEMlib/kinematic_scattering.py +1196 -0
  23. pyTEMlib/low_loss_widget.py +176 -0
  24. pyTEMlib/microscope.py +61 -78
  25. pyTEMlib/peak_dialog.py +1047 -350
  26. pyTEMlib/peak_dlg.py +286 -248
  27. pyTEMlib/probe_tools.py +653 -202
  28. pyTEMlib/sidpy_tools.py +153 -129
  29. pyTEMlib/simulation_tools.py +104 -87
  30. pyTEMlib/version.py +6 -3
  31. pyTEMlib/xrpa_x_sections.py +20972 -0
  32. {pyTEMlib-0.2020.11.0.dist-info → pyTEMlib-0.2024.8.4.dist-info}/LICENSE +21 -21
  33. pyTEMlib-0.2024.8.4.dist-info/METADATA +93 -0
  34. pyTEMlib-0.2024.8.4.dist-info/RECORD +37 -0
  35. {pyTEMlib-0.2020.11.0.dist-info → pyTEMlib-0.2024.8.4.dist-info}/WHEEL +6 -5
  36. {pyTEMlib-0.2020.11.0.dist-info → pyTEMlib-0.2024.8.4.dist-info}/entry_points.txt +0 -1
  37. pyTEMlib/KinsCat.py +0 -2685
  38. pyTEMlib/__version__.py +0 -2
  39. pyTEMlib/data/TEMlibrc +0 -68
  40. pyTEMlib/data/edges_db.csv +0 -189
  41. pyTEMlib/data/edges_db.pkl +0 -0
  42. pyTEMlib/data/fparam.txt +0 -103
  43. pyTEMlib/data/microscopes.csv +0 -7
  44. pyTEMlib/data/microscopes.xml +0 -167
  45. pyTEMlib/data/path.txt +0 -1
  46. pyTEMlib/defaults_parser.py +0 -86
  47. pyTEMlib/dm3_reader.py +0 -609
  48. pyTEMlib/edges_db.py +0 -76
  49. pyTEMlib/eels_dlg.py +0 -240
  50. pyTEMlib/hdf_utils.py +0 -481
  51. pyTEMlib/image_tools1.py +0 -2194
  52. pyTEMlib/info_dialog.py +0 -227
  53. pyTEMlib/info_dlg.py +0 -205
  54. pyTEMlib/nion_reader.py +0 -293
  55. pyTEMlib/nsi_reader.py +0 -165
  56. pyTEMlib/structure_tools.py +0 -316
  57. pyTEMlib-0.2020.11.0.dist-info/METADATA +0 -20
  58. pyTEMlib-0.2020.11.0.dist-info/RECORD +0 -42
  59. {pyTEMlib-0.2020.11.0.dist-info → pyTEMlib-0.2024.8.4.dist-info}/top_level.txt +0 -0
pyTEMlib/image_tools.py CHANGED
@@ -1,956 +1,1399 @@
1
- ##################################
2
- #
3
- # image_tools.py
4
- # by Gerd Duscher, UTK
5
- # part of pyTEMlib
6
- # MIT license except where stated differently
7
- #
8
- ###############################
9
- import numpy as np
10
-
11
- import matplotlib as mpl
12
- import matplotlib.pylab as plt
13
- from matplotlib.patches import Polygon # plotting of polygons -- graph rings
14
-
15
- import matplotlib.widgets as mwidgets
16
- from matplotlib.widgets import RectangleSelector
17
-
18
- import sidpy
19
- from .file_tools import *
20
- from .probe_tools import *
21
- import sys
22
-
23
- import itertools
24
- from itertools import product
25
-
26
- from scipy import fftpack
27
- from scipy import signal
28
- from scipy.interpolate import interp1d, interp2d
29
- from scipy.optimize import leastsq
30
- import scipy.optimize as optimization
31
-
32
- # Multidimensional Image library
33
- import scipy.ndimage as ndimage
34
- import scipy.constants as const
35
-
36
- import scipy.spatial as sp
37
- from scipy.spatial import Voronoi, KDTree, cKDTree
38
-
39
- import skimage
40
- import skimage.registration as registration
41
- from skimage.feature import register_translation # blob_dog, blob_doh
42
- from skimage.feature import peak_local_max
43
- from skimage.measure import points_in_poly
44
-
45
- # our blob detectors from the scipy image package
46
- from skimage.feature import blob_log # blob_dog, blob_doh
47
-
48
- from sklearn.feature_extraction import image
49
- from sklearn.utils.extmath import randomized_svd
50
-
51
- _SimpleITK_present = True
52
- try:
53
- import SimpleITK as sITK
54
- except ModuleNotFoundError:
55
- _SimpleITK_present = False
56
-
57
- if not _SimpleITK_present:
58
- print('SimpleITK not installed; Registration Functions for Image Stacks not available')
59
-
60
-
61
- # Wavelength in 1/nm
62
- def get_wavelength(e0):
63
- """
64
- Calculates the relativistic corrected de Broglie wave length of an electron
65
-
66
- Input:
67
- ------
68
- acceleration voltage in volt
69
- Output:
70
- -------
71
- wave length in 1/nm
72
- """
73
-
74
- eV = const.e * e0
75
- return const.h/np.sqrt(2*const.m_e*eV*(1+eV/(2*const.m_e*const.c**2)))*10**9
76
-
77
-
78
- def fourier_transform(dset):
79
- """
80
- Reads information into dictionary 'tags', performs 'FFT', and provides a smoothed FT and reciprocal
81
- and intensity limits for visualization.
82
-
83
- Input
84
- -----
85
- dset: sidp Dataset
86
-
87
- Usage
88
- -----
89
-
90
- fft_dataset = fourier_transform(sidpy_dataset)
91
- fft+dataset.plot()
92
- """
93
-
94
- assert isinstance(dset, sidpy.Dataset), 'Expected a sidpy Dataset'
95
-
96
- selection = []
97
- image_dim = []
98
- # image_dim = get_image_dims(sidpy.DimensionTypes.SPATIAL)
99
- if dset.data_type == sidpy.DataTypes.IMAGE_STACK:
100
- for dim, axis in dset._axes.items():
101
- if axis.dimension_type == sidpy.DimensionTypes.SPATIAL:
102
- selection.append(slice(None))
103
- image_dim.append(dim)
104
- elif axis.dimension_type == sidpy.DimensionTypes.TEMPORAL or len(dset) == 3:
105
- selection.append(slice(None))
106
- stack_dim = dim
107
- else:
108
- selection.append(slice(0, 1))
109
- if len(image_dim) != 2:
110
- raise ValueError('need at least two SPATIAL dimension for an image stack')
111
- image_stack = np.squeeze(np.array(dset)[selection])
112
- image = np.sum(np.array(image_stack), axis=stack_dim)
113
- elif dset.data_type == sidpy.DataTypes.IMAGE:
114
- image = np.array(dset)
115
- else:
116
- return
117
-
118
- image = image - image.min()
119
- fft_transform = (np.fft.fftshift(np.fft.fft2(image)))
120
-
121
- image_dims = get_image_dims(dset)
122
- extent = dset.get_extent(image_dims)
123
- scale_x = 1 / abs(extent[1] - extent[0])
124
- scale_y = 1 / abs(extent[2] - extent[3])
125
-
126
- units_x = '1/' + dset._axes[image_dims[0]].units
127
- units_y = '1/' + dset._axes[image_dims[1]].units
128
-
129
- fft_dset = sidpy.Dataset.from_array(fft_transform)
130
- fft_dset.quantity = dset.quantity
131
- fft_dset.units = 'a.u.'
132
- fft_dset.data_type = 'IMAGE'
133
- fft_dset.source = dset.title
134
- fft_dset.modality = 'fft'
135
- fft_dset.set_dimension(0, sidpy.Dimension((np.arange(fft_dset.shape[0]) - fft_dset.shape[0] / 2) * scale_x,
136
- name='u', units=units_x, dimension_type='RECIPROCAL',
137
- quantity='reciprocal_length'))
138
- fft_dset.set_dimension(1, sidpy.Dimension((np.arange(fft_dset.shape[1]) - fft_dset.shape[1] / 2) * scale_y,
139
- name='v', units=units_y, dimension_type='RECIPROCAL',
140
- quantity='reciprocal_length'))
141
-
142
- return fft_dset
143
-
144
-
145
- def power_spectrum(dset, smoothing=3):
146
- """
147
- Calculate power spectrum
148
-
149
- Input:
150
- ======
151
- channel: channnel in h5f file with image content
152
-
153
- Output:
154
- =======
155
- tags: dictionary with
156
- ['data']: fourier transformed image
157
- ['axis']: scale of reciprocal image
158
- ['power_spectrum']: power_spectrum
159
- ['FOV']: field of view for extent parameter in plotting
160
- ['minimum_intensity']: suggested minimum intensity for plotting
161
- ['maximum_intensity']: suggested maximum intensity for plotting
162
-
163
- """
164
- fft_transform = fourier_transform(dset)
165
- fft_mag = np.abs(fft_transform)
166
- fft_mag2 = ndimage.gaussian_filter(fft_mag, sigma=(smoothing, smoothing), order=0)
167
-
168
- power_spec = fft_transform.like_data(np.log(1.+fft_mag2))
169
-
170
- # prepare mask
171
-
172
- x, y = np.meshgrid(power_spec.u.values, power_spec.v.values)
173
- mask = np.zeros(power_spec.shape)
174
-
175
- mask_spot = x ** 2 + y ** 2 > 1 ** 2
176
- mask = mask + mask_spot
177
- mask_spot = x ** 2 + y ** 2 < 11 ** 2
178
- mask = mask + mask_spot
179
-
180
- mask[np.where(mask == 1)] = 0 # just in case of overlapping disks
181
-
182
- # minimum_intensity = np.log2(1 + fft_mag2)[np.where(mask == 2)].min() * 0.95
183
- # maximum_intensity = np.log2(1 + fft_mag2)[np.where(mask == 2)].max() * 1.05
184
- power_spec.metadata = {'smoothing': smoothing}
185
- # 'minimum_intensity': minimum_intensity, 'maximum_intensity': maximum_intensity}
186
- power_spec.title = 'power spectrum ' + power_spec.source
187
-
188
- return power_spec
189
-
190
-
191
- def diffractogram_spots(dset, spot_threshold):
192
- """
193
- Find spots in diffractogram and sort them by distance from center
194
-
195
- Input:
196
- ======
197
- fft_tags: dictionary with
198
- ['spatial_***']: information of scale of fourier pattern
199
- ['data']: power_spectrum
200
- spot_threshold: threshold for blob finder
201
- Output:
202
- =======
203
- spots: numpy array with sorted position (x,y) and radius (r) of all spots
204
-
205
- """
206
- # Needed for conversion from pixel to Reciprocal space
207
- # we'll have to switch x- and y-coordinates due to the differences in numpy and matrix
208
- center = np.array([int(dset.shape[0]/2.), int(dset.shape[1]/2.), 1])
209
- rec_scale = np.array([get_slope(dset.u.values), get_slope(dset.v.values), 1])
210
-
211
- # spot detection ( for future referece there is no symmetry assumed here)
212
- data = np.array(dset).T
213
- data = (data - data.min())
214
- data = data/data.max()
215
- # some images are strange and blob_log does not work on the power spectrum
216
- try:
217
- spots_random = (blob_log(data, max_sigma=5, threshold=spot_threshold) - center) * rec_scale
218
- except ValueError:
219
- spots_random = (peak_local_max(np.array(data.T), min_distance=3, threshold_rel=spot_threshold) - center[:2]) \
220
- * rec_scale
221
- spots_random = np.hstack(spots_random,np.zeros((spots_random.shape[0],1)))
222
-
223
- print(f'Found {spots_random.shape[0]} reflections')
224
-
225
- # sort reflections
226
- spots_random[:, 2] = np.linalg.norm(spots_random[:, 0:2], axis=1)
227
- spots_index = np.argsort(spots_random[:, 2])
228
- spots = spots_random[spots_index]
229
- # third row is angles
230
- spots[:, 2] = np.arctan2(spots[:, 0], spots[:, 1])
231
- return spots
232
-
233
-
234
- def adaptive_fourier_filter(dset, spots, low_pass=3, reflection_radius=0.3):
235
- """
236
- Use spots in diffractogram for a Fourier Filter
237
-
238
- Input:
239
- ======
240
- image: image to be filtered
241
- tags: dictionary with
242
- ['spatial_***']: information of scale of fourier pattern
243
- ['spots']: sorted spots in diffractogram in 1/nm
244
- low_pass: low pass filter in center of diffractogrm
245
-
246
- Output:
247
- =======
248
- Fourier filtered image
249
- """
250
- # prepare mask
251
-
252
- fft_transform = fourier_transform(dset)
253
- x, y = np.meshgrid(fft_transform.u.values, fft_transform.v.values)
254
- mask = np.zeros(dset.shape)
255
-
256
- # mask reflections
257
- # reflection_radius = 0.3 # in 1/nm
258
- for spot in spots:
259
- mask_spot = (x - spot[0]) ** 2 + (y - spot[1]) ** 2 < reflection_radius ** 2 # make a spot
260
- mask = mask + mask_spot # add spot to mask
261
-
262
- # mask zero region larger (low-pass filter = intensity variations)
263
- # low_pass = 3 # in 1/nm
264
- mask_spot = x ** 2 + y ** 2 < low_pass ** 2
265
- mask = mask + mask_spot
266
- mask[np.where(mask > 1)] = 1
267
- fft_filtered = fft_transform * mask
268
-
269
- filtered_image = dset.like_data(np.fft.ifft2(np.fft.fftshift(fft_filtered)).real)
270
- filtered_image.title = 'Fourier filtered ' + dset.title
271
- filtered_image.source = dset.title
272
- filtered_image.metadata = {'analysis': 'adaptive fourier filtered', 'spots': spots,
273
- 'low_pass': low_pass, 'reflection_radius': reflection_radius}
274
-
275
- return filtered_image
276
-
277
-
278
- def rotational_symmetry_diffractogram(spots):
279
- rotation_symmetry = []
280
- for n in [2, 3, 4, 6]:
281
- cc = np.array(
282
- [[np.cos(2 * np.pi / n), np.sin(2 * np.pi / n), 0], [-np.sin(2 * np.pi / n), np.cos(2 * np.pi / n), 0],
283
- [0, 0, 1]])
284
- sym_spots = np.dot(spots, cc)
285
- dif = []
286
- for p0, p1 in product(sym_spots[:, 0:2], spots[:, 0:2]):
287
- dif.append(np.linalg.norm(p0 - p1))
288
- dif = np.array(sorted(dif))
289
-
290
- if dif[int(spots.shape[0] * .7)] < 0.2:
291
- rotation_symmetry.append(n)
292
- return rotation_symmetry
293
-
294
- #####################################################
295
- # Registration Functions
296
- #####################################################
297
-
298
-
299
- def complete_registration(main_dataset):
300
- rigid_registered_dataset = rigid_registration(main_dataset)
301
- current_channel = main_dataset.h5_dataset.parent
302
- registration_channel = log_results(current_channel, rigid_registered_dataset)
303
-
304
- print('Non-Rigid_Registration')
305
-
306
- non_rigid_registered = demon_registration(rigid_registered_dataset)
307
-
308
- registration_channel = log_results(current_channel, non_rigid_registered)
309
-
310
- return non_rigid_registered, rigid_registered_dataset
311
-
312
-
313
- def demon_registration(dataset, verbose=False):
314
- """
315
- Diffeomorphic Demon Non-Rigid Registration
316
- Usage:
317
- ------
318
- dem_reg = demon_reg(cube, verbose = False)
319
-
320
- Input:
321
- cube: stack of image after rigid registration and cropping
322
- Output:
323
- dem_reg: stack of images with non-rigid registration
324
-
325
- Depends on:
326
- simpleITK and numpy
327
-
328
- Please Cite: http://www.simpleitk.org/SimpleITK/project/parti.html
329
- and T. Vercauteren, X. Pennec, A. Perchant and N. Ayache
330
- Diffeomorphic Demons Using ITK\'s Finite Difference Solver Hierarchy
331
- The Insight Journal, http://hdl.handle.net/1926/510 2007
332
- """
333
-
334
- dem_reg = np.zeros(dataset.shape)
335
- nimages = dataset.shape[0]
336
- if verbose:
337
- print(nimages)
338
- # create fixed image by summing over rigid registration
339
-
340
- fixed_np = np.average(np.array(dataset), axis=0)
341
-
342
- fixed = sITK.GetImageFromArray(fixed_np)
343
- fixed = sITK.DiscreteGaussian(fixed, 2.0)
344
-
345
- # demons = sITK.SymmetricForcesDemonsRegistrationFilter()
346
- demons = sITK.DiffeomorphicDemonsRegistrationFilter()
347
-
348
- demons.SetNumberOfIterations(200)
349
- demons.SetStandardDeviations(1.0)
350
-
351
- resampler = sITK.ResampleImageFilter()
352
- resampler.SetReferenceImage(fixed)
353
- resampler.SetInterpolator(sITK.sitkBSpline)
354
- resampler.SetDefaultPixelValue(0)
355
-
356
- done = 0
357
-
358
- if QT_available:
359
- progress = ProgressDialog("Non-Rigid Registration", nimages)
360
- for i in range(nimages):
361
- if QT_available:
362
- progress.set_value(i)
363
- else:
364
- if done < int((i + 1) / nimages * 50):
365
- done = int((i + 1) / nimages * 50)
366
- sys.stdout.write('\r')
367
- # progress output :
368
- sys.stdout.write("[%-50s] %d%%" % ('*' * done, 2 * done))
369
- sys.stdout.flush()
370
-
371
- moving = sITK.GetImageFromArray(dataset[i])
372
- moving_f = sITK.DiscreteGaussian(moving, 2.0)
373
- displacement_field = demons.Execute(fixed, moving_f)
374
- out_tx = sITK.DisplacementFieldTransform(displacement_field)
375
- resampler.SetTransform(out_tx)
376
- out = resampler.Execute(moving)
377
- dem_reg[i, :, :] = sITK.GetArrayFromImage(out)
378
- # print('image ', i)
379
-
380
- if QT_available:
381
- progress.close()
382
-
383
- print(':-)')
384
- print('You have successfully completed Diffeomorphic Demons Registration')
385
-
386
- demon_registered = dataset.like_data(dem_reg)
387
- demon_registered.title = 'Non-Rigid Registration'
388
- demon_registered.source = dataset.title
389
-
390
- demon_registered.metadata = {'analysis': 'non-rigid demon registration'}
391
- if 'boundaries' in dataset.metadata:
392
- demon_registered.metadata['boundaries'] = dataset.metadata['boundaries']
393
-
394
- return demon_registered
395
-
396
-
397
- ###############################
398
- # Rigid Registration New 05/09/2020
399
-
400
- def rigid_registration(dataset):
401
- """
402
- Rigid registration of image stack with sub-pixel accuracy
403
- used phase_cross_correlation from skimage.registration
404
- (we determine drift from one image to next)
405
-
406
- Input hdf5 group with image_stack dataset
407
-
408
- Output Registered Stack and drift (with respect to center image)
409
-
410
- """
411
-
412
- nopix = dataset.shape[1]
413
- nopiy = dataset.shape[2]
414
- nimages = dataset.shape[0]
415
-
416
- print('Stack contains ', nimages, ' images, each with', nopix, ' pixels in x-direction and ', nopiy,
417
- ' pixels in y-direction')
418
- fixed = np.array(dataset[0])
419
- fft_fixed = np.fft.fft2(fixed)
420
-
421
- relative_drift = [[0., 0.]]
422
- done = 0
423
-
424
- if QT_available:
425
- progress = ProgressDialog("Rigid Registration", nimages)
426
- for i in range(nimages):
427
- if QT_available:
428
- progress.set_value(i)
429
- else:
430
- if done < int((i + 1) / nimages * 50):
431
- done = int((i + 1) / nimages * 50)
432
- sys.stdout.write('\r')
433
- # progress output :
434
- sys.stdout.write("[%-50s] %d%%" % ('*' * done, 2 * done))
435
- sys.stdout.flush()
436
-
437
- moving = np.array(dataset[i])
438
- fft_moving = np.fft.fft2(moving)
439
- if skimage.__version__[:4] == '0.16':
440
- shift = register_translation(fft_fixed, fft_moving, upsample_factor=1000, space='fourier')
441
- else:
442
- shift = registration.phase_cross_correlation(fft_fixed, fft_moving, upsample_factor=1000, space='fourier')
443
-
444
- fft_fixed = fft_moving
445
- # print(f'Image number {i:2} xshift = {shift[0][0]:6.3f} y-shift = {shift[0][1]:6.3f}')
446
-
447
- relative_drift.append(shift[0])
448
- if QT_available:
449
- progress.close()
450
- rig_reg, drift = rig_reg_drift(dataset, relative_drift)
451
-
452
- crop_reg, boundaries = crop_image_stack(rig_reg, drift)
453
-
454
- rigid_registered = dataset.like_data(crop_reg)
455
- rigid_registered.title = 'Rigid Registration'
456
- rigid_registered.source = dataset.title
457
- rigid_registered.metadata = {'analysis': 'rigid sub-pixel registration', 'drift': drift, 'boundaries': boundaries}
458
-
459
- return rigid_registered
460
-
461
-
462
- def rig_reg_drift(dset, rel_drift):
463
- """
464
- Uses relative drift to shift images ontop of each other
465
- Shifting is done with shift routine of ndimage from scipy
466
-
467
- is used by Rigid_Registration routine
468
-
469
- Input image_channel with image_stack numpy array
470
- relative_drift from image to image as list of [shiftx, shifty]
471
-
472
- output stack and drift
473
- """
474
-
475
- rig_reg = np.zeros(dset.shape)
476
- # absolute drift
477
- drift = np.array(rel_drift).copy()
478
-
479
- drift[0] = [0, 0]
480
- for i in range(drift.shape[0]):
481
- drift[i] = drift[i - 1] + rel_drift[i]
482
- center_drift = drift[int(drift.shape[0] / 2)]
483
- drift = drift - center_drift
484
- # Shift images
485
- for i in range(rig_reg.shape[0]):
486
- # Now we shift
487
- rig_reg[i, :, :] = ndimage.shift(dset[i], [drift[i, 0], drift[i, 1]], order=3)
488
- return rig_reg, drift
489
-
490
-
491
- def crop_image_stack(rig_reg, drift):
492
- """
493
- ## Crop images
494
- """
495
- xpmin = int(-np.floor(np.min(np.array(drift)[:, 0])))
496
- xpmax = int(rig_reg.shape[1] - np.ceil(np.max(np.array(drift)[:, 0])))
497
- ypmin = int(-np.floor(np.min(np.array(drift)[:, 1])))
498
- ypmax = int(rig_reg.shape[2] - np.ceil(np.max(np.array(drift)[:, 1])))
499
-
500
- return rig_reg[:, xpmin:xpmax, ypmin:ypmax], [xpmin, xpmax, ypmin, ypmax]
501
-
502
-
503
- class ImageWithLineProfile:
504
- def __init__(self, data, extent, title=''):
505
- fig, ax = plt.subplots(1, 1)
506
- self.figure = fig
507
- self.title = title
508
- self.line_plot = False
509
- self.ax = ax
510
- self.data = data
511
- self.extent = extent
512
- self.ax.imshow(data, extent=extent)
513
- self.ax.set_title(title)
514
- self.line, = self.ax.plot([0], [0], color='orange') # empty line
515
- self.end_x = self.line.get_xdata()
516
- self.end_y = self.line.get_ydata()
517
- self.cid = self.line.figure.canvas.mpl_connect('button_press_event', self)
518
-
519
- def __call__(self, event):
520
- if event.inaxes != self.line.axes:
521
- return
522
- self.start_x = self.end_x
523
- self.start_y = self.end_y
524
-
525
- self.line.set_data([self.start_x, event.xdata], [self.start_y, event.ydata])
526
- self.line.figure.canvas.draw()
527
-
528
- self.end_x = event.xdata
529
- self.end_y = event.ydata
530
-
531
- self.update()
532
-
533
- def update(self):
534
-
535
- if not self.line_plot:
536
- self.line_plot = True
537
- self.figure.clear()
538
- self.ax = self.figure.subplots(2, 1)
539
- self.ax[0].imshow(self.data, extent=self.extent)
540
- self.ax[0].set_title(self.title)
541
-
542
- self.line, = self.ax[0].plot([0], [0], color='orange') # empty line
543
- self.line_plot, = self.ax[1].plot([], [], color='orange')
544
- self.ax[1].set_xlabel('distance [nm]')
545
-
546
- x0 = self.start_x
547
- x1 = self.end_x
548
- y0 = self.start_y
549
- y1 = self.end_y
550
- length_plot = np.sqrt((x1-x0)**2+(y1-y0)**2)
551
-
552
- num = length_plot*(self.data.shape[0]/self.extent[1])
553
- x = np.linspace(x0, x1, num)*(self.data.shape[0]/self.extent[1])
554
- y = np.linspace(y0, y1, num)*(self.data.shape[0]/self.extent[1])
555
-
556
- # Extract the values along the line, using cubic interpolation
557
- zi2 = ndimage.map_coordinates(self.data.T, np.vstack((x, y)))
558
-
559
- x_axis = np.linspace(0, length_plot, len(zi2))
560
-
561
- self.x = x_axis
562
- self.z = zi2
563
-
564
- self.line_plot.set_xdata(x_axis)
565
- self.line_plot.set_ydata(zi2)
566
- self.ax[1].set_xlim(0, x_axis.max())
567
- self.ax[1].set_ylim(zi2.min(), zi2.max())
568
- self.ax[1].draw()
569
-
570
-
571
- def histogram_plot(image_tags):
572
- nbins = 75
573
- minbin = 0.
574
- maxbin = 1.
575
- color_map_list = ['gray', 'viridis', 'jet', 'hot']
576
-
577
- if 'minimum_intensity' not in image_tags:
578
- image_tags['minimum_intensity'] = image_tags['plotimage'].min()
579
- minimum_intensity = image_tags['minimum_intensity']
580
- if 'maximum_intensity' not in image_tags:
581
- image_tags['maximum_intensity'] = image_tags['plotimage'].max()
582
- data = image_tags['plotimage']
583
- vmin = image_tags['minimum_intensity']
584
- vmax = image_tags['maximum_intensity']
585
- if 'color_map' not in image_tags:
586
- image_tags['color_map'] = color_map_list[0]
587
- cmap = plt.cm.get_cmap(image_tags['color_map'])
588
-
589
- colors = cmap(np.linspace(0., 1., nbins))
590
-
591
- norm2 = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
592
- hist, bin_edges = np.histogram(data, np.linspace(vmin, vmax, nbins), density=True)
593
-
594
- width = bin_edges[1]-bin_edges[0]
595
-
596
- def onselect(vmin, vmax):
597
-
598
- ax1.clear()
599
- cmap = plt.cm.get_cmap(image_tags['color_map'])
600
-
601
- colors = cmap(np.linspace(0., 1., nbins))
602
-
603
- norm2 = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
604
- hist2, bin_edges2 = np.histogram(data, np.linspace(vmin, vmax, nbins), density=True)
605
-
606
- width2 = (bin_edges2[1]-bin_edges2[0])
607
-
608
- for i in range(nbins-1):
609
- histogram[i].xy = (bin_edges2[i], 0)
610
- histogram[i].set_height(hist2[i])
611
- histogram[i].set_width(width2)
612
- histogram[i].set_facecolor(colors[i])
613
- ax.set_xlim(vmin, vmax)
614
- ax.set_ylim(0, hist2.max()*1.01)
615
-
616
- cb1 = mpl.colorbar.ColorbarBase(ax1, cmap=cmap, norm=norm2, orientation='horizontal')
617
-
618
- image_tags['minimum_intensity'] = vmin
619
- image_tags['maximum_intensity'] = vmax
620
-
621
- def onclick(event):
622
- global event2
623
- event2 = event
624
- print('%s click: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' %
625
- ('double' if event.dblclick else 'single', event.button,
626
- event.x, event.y, event.xdata, event.ydata))
627
- if event.inaxes == ax1:
628
- if event.button == 3:
629
- ind = color_map_list.index(image_tags['color_map'])+1
630
- if ind == len(color_map_list):
631
- ind = 0
632
- image_tags['color_map'] = color_map_list[ind] # 'viridis'
633
- vmin = image_tags['minimum_intensity']
634
- vmax = image_tags['maximum_intensity']
635
- else:
636
- vmax = data.max()
637
- vmin = data.min()
638
- onselect(vmin, vmax)
639
-
640
- fig2 = plt.figure()
641
-
642
- ax = fig2.add_axes([0., 0.2, 0.9, 0.7])
643
- ax1 = fig2.add_axes([0., 0.15, 0.9, 0.05])
644
-
645
- histogram = ax.bar(bin_edges[0:-1], hist, width=width, color=colors, edgecolor='black', alpha=0.8)
646
- onselect(vmin, vmax)
647
- cb1 = mpl.colorbar.ColorbarBase(ax1, cmap=cmap, norm=norm2, orientation='horizontal')
648
-
649
- rectprops = dict(facecolor='blue', alpha=0.5)
650
-
651
- span = mwidgets.SpanSelector(ax, onselect, 'horizontal', rectprops=rectprops)
652
-
653
- cid = fig2.canvas.mpl_connect('button_press_event', onclick)
654
- return span
655
-
656
-
657
- def clean_svd(im, pixel_size=1, source_size=5):
658
- patch_size = int(source_size/pixel_size)
659
- if patch_size < 3:
660
- patch_size = 3
661
- print(patch_size)
662
-
663
- patches = image.extract_patches_2d(im, (patch_size, patch_size))
664
- patches = patches.reshape(patches.shape[0], patches.shape[1]*patches.shape[2])
665
-
666
- num_components = 32
667
-
668
- u, s, v = randomized_svd(patches, num_components)
669
- u_im_size = int(np.sqrt(u.shape[0]))
670
- reduced_image = u[:, 0].reshape(u_im_size, u_im_size)
671
- reduced_image = reduced_image/reduced_image.sum()*im.sum()
672
- return reduced_image
673
-
674
-
675
- def rebin(im, binning=2):
676
- """
677
- rebin an image by the number of pixels in x and y direction given by binning
678
-
679
- Input:
680
- ======
681
- image: numpy array in 2 dimensions
682
-
683
- Output:
684
- =======
685
- binned image
686
- """
687
- if len(im.shape) == 2:
688
- return im.reshape((im.shape[0]//binning, binning, im.shape[1]//binning, binning)).mean(axis=3).mean(1)
689
- else:
690
- print('not a 2D image')
691
- return im
692
-
693
-
694
- def cart2pol(points):
695
- rho = np.linalg.norm(points[:, 0:2], axis=1)
696
- phi = np.arctan2(points[:, 1], points[:, 0])
697
- return rho, phi
698
-
699
-
700
- def pol2cart(rho, phi):
701
- x = rho * np.cos(phi)
702
- y = rho * np.sin(phi)
703
- return x, y
704
-
705
-
706
- def xy2polar(points, rounding=1e-3):
707
- """
708
- Conversion from carthesian to polar coordinates
709
-
710
- the angles and distances are sorted by r and then phi
711
- The indices of this sort is also returned
712
-
713
- input points: numpy array with number of points in axis 0 first two elements in axis 1 are x and y
714
-
715
- optional rounding in significant digits
716
-
717
- returns r,phi, sorted_indices
718
- """
719
-
720
- r, phi = cart2pol(points)
721
-
722
- phi = phi-phi.min() # only positive angles
723
- r = (np.floor(r/rounding))*rounding # Remove rounding error differences
724
-
725
- sorted_indices = np.lexsort((phi, r)) # sort first by r and then by phi
726
- r = r[sorted_indices]
727
- phi = phi[sorted_indices]
728
-
729
- return r, phi, sorted_indices
730
-
731
-
732
- def cartesian2polar(x, y, grid, r, t, order=3):
733
-
734
- rr, tt = np.meshgrid(r, t)
735
-
736
- new_x = rr*np.cos(tt)
737
- new_y = rr*np.sin(tt)
738
-
739
- ix = interp1d(x, np.arange(len(x)))
740
- iy = interp1d(y, np.arange(len(y)))
741
-
742
- new_ix = ix(new_x.ravel())
743
- new_iy = iy(new_y.ravel())
744
-
745
- return ndimage.map_coordinates(grid, np.array([new_ix, new_iy]), order=order).reshape(new_x.shape)
746
-
747
-
748
- def warp(diff, center):
749
- # Define original polar grid
750
- nx = diff.shape[0]
751
- ny = diff.shape[1]
752
-
753
- x = np.linspace(1, nx, nx, endpoint=True)-center[1]
754
- y = np.linspace(1, ny, ny, endpoint=True)-center[0]
755
- z = np.abs(diff)
756
-
757
- # Define new polar grid
758
- nr = min([center[0], center[1], diff.shape[0]-center[0], diff.shape[1]-center[1]])-1
759
- nt = 360*3
760
-
761
- r = np.linspace(1, nr, nr)
762
- t = np.linspace(0., np.pi, nt, endpoint=False)
763
- return cartesian2polar(x, y, z, r, t, order=3).T
764
-
765
-
766
- def calculate_ctf(wavelength, cs, defocus, k):
767
- """ Calculate Contrast Transfer Function
768
- everything in nm
769
- """
770
- ctf = np.sin(np.pi*defocus*wavelength*k**2+0.5*np.pi*cs*wavelength**3*k**4)
771
- return ctf
772
-
773
-
774
- def calculate_scherzer(wavelength, cs):
775
- """
776
- Calculate the Scherzer defocus. Cs is in mm, lambda is in nm
777
- # EInput and output in nm
778
- """
779
- scherzer = -1.155*(cs*wavelength)**0.5 # in m
780
- return scherzer
781
-
782
-
783
- def calibrate_image_scale(fft_tags, spots_reference, spots_experiment):
784
- gx = fft_tags['spatial_scale_x']
785
- gy = fft_tags['spatial_scale_y']
786
-
787
- dist_reference = np.linalg.norm(spots_reference, axis=1)
788
- distance_experiment = np.linalg.norm(spots_experiment, axis=1)
789
-
790
- first_reflections = abs(distance_experiment - dist_reference.min()) < .2
791
- print('Evaluate ', first_reflections.sum(), 'reflections')
792
- closest_exp_reflections = spots_experiment[first_reflections]
793
-
794
- def func(params, xdata, ydata):
795
- dgx, dgy = params
796
- return np.sqrt((xdata * dgx) ** 2 + (ydata * dgy) ** 2) - dist_reference.min()
797
-
798
- x0 = [1.001, 0.999]
799
- dg, sig = optimization.leastsq(func, x0, args=(closest_exp_reflections[:, 0], closest_exp_reflections[:, 1]))
800
- return dg
801
-
802
-
803
- def align_crystal_reflections(spots, crystals):
804
- crystal_reflections_polar = []
805
- angles = []
806
- exp_r, exp_phi = cart2pol(spots) # just in polar coordinates
807
- spots_polar = np.array([exp_r, exp_phi])
808
-
809
- for i in range(len(crystals)):
810
- tags = crystals[i]
811
- r, phi, indices = xy2polar(tags['allowed']['g']) # sorted by r and phi , only positive angles
812
- # we mask the experimental values that are found already
813
- angle = 0.
814
-
815
- angle_i = np.argmin(np.abs(exp_r - r[1]))
816
- angle = exp_phi[angle_i] - phi[0]
817
- angles.append(angle) # Determine rotation angle
818
-
819
- crystal_reflections_polar.append([r, angle + phi, indices])
820
- tags['allowed']['g_rotated'] = pol2cart(r, angle + phi)
821
- for spot in tags['allowed']['g']:
822
- dif = np.linalg.norm(spots[:, 0:2]-spot[0:2], axis=1)
823
- # print(dif.min())
824
- if dif.min() < 1.5:
825
- ind = np.argmin(dif)
826
-
827
- return crystal_reflections_polar, angles
828
-
829
- # Deconvolution
830
- def decon_lr(o_image, probe, tags, verbose=False):
831
- """
832
- # This task generates a restored image from an input image and point spread function (PSF) using
833
- # the algorithm developed independently by Lucy (1974, Astron. J. 79, 745) and Richardson
834
- # (1972, J. Opt. Soc. Am. 62, 55) and adapted for HST imagery by Snyder
835
- # (1990, in Restoration of HST Images and Spectra, ST ScI Workshop Proceedings; see also
836
- # Snyder, Hammoud, & White, JOSA, v. 10, no. 5, May 1993, in press).
837
- # Additional options developed by Rick White (STScI) are also included.
838
- #
839
- # The Lucy-Richardson method can be derived from the maximum likelihood expression for data
840
- # with a Poisson noise distribution. Thus, it naturally applies to optical imaging data such as HST.
841
- # The method forces the restored image to be positive, in accord with photon-counting statistics.
842
- #
843
- # The Lucy-Richardson algorithm generates a restored image through an iterative method. The essence
844
- # of the iteration is as follows: the (n+1)th estimate of the restored image is given by the nth estimate
845
- # of the restored image multiplied by a correction image. That is,
846
- #
847
- # original data
848
- # image = image --------------- * reflect(PSF)
849
- # n+1 n image * PSF
850
- # n
851
-
852
- # where the *'s represent convolution operators and reflect(PSF) is the reflection of the PSF, i.e.
853
- # reflect((PSF)(x,y)) = PSF(-x,-y). When the convolutions are carried out using fast Fourier transforms
854
- # (FFTs), one can use the fact that FFT(reflect(PSF)) = conj(FFT(PSF)), where conj is the complex conjugate
855
- # operator.
856
- """
857
-
858
- if len(o_image) < 1:
859
- return o_image
860
-
861
- if o_image.shape != probe.shape:
862
- print('Weirdness ', o_image.shape, ' != ', probe.shape)
863
-
864
- probe_c = np.ones(probe.shape, dtype=np.complex64)
865
- probe_c.real = probe
866
-
867
- error = np.ones(o_image.shape, dtype=np.complex64)
868
- est = np.ones(o_image.shape, dtype=np.complex64)
869
- source = np.ones(o_image.shape, dtype=np.complex64)
870
- source.real = o_image
871
-
872
- response_ft = fftpack.fft2(probe_c)
873
-
874
- if 'ImageScanned' in tags:
875
- ab = tags['ImageScanned']
876
- elif 'aberrations' in tags:
877
- ab = tags['aberrations']
878
- if 'convAngle' not in ab:
879
- ab['convAngle'] = 30
880
- ap_angle = ab['convAngle'] / 1000.0
881
-
882
- e0 = float(ab['EHT'])
883
-
884
- wl = get_wavelength(e0)
885
- ab['wavelength'] = wl
886
-
887
- over_d = 2 * ap_angle / wl
888
-
889
- dx = tags['pixel_size']
890
- dk = 1.0 / float(tags['fov'])
891
- screen_width = 1 / dx
892
-
893
- aperture = np.ones(o_image.shape, dtype=np.complex64)
894
- # Mask for the aperture before the Fourier transform
895
- n = o_image.shape[0]
896
- size_x = o_image.shape[0]
897
- size_y = o_image.shape[1]
898
- app_ratio = over_d / screen_width * n
899
-
900
- theta_x = np.array(-size_x / 2. + np.arange(size_x))
901
- theta_y = np.array(-size_y / 2. + np.arange(size_y))
902
- t_xv, t_yv = np.meshgrid(theta_x, theta_y)
903
-
904
- tp1 = t_xv ** 2 + t_yv ** 2 >= app_ratio ** 2
905
- aperture[tp1.T] = 0.
906
- print(app_ratio, screen_width, dk)
907
-
908
- if QT_available:
909
- progress = ProgressDialog("Lucy-Richardson", 100)
910
-
911
- # de = 100
912
- dest = 100
913
- i = 0
914
- while abs(dest) > 0.0001: # or abs(de) > .025:
915
- i += 1
916
- error_old = np.sum(error.real)
917
- est_old = est.copy()
918
- error = source / np.real(fftpack.fftshift(fftpack.ifft2(fftpack.fft2(est) * response_ft)))
919
- est = est * np.real(fftpack.fftshift(fftpack.ifft2(fftpack.fft2(error) * np.conjugate(response_ft))))
920
- # est = est_old * est
921
- # est = np.real(fftpack.fftshift(fftpack.ifft2(fftpack.fft2(est)*fftpack.fftshift(aperture) )))
922
-
923
- error_new = np.real(np.sum(np.power(error, 2))) - error_old
924
- dest = np.sum(np.power((est - est_old).real, 2)) / np.sum(est) * 100
925
- # print(np.sum((est.real - est_old.real)* (est.real - est_old.real) )/np.sum(est.real)*100 )
926
-
927
- if error_old != 0:
928
- de = error_new / error_old * 1.0
929
- else:
930
- de = error_new
931
-
932
- if verbose:
933
- print(
934
- ' LR Deconvolution - Iteration: {0:d} Error: {1:.2f} = change: {2:.5f}%, {3:.5f}%'.format(i, error_new,
935
- de,
936
- abs(dest)))
937
-
938
- if QT_available:
939
- count = (0.1 - abs(dest)) * 1000.
940
- if count < 0:
941
- count = 0
942
- progress.set_value(count)
943
-
944
-
945
- if i > 1000:
946
- dest = 0.0
947
- print('terminate')
948
- if QT_available:
949
- progress.close()
950
- print('\n Lucy-Richardson deconvolution converged in ' + str(i) + ' Iterations')
951
- est2 = np.real(fftpack.ifft2(fftpack.fft2(est) * fftpack.fftshift(aperture)))
952
- # plt.imshow(np.real(np.log10(np.abs(fftpack.fftshift(fftpack.fft2(est)))+1)+aperture), origin='lower',)
953
- # plt.show()
954
- print(est2.shape)
955
- return est2
956
-
1
+ """
2
+ image_tools.py
3
+ by Gerd Duscher, UTK
4
+ part of pyTEMlib
5
+ MIT license except where stated differently
6
+ """
7
+
8
+ import numpy as np
9
+ import matplotlib
10
+ import matplotlib as mpl
11
+ import matplotlib.pylab as plt
12
+ import matplotlib.widgets as mwidgets
13
+ # from matplotlib.widgets import RectangleSelector
14
+
15
+ import sidpy
16
+ import pyTEMlib.file_tools as ft
17
+ import pyTEMlib.sidpy_tools
18
+ # import pyTEMlib.probe_tools
19
+
20
+ from tqdm.auto import trange, tqdm
21
+
22
+ # import itertools
23
+ from itertools import product
24
+
25
+ from scipy import fftpack
26
+ import scipy
27
+ # from scipy import signal
28
+ from scipy.interpolate import interp1d # , interp2d
29
+ import scipy.optimize as optimization
30
+
31
+ # Multidimensional Image library
32
+ import scipy.ndimage as ndimage
33
+ import scipy.constants as const
34
+
35
+ # from scipy.spatial import Voronoi, KDTree, cKDTree
36
+
37
+ import skimage
38
+
39
+ import skimage.registration as registration
40
+ # from skimage.feature import register_translation # blob_dog, blob_doh
41
+ from skimage.feature import peak_local_max
42
+ # from skimage.measure import points_in_poly
43
+
44
+ # our blob detectors from the scipy image package
45
+ from skimage.feature import blob_log # blob_dog, blob_doh
46
+
47
+ from sklearn.feature_extraction import image
48
+ from sklearn.utils.extmath import randomized_svd
49
+ from sklearn.cluster import DBSCAN
50
+
51
+ from collections import Counter
52
+
53
+ # center diff function
54
+ from skimage.filters import threshold_otsu, sobel
55
+ from scipy.optimize import leastsq
56
+ from sklearn.cluster import DBSCAN
57
+
58
+
59
+ _SimpleITK_present = True
60
+ try:
61
+ import SimpleITK as sitk
62
+ except ImportError:
63
+ sitk = False
64
+ _SimpleITK_present = False
65
+
66
+ if not _SimpleITK_present:
67
+ print('SimpleITK not installed; Registration Functions for Image Stacks not available\n' +
68
+ 'install with: conda install -c simpleitk simpleitk ')
69
+
70
+
71
+ # Wavelength in 1/nm
72
+ def get_wavelength(e0):
73
+ """
74
+ Calculates the relativistic corrected de Broglie wave length of an electron
75
+
76
+ Parameters
77
+ ----------
78
+ e0: float
79
+ acceleration voltage in volt
80
+
81
+ Returns
82
+ -------
83
+ wave length in 1/nm
84
+ """
85
+
86
+ eV = const.e * e0
87
+ return const.h/np.sqrt(2*const.m_e*eV*(1+eV/(2*const.m_e*const.c**2)))*10**9
88
+
89
+
90
+ def fourier_transform(dset: sidpy.Dataset) -> sidpy.Dataset:
91
+ """
92
+ Reads information into dictionary 'tags', performs 'FFT', and provides a smoothed FT and reciprocal
93
+ and intensity limits for visualization.
94
+
95
+ Parameters
96
+ ----------
97
+ dset: sidpy.Dataset
98
+ image
99
+
100
+ Returns
101
+ -------
102
+ fft_dset: sidpy.Dataset
103
+ Fourier transform with correct dimensions
104
+
105
+ Example
106
+ -------
107
+ >>> fft_dataset = fourier_transform(sidpy_dataset)
108
+ >>> fft_dataset.plot()
109
+ """
110
+
111
+ assert isinstance(dset, sidpy.Dataset), 'Expected a sidpy Dataset'
112
+
113
+ selection = []
114
+ image_dim = []
115
+ # image_dim = get_image_dims(sidpy.DimensionTypes.SPATIAL)
116
+
117
+ if dset.data_type == sidpy.DataType.IMAGE_STACK:
118
+ image_dim = dset.get_image_dims()
119
+ stack_dim = dset.get_dimensions_by_type('TEMPORAL')
120
+
121
+ if len(image_dim) != 2:
122
+ raise ValueError('need at least two SPATIAL dimension for an image stack')
123
+
124
+ for i in range(dset.ndim):
125
+ if i in image_dim:
126
+ selection.append(slice(None))
127
+ if len(stack_dim) == 0:
128
+ stack_dim = i
129
+ selection.append(slice(None))
130
+ elif i in stack_dim:
131
+ stack_dim = i
132
+ selection.append(slice(None))
133
+ else:
134
+ selection.append(slice(0, 1))
135
+
136
+ image_stack = np.squeeze(np.array(dset)[selection])
137
+ new_image = np.sum(np.array(image_stack), axis=stack_dim)
138
+ elif dset.data_type == sidpy.DataType.IMAGE:
139
+ new_image = np.array(dset)
140
+ else:
141
+ return
142
+
143
+ new_image = new_image - new_image.min()
144
+ fft_transform = (np.fft.fftshift(np.fft.fft2(new_image)))
145
+
146
+ image_dims = pyTEMlib.sidpy_tools.get_image_dims(dset)
147
+
148
+ units_x = '1/' + dset._axes[image_dims[0]].units
149
+ units_y = '1/' + dset._axes[image_dims[1]].units
150
+
151
+ fft_dset = sidpy.Dataset.from_array(fft_transform)
152
+ fft_dset.quantity = dset.quantity
153
+ fft_dset.units = 'a.u.'
154
+ fft_dset.data_type = 'IMAGE'
155
+ fft_dset.source = dset.title
156
+ fft_dset.modality = 'fft'
157
+
158
+ fft_dset.set_dimension(0, sidpy.Dimension(np.fft.fftshift(np.fft.fftfreq(new_image.shape[0],
159
+ d=ft.get_slope(dset.x.values))),
160
+
161
+ name='u', units=units_x, dimension_type='RECIPROCAL',
162
+ quantity='reciprocal_length'))
163
+ fft_dset.set_dimension(1, sidpy.Dimension(np.fft.fftshift(np.fft.fftfreq(new_image.shape[1],
164
+ d=ft.get_slope(dset.y.values))),
165
+ name='v', units=units_y, dimension_type='RECIPROCAL',
166
+ quantity='reciprocal_length'))
167
+
168
+ return fft_dset
169
+
170
+
171
+ def power_spectrum(dset, smoothing=3):
172
+ """
173
+ Calculate power spectrum
174
+
175
+ Parameters
176
+ ----------
177
+ dset: sidpy.Dataset
178
+ image
179
+ smoothing: int
180
+ Gaussian smoothing
181
+
182
+ Returns
183
+ -------
184
+ power_spec: sidpy.Dataset
185
+ power spectrum with correct dimensions
186
+
187
+ """
188
+
189
+ fft_transform = fourier_transform(dset) # dset.fft()
190
+ fft_mag = np.abs(fft_transform)
191
+ fft_mag2 = ndimage.gaussian_filter(fft_mag, sigma=(smoothing, smoothing), order=0)
192
+
193
+ power_spec = fft_transform.like_data(np.log(1.+fft_mag2))
194
+
195
+ # prepare mask
196
+ x, y = np.meshgrid(power_spec.v.values, power_spec.u.values)
197
+ mask = np.zeros(power_spec.shape)
198
+
199
+ mask_spot = x ** 2 + y ** 2 > 1 ** 2
200
+ mask = mask + mask_spot
201
+ mask_spot = x ** 2 + y ** 2 < 11 ** 2
202
+ mask = mask + mask_spot
203
+
204
+ mask[np.where(mask == 1)] = 0 # just in case of overlapping disks
205
+
206
+ minimum_intensity = np.array(power_spec)[np.where(mask == 2)].min() * 0.95
207
+ maximum_intensity = np.array(power_spec)[np.where(mask == 2)].max() * 1.05
208
+ power_spec.metadata = {'fft': {'smoothing': smoothing,
209
+ 'minimum_intensity': minimum_intensity, 'maximum_intensity': maximum_intensity}}
210
+ power_spec.title = 'power spectrum ' + power_spec.source
211
+
212
+ return power_spec
213
+
214
+
215
+ def diffractogram_spots(dset, spot_threshold, return_center=True, eps=0.1):
216
+ """Find spots in diffractogram and sort them by distance from center
217
+
218
+ Uses blob_log from scipy.spatial
219
+
220
+ Parameters
221
+ ----------
222
+ dset: sidpy.Dataset
223
+ diffractogram
224
+ spot_threshold: float
225
+ threshold for blob finder
226
+ return_center: bool, optional
227
+ return center of image if true
228
+ eps: float, optional
229
+ threshold for blob finder
230
+
231
+ Returns
232
+ -------
233
+ spots: numpy array
234
+ sorted position (x,y) and radius (r) of all spots
235
+ """
236
+
237
+ # spot detection (for future reference there is no symmetry assumed here)
238
+ data = np.array(np.log(1+np.abs(dset)))
239
+ data = data - data.min()
240
+ data = data/data.max()
241
+ # some images are strange and blob_log does not work on the power spectrum
242
+ try:
243
+ spots_random = blob_log(data, max_sigma=5, threshold=spot_threshold)
244
+ except ValueError:
245
+ spots_random = peak_local_max(np.array(data.T), min_distance=3, threshold_rel=spot_threshold)
246
+ spots_random = np.hstack(spots_random, np.zeros((spots_random.shape[0], 1)))
247
+
248
+ print(f'Found {spots_random.shape[0]} reflections')
249
+
250
+ # Needed for conversion from pixel to Reciprocal space
251
+ rec_scale = np.array([ft.get_slope(dset.u.values), ft.get_slope(dset.v.values)])
252
+ spots_random[:, :2] = spots_random[:, :2]*rec_scale+[dset.u.values[0], dset.v.values[0]]
253
+ # sort reflections
254
+ spots_random[:, 2] = np.linalg.norm(spots_random[:, 0:2], axis=1)
255
+ spots_index = np.argsort(spots_random[:, 2])
256
+ spots = spots_random[spots_index]
257
+ # third row is angles
258
+ spots[:, 2] = np.arctan2(spots[:, 0], spots[:, 1])
259
+
260
+ center = [0, 0]
261
+
262
+ if return_center:
263
+ points = spots[:, 0:2]
264
+
265
+ # Calculate the midpoints between all points
266
+ reshaped_points = points[:, np.newaxis, :]
267
+ midpoints = (reshaped_points + reshaped_points.transpose(1, 0, 2)) / 2.0
268
+ midpoints = midpoints.reshape(-1, 2)
269
+
270
+ # Find the most dense cluster of midpoints
271
+ dbscan = DBSCAN(eps=eps, min_samples=2)
272
+ labels = dbscan.fit_predict(midpoints)
273
+ cluster_counter = Counter(labels)
274
+ largest_cluster_label = max(cluster_counter, key=cluster_counter.get)
275
+ largest_cluster_points = midpoints[labels == largest_cluster_label]
276
+
277
+ # Average of these midpoints must be the center
278
+ center = np.mean(largest_cluster_points, axis=0)
279
+
280
+ return spots, center
281
+
282
+
283
+ def center_diffractogram(dset, return_plot = True, histogram_factor = None, smoothing = 1, min_samples = 100):
284
+ try:
285
+ diff = np.array(dset).T.astype(np.float16)
286
+ diff[diff < 0] = 0
287
+
288
+ if histogram_factor is not None:
289
+ hist, bins = np.histogram(np.ravel(diff), bins=256, range=(0, 1), density=True)
290
+ threshold = threshold_otsu(diff, hist = hist * histogram_factor)
291
+ else:
292
+ threshold = threshold_otsu(diff)
293
+ binary = (diff > threshold).astype(float)
294
+ smoothed_image = ndimage.gaussian_filter(binary, sigma=smoothing) # Smooth before edge detection
295
+ smooth_threshold = threshold_otsu(smoothed_image)
296
+ smooth_binary = (smoothed_image > smooth_threshold).astype(float)
297
+ # Find the edges using the Sobel operator
298
+ edges = sobel(smooth_binary)
299
+ edge_points = np.argwhere(edges)
300
+
301
+ # Use DBSCAN to cluster the edge points
302
+ db = DBSCAN(eps=10, min_samples=min_samples).fit(edge_points)
303
+ labels = db.labels_
304
+ if len(set(labels)) == 1:
305
+ raise ValueError("DBSCAN clustering resulted in only one group, check the parameters.")
306
+
307
+ # Get the largest group of edge points
308
+ unique, counts = np.unique(labels, return_counts=True)
309
+ counts = dict(zip(unique, counts))
310
+ largest_group = max(counts, key=counts.get)
311
+ edge_points = edge_points[labels == largest_group]
312
+
313
+ # Fit a circle to the diffraction ring
314
+ def calc_distance(c, x, y):
315
+ Ri = np.sqrt((x - c[0])**2 + (y - c[1])**2)
316
+ return Ri - Ri.mean()
317
+ x_m = np.mean(edge_points[:, 1])
318
+ y_m = np.mean(edge_points[:, 0])
319
+ center_guess = x_m, y_m
320
+ center, ier = leastsq(calc_distance, center_guess, args=(edge_points[:, 1], edge_points[:, 0]))
321
+ mean_radius = np.mean(calc_distance(center, edge_points[:, 1], edge_points[:, 0])) + np.sqrt((edge_points[:, 1] - center[0])**2 + (edge_points[:, 0] - center[1])**2).mean()
322
+
323
+ finally:
324
+ if return_plot:
325
+ fig, ax = plt.subplots(1, 4, figsize=(10, 4))
326
+ ax[0].set_title('Diffractogram')
327
+ ax[0].imshow(dset.T, cmap='viridis')
328
+ ax[1].set_title('Otsu Binary Image')
329
+ ax[1].imshow(binary, cmap='gray')
330
+ ax[2].set_title('Smoothed Binary Image')
331
+ ax[2].imshow(smooth_binary, cmap='gray')
332
+ ax[3].set_title('Edge Detection and Fitting')
333
+ ax[3].imshow(edges, cmap='gray')
334
+ ax[3].scatter(center[0], center[1], c='r', s=10)
335
+ circle = plt.Circle(center, mean_radius, color='red', fill=False)
336
+ ax[3].add_artist(circle)
337
+ for axis in ax:
338
+ axis.axis('off')
339
+ fig.tight_layout()
340
+
341
+ return center
342
+
343
+
344
+ def adaptive_fourier_filter(dset, spots, low_pass=3, reflection_radius=0.3):
345
+ """
346
+ Use spots in diffractogram for a Fourier Filter
347
+
348
+ Parameters:
349
+ -----------
350
+ dset: sidpy.Dataset
351
+ image to be filtered
352
+ spots: np.ndarray(N,2)
353
+ sorted spots in diffractogram in 1/nm
354
+ low_pass: float
355
+ low pass filter in center of diffractogram in 1/nm
356
+ reflection_radius: float
357
+ radius of masked reflections in 1/nm
358
+
359
+ Output:
360
+ -------
361
+ Fourier filtered image
362
+ """
363
+
364
+ if not isinstance(dset, sidpy.Dataset):
365
+ raise TypeError('We need a sidpy.Dataset')
366
+ fft_transform = fourier_transform(dset)
367
+
368
+ # prepare mask
369
+ x, y = np.meshgrid(fft_transform.v.values, fft_transform.u.values)
370
+ mask = np.zeros(dset.shape)
371
+
372
+ # mask reflections
373
+ for spot in spots:
374
+ mask_spot = (x - spot[1]) ** 2 + (y - spot[0]) ** 2 < reflection_radius ** 2 # make a spot
375
+ mask = mask + mask_spot # add spot to mask
376
+
377
+ # mask zero region larger (low-pass filter = intensity variations)
378
+ mask_spot = x ** 2 + y ** 2 < low_pass ** 2
379
+ mask = mask + mask_spot
380
+ mask[np.where(mask > 1)] = 1
381
+ fft_filtered = np.array(fft_transform * mask)
382
+
383
+ filtered_image = dset.like_data(np.fft.ifft2(np.fft.fftshift(fft_filtered)).real)
384
+ filtered_image.title = 'Fourier filtered ' + dset.title
385
+ filtered_image.source = dset.title
386
+ filtered_image.metadata = {'analysis': 'adaptive fourier filtered', 'spots': spots,
387
+ 'low_pass': low_pass, 'reflection_radius': reflection_radius}
388
+ return filtered_image
389
+
390
+
391
+ def rotational_symmetry_diffractogram(spots):
392
+ """ Test rotational symmetry of diffraction spots"""
393
+
394
+ rotation_symmetry = []
395
+ for n in [2, 3, 4, 6]:
396
+ cc = np.array(
397
+ [[np.cos(2 * np.pi / n), np.sin(2 * np.pi / n), 0], [-np.sin(2 * np.pi / n), np.cos(2 * np.pi / n), 0],
398
+ [0, 0, 1]])
399
+ sym_spots = np.dot(spots, cc)
400
+ dif = []
401
+ for p0, p1 in product(sym_spots[:, 0:2], spots[:, 0:2]):
402
+ dif.append(np.linalg.norm(p0 - p1))
403
+ dif = np.array(sorted(dif))
404
+
405
+ if dif[int(spots.shape[0] * .7)] < 0.2:
406
+ rotation_symmetry.append(n)
407
+ return rotation_symmetry
408
+
409
+ #####################################################
410
+ # Registration Functions
411
+ #####################################################
412
+
413
+
414
+ def complete_registration(main_dataset, storage_channel=None):
415
+ """Rigid and then non-rigid (demon) registration
416
+
417
+ Performs rigid and then non-rigid registration, please see individual functions:
418
+ - rigid_registration
419
+ - demon_registration
420
+
421
+ Parameters
422
+ ----------
423
+ main_dataset: sidpy.Dataset
424
+ dataset of data_type 'IMAGE_STACK' to be registered
425
+ storage_channel: h5py.Group
426
+ optional - location in hdf5 file to store datasets
427
+
428
+ Returns
429
+ -------
430
+ non_rigid_registered: sidpy.Dataset
431
+ rigid_registered_dataset: sidpy.Dataset
432
+
433
+ """
434
+
435
+ if not isinstance(main_dataset, sidpy.Dataset):
436
+ raise TypeError('We need a sidpy.Dataset')
437
+ if main_dataset.data_type.name != 'IMAGE_STACK':
438
+ raise TypeError('Registration makes only sense for an image stack')
439
+
440
+ print('Rigid_Registration')
441
+
442
+ rigid_registered_dataset = rigid_registration(main_dataset)
443
+
444
+
445
+ print('Non-Rigid_Registration')
446
+
447
+ non_rigid_registered = demon_registration(rigid_registered_dataset)
448
+ return non_rigid_registered, rigid_registered_dataset
449
+
450
+
451
+ def demon_registration(dataset, verbose=False):
452
+ """
453
+ Diffeomorphic Demon Non-Rigid Registration
454
+
455
+ Depends on:
456
+ simpleITK and numpy
457
+ Please Cite: http://www.simpleitk.org/SimpleITK/project/parti.html
458
+ and T. Vercauteren, X. Pennec, A. Perchant and N. Ayache
459
+ Diffeomorphic Demons Using ITK\'s Finite Difference Solver Hierarchy
460
+ The Insight Journal, http://hdl.handle.net/1926/510 2007
461
+
462
+ Parameters
463
+ ----------
464
+ dataset: sidpy.Dataset
465
+ stack of image after rigid registration and cropping
466
+ verbose: boolean
467
+ optional for increased output
468
+ Returns
469
+ -------
470
+ dem_reg: stack of images with non-rigid registration
471
+
472
+ Example
473
+ -------
474
+ dem_reg = demon_reg(stack_dataset, verbose=False)
475
+ """
476
+
477
+ if not isinstance(dataset, sidpy.Dataset):
478
+ raise TypeError('We need a sidpy.Dataset')
479
+ if dataset.data_type.name != 'IMAGE_STACK':
480
+ raise TypeError('Registration makes only sense for an image stack')
481
+
482
+ dem_reg = np.zeros(dataset.shape)
483
+ nimages = dataset.shape[0]
484
+ if verbose:
485
+ print(nimages)
486
+ # create fixed image by summing over rigid registration
487
+
488
+ fixed_np = np.average(np.array(dataset), axis=0)
489
+
490
+ if not _SimpleITK_present:
491
+ print('This feature is not available: \n Please install simpleITK with: conda install simpleitk -c simpleitk')
492
+
493
+ fixed = sitk.GetImageFromArray(fixed_np)
494
+ fixed = sitk.DiscreteGaussian(fixed, 2.0)
495
+
496
+ # demons = sitk.SymmetricForcesDemonsRegistrationFilter()
497
+ demons = sitk.DiffeomorphicDemonsRegistrationFilter()
498
+
499
+ demons.SetNumberOfIterations(200)
500
+ demons.SetStandardDeviations(1.0)
501
+
502
+ resampler = sitk.ResampleImageFilter()
503
+ resampler.SetReferenceImage(fixed)
504
+ resampler.SetInterpolator(sitk.sitkBSpline)
505
+ resampler.SetDefaultPixelValue(0)
506
+
507
+ for i in trange(nimages):
508
+
509
+ moving = sitk.GetImageFromArray(dataset[i])
510
+ moving_f = sitk.DiscreteGaussian(moving, 2.0)
511
+ displacement_field = demons.Execute(fixed, moving_f)
512
+ out_tx = sitk.DisplacementFieldTransform(displacement_field)
513
+ resampler.SetTransform(out_tx)
514
+ out = resampler.Execute(moving)
515
+ dem_reg[i, :, :] = sitk.GetArrayFromImage(out)
516
+
517
+ print(':-)')
518
+ print('You have successfully completed Diffeomorphic Demons Registration')
519
+
520
+ demon_registered = dataset.like_data(dem_reg)
521
+ demon_registered.title = 'Non-Rigid Registration'
522
+ demon_registered.source = dataset.title
523
+
524
+ demon_registered.metadata = {'analysis': 'non-rigid demon registration'}
525
+ if 'input_crop' in dataset.metadata:
526
+ demon_registered.metadata['input_crop'] = dataset.metadata['input_crop']
527
+ if 'input_shape' in dataset.metadata:
528
+ demon_registered.metadata['input_shape'] = dataset.metadata['input_shape']
529
+ demon_registered.metadata['input_dataset'] = dataset.source
530
+ return demon_registered
531
+
532
+
533
+ ###############################
534
+ # Rigid Registration New 05/09/2020
535
+
536
+ def rigid_registration(dataset, sub_pixel=True):
537
+ """
538
+ Rigid registration of image stack with pixel accuracy
539
+
540
+ Uses simple cross_correlation
541
+ (we determine drift from one image to next)
542
+
543
+ Parameters
544
+ ----------
545
+ dataset: sidpy.Dataset
546
+ sidpy dataset with image_stack dataset
547
+
548
+ Returns
549
+ -------
550
+ rigid_registered: sidpy.Dataset
551
+ Registered Stack and drift (with respect to center image)
552
+ """
553
+
554
+ if not isinstance(dataset, sidpy.Dataset):
555
+ raise TypeError('We need a sidpy.Dataset')
556
+ if dataset.data_type.name != 'IMAGE_STACK':
557
+ raise TypeError('Registration makes only sense for an image stack')
558
+
559
+ frame_dim = []
560
+ spatial_dim = []
561
+ selection = []
562
+
563
+ for i, axis in dataset._axes.items():
564
+ if axis.dimension_type.name == 'SPATIAL':
565
+ spatial_dim.append(i)
566
+ selection.append(slice(None))
567
+ else:
568
+ frame_dim.append(i)
569
+ selection.append(slice(0, 1))
570
+
571
+ if len(spatial_dim) != 2:
572
+ print('need two spatial dimensions')
573
+ if len(frame_dim) != 1:
574
+ print('need one frame dimensions')
575
+
576
+ nopix = dataset.shape[spatial_dim[0]]
577
+ nopiy = dataset.shape[spatial_dim[1]]
578
+ nimages = dataset.shape[frame_dim[0]]
579
+
580
+ print('Stack contains ', nimages, ' images, each with', nopix, ' pixels in x-direction and ', nopiy,
581
+ ' pixels in y-direction')
582
+
583
+ fixed = dataset[tuple(selection)].squeeze().compute()
584
+ fft_fixed = np.fft.fft2(fixed)
585
+
586
+ relative_drift = [[0., 0.]]
587
+
588
+ for i in trange(nimages):
589
+ selection[frame_dim[0]] = slice(i, i+1)
590
+ moving = dataset[tuple(selection)].squeeze().compute()
591
+ fft_moving = np.fft.fft2(moving)
592
+ if sub_pixel:
593
+ shift = skimage.registration.phase_cross_correlation(fft_fixed, fft_moving, upsample_factor=1000,
594
+ space='fourier')[0]
595
+ else:
596
+ image_product = fft_fixed * fft_moving.conj()
597
+ cc_image = np.fft.fftshift(np.fft.ifft2(image_product))
598
+ shift = np.array(ndimage.maximum_position(cc_image.real))-cc_image.shape[0]/2
599
+ fft_fixed = fft_moving
600
+ relative_drift.append(shift)
601
+ rig_reg, drift = rig_reg_drift(dataset, relative_drift)
602
+ crop_reg, input_crop = crop_image_stack(rig_reg, drift)
603
+
604
+ rigid_registered = sidpy.Dataset.from_array(crop_reg,
605
+ title='Rigid Registration',
606
+ data_type='IMAGE_STACK',
607
+ quantity=dataset.quantity,
608
+ units=dataset.units)
609
+ rigid_registered.title = 'Rigid_Registration'
610
+ rigid_registered.source = dataset.title
611
+ rigid_registered.metadata = {'analysis': 'rigid sub-pixel registration', 'drift': drift,
612
+ 'input_crop': input_crop, 'input_shape': dataset.shape[1:]}
613
+ rigid_registered.set_dimension(0, sidpy.Dimension(np.arange(rigid_registered.shape[0]),
614
+ name='frame', units='frame', quantity='time',
615
+ dimension_type='temporal'))
616
+
617
+ array_x = dataset._axes[spatial_dim[0]][input_crop[0]:input_crop[1]].values
618
+ rigid_registered.set_dimension(1, sidpy.Dimension(array_x,
619
+ 'x', units='nm', quantity='Length',
620
+ dimension_type='spatial'))
621
+ array_y = dataset._axes[spatial_dim[1]][input_crop[2]:input_crop[3]].values
622
+ rigid_registered.set_dimension(2, sidpy.Dimension(array_y,
623
+ 'y', units='nm', quantity='Length',
624
+ dimension_type='spatial'))
625
+ return rigid_registered.rechunk({0: 'auto', 1: -1, 2: -1})
626
+
627
+
628
+ def rig_reg_drift(dset, rel_drift):
629
+ """ Shifting images on top of each other
630
+
631
+ Uses relative drift to shift images on top of each other,
632
+ with center image as reference.
633
+ Shifting is done with shift routine of ndimage from scipy.
634
+ This function is used by rigid_registration routine
635
+
636
+ Parameters
637
+ ----------
638
+ dset: sidpy.Dataset
639
+ dataset with image_stack
640
+ rel_drift:
641
+ relative_drift from image to image as list of [shiftx, shifty]
642
+
643
+ Returns
644
+ -------
645
+ stack: numpy array
646
+ drift: list of drift in pixel
647
+ """
648
+
649
+ frame_dim = []
650
+ spatial_dim = []
651
+ selection = []
652
+
653
+ for i, axis in dset._axes.items():
654
+ if axis.dimension_type.name == 'SPATIAL':
655
+ spatial_dim.append(i)
656
+ selection.append(slice(None))
657
+ else:
658
+ frame_dim.append(i)
659
+ selection.append(slice(0, 1))
660
+
661
+ if len(spatial_dim) != 2:
662
+ print('need two spatial dimensions')
663
+ if len(frame_dim) != 1:
664
+ print('need one frame dimensions')
665
+
666
+ rig_reg = np.zeros([dset.shape[frame_dim[0]], dset.shape[spatial_dim[0]], dset.shape[spatial_dim[1]]])
667
+
668
+ # absolute drift
669
+ print(rel_drift)
670
+ drift = np.array(rel_drift).copy()
671
+
672
+ drift[0] = [0, 0]
673
+ for i in range(1, drift.shape[0]):
674
+ drift[i] = drift[i - 1] + rel_drift[i]
675
+ center_drift = drift[int(drift.shape[0] / 2)]
676
+ drift = drift - center_drift
677
+ # Shift images
678
+ for i in range(rig_reg.shape[0]):
679
+ selection[frame_dim[0]] = slice(i, i+1)
680
+ # Now we shift
681
+ rig_reg[i, :, :] = ndimage.shift(dset[tuple(selection)].squeeze().compute(),
682
+ [drift[i, 0], drift[i, 1]], order=3)
683
+ return rig_reg, drift
684
+
685
+
686
+ def crop_image_stack(rig_reg, drift):
687
+ """Crop images in stack according to drift
688
+
689
+ This function is used by rigid_registration routine
690
+
691
+ Parameters
692
+ ----------
693
+ rig_reg: numpy array (N,x,y)
694
+ drift: list (2,B)
695
+
696
+ Returns
697
+ -------
698
+ numpy array
699
+ """
700
+
701
+ xpmin = int(-np.floor(np.min(np.array(drift)[:, 0])))
702
+ xpmax = int(rig_reg.shape[1] - np.ceil(np.max(np.array(drift)[:, 0])))
703
+ ypmin = int(-np.floor(np.min(np.array(drift)[:, 1])))
704
+ ypmax = int(rig_reg.shape[2] - np.ceil(np.max(np.array(drift)[:, 1])))
705
+
706
+ return rig_reg[:, xpmin:xpmax, ypmin:ypmax], [xpmin, xpmax, ypmin, ypmax]
707
+
708
+
709
+ class ImageWithLineProfile:
710
+ """Image with line profile"""
711
+
712
+ def __init__(self, data, extent, title=''):
713
+ fig, ax = plt.subplots(1, 1)
714
+ self.figure = fig
715
+ self.title = title
716
+ self.line_plot = False
717
+ self.ax = ax
718
+ self.data = data
719
+ self.extent = extent
720
+ self.ax.imshow(data, extent=extent)
721
+ self.ax.set_title(title)
722
+ self.line, = self.ax.plot([0], [0], color='orange') # empty line
723
+ self.end_x = self.line.get_xdata()
724
+ self.end_y = self.line.get_ydata()
725
+ self.cid = self.line.figure.canvas.mpl_connect('button_press_event', self)
726
+
727
+ def __call__(self, event):
728
+ if event.inaxes != self.line.axes:
729
+ return
730
+ self.start_x = self.end_x
731
+ self.start_y = self.end_y
732
+
733
+ self.line.set_data([self.start_x, event.xdata], [self.start_y, event.ydata])
734
+ self.line.figure.canvas.draw()
735
+
736
+ self.end_x = event.xdata
737
+ self.end_y = event.ydata
738
+
739
+ self.update()
740
+
741
+ def update(self):
742
+ if not self.line_plot:
743
+ self.line_plot = True
744
+ self.figure.clear()
745
+ self.ax = self.figure.subplots(2, 1)
746
+ self.ax[0].imshow(self.data, extent=self.extent)
747
+ self.ax[0].set_title(self.title)
748
+
749
+ self.line, = self.ax[0].plot([0], [0], color='orange') # empty line
750
+ self.line_plot, = self.ax[1].plot([], [], color='orange')
751
+ self.ax[1].set_xlabel('distance [nm]')
752
+
753
+ x0 = self.start_x
754
+ x1 = self.end_x
755
+ y0 = self.start_y
756
+ y1 = self.end_y
757
+ length_plot = np.sqrt((x1-x0)**2+(y1-y0)**2)
758
+
759
+ num = length_plot*(self.data.shape[0]/self.extent[1])
760
+ x = np.linspace(x0, x1, num)*(self.data.shape[0]/self.extent[1])
761
+ y = np.linspace(y0, y1, num)*(self.data.shape[0]/self.extent[1])
762
+
763
+ # Extract the values along the line, using cubic interpolation
764
+ zi2 = ndimage.map_coordinates(self.data.T, np.vstack((x, y)))
765
+
766
+ x_axis = np.linspace(0, length_plot, len(zi2))
767
+ self.x = x_axis
768
+ self.z = zi2
769
+
770
+ self.line_plot.set_xdata(x_axis)
771
+ self.line_plot.set_ydata(zi2)
772
+ self.ax[1].set_xlim(0, x_axis.max())
773
+ self.ax[1].set_ylim(zi2.min(), zi2.max())
774
+ self.ax[1].draw()
775
+
776
+
777
+ class LineSelector(matplotlib.widgets.PolygonSelector):
778
+ def __init__(self, ax, onselect, line_width=1, **kwargs):
779
+ super().__init__(ax, onselect, **kwargs)
780
+ bounds = ax.viewLim.get_points()
781
+ np.max(bounds[0])
782
+ self.line_verts = np.array([[np.max(bounds[1])/2, np.max(bounds[0])/5], [np.max(bounds[1])/2,
783
+ np.max(bounds[0])/5+1],
784
+ [np.max(bounds[1])/5, np.max(bounds[0])/2], [np.max(bounds[1])/5,
785
+ np.max(bounds[0])/2]])
786
+ self.verts = self.line_verts
787
+ self.line_width = line_width
788
+
789
+ def set_linewidth(self, line_width=None):
790
+ if line_width is not None:
791
+ self.line_width = line_width
792
+
793
+ m = -(self.line_verts[0, 1]-self.line_verts[3, 1])/(self.line_verts[0, 0]-self.line_verts[3, 0])
794
+ c = 1/np.sqrt(1+m**2)
795
+ s = c*m
796
+ self.line_verts[1] = [self.line_verts[0, 0]+self.line_width*s, self.line_verts[0, 1]+self.line_width*c]
797
+ self.line_verts[2] = [self.line_verts[3, 0]+self.line_width*s, self.line_verts[3, 1]+self.line_width*c]
798
+
799
+ self.verts = self.line_verts.copy()
800
+
801
+ def onmove(self, event):
802
+ super().onmove(event)
803
+ if np.max(np.linalg.norm(self.line_verts-self.verts, axis=1)) > 1:
804
+ self.moved_point = np.argmax(np.linalg.norm(self.line_verts-self.verts, axis=1))
805
+
806
+ self.new_point = self.verts[self.moved_point]
807
+ moved_point = int(np.floor(self.moved_point/2)*3)
808
+ self.moved_point = moved_point
809
+ self.line_verts[moved_point] = self.new_point
810
+ self.set_linewidth()
811
+
812
+ def get_profile(dataset, line, spline_order=-1):
813
+ """
814
+ This function extracts a line profile from a given dataset. The line profile is a representation of the data values
815
+ along a specified line in the dataset. This function works for both image and spectral image data types.
816
+
817
+ Args:
818
+ dataset (sidpy.Dataset): The input dataset from which to extract the line profile.
819
+ line (list): A list specifying the line along which the profile should be extracted.
820
+ spline_order (int, optional): The order of the spline interpolation to use. Default is -1, which means no interpolation.
821
+
822
+ Returns:
823
+ profile_dataset (sidpy.Dataset): A new sidpy.Dataset containing the line profile.
824
+
825
+
826
+ """
827
+ xv, yv = get_line_selection_points(line)
828
+ if dataset.data_type.name == 'IMAGE':
829
+ dataset.get_image_dims()
830
+ xv /= (dataset.x[1] - dataset.x[0])
831
+ yv /= (dataset.y[1] - dataset.y[0])
832
+ profile = scipy.ndimage.map_coordinates(np.array(dataset), [xv, yv])
833
+
834
+ profile_dataset = sidpy.Dataset.from_array(profile.sum(axis=0))
835
+ profile_dataset.data_type='spectrum'
836
+ profile_dataset.units = dataset.units
837
+ profile_dataset.quantity = dataset.quantity
838
+ profile_dataset.set_dimension(0, sidpy.Dimension(np.linspace(xv[0,0], xv[-1,-1], profile_dataset.shape[0]),
839
+ name='x', units=dataset.x.units, quantity=dataset.x.quantity,
840
+ dimension_type='spatial'))
841
+
842
+ profile_dataset
843
+
844
+ if dataset.data_type.name == 'SPECTRAL_IMAGE':
845
+ spectral_axis = dataset.get_spectral_dims(return_axis=True)[0]
846
+ if spline_order > -1:
847
+ xv, yv, zv = get_line_selection_points_interpolated(line, z_length=dataset.shape[2])
848
+ profile = scipy.ndimage.map_coordinates(np.array(dataset), [xv, yv, zv], order=spline_order)
849
+ profile = profile.sum(axis=0)
850
+ profile = np.stack([profile, profile], axis=1)
851
+ start = xv[0, 0, 0]
852
+ else:
853
+ profile = get_line_profile(np.array(dataset), xv, yv, len(spectral_axis))
854
+ start = xv[0, 0]
855
+ print(profile.shape)
856
+ profile_dataset = sidpy.Dataset.from_array(profile)
857
+ profile_dataset.data_type='spectral_image'
858
+ profile_dataset.units = dataset.units
859
+ profile_dataset.quantity = dataset.quantity
860
+ profile_dataset.set_dimension(0, sidpy.Dimension(np.arange(profile_dataset.shape[0])+start,
861
+ name='x', units=dataset.x.units, quantity=dataset.x.quantity,
862
+ dimension_type='spatial'))
863
+ profile_dataset.set_dimension(1, sidpy.Dimension([0, 1],
864
+ name='y', units=dataset.x.units, quantity=dataset.x.quantity,
865
+ dimension_type='spatial'))
866
+
867
+ profile_dataset.set_dimension(2, spectral_axis)
868
+ return profile_dataset
869
+
870
+
871
+
872
+ def get_line_selection_points_interpolated(line, z_length=1):
873
+
874
+ start_point = line.line_verts[3]
875
+ right_point = line.line_verts[0]
876
+ low_point = line.line_verts[2]
877
+
878
+ if start_point[0] > right_point[0]:
879
+ start_point = line.line_verts[0]
880
+ right_point = line.line_verts[3]
881
+ low_point = line.line_verts[1]
882
+ m = (right_point[1] - start_point[1]) / (right_point[0] - start_point[0])
883
+ length_x = int(abs(start_point[0]-right_point[0]))
884
+ length_v = int(np.linalg.norm(start_point-right_point))
885
+
886
+ linewidth = int(abs(start_point[1]-low_point[1]))
887
+ x = np.linspace(0,length_x, length_v)
888
+ y = np.linspace(0,linewidth, line.line_width)
889
+ if z_length > 1:
890
+ z = np.linspace(0, z_length, z_length)
891
+ xv, yv, zv = np.meshgrid(x, y, np.arange(z_length))
892
+ x = np.atleast_2d(x).repeat(z_length, axis=0).T
893
+ y = np.atleast_2d(y).repeat(z_length, axis=0).T
894
+ else:
895
+ xv, yv = np.meshgrid(x, y)
896
+
897
+
898
+ yv = yv + x*m + start_point[1]
899
+ xv = (xv.swapaxes(0,1) -y*m ).swapaxes(0,1) + start_point[0]
900
+
901
+ if z_length > 1:
902
+ return xv, yv, zv
903
+ else:
904
+ return xv, yv
905
+
906
+
907
+ def get_line_selection_points(line):
908
+
909
+ start_point = line.line_verts[3]
910
+ right_point = line.line_verts[0]
911
+ low_point = line.line_verts[2]
912
+
913
+ if start_point[0] > right_point[0]:
914
+ start_point = line.line_verts[0]
915
+ right_point = line.line_verts[3]
916
+ low_point = line.line_verts[1]
917
+ m = (right_point[1] - start_point[1]) / (right_point[0] - start_point[0])
918
+ length_x = int(abs(start_point[0]-right_point[0]))
919
+ length_v = int(np.linalg.norm(start_point-right_point))
920
+
921
+ linewidth = int(abs(start_point[1]-low_point[1]))
922
+ x = np.linspace(0,length_x, length_v)
923
+ y = np.linspace(0,linewidth, line.line_width)
924
+ xv, yv = np.meshgrid(x, y)
925
+
926
+ yy = yv +x*m+start_point[1]
927
+ xx = (xv.T -y*m ).T + start_point[0]
928
+
929
+ return xx, yy
930
+
931
+
932
+ def get_line_profile(data, xv, yv, z_length):
933
+ profile = np.zeros([len(xv[0]), 2, z_length])
934
+ for index_x in range(xv.shape[1]):
935
+ for index_y in range(xv.shape[0]):
936
+ x = int(xv[index_y, index_x])
937
+ y = int(yv[index_y, index_x])
938
+ if x< data.shape[0] and x>0 and y < data.shape[1] and y>0:
939
+ profile[index_x, 0] +=data[x, y]
940
+ return profile
941
+
942
+
943
+ def histogram_plot(image_tags):
944
+ """interactive histogram"""
945
+ nbins = 75
946
+ color_map_list = ['gray', 'viridis', 'jet', 'hot']
947
+ if 'minimum_intensity' not in image_tags:
948
+ image_tags['minimum_intensity'] = image_tags['plotimage'].min()
949
+ minimum_intensity = image_tags['minimum_intensity']
950
+ if 'maximum_intensity' not in image_tags:
951
+ image_tags['maximum_intensity'] = image_tags['plotimage'].max()
952
+ data = image_tags['plotimage']
953
+ vmin = image_tags['minimum_intensity']
954
+ vmax = image_tags['maximum_intensity']
955
+ if 'color_map' not in image_tags:
956
+ image_tags['color_map'] = color_map_list[0]
957
+
958
+ cmap = plt.cm.get_cmap(image_tags['color_map'])
959
+ colors = cmap(np.linspace(0., 1., nbins))
960
+ norm2 = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
961
+ hist, bin_edges = np.histogram(data, np.linspace(vmin, vmax, nbins), density=True)
962
+
963
+ width = bin_edges[1]-bin_edges[0]
964
+
965
+ def onselect(vmin, vmax):
966
+ ax1.clear()
967
+ cmap = plt.cm.get_cmap(image_tags['color_map'])
968
+ colors = cmap(np.linspace(0., 1., nbins))
969
+ norm2 = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
970
+ hist2, bin_edges2 = np.histogram(data, np.linspace(vmin, vmax, nbins), density=True)
971
+
972
+ width2 = (bin_edges2[1]-bin_edges2[0])
973
+
974
+ for i in range(nbins-1):
975
+ histogram[i].xy = (bin_edges2[i], 0)
976
+ histogram[i].set_height(hist2[i])
977
+ histogram[i].set_width(width2)
978
+ histogram[i].set_facecolor(colors[i])
979
+ ax.set_xlim(vmin, vmax)
980
+ ax.set_ylim(0, hist2.max()*1.01)
981
+
982
+ cb1 = mpl.colorbar.ColorbarBase(ax1, cmap=cmap, norm=norm2, orientation='horizontal')
983
+
984
+ image_tags['minimum_intensity'] = vmin
985
+ image_tags['maximum_intensity'] = vmax
986
+
987
+ def onclick(event):
988
+ global event2
989
+ event2 = event
990
+ print('%s click: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' %
991
+ ('double' if event.dblclick else 'single', event.button,
992
+ event.x, event.y, event.xdata, event.ydata))
993
+ if event.inaxes == ax1:
994
+ if event.button == 3:
995
+ ind = color_map_list.index(image_tags['color_map'])+1
996
+ if ind == len(color_map_list):
997
+ ind = 0
998
+ image_tags['color_map'] = color_map_list[ind] # 'viridis'
999
+ vmin = image_tags['minimum_intensity']
1000
+ vmax = image_tags['maximum_intensity']
1001
+ else:
1002
+ vmax = data.max()
1003
+ vmin = data.min()
1004
+ onselect(vmin, vmax)
1005
+
1006
+ fig2 = plt.figure()
1007
+
1008
+ ax = fig2.add_axes([0., 0.2, 0.9, 0.7])
1009
+ ax1 = fig2.add_axes([0., 0.15, 0.9, 0.05])
1010
+
1011
+ histogram = ax.bar(bin_edges[0:-1], hist, width=width, color=colors, edgecolor='black', alpha=0.8)
1012
+ onselect(vmin, vmax)
1013
+ cb1 = mpl.colorbar.ColorbarBase(ax1, cmap=cmap, norm=norm2, orientation='horizontal')
1014
+
1015
+ rectprops = dict(facecolor='blue', alpha=0.5)
1016
+
1017
+ span = mwidgets.SpanSelector(ax, onselect, 'horizontal', rectprops=rectprops)
1018
+
1019
+ cid = fig2.canvas.mpl_connect('button_press_event', onclick)
1020
+ return span
1021
+
1022
+
1023
+ def clean_svd(im, pixel_size=1, source_size=5):
1024
+ """De-noising of image by using first component of single value decomposition"""
1025
+ patch_size = int(source_size/pixel_size)
1026
+ if patch_size < 3:
1027
+ patch_size = 3
1028
+ patches = image.extract_patches_2d(im, (patch_size, patch_size))
1029
+ patches = patches.reshape(patches.shape[0], patches.shape[1]*patches.shape[2])
1030
+
1031
+ num_components = 32
1032
+
1033
+ u, s, v = randomized_svd(patches, num_components)
1034
+ u_im_size = int(np.sqrt(u.shape[0]))
1035
+ reduced_image = u[:, 0].reshape(u_im_size, u_im_size)
1036
+ reduced_image = reduced_image/reduced_image.sum()*im.sum()
1037
+ return reduced_image
1038
+
1039
+
1040
+ def rebin(im, binning=2):
1041
+ """
1042
+ rebin an image by the number of pixels in x and y direction given by binning
1043
+
1044
+ Parameter
1045
+ ---------
1046
+ image: numpy array in 2 dimensions
1047
+
1048
+ Returns
1049
+ -------
1050
+ binned image as numpy array
1051
+ """
1052
+ if len(im.shape) == 2:
1053
+ return im.reshape((im.shape[0]//binning, binning, im.shape[1]//binning, binning)).mean(axis=3).mean(1)
1054
+ else:
1055
+ raise TypeError('not a 2D image')
1056
+
1057
+
1058
+ def cart2pol(points):
1059
+ """Cartesian to polar coordinate conversion
1060
+
1061
+ Parameters
1062
+ ---------
1063
+ points: float or numpy array
1064
+ points to be converted (Nx2)
1065
+
1066
+ Returns
1067
+ -------
1068
+ rho: float or numpy array
1069
+ distance
1070
+ phi: float or numpy array
1071
+ angle
1072
+ """
1073
+
1074
+ rho = np.linalg.norm(points[:, 0:2], axis=1)
1075
+ phi = np.arctan2(points[:, 1], points[:, 0])
1076
+
1077
+ return rho, phi
1078
+
1079
+
1080
+ def pol2cart(rho, phi):
1081
+ """Polar to Cartesian coordinate conversion
1082
+
1083
+ Parameters
1084
+ ----------
1085
+ rho: float or numpy array
1086
+ distance
1087
+ phi: float or numpy array
1088
+ angle
1089
+
1090
+ Returns
1091
+ -------
1092
+ x: float or numpy array
1093
+ x coordinates of converted points(Nx2)
1094
+ """
1095
+
1096
+ x = rho * np.cos(phi)
1097
+ y = rho * np.sin(phi)
1098
+ return x, y
1099
+
1100
+
1101
+ def xy2polar(points, rounding=1e-3):
1102
+ """ Conversion from carthesian to polar coordinates
1103
+
1104
+ the angles and distances are sorted by r and then phi
1105
+ The indices of this sort is also returned
1106
+
1107
+ Parameters
1108
+ ----------
1109
+ points: numpy array
1110
+ number of points in axis 0 first two elements in axis 1 are x and y
1111
+ rounding: int
1112
+ optional rounding in significant digits
1113
+
1114
+ Returns
1115
+ -------
1116
+ r, phi, sorted_indices
1117
+ """
1118
+
1119
+ r, phi = cart2pol(points)
1120
+
1121
+ phi = phi # %np.pi # only positive angles
1122
+ r = (np.floor(r/rounding))*rounding # Remove rounding error differences
1123
+
1124
+ sorted_indices = np.lexsort((phi, r)) # sort first by r and then by phi
1125
+ r = r[sorted_indices]
1126
+ phi = phi[sorted_indices]
1127
+
1128
+ return r, phi, sorted_indices
1129
+
1130
+
1131
+ def cartesian2polar(x, y, grid, r, t, order=3):
1132
+ """Transform cartesian grid to polar grid
1133
+
1134
+ Used by warp
1135
+ """
1136
+
1137
+ rr, tt = np.meshgrid(r, t)
1138
+
1139
+ new_x = rr*np.cos(tt)
1140
+ new_y = rr*np.sin(tt)
1141
+
1142
+ ix = interp1d(x, np.arange(len(x)))
1143
+ iy = interp1d(y, np.arange(len(y)))
1144
+
1145
+ new_ix = ix(new_x.ravel())
1146
+ new_iy = iy(new_y.ravel())
1147
+
1148
+ return ndimage.map_coordinates(grid, np.array([new_ix, new_iy]), order=order).reshape(new_x.shape)
1149
+
1150
+
1151
+ def warp(diff, center):
1152
+ """Takes a diffraction pattern (as a sidpy dataset)and warps it to a polar grid"""
1153
+
1154
+ # Define original polar grid
1155
+ nx = np.shape(diff)[0]
1156
+ ny = np.shape(diff)[1]
1157
+
1158
+ # Define center pixel
1159
+ pix2nm = np.gradient(diff.u.values)[0]
1160
+
1161
+ x = np.linspace(1, nx, nx, endpoint=True)-center[0]
1162
+ y = np.linspace(1, ny, ny, endpoint=True)-center[1]
1163
+ z = diff
1164
+
1165
+ # Define new polar grid
1166
+ nr = int(min([center[0], center[1], diff.shape[0]-center[0], diff.shape[1]-center[1]])-1)
1167
+ nt = 360 * 3
1168
+
1169
+ r = np.linspace(1, nr, nr)
1170
+ t = np.linspace(0., np.pi, nt, endpoint=False)
1171
+
1172
+ return cartesian2polar(x, y, z, r, t, order=3).T
1173
+
1174
+
1175
+ def calculate_ctf(wavelength, cs, defocus, k):
1176
+ """ Calculate Contrast Transfer Function
1177
+
1178
+ everything in nm
1179
+
1180
+ Parameters
1181
+ ----------
1182
+ wavelength: float
1183
+ deBroglie wavelength of electrons
1184
+ cs: float
1185
+ spherical aberration coefficient
1186
+ defocus: float
1187
+ defocus
1188
+ k: numpy array
1189
+ reciprocal scale
1190
+
1191
+ Returns
1192
+ -------
1193
+ ctf: numpy array
1194
+ contrast transfer function
1195
+
1196
+ """
1197
+ ctf = np.sin(np.pi*defocus*wavelength*k**2+0.5*np.pi*cs*wavelength**3*k**4)
1198
+ return ctf
1199
+
1200
+
1201
+ def calculate_scherzer(wavelength, cs):
1202
+ """
1203
+ Calculate the Scherzer defocus. Cs is in mm, lambda is in nm
1204
+
1205
+ # Input and output in nm
1206
+ """
1207
+
1208
+ scherzer = -1.155*(cs*wavelength)**0.5 # in m
1209
+ return scherzer
1210
+
1211
+
1212
+ def get_rotation(experiment_spots, crystal_spots):
1213
+ """Get rotation by comparing spots in diffractogram to diffraction Bragg spots
1214
+
1215
+ Parameter
1216
+ ---------
1217
+ experiment_spots: numpy array (nx2)
1218
+ positions (in 1/nm) of spots in diffractogram
1219
+ crystal_spots: numpy array (nx2)
1220
+ positions (in 1/nm) of Bragg spots according to kinematic scattering theory
1221
+
1222
+ """
1223
+
1224
+ r_experiment, phi_experiment = cart2pol(experiment_spots)
1225
+
1226
+ # get crystal spots of same length and sort them by angle as well
1227
+ r_crystal, phi_crystal, crystal_indices = xy2polar(crystal_spots)
1228
+ angle_index = np.argmin(np.abs(r_experiment-r_crystal[1]))
1229
+ rotation_angle = phi_experiment[angle_index] % (2*np.pi) - phi_crystal[1]
1230
+ print(phi_experiment[angle_index])
1231
+ st = np.sin(rotation_angle)
1232
+ ct = np.cos(rotation_angle)
1233
+ rotation_matrix = np.array([[ct, -st], [st, ct]])
1234
+
1235
+ return rotation_matrix, rotation_angle
1236
+
1237
+
1238
+ def calibrate_image_scale(fft_tags, spots_reference, spots_experiment):
1239
+ """depreciated get change of scale from comparison of spots to Bragg angles """
1240
+ gx = fft_tags['spatial_scale_x']
1241
+ gy = fft_tags['spatial_scale_y']
1242
+
1243
+ dist_reference = np.linalg.norm(spots_reference, axis=1)
1244
+ distance_experiment = np.linalg.norm(spots_experiment, axis=1)
1245
+
1246
+ first_reflections = abs(distance_experiment - dist_reference.min()) < .2
1247
+ print('Evaluate ', first_reflections.sum(), 'reflections')
1248
+ closest_exp_reflections = spots_experiment[first_reflections]
1249
+
1250
+ def func(params, xdata, ydata):
1251
+ dgx, dgy = params
1252
+ return np.sqrt((xdata * dgx) ** 2 + (ydata * dgy) ** 2) - dist_reference.min()
1253
+
1254
+ x0 = [1.001, 0.999]
1255
+ [dg, sig] = optimization.leastsq(func, x0, args=(closest_exp_reflections[:, 0], closest_exp_reflections[:, 1]))
1256
+ return dg
1257
+
1258
+
1259
+ def align_crystal_reflections(spots, crystals):
1260
+ """ Depreciated - use diffraction spots"""
1261
+
1262
+ crystal_reflections_polar = []
1263
+ angles = []
1264
+ exp_r, exp_phi = cart2pol(spots) # just in polar coordinates
1265
+ spots_polar = np.array([exp_r, exp_phi])
1266
+
1267
+ for i in range(len(crystals)):
1268
+ tags = crystals[i]
1269
+ r, phi, indices = xy2polar(tags['allowed']['g']) # sorted by r and phi , only positive angles
1270
+ # we mask the experimental values that are found already
1271
+ angle = 0.
1272
+
1273
+ angle_i = np.argmin(np.abs(exp_r - r[1]))
1274
+ angle = exp_phi[angle_i] - phi[0]
1275
+ angles.append(angle) # Determine rotation angle
1276
+
1277
+ crystal_reflections_polar.append([r, angle + phi, indices])
1278
+ tags['allowed']['g_rotated'] = pol2cart(r, angle + phi)
1279
+ for spot in tags['allowed']['g']:
1280
+ dif = np.linalg.norm(spots[:, 0:2]-spot[0:2], axis=1)
1281
+ # print(dif.min())
1282
+ if dif.min() < 1.5:
1283
+ ind = np.argmin(dif)
1284
+
1285
+ return crystal_reflections_polar, angles
1286
+
1287
+
1288
+ # Deconvolution
1289
+ def decon_lr(o_image, probe, verbose=False):
1290
+ """
1291
+ # This task generates a restored image from an input image and point spread function (PSF) using
1292
+ # the algorithm developed independently by Lucy (1974, Astron. J. 79, 745) and Richardson
1293
+ # (1972, J. Opt. Soc. Am. 62, 55) and adapted for HST imagery by Snyder
1294
+ # (1990, in Restoration of HST Images and Spectra, ST ScI Workshop Proceedings; see also
1295
+ # Snyder, Hammoud, & White, JOSA, v. 10, no. 5, May 1993, in press).
1296
+ # Additional options developed by Rick White (STScI) are also included.
1297
+ #
1298
+ # The Lucy-Richardson method can be derived from the maximum likelihood expression for data
1299
+ # with a Poisson noise distribution. Thus, it naturally applies to optical imaging data such as HST.
1300
+ # The method forces the restored image to be positive, in accord with photon-counting statistics.
1301
+ #
1302
+ # The Lucy-Richardson algorithm generates a restored image through an iterative method. The essence
1303
+ # of the iteration is as follows: the (n+1)th estimate of the restored image is given by the nth estimate
1304
+ # of the restored image multiplied by a correction image. That is,
1305
+ #
1306
+ # original data
1307
+ # image = image --------------- * reflect(PSF)
1308
+ # n+1 n image * PSF
1309
+ # n
1310
+
1311
+ # where the *'s represent convolution operators and reflect(PSF) is the reflection of the PSF, i.e.
1312
+ # reflect((PSF)(x,y)) = PSF(-x,-y). When the convolutions are carried out using fast Fourier transforms
1313
+ # (FFTs), one can use the fact that FFT(reflect(PSF)) = conj(FFT(PSF)), where conj is the complex conjugate
1314
+ # operator.
1315
+ """
1316
+
1317
+ if len(o_image) < 1:
1318
+ return o_image
1319
+
1320
+ if o_image.shape != probe.shape:
1321
+ print('Weirdness ', o_image.shape, ' != ', probe.shape)
1322
+
1323
+ probe_c = np.ones(probe.shape, dtype=np.complex64)
1324
+ probe_c.real = probe
1325
+
1326
+ error = np.ones(o_image.shape, dtype=np.complex64)
1327
+ est = np.ones(o_image.shape, dtype=np.complex64)
1328
+ source = np.ones(o_image.shape, dtype=np.complex64)
1329
+ source.real = o_image
1330
+
1331
+ response_ft = fftpack.fft2(probe_c)
1332
+
1333
+ ap_angle = o_image.metadata['experiment']['convergence_angle'] / 1000.0 # now in rad
1334
+
1335
+ e0 = float(o_image.metadata['experiment']['acceleration_voltage'])
1336
+
1337
+ wl = get_wavelength(e0)
1338
+ o_image.metadata['experiment']['wavelength'] = wl
1339
+
1340
+ over_d = 2 * ap_angle / wl
1341
+
1342
+ dx = o_image.x[1]-o_image.x[0]
1343
+ dk = 1.0 / float(o_image.x[-1]) # last value of x-axis is field of view
1344
+ screen_width = 1 / dx
1345
+
1346
+ aperture = np.ones(o_image.shape, dtype=np.complex64)
1347
+ # Mask for the aperture before the Fourier transform
1348
+ n = o_image.shape[0]
1349
+ size_x = o_image.shape[0]
1350
+ size_y = o_image.shape[1]
1351
+ app_ratio = over_d / screen_width * n
1352
+
1353
+ theta_x = np.array(-size_x / 2. + np.arange(size_x))
1354
+ theta_y = np.array(-size_y / 2. + np.arange(size_y))
1355
+ t_xv, t_yv = np.meshgrid(theta_x, theta_y)
1356
+
1357
+ tp1 = t_xv ** 2 + t_yv ** 2 >= app_ratio ** 2
1358
+ aperture[tp1.T] = 0.
1359
+ # print(app_ratio, screen_width, dk)
1360
+
1361
+ progress = tqdm(total=500)
1362
+ # de = 100
1363
+ dest = 100
1364
+ i = 0
1365
+ while abs(dest) > 0.0001: # or abs(de) > .025:
1366
+ i += 1
1367
+ error_old = np.sum(error.real)
1368
+ est_old = est.copy()
1369
+ error = source / np.real(fftpack.fftshift(fftpack.ifft2(fftpack.fft2(est) * response_ft)))
1370
+ est = est * np.real(fftpack.fftshift(fftpack.ifft2(fftpack.fft2(error) * np.conjugate(response_ft))))
1371
+ # est = est_old * est
1372
+ # est = np.real(fftpack.fftshift(fftpack.ifft2(fftpack.fft2(est)*fftpack.fftshift(aperture) )))
1373
+
1374
+ error_new = np.real(np.sum(np.power(error, 2))) - error_old
1375
+ dest = np.sum(np.power((est - est_old).real, 2)) / np.sum(est) * 100
1376
+ # print(np.sum((est.real - est_old.real)* (est.real - est_old.real) )/np.sum(est.real)*100 )
1377
+
1378
+ if error_old != 0:
1379
+ de = error_new / error_old * 1.0
1380
+ else:
1381
+ de = error_new
1382
+
1383
+ if verbose:
1384
+ print(
1385
+ ' LR Deconvolution - Iteration: {0:d} Error: {1:.2f} = change: {2:.5f}%, {3:.5f}%'.format(i, error_new,
1386
+ de,
1387
+ abs(dest)))
1388
+ if i > 500:
1389
+ dest = 0.0
1390
+ print('terminate')
1391
+ progress.update(1)
1392
+ progress.write(f"converged in {i} iterations")
1393
+ # progress.close()
1394
+ print('\n Lucy-Richardson deconvolution converged in ' + str(i) + ' iterations')
1395
+ est2 = np.real(fftpack.ifft2(fftpack.fft2(est) * fftpack.fftshift(aperture)))
1396
+ out_dataset = o_image.like_data(est2)
1397
+ out_dataset.title = 'Lucy Richardson deconvolution'
1398
+ out_dataset.data_type = 'image'
1399
+ return out_dataset