bandu 1.3.6__py3-none-any.whl → 1.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bandu/bandu.py CHANGED
@@ -1,314 +1,321 @@
1
- import numpy as np
2
- from typing import Generator
3
- from copy import copy
4
- from . import brillouin_zone as brlzn
5
- from . import translate as trnslt
6
- from . import wfk_class as wc
7
-
8
- class BandU():
9
- def __init__(
10
- self, wfks:Generator, energy_level:float, width:float, grid:bool=True, fft:bool=True, norm:bool=True,
11
- sym:bool=True, low_mem:bool=False, plot:bool=True, real_imag_sep:bool=True
12
- )->None:
13
- '''
14
- BandU object with methods for finding states and computing BandU functions from states.
15
-
16
- Parameters
17
- ----------
18
- wfks : Generator
19
- An iterable generator of WFK objects with wavefunction coefficients, k-points, and eigenvalue attributes.
20
- energy_level : float
21
- The energy level of interest relative to the Fermi energy.
22
- width : float
23
- Defines how far above and below the energy_level is searched for states.
24
- Search is done width/2 above and below, so total states captured are within 'width' energy.
25
- grid : bool
26
- Determines whether or not wavefunction coefficients are converted to 3D numpy grid.
27
- Default converts to grid (True).
28
- fft : bool
29
- Determines whether or not wavefunction coefficients are Fourier transformed to real space.
30
- Default converts from reciprocal space to real space (True).
31
- norm : bool
32
- Determines whether or not wavefunction coefficients are normalized.
33
- Default normalizes coefficients (True)
34
- plot : bool
35
- Plots eigenvalues from principal component analysis.
36
- Default will save plot (True)
37
- low_mem : bool
38
- Run the program on a lower memory setting
39
- The low_mem tag will print plane wave cofficients to a Python pickle to read from disk later.
40
- Default does not run in low memory mode (False)
41
-
42
- Methods
43
- -------
44
- ToXSF
45
- Writes real and imaginary parts of BandU functions to XSF files
46
- '''
47
- self.grid:bool=grid
48
- self.fft:bool=fft
49
- self.norm:bool=norm
50
- self.sym:bool=sym
51
- self.low_mem:bool=low_mem
52
- self.found_states:int=0
53
- self.bandu_fxns:list[wc.WFK]=[]
54
- self.plot=plot
55
- # find all states within width
56
- self._FindStates(energy_level, width, wfks)
57
- print(f'{self.found_states} states found within specified energy range')
58
- # construct principal orbital components
59
- if real_imag_sep:
60
- principal_vals = self._RealImagPrnplComps()
61
- else:
62
- principal_vals = self._PrincipalComponents()
63
- # plot eigenvalues from PCA
64
- if plot:
65
- self._PlotEigs(principal_vals)
66
- # normalize bandu functions
67
- for i in range(self.found_states):
68
- self.bandu_fxns[i] = self.bandu_fxns[i].Normalize()
69
- # compute ratios
70
- if not real_imag_sep:
71
- omega_vals, omega_check = self._CheckOmega()
72
- else:
73
- omega_vals = 0
74
- omega_check = 0
75
- # write output file
76
- fermi = self.bandu_fxns[0].fermi_energy
77
- with open('eigenvalues.out', 'w') as f:
78
- print(f'Width: {width}, found states: {self.found_states}', file=f)
79
- print(f'Energy level: {energy_level+fermi}, Fermi energy: {fermi}', file=f)
80
- print(np.abs(principal_vals), file=f)
81
- print('Omega Values', file=f)
82
- print(omega_vals, file=f)
83
- print('Omega Check (value of 0 indicates Omega is at local extremum)', file=f)
84
- print(omega_check, file=f)
85
- #---------------------------------------------------------------------------------------------------------------------#
86
- #------------------------------------------------------ METHODS ------------------------------------------------------#
87
- #---------------------------------------------------------------------------------------------------------------------#
88
- # method transforming reciprocal space wfks to real space
89
- def _FindStates(
90
- self, energy_level:float, width:float, wfks:Generator[wc.WFK,None,None]
91
- ):
92
- # loop through every state
93
- for state in wfks:
94
- # check if state has a band that crosses the width
95
- min_en = state.fermi_energy + energy_level - width/2
96
- max_en = state.fermi_energy + energy_level + width/2
97
- for i, band in enumerate(state.eigenvalues):
98
- if min_en <= band <= max_en:
99
- self.found_states += 1
100
- coeffs = copy(state)
101
- coeffs.wfk_coeffs = coeffs.wfk_coeffs[i]
102
- for wfk in self._Process(coeffs):
103
- self.bandu_fxns.append(wfk)
104
- if self.bandu_fxns is []:
105
- raise ValueError(
106
- '''Identified 0 states within provided width.
107
- Action: Increase width or increase fineness of kpoint grid.
108
- ''')
109
- #-----------------------------------------------------------------------------------------------------------------#
110
- # method for processing planewave coefficient data from FindStates
111
- def _Process(
112
- self, state:wc.WFK
113
- )->Generator[wc.WFK,None,None]:
114
- funcs:list[wc.WFK] = []
115
- # generate symmetrically equivalent coefficients
116
- if self.sym:
117
- for sym_coeffs in state.SymWFKs(kpoint=state.kpoints):
118
- self.found_states += 1
119
- funcs.append(sym_coeffs)
120
- self.found_states -= 1
121
- else:
122
- # shift point back into Brillouin Zone as necessary
123
- rec_latt = state.Real2Reciprocal()
124
- shift = brlzn.BZ(rec_latt=rec_latt).GetShifts(state.kpoints)
125
- state.pw_indices += shift
126
- state.kpoints = state.kpoints.reshape((-1,3)) - shift
127
- funcs.append(state)
128
- # apply desired transformations
129
- for wfk in funcs:
130
- if self.grid:
131
- wfk = wfk.GridWFK()
132
- if self.fft:
133
- wfk = wfk.IFFT()
134
- if self.norm:
135
- wfk = wfk.Normalize()
136
- yield wfk
137
- #-----------------------------------------------------------------------------------------------------------------#
138
- # check if states along BZ edge have been collected
139
- # --not functional--
140
- def _CheckEdgeCase(
141
- self
142
- ):
143
- bz = brlzn.BZ(self.bandu_fxns[0].Real2Reciprocal())
144
- x = self.bandu_fxns[0].ngfftx
145
- y = self.bandu_fxns[0].ngffty
146
- z = self.bandu_fxns[0].ngfftz
147
- all_shifts:list[np.ndarray] = []
148
- wfk_to_shift:list[int] = []
149
- dupes:list[wc.WFK] = []
150
- for i, wfk in enumerate(self.bandu_fxns):
151
- # if point on edge, translate to find other periodic points
152
- edge_pts = bz._BZEdgePt(wfk.kpoints)
153
- print(wfk.kpoints)
154
- print(edge_pts)
155
- if len(edge_pts[edge_pts > 0]) > 0:
156
- shifts, _ = trnslt.TranslatePoints(np.zeros((1,3)), np.zeros(1), np.identity(3))
157
- shifts = shifts[edge_pts > 0].astype(int)
158
- all_shifts.append(shifts)
159
- wfk_to_shift.append(i)
160
- for i in wfk_to_shift:
161
- for shifts in all_shifts:
162
- for shift in shifts:
163
- shifted_wfk = copy(self.bandu_fxns[i])
164
- shifted_wfk = shifted_wfk.RemoveGrid()
165
- shifted_wfk.pw_indices += shift
166
- shifted_wfk = shifted_wfk.GridWFK()
167
- shifted_wfk.wfk_coeffs = shifted_wfk.wfk_coeffs.reshape((1,x*y*z))
168
- shifted_dupe = copy(self.bandu_fxns)
169
- shifted_dupe[i] = shifted_wfk
170
- dupes.extend(shifted_dupe)
171
- self.duped_states += self.found_states
172
- self.bandu_fxns.extend(dupes)
173
- #-----------------------------------------------------------------------------------------------------------------#
174
- # find principal components
175
- def _PrincipalComponents(
176
- self
177
- )->np.ndarray:
178
- total_states = self.found_states
179
- # organize wfk coefficients
180
- x = self.bandu_fxns[0].ngfftx
181
- y = self.bandu_fxns[0].ngffty
182
- z = self.bandu_fxns[0].ngfftz
183
- mat = np.zeros((total_states,x*y*z), dtype=complex)
184
- for i in range(total_states):
185
- mat[i,:] = self.bandu_fxns[i].wfk_coeffs.reshape((1,x*y*z))
186
- # compute overlap matrix
187
- print('Computing overlap matrix')
188
- overlap_mat = np.matmul(np.conj(mat), mat.T)
189
- # diagonlize matrix
190
- principal_vals, principal_vecs = np.linalg.eig(overlap_mat)
191
- principal_vecs = principal_vecs.T
192
- # organize eigenvectors and eigenvalues
193
- sorted_inds = np.flip(principal_vals.argsort())
194
- principal_vals = np.take(principal_vals, sorted_inds)
195
- principal_vecs = np.take(principal_vecs, sorted_inds, axis=0)
196
- mat = np.matmul(principal_vecs, mat)
197
- for i in range(total_states):
198
- self.bandu_fxns[i].wfk_coeffs = mat[i,:]
199
- return principal_vals
200
- #-----------------------------------------------------------------------------------------------------------------#
201
- # find principal components with real and imaginary separated
202
- def _RealImagPrnplComps(
203
- self
204
- )->np.ndarray:
205
- x = self.bandu_fxns[0].ngfftx
206
- y = self.bandu_fxns[0].ngffty
207
- z = self.bandu_fxns[0].ngfftz
208
- mat = np.zeros((2*self.found_states,x*y*z), dtype=complex)
209
- for i in range(self.found_states):
210
- ind = 2*i
211
- real_coeffs = self.bandu_fxns[i].wfk_coeffs.real
212
- imag_coeffs = self.bandu_fxns[i].wfk_coeffs.imag
213
- mat[ind,:] = real_coeffs.reshape((1,x*y*z))
214
- mat[ind+1,:] = imag_coeffs.reshape((1,x*y*z))
215
- print('Computing overlap matrix')
216
- overlap_mat = np.matmul(np.conj(mat),mat.T)
217
- # diagonlize matrix
218
- principal_vals, principal_vecs = np.linalg.eig(overlap_mat)
219
- principal_vecs = principal_vecs.T
220
- # organize eigenvectors and eigenvalues
221
- sorted_inds = np.flip(principal_vals.argsort())
222
- principal_vals = np.take(principal_vals, sorted_inds)
223
- principal_vecs = np.take(principal_vecs, sorted_inds, axis=0)
224
- mat = np.matmul(principal_vecs, mat)
225
- for i in range(2*self.found_states):
226
- if i < self.found_states:
227
- self.bandu_fxns[i].wfk_coeffs = mat[i,:]
228
- else:
229
- new_coeffs = copy(self.bandu_fxns[0])
230
- new_coeffs.wfk_coeffs = mat[i,:]
231
- self.bandu_fxns.append(new_coeffs)
232
- self.found_states *= 2
233
- return principal_vals
234
- #-----------------------------------------------------------------------------------------------------------------#
235
- # find ratio of real and imaginary components
236
- def _CheckOmega(
237
- self
238
- )->tuple[np.ndarray, np.ndarray]:
239
- total_states = self.found_states
240
- omega_vals = np.zeros((total_states, 3), dtype=float)
241
- vals = np.linspace(start=-0.01, stop=0.01, num=3)
242
- for i, val in enumerate(vals):
243
- for j in range(total_states):
244
- coeffs:np.ndarray = copy(self.bandu_fxns[j].wfk_coeffs)
245
- coeffs *= np.exp(1j*val*np.pi)
246
- omega = np.sum(coeffs.real*coeffs)/np.sum(coeffs.imag*coeffs)
247
- omega = np.abs(omega)
248
- omega_vals[j,i] = omega
249
- omega_diff1 = (omega_vals[:,1] - omega_vals[:,0])
250
- omega_diff2 = (omega_vals[:,2] - omega_vals[:,1])
251
- omega_check = np.sign(omega_diff1) + np.sign(omega_diff2)
252
- return omega_vals[:,1], omega_check
253
- #-----------------------------------------------------------------------------------------------------------------#
254
- # plot eigenvalues from PCA
255
- def _PlotEigs(
256
- self, eigvals:np.ndarray
257
- ):
258
- import matplotlib.pyplot as plt
259
- x = np.arange(self.found_states) + 1
260
- y = np.abs(eigvals.flatten())
261
- figsize = (12,6)
262
- fig = plt.figure(figsize=figsize)
263
- ax = fig.add_subplot()
264
- ax.plot(
265
- x,
266
- y,
267
- color='black',
268
- linestyle='-',
269
- marker='o',
270
- mfc='red',
271
- markersize=8
272
- )
273
- ax.spines[['right','top']].set_visible(False)
274
- ax.tick_params(axis='both', labelsize=12)
275
- plt.xlim(1.0,len(y)+5.0)
276
- plt.ylim=(0,np.max(y))
277
- mod_val = round(self.found_states/5 - 0.5)
278
- plt.xticks(ticks=[val for val in range(0,len(y)+1) if val % mod_val == 0])
279
- plt.rcParams['font.family'] = 'Times New Roman'
280
- plt.savefig('bandu_eigenvalues.png',dpi=500)
281
- #-----------------------------------------------------------------------------------------------------------------#
282
- # make xsf of BandU functions
283
- def ToXSF(
284
- self, nums:list[int]=[], xsf_name:str='Principal_orbital_component'
285
- ):
286
- total_states = self.found_states
287
- if nums is []:
288
- nums = [1,total_states]
289
- else:
290
- # check if list has only 2 elements
291
- if len(nums) != 2:
292
- raise ValueError(f'nums should contain two values, {len(nums)} were received.')
293
- # check if function number is within defined range
294
- if nums[0] < 1:
295
- print('First element of nums cannot be lower than 1, changing to 1 now.')
296
- nums[0] = 1
297
- # update function number list if it exceeds maximum number of bandu functions
298
- if nums[1] > total_states:
299
- print(f'Printing up to max Band-U function number: {total_states}')
300
- nums[1] = total_states
301
- # check if lower limit is within defined range
302
- if nums[0] > nums[1]:
303
- nums[0] = nums[1]
304
- print(f'Writing XSF files for Band-U functions {nums[0]} through {nums[1]}.')
305
- # write xsf files
306
- x = self.bandu_fxns[0].ngfftx
307
- y = self.bandu_fxns[0].ngffty
308
- z = self.bandu_fxns[0].ngfftz
309
- for i in range(nums[0]-1, nums[1]):
310
- file_name = xsf_name + f'_{i+1}'
311
- wfk = copy(self.bandu_fxns[i])
312
- wfk.wfk_coeffs = wfk.wfk_coeffs.reshape((x,y,z))
313
- wfk = wfk.XSFFormat()
1
+ import numpy as np
2
+ from typing import Generator
3
+ from copy import copy
4
+ from . import brillouin_zone as brlzn
5
+ from . import translate as trnslt
6
+ from . import wfk_class as wc
7
+
8
+ class BandU():
9
+ def __init__(
10
+ self, wfks:Generator, energy_level:float, width:float, grid:bool=True, ifft:bool=True, fft:bool=False,
11
+ norm:bool=True, sym:bool=True, low_mem:bool=False, plot:bool=True, real_imag_sep:bool=False, opt:bool=True
12
+ )->None:
13
+ '''
14
+ BandU object with methods for finding states and computing BandU functions from states.
15
+
16
+ Parameters
17
+ ----------
18
+ wfks : Generator
19
+ An iterable generator of WFK objects with wavefunction coefficients, k-points, and eigenvalue attributes.
20
+ energy_level : float
21
+ The energy level of interest relative to the Fermi energy.
22
+ width : float
23
+ Defines how far above and below the energy_level is searched for states.
24
+ Search is done width/2 above and below, so total states captured are within 'width' energy.
25
+ grid : bool
26
+ Determines whether or not wavefunction coefficients are converted to 3D numpy grid.
27
+ Default converts to grid (True).
28
+ ifft : bool
29
+ Determines whether or not wavefunction coefficients are Inverse Fourier transformed to real space.
30
+ Default converts from reciprocal space to real space (True).
31
+ fft : bool
32
+ Determines whether or not wavefunction coefficients are Fourier transformed to real space.
33
+ Generally, the Inverse Fourier transform should be applied to wavefunctions, see ifft argument.
34
+ Default does not apply FFT (False).
35
+ norm : bool
36
+ Determines whether or not wavefunction coefficients are normalized.
37
+ Default normalizes coefficients (True).
38
+ plot : bool
39
+ Plots eigenvalues from principal component analysis.
40
+ Default will save plot (True).
41
+ opt : bool
42
+ Attempts to automatically convert Band-U function into *mostly* real function by applying a phase factor.
43
+ Default computes and applies phase factor (True).
44
+
45
+ Methods
46
+ -------
47
+ ToXSF
48
+ Writes real and imaginary parts of BandU functions to XSF files
49
+ '''
50
+ self.grid:bool=grid
51
+ self.ifft:bool=ifft
52
+ self.fft:bool=fft
53
+ self.norm:bool=norm
54
+ self.sym:bool=sym
55
+ self.low_mem:bool=low_mem
56
+ self.found_states:int=0
57
+ self.bandu_fxns:list[wc.WFK]=[]
58
+ self.plot=plot
59
+ self.opt=opt
60
+ # find all states within width
61
+ self._FindStates(energy_level, width, wfks)
62
+ print(f'{self.found_states} states found within specified energy range')
63
+ # construct principal orbital components
64
+ if real_imag_sep:
65
+ principal_vals = self._RealImagPrnplComps()
66
+ else:
67
+ principal_vals = self._PrincipalComponents()
68
+ # plot eigenvalues from PCA
69
+ if plot:
70
+ self._PlotEigs(principal_vals)
71
+ # normalize bandu functions
72
+ for i in range(self.found_states):
73
+ self.bandu_fxns[i] = self.bandu_fxns[i].Normalize()
74
+ # compute ratios
75
+ if not real_imag_sep:
76
+ omega_vals, omega_check = self._CheckOmega()
77
+ else:
78
+ omega_vals = 0
79
+ omega_check = 0
80
+ # write output file
81
+ fermi = self.bandu_fxns[0].fermi_energy
82
+ with open('eigenvalues.out', 'w') as f:
83
+ print(f'Width: {width}, found states: {self.found_states}', file=f)
84
+ print(f'Energy level: {energy_level+fermi}, Fermi energy: {fermi}', file=f)
85
+ print(np.abs(principal_vals), file=f)
86
+ print('Omega Values', file=f)
87
+ print(omega_vals, file=f)
88
+ print('Omega Check (value of 0 indicates Omega is at local extremum)', file=f)
89
+ print(omega_check, file=f)
90
+ #---------------------------------------------------------------------------------------------------------------------#
91
+ #------------------------------------------------------ METHODS ------------------------------------------------------#
92
+ #---------------------------------------------------------------------------------------------------------------------#
93
+ # method transforming reciprocal space wfks to real space
94
+ def _FindStates(
95
+ self, energy_level:float, width:float, wfks:Generator[wc.WFK,None,None]
96
+ ):
97
+ # loop through every state
98
+ for state in wfks:
99
+ # check if state has a band that crosses the width
100
+ min_en = state.fermi_energy + energy_level - width/2
101
+ max_en = state.fermi_energy + energy_level + width/2
102
+ for i, band in enumerate(state.eigenvalues):
103
+ if min_en <= band <= max_en:
104
+ self.found_states += 1
105
+ coeffs = copy(state)
106
+ coeffs.wfk_coeffs = coeffs.wfk_coeffs[i]
107
+ for wfk in self._Process(coeffs):
108
+ self.bandu_fxns.append(wfk)
109
+ if self.bandu_fxns is []:
110
+ raise ValueError(
111
+ '''Identified 0 states within provided width.
112
+ Action: Increase width or increase fineness of kpoint grid.
113
+ ''')
114
+ #-----------------------------------------------------------------------------------------------------------------#
115
+ # method for processing planewave coefficient data from FindStates
116
+ def _Process(
117
+ self, state:wc.WFK
118
+ )->Generator[wc.WFK,None,None]:
119
+ funcs:list[wc.WFK] = []
120
+ # generate symmetrically equivalent coefficients
121
+ if self.sym:
122
+ for sym_coeffs in state.SymWFKs(kpoint=state.kpoints):
123
+ self.found_states += 1
124
+ funcs.append(sym_coeffs)
125
+ self.found_states -= 1
126
+ else:
127
+ # shift point back into Brillouin Zone as necessary
128
+ rec_latt = state.rec_latt
129
+ shift = brlzn.BZ(rec_latt=rec_latt).GetShifts(state.kpoints)
130
+ state.pw_indices += shift
131
+ state.kpoints = state.kpoints.reshape((-1,3)) - shift
132
+ funcs.append(state)
133
+ # apply desired transformations
134
+ for wfk in funcs:
135
+ if self.grid:
136
+ wfk = wfk.GridWFK()
137
+ if self.ifft:
138
+ wfk = wfk.IFFT()
139
+ if self.fft:
140
+ wfk = wfk.FFT()
141
+ if self.norm:
142
+ wfk = wfk.Normalize()
143
+ yield wfk
144
+ #-----------------------------------------------------------------------------------------------------------------#
145
+ # check if states along BZ edge have been collected
146
+ # --not functional--
147
+ def _CheckEdgeCase(
148
+ self
149
+ ):
150
+ bz = brlzn.BZ(self.bandu_fxns[0].rec_latt)
151
+ x = self.bandu_fxns[0].ngfftx
152
+ y = self.bandu_fxns[0].ngffty
153
+ z = self.bandu_fxns[0].ngfftz
154
+ all_shifts:list[np.ndarray] = []
155
+ wfk_to_shift:list[int] = []
156
+ dupes:list[wc.WFK] = []
157
+ for i, wfk in enumerate(self.bandu_fxns):
158
+ # if point on edge, translate to find other periodic points
159
+ edge_pts = bz._BZEdgePt(wfk.kpoints)
160
+ print(wfk.kpoints)
161
+ print(edge_pts)
162
+ if len(edge_pts[edge_pts > 0]) > 0:
163
+ shifts, _ = trnslt.TranslatePoints(np.zeros((1,3)), np.zeros(1), np.identity(3))
164
+ shifts = shifts[edge_pts > 0].astype(int)
165
+ all_shifts.append(shifts)
166
+ wfk_to_shift.append(i)
167
+ for i in wfk_to_shift:
168
+ for shifts in all_shifts:
169
+ for shift in shifts:
170
+ shifted_wfk = copy(self.bandu_fxns[i])
171
+ shifted_wfk = shifted_wfk.RemoveGrid()
172
+ shifted_wfk.pw_indices += shift
173
+ shifted_wfk = shifted_wfk.GridWFK()
174
+ shifted_wfk.wfk_coeffs = shifted_wfk.wfk_coeffs.reshape((1,x*y*z))
175
+ shifted_dupe = copy(self.bandu_fxns)
176
+ shifted_dupe[i] = shifted_wfk
177
+ dupes.extend(shifted_dupe)
178
+ self.duped_states += self.found_states
179
+ self.bandu_fxns.extend(dupes)
180
+ #-----------------------------------------------------------------------------------------------------------------#
181
+ # find principal components
182
+ def _PrincipalComponents(
183
+ self
184
+ )->np.ndarray:
185
+ total_states = self.found_states
186
+ # organize wfk coefficients
187
+ x = self.bandu_fxns[0].ngfftx
188
+ y = self.bandu_fxns[0].ngffty
189
+ z = self.bandu_fxns[0].ngfftz
190
+ mat = np.zeros((total_states,x*y*z), dtype=complex)
191
+ for i in range(total_states):
192
+ mat[i,:] = self.bandu_fxns[i].wfk_coeffs.reshape((1,x*y*z)).real
193
+ # compute overlap matrix
194
+ print('Computing overlap matrix')
195
+ overlap_mat = np.matmul(np.conj(mat), mat.T)
196
+ # diagonlize matrix
197
+ principal_vals, principal_vecs = np.linalg.eig(overlap_mat)
198
+ principal_vecs = principal_vecs.T
199
+ # organize eigenvectors and eigenvalues
200
+ sorted_inds = np.flip(principal_vals.argsort())
201
+ principal_vals = np.take(principal_vals, sorted_inds)
202
+ principal_vecs = np.take(principal_vecs, sorted_inds, axis=0)
203
+ mat = np.matmul(principal_vecs, mat)
204
+ for i in range(total_states):
205
+ self.bandu_fxns[i].wfk_coeffs = mat[i,:]
206
+ return principal_vals
207
+ #-----------------------------------------------------------------------------------------------------------------#
208
+ # find principal components with real and imaginary separated
209
+ def _RealImagPrnplComps(
210
+ self
211
+ )->np.ndarray:
212
+ x = self.bandu_fxns[0].ngfftx
213
+ y = self.bandu_fxns[0].ngffty
214
+ z = self.bandu_fxns[0].ngfftz
215
+ mat = np.zeros((2*self.found_states,x*y*z), dtype=complex)
216
+ for i in range(self.found_states):
217
+ ind = 2*i
218
+ real_coeffs = self.bandu_fxns[i].wfk_coeffs.real
219
+ imag_coeffs = self.bandu_fxns[i].wfk_coeffs.imag
220
+ mat[ind,:] = real_coeffs.reshape((1,x*y*z))
221
+ mat[ind+1,:] = imag_coeffs.reshape((1,x*y*z))
222
+ print('Computing overlap matrix')
223
+ overlap_mat = np.matmul(np.conj(mat),mat.T)
224
+ # diagonlize matrix
225
+ principal_vals, principal_vecs = np.linalg.eig(overlap_mat)
226
+ principal_vecs = principal_vecs.T
227
+ # organize eigenvectors and eigenvalues
228
+ sorted_inds = np.flip(principal_vals.argsort())
229
+ principal_vals = np.take(principal_vals, sorted_inds)
230
+ principal_vecs = np.take(principal_vecs, sorted_inds, axis=0)
231
+ mat = np.matmul(principal_vecs, mat)
232
+ for i in range(2*self.found_states):
233
+ if i < self.found_states:
234
+ self.bandu_fxns[i].wfk_coeffs = mat[i,:]
235
+ else:
236
+ new_coeffs = copy(self.bandu_fxns[0])
237
+ new_coeffs.wfk_coeffs = mat[i,:]
238
+ self.bandu_fxns.append(new_coeffs)
239
+ self.found_states *= 2
240
+ return principal_vals
241
+ #-----------------------------------------------------------------------------------------------------------------#
242
+ # find ratio of real and imaginary components
243
+ def _CheckOmega(
244
+ self
245
+ )->tuple[np.ndarray, np.ndarray]:
246
+ total_states = self.found_states
247
+ omega_vals = np.zeros((total_states, 3), dtype=float)
248
+ vals = np.linspace(start=-0.01, stop=0.01, num=3)
249
+ for i, val in enumerate(vals):
250
+ for j in range(total_states):
251
+ coeffs:np.ndarray = copy(self.bandu_fxns[j].wfk_coeffs)
252
+ coeffs *= np.exp(1j*val*np.pi)
253
+ omega = np.sum(coeffs.real*coeffs)/np.sum(coeffs.imag*coeffs)
254
+ omega = np.abs(omega)
255
+ omega_vals[j,i] = omega
256
+ omega_diff1 = (omega_vals[:,1] - omega_vals[:,0])
257
+ omega_diff2 = (omega_vals[:,2] - omega_vals[:,1])
258
+ omega_check = np.sign(omega_diff1) + np.sign(omega_diff2)
259
+ return omega_vals[:,1], omega_check
260
+ #-----------------------------------------------------------------------------------------------------------------#
261
+ # plot eigenvalues from PCA
262
+ def _PlotEigs(
263
+ self, eigvals:np.ndarray
264
+ ):
265
+ import matplotlib.pyplot as plt
266
+ x = np.arange(self.found_states) + 1
267
+ y = np.abs(eigvals.flatten())
268
+ figsize = (12,6)
269
+ fig = plt.figure(figsize=figsize)
270
+ ax = fig.add_subplot()
271
+ ax.plot(
272
+ x,
273
+ y,
274
+ color='black',
275
+ linestyle='-',
276
+ marker='o',
277
+ mfc='red',
278
+ markersize=8
279
+ )
280
+ ax.spines[['right','top']].set_visible(False)
281
+ ax.tick_params(axis='both', labelsize=12)
282
+ plt.xlim(1.0,len(y)+5.0)
283
+ plt.ylim=(0,np.max(y))
284
+ mod_val = round(self.found_states/5 - 0.5)
285
+ plt.xticks(ticks=[val for val in range(0,len(y)+1) if val % mod_val == 0])
286
+ plt.rcParams['font.family'] = 'Times New Roman'
287
+ plt.savefig('bandu_eigenvalues.png',dpi=500)
288
+ #-----------------------------------------------------------------------------------------------------------------#
289
+ # make xsf of BandU functions
290
+ def ToXSF(
291
+ self, nums:list[int]=[], xsf_name:str='Principal_orbital_component'
292
+ ):
293
+ total_states = self.found_states
294
+ if nums is []:
295
+ nums = [1,total_states]
296
+ else:
297
+ # check if list has only 2 elements
298
+ if len(nums) != 2:
299
+ raise ValueError(f'nums should contain two values, {len(nums)} were received.')
300
+ # check if function number is within defined range
301
+ if nums[0] < 1:
302
+ print('First element of nums cannot be lower than 1, changing to 1 now.')
303
+ nums[0] = 1
304
+ # update function number list if it exceeds maximum number of bandu functions
305
+ if nums[1] > total_states:
306
+ print(f'Printing up to max Band-U function number: {total_states}')
307
+ nums[1] = total_states
308
+ # check if lower limit is within defined range
309
+ if nums[0] > nums[1]:
310
+ nums[0] = nums[1]
311
+ print(f'Writing XSF files for Band-U functions {nums[0]} through {nums[1]}.')
312
+ # write xsf files
313
+ x = self.bandu_fxns[0].ngfftx
314
+ y = self.bandu_fxns[0].ngffty
315
+ z = self.bandu_fxns[0].ngfftz
316
+ for i in range(nums[0]-1, nums[1]):
317
+ file_name = xsf_name + f'_{i+1}'
318
+ wfk = copy(self.bandu_fxns[i])
319
+ wfk.wfk_coeffs = wfk.wfk_coeffs.reshape((x,y,z))
320
+ wfk = wfk.XSFFormat()
314
321
  wfk.WriteXSF(xsf_file=file_name)