valyte 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- valyte/band.py +135 -177
- valyte/band_plot.py +25 -66
- valyte/cli.py +66 -81
- valyte/dos_plot.py +137 -306
- valyte/ipr.py +186 -0
- valyte/kpoints.py +26 -54
- valyte/potcar.py +29 -54
- valyte/supercell.py +5 -22
- {valyte-0.1.9.dist-info → valyte-0.1.11.dist-info}/METADATA +60 -1
- valyte-0.1.11.dist-info/RECORD +19 -0
- valyte-0.1.9.dist-info/RECORD +0 -18
- {valyte-0.1.9.dist-info → valyte-0.1.11.dist-info}/WHEEL +0 -0
- {valyte-0.1.9.dist-info → valyte-0.1.11.dist-info}/entry_points.txt +0 -0
- {valyte-0.1.9.dist-info → valyte-0.1.11.dist-info}/top_level.txt +0 -0
valyte/band.py
CHANGED
|
@@ -1,289 +1,247 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Band structure KPOINTS generation module for Valyte.
|
|
3
|
-
"""
|
|
1
|
+
"""Band structure KPOINTS generation."""
|
|
4
2
|
|
|
5
3
|
import os
|
|
6
4
|
import json
|
|
7
5
|
import numpy as np
|
|
8
6
|
import seekpath
|
|
9
|
-
import spglib
|
|
10
7
|
from pymatgen.core import Structure
|
|
11
8
|
from pymatgen.symmetry.bandstructure import HighSymmKpath
|
|
12
9
|
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
|
13
10
|
try:
|
|
14
11
|
from importlib.resources import files as ilr_files
|
|
15
12
|
except ImportError:
|
|
16
|
-
|
|
13
|
+
from importlib_resources import files as ilr_files
|
|
17
14
|
|
|
18
15
|
from valyte.potcar import generate_potcar
|
|
19
16
|
|
|
20
17
|
|
|
21
18
|
def generate_band_kpoints(poscar_path="POSCAR", npoints=40, output="KPOINTS", symprec=0.01, mode="bradcrack"):
|
|
22
|
-
"""
|
|
23
|
-
|
|
24
|
-
Uses SeeK-path method for high-symmetry path determination.
|
|
25
|
-
|
|
26
|
-
IMPORTANT: Writes a standardized POSCAR (POSCAR_standard) that MUST be used
|
|
27
|
-
for the band structure calculation to ensure K-points are valid.
|
|
28
|
-
|
|
29
|
-
Args:
|
|
30
|
-
poscar_path (str): Path to input POSCAR file.
|
|
31
|
-
npoints (int): Number of points per segment (default: 40).
|
|
32
|
-
output (str): Output filename for KPOINTS.
|
|
33
|
-
symprec (float): Symmetry precision for standardization (default: 0.01).
|
|
34
|
-
mode (str): Standardization convention (default: "bradcrack").
|
|
35
|
-
"""
|
|
36
|
-
|
|
19
|
+
"""Generate a line-mode KPOINTS file for band structure calculations."""
|
|
20
|
+
|
|
37
21
|
if not os.path.exists(poscar_path):
|
|
38
22
|
raise FileNotFoundError(f"{poscar_path} not found")
|
|
39
|
-
|
|
40
|
-
|
|
23
|
+
|
|
24
|
+
mode = (mode or "bradcrack").lower()
|
|
25
|
+
|
|
41
26
|
structure = Structure.from_file(poscar_path)
|
|
42
|
-
|
|
43
|
-
# --- K-Point Generation Logic ---
|
|
27
|
+
|
|
44
28
|
if mode == "bradcrack":
|
|
45
29
|
try:
|
|
46
30
|
kpath = BradCrackKpath(structure, symprec=symprec)
|
|
47
31
|
prim_std = kpath.prim
|
|
48
32
|
path = kpath.path
|
|
49
33
|
kpoints = kpath.kpoints
|
|
50
|
-
|
|
51
|
-
# Write standardized POSCAR from BradCrack logic
|
|
34
|
+
|
|
52
35
|
standard_filename = "POSCAR_standard"
|
|
53
36
|
prim_std.to(filename=standard_filename)
|
|
54
37
|
except Exception as e:
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
else:
|
|
59
|
-
# Fallback to Pymatgen logic for other modes
|
|
38
|
+
raise RuntimeError(f"Error generating Bradley-Cracknell path: {e}")
|
|
39
|
+
else:
|
|
60
40
|
try:
|
|
61
|
-
# Map 'seekpath' alias to 'hinuma' which pymatgen uses (wrapper around seekpath)
|
|
62
41
|
if mode == "seekpath":
|
|
63
42
|
mode = "hinuma"
|
|
64
43
|
|
|
65
|
-
# Standardize structure first using SpacegroupAnalyzer
|
|
66
44
|
sga = SpacegroupAnalyzer(structure, symprec=symprec)
|
|
67
45
|
prim_std = sga.get_primitive_standard_structure()
|
|
68
46
|
except Exception as e:
|
|
69
|
-
|
|
70
|
-
return
|
|
47
|
+
raise RuntimeError(f"Error during standardization: {e}")
|
|
71
48
|
|
|
72
|
-
# Get high-symmetry path for the STANDARDIZED structure
|
|
73
49
|
try:
|
|
74
50
|
kpath = HighSymmKpath(prim_std, path_type=mode, symprec=symprec)
|
|
75
|
-
|
|
76
|
-
# Write the standardized primitive structure
|
|
51
|
+
|
|
77
52
|
standard_filename = "POSCAR_standard"
|
|
78
53
|
prim_std.to(filename=standard_filename)
|
|
79
|
-
|
|
80
|
-
# Get the path
|
|
54
|
+
|
|
81
55
|
path = kpath.kpath["path"]
|
|
82
56
|
kpoints = kpath.kpath["kpoints"]
|
|
83
57
|
except Exception as e:
|
|
84
|
-
|
|
85
|
-
return
|
|
58
|
+
raise RuntimeError(f"Error generating K-path: {e}")
|
|
86
59
|
|
|
87
|
-
# Write KPOINTS file
|
|
88
60
|
try:
|
|
89
61
|
with open(output, "w") as f:
|
|
90
62
|
f.write("KPOINTS for Band Structure\n")
|
|
91
63
|
f.write(f"{npoints}\n")
|
|
92
64
|
f.write("Line-mode\n")
|
|
93
65
|
f.write("Reciprocal\n")
|
|
94
|
-
|
|
66
|
+
|
|
95
67
|
for subpath in path:
|
|
96
68
|
for i in range(len(subpath) - 1):
|
|
97
69
|
start_label = subpath[i]
|
|
98
|
-
end_label = subpath[i+1]
|
|
99
|
-
|
|
70
|
+
end_label = subpath[i + 1]
|
|
71
|
+
|
|
100
72
|
start_coords = kpoints[start_label]
|
|
101
73
|
end_coords = kpoints[end_label]
|
|
102
|
-
|
|
103
|
-
f.write(f"{start_coords[0]:10.6f} {start_coords[1]:10.6f} {start_coords[2]:10.6f} ! {start_label}\n")
|
|
104
|
-
f.write(f"{end_coords[0]:10.6f} {end_coords[1]:10.6f} {end_coords[2]:10.6f} ! {end_label}\n")
|
|
105
|
-
f.write("\n") # Optional newline between segments
|
|
106
|
-
|
|
107
|
-
print(f"✅ Generated {output} ({' - '.join([' - '.join(seg) for seg in path])})")
|
|
108
|
-
print(f"✅ Generated {standard_filename} (Standardized Primitive Cell)")
|
|
109
|
-
print(f"\n⚠️ IMPORTANT: You MUST use '{standard_filename}' for your band calculation!")
|
|
110
|
-
print(f" The K-points are generated for this standardized orientation.")
|
|
111
|
-
print(f" Using your original POSCAR may result in incorrect paths or 'Reciprocal lattice' errors.")
|
|
112
74
|
|
|
75
|
+
f.write(
|
|
76
|
+
f"{start_coords[0]:10.6f} {start_coords[1]:10.6f} {start_coords[2]:10.6f} ! {start_label}\n"
|
|
77
|
+
)
|
|
78
|
+
f.write(
|
|
79
|
+
f"{end_coords[0]:10.6f} {end_coords[1]:10.6f} {end_coords[2]:10.6f} ! {end_label}\n"
|
|
80
|
+
)
|
|
81
|
+
f.write("\n")
|
|
82
|
+
|
|
83
|
+
print(f"Generated {output} ({' - '.join([' - '.join(seg) for seg in path])})")
|
|
84
|
+
print(f"Generated {standard_filename} (Standardized Primitive Cell)")
|
|
85
|
+
print("IMPORTANT: Use POSCAR_standard for the band calculation.")
|
|
113
86
|
except Exception as e:
|
|
114
|
-
|
|
87
|
+
raise RuntimeError(f"Error writing KPOINTS file: {e}")
|
|
115
88
|
|
|
116
|
-
# --- POTCAR Generation ---
|
|
117
89
|
try:
|
|
118
|
-
print("
|
|
90
|
+
print("Generating default POTCAR (PBE)...")
|
|
119
91
|
generate_potcar(poscar_path=poscar_path, functional="PBE", output="POTCAR")
|
|
120
92
|
except Exception as e:
|
|
121
|
-
print(f"
|
|
122
|
-
print("
|
|
123
|
-
|
|
93
|
+
print(f"Warning: could not generate POTCAR: {e}")
|
|
94
|
+
print("Proceeding without POTCAR generation.")
|
|
124
95
|
|
|
125
96
|
|
|
126
97
|
class BradCrackKpath:
|
|
127
|
-
"""
|
|
128
|
-
|
|
129
|
-
Replicates logic from Sumo/SeeK-path to determine standard paths.
|
|
130
|
-
"""
|
|
98
|
+
"""Bradley-Cracknell K-path generation using SeeK-path output."""
|
|
99
|
+
|
|
131
100
|
def __init__(self, structure, symprec=0.01):
|
|
132
101
|
self.structure = structure
|
|
133
102
|
self.symprec = symprec
|
|
134
|
-
|
|
135
|
-
# Use SpacegroupAnalyzer for basic data
|
|
103
|
+
|
|
136
104
|
sga = SpacegroupAnalyzer(structure, symprec=symprec)
|
|
137
105
|
self._spg_data = sga.get_symmetry_dataset()
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
cell = (structure.lattice.matrix, structure.frac_coords, [s.specie.number for s in structure])
|
|
146
|
-
|
|
147
|
-
# Sumo uses spglib.refine_cell on the cell first?
|
|
148
|
-
# "std = spglib.refine_cell(sym._cell, symprec=symprec)"
|
|
149
|
-
# pymatgen sga._cell is (lattice, positions, numbers)
|
|
150
|
-
|
|
151
|
-
# seekpath.get_path takes the cell structure
|
|
152
|
-
# output is dictionary
|
|
106
|
+
|
|
107
|
+
cell = (
|
|
108
|
+
structure.lattice.matrix,
|
|
109
|
+
structure.frac_coords,
|
|
110
|
+
[s.specie.number for s in structure],
|
|
111
|
+
)
|
|
112
|
+
|
|
153
113
|
self._seek_data = seekpath.get_path(cell, symprec=symprec)
|
|
154
|
-
|
|
155
|
-
# Reconstruct primitive structure from seekpath output
|
|
114
|
+
|
|
156
115
|
prim_lattice = self._seek_data["primitive_lattice"]
|
|
157
116
|
prim_pos = self._seek_data["primitive_positions"]
|
|
158
117
|
prim_types = self._seek_data["primitive_types"]
|
|
159
|
-
|
|
160
|
-
# We need a map from number to Element.
|
|
161
|
-
# unique_species from sga?
|
|
162
|
-
# Let's just use explicit element list from input structure, assuming types are consistent?
|
|
163
|
-
# Or better, use sga to map Z to elements.
|
|
164
|
-
|
|
165
|
-
# Setup element mapping
|
|
166
|
-
# Create a map from atomic number to Element object from input structure
|
|
118
|
+
|
|
167
119
|
z_to_specie = {s.specie.number: s.specie for s in structure}
|
|
168
120
|
prim_species = [z_to_specie[z] for z in prim_types]
|
|
169
|
-
|
|
121
|
+
|
|
170
122
|
self.prim = Structure(prim_lattice, prim_species, prim_pos)
|
|
171
|
-
|
|
123
|
+
|
|
172
124
|
conv_lattice = self._seek_data["conv_lattice"]
|
|
173
125
|
conv_pos = self._seek_data["conv_positions"]
|
|
174
126
|
conv_types = self._seek_data["conv_types"]
|
|
175
127
|
conv_species = [z_to_specie[z] for z in conv_types]
|
|
176
128
|
self.conv = Structure(conv_lattice, conv_species, conv_pos)
|
|
177
|
-
|
|
178
|
-
# Now determine Bravais lattice for BradCrack
|
|
129
|
+
|
|
179
130
|
self._get_bradcrack_path()
|
|
180
131
|
|
|
181
132
|
def _get_bradcrack_path(self):
|
|
182
|
-
|
|
183
|
-
# Determine lattice parameters from CONVENTIONAL cell
|
|
184
133
|
a, b, c = self.conv.lattice.abc
|
|
185
134
|
angles = self.conv.lattice.angles
|
|
186
|
-
|
|
187
|
-
# logic from BradCrackKpath.__init__
|
|
188
|
-
# "unique = angles.index(min(angles, key=angles.count))"
|
|
189
|
-
# usually 90, 90, beta. So unique is beta (non-90) index? No.
|
|
190
|
-
# Monoclinic: alpha=gamma=90, beta!=90. 90 appears twice. non-90 appears once.
|
|
191
|
-
# min count of angle values?
|
|
192
|
-
# if angles are [90, 90, 105], counts are {90:2, 105:1}. min count is 1. value is 105. index is 2.
|
|
193
|
-
# so unique is index of non-90 degree angle.
|
|
194
|
-
|
|
195
|
-
# Round angles to avoid float issues
|
|
135
|
+
|
|
196
136
|
angles_r = [round(x, 3) for x in angles]
|
|
197
137
|
unique_val = min(angles_r, key=angles_r.count)
|
|
198
138
|
unique = angles_r.index(unique_val)
|
|
199
139
|
|
|
200
|
-
# Get Space Group Symbol and Number
|
|
201
|
-
# From seekpath? or sga?
|
|
202
|
-
# Sumo uses: "spg_symbol = self.spg_symbol" which is "self._spg_data['international']"
|
|
203
|
-
# spglib dataset returns 'international'
|
|
204
140
|
spg_symbol = self._spg_data["international"]
|
|
205
141
|
spg_number = self._spg_data["number"]
|
|
206
|
-
|
|
142
|
+
|
|
207
143
|
lattice_type = self.get_lattice_type(spg_number)
|
|
208
|
-
|
|
209
144
|
bravais = self._get_bravais_lattice(spg_symbol, lattice_type, a, b, c, unique)
|
|
210
|
-
|
|
211
|
-
# Load JSON
|
|
212
|
-
|
|
145
|
+
|
|
213
146
|
json_file = ilr_files("valyte.data").joinpath("bradcrack.json")
|
|
214
|
-
with open(
|
|
147
|
+
with json_file.open("r") as f:
|
|
215
148
|
data = json.load(f)
|
|
216
|
-
|
|
149
|
+
|
|
217
150
|
if bravais not in data:
|
|
218
151
|
raise ValueError(f"Bravais lattice code '{bravais}' not found in BradCrack data.")
|
|
219
|
-
|
|
152
|
+
|
|
220
153
|
self.bradcrack_data = data[bravais]
|
|
221
154
|
self.kpoints = self.bradcrack_data["kpoints"]
|
|
222
155
|
self.path = self.bradcrack_data["path"]
|
|
223
156
|
|
|
224
157
|
def get_lattice_type(self, number):
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
if 3 <= number <= 15:
|
|
228
|
-
|
|
229
|
-
if
|
|
230
|
-
|
|
231
|
-
|
|
158
|
+
if 1 <= number <= 2:
|
|
159
|
+
return "triclinic"
|
|
160
|
+
if 3 <= number <= 15:
|
|
161
|
+
return "monoclinic"
|
|
162
|
+
if 16 <= number <= 74:
|
|
163
|
+
return "orthorhombic"
|
|
164
|
+
if 75 <= number <= 142:
|
|
165
|
+
return "tetragonal"
|
|
166
|
+
if 143 <= number <= 167:
|
|
167
|
+
if number in [146, 148, 155, 160, 161, 166, 167]:
|
|
168
|
+
return "rhombohedral"
|
|
232
169
|
return "trigonal"
|
|
233
|
-
if 168 <= number <= 194:
|
|
234
|
-
|
|
170
|
+
if 168 <= number <= 194:
|
|
171
|
+
return "hexagonal"
|
|
172
|
+
if 195 <= number <= 230:
|
|
173
|
+
return "cubic"
|
|
235
174
|
return "unknown"
|
|
236
175
|
|
|
237
176
|
def _get_bravais_lattice(self, spg_symbol, lattice_type, a, b, c, unique):
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
177
|
+
if lattice_type == "triclinic":
|
|
178
|
+
return "triclinic"
|
|
179
|
+
|
|
180
|
+
if lattice_type == "monoclinic":
|
|
181
|
+
if "P" in spg_symbol:
|
|
182
|
+
if unique == 0:
|
|
183
|
+
return "mon_p_a"
|
|
184
|
+
if unique == 1:
|
|
185
|
+
return "mon_p_b"
|
|
186
|
+
if unique == 2:
|
|
187
|
+
return "mon_p_c"
|
|
188
|
+
if "C" in spg_symbol:
|
|
189
|
+
if unique == 0:
|
|
190
|
+
return "mon_c_a"
|
|
191
|
+
if unique == 1:
|
|
192
|
+
return "mon_c_b"
|
|
193
|
+
if unique == 2:
|
|
194
|
+
return "mon_c_c"
|
|
195
|
+
|
|
196
|
+
if lattice_type == "orthorhombic":
|
|
242
197
|
if "P" in spg_symbol:
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
if
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
elif "A" in spg_symbol or "C" in spg_symbol:
|
|
254
|
-
if a > b: return "orth_c_a"
|
|
255
|
-
elif b > a: return "orth_c_b"
|
|
256
|
-
elif "F" in spg_symbol:
|
|
257
|
-
# 1/a^2 etc conditions... need to replicate exact math
|
|
258
|
-
# Copied from Sumo source view
|
|
259
|
-
inv_a2 = 1/a**2; inv_b2 = 1/b**2; inv_c2 = 1/c**2
|
|
198
|
+
return "orth_p"
|
|
199
|
+
if "A" in spg_symbol or "C" in spg_symbol:
|
|
200
|
+
if a > b:
|
|
201
|
+
return "orth_c_a"
|
|
202
|
+
if b > a:
|
|
203
|
+
return "orth_c_b"
|
|
204
|
+
if "F" in spg_symbol:
|
|
205
|
+
inv_a2 = 1 / a**2
|
|
206
|
+
inv_b2 = 1 / b**2
|
|
207
|
+
inv_c2 = 1 / c**2
|
|
260
208
|
if (inv_a2 < inv_b2 + inv_c2) and (inv_b2 < inv_c2 + inv_a2) and (inv_c2 < inv_a2 + inv_b2):
|
|
261
209
|
return "orth_f_1"
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
if
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
if
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
210
|
+
if inv_c2 > inv_a2 + inv_b2:
|
|
211
|
+
return "orth_f_2"
|
|
212
|
+
if inv_b2 > inv_a2 + inv_c2:
|
|
213
|
+
return "orth_f_3"
|
|
214
|
+
if inv_a2 > inv_c2 + inv_b2:
|
|
215
|
+
return "orth_f_4"
|
|
216
|
+
if "I" in spg_symbol:
|
|
217
|
+
if a > b and a > c:
|
|
218
|
+
return "orth_i_a"
|
|
219
|
+
if b > a and b > c:
|
|
220
|
+
return "orth_i_b"
|
|
221
|
+
if c > a and c > b:
|
|
222
|
+
return "orth_i_c"
|
|
223
|
+
|
|
224
|
+
if lattice_type == "tetragonal":
|
|
225
|
+
if "P" in spg_symbol:
|
|
226
|
+
return "tet_p"
|
|
227
|
+
if "I" in spg_symbol:
|
|
228
|
+
return "tet_i_a" if a > c else "tet_i_c"
|
|
229
|
+
|
|
230
|
+
if lattice_type in ["trigonal", "hexagonal", "rhombohedral"]:
|
|
277
231
|
if "R" in spg_symbol:
|
|
278
|
-
if a > np.sqrt(2)*c
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
232
|
+
return "trig_r_a" if a > np.sqrt(2) * c else "trig_r_c"
|
|
233
|
+
if "P" in spg_symbol:
|
|
234
|
+
if unique == 0:
|
|
235
|
+
return "trig_p_a"
|
|
236
|
+
if unique == 2:
|
|
237
|
+
return "trig_p_c"
|
|
238
|
+
|
|
239
|
+
if lattice_type == "cubic":
|
|
240
|
+
if "P" in spg_symbol:
|
|
241
|
+
return "cubic_p"
|
|
242
|
+
if "I" in spg_symbol:
|
|
243
|
+
return "cubic_i"
|
|
244
|
+
if "F" in spg_symbol:
|
|
245
|
+
return "cubic_f"
|
|
246
|
+
|
|
289
247
|
return "unknown"
|
valyte/band_plot.py
CHANGED
|
@@ -4,25 +4,17 @@ import matplotlib as mpl
|
|
|
4
4
|
mpl.use("agg")
|
|
5
5
|
mpl.rcParams["axes.unicode_minus"] = False
|
|
6
6
|
import matplotlib.pyplot as plt
|
|
7
|
-
from pymatgen.io.vasp import
|
|
7
|
+
from pymatgen.io.vasp import BSVasprun
|
|
8
8
|
from pymatgen.electronic_structure.plotter import BSPlotter
|
|
9
9
|
|
|
10
|
+
|
|
10
11
|
def plot_band_structure(vasprun_path, kpoints_path=None, output="valyte_band.png",
|
|
11
12
|
ylim=None, figsize=(4, 4), dpi=400, font="Arial"):
|
|
12
|
-
"""
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
kpoints_path (str, optional): Path to KPOINTS file (for labels).
|
|
18
|
-
output (str): Output filename.
|
|
19
|
-
ylim (tuple, optional): Energy range (min, max).
|
|
20
|
-
figsize (tuple): Figure size in inches.
|
|
21
|
-
dpi (int): Resolution of the output image.
|
|
22
|
-
font (str): Font family.
|
|
23
|
-
"""
|
|
24
|
-
|
|
25
|
-
# --- Font configuration ---
|
|
13
|
+
"""Plot the electronic band structure from a VASP vasprun.xml."""
|
|
14
|
+
|
|
15
|
+
if os.path.isdir(vasprun_path):
|
|
16
|
+
vasprun_path = os.path.join(vasprun_path, "vasprun.xml")
|
|
17
|
+
|
|
26
18
|
font_map = {
|
|
27
19
|
"arial": "Arial",
|
|
28
20
|
"helvetica": "Helvetica",
|
|
@@ -37,91 +29,58 @@ def plot_band_structure(vasprun_path, kpoints_path=None, output="valyte_band.png
|
|
|
37
29
|
mpl.rcParams["xtick.major.width"] = 1.2
|
|
38
30
|
mpl.rcParams["ytick.major.width"] = 1.2
|
|
39
31
|
|
|
40
|
-
# print(f"🔍 Reading {vasprun_path} ...") # Silent mode
|
|
41
|
-
|
|
42
32
|
try:
|
|
43
|
-
# Load VASP output
|
|
44
|
-
# BSVasprun is optimized for band structures
|
|
45
33
|
vr = BSVasprun(vasprun_path, parse_projected_eigen=False)
|
|
46
34
|
bs = vr.get_band_structure(kpoints_filename=kpoints_path, line_mode=True)
|
|
47
35
|
except Exception as e:
|
|
48
36
|
raise ValueError(f"Failed to load band structure: {e}")
|
|
49
37
|
|
|
50
|
-
# Use BSPlotter to get the data in a plot-friendly format
|
|
51
38
|
bs_plotter = BSPlotter(bs)
|
|
52
39
|
data = bs_plotter.bs_plot_data(zero_to_efermi=True)
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
# Setup plot
|
|
40
|
+
|
|
41
|
+
distances = data["distances"]
|
|
42
|
+
energies = data["energy"]
|
|
43
|
+
ticks = data["ticks"]
|
|
44
|
+
|
|
60
45
|
fig, ax = plt.subplots(figsize=figsize)
|
|
61
|
-
|
|
62
|
-
#
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
# Plot bands
|
|
67
|
-
# Iterate over segments
|
|
46
|
+
|
|
47
|
+
color_vb = "#8e44ad"
|
|
48
|
+
color_cb = "#2a9d8f"
|
|
49
|
+
|
|
68
50
|
for i in range(len(distances)):
|
|
69
51
|
d = distances[i]
|
|
70
|
-
|
|
71
|
-
# Handle different energy data structures
|
|
52
|
+
|
|
72
53
|
if isinstance(energies, dict):
|
|
73
|
-
# Structure: {'1': [seg1, seg2, ...], '-1': ...}
|
|
74
|
-
# Iterate over spins
|
|
75
54
|
for spin in energies:
|
|
76
|
-
# energies[spin] is a list of segments
|
|
77
|
-
# energies[spin][i] is the list of bands for segment i
|
|
78
55
|
for band in energies[spin][i]:
|
|
79
|
-
|
|
80
|
-
if np.mean(band) <= 0:
|
|
81
|
-
c = color_vb
|
|
82
|
-
else:
|
|
83
|
-
c = color_cb
|
|
56
|
+
c = color_vb if np.mean(band) <= 0 else color_cb
|
|
84
57
|
ax.plot(d, band, color=c, lw=1.5, alpha=1.0)
|
|
85
58
|
else:
|
|
86
|
-
# Structure: [{'1': bands, ...}, {'1': bands, ...}] (List of dicts)
|
|
87
|
-
# Iterate over spin channels in this segment
|
|
88
59
|
for spin in energies[i]:
|
|
89
|
-
# energies[i][spin] is a list of arrays (one per band)
|
|
90
60
|
for band in energies[i][spin]:
|
|
91
|
-
if np.mean(band) <= 0
|
|
92
|
-
c = color_vb
|
|
93
|
-
else:
|
|
94
|
-
c = color_cb
|
|
61
|
+
c = color_vb if np.mean(band) <= 0 else color_cb
|
|
95
62
|
ax.plot(d, band, color=c, lw=1.5, alpha=1.0)
|
|
96
63
|
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
# Clean up labels (remove formatting like $ if needed, but pymatgen usually does a good job)
|
|
100
|
-
clean_labels = [l.replace("$\\mid$", "|") for l in ticks['label']]
|
|
64
|
+
ax.set_xticks(ticks["distance"])
|
|
65
|
+
clean_labels = [(l or "").replace("$\\mid$", "|") for l in ticks["label"]]
|
|
101
66
|
ax.set_xticklabels(clean_labels, fontsize=14, fontweight="bold")
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
for d in ticks['distance']:
|
|
67
|
+
|
|
68
|
+
for d in ticks["distance"]:
|
|
105
69
|
ax.axvline(d, color="k", lw=0.8, ls="-", alpha=0.3)
|
|
106
|
-
|
|
107
|
-
# Draw VBM line (E=0)
|
|
70
|
+
|
|
108
71
|
ax.axhline(0, color="k", lw=0.8, ls="--", alpha=0.5)
|
|
109
72
|
|
|
110
|
-
# Setup Y-axis
|
|
111
73
|
ax.set_ylabel("Energy (eV)", fontsize=16, fontweight="bold", labelpad=8)
|
|
112
74
|
if ylim:
|
|
113
75
|
ax.set_ylim(ylim)
|
|
114
|
-
# Set y-ticks with 1 eV spacing
|
|
115
76
|
yticks = np.arange(np.ceil(ylim[0]), np.floor(ylim[1]) + 1, 1)
|
|
116
77
|
ax.set_yticks(yticks)
|
|
117
78
|
else:
|
|
118
|
-
# Default zoom around gap
|
|
119
79
|
ax.set_ylim(-4, 4)
|
|
120
80
|
ax.set_yticks(np.arange(-4, 5, 1))
|
|
121
|
-
|
|
81
|
+
|
|
122
82
|
ax.set_xlim(distances[0][0], distances[-1][-1])
|
|
123
83
|
|
|
124
84
|
plt.tight_layout()
|
|
125
85
|
plt.savefig(output, dpi=dpi)
|
|
126
86
|
plt.close(fig)
|
|
127
|
-
# print(f"✅ Band structure saved to {output}") # Silent mode
|