bandu 1.3.6__py3-none-any.whl → 1.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bandu/wfk_class.py CHANGED
@@ -1,557 +1,558 @@
1
- import numpy as np
2
- from scipy.fft import fftn, ifftn
3
- import sys
4
- from typing import Self, Generator
5
- from copy import copy
6
- from . import brillouin_zone as brlzn
7
- np.set_printoptions(threshold=sys.maxsize)
8
-
9
- class WFK():
10
- '''
11
- A class for working with wavefunctions from DFT calculations
12
-
13
- Parameters
14
- ----------
15
- wfk_coeffs : np.ndarray
16
- The planewave coefficients of the wavefunction
17
- These should be complex values
18
- kpoints : np.ndarray
19
- A multidimensional array of 3D kpoints
20
- Entries along axis 0 should be individual kpoints
21
- Entries along axis 1 should be the kx, ky, and kz components, in that order
22
- The kpoints should be in reduced form
23
- pw_indices : np.ndarray
24
- Array of H, K, L indices for the planwave basis set.
25
- Arrays of (1,3) [H,K,L] should fill axis 0, and H, K, L values fill axis 1 in that order.
26
- Necessary for arranging wavefunction coefficients in 3D array.
27
- syrmel : np.ndarray
28
- A multidimensional array of 3x3 arrays of symmetry operations
29
- non_symm_vec : np.ndarray
30
- A multidimensional array of 1x3 arrays of the nonsymmorphic translation vectors for each-
31
- symmetry operation.
32
- nsym : int
33
- Total number of symmetry operations
34
- nkpt : int
35
- Total number of kpoints
36
- If a kpoints array is provided, then nkpt will be acquired the its length
37
- nbands : int
38
- Total number of bands
39
- ngfftx : int
40
- x dimension of Fourier transform grid
41
- ngffty : int
42
- y dimension of Fourier transform grid
43
- ngfftz : int
44
- z dimension of Fourier transform grid
45
- eigenvalues : list
46
- List of the eigenvalues for wavefunction at each band
47
- Should be ordered from least -> greatest
48
- fermi_energy : float
49
- Fermi energy
50
- lattice : np.ndarray
51
- 3x3 array containing lattice parameters
52
- natom : int
53
- Total number of atoms in unit cell
54
- xred : np.ndarray
55
- Reduced coordinates of all atoms in unit cell
56
- Individual atomic coordinates fill along axis 0
57
- X, Y, and Z components fill along axis 1, in that order
58
- typat : list
59
- Numeric labels starting from 1 and incrementing up to natom
60
- Order of labels should follow xred
61
- znucltypat : list
62
- List of element names
63
- First element of list should correspond to typat label 1, second element to label 2 and so on
64
- time_reversal : bool
65
- Select whether system has time reversal symmetry or not
66
- If the system is time reversal symmetric, then reciprocal space electronic states will share inversion
67
- symmetry even if the real space symmetries do not include inversion
68
- Default assumes noncentrosymmetric systems have time reversal symmetry (True)
69
-
70
- Methods
71
- -------
72
- GridWFK
73
- Assembles plane wave coefficients on FFT grid
74
- RemoveGrid
75
- Undoes FFT grid and returns coefficients to a flat array
76
- FFT
77
- Applies Fast Fourier Transform to plane wave coefficients
78
- IFFT
79
- Applies Inverse Fast Fourier Transform to plane wave coefficients
80
- Normalize
81
- Calculates and applies normalization factor to coefficients
82
- Real2Reciprocal
83
- Calculates reciprocal lattice vectors from real space vectors
84
- Symmetrize
85
- Generates symmetrical copies from symmetry matrix operations
86
- SymWFK
87
- Generates symmetrical plane wave coefficients from operations
88
- XSFFormat
89
- Converts plane wave coefficients grid into XSF formatted grid
90
- RemoveXSF
91
- Converts XSF formatted grid into regular FFT grid
92
- WriteXSF
93
- Prints out XSF files for both the real and imaginary parts of the coefficients
94
- '''
95
- def __init__(
96
- self,
97
- wfk_coeffs:np.ndarray=np.zeros(1), kpoints:np.ndarray=np.zeros(1), symrel:np.ndarray=np.zeros(1),
98
- nsym:int=0, nkpt:int=0, nbands:int=0, ngfftx:int=0, ngffty:int=0, ngfftz:int=0,
99
- eigenvalues:np.ndarray=np.zeros(1),fermi_energy:float=0.0, lattice:np.ndarray=np.zeros(1), natom:int=0,
100
- xred:np.ndarray=np.zeros(1), typat:list=[], znucltypat:list=[], pw_indices:np.ndarray=np.zeros(1),
101
- non_symm_vecs:np.ndarray=np.zeros(1), time_reversal:bool=True
102
- )->None:
103
- self.wfk_coeffs=wfk_coeffs
104
- self.kpoints=kpoints
105
- self.pw_indices=pw_indices
106
- self.symrel=symrel
107
- self.nsym=nsym
108
- self.non_symm_vecs=non_symm_vecs
109
- self.nkpt=nkpt
110
- self.nbands=nbands
111
- self.ngfftx=ngfftx
112
- self.ngffty=ngffty
113
- self.ngfftz=ngfftz
114
- self.eigenvalues=eigenvalues
115
- self.fermi_energy=fermi_energy
116
- self.lattice=lattice
117
- self.natom=natom
118
- self.xred=xred
119
- self.typat=typat
120
- self.znucltypat=znucltypat
121
- self.time_reversal=time_reversal
122
- #---------------------------------------------------------------------------------------------------------------------#
123
- #------------------------------------------------------ METHODS ------------------------------------------------------#
124
- #---------------------------------------------------------------------------------------------------------------------#
125
- # method for putting plane wave coefficients onto 3D gridded array
126
- def GridWFK(
127
- self, band_index:int=-1
128
- )->Self:
129
- '''
130
- Returns copy of WFK object with coefficients in numpy 3D array grid.
131
- Grid is organized in (ngfftz, ngfftx, ngffty) dimensions.
132
- Where ngfft_ represents the _ Fourier transform grid dimension.
133
-
134
- Parameters
135
- ----------
136
- band_index : int
137
- Integer represent the band index of the wavefunction coefficients to be transformed.
138
- If nothing is passed, it is assumed the coefficients of a single band are supplied.
139
- '''
140
- # initialize 3D grid
141
- gridded_wfk = np.zeros((self.ngfftx, self.ngffty, self.ngfftz), dtype=complex)
142
- # update grid with wfk coefficients
143
- for k, kpt in enumerate(self.pw_indices):
144
- kx = kpt[0]
145
- ky = kpt[1]
146
- kz = kpt[2]
147
- if band_index >= 0:
148
- gridded_wfk[kx, ky, kz] = self.wfk_coeffs[band_index][k]
149
- else:
150
- gridded_wfk[kx, ky, kz] = self.wfk_coeffs[k]
151
- new_WFK = copy(self)
152
- new_WFK.wfk_coeffs = gridded_wfk
153
- return new_WFK
154
- #-----------------------------------------------------------------------------------------------------------------#
155
- # method for undoing grid
156
- def RemoveGrid(
157
- self, band_index:int=-1
158
- )->Self:
159
- '''
160
- Returns copy of WFK object with coefficients removed from the 3D gridded array.
161
-
162
- Parameters
163
- ----------
164
- band_index : int
165
- Integer represent the band index of the wavefunction coefficients to be transformed.
166
- If nothing is passed, it is assumed the coefficients of a single band are supplied.
167
- '''
168
- # check if coefficients are gridded before undoing grid format
169
- if self.wfk_coeffs.shape != (self.ngfftx,self.ngffty,self.ngfftz):
170
- raise ValueError((
171
- f'Plane wave coefficients must be in 3D grid with shape ({self.ngfftx}, {self.ngffty}, {self.ngfftz})'
172
- ' in order to remove the gridded format'
173
- ))
174
- if band_index >= 0:
175
- coeffs_no_grid = self.wfk_coeffs[band_index]
176
- else:
177
- coeffs_no_grid = self.wfk_coeffs
178
- # returns values at each plane wave index, undoing grid
179
- coeffs_no_grid = coeffs_no_grid[tuple(self.pw_indices.T)]
180
- new_WFK = copy(self)
181
- new_WFK.wfk_coeffs = coeffs_no_grid
182
- return new_WFK
183
- #-----------------------------------------------------------------------------------------------------------------#
184
- # method transforming reciprocal space wfks to real space
185
- def FFT(
186
- self
187
- )->Self:
188
- '''
189
- Returns copy of WFK with wavefunction coefficients expressed in real space.
190
- Assumes existing wavefunction coefficients are expressed in reciprocal space.
191
- '''
192
- # Fourier transform real grid to reciprocal grid
193
- reciprocal_coeffs = fftn(self.wfk_coeffs, norm='ortho')
194
- new_WFK = copy(self)
195
- new_WFK.wfk_coeffs = np.array(reciprocal_coeffs).reshape((self.ngfftx, self.ngffty, self.ngfftz))
196
- return new_WFK
197
- #-----------------------------------------------------------------------------------------------------------------#
198
- # method transforming real space wfks to reciprocal space
199
- def IFFT(
200
- self
201
- )->Self:
202
- '''
203
- Returns copy of WFK with wavefunction coefficients in expressed in reciprocal space.
204
- Assumes existing wavefunction coefficients are expressed in real space.
205
- '''
206
- # Fourier transform reciprocal grid to real grid
207
- real_coeffs = ifftn(self.wfk_coeffs, norm='ortho')
208
- new_WFK = copy(self)
209
- new_WFK.wfk_coeffs = np.array(real_coeffs).reshape((self.ngfftx,self.ngffty,self.ngfftz))
210
- return new_WFK
211
- #-----------------------------------------------------------------------------------------------------------------#
212
- # method for normalizing wfks
213
- def Normalize(
214
- self
215
- )->Self:
216
- '''
217
- Returns copy of WFK object with normalized wavefunction coefficients such that <psi|psi> = 1.
218
- '''
219
- coeffs = np.array(self.wfk_coeffs)
220
- # calculate normalization constant and apply to wfk
221
- norm = np.dot(coeffs.flatten(), np.conj(coeffs).flatten())
222
- norm = np.sqrt(norm)
223
- new_WFK = copy(self)
224
- new_WFK.wfk_coeffs /= norm
225
- return new_WFK
226
- #-----------------------------------------------------------------------------------------------------------------#
227
- # method for converting real space lattice vectors to reciprocal space vectors
228
- def Real2Reciprocal(
229
- self
230
- )->np.ndarray:
231
- '''
232
- Method for converting the real space lattice parameters to reciprocal lattice parameters.
233
- '''
234
- # conversion by default converts Angstrom to Bohr since ABINIT uses Bohr
235
- a = self.lattice[0,:]
236
- b = self.lattice[1,:]
237
- c = self.lattice[2,:]
238
- vol = np.dot(a,np.cross(b,c))
239
- b1 = 2*np.pi*(np.cross(b,c))/vol
240
- b2 = 2*np.pi*(np.cross(c,a))/vol
241
- b3 = 2*np.pi*(np.cross(a,b))/vol
242
- return np.array([b1,b2,b3]).reshape((3,3))
243
- #-----------------------------------------------------------------------------------------------------------------#
244
- # method for checking for time reversal symmetry
245
- def _CheckTimeRevSym(
246
- self
247
- ):
248
- if self.time_reversal:
249
- # if system is centrosymmetric, do not double reciprocal symmetry operations
250
- if -3.0 in [np.trace(mat) for mat in self.symrel]:
251
- self.time_reversal = False
252
- else:
253
- print((
254
- 'Noncentrosymmetric system identified, assuming time reversal symmetry\n'
255
- 'To change this, set time_reversal attribute to False'
256
- ))
257
- #-----------------------------------------------------------------------------------------------------------------#
258
- # method for finding symmetrically distinct k points
259
- def _FindOrbit(
260
- self, sym_kpts:np.ndarray
261
- )->tuple[list,list]:
262
- sym_kpts = np.round(sym_kpts, decimals=15)
263
- _, unique_inds = np.unique(sym_kpts, return_index=True, axis=0)
264
- # for each unique kpoint check original point is related by reciprocal lattice vector
265
- dupes = []
266
- for i, ind1 in enumerate(unique_inds):
267
- if i in dupes:
268
- continue
269
- for j, ind2 in enumerate(unique_inds):
270
- if i == j or j in dupes:
271
- continue
272
- diff = np.abs(sym_kpts[ind1] - sym_kpts[ind2])
273
- diff[diff < 10**(-12)] = 0.0
274
- diff[diff > 0.999] = 1.0
275
- mask = np.isin(diff, np.array([0.0,1.0]))
276
- if mask.all():
277
- dupes.append(j)
278
- return dupes, unique_inds.tolist()
279
- #-----------------------------------------------------------------------------------------------------------------#
280
- # function for calculating phase imparted by nonsymmorphic translation
281
- def _FindPhase(
282
- self, nonsymmvec:np.ndarray, g_vecs:np.ndarray, kpt:np.ndarray
283
- )->np.ndarray:
284
- if self.non_symm_vecs is np.zeros(1):
285
- return np.ones(len(g_vecs))
286
- elif np.sum(np.abs(nonsymmvec)) < 10**(-8):
287
- return np.ones(len(g_vecs))
288
- else:
289
- return np.exp(-1j*np.dot((kpt+g_vecs), nonsymmvec.T))
290
- #-----------------------------------------------------------------------------------------------------------------#
291
- # method for creating symmetrically equivalent points
292
- def Symmetrize(
293
- self, points:np.ndarray, values:np.ndarray=np.empty([]), unique:bool=True, reciprocal:bool=False,
294
- inverse:bool=False
295
- )->tuple[np.ndarray, np.ndarray]:
296
- '''
297
- Method for generating symmetric data from irreducible data.
298
-
299
- Parameters
300
- ----------
301
- points : np.ndarray
302
- Irreducible set of points.
303
- Shape of (N,3).
304
- values : np.ndarray
305
- Values corresponding to irreducible points (such as energy eigenvalues w/ kpoints).
306
- Shape of (N,1).
307
- unique : bool
308
- Check for duplicate points.
309
- Default is to check (True).
310
- reciprocal : bool
311
- Calculate reciprocal space symmetry matrices from real space matrices.
312
- Default uses real space matrices (False).
313
- inverse : bool
314
- Use inverse symmetry operations.
315
- Default applies forwards operation (False).
316
- '''
317
- # check if reciprocal or real space symmetries will be used
318
- sym_num = self.nsym
319
- if reciprocal:
320
- # nosymmorphic translations do not apply to reciprocal space
321
- tnons = False
322
- sym_mats = [np.linalg.inv(mat).T for mat in self.symrel]
323
- # time reversal only adds to reciprocal space symmetries
324
- if self.time_reversal:
325
- sym_mats = np.concatenate((sym_mats, [-mat for mat in sym_mats]), axis=0)
326
- sym_num *= 2
327
- else:
328
- tnons = True
329
- sym_mats = self.symrel
330
- # initialize symmetrically equivalent point and value arrays
331
- if len(points.shape) == 1:
332
- points.reshape((1,points.shape[0]))
333
- ind_len = np.shape(points)[0]
334
- if values is np.empty([]):
335
- values = np.zeros((ind_len,1))
336
- sym_pts = np.zeros((sym_num*ind_len,3))
337
- sym_vals = np.zeros((sym_num*ind_len,self.nbands))
338
- if self.non_symm_vecs.all() == np.zeros(1):
339
- self.non_symm_vecs = np.zeros(self.nsym)
340
- # apply symmetry operations to points
341
- if inverse:
342
- for i, op in enumerate(sym_mats):
343
- if tnons:
344
- points += self.non_symm_vecs[i]
345
- new_pts:np.ndarray = np.matmul(np.linalg.inv(op), points.T).T
346
- sym_pts[i*ind_len:(i+1)*ind_len,:] = new_pts
347
- sym_vals[i*ind_len:(i+1)*ind_len,:] = values
348
- else:
349
- for i, op in enumerate(sym_mats):
350
- if tnons:
351
- points += self.non_symm_vecs[i]
352
- new_pts:np.ndarray = np.matmul(op, points.T).T
353
- sym_pts[i*ind_len:(i+1)*ind_len,:] = new_pts
354
- sym_vals[i*ind_len:(i+1)*ind_len,:] = values
355
- # points overlap on at edges of each symmetric block, remove duplicates
356
- if unique:
357
- dupes, unique_inds = self._FindOrbit(sym_pts)
358
- unique_kpts = np.array([sym_pts[ind,:] for i, ind in enumerate(unique_inds) if i not in dupes])
359
- unique_vals = np.array([sym_vals[ind,:] for i, ind in enumerate(unique_inds) if i not in dupes])
360
- return unique_kpts, unique_vals
361
- return sym_pts, sym_vals
362
- #-----------------------------------------------------------------------------------------------------------------#
363
- # method for creating symmetrically equivalent functions at specified kpoint
364
- def SymWFKs(
365
- self, kpoint:np.ndarray, band:int=-1
366
- )->Generator[Self, None, None]:
367
- '''
368
- Method for generating wavefunction planewave coefficients from coefficients of the irreducible BZ.
369
-
370
- Parameters
371
- ----------
372
- kpoint : np.ndarray
373
- A single reciprocal space point is provided to generate symmetrically equivalent coefficients.
374
- Shape (1,3).
375
- band : int
376
- Choose which band to pull coefficients from (indexed starting from zero).
377
- Default assumes coefficients from a single band are provided (-1).
378
- '''
379
- # find symmetric kpoints
380
- kpoint = kpoint.reshape((1,3))
381
- sym_kpoints, _ = self.Symmetrize(kpoint, unique=False, reciprocal=True)
382
- dupes, unique_inds = self._FindOrbit(sym_kpoints)
383
- # find symmetric planewave indices
384
- sym_pw_inds, _ = self.Symmetrize(self.pw_indices, unique=False, reciprocal=True)
385
- sym_pw_inds = sym_pw_inds.astype(int)
386
- ind_range = self.pw_indices.shape[0]
387
- # find reciprocal lattice shifts to move all points into BZ
388
- bz = brlzn.BZ(rec_latt=self.Real2Reciprocal())
389
- shifts = bz.GetShifts(sym_kpoints)
390
- # create WFK copies with new planewave indices
391
- for i, ind in enumerate(unique_inds):
392
- if i in dupes:
393
- continue
394
- ind1 = ind*ind_range
395
- ind2 = (ind+1)*ind_range
396
- new_pw_inds = sym_pw_inds[ind1:ind2,:]
397
- new_pw_inds += shifts[ind,:]
398
- new_coeffs = copy(self)
399
- new_coeffs.pw_indices = new_pw_inds
400
- new_coeffs.kpoints = sym_kpoints[ind,:] - shifts[ind,:]
401
- phase_factor = self._FindPhase(
402
- self.non_symm_vecs[ind % len(self.non_symm_vecs)],
403
- self.pw_indices,
404
- sym_kpoints[ind,:]
405
- )
406
- if band >= 0:
407
- new_coeffs.wfk_coeffs = new_coeffs.wfk_coeffs[band] * phase_factor
408
- yield new_coeffs
409
- else:
410
- new_coeffs.wfk_coeffs *= phase_factor
411
- yield new_coeffs
412
- #-----------------------------------------------------------------------------------------------------------------#
413
- # method that returns BZ kpoints and eigenvalues
414
- def GetBZPtsEigs(
415
- self
416
- )->tuple[np.ndarray,np.ndarray]:
417
- bz = brlzn.BZ(rec_latt=self.Real2Reciprocal())
418
- bz_kpts, bz_eigs = self.Symmetrize(points=self.kpoints, values=self.eigenvalues, reciprocal=True)
419
- bz_kpts -= bz.GetShifts(bz_kpts)
420
- return bz_kpts, bz_eigs
421
- #-----------------------------------------------------------------------------------------------------------------#
422
- # method for expanding a grid into XSF format
423
- def XSFFormat(
424
- self
425
- )->Self:
426
- '''
427
- Returns copy of WFK object XSF formatted coefficients.
428
- Requires wfk_coeffs to be in gridded format, i.e. (ngfftz, ngfftx, ngffty) shape.
429
- '''
430
- # append zeros to ends of all axes in grid_wfk
431
- # zeros get replaced by values at beginning of each axis
432
- # this repetition is required by XSF format
433
- if np.shape(self.wfk_coeffs) != (self.ngfftx, self.ngffty, self.ngfftz):
434
- raise ValueError(
435
- f'''Passed array is not the correct shape:
436
- Expected: ({self.ngfftx}, {self.ngffty}, {self.ngfftz}),
437
- Received: {np.shape(self.wfk_coeffs)}
438
- ''')
439
- else:
440
- grid_wfk = self.wfk_coeffs
441
- grid_wfk = np.append(grid_wfk, np.zeros((1, self.ngffty, self.ngfftz)), axis=0)
442
- grid_wfk = np.append(grid_wfk, np.zeros((self.ngfftx+1, 1, self.ngfftz)), axis=1)
443
- grid_wfk = np.append(grid_wfk, np.zeros((self.ngfftx+1, self.ngffty+1, 1)), axis=2)
444
- for x in range(self.ngfftx+1):
445
- for y in range(self.ngffty+1):
446
- for z in range(self.ngfftz+1):
447
- if x == self.ngfftx:
448
- grid_wfk[x,y,z] = grid_wfk[0,y,z]
449
- if y == self.ngffty:
450
- grid_wfk[x,y,z] = grid_wfk[x,0,z]
451
- if z == self.ngfftz:
452
- grid_wfk[x,y,z] = grid_wfk[x,y,0]
453
- if x == self.ngfftx and y == self.ngffty:
454
- grid_wfk[x,y,z] = grid_wfk[0,0,z]
455
- if x == self.ngfftx and z == self.ngfftz:
456
- grid_wfk[x,y,z] = grid_wfk[0,y,0]
457
- if z == self.ngfftz and y == self.ngffty:
458
- grid_wfk[x,y,z] = grid_wfk[x,0,0]
459
- if x == self.ngfftx and y == self.ngffty and z == self.ngfftz:
460
- grid_wfk[x,y,z] = grid_wfk[0,0,0]
461
- new_WFK = copy(self)
462
- new_WFK.wfk_coeffs = grid_wfk
463
- new_WFK.ngfftx += 1
464
- new_WFK.ngffty += 1
465
- new_WFK.ngfftz += 1
466
- return new_WFK
467
- #-----------------------------------------------------------------------------------------------------------------#
468
- # method removing XSF formatting from density grid
469
- def RemoveXSF(
470
- self
471
- )->Self:
472
- '''
473
- Returns copy of WFK object without XSF formatting.
474
- '''
475
- grid = self.wfk_coeffs
476
- # to_be_del will be used to remove all extra data points added for XSF formatting
477
- to_be_del = np.ones((self.ngfftx, self.ngffty, self.ngfftz), dtype=bool)
478
- for z in range(self.ngfftz):
479
- for y in range(self.ngffty):
480
- for x in range(self.ngfftx):
481
- # any time you reach the last density point it is a repeat of the first point
482
- # remove the end points along each axis
483
- if y == self.ngffty - 1 or x == self.ngfftx - 1 or z == self.ngfftz - 1:
484
- to_be_del[x,y,z] = False
485
- # remove xsf entries from array
486
- grid = grid[to_be_del]
487
- # restore grid shape
488
- grid = grid.reshape((self.ngfftx-1, self.ngffty-1, self.ngfftz-1))
489
- new_WFK = copy(self)
490
- new_WFK.wfk_coeffs = grid
491
- new_WFK.ngfftx -= 1
492
- new_WFK.ngffty -= 1
493
- new_WFK.ngfftz -= 1
494
- return new_WFK
495
- #-----------------------------------------------------------------------------------------------------------------#
496
- # method for writing wavefunctions to XSF file
497
- def WriteXSF(
498
- self, xsf_file:str, _component:bool=True
499
- )->None:
500
- '''
501
- A method for writing numpy grids to an XSF formatted file.
502
-
503
- Parameters
504
- ----------
505
- xsf_file : str
506
- The file name.
507
- '''
508
- # first run writes out real part of eigenfunction to xsf
509
- if _component:
510
- xsf_file += '_real.xsf'
511
- # second run writes out imaginary part
512
- else:
513
- xsf_file += '_imag.xsf'
514
- with open(xsf_file, 'w') as xsf:
515
- print('DIM-GROUP', file=xsf)
516
- print('3 1', file=xsf)
517
- print('PRIMVEC', file=xsf)
518
- print(f'{self.lattice[0,0]} {self.lattice[0,1]} {self.lattice[0,2]}', file=xsf)
519
- print(f'{self.lattice[1,0]} {self.lattice[1,1]} {self.lattice[1,2]}', file=xsf)
520
- print(f'{self.lattice[2,0]} {self.lattice[2,1]} {self.lattice[2,2]}', file=xsf)
521
- print('PRIMCOORD', file=xsf)
522
- print(f'{self.natom} 1', file=xsf)
523
- for i, coord in enumerate(self.xred):
524
- atomic_num = int(self.znucltypat[self.typat[i] - 1])
525
- cart_coord = np.matmul(coord, self.lattice)
526
- print(f'{atomic_num} {cart_coord[0]} {cart_coord[1]} {cart_coord[2]}', file=xsf)
527
- print('ATOMS', file=xsf)
528
- for i, coord in enumerate(self.xred):
529
- atomic_num = int(self.znucltypat[self.typat[i] - 1])
530
- cart_coord = np.matmul(coord, self.lattice)
531
- print(f'{atomic_num} {cart_coord[0]} {cart_coord[1]} {cart_coord[2]}', file=xsf)
532
- print('BEGIN_BLOCK_DATAGRID_3D', file=xsf)
533
- print('datagrids', file=xsf)
534
- print('BEGIN_DATAGRID_3D_principal_orbital_component', file=xsf)
535
- print(f'{self.ngfftx} {self.ngffty} {self.ngfftz}', file=xsf)
536
- print('0.0 0.0 0.0', file=xsf)
537
- print(f'{self.lattice[0,0]} {self.lattice[0,1]} {self.lattice[0,2]}', file=xsf)
538
- print(f'{self.lattice[1,0]} {self.lattice[1,1]} {self.lattice[1,2]}', file=xsf)
539
- print(f'{self.lattice[2,0]} {self.lattice[2,1]} {self.lattice[2,2]}', file=xsf)
540
- count = 0
541
- for z in range(self.ngfftz):
542
- for y in range(self.ngffty):
543
- for x in range(self.ngfftx):
544
- count += 1
545
- if _component:
546
- print(self.wfk_coeffs[x,y,z].real, file=xsf, end=' ')
547
- else:
548
- print(self.wfk_coeffs[x,y,z].imag, file=xsf, end=' ')
549
- if count == 6:
550
- count = 0
551
- print('\n', file=xsf, end='')
552
- print('END_DATAGRID_3D', file=xsf)
553
- print('END_BLOCK_DATAGRID_3D', file=xsf)
554
- # rerun method to write out imaginary part
555
- if _component:
556
- xsf_file = xsf_file.split('_real')[0]
1
+ import numpy as np
2
+ from scipy.fft import fftn, ifftn
3
+ import sys
4
+ from typing import Self, Generator
5
+ from copy import copy
6
+ from . import brillouin_zone as brlzn
7
+ np.set_printoptions(threshold=sys.maxsize)
8
+
9
+ class WFK():
10
+ '''
11
+ A class for working with wavefunctions from DFT calculations
12
+
13
+ Parameters
14
+ ----------
15
+ wfk_coeffs : np.ndarray
16
+ The planewave coefficients of the wavefunction
17
+ These should be complex values
18
+ kpoints : np.ndarray
19
+ A multidimensional array of 3D kpoints
20
+ Entries along axis 0 should be individual kpoints
21
+ Entries along axis 1 should be the kx, ky, and kz components, in that order
22
+ The kpoints should be in reduced form
23
+ pw_indices : np.ndarray
24
+ Array of H, K, L indices for the planwave basis set.
25
+ Arrays of (1,3) [H,K,L] should fill axis 0, and H, K, L values fill axis 1 in that order.
26
+ Necessary for arranging wavefunction coefficients in 3D array.
27
+ syrmel : np.ndarray
28
+ A multidimensional array of 3x3 arrays of symmetry operations
29
+ non_symm_vec : np.ndarray
30
+ A multidimensional array of 1x3 arrays of the nonsymmorphic translation vectors for each-
31
+ symmetry operation.
32
+ nsym : int
33
+ Total number of symmetry operations
34
+ nkpt : int
35
+ Total number of kpoints
36
+ If a kpoints array is provided, then nkpt will be acquired the its length
37
+ nbands : int
38
+ Total number of bands
39
+ ngfftx : int
40
+ x dimension of Fourier transform grid
41
+ ngffty : int
42
+ y dimension of Fourier transform grid
43
+ ngfftz : int
44
+ z dimension of Fourier transform grid
45
+ eigenvalues : list
46
+ List of the eigenvalues for wavefunction at each band
47
+ Should be ordered from least -> greatest
48
+ fermi_energy : float
49
+ Fermi energy
50
+ lattice : np.ndarray
51
+ 3x3 array containing lattice parameters
52
+ natom : int
53
+ Total number of atoms in unit cell
54
+ xred : np.ndarray
55
+ Reduced coordinates of all atoms in unit cell
56
+ Individual atomic coordinates fill along axis 0
57
+ X, Y, and Z components fill along axis 1, in that order
58
+ typat : list
59
+ Numeric labels starting from 1 and incrementing up to natom
60
+ Order of labels should follow xred
61
+ znucltypat : list
62
+ List of element names
63
+ First element of list should correspond to typat label 1, second element to label 2 and so on
64
+ time_reversal : bool
65
+ Select whether system has time reversal symmetry or not
66
+ If the system is time reversal symmetric, then reciprocal space electronic states will share inversion
67
+ symmetry even if the real space symmetries do not include inversion
68
+ Default assumes noncentrosymmetric systems have time reversal symmetry (True)
69
+
70
+ Methods
71
+ -------
72
+ GridWFK
73
+ Assembles plane wave coefficients on FFT grid
74
+ RemoveGrid
75
+ Undoes FFT grid and returns coefficients to a flat array
76
+ FFT
77
+ Applies Fast Fourier Transform to plane wave coefficients
78
+ IFFT
79
+ Applies Inverse Fast Fourier Transform to plane wave coefficients
80
+ Normalize
81
+ Calculates and applies normalization factor to coefficients
82
+ Real2Reciprocal
83
+ Calculates reciprocal lattice vectors from real space vectors
84
+ Symmetrize
85
+ Generates symmetrical copies from symmetry matrix operations
86
+ SymWFK
87
+ Generates symmetrical plane wave coefficients from operations
88
+ XSFFormat
89
+ Converts plane wave coefficients grid into XSF formatted grid
90
+ RemoveXSF
91
+ Converts XSF formatted grid into regular FFT grid
92
+ WriteXSF
93
+ Prints out XSF files for both the real and imaginary parts of the coefficients
94
+ '''
95
+ def __init__(
96
+ self,
97
+ wfk_coeffs:np.ndarray=np.zeros(1), kpoints:np.ndarray=np.zeros(1), symrel:np.ndarray=np.zeros(1),
98
+ nsym:int=0, nkpt:int=0, nbands:int=0, ngfftx:int=0, ngffty:int=0, ngfftz:int=0,
99
+ eigenvalues:np.ndarray=np.zeros(1),fermi_energy:float=0.0, lattice:np.ndarray=np.zeros(1), natom:int=0,
100
+ xred:np.ndarray=np.zeros(1), typat:list=[], znucltypat:list=[], pw_indices:np.ndarray=np.zeros(1),
101
+ non_symm_vecs:np.ndarray=np.zeros(1), time_reversal:bool=True
102
+ )->None:
103
+ self.wfk_coeffs=wfk_coeffs
104
+ self.kpoints=kpoints
105
+ self.pw_indices=pw_indices
106
+ self.symrel=symrel
107
+ self.nsym=nsym
108
+ self.non_symm_vecs=non_symm_vecs
109
+ self.nkpt=nkpt
110
+ self.nbands=nbands
111
+ self.ngfftx=ngfftx
112
+ self.ngffty=ngffty
113
+ self.ngfftz=ngfftz
114
+ self.eigenvalues=eigenvalues
115
+ self.fermi_energy=fermi_energy
116
+ self.lattice=lattice
117
+ self.natom=natom
118
+ self.xred=xred
119
+ self.typat=typat
120
+ self.znucltypat=znucltypat
121
+ self.time_reversal=time_reversal
122
+ #---------------------------------------------------------------------------------------------------------------------#
123
+ #------------------------------------------------------ METHODS ------------------------------------------------------#
124
+ #---------------------------------------------------------------------------------------------------------------------#
125
+ # method for putting plane wave coefficients onto 3D gridded array
126
+ def GridWFK(
127
+ self, band_index:int=-1
128
+ )->Self:
129
+ '''
130
+ Returns copy of WFK object with coefficients in numpy 3D array grid.
131
+ Grid is organized in (ngfftz, ngfftx, ngffty) dimensions.
132
+ Where ngfft_ represents the _ Fourier transform grid dimension.
133
+
134
+ Parameters
135
+ ----------
136
+ band_index : int
137
+ Integer represent the band index of the wavefunction coefficients to be transformed.
138
+ If nothing is passed, it is assumed the coefficients of a single band are supplied.
139
+ '''
140
+ # initialize 3D grid
141
+ gridded_wfk = np.zeros((self.ngfftx, self.ngffty, self.ngfftz), dtype=complex)
142
+ # update grid with wfk coefficients
143
+ for k, kpt in enumerate(self.pw_indices):
144
+ kx = kpt[0]
145
+ ky = kpt[1]
146
+ kz = kpt[2]
147
+ if band_index >= 0:
148
+ gridded_wfk[kx, ky, kz] = self.wfk_coeffs[band_index][k]
149
+ else:
150
+ gridded_wfk[kx, ky, kz] = self.wfk_coeffs[k]
151
+ new_WFK = copy(self)
152
+ new_WFK.wfk_coeffs = gridded_wfk
153
+ return new_WFK
154
+ #-----------------------------------------------------------------------------------------------------------------#
155
+ # method for undoing grid
156
+ def RemoveGrid(
157
+ self, band_index:int=-1
158
+ )->Self:
159
+ '''
160
+ Returns copy of WFK object with coefficients removed from the 3D gridded array.
161
+
162
+ Parameters
163
+ ----------
164
+ band_index : int
165
+ Integer represent the band index of the wavefunction coefficients to be transformed.
166
+ If nothing is passed, it is assumed the coefficients of a single band are supplied.
167
+ '''
168
+ # check if coefficients are gridded before undoing grid format
169
+ if self.wfk_coeffs.shape != (self.ngfftx,self.ngffty,self.ngfftz):
170
+ raise ValueError((
171
+ f'Plane wave coefficients must be in 3D grid with shape ({self.ngfftx}, {self.ngffty}, {self.ngfftz})'
172
+ ' in order to remove the gridded format'
173
+ ))
174
+ if band_index >= 0:
175
+ coeffs_no_grid = self.wfk_coeffs[band_index]
176
+ else:
177
+ coeffs_no_grid = self.wfk_coeffs
178
+ # returns values at each plane wave index, undoing grid
179
+ coeffs_no_grid = coeffs_no_grid[tuple(self.pw_indices.T)]
180
+ new_WFK = copy(self)
181
+ new_WFK.wfk_coeffs = coeffs_no_grid
182
+ return new_WFK
183
+ #-----------------------------------------------------------------------------------------------------------------#
184
+ # method transforming reciprocal space wfks to real space
185
+ def FFT(
186
+ self
187
+ )->Self:
188
+ '''
189
+ Returns copy of WFK with wavefunction coefficients expressed in real space.
190
+ Assumes existing wavefunction coefficients are expressed in reciprocal space.
191
+ '''
192
+ # Fourier transform real grid to reciprocal grid
193
+ reciprocal_coeffs = fftn(self.wfk_coeffs, norm='ortho')
194
+ new_WFK = copy(self)
195
+ new_WFK.wfk_coeffs = np.array(reciprocal_coeffs).reshape((self.ngfftx, self.ngffty, self.ngfftz))
196
+ return new_WFK
197
+ #-----------------------------------------------------------------------------------------------------------------#
198
+ # method transforming real space wfks to reciprocal space
199
+ def IFFT(
200
+ self
201
+ )->Self:
202
+ '''
203
+ Returns copy of WFK with wavefunction coefficients in expressed in reciprocal space.
204
+ Assumes existing wavefunction coefficients are expressed in real space.
205
+ '''
206
+ # Fourier transform reciprocal grid to real grid
207
+ real_coeffs = ifftn(self.wfk_coeffs, norm='ortho')
208
+ new_WFK = copy(self)
209
+ new_WFK.wfk_coeffs = np.array(real_coeffs).reshape((self.ngfftx,self.ngffty,self.ngfftz))
210
+ return new_WFK
211
+ #-----------------------------------------------------------------------------------------------------------------#
212
+ # method for normalizing wfks
213
+ def Normalize(
214
+ self
215
+ )->Self:
216
+ '''
217
+ Returns copy of WFK object with normalized wavefunction coefficients such that <psi|psi> = 1.
218
+ '''
219
+ coeffs = np.array(self.wfk_coeffs)
220
+ # calculate normalization constant and apply to wfk
221
+ norm = np.dot(coeffs.flatten(), np.conj(coeffs).flatten())
222
+ norm = np.sqrt(norm)
223
+ new_WFK = copy(self)
224
+ new_WFK.wfk_coeffs /= norm
225
+ return new_WFK
226
+ #-----------------------------------------------------------------------------------------------------------------#
227
+ # method for converting real space lattice vectors to reciprocal space vectors
228
+ @property
229
+ def rec_latt(
230
+ self
231
+ )->np.ndarray:
232
+ '''
233
+ Method for converting the real space lattice parameters to reciprocal lattice parameters.
234
+ '''
235
+ # conversion by default converts Angstrom to Bohr since ABINIT uses Bohr
236
+ a = self.lattice[0,:]
237
+ b = self.lattice[1,:]
238
+ c = self.lattice[2,:]
239
+ vol = np.dot(a,np.cross(b,c))
240
+ b1 = 2*np.pi*(np.cross(b,c))/vol
241
+ b2 = 2*np.pi*(np.cross(c,a))/vol
242
+ b3 = 2*np.pi*(np.cross(a,b))/vol
243
+ return np.array([b1,b2,b3]).reshape((3,3))
244
+ #-----------------------------------------------------------------------------------------------------------------#
245
+ # method for checking for time reversal symmetry
246
+ def _CheckTimeRevSym(
247
+ self
248
+ ):
249
+ if self.time_reversal:
250
+ # if system is centrosymmetric, do not double reciprocal symmetry operations
251
+ if -3.0 in [np.trace(mat) for mat in self.symrel]:
252
+ self.time_reversal = False
253
+ else:
254
+ print((
255
+ 'Noncentrosymmetric system identified, assuming time reversal symmetry\n'
256
+ 'To change this, set time_reversal attribute to False'
257
+ ))
258
+ #-----------------------------------------------------------------------------------------------------------------#
259
+ # method for finding symmetrically distinct k points
260
+ def _FindOrbit(
261
+ self, sym_kpts:np.ndarray
262
+ )->tuple[list,list]:
263
+ sym_kpts = np.round(sym_kpts, decimals=15)
264
+ _, unique_inds = np.unique(sym_kpts, return_index=True, axis=0)
265
+ # for each unique kpoint check original point is related by reciprocal lattice vector
266
+ dupes = []
267
+ for i, ind1 in enumerate(unique_inds):
268
+ if i in dupes:
269
+ continue
270
+ for j, ind2 in enumerate(unique_inds):
271
+ if i == j or j in dupes:
272
+ continue
273
+ diff = np.abs(sym_kpts[ind1] - sym_kpts[ind2])
274
+ diff[diff < 10**(-12)] = 0.0
275
+ diff[diff > 0.999] = 1.0
276
+ mask = np.isin(diff, np.array([0.0,1.0]))
277
+ if mask.all():
278
+ dupes.append(j)
279
+ return dupes, unique_inds.tolist()
280
+ #-----------------------------------------------------------------------------------------------------------------#
281
+ # function for calculating phase imparted by nonsymmorphic translation
282
+ def _FindPhase(
283
+ self, nonsymmvec:np.ndarray, g_vecs:np.ndarray, kpt:np.ndarray
284
+ )->np.ndarray:
285
+ if self.non_symm_vecs is np.zeros(1):
286
+ return np.ones(len(g_vecs))
287
+ elif np.sum(np.abs(nonsymmvec)) < 10**(-8):
288
+ return np.ones(len(g_vecs))
289
+ else:
290
+ return np.exp(-1j*np.dot((kpt+g_vecs), nonsymmvec.T))
291
+ #-----------------------------------------------------------------------------------------------------------------#
292
+ # method for creating symmetrically equivalent points
293
+ def Symmetrize(
294
+ self, points:np.ndarray, values:np.ndarray=np.empty([]), unique:bool=True, reciprocal:bool=False,
295
+ inverse:bool=False
296
+ )->tuple[np.ndarray, np.ndarray]:
297
+ '''
298
+ Method for generating symmetric data from irreducible data.
299
+
300
+ Parameters
301
+ ----------
302
+ points : np.ndarray
303
+ Irreducible set of points.
304
+ Shape of (N,3).
305
+ values : np.ndarray
306
+ Values corresponding to irreducible points (such as energy eigenvalues w/ kpoints).
307
+ Shape of (N,1).
308
+ unique : bool
309
+ Check for duplicate points.
310
+ Default is to check (True).
311
+ reciprocal : bool
312
+ Calculate reciprocal space symmetry matrices from real space matrices.
313
+ Default uses real space matrices (False).
314
+ inverse : bool
315
+ Use inverse symmetry operations.
316
+ Default applies forwards operation (False).
317
+ '''
318
+ # check if reciprocal or real space symmetries will be used
319
+ sym_num = self.nsym
320
+ if reciprocal:
321
+ # nosymmorphic translations do not apply to reciprocal space
322
+ tnons = False
323
+ sym_mats = [np.linalg.inv(mat).T for mat in self.symrel]
324
+ # time reversal only adds to reciprocal space symmetries
325
+ if self.time_reversal:
326
+ sym_mats = np.concatenate((sym_mats, [-mat for mat in sym_mats]), axis=0)
327
+ sym_num *= 2
328
+ else:
329
+ tnons = True
330
+ sym_mats = self.symrel
331
+ # initialize symmetrically equivalent point and value arrays
332
+ if len(points.shape) == 1:
333
+ points.reshape((1,points.shape[0]))
334
+ ind_len = np.shape(points)[0]
335
+ if values is np.empty([]):
336
+ values = np.zeros((ind_len,1))
337
+ sym_pts = np.zeros((sym_num*ind_len,3))
338
+ sym_vals = np.zeros((sym_num*ind_len,self.nbands))
339
+ if self.non_symm_vecs.all() == np.zeros(1):
340
+ self.non_symm_vecs = np.zeros(self.nsym)
341
+ # apply symmetry operations to points
342
+ if inverse:
343
+ for i, op in enumerate(sym_mats):
344
+ if tnons:
345
+ points += self.non_symm_vecs[i]
346
+ new_pts:np.ndarray = np.matmul(np.linalg.inv(op), points.T).T
347
+ sym_pts[i*ind_len:(i+1)*ind_len,:] = new_pts
348
+ sym_vals[i*ind_len:(i+1)*ind_len,:] = values
349
+ else:
350
+ for i, op in enumerate(sym_mats):
351
+ if tnons:
352
+ points += self.non_symm_vecs[i]
353
+ new_pts:np.ndarray = np.matmul(op, points.T).T
354
+ sym_pts[i*ind_len:(i+1)*ind_len,:] = new_pts
355
+ sym_vals[i*ind_len:(i+1)*ind_len,:] = values
356
+ # points overlap on at edges of each symmetric block, remove duplicates
357
+ if unique:
358
+ dupes, unique_inds = self._FindOrbit(sym_pts)
359
+ unique_kpts = np.array([sym_pts[ind,:] for i, ind in enumerate(unique_inds) if i not in dupes])
360
+ unique_vals = np.array([sym_vals[ind,:] for i, ind in enumerate(unique_inds) if i not in dupes])
361
+ return unique_kpts, unique_vals
362
+ return sym_pts, sym_vals
363
+ #-----------------------------------------------------------------------------------------------------------------#
364
+ # method for creating symmetrically equivalent functions at specified kpoint
365
+ def SymWFKs(
366
+ self, kpoint:np.ndarray, band:int=-1
367
+ )->Generator[Self, None, None]:
368
+ '''
369
+ Method for generating wavefunction planewave coefficients from coefficients of the irreducible BZ.
370
+
371
+ Parameters
372
+ ----------
373
+ kpoint : np.ndarray
374
+ A single reciprocal space point is provided to generate symmetrically equivalent coefficients.
375
+ Shape (1,3).
376
+ band : int
377
+ Choose which band to pull coefficients from (indexed starting from zero).
378
+ Default assumes coefficients from a single band are provided (-1).
379
+ '''
380
+ # find symmetric kpoints
381
+ kpoint = kpoint.reshape((1,3))
382
+ sym_kpoints, _ = self.Symmetrize(kpoint, unique=False, reciprocal=True)
383
+ dupes, unique_inds = self._FindOrbit(sym_kpoints)
384
+ # find symmetric planewave indices
385
+ sym_pw_inds, _ = self.Symmetrize(self.pw_indices, unique=False, reciprocal=True)
386
+ sym_pw_inds = sym_pw_inds.astype(int)
387
+ ind_range = self.pw_indices.shape[0]
388
+ # find reciprocal lattice shifts to move all points into BZ
389
+ bz = brlzn.BZ(self.rec_latt)
390
+ shifts = bz.GetShifts(sym_kpoints)
391
+ # create WFK copies with new planewave indices
392
+ for i, ind in enumerate(unique_inds):
393
+ if i in dupes:
394
+ continue
395
+ ind1 = ind*ind_range
396
+ ind2 = (ind+1)*ind_range
397
+ new_pw_inds = sym_pw_inds[ind1:ind2,:]
398
+ new_pw_inds += shifts[ind,:]
399
+ new_coeffs = copy(self)
400
+ new_coeffs.pw_indices = new_pw_inds
401
+ new_coeffs.kpoints = sym_kpoints[ind,:] - shifts[ind,:]
402
+ phase_factor = self._FindPhase(
403
+ self.non_symm_vecs[ind % len(self.non_symm_vecs)],
404
+ self.pw_indices,
405
+ sym_kpoints[ind,:]
406
+ )
407
+ if band >= 0:
408
+ new_coeffs.wfk_coeffs = new_coeffs.wfk_coeffs[band] * phase_factor
409
+ yield new_coeffs
410
+ else:
411
+ new_coeffs.wfk_coeffs *= phase_factor
412
+ yield new_coeffs
413
+ #-----------------------------------------------------------------------------------------------------------------#
414
+ # method that returns BZ kpoints and eigenvalues
415
+ def GetBZPtsEigs(
416
+ self
417
+ )->tuple[np.ndarray,np.ndarray]:
418
+ bz = brlzn.BZ(self.rec_latt)
419
+ bz_kpts, bz_eigs = self.Symmetrize(points=self.kpoints, values=self.eigenvalues, reciprocal=True)
420
+ bz_kpts -= bz.GetShifts(bz_kpts)
421
+ return bz_kpts, bz_eigs
422
+ #-----------------------------------------------------------------------------------------------------------------#
423
+ # method for expanding a grid into XSF format
424
+ def XSFFormat(
425
+ self
426
+ )->Self:
427
+ '''
428
+ Returns copy of WFK object XSF formatted coefficients.
429
+ Requires wfk_coeffs to be in gridded format, i.e. (ngfftz, ngfftx, ngffty) shape.
430
+ '''
431
+ # append zeros to ends of all axes in grid_wfk
432
+ # zeros get replaced by values at beginning of each axis
433
+ # this repetition is required by XSF format
434
+ if np.shape(self.wfk_coeffs) != (self.ngfftx, self.ngffty, self.ngfftz):
435
+ raise ValueError(
436
+ f'''Passed array is not the correct shape:
437
+ Expected: ({self.ngfftx}, {self.ngffty}, {self.ngfftz}),
438
+ Received: {np.shape(self.wfk_coeffs)}
439
+ ''')
440
+ else:
441
+ grid_wfk = self.wfk_coeffs
442
+ grid_wfk = np.append(grid_wfk, np.zeros((1, self.ngffty, self.ngfftz)), axis=0)
443
+ grid_wfk = np.append(grid_wfk, np.zeros((self.ngfftx+1, 1, self.ngfftz)), axis=1)
444
+ grid_wfk = np.append(grid_wfk, np.zeros((self.ngfftx+1, self.ngffty+1, 1)), axis=2)
445
+ for x in range(self.ngfftx+1):
446
+ for y in range(self.ngffty+1):
447
+ for z in range(self.ngfftz+1):
448
+ if x == self.ngfftx:
449
+ grid_wfk[x,y,z] = grid_wfk[0,y,z]
450
+ if y == self.ngffty:
451
+ grid_wfk[x,y,z] = grid_wfk[x,0,z]
452
+ if z == self.ngfftz:
453
+ grid_wfk[x,y,z] = grid_wfk[x,y,0]
454
+ if x == self.ngfftx and y == self.ngffty:
455
+ grid_wfk[x,y,z] = grid_wfk[0,0,z]
456
+ if x == self.ngfftx and z == self.ngfftz:
457
+ grid_wfk[x,y,z] = grid_wfk[0,y,0]
458
+ if z == self.ngfftz and y == self.ngffty:
459
+ grid_wfk[x,y,z] = grid_wfk[x,0,0]
460
+ if x == self.ngfftx and y == self.ngffty and z == self.ngfftz:
461
+ grid_wfk[x,y,z] = grid_wfk[0,0,0]
462
+ new_WFK = copy(self)
463
+ new_WFK.wfk_coeffs = grid_wfk
464
+ new_WFK.ngfftx += 1
465
+ new_WFK.ngffty += 1
466
+ new_WFK.ngfftz += 1
467
+ return new_WFK
468
+ #-----------------------------------------------------------------------------------------------------------------#
469
+ # method removing XSF formatting from density grid
470
+ def RemoveXSF(
471
+ self
472
+ )->Self:
473
+ '''
474
+ Returns copy of WFK object without XSF formatting.
475
+ '''
476
+ grid = self.wfk_coeffs
477
+ # to_be_del will be used to remove all extra data points added for XSF formatting
478
+ to_be_del = np.ones((self.ngfftx, self.ngffty, self.ngfftz), dtype=bool)
479
+ for z in range(self.ngfftz):
480
+ for y in range(self.ngffty):
481
+ for x in range(self.ngfftx):
482
+ # any time you reach the last density point it is a repeat of the first point
483
+ # remove the end points along each axis
484
+ if y == self.ngffty - 1 or x == self.ngfftx - 1 or z == self.ngfftz - 1:
485
+ to_be_del[x,y,z] = False
486
+ # remove xsf entries from array
487
+ grid = grid[to_be_del]
488
+ # restore grid shape
489
+ grid = grid.reshape((self.ngfftx-1, self.ngffty-1, self.ngfftz-1))
490
+ new_WFK = copy(self)
491
+ new_WFK.wfk_coeffs = grid
492
+ new_WFK.ngfftx -= 1
493
+ new_WFK.ngffty -= 1
494
+ new_WFK.ngfftz -= 1
495
+ return new_WFK
496
+ #-----------------------------------------------------------------------------------------------------------------#
497
+ # method for writing wavefunctions to XSF file
498
+ def WriteXSF(
499
+ self, xsf_file:str, _component:bool=True
500
+ )->None:
501
+ '''
502
+ A method for writing numpy grids to an XSF formatted file.
503
+
504
+ Parameters
505
+ ----------
506
+ xsf_file : str
507
+ The file name.
508
+ '''
509
+ # first run writes out real part of eigenfunction to xsf
510
+ if _component:
511
+ xsf_file += '_real.xsf'
512
+ # second run writes out imaginary part
513
+ else:
514
+ xsf_file += '_imag.xsf'
515
+ with open(xsf_file, 'w') as xsf:
516
+ print('DIM-GROUP', file=xsf)
517
+ print('3 1', file=xsf)
518
+ print('PRIMVEC', file=xsf)
519
+ print(f'{self.lattice[0,0]} {self.lattice[0,1]} {self.lattice[0,2]}', file=xsf)
520
+ print(f'{self.lattice[1,0]} {self.lattice[1,1]} {self.lattice[1,2]}', file=xsf)
521
+ print(f'{self.lattice[2,0]} {self.lattice[2,1]} {self.lattice[2,2]}', file=xsf)
522
+ print('PRIMCOORD', file=xsf)
523
+ print(f'{self.natom} 1', file=xsf)
524
+ for i, coord in enumerate(self.xred):
525
+ atomic_num = int(self.znucltypat[self.typat[i] - 1])
526
+ cart_coord = np.matmul(coord, self.lattice)
527
+ print(f'{atomic_num} {cart_coord[0]} {cart_coord[1]} {cart_coord[2]}', file=xsf)
528
+ print('ATOMS', file=xsf)
529
+ for i, coord in enumerate(self.xred):
530
+ atomic_num = int(self.znucltypat[self.typat[i] - 1])
531
+ cart_coord = np.matmul(coord, self.lattice)
532
+ print(f'{atomic_num} {cart_coord[0]} {cart_coord[1]} {cart_coord[2]}', file=xsf)
533
+ print('BEGIN_BLOCK_DATAGRID_3D', file=xsf)
534
+ print('datagrids', file=xsf)
535
+ print('BEGIN_DATAGRID_3D_principal_orbital_component', file=xsf)
536
+ print(f'{self.ngfftx} {self.ngffty} {self.ngfftz}', file=xsf)
537
+ print('0.0 0.0 0.0', file=xsf)
538
+ print(f'{self.lattice[0,0]} {self.lattice[0,1]} {self.lattice[0,2]}', file=xsf)
539
+ print(f'{self.lattice[1,0]} {self.lattice[1,1]} {self.lattice[1,2]}', file=xsf)
540
+ print(f'{self.lattice[2,0]} {self.lattice[2,1]} {self.lattice[2,2]}', file=xsf)
541
+ count = 0
542
+ for z in range(self.ngfftz):
543
+ for y in range(self.ngffty):
544
+ for x in range(self.ngfftx):
545
+ count += 1
546
+ if _component:
547
+ print(self.wfk_coeffs[x,y,z].real, file=xsf, end=' ')
548
+ else:
549
+ print(self.wfk_coeffs[x,y,z].imag, file=xsf, end=' ')
550
+ if count == 6:
551
+ count = 0
552
+ print('\n', file=xsf, end='')
553
+ print('END_DATAGRID_3D', file=xsf)
554
+ print('END_BLOCK_DATAGRID_3D', file=xsf)
555
+ # rerun method to write out imaginary part
556
+ if _component:
557
+ xsf_file = xsf_file.split('_real')[0]
557
558
  self.WriteXSF(xsf_file, _component=False)