pyTEMlib 0.2023.3.0__py2.py3-none-any.whl → 0.2023.8.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pyTEMlib might be problematic. Click here for more details.

pyTEMlib/graph_tools.py CHANGED
@@ -3,15 +3,19 @@
3
3
  """
4
4
  import numpy as np
5
5
  # import ase
6
+ import sys
6
7
 
7
8
  # from scipy.spatial import cKDTree, Voronoi, ConvexHull
8
9
  import scipy.spatial
9
10
  import scipy.optimize
11
+ import scipy.interpolate
10
12
 
11
- # from skimage.measure import grid_points_in_poly, points_in_poly
13
+ from skimage.measure import grid_points_in_poly, points_in_poly
12
14
 
13
15
  # import plotly.graph_objects as go
14
16
  # import plotly.express as px
17
+ import matplotlib.patches as patches
18
+
15
19
  import pyTEMlib.crystal_tools
16
20
  from tqdm.auto import tqdm, trange
17
21
 
@@ -136,10 +140,13 @@ def get_bond_radii(atoms, bond_type='bond'):
136
140
 
137
141
  r_a = []
138
142
  for atom in atoms:
139
- if bond_type == 'covalent':
140
- r_a.append(pyTEMlib.crystal_tools.electronFF[atom.symbol]['bond_length'][0])
143
+ if atom.symbol == 'X':
144
+ r_a.append(1.2)
141
145
  else:
142
- r_a.append(pyTEMlib.crystal_tools.electronFF[atom.symbol]['bond_length'][1])
146
+ if bond_type == 'covalent':
147
+ r_a.append(pyTEMlib.crystal_tools.electronFF[atom.symbol]['bond_length'][0])
148
+ else:
149
+ r_a.append(pyTEMlib.crystal_tools.electronFF[atom.symbol]['bond_length'][1])
143
150
  if atoms.info is None:
144
151
  atoms.info = {}
145
152
  atoms.info['bond_radii'] = r_a
@@ -179,7 +186,7 @@ def set_bond_radii(atoms, bond_type='bond'):
179
186
  return r_a
180
187
 
181
188
 
182
- def get_voronoi(tetrahedra, atoms, optimize=True):
189
+ def get_voronoi(tetrahedra, atoms, bond_radii=None, optimize=True):
183
190
  """
184
191
  Find Voronoi vertices and keep track of associated tetrahedrons and interstitial radii
185
192
 
@@ -207,8 +214,12 @@ def get_voronoi(tetrahedra, atoms, optimize=True):
207
214
  extent = atoms.cell.lengths()
208
215
  if atoms.info is None:
209
216
  atoms.info = {}
210
- if 'bond_radii' in atoms.info:
217
+
218
+ if bond_radii is not None:
219
+ bond_radii = [bond_radii]*len(atoms)
220
+ elif 'bond_radii' in atoms.info:
211
221
  bond_radii = atoms.info['bond_radii']
222
+
212
223
  else:
213
224
  bond_radii = get_bond_radii(atoms)
214
225
 
@@ -399,8 +410,271 @@ def make_polyhedrons(atoms, voronoi_vertices, voronoi_tetrahedrons, clusters, vi
399
410
  # polyhedra functions
400
411
  ##################################################################
401
412
 
413
+ def get_non_periodic_supercell(super_cell):
414
+ super_cell.wrap()
415
+ atoms = super_cell*3
416
+ atoms.positions -= super_cell.cell.lengths()
417
+ atoms.positions[:,0] += super_cell.cell[0,0]*.0
418
+ del(atoms[atoms.positions[: , 0]<-5])
419
+ del(atoms[atoms.positions[: , 0]>super_cell.cell[0,0]+5])
420
+ del(atoms[atoms.positions[: , 1]<-5])
421
+ del(atoms[atoms.positions[: , 1]>super_cell.cell[1,1]+5])
422
+ del(atoms[atoms.positions[: , 2]<-5])
423
+ del(atoms[atoms.positions[: , 2]>super_cell.cell[2,2]+5])
424
+ return atoms
425
+
426
+ def get_connectivity_matrix(crystal, atoms, polyhedra):
427
+ crystal_tree = scipy.spatial.cKDTree(crystal.positions)
428
+
429
+
430
+ connectivity_matrix = np.zeros([len(atoms),len(atoms)], dtype=int)
431
+
432
+ for polyhedron in polyhedra.values():
433
+ vertices = polyhedron['vertices'] - crystal.cell.lengths()
434
+ atom_ind = np.array(polyhedron['indices'])
435
+ dd, polyhedron['atom_indices'] = crystal_tree.query(vertices , k=1)
436
+ to_bond = np.where(dd<0.001)[0]
437
+
438
+ for triangle in polyhedron['triangles']:
439
+ triangle = np.array(triangle)
440
+ for permut in [[0,1], [1,2], [0,2]]:
441
+ vertex = [np.min(triangle[permut]), np.max(triangle[permut])]
442
+ if vertex[0] in to_bond or vertex[1] in to_bond:
443
+ connectivity_matrix[atom_ind[vertex[1]], atom_ind[vertex[0]]] = 1
444
+ connectivity_matrix[atom_ind[vertex[0]], atom_ind[vertex[1]]] = 1
445
+ return connectivity_matrix
446
+
447
+
402
448
 
403
- def find_polyhedra(atoms, optimize=True, cheat=1.0):
449
+ def get_bonds(crystal, shift= 0., verbose = False, cheat=1.0):
450
+ """
451
+ Get polyhedra, and bonds from and edges and lengths of edges for each polyhedron and store it in info dictionary of new ase.Atoms object
452
+
453
+ Parameter:
454
+ ----------
455
+ crystal: ase.atoms_object
456
+ information on all polyhedra
457
+ """
458
+ crystal.positions += shift * crystal.cell[0, 0]
459
+ crystal.wrap()
460
+
461
+ atoms = get_non_periodic_supercell(crystal)
462
+ atoms = atoms[atoms.numbers.argsort()]
463
+
464
+
465
+ atoms.positions += crystal.cell.lengths()
466
+ polyhedra = find_polyhedra(atoms, cheat=cheat)
467
+
468
+ connectivity_matrix = get_connectivity_matrix(crystal, atoms, polyhedra)
469
+ coord = connectivity_matrix.sum(axis=1)
470
+
471
+ del(atoms[np.where(coord==0)])
472
+ new_polyhedra = {}
473
+ index = 0
474
+ octahedra =[]
475
+ tetrahedra = []
476
+ other = []
477
+ super_cell_atoms =[]
478
+
479
+ atoms_tree = scipy.spatial.cKDTree(atoms.positions-crystal.cell.lengths())
480
+ crystal_tree = scipy.spatial.cKDTree(crystal.positions)
481
+ connectivity_matrix = np.zeros([len(atoms),len(atoms)], dtype=float)
482
+
483
+ for polyhedron in polyhedra.values():
484
+ polyhedron['vertices'] -= crystal.cell.lengths()
485
+ vertices = polyhedron['vertices']
486
+ center = np.average(polyhedron['vertices'], axis=0)
487
+
488
+ dd, polyhedron['indices'] = atoms_tree.query(vertices , k=1)
489
+ atom_ind = (np.array(polyhedron['indices']))
490
+ dd, polyhedron['atom_indices'] = crystal_tree.query(vertices , k=1)
491
+
492
+ to_bond = np.where(dd<0.001)[0]
493
+ super_cell_atoms.extend(list(atom_ind[to_bond]))
494
+
495
+ edges = []
496
+ lengths = []
497
+ for triangle in polyhedron['triangles']:
498
+ triangle = np.array(triangle)
499
+ for permut in [[0,1], [1,2], [0,2]]:
500
+ vertex = [np.min(triangle[permut]), np.max(triangle[permut])]
501
+ length = np.linalg.norm(vertices[vertex[0]]-vertices[vertex[1]])
502
+ if vertex[0] in to_bond or vertex[1] in to_bond:
503
+ connectivity_matrix[atom_ind[vertex[1]], atom_ind[vertex[0]]] = length
504
+ connectivity_matrix[atom_ind[vertex[0]], atom_ind[vertex[1]]] = length
505
+ if vertex[0] not in to_bond:
506
+ atoms[atom_ind[vertex[0]]].symbol = 'Be'
507
+ if vertex[1] not in to_bond:
508
+ atoms[atom_ind[vertex[1]]].symbol = 'Be'
509
+ if vertex not in edges:
510
+ edges.append(vertex)
511
+ lengths.append(np.linalg.norm(vertices[vertex[0]]-vertices[vertex[1]] ))
512
+ polyhedron['edges'] = edges
513
+ polyhedron['edge_lengths'] = lengths
514
+ if all(center > -0.000001) and all(center < crystal.cell.lengths()-0.01):
515
+ new_polyhedra[str(index)]=polyhedron
516
+ if polyhedron['length'] == 4:
517
+ tetrahedra.append(str(index))
518
+ elif polyhedron['length'] == 6:
519
+ octahedra.append(str(index))
520
+ else:
521
+ other.append(str(index))
522
+ if verbose:
523
+ print(polyhedron['length'])
524
+ index += 1
525
+ atoms.positions -= crystal.cell.lengths()
526
+ coord = connectivity_matrix.copy()
527
+ coord[np.where(coord>.1)] = 1
528
+ coord = coord.sum(axis=1)
529
+
530
+ super_cell_atoms = np.sort(np.unique(super_cell_atoms))
531
+ atoms.info.update({'polyhedra': {'polyhedra': new_polyhedra,
532
+ 'tetrahedra': tetrahedra,
533
+ 'octahedra': octahedra,
534
+ 'other' : other}})
535
+ atoms.info.update({'bonds': {'connectivity_matrix': connectivity_matrix,
536
+ 'super_cell_atoms': super_cell_atoms,
537
+ 'super_cell_dimensions': crystal.cell.array,
538
+ 'coordination': coord}})
539
+ atoms.info.update({'supercell': crystal})
540
+ return atoms
541
+
542
+ def plot_atoms(atoms: ase.Atoms, polyhedra_indices=None, plot_bonds=False, color='', template=None, atom_size=None, max_size=35) -> go.Figure:
543
+ """
544
+ Plot structure in a ase.Atoms object with plotly
545
+
546
+ If the info dictionary of the atoms object contains bond or polyedra information, these can be set tobe plotted
547
+
548
+ Partameter:
549
+ -----------
550
+ atoms: ase.Atoms object
551
+ structure of supercell
552
+ polyhedra_indices: list of integers
553
+ indices of polyhedra to be plotted
554
+ plot_bonds: boolean
555
+ whether to plot bonds or not
556
+
557
+ Returns:
558
+ --------
559
+ fig: plotly figure object
560
+ handle to figure needed to modify appearance
561
+ """
562
+ energies = np.zeros(len(atoms))
563
+ if 'bonds' in atoms.info:
564
+ if 'atom_energy' in atoms.info['bonds']:
565
+ energies = np.round(np.array(atoms.info['bonds']['atom_energy'] - 12 * atoms.info['bonds']['ideal_bond_energy']) *1000,0)
566
+
567
+ for atom in atoms:
568
+ if atom.index not in atoms.info['bonds']['super_cell_atoms']:
569
+ energies[atom.index] = 0.
570
+ if color == 'coordination':
571
+ colors = atoms.info['bonds']['coordination']
572
+ elif color == 'layer':
573
+ colors = atoms.positions[:, 2]
574
+ elif color == 'energy':
575
+ colors = energies
576
+ colors[colors>50] = 50
577
+ colors = np.log(1+ energies)
578
+
579
+ else:
580
+ colors = atoms.get_atomic_numbers()
581
+
582
+ if atom_size is None:
583
+ atom_size = atoms.get_atomic_numbers()*4
584
+ elif isinstance(atom_size, float):
585
+ atom_size = atoms.get_atomic_numbers()*4*atom_size
586
+ atom_size[atom_size>max_size] = max_size
587
+ elif isinstance(atom_size, int):
588
+ atom_size = [atom_size]*len(atoms)
589
+ if len(atom_size) != len(atoms):
590
+ atom_size = [10]*len(atoms)
591
+ print('wrong length of atom_size parameter')
592
+ plot_polyhedra = False
593
+ data = []
594
+ if polyhedra_indices is not None:
595
+ if 'polyhedra' in atoms.info:
596
+ if polyhedra_indices == -1:
597
+ data = plot_polyhedron(atoms.info['polyhedra']['polyhedra'], range(len(atoms.info['polyhedra']['polyhedra'])))
598
+ plot_polyhedra = True
599
+ elif isinstance(polyhedra_indices, list):
600
+ data = plot_polyhedron(atoms.info['polyhedra']['polyhedra'], polyhedra_indices)
601
+ plot_polyhedra = True
602
+ text = []
603
+ if 'bonds' in atoms.info:
604
+ coord = atoms.info['bonds']['coordination']
605
+ for atom in atoms:
606
+ if atom.index in atoms.info['bonds']['super_cell_atoms']:
607
+
608
+ text.append(f'Atom {atom.index}: coordination={coord[atom.index]}' +
609
+ f'x:{atom.position[0]:.2f} \n y:{atom.position[1]:.2f} \n z:{atom.position[2]:.2f}')
610
+ if 'atom_energy' in atoms.info['bonds']:
611
+ text[-1] += f"\n energy: {energies[atom.index]:.0f} meV"
612
+ else:
613
+ text.append('')
614
+ else:
615
+ text = [''] * len(atoms)
616
+
617
+ if plot_bonds:
618
+ data += get_plot_bonds(atoms)
619
+ if plot_polyhedra or plot_bonds:
620
+ fig = go.Figure(data=data)
621
+ else:
622
+ fig = go.Figure()
623
+ if color=='energy':
624
+ fig.add_trace(go.Scatter3d(
625
+ mode='markers',
626
+ x=atoms.positions[:,0], y=atoms.positions[:,1], z=atoms.positions[:,2],
627
+ hovertemplate='<b>%{text}</b><extra></extra>',
628
+ text = text,
629
+ marker=dict(
630
+ color=colors,
631
+ size=atom_size,
632
+ sizemode='diameter',
633
+ colorscale='Rainbow', #px.colors.qualitative.Light24,
634
+ colorbar=dict(thickness=10, orientation='h'))))
635
+ #hover_name = colors))) # ["blue", "green", "red"])))
636
+
637
+ elif 'bonds' in atoms.info:
638
+ fig.add_trace(go.Scatter3d(
639
+ mode='markers',
640
+ x=atoms.positions[:,0], y=atoms.positions[:,1], z=atoms.positions[:,2],
641
+ hovertemplate='<b>%{text}</b><extra></extra>',
642
+ text = text,
643
+ marker=dict(
644
+ color=colors,
645
+ size=atom_size,
646
+ sizemode='diameter',
647
+ colorscale= px.colors.qualitative.Light24)))
648
+ #hover_name = colors))) # ["blue", "green", "red"])))
649
+
650
+ else:
651
+ fig.add_trace(go.Scatter3d(
652
+ mode='markers',
653
+ x=atoms.positions[:,0], y=atoms.positions[:,1], z=atoms.positions[:,2],
654
+ marker=dict(
655
+ color=colors,
656
+ size=atom_size,
657
+ sizemode='diameter',
658
+ colorbar=dict(thickness=10),
659
+ colorscale= px.colors.qualitative.Light24)))
660
+ #hover_name = colors))) # ["blue", "green", "red"])))
661
+ fig.update_layout(width=1000, height=700, showlegend=False, template=template)
662
+ fig.update_layout(scene_aspectmode='data',
663
+ scene_aspectratio=dict(x=1, y=1, z=1))
664
+
665
+ camera = {'up': {'x': 0, 'y': 1, 'z': 0},
666
+ 'center': {'x': 0, 'y': 0, 'z': 0},
667
+ 'eye': {'x': 0, 'y': 0, 'z': 1}}
668
+ fig.update_coloraxes(showscale=True)
669
+ fig.update_layout(scene_camera=camera, title=r"Al-GB $")
670
+ fig.update_scenes(camera_projection_type="orthographic" )
671
+ fig.show()
672
+ return fig
673
+
674
+
675
+
676
+
677
+ def find_polyhedra(atoms, optimize=True, cheat=1.0, bond_radii=None):
404
678
  """ get polyhedra information from an ase.Atoms object
405
679
 
406
680
  This is following the method of Banadaki and Patala
@@ -429,19 +703,50 @@ def find_polyhedra(atoms, optimize=True, cheat=1.0):
429
703
  else:
430
704
  tetrahedra = scipy.spatial.Delaunay(atoms.positions)
431
705
 
432
- voronoi_vertices, voronoi_tetrahedrons, r_vv, r_a = get_voronoi(tetrahedra, atoms, optimize=optimize)
433
-
706
+ voronoi_vertices, voronoi_tetrahedrons, r_vv, r_a = get_voronoi(tetrahedra, atoms, optimize=optimize, bond_radii=bond_radii)
707
+ if np.abs(atoms.positions[:, 2]).sum() <= 0.01:
708
+ r_vv = np.array(r_vv)*3.
434
709
  overlapping_pairs = find_overlapping_spheres(voronoi_vertices, r_vv, r_a, cheat=cheat)
435
710
 
436
711
  clusters, visited_all = find_interstitial_clusters(overlapping_pairs)
437
712
 
438
713
  if np.abs(atoms.positions[:, 2]).sum() <= 0.01:
439
- polyhedra = make_polygons(atoms, voronoi_vertices, voronoi_tetrahedrons, clusters, visited_all)
714
+ rings = get_polygons(atoms, clusters, voronoi_tetrahedrons)
715
+ return rings
440
716
  else:
441
717
  polyhedra = make_polyhedrons(atoms, voronoi_vertices, voronoi_tetrahedrons, clusters, visited_all)
442
718
  return polyhedra
443
719
 
444
720
 
721
+ def polygon_sort(corners):
722
+ center = np.average(corners[:, :2], axis=0)
723
+ angles = (np.arctan2(corners[:,0]-center[0], corners[:,1]-center[1]) + 2.0 * np.pi)% (2.0 * np.pi)
724
+ return corners[np.argsort(angles)]
725
+
726
+ def get_polygons(atoms, clusters, voronoi_tetrahedrons):
727
+ polygons = []
728
+ cyclicity = []
729
+ centers = []
730
+ corners =[]
731
+ for index, cluster in (enumerate(clusters)):
732
+ cc = []
733
+ for c in cluster:
734
+ cc = cc + list(voronoi_tetrahedrons[c])
735
+
736
+ sorted_corners = polygon_sort(atoms.positions[list(set(cc)), :2])
737
+ cyclicity.append(len(sorted_corners))
738
+ corners.append(sorted_corners)
739
+ centers.append(np.mean(sorted_corners[:,:2], axis=0))
740
+ polygons.append(patches.Polygon(np.array(sorted_corners)[:,:2], closed=True, fill=True, edgecolor='red'))
741
+
742
+ rings={'atoms': atoms.positions[:, :2],
743
+ 'cyclicity': np.array(cyclicity),
744
+ 'centers': np.array(centers),
745
+ 'corners': corners,
746
+ 'polygons': polygons}
747
+ return rings
748
+
749
+
445
750
  def sort_polyhedra_by_vertices(polyhedra, visible=range(4, 100), z_lim=[0, 100], verbose=False):
446
751
  indices = []
447
752
 
@@ -459,3 +764,404 @@ def sort_polyhedra_by_vertices(polyhedra, visible=range(4, 100), z_lim=[0, 100],
459
764
 
460
765
  # color_scheme = ['lightyellow', 'silver', 'rosybrown', 'lightsteelblue', 'orange', 'cyan', 'blue', 'magenta',
461
766
  # 'firebrick', 'forestgreen']
767
+
768
+
769
+
770
+ ##########################
771
+ # New Graph Stuff
772
+ ##########################
773
+ def breadth_first_search(graph, initial, projected_crystal):
774
+ """ breadth first search of atoms viewed as a graph
775
+
776
+ the projection dictionary has to contain the following items
777
+ 'number_of_nearest_neighbours', 'rotated_cell', 'near_base', 'allowed_variation'
778
+
779
+ Parameters
780
+ ----------
781
+ graph: numpy array (Nx2)
782
+ the atom positions
783
+ initial: int
784
+ index of starting atom
785
+ projection_tags: dict
786
+ dictionary with information on projected unit cell (with 'rotated_cell' item)
787
+
788
+ Returns
789
+ -------
790
+ graph[visited]: numpy array (M,2) with M<N
791
+ positions of atoms hopped in unit cell lattice
792
+ ideal: numpy array (M,2)
793
+ ideal atom positions
794
+ """
795
+
796
+ projection_tags = projected_crystal.info['projection']
797
+
798
+ # get lattice vectors to hopp along through graph
799
+ projected_unit_cell = projected_crystal.cell[:2, :2]
800
+ a_lattice_vector = projected_unit_cell[0]/10
801
+ b_lattice_vector = projected_unit_cell[1]/10
802
+ print(a_lattice_vector, b_lattice_vector)
803
+ main = np.array([a_lattice_vector, -a_lattice_vector, b_lattice_vector, -b_lattice_vector]) # vectors of unit cell
804
+ near = np.append(main, projection_tags['near_base'], axis=0) # all nearest atoms
805
+ # get k next nearest neighbours for each node
806
+ neighbour_tree = scipy.spatial.cKDTree(graph)
807
+ distances, indices = neighbour_tree.query(graph, # let's get all neighbours
808
+ k=8) # projection_tags['number_of_nearest_neighbours']*2 + 1)
809
+ # print(projection_tags['number_of_nearest_neighbours'] * 2 + 1)
810
+ visited = [] # the atoms we visited
811
+ ideal = [] # atoms at ideal lattice
812
+ sub_lattice = [] # atoms in base and disregarded
813
+ queue = [initial]
814
+ ideal_queue = [graph[initial]]
815
+
816
+ while queue:
817
+ node = queue.pop(0)
818
+ ideal_node = ideal_queue.pop(0)
819
+
820
+ if node not in visited:
821
+ visited.append(node)
822
+ ideal.append(ideal_node)
823
+ # print(node,ideal_node)
824
+ neighbors = indices[node]
825
+ for i, neighbour in enumerate(neighbors):
826
+ if neighbour not in visited:
827
+ distance_to_ideal = np.linalg.norm(near + graph[node] - graph[neighbour], axis=1)
828
+
829
+ if np.min(distance_to_ideal) < projection_tags['allowed_variation']:
830
+ direction = np.argmin(distance_to_ideal)
831
+ if direction > 3: # counting starts at 0
832
+ sub_lattice.append(neighbour)
833
+ elif distances[node, i] < projection_tags['distance_unit_cell'] * 1.05:
834
+ queue.append(neighbour)
835
+ ideal_queue.append(ideal_node + near[direction])
836
+
837
+ return graph[visited], ideal
838
+
839
+ ####################
840
+ # Distortion Matrix
841
+ ####################
842
+ def get_distortion_matrix(atoms, ideal_lattice):
843
+ """ Calculates distortion matrix
844
+
845
+ Calculates the distortion matrix by comparing ideal and distorted Voronoi tiles
846
+ """
847
+
848
+ vor = scipy.spatial.Voronoi(atoms)
849
+
850
+ # determine a middle Voronoi tile
851
+ ideal_vor = scipy.spatial.Voronoi(ideal_lattice)
852
+ near_center = np.average(ideal_lattice, axis=0)
853
+ index = np.argmin(np.linalg.norm(ideal_lattice - near_center, axis=0))
854
+
855
+ # the ideal vertices fo such an Voronoi tile (are there crystals with more than one voronoi?)
856
+ ideal_vertices = ideal_vor.vertices[ideal_vor.regions[ideal_vor.point_region[index]]]
857
+ ideal_vertices = get_significant_vertices(ideal_vertices - np.average(ideal_vertices, axis=0))
858
+
859
+ distortion_matrix = []
860
+ for index in range(vor.points.shape[0]):
861
+ done = int((index + 1) / vor.points.shape[0] * 50)
862
+ sys.stdout.write('\r')
863
+ # progress output :
864
+ sys.stdout.write("[%-50s] %d%%" % ('=' * done, 2 * done))
865
+ sys.stdout.flush()
866
+
867
+ # determine vertices of Voronoi polygons of an atom with number index
868
+ poly_point = vor.points[index]
869
+ poly_vertices = get_significant_vertices(vor.vertices[vor.regions[vor.point_region[index]]] - poly_point)
870
+
871
+ # where ATOM has to be moved (not pixel)
872
+ ideal_point = ideal_lattice[index]
873
+
874
+ # transform voronoi to ideal one and keep transformation matrix A
875
+ uncorrected, corrected, aa = transform_voronoi(poly_vertices, ideal_vertices)
876
+
877
+ # pixel positions
878
+ corrected = corrected + ideal_point + (np.rint(poly_point) - poly_point)
879
+ for i in range(len(corrected)):
880
+ # original image pixels
881
+ x, y = uncorrected[i] + np.rint(poly_point)
882
+ # collect the two origin and target coordinates and store
883
+ distortion_matrix.append([x, y, corrected[i, 0], corrected[i, 1]])
884
+ print()
885
+ return np.array(distortion_matrix)
886
+
887
+
888
+ def undistort(distortion_matrix, image_data):
889
+ """ Undistort image according to distortion matrix
890
+
891
+ Uses the griddata interpolation of scipy to apply distortion matrix to image.
892
+ The distortion matrix contains in origin and target pixel coordinates
893
+ target is where the pixel has to be moved (floats)
894
+
895
+ Parameters
896
+ ----------
897
+ distortion_matrix: numpy array (Nx2)
898
+ distortion matrix (format N x 2)
899
+ image_data: numpy array or sidpy.Dataset
900
+ image
901
+
902
+ Returns
903
+ -------
904
+ interpolated: numpy array
905
+ undistorted image
906
+ """
907
+
908
+ intensity_values = image_data[(distortion_matrix[:, 0].astype(int), distortion_matrix[:, 1].astype(int))]
909
+
910
+ corrected = distortion_matrix[:, 2:4]
911
+
912
+ size_x, size_y = 2 ** np.round(np.log2(image_data.shape[0:2])) # nearest power of 2
913
+ size_x = int(size_x)
914
+ size_y = int(size_y)
915
+ grid_x, grid_y = np.mgrid[0:size_x - 1:size_x * 1j, 0:size_y - 1:size_y * 1j]
916
+ print('interpolate')
917
+
918
+ interpolated = scipy.interpolate.griddata(np.array(corrected), np.array(intensity_values), (grid_x, grid_y), method='linear')
919
+ return interpolated
920
+
921
+
922
+ def transform_voronoi(vertices, ideal_voronoi):
923
+ """ find transformation matrix A between a distorted polygon and a perfect reference one
924
+
925
+ Returns
926
+ -------
927
+ uncorrected: list of points:
928
+ all points on a grid within original polygon
929
+ corrected: list of points:
930
+ coordinates of these points where pixel have to move to
931
+ aa: 2x2 matrix A:
932
+ transformation matrix
933
+ """
934
+
935
+ # Find Transformation Matrix, note polygons have to be ordered first.
936
+ sort_vert = []
937
+ for vert in ideal_voronoi:
938
+ sort_vert.append(np.argmin(np.linalg.norm(vertices - vert, axis=1)))
939
+ vertices = np.array(vertices)[sort_vert]
940
+
941
+ # Solve the least squares problem X * A = Y
942
+ # to find our transformation matrix aa = A
943
+ aa, res, rank, s = np.linalg.lstsq(vertices, ideal_voronoi, rcond=None)
944
+
945
+ # expand polygon to include more points in distortion matrix
946
+ vertices2 = vertices + np.sign(vertices) # +np.sign(vertices)
947
+
948
+ ext_v = int(np.abs(vertices2).max() + 1)
949
+
950
+ polygon_grid = np.mgrid[0:ext_v * 2 + 1, :ext_v * 2 + 1] - ext_v
951
+ polygon_grid = np.swapaxes(polygon_grid, 0, 2)
952
+ polygon_array = polygon_grid.reshape(-1, polygon_grid.shape[-1])
953
+
954
+ p = points_in_poly(polygon_array, vertices2)
955
+ uncorrected = polygon_array[p]
956
+
957
+ corrected = np.dot(uncorrected, aa)
958
+
959
+ return uncorrected, corrected, aa
960
+
961
+
962
+ def get_maximum_view(distortion_matrix):
963
+ distortion_matrix_extent = np.ones(distortion_matrix.shape[1:], dtype=int)
964
+ distortion_matrix_extent[distortion_matrix[0] == -1000.] = 0
965
+
966
+ area = distortion_matrix_extent
967
+ view_square = np.array([0, distortion_matrix.shape[1] - 1, 0, distortion_matrix.shape[2] - 1], dtype=int)
968
+ while np.array(np.where(area == 0)).shape[1] > 0:
969
+ view_square = view_square + [1, -1, 1, -1]
970
+ area = distortion_matrix_extent[view_square[0]:view_square[1], view_square[2]:view_square[3]]
971
+
972
+ change = [-int(np.sum(np.min(distortion_matrix_extent[:view_square[0], view_square[2]:view_square[3]], axis=1))),
973
+ int(np.sum(np.min(distortion_matrix_extent[view_square[1]:, view_square[2]:view_square[3]], axis=1))),
974
+ -int(np.sum(np.min(distortion_matrix_extent[view_square[0]:view_square[1], :view_square[2]], axis=0))),
975
+ int(np.sum(np.min(distortion_matrix_extent[view_square[0]:view_square[1], view_square[3]:], axis=0)))]
976
+
977
+ return np.array(view_square) + change
978
+
979
+
980
+ def get_significant_vertices(vertices, distance=3):
981
+ """Calculate average for all points that are closer than distance apart, otherwise leave the points alone
982
+
983
+ Parameters
984
+ ----------
985
+ vertices: numpy array (n,2)
986
+ list of points
987
+ distance: float
988
+ (in same scale as points )
989
+
990
+ Returns
991
+ -------
992
+ ideal_vertices: list of floats
993
+ list of points that are all a minimum of 3 apart.
994
+ """
995
+
996
+ tt = scipy.spatial.cKDTree(np.array(vertices))
997
+ near = tt.query_ball_point(vertices, distance)
998
+ ideal_vertices = []
999
+ for indices in near:
1000
+ if len(indices) == 1:
1001
+ ideal_vertices.append(vertices[indices][0])
1002
+ else:
1003
+ ideal_vertices.append(np.average(vertices[indices], axis=0))
1004
+ ideal_vertices = np.unique(np.array(ideal_vertices), axis=0)
1005
+ angles = np.arctan2(ideal_vertices[:, 1], ideal_vertices[:, 0])
1006
+ ang_sort = np.argsort(angles)
1007
+
1008
+ ideal_vertices = ideal_vertices[ang_sort]
1009
+
1010
+ return ideal_vertices
1011
+
1012
+
1013
+ def transform_voronoi(vertices, ideal_voronoi):
1014
+ """
1015
+ find transformation matrix A between a polygon and a perfect one
1016
+
1017
+ returns:
1018
+ list of points: all points on a grid within original polygon
1019
+ list of points: coordinates of these points where pixel have to move to
1020
+ 2x2 matrix aa: transformation matrix
1021
+ """
1022
+ # Find Transformation Matrix, note polygons have to be ordered first.
1023
+ sort_vert = []
1024
+ for vert in ideal_voronoi:
1025
+ sort_vert.append(np.argmin(np.linalg.norm(vertices - vert, axis=1)))
1026
+ vertices = np.array(vertices)[sort_vert]
1027
+
1028
+ # Solve the least squares problem X * A = Y
1029
+ # to find our transformation matrix A
1030
+ aa, res, rank, s = np.linalg.lstsq(vertices, ideal_voronoi, rcond=None)
1031
+
1032
+ # expand polygon to include more points in distortion matrix
1033
+ vertices2 = vertices + np.sign(vertices) # +np.sign(vertices)
1034
+
1035
+ ext_v = int(np.abs(vertices2).max() + 1)
1036
+
1037
+ polygon_grid = np.mgrid[0:ext_v * 2 + 1, :ext_v * 2 + 1] - ext_v
1038
+ polygon_grid = np.swapaxes(polygon_grid, 0, 2)
1039
+ polygon_array = polygon_grid.reshape(-1, polygon_grid.shape[-1])
1040
+
1041
+ p = points_in_poly(polygon_array, vertices2)
1042
+ uncorrected = polygon_array[p]
1043
+
1044
+ corrected = np.dot(uncorrected, aa)
1045
+
1046
+ return uncorrected, corrected, aa
1047
+
1048
+
1049
+
1050
+ def undistort_sitk(image_data, distortion_matrix):
1051
+ """ use simple ITK to undistort image
1052
+
1053
+ Parameters
1054
+ ----------
1055
+ image_data: numpy array with size NxM
1056
+ distortion_matrix: sidpy.Dataset or numpy array with size 2 x P x Q
1057
+ with P, Q >= M, N
1058
+
1059
+ Returns
1060
+ -------
1061
+ image: numpy array MXN
1062
+
1063
+ """
1064
+ resampler = sitk.ResampleImageFilter()
1065
+ resampler.SetReferenceImage(sitk.GetImageFromArray(image_data))
1066
+ resampler.SetInterpolator(sitk.sitkBSpline)
1067
+ resampler.SetDefaultPixelValue(0)
1068
+
1069
+ distortion_matrix2 = distortion_matrix[:, :image_data.shape[0], :image_data.shape[1]]
1070
+
1071
+ displ2 = sitk.Compose(
1072
+ [sitk.GetImageFromArray(-distortion_matrix2[1]), sitk.GetImageFromArray(-distortion_matrix2[0])])
1073
+ out_tx = sitk.DisplacementFieldTransform(displ2)
1074
+ resampler.SetTransform(out_tx)
1075
+ out = resampler.Execute(sitk.GetImageFromArray(image_data))
1076
+ return sitk.GetArrayFromImage(out)
1077
+
1078
+
1079
+ def undistort_stack_sitk(distortion_matrix, image_stack):
1080
+ """
1081
+ use simple ITK to undistort stack of image
1082
+ input:
1083
+ image: numpy array with size NxM
1084
+ distortion_matrix: h5 Dataset or numpy array with size 2 x P x Q
1085
+ with P, Q >= M, N
1086
+ output:
1087
+ image M, N
1088
+
1089
+ """
1090
+
1091
+ resampler = sitk.ResampleImageFilter()
1092
+ resampler.SetReferenceImage(sitk.GetImageFromArray(image_stack[0]))
1093
+ resampler.SetInterpolator(sitk.sitkBSpline)
1094
+ resampler.SetDefaultPixelValue(0)
1095
+
1096
+ displ2 = sitk.Compose(
1097
+ [sitk.GetImageFromArray(-distortion_matrix[1]), sitk.GetImageFromArray(-distortion_matrix[0])])
1098
+ out_tx = sitk.DisplacementFieldTransform(displ2)
1099
+ resampler.SetTransform(out_tx)
1100
+
1101
+ interpolated = np.zeros(image_stack.shape)
1102
+
1103
+ nimages = image_stack.shape[0]
1104
+
1105
+ if QT_available:
1106
+ progress = pyTEMlib.sidpy_tools.ProgressDialog("Correct Scan Distortions", nimages)
1107
+
1108
+ for i in range(nimages):
1109
+ if QT_available:
1110
+ progress.setValue(i)
1111
+ out = resampler.Execute(sitk.GetImageFromArray(image_stack[i]))
1112
+ interpolated[i] = sitk.GetArrayFromImage(out)
1113
+
1114
+ progress.setValue(nimages)
1115
+
1116
+ if QT_available:
1117
+ progress.setValue(nimages)
1118
+
1119
+ return interpolated
1120
+
1121
+
1122
+ def undistort_stack(distortion_matrix, data):
1123
+ """ Undistort stack with distortion matrix
1124
+
1125
+ Use the griddata interpolation of scipy to apply distortion matrix to image
1126
+ The distortion matrix contains in each pixel where the pixel has to be moved (floats)
1127
+
1128
+ Parameters
1129
+ ----------
1130
+ distortion_matrix: numpy array
1131
+ distortion matrix to undistort image (format image.shape[0], image.shape[2], 2)
1132
+ data: numpy array or sidpy.Dataset
1133
+ image
1134
+ """
1135
+
1136
+ corrected = distortion_matrix[:, 2:4]
1137
+ intensity_values = data[:, distortion_matrix[:, 0].astype(int), distortion_matrix[:, 1].astype(int)]
1138
+
1139
+ size_x, size_y = 2 ** np.round(np.log2(data.shape[1:])) # nearest power of 2
1140
+ size_x = int(size_x)
1141
+ size_y = int(size_y)
1142
+
1143
+ grid_x, grid_y = np.mgrid[0:size_x - 1:size_x * 1j, 0:size_y - 1:size_y * 1j]
1144
+ print('interpolate')
1145
+
1146
+ interpolated = np.zeros([data.shape[0], size_x, size_y])
1147
+ nimages = data.shape[0]
1148
+ done = 0
1149
+
1150
+ if QT_available:
1151
+ progress = ft.ProgressDialog("Correct Scan Distortions", nimages)
1152
+ for i in range(nimages):
1153
+ if QT_available:
1154
+ progress.set_value(i)
1155
+ elif done < int((i + 1) / nimages * 50):
1156
+ done = int((i + 1) / nimages * 50)
1157
+ sys.stdout.write('\r')
1158
+ # progress output :
1159
+ sys.stdout.write("[%-50s] %d%%" % ('=' * done, 2 * done))
1160
+ sys.stdout.flush()
1161
+
1162
+ interpolated[i, :, :] = griddata(corrected, intensity_values[i, :], (grid_x, grid_y), method='linear')
1163
+ if QT_available:
1164
+ progress.set_value(nimages)
1165
+ print(':-)')
1166
+ print('You have successfully completed undistortion of image stack')
1167
+ return interpolated