pyTEMlib 0.2024.2.1__py2.py3-none-any.whl → 0.2024.6.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyTEMlib might be problematic. Click here for more details.

pyTEMlib/image_tools.py CHANGED
@@ -50,6 +50,11 @@ from sklearn.cluster import DBSCAN
50
50
 
51
51
  from collections import Counter
52
52
 
53
+ # center diff function
54
+ from skimage.filters import threshold_otsu, sobel
55
+ from scipy.optimize import leastsq
56
+ from sklearn.cluster import DBSCAN
57
+
53
58
 
54
59
  _SimpleITK_present = True
55
60
  try:
@@ -275,6 +280,67 @@ def diffractogram_spots(dset, spot_threshold, return_center=True, eps=0.1):
275
280
  return spots, center
276
281
 
277
282
 
283
+ def center_diffractogram(dset, return_plot = True, histogram_factor = None, smoothing = 1, min_samples = 100):
284
+ try:
285
+ diff = np.array(dset).T.astype(np.float16)
286
+ diff[diff < 0] = 0
287
+
288
+ if histogram_factor is not None:
289
+ hist, bins = np.histogram(np.ravel(diff), bins=256, range=(0, 1), density=True)
290
+ threshold = threshold_otsu(diff, hist = hist * histogram_factor)
291
+ else:
292
+ threshold = threshold_otsu(diff)
293
+ binary = (diff > threshold).astype(float)
294
+ smoothed_image = ndimage.gaussian_filter(binary, sigma=smoothing) # Smooth before edge detection
295
+ smooth_threshold = threshold_otsu(smoothed_image)
296
+ smooth_binary = (smoothed_image > smooth_threshold).astype(float)
297
+ # Find the edges using the Sobel operator
298
+ edges = sobel(smooth_binary)
299
+ edge_points = np.argwhere(edges)
300
+
301
+ # Use DBSCAN to cluster the edge points
302
+ db = DBSCAN(eps=10, min_samples=min_samples).fit(edge_points)
303
+ labels = db.labels_
304
+ if len(set(labels)) == 1:
305
+ raise ValueError("DBSCAN clustering resulted in only one group, check the parameters.")
306
+
307
+ # Get the largest group of edge points
308
+ unique, counts = np.unique(labels, return_counts=True)
309
+ counts = dict(zip(unique, counts))
310
+ largest_group = max(counts, key=counts.get)
311
+ edge_points = edge_points[labels == largest_group]
312
+
313
+ # Fit a circle to the diffraction ring
314
+ def calc_distance(c, x, y):
315
+ Ri = np.sqrt((x - c[0])**2 + (y - c[1])**2)
316
+ return Ri - Ri.mean()
317
+ x_m = np.mean(edge_points[:, 1])
318
+ y_m = np.mean(edge_points[:, 0])
319
+ center_guess = x_m, y_m
320
+ center, ier = leastsq(calc_distance, center_guess, args=(edge_points[:, 1], edge_points[:, 0]))
321
+ mean_radius = np.mean(calc_distance(center, edge_points[:, 1], edge_points[:, 0])) + np.sqrt((edge_points[:, 1] - center[0])**2 + (edge_points[:, 0] - center[1])**2).mean()
322
+
323
+ finally:
324
+ if return_plot:
325
+ fig, ax = plt.subplots(1, 4, figsize=(10, 4))
326
+ ax[0].set_title('Diffractogram')
327
+ ax[0].imshow(dset.T, cmap='viridis')
328
+ ax[1].set_title('Otsu Binary Image')
329
+ ax[1].imshow(binary, cmap='gray')
330
+ ax[2].set_title('Smoothed Binary Image')
331
+ ax[2].imshow(smooth_binary, cmap='gray')
332
+ ax[3].set_title('Edge Detection and Fitting')
333
+ ax[3].imshow(edges, cmap='gray')
334
+ ax[3].scatter(center[0], center[1], c='r', s=10)
335
+ circle = plt.Circle(center, mean_radius, color='red', fill=False)
336
+ ax[3].add_artist(circle)
337
+ for axis in ax:
338
+ axis.axis('off')
339
+ fig.tight_layout()
340
+
341
+ return center
342
+
343
+
278
344
  def adaptive_fourier_filter(dset, spots, low_pass=3, reflection_radius=0.3):
279
345
  """
280
346
  Use spots in diffractogram for a Fourier Filter
@@ -375,16 +441,10 @@ def complete_registration(main_dataset, storage_channel=None):
375
441
 
376
442
  rigid_registered_dataset = rigid_registration(main_dataset)
377
443
 
378
- if storage_channel is not None:
379
- registration_channel = ft.log_results(storage_channel, rigid_registered_dataset)
380
-
444
+
381
445
  print('Non-Rigid_Registration')
382
446
 
383
447
  non_rigid_registered = demon_registration(rigid_registered_dataset)
384
- if storage_channel is not None:
385
- registration_channel = ft.log_results(storage_channel, non_rigid_registered)
386
-
387
- non_rigid_registered.h5_dataset = registration_channel
388
448
  return non_rigid_registered, rigid_registered_dataset
389
449
 
390
450
 
@@ -473,7 +533,7 @@ def demon_registration(dataset, verbose=False):
473
533
  ###############################
474
534
  # Rigid Registration New 05/09/2020
475
535
 
476
- def rigid_registration(dataset):
536
+ def rigid_registration(dataset, sub_pixel=True):
477
537
  """
478
538
  Rigid registration of image stack with pixel accuracy
479
539
 
@@ -529,9 +589,13 @@ def rigid_registration(dataset):
529
589
  selection[frame_dim[0]] = slice(i, i+1)
530
590
  moving = dataset[tuple(selection)].squeeze().compute()
531
591
  fft_moving = np.fft.fft2(moving)
532
- image_product = fft_fixed * fft_moving.conj()
533
- cc_image = np.fft.fftshift(np.fft.ifft2(image_product))
534
- shift = np.array(ndimage.maximum_position(cc_image.real))-cc_image.shape[0]/2
592
+ if sub_pixel:
593
+ shift = skimage.registration.phase_cross_correlation(fft_fixed, fft_moving, upsample_factor=1000,
594
+ space='fourier')[0]
595
+ else:
596
+ image_product = fft_fixed * fft_moving.conj()
597
+ cc_image = np.fft.fftshift(np.fft.ifft2(image_product))
598
+ shift = np.array(ndimage.maximum_position(cc_image.real))-cc_image.shape[0]/2
535
599
  fft_fixed = fft_moving
536
600
  relative_drift.append(shift)
537
601
  rig_reg, drift = rig_reg_drift(dataset, relative_drift)
@@ -542,9 +606,18 @@ def rigid_registration(dataset):
542
606
  rigid_registered.source = dataset.title
543
607
  rigid_registered.metadata = {'analysis': 'rigid sub-pixel registration', 'drift': drift,
544
608
  'input_crop': input_crop, 'input_shape': dataset.shape[1:]}
545
- rigid_registered.set_dimension(0, dataset._axes[frame_dim[0]])
546
- rigid_registered.set_dimension(1, dataset._axes[spatial_dim[0]][input_crop[0]:input_crop[1]])
547
- rigid_registered.set_dimension(2, dataset._axes[spatial_dim[1]][input_crop[2]:input_crop[3]])
609
+ rigid_registered.set_dimension(0, sidpy.Dimension(np.arange(rigid_registered.shape[0]),
610
+ name='frame', units='frame', quantity='time',
611
+ dimension_type='temporal'))
612
+
613
+ array_x = dataset._axes[spatial_dim[0]][input_crop[0]:input_crop[1]].values
614
+ rigid_registered.set_dimension(1, sidpy.Dimension(array_x,
615
+ 'x', units='nm', quantity='Length',
616
+ dimension_type='spatial'))
617
+ array_y = dataset._axes[spatial_dim[1]][input_crop[2]:input_crop[3]].values
618
+ rigid_registered.set_dimension(2, sidpy.Dimension(array_y,
619
+ 'y', units='nm', quantity='Length',
620
+ dimension_type='spatial'))
548
621
  return rigid_registered.rechunk({0: 'auto', 1: -1, 2: -1})
549
622
 
550
623
 
@@ -589,6 +662,7 @@ def rig_reg_drift(dset, rel_drift):
589
662
  rig_reg = np.zeros([dset.shape[frame_dim[0]], dset.shape[spatial_dim[0]], dset.shape[spatial_dim[1]]])
590
663
 
591
664
  # absolute drift
665
+ print(rel_drift)
592
666
  drift = np.array(rel_drift).copy()
593
667
 
594
668
  drift[0] = [0, 0]
@@ -731,15 +805,27 @@ class LineSelector(matplotlib.widgets.PolygonSelector):
731
805
  self.line_verts[moved_point] = self.new_point
732
806
  self.set_linewidth()
733
807
 
734
- def get_profile(dataset, line):
808
+ def get_profile(dataset, line, spline_order=-1):
809
+ """
810
+ This function extracts a line profile from a given dataset. The line profile is a representation of the data values
811
+ along a specified line in the dataset. This function works for both image and spectral image data types.
812
+
813
+ Args:
814
+ dataset (sidpy.Dataset): The input dataset from which to extract the line profile.
815
+ line (list): A list specifying the line along which the profile should be extracted.
816
+ spline_order (int, optional): The order of the spline interpolation to use. Default is -1, which means no interpolation.
817
+
818
+ Returns:
819
+ profile_dataset (sidpy.Dataset): A new sidpy.Dataset containing the line profile.
820
+
821
+
822
+ """
735
823
  xv, yv = get_line_selection_points(line)
736
-
737
-
738
824
  if dataset.data_type.name == 'IMAGE':
739
825
  dataset.get_image_dims()
740
826
  xv /= (dataset.x[1] - dataset.x[0])
741
827
  yv /= (dataset.y[1] - dataset.y[0])
742
- profile = scipy.ndimage.map_coordinates(np.array(dataset), [xv,yv])
828
+ profile = scipy.ndimage.map_coordinates(np.array(dataset), [xv, yv])
743
829
 
744
830
  profile_dataset = sidpy.Dataset.from_array(profile.sum(axis=0))
745
831
  profile_dataset.data_type='spectrum'
@@ -753,19 +839,21 @@ def get_profile(dataset, line):
753
839
 
754
840
  if dataset.data_type.name == 'SPECTRAL_IMAGE':
755
841
  spectral_axis = dataset.get_spectral_dims(return_axis=True)[0]
756
- profile = np.zeros([xv.shape[1], 2, len(spectral_axis)])
757
- data =np.array(dataset)
758
-
759
- for index_x in range(xv.shape[1]):
760
- for index_y in range(xv.shape[0]):
761
- x = xv[index_y, index_x]
762
- y = yv[index_y, index_x]
763
- profile[index_x, 0] +=data[int(x),int(y)]
842
+ if spline_order > -1:
843
+ xv, yv, zv = get_line_selection_points_interpolated(line, z_length=dataset.shape[2])
844
+ profile = scipy.ndimage.map_coordinates(np.array(dataset), [xv, yv, zv], order=spline_order)
845
+ profile = profile.sum(axis=0)
846
+ profile = np.stack([profile, profile], axis=1)
847
+ start = xv[0, 0, 0]
848
+ else:
849
+ profile = get_line_profile(np.array(dataset), xv, yv, len(spectral_axis))
850
+ start = xv[0, 0]
851
+ print(profile.shape)
764
852
  profile_dataset = sidpy.Dataset.from_array(profile)
765
853
  profile_dataset.data_type='spectral_image'
766
854
  profile_dataset.units = dataset.units
767
855
  profile_dataset.quantity = dataset.quantity
768
- profile_dataset.set_dimension(0, sidpy.Dimension(np.linspace(xv[0,0], xv[-1,-1], profile_dataset.shape[0]),
856
+ profile_dataset.set_dimension(0, sidpy.Dimension(np.arange(profile_dataset.shape[0])+start,
769
857
  name='x', units=dataset.x.units, quantity=dataset.x.quantity,
770
858
  dimension_type='spatial'))
771
859
  profile_dataset.set_dimension(1, sidpy.Dimension([0, 1],
@@ -776,6 +864,42 @@ def get_profile(dataset, line):
776
864
  return profile_dataset
777
865
 
778
866
 
867
+
868
+ def get_line_selection_points_interpolated(line, z_length=1):
869
+
870
+ start_point = line.line_verts[3]
871
+ right_point = line.line_verts[0]
872
+ low_point = line.line_verts[2]
873
+
874
+ if start_point[0] > right_point[0]:
875
+ start_point = line.line_verts[0]
876
+ right_point = line.line_verts[3]
877
+ low_point = line.line_verts[1]
878
+ m = (right_point[1] - start_point[1]) / (right_point[0] - start_point[0])
879
+ length_x = int(abs(start_point[0]-right_point[0]))
880
+ length_v = int(np.linalg.norm(start_point-right_point))
881
+
882
+ linewidth = int(abs(start_point[1]-low_point[1]))
883
+ x = np.linspace(0,length_x, length_v)
884
+ y = np.linspace(0,linewidth, line.line_width)
885
+ if z_length > 1:
886
+ z = np.linspace(0, z_length, z_length)
887
+ xv, yv, zv = np.meshgrid(x, y, np.arange(z_length))
888
+ x = np.atleast_2d(x).repeat(z_length, axis=0).T
889
+ y = np.atleast_2d(y).repeat(z_length, axis=0).T
890
+ else:
891
+ xv, yv = np.meshgrid(x, y)
892
+
893
+
894
+ yv = yv + x*m + start_point[1]
895
+ xv = (xv.swapaxes(0,1) -y*m ).swapaxes(0,1) + start_point[0]
896
+
897
+ if z_length > 1:
898
+ return xv, yv, zv
899
+ else:
900
+ return xv, yv
901
+
902
+
779
903
  def get_line_selection_points(line):
780
904
 
781
905
  start_point = line.line_verts[3]
@@ -801,6 +925,16 @@ def get_line_selection_points(line):
801
925
  return xx, yy
802
926
 
803
927
 
928
+ def get_line_profile(data, xv, yv, z_length):
929
+ profile = np.zeros([len(xv[0]), 2, z_length])
930
+ for index_x in range(xv.shape[1]):
931
+ for index_y in range(xv.shape[0]):
932
+ x = int(xv[index_y, index_x])
933
+ y = int(yv[index_y, index_x])
934
+ if x< data.shape[0] and x>0 and y < data.shape[1] and y>0:
935
+ profile[index_x, 0] +=data[x, y]
936
+ return profile
937
+
804
938
 
805
939
  def histogram_plot(image_tags):
806
940
  """interactive histogram"""
@@ -1010,9 +1144,8 @@ def cartesian2polar(x, y, grid, r, t, order=3):
1010
1144
  return ndimage.map_coordinates(grid, np.array([new_ix, new_iy]), order=order).reshape(new_x.shape)
1011
1145
 
1012
1146
 
1013
- def warp(diff):
1014
- """Takes a centered diffraction pattern (as a sidpy dataset)and warps it to a polar grid"""
1015
- """Centered diff can be produced with it.diffractogram_spots(return_center = True)"""
1147
+ def warp(diff, center):
1148
+ """Takes a diffraction pattern (as a sidpy dataset)and warps it to a polar grid"""
1016
1149
 
1017
1150
  # Define original polar grid
1018
1151
  nx = np.shape(diff)[0]
@@ -1020,20 +1153,19 @@ def warp(diff):
1020
1153
 
1021
1154
  # Define center pixel
1022
1155
  pix2nm = np.gradient(diff.u.values)[0]
1023
- center_pixel = [abs(min(diff.u.values)), abs(min(diff.v.values))]//pix2nm
1024
1156
 
1025
- x = np.linspace(1, nx, nx, endpoint=True)-center_pixel[0]
1026
- y = np.linspace(1, ny, ny, endpoint=True)-center_pixel[1]
1157
+ x = np.linspace(1, nx, nx, endpoint=True)-center[0]
1158
+ y = np.linspace(1, ny, ny, endpoint=True)-center[1]
1027
1159
  z = diff
1028
1160
 
1029
1161
  # Define new polar grid
1030
- nr = int(min([center_pixel[0], center_pixel[1], diff.shape[0]-center_pixel[0], diff.shape[1]-center_pixel[1]])-1)
1031
- nt = 360*3
1162
+ nr = int(min([center[0], center[1], diff.shape[0]-center[0], diff.shape[1]-center[1]])-1)
1163
+ nt = 360 * 3
1032
1164
 
1033
1165
  r = np.linspace(1, nr, nr)
1034
1166
  t = np.linspace(0., np.pi, nt, endpoint=False)
1035
1167
 
1036
- return cartesian2polar(x, y, z, r, t, order=3)
1168
+ return cartesian2polar(x, y, z, r, t, order=3).T
1037
1169
 
1038
1170
 
1039
1171
  def calculate_ctf(wavelength, cs, defocus, k):